aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp')
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp52
1 files changed, 49 insertions, 3 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 5d933cb1285..e8c83b8548a 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -6,6 +6,7 @@
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/tensor.h>
+#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/common/feature.h>
#include <vespa/searchlib/fef/matchdata.h>
#include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h>
@@ -23,6 +24,7 @@ LOG_SETUP("nearest_neighbor_test");
using search::feature_t;
using search::tensor::DenseTensorAttribute;
using search::AttributeVector;
+using search::BitVector;
using vespalib::eval::ValueType;
using CellType = vespalib::eval::ValueType::CellType;
using vespalib::eval::TensorSpec;
@@ -65,13 +67,15 @@ struct Fixture
vespalib::string _typeSpec;
std::shared_ptr<DenseTensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
+ std::unique_ptr<BitVector> _global_filter;
Fixture(const vespalib::string &typeSpec)
: _cfg(BasicType::TENSOR, CollectionType::SINGLE),
_name("test"),
_typeSpec(typeSpec),
_tensorAttr(),
- _attr()
+ _attr(),
+ _global_filter()
{
_cfg.setTensorType(ValueType::from_spec(typeSpec));
_tensorAttr = makeAttr();
@@ -93,6 +97,15 @@ struct Fixture
}
}
+ void setFilter(std::vector<uint32_t> docids) {
+ uint32_t sz = _attr->getNumDocs();
+ _global_filter = BitVector::create(sz);
+ for (uint32_t id : docids) {
+ EXPECT_LESS(id, sz);
+ _global_filter->setBit(id);
+ }
+ }
+
void setTensor(uint32_t docId, const Tensor &tensor) {
ensureSpace(docId);
_tensorAttr->setTensor(docId, tensor);
@@ -119,7 +132,8 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun());
+ const BitVector *filter = env._global_filter.get();
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun());
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
} else {
@@ -158,13 +172,45 @@ TEST("require that NearestNeighborIterator returns expected results") {
TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat));
}
+void
+verify_iterator_returns_filtered_results(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
+ fixture.ensureSpace(6);
+ fixture.setFilter({1,3,4});
+ fixture.setTensor(1, 3.0, 4.0);
+ fixture.setTensor(2, 6.0, 8.0);
+ fixture.setTensor(3, 5.0, 12.0);
+ fixture.setTensor(4, 4.0, 3.0);
+ fixture.setTensor(5, 8.0, 6.0);
+ fixture.setTensor(6, 4.0, 3.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
+ SimpleResult result = find_matches<true>(fixture, *nullTensor);
+ SimpleResult nullExpect({1,3,4});
+ EXPECT_EQUAL(result, nullExpect);
+ result = find_matches<false>(fixture, *nullTensor);
+ EXPECT_EQUAL(result, nullExpect);
+ auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0);
+ SimpleResult farExpect({1,3,4});
+ result = find_matches<true>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
+ result = find_matches<false>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
+}
+
+TEST("require that NearestNeighborIterator returns filtered results") {
+ TEST_DO(verify_iterator_returns_filtered_results(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_returns_filtered_results(denseSpecFloat, denseSpecFloat));
+}
+
template <bool strict>
std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun());
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);