diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-02-02 10:03:07 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-02-02 10:03:59 +0000 |
commit | add5dcb82527b280208a9526a0dd21ba3f01e271 (patch) | |
tree | dd9393ce9f193b8a22cdffc82386ee7b2dd9809e /eval/src/tests/instruction/sum_max_dot_product_function | |
parent | c576a4373cff207e0b4699ca86c1edbed1b68ab1 (diff) |
use GenSpec in sum_max_dot_product_function_test
Diffstat (limited to 'eval/src/tests/instruction/sum_max_dot_product_function')
-rw-r--r-- | eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp | 51 |
1 files changed, 29 insertions, 22 deletions
diff --git a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp index 4b89f30d879..616649e914b 100644 --- a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp +++ b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp @@ -3,7 +3,7 @@ #include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/tensor_function.h> #include <vespa/eval/eval/test/eval_fixture.h> -#include <vespa/eval/eval/test/tensor_model.hpp> +#include <vespa/eval/eval/test/gen_spec.h> #include <vespa/eval/instruction/sum_max_dot_product_function.h> #include <vespa/vespalib/gtest/gtest.h> @@ -13,12 +13,6 @@ using namespace vespalib::eval::test; const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); -struct MyVecSeq : Sequence { - double bias; - double operator[](size_t i) const override { return (i + bias); } - MyVecSeq(double cellBias) : bias(cellBias) {} -}; - //----------------------------------------------------------------------------- vespalib::string main_expr = "reduce(reduce(reduce(a*b,sum,z),max,y),sum,x)"; @@ -34,7 +28,7 @@ void assert_optimized(const TensorSpec &a, const TensorSpec &b, size_t dp_size) auto info = fast_fixture.find_all<SumMaxDotProductFunction>(); ASSERT_EQ(info.size(), 1u); EXPECT_TRUE(info[0]->result_is_mutable()); - EXPECT_EQUAL(info[0]->dp_size(), dp_size); + EXPECT_EQ(info[0]->dp_size(), dp_size); } void assert_not_optimized(const TensorSpec &a, const TensorSpec &b, const vespalib::string &expr = main_expr) { @@ -51,10 +45,23 @@ void assert_not_optimized(const TensorSpec &a, const TensorSpec &b, const vespal //----------------------------------------------------------------------------- -auto query = spec(float_cells({x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5)); -auto document = spec(float_cells({y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5)); -auto empty_query = spec(float_cells({x({}),z(5)}), MyVecSeq(0.5)); -auto empty_document = spec(float_cells({y({}),z(5)}), MyVecSeq(2.5)); +GenSpec QueGen(size_t x_size, size_t z_size) { return GenSpec().cells_float().map("x", x_size).idx("z", z_size).seq_bias(0.5); } + +GenSpec DocGen(size_t y_size, size_t z_size) { return GenSpec().cells_float().map("y", y_size).idx("z", z_size).seq_bias(2.5); } + +GenSpec Que() { return QueGen(3, 5); } +GenSpec Doc() { return DocGen(6, 5); } + +GenSpec QueX0() { return QueGen(0, 5); } +GenSpec DocX0() { return DocGen(0, 5); } + +GenSpec QueZ1() { return QueGen(3, 1); } +GenSpec DocZ1() { return DocGen(6, 1); } + +auto query = Que().gen(); +auto document = Doc().gen(); +auto empty_query = QueX0().gen(); +auto empty_document = DocX0().gen(); TEST(SumMaxDotProduct, expressions_can_be_optimized) { @@ -66,24 +73,24 @@ TEST(SumMaxDotProduct, expressions_can_be_optimized) } TEST(SumMaxDotProduct, double_cells_are_not_optimized) { - auto double_query = spec({x({"0", "1", "2"}),z(5)}, MyVecSeq(0.5)); - auto double_document = spec({y({"0", "1", "2", "3", "4", "5"}),z(5)}, MyVecSeq(2.5)); + auto double_query = Que().cells_double().gen(); + auto double_document = Doc().cells_double().gen(); assert_not_optimized(query, double_document); assert_not_optimized(double_query, document); assert_not_optimized(double_query, double_document); } TEST(SumMaxDotProduct, trivial_dot_product_is_not_optimized) { - auto trivial_query = spec(float_cells({x({"0", "1", "2"}),z(1)}), MyVecSeq(0.5)); - auto trivial_document = spec(float_cells({y({"0", "1", "2", "3", "4", "5"}),z(1)}), MyVecSeq(2.5)); + auto trivial_query = QueZ1().gen(); + auto trivial_document = DocZ1().gen(); assert_not_optimized(trivial_query, trivial_document); } TEST(SumMaxDotProduct, additional_dimensions_are_not_optimized) { - auto extra_sparse_query = spec(float_cells({Domain("a", {"0"}),x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5)); - auto extra_dense_query = spec(float_cells({Domain("a", 1),x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5)); - auto extra_sparse_document = spec(float_cells({Domain("a", {"0"}),y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5)); - auto extra_dense_document = spec(float_cells({Domain("a", 1),y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5)); + auto extra_sparse_query = Que().map("a", 1).gen(); + auto extra_dense_query = Que().idx("a", 1).gen(); + auto extra_sparse_document = Doc().map("a", 1).gen(); + auto extra_dense_document = Doc().idx("a", 1).gen(); vespalib::string extra_sum_expr = "reduce(reduce(reduce(a*b,sum,z),max,y),sum,a,x)"; vespalib::string extra_max_expr = "reduce(reduce(reduce(a*b,sum,z),max,a,y),sum,x)"; assert_not_optimized(extra_sparse_query, document); @@ -97,8 +104,8 @@ TEST(SumMaxDotProduct, additional_dimensions_are_not_optimized) { } TEST(SumMaxDotProduct, more_dense_variants_are_not_optimized) { - auto dense_query = spec(float_cells({x(3),z(5)}), MyVecSeq(0.5)); - auto dense_document = spec(float_cells({y(5),z(5)}), MyVecSeq(2.5)); + auto dense_query = GenSpec().cells_float().idx("x", 3).idx("z", 5).seq_bias(0.5).gen(); + auto dense_document = GenSpec().cells_float().idx("y", 5).idx("z", 5).seq_bias(2.5).gen(); assert_not_optimized(dense_query, document); assert_not_optimized(query, dense_document); assert_not_optimized(dense_query, dense_document); |