summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
blob: 8045958d9ba2944249ea89b6389888a2d4679729 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/log/log.h>
LOG_SETUP("dense_dot_product_function_test");

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/simple_tensor.h>
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_xw_product_function.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
#include <vespa/eval/eval/test/eval_fixture.h>

#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::eval::test;
using namespace vespalib::tensor;
using namespace vespalib::eval::tensor_function;

const TensorEngine &prod_engine = DefaultTensorEngine::ref();

struct MyVecSeq : Sequence {
    double operator[](size_t i) const override { return (3.0 + i) * 7.0; }
};

struct MyMatSeq : Sequence {
    double operator[](size_t i) const override { return (5.0 + i) * 43.0; }
};

EvalFixture::ParamRepo make_params() {
    return EvalFixture::ParamRepo()
        .add("y1", spec({y(1)}, MyVecSeq()))
        .add("y3", spec({y(3)}, MyVecSeq()))
        .add("y3f", spec({y(3)}, MyVecSeq()), "tensor<float>(y[3])")
        .add("y5", spec({y(5)}, MyVecSeq()))
        .add("y16", spec({y(16)}, MyVecSeq()))
        .add("x1y1", spec({x(1),y(1)}, MyMatSeq()))
        .add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
        .add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
        .add("x2y3f", spec({x(2),y(3)}, MyMatSeq()), "tensor<float>(x[2],y[3])")
        .add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
        .add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
        .add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
        .add("y5z8", spec({y(5),z(8)}, MyMatSeq()))
        .add("x5y16", spec({x(5),y(16)}, MyMatSeq()))
        .add("y16z5", spec({y(16),z(5)}, MyMatSeq()));
}
EvalFixture::ParamRepo param_repo = make_params();

void verify_optimized(const vespalib::string &expr, size_t vec_size, size_t res_size, bool happy) {
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    auto info = fixture.find_all<DenseXWProductFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    EXPECT_TRUE(info[0]->result_is_mutable());
    EXPECT_EQUAL(info[0]->vectorSize(), vec_size);
    EXPECT_EQUAL(info[0]->resultSize(), res_size);
    EXPECT_EQUAL(info[0]->matrixHasCommonDimensionInnermost(), happy);
}

void verify_not_optimized(const vespalib::string &expr) {
    EvalFixture fixture(prod_engine, expr, param_repo, true);
    EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
    auto info = fixture.find_all<DenseXWProductFunction>();
    EXPECT_TRUE(info.empty());
}

TEST("require that xw product gives same results as reference join/reduce") {
    // 1 -> 1 happy/unhappy
    TEST_DO(verify_optimized("reduce(y1*x1y1,sum,y)", 1, 1, true));
    TEST_DO(verify_optimized("reduce(y1*y1z1,sum,y)", 1, 1, false));
    // 3 -> 2 happy/unhappy
    TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
    TEST_DO(verify_optimized("reduce(y3*y3z2,sum,y)", 3, 2, false));
    // 5 -> 8 happy/unhappy
    TEST_DO(verify_optimized("reduce(y5*x8y5,sum,y)", 5, 8, true));
    TEST_DO(verify_optimized("reduce(y5*y5z8,sum,y)", 5, 8, false));
    // 16 -> 5 happy/unhappy
    TEST_DO(verify_optimized("reduce(y16*x5y16,sum,y)", 16, 5, true));
    TEST_DO(verify_optimized("reduce(y16*y16z5,sum,y)", 16, 5, false));
}

TEST("require that various variants of xw product can be optimized") {
    TEST_DO(verify_optimized("reduce(y3*x2y3,sum,y)", 3, 2, true));
    TEST_DO(verify_optimized("reduce(x2y3*y3,sum,y)", 3, 2, true));
    TEST_DO(verify_optimized("reduce(join(y3,x2y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
    TEST_DO(verify_optimized("reduce(join(x2y3,y3,f(x,y)(x*y)),sum,y)", 3, 2, true));
}

TEST("require that expressions similar to xw product are not optimized") {
    TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum,x)"));
    TEST_DO(verify_not_optimized("reduce(y3*x2y3,prod,y)"));
    TEST_DO(verify_not_optimized("reduce(y3*x2y3,sum)"));
    TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x+y)),sum,y)"));
    // TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x)),sum,y)"));
    TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(x*x)),sum,y)"));
    TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*y)),sum,y)"));
    TEST_DO(verify_not_optimized("reduce(join(y3,x2y3,f(x,y)(y*x*1)),sum,y)"));
}

TEST("require that xw products with incompatible dimensions are not optimized") {
    TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,y)"));
    TEST_DO(verify_not_optimized("reduce(y3*x2z3,sum,z)"));
}

TEST("require that xw product can be debug dumped") {
    EvalFixture fixture(prod_engine, "reduce(y5*x8y5,sum,y)", param_repo, true);
    auto info = fixture.find_all<DenseXWProductFunction>();
    ASSERT_EQUAL(info.size(), 1u);
    EXPECT_TRUE(info[0]->result_is_mutable());
    fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}

TEST("require that optimization is disabled for tensors with non-double cells") {
    TEST_DO(verify_not_optimized("reduce(y3f*x2y3,sum,y)"));
    TEST_DO(verify_not_optimized("reduce(y3*x2y3f,sum,y)"));
    TEST_DO(verify_not_optimized("reduce(y3f*x2y3f,sum,y)"));
}

TEST_MAIN() { TEST_RUN_ALL(); }