diff options
author | Håvard Pettersen <havardpe@oath.com> | 2021-02-01 15:13:22 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2021-02-01 18:38:36 +0000 |
commit | 0261658338a6f7ad28bfca6f16f8a4b7c35d9cae (patch) | |
tree | 14f4494de17b6a64fc2f916fafcf66c1e723ec93 /eval/src/tests/instruction/sparse_dot_product_function | |
parent | fe6300b1e9b81c09aa0235b5049439198c6a2206 (diff) |
sparse dot product
Diffstat (limited to 'eval/src/tests/instruction/sparse_dot_product_function')
-rw-r--r-- | eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt | 9 | ||||
-rw-r--r-- | eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp | 85 |
2 files changed, 94 insertions, 0 deletions
diff --git a/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt new file mode 100644 index 00000000000..076f1d79796 --- /dev/null +++ b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_sparse_dot_product_function_test_app TEST + SOURCES + sparse_dot_product_function_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_sparse_dot_product_function_test_app COMMAND eval_sparse_dot_product_function_test_app) diff --git a/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp new file mode 100644 index 00000000000..65eab2778aa --- /dev/null +++ b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp @@ -0,0 +1,85 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/instruction/sparse_dot_product_function.h> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/eval/eval/test/gen_spec.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::test; + +const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); +const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get(); + +//----------------------------------------------------------------------------- + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add("v1_x", GenSpec().map("x", 32, 1).seq_bias(3.0).gen()) + .add("v1_x_f", GenSpec().map("x", 32, 1).seq_bias(3.0).cells_float().gen()) + .add("v2_x", GenSpec().map("x", 16, 2).seq_bias(7.0).gen()) + .add("v2_x_f", GenSpec().map("x", 16, 2).seq_bias(7.0).cells_float().gen()) + .add("v3_y", GenSpec().map("y", 10, 1).gen()) + .add("v4_xd", GenSpec().idx("x", 10).gen()) + .add("m1_xy", GenSpec().map("x", 32, 1).map("y", 16, 2).seq_bias(3.0).gen()) + .add("m2_xy", GenSpec().map("x", 16, 2).map("y", 32, 1).seq_bias(7.0).gen()) + .add("m3_xym", GenSpec().map("x", 8, 1).idx("y", 5).gen()); +} +EvalFixture::ParamRepo param_repo = make_params(); + +void assert_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EvalFixture test_fixture(test_factory, expr, param_repo, true); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 1u); + EXPECT_EQ(test_fixture.find_all<SparseDotProductFunction>().size(), 1u); + EXPECT_EQ(slow_fixture.find_all<SparseDotProductFunction>().size(), 0u); +} + +void assert_not_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 0u); +} + +//----------------------------------------------------------------------------- + +TEST(SparseDotProduct, expression_can_be_optimized) +{ + assert_optimized("reduce(v1_x*v2_x,sum,x)"); + assert_optimized("reduce(v2_x*v1_x,sum)"); + assert_optimized("reduce(v1_x*v2_x_f,sum)"); + assert_optimized("reduce(v1_x_f*v2_x,sum)"); + assert_optimized("reduce(v1_x_f*v2_x_f,sum)"); +} + +TEST(SparseDotProduct, multi_dimensional_expression_can_be_optimized) +{ + assert_optimized("reduce(m1_xy*m2_xy,sum,x,y)"); + assert_optimized("reduce(m1_xy*m2_xy,sum)"); +} + +TEST(SparseDotProduct, embedded_dot_product_is_not_optimized) +{ + assert_not_optimized("reduce(m1_xy*v1_x,sum,x)"); + assert_not_optimized("reduce(v1_x*m1_xy,sum,x)"); +} + +TEST(SparseDotProduct, similar_expressions_are_not_optimized) +{ + assert_not_optimized("reduce(m1_xy*v1_x,sum)"); + assert_not_optimized("reduce(v1_x*v3_y,sum)"); + assert_not_optimized("reduce(v2_x*v1_x,max)"); + assert_not_optimized("reduce(v2_x+v1_x,sum)"); + assert_not_optimized("reduce(v4_xd*v4_xd,sum)"); + assert_not_optimized("reduce(m3_xym*m3_xym,sum)"); +} + +//----------------------------------------------------------------------------- + +GTEST_MAIN_RUN_ALL_TESTS() |