summaryrefslogtreecommitdiffstats
path: root/eval/src/tests
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-09-28 08:13:03 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-09-30 09:20:27 +0000
commitfff135ac0ccd2ae07edc49857abf7c305b2ac3a5 (patch)
treeac1889791678f06fd101c5cc3096dbdbfcb4515e /eval/src/tests
parent2f5a11f868291b34a3aa2c28817b36c5d0ed3d52 (diff)
best similarity function
Diffstat (limited to 'eval/src/tests')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp14
-rw-r--r--eval/src/tests/instruction/best_similarity_function/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/best_similarity_function/best_similarity_function_test.cpp148
-rw-r--r--eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp2
4 files changed, 171 insertions, 2 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index c457f68a614..3d4e2d41cb5 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -510,4 +510,18 @@ TEST("require that tensor function can be dumped for debugging") {
fprintf(stderr, "function dump -->[[%s]]<-- function dump\n", root.as_string().c_str());
}
+TEST("require that full tensor reduce expands dimension list") {
+ Stash stash;
+ const auto &num = inject(ValueType::from_spec("double"), 0, stash);
+ const auto &mat = inject(ValueType::from_spec("tensor(x[5],y[5])"), 1, stash);
+ const auto *reduce_num = as<Reduce>(reduce(num, Aggr::SUM, {}, stash));
+ const auto *reduce_mat = as<Reduce>(reduce(mat, Aggr::SUM, {}, stash));
+ ASSERT_TRUE(reduce_num);
+ ASSERT_TRUE(reduce_mat);
+ EXPECT_EQUAL(reduce_num->dimensions().size(), 0u);
+ ASSERT_EQUAL(reduce_mat->dimensions().size(), 2u);
+ EXPECT_EQUAL(reduce_mat->dimensions()[0], "x");
+ EXPECT_EQUAL(reduce_mat->dimensions()[1], "y");
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/instruction/best_similarity_function/CMakeLists.txt b/eval/src/tests/instruction/best_similarity_function/CMakeLists.txt
new file mode 100644
index 00000000000..fbcf435aad1
--- /dev/null
+++ b/eval/src/tests/instruction/best_similarity_function/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_best_similarity_function_test_app TEST
+ SOURCES
+ best_similarity_function_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_best_similarity_function_test_app COMMAND eval_best_similarity_function_test_app)
diff --git a/eval/src/tests/instruction/best_similarity_function/best_similarity_function_test.cpp b/eval/src/tests/instruction/best_similarity_function/best_similarity_function_test.cpp
new file mode 100644
index 00000000000..058b0f82678
--- /dev/null
+++ b/eval/src/tests/instruction/best_similarity_function/best_similarity_function_test.cpp
@@ -0,0 +1,148 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/eval/instruction/best_similarity_function.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib;
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+
+//-----------------------------------------------------------------------------
+
+void verify_impl(const TensorSpec &a, const TensorSpec &b, const vespalib::string &expr, bool optimized) {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add("a", a).add("b", b);
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<BestSimilarityFunction>().size(), optimized ? 1 : 0);
+}
+
+void verify(const TensorSpec &a, const TensorSpec &b, const vespalib::string &expr, bool optimized = true) {
+ verify_impl(a, b, expr, optimized);
+ verify_impl(b, a, expr, optimized);
+}
+
+//-----------------------------------------------------------------------------
+
+GenSpec gen_double(const vespalib::string &desc, int bias) {
+ return GenSpec::from_desc(desc).cells(CellType::DOUBLE).seq(N(bias));
+}
+
+GenSpec gen_float(const vespalib::string &desc, int bias) {
+ return GenSpec::from_desc(desc).cells(CellType::FLOAT).seq(N(bias));
+}
+
+GenSpec gen_int8(const vespalib::string &desc, int bias) {
+ return GenSpec::from_desc(desc).cells(CellType::INT8).seq(N(bias));
+}
+
+vespalib::string max_sim = "reduce(reduce(a*b,sum,d),max,b)";
+vespalib::string min_hamming = "reduce(reduce(hamming(a,b),sum,d),min,b)";
+
+//-----------------------------------------------------------------------------
+
+TEST(BestSimilarityFunctionTest, result_is_mutable) {
+ tensor_function::Inject child(ValueType::double_type(), 0);
+ BestSimilarityFunction node(ValueType::double_type(), child, child, nullptr, 1);
+ EXPECT_TRUE(node.result_is_mutable());
+}
+
+TEST(BestSimilarityFunctionTest, max_sim_can_be_optimized) {
+ verify(gen_float("A3_2B3d8", 3), gen_float("b5d8", 7), max_sim);
+ verify(gen_float("A3_2B3d8", 3), gen_float("b5_2d8", 7), max_sim);
+}
+
+TEST(BestSimilarityFunctionTest, min_hamming_can_be_optimized) {
+ verify(gen_int8("A3_2B3d8", 3), gen_int8("b5d8", 7), min_hamming);
+ verify(gen_int8("A3_2B3d8", 3), gen_int8("b5_2d8", 7), min_hamming);
+}
+
+TEST(BestSimilarityFunctionTest, result_can_be_sparse) {
+ verify(gen_float("A3_2d8", 3), gen_float("b5d8", 7), max_sim);
+ verify(gen_int8("A3_2d8", 3), gen_int8("b5_2d8", 7), min_hamming);
+}
+
+TEST(BestSimilarityFunctionTest, result_can_be_dense) {
+ verify(gen_float("B3d8", 3), gen_float("b5d8", 7), max_sim);
+ verify(gen_int8("B3d8", 3), gen_int8("b5_2d8", 7), min_hamming);
+}
+
+TEST(BestSimilarityFunctionTest, result_can_be_double) {
+ verify(gen_float("d8", 3), gen_float("b5d8", 7), max_sim);
+ verify(gen_int8("d8", 3), gen_int8("b5_2d8", 7), min_hamming);
+}
+
+TEST(BestSimilarityFunctionTest, primary_dimensions_can_be_trivial) {
+ verify(gen_float("d1", 3), gen_float("b1d1", 7), max_sim);
+ verify(gen_int8("d1", 3), gen_int8("b1d1", 7), min_hamming);
+}
+
+TEST(BestSimilarityFunctionTest, extra_trivial_dimensions_are_allowed) {
+ verify(gen_float("A1a1d8x1z1", 3), gen_float("a1b5c1d8x1y1", 7), max_sim);
+}
+
+TEST(BestSimilarityFunctionTest, allow_full_reduce_for_outer_dimension) {
+ vespalib::string my_max_sim = "reduce(reduce(a*b,sum,d),max)";
+ vespalib::string my_min_hamming = "reduce(reduce(hamming(a,b),sum,d),min)";
+ verify(gen_float("d8", 3), gen_float("b5d8", 7), my_max_sim);
+ verify(gen_int8("d8", 3), gen_int8("b5_2d8", 7), my_min_hamming);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST(BestSimilarityFunctionTest, cell_type_must_match_operation) {
+ verify(gen_double("d8", 3), gen_double("b5d8", 7), max_sim, false);
+ verify(gen_float("d8", 3), gen_float("b5_2d8", 7), min_hamming, false);
+}
+
+vespalib::string max_sim_2d_dist = "reduce(reduce(a*b,sum,d,e),max,b)";
+
+TEST(BestSimilarityFunctionTest, similarity_must_use_1d_vector) {
+ verify(gen_float("d8_1", 3), gen_float("b5d8_1", 7), max_sim, false);
+ verify(gen_float("d8e1", 3), gen_float("b5d8e1", 7), max_sim_2d_dist, false);
+}
+
+vespalib::string inv_max_sim = "reduce(reduce(a*b,sum,b),max,d)";
+
+TEST(BestSimilarityFunctionTest, similarity_dimension_must_be_inner) {
+ verify(gen_float("d8e3", 3), gen_float("b5d8", 7), max_sim, false);
+ verify(gen_float("d8", 3), gen_float("b5d8", 7), inv_max_sim, false);
+}
+
+vespalib::string max_sim_2d_best = "reduce(reduce(a*b,sum,d),max,a,b)";
+
+TEST(BestSimilarityFunctionTest, alternatives_must_use_a_single_dimension) {
+ verify(gen_float("d8", 3), gen_float("a1b5d8", 7), max_sim_2d_best, false);
+}
+
+TEST(BestSimilarityFunctionTest, alternatives_dimension_can_not_be_common) {
+ verify(gen_float("b5d8", 3), gen_float("b5d8", 7), max_sim, false);
+}
+
+TEST(BestSimilarityFunctionTest, extra_common_nontrivial_dimensions_not_allowed) {
+ verify(gen_float("a3d8", 3), gen_float("a3b5d8", 7), max_sim, false);
+}
+
+TEST(BestSimilarityFunctionTest, secondary_tensor_must_not_contain_extra_nontrivial_dimensions) {
+ verify(gen_float("d8", 3), gen_float("a2b5d8", 7), max_sim, false);
+}
+
+//-----------------------------------------------------------------------------
+
+vespalib::string other_join = "reduce(reduce(a+b,sum,d),max,b)";
+vespalib::string mismatch_best = "reduce(reduce(a*b,sum,d),min,b)";
+
+TEST(BestSimilarityFunctionTest, similar_expressions_are_not_optimized) {
+ verify(gen_float("d8", 3), gen_float("b5d8", 7), other_join, false);
+ verify(gen_float("d8", 3), gen_float("b5d8", 7), mismatch_best, false);
+}
+
+//-----------------------------------------------------------------------------
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
index f68c089e784..4998885c6a6 100644
--- a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
+++ b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
@@ -115,11 +115,9 @@ TEST(SumMaxDotProduct, similar_expressions_are_not_optimized) {
vespalib::string max_sum_expr = "reduce(reduce(reduce(a*b,sum,z),sum,y),max,x)";
vespalib::string not_dp_expr1 = "reduce(reduce(reduce(a+b,sum,z),max,y),sum,x)";
vespalib::string not_dp_expr2 = "reduce(reduce(reduce(a*b,min,z),max,y),sum,x)";
- vespalib::string sum_all_expr = "reduce(reduce(reduce(a*b,sum,z),max,y),sum)";
assert_not_optimized(query, document, max_sum_expr);
assert_not_optimized(query, document, not_dp_expr1);
assert_not_optimized(query, document, not_dp_expr2);
- assert_not_optimized(query, document, sum_all_expr);
}
//-----------------------------------------------------------------------------