aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
blob: f27a2073159f56fe382558a54ed0ef877b137691 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/log/log.h>
LOG_SETUP("dense_dot_product_function_test");

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/simple_tensor.h>
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_xw_product_function.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor_builder.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>

#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/util/stash.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::tensor;

const TensorEngine &ref_engine = SimpleTensorEngine::ref();
const TensorEngine &prod_engine = DefaultTensorEngine::ref();

void verify_equal(const Value &expect, const Value &value) {
    const eval::Tensor *tensor = value.as_tensor();
    ASSERT_TRUE(tensor != nullptr);
    const eval::Tensor *expect_tensor = expect.as_tensor();
    ASSERT_TRUE(expect_tensor != nullptr);
    auto expect_spec = expect_tensor->engine().to_spec(expect);
    auto value_spec = tensor->engine().to_spec(value);
    EXPECT_EQUAL(expect_spec, value_spec);
}

SimpleObjectParams wrap(std::vector<eval::Value::CREF> params) {
    return SimpleObjectParams(params);
}

void verify_result(const TensorSpec &v, const TensorSpec &m, bool happy) {
    Stash stash;
    Value::UP ref_vec = ref_engine.from_spec(v);
    Value::UP ref_mat = ref_engine.from_spec(m);
    const Value &joined = ref_engine.join(*ref_vec, *ref_mat, operation::Mul::f, stash);
    const Value &expect = ref_engine.reduce(joined, Aggr::SUM, {"x"}, stash);

    Value::UP prod_vec = prod_engine.from_spec(v);
    Value::UP prod_mat = prod_engine.from_spec(m);

    DenseXWProductFunction fun1(expect.type(), 0, 1,
                                prod_vec->type().dimensions()[0].size,
                                expect.type().dimensions()[0].size,
                                happy);
    const Value &actual1 = fun1.eval(wrap({*prod_vec, *prod_mat}), stash);
    TEST_DO(verify_equal(expect, actual1));

    DenseXWProductFunction fun2(expect.type(), 1, 0,
                                prod_vec->type().dimensions()[0].size,
                                expect.type().dimensions()[0].size,
                                happy);
    const Value &actual2 = fun2.eval(wrap({*prod_mat, *prod_vec}), stash);
    TEST_DO(verify_equal(expect, actual2));
}

TensorSpec make_vector(const vespalib::string &name, size_t sz) {
    TensorSpec ret(make_string("tensor(%s[%zu])", name.c_str(), sz));
    for (size_t i = 0; i < sz; ++i) {
        ret.add({{name, i}}, (1.0 + i) * 16.0);
    }
    return ret;
}

TensorSpec make_matrix(const vespalib::string &d1name, size_t d1sz,
                       const vespalib::string &d2name, size_t d2sz)
{
    TensorSpec ret(make_string("tensor(%s[%zu],%s[%zu])",
                               d1name.c_str(), d1sz,
                               d2name.c_str(), d2sz));
    for (size_t i = 0; i < d1sz; ++i) {
        for (size_t j = 0; j < d2sz; ++j) {
            ret.add({{d1name,i},{d2name,j}}, 1.0 + i*7.0 + j*43.0);
        }
    }
    return ret;
}

TEST("require that xw product gives same results as reference join/reduce") {
    verify_result(make_vector("x", 1), make_matrix("o", 1, "x", 1), true);
    verify_result(make_vector("x", 1), make_matrix("x", 1, "y", 1), false);

    verify_result(make_vector("x", 3), make_matrix("o", 2, "x", 3), true);
    verify_result(make_vector("x", 3), make_matrix("x", 3, "y", 2), false);

    verify_result(make_vector("x", 5), make_matrix("o", 8, "x", 5), true);
    verify_result(make_vector("x", 5), make_matrix("x", 5, "y", 8), false);

    verify_result(make_vector("x", 16), make_matrix("o", 5, "x", 16), true);
    verify_result(make_vector("x", 16), make_matrix("x", 16, "y", 5), false);
}

TEST_MAIN() { TEST_RUN_ALL(); }