diff options
Diffstat (limited to 'searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp')
-rw-r--r-- | searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp | 52 |
1 files changed, 49 insertions, 3 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 5d933cb1285..e8c83b8548a 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -6,6 +6,7 @@ #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/eval/tensor/dense/dense_tensor.h> #include <vespa/eval/tensor/tensor.h> +#include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/common/feature.h> #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h> @@ -23,6 +24,7 @@ LOG_SETUP("nearest_neighbor_test"); using search::feature_t; using search::tensor::DenseTensorAttribute; using search::AttributeVector; +using search::BitVector; using vespalib::eval::ValueType; using CellType = vespalib::eval::ValueType::CellType; using vespalib::eval::TensorSpec; @@ -65,13 +67,15 @@ struct Fixture vespalib::string _typeSpec; std::shared_ptr<DenseTensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; + std::unique_ptr<BitVector> _global_filter; Fixture(const vespalib::string &typeSpec) : _cfg(BasicType::TENSOR, CollectionType::SINGLE), _name("test"), _typeSpec(typeSpec), _tensorAttr(), - _attr() + _attr(), + _global_filter() { _cfg.setTensorType(ValueType::from_spec(typeSpec)); _tensorAttr = makeAttr(); @@ -93,6 +97,15 @@ struct Fixture } } + void setFilter(std::vector<uint32_t> docids) { + uint32_t sz = _attr->getNumDocs(); + _global_filter = BitVector::create(sz); + for (uint32_t id : docids) { + EXPECT_LESS(id, sz); + _global_filter->setBit(id); + } + } + void setTensor(uint32_t docId, const Tensor &tensor) { ensureSpace(docId); _tensorAttr->setTensor(docId, tensor); @@ -119,7 +132,8 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) { auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun()); + const BitVector *filter = env._global_filter.get(); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); } else { @@ -158,13 +172,45 @@ TEST("require that NearestNeighborIterator returns expected results") { TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat)); } +void +verify_iterator_returns_filtered_results(const vespalib::string& attribute_tensor_type_spec, + const vespalib::string& query_tensor_type_spec) +{ + Fixture fixture(attribute_tensor_type_spec); + fixture.ensureSpace(6); + fixture.setFilter({1,3,4}); + fixture.setTensor(1, 3.0, 4.0); + fixture.setTensor(2, 6.0, 8.0); + fixture.setTensor(3, 5.0, 12.0); + fixture.setTensor(4, 4.0, 3.0); + fixture.setTensor(5, 8.0, 6.0); + fixture.setTensor(6, 4.0, 3.0); + auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0); + SimpleResult result = find_matches<true>(fixture, *nullTensor); + SimpleResult nullExpect({1,3,4}); + EXPECT_EQUAL(result, nullExpect); + result = find_matches<false>(fixture, *nullTensor); + EXPECT_EQUAL(result, nullExpect); + auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0); + SimpleResult farExpect({1,3,4}); + result = find_matches<true>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); + result = find_matches<false>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); +} + +TEST("require that NearestNeighborIterator returns filtered results") { + TEST_DO(verify_iterator_returns_filtered_results(denseSpecDouble, denseSpecDouble)); + TEST_DO(verify_iterator_returns_filtered_results(denseSpecFloat, denseSpecFloat)); +} + template <bool strict> std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun()); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); |