summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-12-05 10:04:52 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-12-05 13:30:46 +0000
commitf093b271f1f6aafa37079a889ae5d621db275dcb (patch)
tree840482cbb42ecffb42bdef1bea0d7647cf25984f /searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
parentc6484309cfe178a5d2610405460cfb0d4a89db4c (diff)
Allow nearest neighbor operator where attribute tensor and query tensor have different cell types (float vs double).
Diffstat (limited to 'searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp')
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp43
1 files changed, 32 insertions, 11 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 25ff459c005..7bc582ab442 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -28,7 +28,8 @@ using vespalib::tensor::DefaultTensorEngine;
using namespace search::fef;
using namespace search::queryeval;
-vespalib::string denseSpec("tensor(x[2])");
+vespalib::string denseSpecDouble("tensor(x[2])");
+vespalib::string denseSpecFloat("tensor<float>(x[2])");
std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) {
auto value = DefaultTensorEngine::ref().from_spec(spec);
@@ -38,8 +39,8 @@ std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) {
return std::unique_ptr<DenseTensorView>(tensor);
}
-std::unique_ptr<DenseTensorView> createTensor(double v1, double v2) {
- return createTensor(TensorSpec(denseSpec).add({{"x", 0}}, v1)
+std::unique_ptr<DenseTensorView> createTensor(const vespalib::string& type_spec, double v1, double v2) {
+ return createTensor(TensorSpec(type_spec).add({{"x", 0}}, v1)
.add({{"x", 1}}, v2));
}
@@ -89,7 +90,7 @@ struct Fixture
}
void setTensor(uint32_t docId, double v1, double v2) {
- auto t = createTensor(v1, v2);
+ auto t = createTensor(_typeSpec, v1, v2);
setTensor(docId, *t);
}
};
@@ -108,8 +109,11 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
}
}
-TEST("require that NearestNeighborIterator returns expected results") {
- Fixture fixture(denseSpec);
+void
+verify_iterator_returns_expected_results(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
fixture.ensureSpace(6);
fixture.setTensor(1, 3.0, 4.0);
fixture.setTensor(2, 6.0, 8.0);
@@ -117,13 +121,13 @@ TEST("require that NearestNeighborIterator returns expected results") {
fixture.setTensor(4, 4.0, 3.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- auto nullTensor = createTensor(0.0, 0.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
SimpleResult result = find_matches<true>(fixture, *nullTensor);
SimpleResult nullExpect({1,2,4,6});
EXPECT_EQUAL(result, nullExpect);
result = find_matches<false>(fixture, *nullTensor);
EXPECT_EQUAL(result, nullExpect);
- auto farTensor = createTensor(9.0, 9.0);
+ auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0);
SimpleResult farExpect({1,2,3,5});
result = find_matches<true>(fixture, *farTensor);
EXPECT_EQUAL(result, farExpect);
@@ -131,6 +135,13 @@ TEST("require that NearestNeighborIterator returns expected results") {
EXPECT_EQUAL(result, farExpect);
}
+TEST("require that NearestNeighborIterator returns expected results") {
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecFloat));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecDouble));
+}
+
template <bool strict>
std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
@@ -152,8 +163,11 @@ std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
return rv;
}
-TEST("require that NearestNeighborIterator sets expected rawscore") {
- Fixture fixture(denseSpec);
+void
+verify_iterator_sets_expected_rawscore(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
fixture.ensureSpace(6);
fixture.setTensor(1, 3.0, 4.0);
fixture.setTensor(2, 5.0, 12.0);
@@ -161,7 +175,7 @@ TEST("require that NearestNeighborIterator sets expected rawscore") {
fixture.setTensor(4, 5.0, 12.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- auto nullTensor = createTensor(0.0, 0.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
std::vector<feature_t> got = get_rawscores<true>(fixture, *nullTensor);
std::vector<feature_t> expected{5.0, 13.0, 10.0, 10.0, 5.0};
EXPECT_EQUAL(got, expected);
@@ -169,4 +183,11 @@ TEST("require that NearestNeighborIterator sets expected rawscore") {
EXPECT_EQUAL(got, expected);
}
+TEST("require that NearestNeighborIterator sets expected rawscore") {
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecFloat));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecFloat));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecDouble));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }