aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/sum_max_dot_product_function
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-02-02 10:03:07 +0000
committerArne Juul <arnej@verizonmedia.com>2021-02-02 10:03:59 +0000
commitadd5dcb82527b280208a9526a0dd21ba3f01e271 (patch)
treedd9393ce9f193b8a22cdffc82386ee7b2dd9809e /eval/src/tests/instruction/sum_max_dot_product_function
parentc576a4373cff207e0b4699ca86c1edbed1b68ab1 (diff)
use GenSpec in sum_max_dot_product_function_test
Diffstat (limited to 'eval/src/tests/instruction/sum_max_dot_product_function')
-rw-r--r--eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp51
1 files changed, 29 insertions, 22 deletions
diff --git a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
index 4b89f30d879..616649e914b 100644
--- a/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
+++ b/eval/src/tests/instruction/sum_max_dot_product_function/sum_max_dot_product_function_test.cpp
@@ -3,7 +3,7 @@
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
-#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/sum_max_dot_product_function.h>
#include <vespa/vespalib/gtest/gtest.h>
@@ -13,12 +13,6 @@ using namespace vespalib::eval::test;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-struct MyVecSeq : Sequence {
- double bias;
- double operator[](size_t i) const override { return (i + bias); }
- MyVecSeq(double cellBias) : bias(cellBias) {}
-};
-
//-----------------------------------------------------------------------------
vespalib::string main_expr = "reduce(reduce(reduce(a*b,sum,z),max,y),sum,x)";
@@ -34,7 +28,7 @@ void assert_optimized(const TensorSpec &a, const TensorSpec &b, size_t dp_size)
auto info = fast_fixture.find_all<SumMaxDotProductFunction>();
ASSERT_EQ(info.size(), 1u);
EXPECT_TRUE(info[0]->result_is_mutable());
- EXPECT_EQUAL(info[0]->dp_size(), dp_size);
+ EXPECT_EQ(info[0]->dp_size(), dp_size);
}
void assert_not_optimized(const TensorSpec &a, const TensorSpec &b, const vespalib::string &expr = main_expr) {
@@ -51,10 +45,23 @@ void assert_not_optimized(const TensorSpec &a, const TensorSpec &b, const vespal
//-----------------------------------------------------------------------------
-auto query = spec(float_cells({x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5));
-auto document = spec(float_cells({y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5));
-auto empty_query = spec(float_cells({x({}),z(5)}), MyVecSeq(0.5));
-auto empty_document = spec(float_cells({y({}),z(5)}), MyVecSeq(2.5));
+GenSpec QueGen(size_t x_size, size_t z_size) { return GenSpec().cells_float().map("x", x_size).idx("z", z_size).seq_bias(0.5); }
+
+GenSpec DocGen(size_t y_size, size_t z_size) { return GenSpec().cells_float().map("y", y_size).idx("z", z_size).seq_bias(2.5); }
+
+GenSpec Que() { return QueGen(3, 5); }
+GenSpec Doc() { return DocGen(6, 5); }
+
+GenSpec QueX0() { return QueGen(0, 5); }
+GenSpec DocX0() { return DocGen(0, 5); }
+
+GenSpec QueZ1() { return QueGen(3, 1); }
+GenSpec DocZ1() { return DocGen(6, 1); }
+
+auto query = Que().gen();
+auto document = Doc().gen();
+auto empty_query = QueX0().gen();
+auto empty_document = DocX0().gen();
TEST(SumMaxDotProduct, expressions_can_be_optimized)
{
@@ -66,24 +73,24 @@ TEST(SumMaxDotProduct, expressions_can_be_optimized)
}
TEST(SumMaxDotProduct, double_cells_are_not_optimized) {
- auto double_query = spec({x({"0", "1", "2"}),z(5)}, MyVecSeq(0.5));
- auto double_document = spec({y({"0", "1", "2", "3", "4", "5"}),z(5)}, MyVecSeq(2.5));
+ auto double_query = Que().cells_double().gen();
+ auto double_document = Doc().cells_double().gen();
assert_not_optimized(query, double_document);
assert_not_optimized(double_query, document);
assert_not_optimized(double_query, double_document);
}
TEST(SumMaxDotProduct, trivial_dot_product_is_not_optimized) {
- auto trivial_query = spec(float_cells({x({"0", "1", "2"}),z(1)}), MyVecSeq(0.5));
- auto trivial_document = spec(float_cells({y({"0", "1", "2", "3", "4", "5"}),z(1)}), MyVecSeq(2.5));
+ auto trivial_query = QueZ1().gen();
+ auto trivial_document = DocZ1().gen();
assert_not_optimized(trivial_query, trivial_document);
}
TEST(SumMaxDotProduct, additional_dimensions_are_not_optimized) {
- auto extra_sparse_query = spec(float_cells({Domain("a", {"0"}),x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5));
- auto extra_dense_query = spec(float_cells({Domain("a", 1),x({"0", "1", "2"}),z(5)}), MyVecSeq(0.5));
- auto extra_sparse_document = spec(float_cells({Domain("a", {"0"}),y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5));
- auto extra_dense_document = spec(float_cells({Domain("a", 1),y({"0", "1", "2", "3", "4", "5"}),z(5)}), MyVecSeq(2.5));
+ auto extra_sparse_query = Que().map("a", 1).gen();
+ auto extra_dense_query = Que().idx("a", 1).gen();
+ auto extra_sparse_document = Doc().map("a", 1).gen();
+ auto extra_dense_document = Doc().idx("a", 1).gen();
vespalib::string extra_sum_expr = "reduce(reduce(reduce(a*b,sum,z),max,y),sum,a,x)";
vespalib::string extra_max_expr = "reduce(reduce(reduce(a*b,sum,z),max,a,y),sum,x)";
assert_not_optimized(extra_sparse_query, document);
@@ -97,8 +104,8 @@ TEST(SumMaxDotProduct, additional_dimensions_are_not_optimized) {
}
TEST(SumMaxDotProduct, more_dense_variants_are_not_optimized) {
- auto dense_query = spec(float_cells({x(3),z(5)}), MyVecSeq(0.5));
- auto dense_document = spec(float_cells({y(5),z(5)}), MyVecSeq(2.5));
+ auto dense_query = GenSpec().cells_float().idx("x", 3).idx("z", 5).seq_bias(0.5).gen();
+ auto dense_document = GenSpec().cells_float().idx("y", 5).idx("z", 5).seq_bias(2.5).gen();
assert_not_optimized(dense_query, document);
assert_not_optimized(query, dense_document);
assert_not_optimized(dense_query, dense_document);