summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp')
-rw-r--r--searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp52
1 files changed, 42 insertions, 10 deletions
diff --git a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
index a32d52d06a0..72c9f9db165 100644
--- a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
+++ b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
@@ -86,8 +86,11 @@ struct FixtureBase : ImportedAttributeFixture {
const vespalib::string& vector,
DocId doc_id,
const vespalib::string& shared_param = "") {
+ check_executions<int32_t>([this](auto int_type){ this->setup_integer_mappings(int_type); },
+ {{BasicType::INT32}},
+ expected, vector, doc_id, shared_param);
check_executions<int64_t>([this](auto int_type){ this->setup_integer_mappings(int_type); },
- {{BasicType::INT32, BasicType::INT64}},
+ {{BasicType::INT64}},
expected, vector, doc_id, shared_param);
}
};
@@ -95,21 +98,47 @@ struct FixtureBase : ImportedAttributeFixture {
struct ArrayFixture : FixtureBase {
~ArrayFixture() override;
- void setup_integer_mappings(BasicType int_type) override {
- reset_with_array_value_reference_mappings<IntegerAttribute, int64_t>(
+ template <typename T>
+ void setup_integer_mappings_helper(BasicType int_type) {
+ reset_with_array_value_reference_mappings<IntegerAttribute, T>(
int_type,
{{DocId(1), dummy_gid(3), DocId(3), {{2, 3, 5}}},
{DocId(3), dummy_gid(7), DocId(7), {{7, 11}}},
{DocId(5), dummy_gid(8), DocId(8), {{13, 17, 19, 23}}}});
}
+ void setup_integer_mappings(BasicType int_type) override {
+ switch (int_type.type()) {
+ case BasicType::INT32:
+ setup_integer_mappings_helper<int32_t>(int_type);
+ break;
+ case BasicType::INT64:
+ setup_integer_mappings_helper<int64_t>(int_type);
+ break;
+ default:
+ TEST_FATAL("unexpected integer type");
+ }
+ }
- void setup_float_mappings(BasicType float_type) {
- reset_with_array_value_reference_mappings<FloatingPointAttribute, double>(
+ template <typename T>
+ void setup_float_mappings_helper(BasicType float_type) {
+ reset_with_array_value_reference_mappings<FloatingPointAttribute, T>(
float_type,
{{DocId(2), dummy_gid(4), DocId(4), {{2.2, 3.3, 5.5}}},
{DocId(4), dummy_gid(8), DocId(8), {{7.7, 11.11}}},
{DocId(6), dummy_gid(9), DocId(9), {{13.1, 17.2, 19.3, 23.4}}}});
}
+ void setup_float_mappings(BasicType float_type) {
+ switch(float_type.type()) {
+ case BasicType::FLOAT:
+ setup_float_mappings_helper<float>(float_type);
+ break;
+ case BasicType::DOUBLE:
+ setup_float_mappings_helper<double>(float_type);
+ break;
+ default:
+ TEST_FATAL("unexpected float type");
+ }
+ }
template <typename ExpectedType>
void check_prepare_state_output(const vespalib::eval::Value & tensor, const ExpectedType & expected) {
@@ -163,8 +192,11 @@ struct ArrayFixture : FixtureBase {
void check_all_float_executions(feature_t expected, const vespalib::string& vector,
DocId doc_id, const vespalib::string& shared_param = "")
{
+ check_executions<float>([this](auto float_type){ this->setup_float_mappings(float_type); },
+ {{BasicType::FLOAT}},
+ expected, vector, doc_id, shared_param);
check_executions<double>([this](auto float_type){ this->setup_float_mappings(float_type); },
- {{BasicType::FLOAT, BasicType::DOUBLE}},
+ {{BasicType::DOUBLE}},
expected, vector, doc_id, shared_param);
}
};
@@ -187,9 +219,9 @@ TEST_F("Zero-length float/double array query vector evaluates to zero", ArrayFix
f.check_all_float_executions(0, "[]", DocId(1));
}
-TEST_F("prepareSharedState emits i64 vector for i32 imported attribute", ArrayFixture) {
+TEST_F("prepareSharedState emits i32 vector for i32 imported attribute", ArrayFixture) {
f.setup_integer_mappings(BasicType::INT32);
- f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
+ f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int32_t>({101, 202, 303}));
}
TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFixture) {
@@ -197,9 +229,9 @@ TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFi
f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
}
-TEST_F("prepareSharedState emits double vector for float imported attribute", ArrayFixture) {
+TEST_F("prepareSharedState emits float vector for float imported attribute", ArrayFixture) {
f.setup_float_mappings(BasicType::FLOAT);
- f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+ f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<float>({10.1, 20.2, 30.3}));
}
TEST_F("prepareSharedState emits double vector for double imported attribute", ArrayFixture) {