aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/dense_simple_expand_function/dense_simple_expand_function_test.cpp
blob: d89ebf44912c623e71f3bbb437da7fe4233bc32f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/instruction/dense_simple_expand_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/vespalib/gtest/gtest.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::eval::test;
using namespace vespalib::eval::tensor_function;

using Inner = DenseSimpleExpandFunction::Inner;

struct FunInfo {
    using LookFor = DenseSimpleExpandFunction;
    Inner inner;
    void verify(const LookFor &fun) const {
        EXPECT_TRUE(fun.result_is_mutable());
        EXPECT_EQ(fun.inner(), inner);
    }
};

void verify_optimized(const vespalib::string &expr, Inner inner) {
    SCOPED_TRACE(expr.c_str());
    CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
    EvalFixture::verify<FunInfo>(expr, {FunInfo{inner}}, all_types);
}

void verify_not_optimized(const vespalib::string &expr) {
    SCOPED_TRACE(expr.c_str());
    CellTypeSpace just_double({CellType::DOUBLE}, 2);
    EvalFixture::verify<FunInfo>(expr, {}, just_double);
}

TEST(ExpandTest, simple_expand_is_optimized) {
    verify_optimized("join(a5,b3,f(x,y)(x*y))", Inner::RHS);
    verify_optimized("join(b3,a5,f(x,y)(x*y))", Inner::LHS);
}

TEST(ExpandTest, multiple_dimensions_are_supported) {
    verify_optimized("join(a5,x3y2,f(x,y)(x*y))", Inner::RHS);
    verify_optimized("join(x3y2,a5,f(x,y)(x*y))", Inner::LHS);
    verify_optimized("join(a5c3,x3y2,f(x,y)(x*y))", Inner::RHS);
    verify_optimized("join(x3y2,a5c3,f(x,y)(x*y))", Inner::LHS);
}

TEST(ExpandTest, trivial_dimensions_are_ignored) {
    verify_optimized("join(A1a5c1,B1b3c1,f(x,y)(x*y))", Inner::RHS);
    verify_optimized("join(B1b3c1,A1a5c1,f(x,y)(x*y))", Inner::LHS);
}

TEST(ExpandTest, simple_expand_handles_asymmetric_operations_correctly) {
    verify_optimized("join(a5,b3,f(x,y)(x-y))", Inner::RHS);
    verify_optimized("join(b3,a5,f(x,y)(x-y))", Inner::LHS);
    verify_optimized("join(a5,b3,f(x,y)(x/y))", Inner::RHS);
    verify_optimized("join(b3,a5,f(x,y)(x/y))", Inner::LHS);
}

#if 0
// XXX no way to really verify this now
TEST(ExpandTest, simple_expand_is_never_inplace) {
    verify_optimized("join(@a5,@b3,f(x,y)(x*y))", Inner::RHS);
    verify_optimized("join(@b3,@a5,f(x,y)(x*y))", Inner::LHS);
}
#endif

TEST(ExpandTest, interleaved_dimensions_are_not_optimized) {
    verify_not_optimized("join(a5c3,b3,f(x,y)(x*y))");
    verify_not_optimized("join(b3,a5c3,f(x,y)(x*y))");
}

TEST(ExpandTest, matching_dimensions_are_not_expanding) {
    verify_not_optimized("join(a5c3,a5,f(x,y)(x*y))");
    verify_not_optimized("join(a5,a5c3,f(x,y)(x*y))");
}

TEST(ExpandTest, scalar_is_not_expanding) {
    verify_not_optimized("join(a5,@$1,f(x,y)(x*y))");
}

TEST(ExpandTest, unit_tensor_is_not_expanding) {
    verify_not_optimized("join(a5,x1y1z1,f(x,y)(x+y))");
    verify_not_optimized("join(x1y1z1,a5,f(x,y)(x+y))");
    verify_not_optimized("join(a1b1c1,x1y1z1,f(x,y)(x+y))");
}

TEST(ExpandTest, sparse_expand_is_not_optimized) {
    verify_not_optimized("join(a5,x1_1,f(x,y)(x*y))");
    verify_not_optimized("join(x1_1,a5,f(x,y)(x*y))");
}

TEST(ExpandTest, mixed_expand_is_not_optimized) {
    verify_not_optimized("join(a5,y1_1z2,f(x,y)(x*y))");
    verify_not_optimized("join(y1_1z2,a5,f(x,y)(x*y))");
}

GTEST_MAIN_RUN_ALL_TESTS()