summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
blob: c0823248538da19f647c5bf4869bcfb5a5a0f577 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/simple_tensor.h>
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_multi_matmul_function.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
#include <vespa/eval/eval/test/eval_fixture.h>

#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::eval::test;
using namespace vespalib::tensor;
using namespace vespalib::eval::tensor_function;

const TensorEngine &prod_engine = DefaultTensorEngine::ref();

EvalFixture::ParamRepo make_params() {
    return EvalFixture::ParamRepo()
        .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}})  // inner/inner
        .add_dense({{"B", 1}, {"C", 3}, {"a", 2}, {"d", 3}})            // inner/inner, missing A
        .add_dense({{"A", 1}, {"a", 2}, {"d", 3}})                      // inner/inner, single mat
        .add_dense({{"A", 2}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}})  // inner/inner, inverted
        .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"b", 5}})  // inner/outer
        .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"c", 2}})  // outer/outer
        .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"a", 2}, {"c", 3}})  // not matching
        //----------------------------------------------------------------------------------------
        .add_dense({{"A", 2}, {"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}})  // fixed param
        .add_dense({{"B", 1}, {"C", 3}, {"b", 5}, {"d", 3}})            // fixed param, missing A
        .add_dense({{"A", 1}, {"b", 5}, {"d", 3}})                      // fixed param, single mat
        .add_dense({{"B", 5}, {"D", 3}, {"a", 2}, {"b", 1}, {"c", 3}}); // fixed param, inverted
}
EvalFixture::ParamRepo param_repo = make_params();

void verify_optimized(const vespalib::string &expr,
                      size_t lhs_size, size_t common_size, size_t rhs_size, size_t matmul_cnt,
                      bool lhs_inner, bool rhs_inner)
{
    EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    EXPECT_EQUAL(fixture.result(), slow_fixture.result());
    auto info = fixture.find_all<DenseMultiMatMulFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    EXPECT_TRUE(info[0]->result_is_mutable());
    EXPECT_EQUAL(info[0]->lhs_size(), lhs_size);
    EXPECT_EQUAL(info[0]->common_size(), common_size);
    EXPECT_EQUAL(info[0]->rhs_size(), rhs_size);
    EXPECT_EQUAL(info[0]->matmul_cnt(), matmul_cnt);
    EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner);
    EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner);
}

void verify_not_optimized(const vespalib::string &expr) {
    EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    EXPECT_EQUAL(fixture.result(), slow_fixture.result());
    auto info = fixture.find_all<DenseMultiMatMulFunction>();
    EXPECT_TRUE(info.empty());
}

TEST("require that multi matmul can be optimized") {
    TEST_DO(verify_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", 2, 3, 5, 6, true, true));
}

TEST("require that single multi matmul can be optimized") {
    TEST_DO(verify_optimized("reduce(A1a2d3*A1b5d3,sum,d)", 2, 3, 5, 1, true, true));
}

TEST("require that multi matmul with lambda can be optimized") {
    TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, 6, true, true));
    TEST_DO(verify_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*x)),sum,d)", 2, 3, 5, 6, true, true));
}

TEST("require that expressions similar to multi matmul are not optimized") {
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,a)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum,b)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,prod,d)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3,sum)"));
    TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x+y)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*x)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(y*y)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(A2B1C3a2d3,A2B1C3b5d3,f(x,y)(x*y*1)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2c3*A2B1C3b5d3,sum,d)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2c3*A2B1C3b5d3,sum,c)"));
}

TEST("require that multi matmul must have matching cell type") {
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3f*A2B1C3b5d3,sum,d)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*A2B1C3b5d3f,sum,d)"));
}

TEST("require that multi matmul must have matching dimension prefix") {
    TEST_DO(verify_not_optimized("reduce(B1C3a2d3*A2B1C3b5d3,sum,d)"));
    TEST_DO(verify_not_optimized("reduce(A2B1C3a2d3*B1C3b5d3,sum,d)"));
}

TEST("require that multi matmul must have inner nesting of matmul dimensions") {
    TEST_DO(verify_not_optimized("reduce(A2D3a2b1c3*B5D3a2b1c3,sum,D)"));
    TEST_DO(verify_not_optimized("reduce(B5D3a2b1c3*A2D3a2b1c3,sum,D)"));
}

TEST("require that multi matmul function can be debug dumped") {
    EvalFixture fixture(prod_engine, "reduce(A2B1C3a2d3*A2B1C3b5d3,sum,d)", param_repo, true);
    auto info = fixture.find_all<DenseMultiMatMulFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}

vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
                           bool float_cells)
{
    return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_cells ? "f" : "", b.c_str(), float_cells ? "f" : "", common.c_str());
}

void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
                            size_t lhs_size, size_t common_size, size_t rhs_size, size_t matmul_cnt,
                            bool lhs_inner, bool rhs_inner)
{
    for (bool float_cells: {false, true}) {
        {
            auto expr = make_expr(a, b, common, float_cells);
            TEST_STATE(expr.c_str());
            TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, matmul_cnt, lhs_inner, rhs_inner));
        }
        {
            auto expr = make_expr(b, a, common, float_cells);
            TEST_STATE(expr.c_str());
            TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, matmul_cnt, lhs_inner, rhs_inner));
        }
    }
}

TEST("require that multi matmul inner/inner works correctly") {
    TEST_DO(verify_optimized_multi("A2B1C3a2d3", "A2B1C3b5d3", "d", 2, 3, 5, 6, true, true));
}

TEST("require that multi matmul inner/outer works correctly") {
    TEST_DO(verify_optimized_multi("A2B1C3a2b5", "A2B1C3b5d3", "b", 2, 5, 3, 6, true, false));
}

TEST("require that multi matmul outer/outer works correctly") {
    TEST_DO(verify_optimized_multi("A2B1C3b5c2", "A2B1C3b5d3", "b", 2, 5, 3, 6, false, false));
}

TEST_MAIN() { TEST_RUN_ALL(); }