aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp35
1 files changed, 22 insertions, 13 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index 68c6a1603d0..07e10271c55 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "nearest_neighbor_iterator.h"
+#include <vespa/searchlib/common/bitvector.h>
using search::tensor::DenseTensorAttribute;
using vespalib::ConstArrayRef;
@@ -29,7 +30,7 @@ is_compatible(const vespalib::eval::ValueType& lhs,
* Keeps a heap of the K best hit distances.
* Currently always does brute-force scanning, which is very expensive.
**/
-template <bool strict>
+template <bool strict, bool has_filter>
class NearestNeighborImpl : public NearestNeighborIterator
{
public:
@@ -48,11 +49,13 @@ public:
void doSeek(uint32_t docId) override {
double distanceLimit = params().distanceHeap.distanceLimit();
while (__builtin_expect((docId < getEndId()), true)) {
- double d = computeDistance(docId, distanceLimit);
- if (d <= distanceLimit) {
- _lastScore = d;
- setDocId(docId);
- return;
+ if ((!has_filter) || params().filter->testBit(docId)) {
+ double d = computeDistance(docId, distanceLimit);
+ if (d <= distanceLimit) {
+ _lastScore = d;
+ setDocId(docId);
+ return;
+ }
}
if (strict) {
++docId;
@@ -83,22 +86,23 @@ private:
double _lastScore;
};
-template <bool strict>
-NearestNeighborImpl<strict>::~NearestNeighborImpl() = default;
+template <bool strict, bool has_filter>
+NearestNeighborImpl<strict, has_filter>::~NearestNeighborImpl() = default;
namespace {
+template <bool has_filter>
std::unique_ptr<NearestNeighborIterator>
-resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params &params)
+resolve_strict(bool strict, const NearestNeighborIterator::Params &params)
{
CellType lct = params.queryTensor.fast_type().cell_type();
CellType rct = params.tensorAttribute.getTensorType().cell_type();
if (lct != rct) abort();
if (strict) {
- using NNI = NearestNeighborImpl<true>;
+ using NNI = NearestNeighborImpl<true, has_filter>;
return std::make_unique<NNI>(params);
} else {
- using NNI = NearestNeighborImpl<false>;
+ using NNI = NearestNeighborImpl<false, has_filter>;
return std::make_unique<NNI>(params);
}
}
@@ -112,11 +116,16 @@ NearestNeighborIterator::create(
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
NearestNeighborDistanceHeap &distanceHeap,
+ const search::BitVector *filter,
const search::tensor::DistanceFunction *dist_fun)
{
- Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, dist_fun);
- return resolve_strict_LCT_RCT(strict, params);
+ Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun);
+ if (filter) {
+ return resolve_strict<true>(strict, params);
+ } else {
+ return resolve_strict<false>(strict, params);
+ }
}
} // namespace