summaryrefslogtreecommitdiffstats
path: root/vespalib/src/tests/eval/tensor/eval_tensor_test.cpp
blob: e381ae88cbed29ed68b4506c6bebae327fd25031 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/eval/function.h>
#include <vespa/vespalib/tensor/tensor.h>
#include <vespa/vespalib/eval/interpreted_function.h>
#include <vespa/vespalib/tensor/default_tensor_engine.h>

using namespace vespalib;
using namespace vespalib::eval;

struct Check {
    const Value &value;
    explicit Check(const Value &value_in) : value(value_in) {}
    bool operator==(const Check &rhs) const { return value.equal(rhs.value); }  
};

std::ostream &operator<<(std::ostream &os, const Check &check) {
    const Value &value = check.value;
    if (value.is_error()) {
        os << "EVAL ERROR";
    } else if (value.is_double()) {
        os << value.as_double();
    } else if (value.is_tensor()) {
        os << value.as_tensor();
    }
    return os;
}

struct Eval {
    InterpretedFunction::Context ctx;
    InterpretedFunction ifun;
    const Value *result;
    explicit Eval(const vespalib::string &expr)
        : ctx(), ifun(tensor::DefaultTensorEngine::ref(), Function::parse(expr)), result(&ifun.eval(ctx)) {}
    bool operator==(const Eval &rhs) const { return result->equal(*rhs.result); }
};

std::ostream &operator<<(std::ostream &os, const Eval &eval) {
    os << Check(*eval.result);
    return os;
}

TEST_FF("require that eval errors are not equal", Eval("{"), Eval("{")) {
    EXPECT_TRUE(f1.result->is_error());
    EXPECT_TRUE(f2.result->is_error());
    EXPECT_NOT_EQUAL(f1, f2);
}

TEST("require that different tensors are not equal") {
    EXPECT_EQUAL(Eval("{{x:1}:1}"), Eval("{{x:1}:1}"));
    EXPECT_NOT_EQUAL(Eval("{{x:1}:1}"), Eval("{{x:1}:2}"));
    EXPECT_NOT_EQUAL(Eval("{{x:1}:1}"), Eval("{{x:2}:1}"));
    EXPECT_NOT_EQUAL(Eval("{{x:1}:1}"), Eval("{{y:1}:1}"));
    EXPECT_NOT_EQUAL(Eval("{{x:1}:1}"), Eval("{{x:1}:1,{x:2}:1}"));
}

TEST("require that tensor sum works") {
    EXPECT_EQUAL(Eval("6.0"), Eval("sum({{x:1}:1,{x:2}:2,{x:3}:3})"));
}

TEST("require that tensor sum over dimension works") {
    EXPECT_EQUAL(Eval("{{x:1}:4,{x:2}:6}"), Eval("sum({{x:1,y:1}:1,{x:2,y:1}:2,{x:1,y:2}:3,{x:2,y:2}:4},y)"));
    EXPECT_EQUAL(Eval("{{y:1}:3,{y:2}:7}"), Eval("sum({{x:1,y:1}:1,{x:2,y:1}:2,{x:1,y:2}:3,{x:2,y:2}:4},x)"));
}

TEST("require that tensor add works") {
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:5,{x:3}:3}"), Eval("{{x:1}:1,{x:2}:2} + {{x:2}:3,{x:3}:3}"));
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:5,{x:3}:3}"), Eval("{{x:2}:3,{x:3}:3} + {{x:1}:1,{x:2}:2}"));
}

TEST("require that tensor sub works") {
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:-1,{x:3}:-3}"), Eval("{{x:1}:1,{x:2}:2} - {{x:2}:3,{x:3}:3}"));
    EXPECT_EQUAL(Eval("{{x:1}:-1,{x:2}:1,{x:3}:3}"), Eval("{{x:2}:3,{x:3}:3} - {{x:1}:1,{x:2}:2}"));
}

TEST("require that tensor multiply works") {
    EXPECT_EQUAL(Eval("{{x:1,y:1}:3,{x:2,y:1}:6,{x:1,y:2}:4,{x:2,y:2}:8}"), Eval("{{x:1}:1,{x:2}:2}*{{y:1}:3,{y:2}:4}"));
}

TEST("require that tensor min works") {
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:2,{x:3}:3}"), Eval("min({{x:1}:1,{x:2}:2}, {{x:2}:3,{x:3}:3})"));
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:2,{x:3}:3}"), Eval("min({{x:2}:3,{x:3}:3}, {{x:1}:1,{x:2}:2})"));
}

TEST("require that tensor max works") {
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:3,{x:3}:3}"), Eval("max({{x:1}:1,{x:2}:2}, {{x:2}:3,{x:3}:3})"));
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:3,{x:3}:3}"), Eval("max({{x:2}:3,{x:3}:3}, {{x:1}:1,{x:2}:2})"));
}

TEST("require that tensor match works") {
    EXPECT_EQUAL(Eval("{{x:2}:6}"), Eval("match({{x:1}:1,{x:2}:2},{{x:2}:3,{x:3}:3})"));
}

TEST("require that tensor cell function works") {
    EXPECT_EQUAL(Eval("{{x:1}:3,{x:2}:4,{x:3}:5}"), Eval("{{x:1}:1,{x:2}:2,{x:3}:3}+2"));
    EXPECT_EQUAL(Eval("{{x:1}:3,{x:2}:4,{x:3}:5}"), Eval("2+{{x:1}:1,{x:2}:2,{x:3}:3}"));
    EXPECT_EQUAL(Eval("{{x:1}:-1,{x:2}:0,{x:3}:1}"), Eval("{{x:1}:1,{x:2}:2,{x:3}:3}-2"));
    EXPECT_EQUAL(Eval("{{x:1}:1,{x:2}:0,{x:3}:-1}"), Eval("2-{{x:1}:1,{x:2}:2,{x:3}:3}"));
}

TEST_MAIN() { TEST_RUN_ALL(); }