aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/l2_distance/l2_distance_test.cpp
blob: 114f0c21b8ac6ba626385112673ba890e68978f0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/l2_distance.h>
#include <vespa/vespalib/util/stash.h>
#include <vespa/vespalib/util/stringfmt.h>

#include <vespa/vespalib/util/require.h>
#include <vespa/vespalib/gtest/gtest.h>

using namespace vespalib;
using namespace vespalib::eval;
using namespace vespalib::eval::test;

const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();

//-----------------------------------------------------------------------------

void verify(const TensorSpec &a, const TensorSpec &b, const vespalib::string &expr, bool optimized = true) {
    EvalFixture::ParamRepo param_repo;
    param_repo.add("a", a).add("b", b);
    EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
    EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
    EXPECT_EQ(fast_fixture.find_all<L2Distance>().size(), optimized ? 1 : 0);
}

void verify_cell_types(GenSpec a, GenSpec b, const vespalib::string &expr, bool optimized = true) {
    for (CellType act : CellTypeUtils::list_types()) {
        for (CellType bct : CellTypeUtils::list_types()) {
            if (optimized && (act == bct) && (act != CellType::BFLOAT16)) {
                verify(a.cpy().cells(act), b.cpy().cells(bct), expr, true);
            } else {
                verify(a.cpy().cells(act), b.cpy().cells(bct), expr, false);
            }
        }
    }
}

//-----------------------------------------------------------------------------

GenSpec gen(const vespalib::string &desc, int bias) {
    return GenSpec::from_desc(desc).cells(CellType::FLOAT).seq(N(bias));
}

//-----------------------------------------------------------------------------

vespalib::string sq_l2 = "reduce((a-b)^2,sum)";
vespalib::string alt_sq_l2 = "reduce(map((a-b),f(x)(x*x)),sum)";

//-----------------------------------------------------------------------------

TEST(L2DistanceTest, squared_l2_distance_can_be_optimized) {
    verify_cell_types(gen("x5", 3), gen("x5", 7), sq_l2);
    verify_cell_types(gen("x5", 3), gen("x5", 7), alt_sq_l2);
}

TEST(L2DistanceTest, trivial_dimensions_are_ignored) {
    verify(gen("x5y1", 3), gen("x5", 7), sq_l2);
    verify(gen("x5", 3), gen("x5y1", 7), sq_l2);
}

TEST(L2DistanceTest, multiple_dimensions_can_be_used) {
    verify(gen("x5y3", 3), gen("x5y3", 7), sq_l2);
}

//-----------------------------------------------------------------------------

TEST(L2DistanceTest, inputs_must_be_dense) {
    verify(gen("x5_1", 3), gen("x5_1", 7), sq_l2, false);
    verify(gen("x5_1y3", 3), gen("x5_1y3", 7), sq_l2, false);
    verify(gen("x5", 3), GenSpec(7), sq_l2, false);
    verify(GenSpec(3), gen("x5", 7), sq_l2, false);
}

TEST(L2DistanceTest, result_must_be_double) {
    verify(gen("x5y1", 3), gen("x5y1", 7), "reduce((a-b)^2,sum,x)", false);
    verify(gen("x5y1_1", 3), gen("x5y1_1", 7), "reduce((a-b)^2,sum,x)", false);
}

TEST(L2DistanceTest, dimensions_must_match) {
    verify(gen("x5y3", 3), gen("x5", 7), sq_l2, false);
    verify(gen("x5", 3), gen("x5y3", 7), sq_l2, false);
}

TEST(L2DistanceTest, similar_expressions_are_not_optimized) {
    verify(gen("x5", 3), gen("x5", 7), "reduce((a-b)^2,prod)", false);
    verify(gen("x5", 3), gen("x5", 7), "reduce((a-b)^3,sum)", false);
    verify(gen("x5", 3), gen("x5", 7), "reduce((a+b)^2,sum)", false);
}

//-----------------------------------------------------------------------------

GTEST_MAIN_RUN_ALL_TESTS()