diff options
Diffstat (limited to 'searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp')
-rw-r--r-- | searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp | 52 |
1 files changed, 42 insertions, 10 deletions
diff --git a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp index a32d52d06a0..72c9f9db165 100644 --- a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp +++ b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp @@ -86,8 +86,11 @@ struct FixtureBase : ImportedAttributeFixture { const vespalib::string& vector, DocId doc_id, const vespalib::string& shared_param = "") { + check_executions<int32_t>([this](auto int_type){ this->setup_integer_mappings(int_type); }, + {{BasicType::INT32}}, + expected, vector, doc_id, shared_param); check_executions<int64_t>([this](auto int_type){ this->setup_integer_mappings(int_type); }, - {{BasicType::INT32, BasicType::INT64}}, + {{BasicType::INT64}}, expected, vector, doc_id, shared_param); } }; @@ -95,21 +98,47 @@ struct FixtureBase : ImportedAttributeFixture { struct ArrayFixture : FixtureBase { ~ArrayFixture() override; - void setup_integer_mappings(BasicType int_type) override { - reset_with_array_value_reference_mappings<IntegerAttribute, int64_t>( + template <typename T> + void setup_integer_mappings_helper(BasicType int_type) { + reset_with_array_value_reference_mappings<IntegerAttribute, T>( int_type, {{DocId(1), dummy_gid(3), DocId(3), {{2, 3, 5}}}, {DocId(3), dummy_gid(7), DocId(7), {{7, 11}}}, {DocId(5), dummy_gid(8), DocId(8), {{13, 17, 19, 23}}}}); } + void setup_integer_mappings(BasicType int_type) override { + switch (int_type.type()) { + case BasicType::INT32: + setup_integer_mappings_helper<int32_t>(int_type); + break; + case BasicType::INT64: + setup_integer_mappings_helper<int64_t>(int_type); + break; + default: + TEST_FATAL("unexpected integer type"); + } + } - void setup_float_mappings(BasicType float_type) { - reset_with_array_value_reference_mappings<FloatingPointAttribute, double>( + template <typename T> + void setup_float_mappings_helper(BasicType float_type) { + reset_with_array_value_reference_mappings<FloatingPointAttribute, T>( float_type, {{DocId(2), dummy_gid(4), DocId(4), {{2.2, 3.3, 5.5}}}, {DocId(4), dummy_gid(8), DocId(8), {{7.7, 11.11}}}, {DocId(6), dummy_gid(9), DocId(9), {{13.1, 17.2, 19.3, 23.4}}}}); } + void setup_float_mappings(BasicType float_type) { + switch(float_type.type()) { + case BasicType::FLOAT: + setup_float_mappings_helper<float>(float_type); + break; + case BasicType::DOUBLE: + setup_float_mappings_helper<double>(float_type); + break; + default: + TEST_FATAL("unexpected float type"); + } + } template <typename ExpectedType> void check_prepare_state_output(const vespalib::eval::Value & tensor, const ExpectedType & expected) { @@ -163,8 +192,11 @@ struct ArrayFixture : FixtureBase { void check_all_float_executions(feature_t expected, const vespalib::string& vector, DocId doc_id, const vespalib::string& shared_param = "") { + check_executions<float>([this](auto float_type){ this->setup_float_mappings(float_type); }, + {{BasicType::FLOAT}}, + expected, vector, doc_id, shared_param); check_executions<double>([this](auto float_type){ this->setup_float_mappings(float_type); }, - {{BasicType::FLOAT, BasicType::DOUBLE}}, + {{BasicType::DOUBLE}}, expected, vector, doc_id, shared_param); } }; @@ -187,9 +219,9 @@ TEST_F("Zero-length float/double array query vector evaluates to zero", ArrayFix f.check_all_float_executions(0, "[]", DocId(1)); } -TEST_F("prepareSharedState emits i64 vector for i32 imported attribute", ArrayFixture) { +TEST_F("prepareSharedState emits i32 vector for i32 imported attribute", ArrayFixture) { f.setup_integer_mappings(BasicType::INT32); - f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303})); + f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int32_t>({101, 202, 303})); } TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFixture) { @@ -197,9 +229,9 @@ TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFi f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303})); } -TEST_F("prepareSharedState emits double vector for float imported attribute", ArrayFixture) { +TEST_F("prepareSharedState emits float vector for float imported attribute", ArrayFixture) { f.setup_float_mappings(BasicType::FLOAT); - f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3})); + f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<float>({10.1, 20.2, 30.3})); } TEST_F("prepareSharedState emits double vector for double imported attribute", ArrayFixture) { |