diff options
Diffstat (limited to 'searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp')
-rw-r--r-- | searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp | 35 |
1 files changed, 22 insertions, 13 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index 68c6a1603d0..07e10271c55 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -1,6 +1,7 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "nearest_neighbor_iterator.h" +#include <vespa/searchlib/common/bitvector.h> using search::tensor::DenseTensorAttribute; using vespalib::ConstArrayRef; @@ -29,7 +30,7 @@ is_compatible(const vespalib::eval::ValueType& lhs, * Keeps a heap of the K best hit distances. * Currently always does brute-force scanning, which is very expensive. **/ -template <bool strict> +template <bool strict, bool has_filter> class NearestNeighborImpl : public NearestNeighborIterator { public: @@ -48,11 +49,13 @@ public: void doSeek(uint32_t docId) override { double distanceLimit = params().distanceHeap.distanceLimit(); while (__builtin_expect((docId < getEndId()), true)) { - double d = computeDistance(docId, distanceLimit); - if (d <= distanceLimit) { - _lastScore = d; - setDocId(docId); - return; + if ((!has_filter) || params().filter->testBit(docId)) { + double d = computeDistance(docId, distanceLimit); + if (d <= distanceLimit) { + _lastScore = d; + setDocId(docId); + return; + } } if (strict) { ++docId; @@ -83,22 +86,23 @@ private: double _lastScore; }; -template <bool strict> -NearestNeighborImpl<strict>::~NearestNeighborImpl() = default; +template <bool strict, bool has_filter> +NearestNeighborImpl<strict, has_filter>::~NearestNeighborImpl() = default; namespace { +template <bool has_filter> std::unique_ptr<NearestNeighborIterator> -resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params ¶ms) +resolve_strict(bool strict, const NearestNeighborIterator::Params ¶ms) { CellType lct = params.queryTensor.fast_type().cell_type(); CellType rct = params.tensorAttribute.getTensorType().cell_type(); if (lct != rct) abort(); if (strict) { - using NNI = NearestNeighborImpl<true>; + using NNI = NearestNeighborImpl<true, has_filter>; return std::make_unique<NNI>(params); } else { - using NNI = NearestNeighborImpl<false>; + using NNI = NearestNeighborImpl<false, has_filter>; return std::make_unique<NNI>(params); } } @@ -112,11 +116,16 @@ NearestNeighborIterator::create( const vespalib::tensor::DenseTensorView &queryTensor, const search::tensor::DenseTensorAttribute &tensorAttribute, NearestNeighborDistanceHeap &distanceHeap, + const search::BitVector *filter, const search::tensor::DistanceFunction *dist_fun) { - Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, dist_fun); - return resolve_strict_LCT_RCT(strict, params); + Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun); + if (filter) { + return resolve_strict<true>(strict, params); + } else { + return resolve_strict<false>(strict, params); + } } } // namespace |