summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_matmul_function/dense_matmul_function_test.cpp
blob: 92fdbfade46abba288a7f85086ebf07eae4c4534 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/simple_tensor.h>
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_matmul_function.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
#include <vespa/eval/eval/test/eval_fixture.h>

#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::eval::test;
using namespace vespalib::tensor;
using namespace vespalib::eval::tensor_function;

const TensorEngine &prod_engine = DefaultTensorEngine::ref();

EvalFixture::ParamRepo make_params() {
    return EvalFixture::ParamRepo()
        .add_matrix("a", 2, "d", 3)  // inner/inner
        .add_matrix("a", 2, "b", 5)  // inner/outer
        .add_matrix("b", 5, "c", 2)  // outer/outer
        .add_matrix("a", 2, "c", 3)  // not matching
        //------------------------------------------
        .add_matrix("b", 5, "d", 3); // fixed param
}
EvalFixture::ParamRepo param_repo = make_params();

void verify_optimized(const vespalib::string &expr,
                      size_t lhs_size, size_t common_size, size_t rhs_size,
                      bool lhs_inner, bool rhs_inner)
{
    EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    EXPECT_EQUAL(fixture.result(), slow_fixture.result());
    auto info = fixture.find_all<DenseMatMulFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    EXPECT_TRUE(info[0]->result_is_mutable());
    EXPECT_EQUAL(info[0]->lhs_size(), lhs_size);
    EXPECT_EQUAL(info[0]->common_size(), common_size);
    EXPECT_EQUAL(info[0]->rhs_size(), rhs_size);
    EXPECT_EQUAL(info[0]->lhs_common_inner(), lhs_inner);
    EXPECT_EQUAL(info[0]->rhs_common_inner(), rhs_inner);
}

void verify_not_optimized(const vespalib::string &expr) {
    EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    EXPECT_EQUAL(fixture.result(), slow_fixture.result());
    auto info = fixture.find_all<DenseMatMulFunction>();
    EXPECT_TRUE(info.empty());
}

TEST("require that matmul can be optimized") {
    TEST_DO(verify_optimized("reduce(a2d3*b5d3,sum,d)", 2, 3, 5, true, true));
}

TEST("require that matmul with lambda can be optimized") {
    TEST_DO(verify_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y)),sum,d)", 2, 3, 5, true, true));
}

TEST("require that expressions similar to matmul are not optimized") {
    TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,a)"));
    TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum,b)"));
    TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,prod,d)"));
    TEST_DO(verify_not_optimized("reduce(a2d3*b5d3,sum)"));
    TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*x)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x+y)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*x)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(y*y)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(join(a2d3,b5d3,f(x,y)(x*y*1)),sum,d)"));
    TEST_DO(verify_not_optimized("reduce(a2c3*b5d3,sum,d)"));
    TEST_DO(verify_not_optimized("reduce(a2c3*b5d3,sum,c)"));
}

TEST("require that xw product can be debug dumped") {
    EvalFixture fixture(prod_engine, "reduce(a2d3*b5d3,sum,d)", param_repo, true);
    auto info = fixture.find_all<DenseMatMulFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}

vespalib::string make_expr(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
                           bool float_a, bool float_b)
{
    return make_string("reduce(%s%s*%s%s,sum,%s)", a.c_str(), float_a ? "f" : "", b.c_str(), float_b ? "f" : "", common.c_str());
}

void verify_optimized_multi(const vespalib::string &a, const vespalib::string &b, const vespalib::string &common,
                            size_t lhs_size, size_t common_size, size_t rhs_size,
                            bool lhs_inner, bool rhs_inner)
{
    for (bool float_a: {false, true}) {
        for (bool float_b: {false, true}) {
            {
                auto expr = make_expr(a, b, common, float_a, float_b);
                TEST_STATE(expr.c_str());
                TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner));
            }
            {
                auto expr = make_expr(b, a, common, float_b, float_a);
                TEST_STATE(expr.c_str());
                TEST_DO(verify_optimized(expr, lhs_size, common_size, rhs_size, lhs_inner, rhs_inner));
            }
        }
    }
}

TEST("require that matmul inner/inner works correctly") {
    TEST_DO(verify_optimized_multi("a2d3", "b5d3", "d", 2, 3, 5, true, true));
}

TEST("require that matmul inner/outer works correctly") {
    TEST_DO(verify_optimized_multi("a2b5", "b5d3", "b", 2, 5, 3, true, false));
}

TEST("require that matmul outer/outer works correctly") {
    TEST_DO(verify_optimized_multi("b5c2", "b5d3", "b", 2, 5, 3, false, false));
}

TEST_MAIN() { TEST_RUN_ALL(); }