aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/sparse_dot_product_function
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-02-01 15:13:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-02-01 18:38:36 +0000
commit0261658338a6f7ad28bfca6f16f8a4b7c35d9cae (patch)
tree14f4494de17b6a64fc2f916fafcf66c1e723ec93 /eval/src/tests/instruction/sparse_dot_product_function
parentfe6300b1e9b81c09aa0235b5049439198c6a2206 (diff)
sparse dot product
Diffstat (limited to 'eval/src/tests/instruction/sparse_dot_product_function')
-rw-r--r--eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp85
2 files changed, 94 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt
new file mode 100644
index 00000000000..076f1d79796
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_sparse_dot_product_function_test_app TEST
+ SOURCES
+ sparse_dot_product_function_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_sparse_dot_product_function_test_app COMMAND eval_sparse_dot_product_function_test_app)
diff --git a/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp
new file mode 100644
index 00000000000..65eab2778aa
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp
@@ -0,0 +1,85 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/instruction/sparse_dot_product_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
+
+//-----------------------------------------------------------------------------
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("v1_x", GenSpec().map("x", 32, 1).seq_bias(3.0).gen())
+ .add("v1_x_f", GenSpec().map("x", 32, 1).seq_bias(3.0).cells_float().gen())
+ .add("v2_x", GenSpec().map("x", 16, 2).seq_bias(7.0).gen())
+ .add("v2_x_f", GenSpec().map("x", 16, 2).seq_bias(7.0).cells_float().gen())
+ .add("v3_y", GenSpec().map("y", 10, 1).gen())
+ .add("v4_xd", GenSpec().idx("x", 10).gen())
+ .add("m1_xy", GenSpec().map("x", 32, 1).map("y", 16, 2).seq_bias(3.0).gen())
+ .add("m2_xy", GenSpec().map("x", 16, 2).map("y", 32, 1).seq_bias(7.0).gen())
+ .add("m3_xym", GenSpec().map("x", 8, 1).idx("y", 5).gen());
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void assert_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EvalFixture test_fixture(test_factory, expr, param_repo, true);
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 1u);
+ EXPECT_EQ(test_fixture.find_all<SparseDotProductFunction>().size(), 1u);
+ EXPECT_EQ(slow_fixture.find_all<SparseDotProductFunction>().size(), 0u);
+}
+
+void assert_not_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 0u);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST(SparseDotProduct, expression_can_be_optimized)
+{
+ assert_optimized("reduce(v1_x*v2_x,sum,x)");
+ assert_optimized("reduce(v2_x*v1_x,sum)");
+ assert_optimized("reduce(v1_x*v2_x_f,sum)");
+ assert_optimized("reduce(v1_x_f*v2_x,sum)");
+ assert_optimized("reduce(v1_x_f*v2_x_f,sum)");
+}
+
+TEST(SparseDotProduct, multi_dimensional_expression_can_be_optimized)
+{
+ assert_optimized("reduce(m1_xy*m2_xy,sum,x,y)");
+ assert_optimized("reduce(m1_xy*m2_xy,sum)");
+}
+
+TEST(SparseDotProduct, embedded_dot_product_is_not_optimized)
+{
+ assert_not_optimized("reduce(m1_xy*v1_x,sum,x)");
+ assert_not_optimized("reduce(v1_x*m1_xy,sum,x)");
+}
+
+TEST(SparseDotProduct, similar_expressions_are_not_optimized)
+{
+ assert_not_optimized("reduce(m1_xy*v1_x,sum)");
+ assert_not_optimized("reduce(v1_x*v3_y,sum)");
+ assert_not_optimized("reduce(v2_x*v1_x,max)");
+ assert_not_optimized("reduce(v2_x+v1_x,sum)");
+ assert_not_optimized("reduce(v4_xd*v4_xd,sum)");
+ assert_not_optimized("reduce(m3_xym*m3_xym,sum)");
+}
+
+//-----------------------------------------------------------------------------
+
+GTEST_MAIN_RUN_ALL_TESTS()