diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-05-12 16:29:12 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-05-12 16:29:12 +0200 |
commit | b43e8090a860bf92b07153a2fd01d95f7fa2e548 (patch) | |
tree | 1ac17c129d4ceb0d53f4d9f6c3ec55586a16345b | |
parent | feb582188627adb3a2e3ede8c61b3e15cabd1d82 (diff) | |
parent | 6007048884ea5c222c0ad88862f7cc12992ac336 (diff) |
Merge pull request #13221 from vespa-engine/arnej/allow-filter-bruteforce-nn
allow filter in bruteforce
4 files changed, 77 insertions, 17 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 5d933cb1285..e8c83b8548a 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -6,6 +6,7 @@ #include <vespa/eval/tensor/default_tensor_engine.h> #include <vespa/eval/tensor/dense/dense_tensor.h> #include <vespa/eval/tensor/tensor.h> +#include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/common/feature.h> #include <vespa/searchlib/fef/matchdata.h> #include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h> @@ -23,6 +24,7 @@ LOG_SETUP("nearest_neighbor_test"); using search::feature_t; using search::tensor::DenseTensorAttribute; using search::AttributeVector; +using search::BitVector; using vespalib::eval::ValueType; using CellType = vespalib::eval::ValueType::CellType; using vespalib::eval::TensorSpec; @@ -65,13 +67,15 @@ struct Fixture vespalib::string _typeSpec; std::shared_ptr<DenseTensorAttribute> _tensorAttr; std::shared_ptr<AttributeVector> _attr; + std::unique_ptr<BitVector> _global_filter; Fixture(const vespalib::string &typeSpec) : _cfg(BasicType::TENSOR, CollectionType::SINGLE), _name("test"), _typeSpec(typeSpec), _tensorAttr(), - _attr() + _attr(), + _global_filter() { _cfg.setTensorType(ValueType::from_spec(typeSpec)); _tensorAttr = makeAttr(); @@ -93,6 +97,15 @@ struct Fixture } } + void setFilter(std::vector<uint32_t> docids) { + uint32_t sz = _attr->getNumDocs(); + _global_filter = BitVector::create(sz); + for (uint32_t id : docids) { + EXPECT_LESS(id, sz); + _global_filter->setBit(id); + } + } + void setTensor(uint32_t docId, const Tensor &tensor) { ensureSpace(docId); _tensorAttr->setTensor(docId, tensor); @@ -119,7 +132,8 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) { auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun()); + const BitVector *filter = env._global_filter.get(); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); } else { @@ -158,13 +172,45 @@ TEST("require that NearestNeighborIterator returns expected results") { TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat)); } +void +verify_iterator_returns_filtered_results(const vespalib::string& attribute_tensor_type_spec, + const vespalib::string& query_tensor_type_spec) +{ + Fixture fixture(attribute_tensor_type_spec); + fixture.ensureSpace(6); + fixture.setFilter({1,3,4}); + fixture.setTensor(1, 3.0, 4.0); + fixture.setTensor(2, 6.0, 8.0); + fixture.setTensor(3, 5.0, 12.0); + fixture.setTensor(4, 4.0, 3.0); + fixture.setTensor(5, 8.0, 6.0); + fixture.setTensor(6, 4.0, 3.0); + auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0); + SimpleResult result = find_matches<true>(fixture, *nullTensor); + SimpleResult nullExpect({1,3,4}); + EXPECT_EQUAL(result, nullExpect); + result = find_matches<false>(fixture, *nullTensor); + EXPECT_EQUAL(result, nullExpect); + auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0); + SimpleResult farExpect({1,3,4}); + result = find_matches<true>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); + result = find_matches<false>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); +} + +TEST("require that NearestNeighborIterator returns filtered results") { + TEST_DO(verify_iterator_returns_filtered_results(denseSpecDouble, denseSpecDouble)); + TEST_DO(verify_iterator_returns_filtered_results(denseSpecFloat, denseSpecFloat)); +} + template <bool strict> std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun()); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun()); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 68be4d35972..b0db678dfc6 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -121,7 +121,8 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData return NnsIndexIterator::create(tfmd, _found_hits, _dist_fun); } const vespalib::tensor::DenseTensorView &qT = *_query_tensor; - return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap, _dist_fun); + return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, + _distance_heap, nullptr, _dist_fun); } void diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index 68c6a1603d0..07e10271c55 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -1,6 +1,7 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "nearest_neighbor_iterator.h" +#include <vespa/searchlib/common/bitvector.h> using search::tensor::DenseTensorAttribute; using vespalib::ConstArrayRef; @@ -29,7 +30,7 @@ is_compatible(const vespalib::eval::ValueType& lhs, * Keeps a heap of the K best hit distances. * Currently always does brute-force scanning, which is very expensive. **/ -template <bool strict> +template <bool strict, bool has_filter> class NearestNeighborImpl : public NearestNeighborIterator { public: @@ -48,11 +49,13 @@ public: void doSeek(uint32_t docId) override { double distanceLimit = params().distanceHeap.distanceLimit(); while (__builtin_expect((docId < getEndId()), true)) { - double d = computeDistance(docId, distanceLimit); - if (d <= distanceLimit) { - _lastScore = d; - setDocId(docId); - return; + if ((!has_filter) || params().filter->testBit(docId)) { + double d = computeDistance(docId, distanceLimit); + if (d <= distanceLimit) { + _lastScore = d; + setDocId(docId); + return; + } } if (strict) { ++docId; @@ -83,22 +86,23 @@ private: double _lastScore; }; -template <bool strict> -NearestNeighborImpl<strict>::~NearestNeighborImpl() = default; +template <bool strict, bool has_filter> +NearestNeighborImpl<strict, has_filter>::~NearestNeighborImpl() = default; namespace { +template <bool has_filter> std::unique_ptr<NearestNeighborIterator> -resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params ¶ms) +resolve_strict(bool strict, const NearestNeighborIterator::Params ¶ms) { CellType lct = params.queryTensor.fast_type().cell_type(); CellType rct = params.tensorAttribute.getTensorType().cell_type(); if (lct != rct) abort(); if (strict) { - using NNI = NearestNeighborImpl<true>; + using NNI = NearestNeighborImpl<true, has_filter>; return std::make_unique<NNI>(params); } else { - using NNI = NearestNeighborImpl<false>; + using NNI = NearestNeighborImpl<false, has_filter>; return std::make_unique<NNI>(params); } } @@ -112,11 +116,16 @@ NearestNeighborIterator::create( const vespalib::tensor::DenseTensorView &queryTensor, const search::tensor::DenseTensorAttribute &tensorAttribute, NearestNeighborDistanceHeap &distanceHeap, + const search::BitVector *filter, const search::tensor::DistanceFunction *dist_fun) { - Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, dist_fun); - return resolve_strict_LCT_RCT(strict, params); + Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun); + if (filter) { + return resolve_strict<true>(strict, params); + } else { + return resolve_strict<false>(strict, params); + } } } // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h index 2a800f96710..9cbb1d39a91 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h @@ -25,17 +25,20 @@ public: const DenseTensorView &queryTensor; const DenseTensorAttribute &tensorAttribute; NearestNeighborDistanceHeap &distanceHeap; + const search::BitVector *filter; const search::tensor::DistanceFunction *distanceFunction; Params(fef::TermFieldMatchData &tfmd_in, const DenseTensorView &queryTensor_in, const DenseTensorAttribute &tensorAttribute_in, NearestNeighborDistanceHeap &distanceHeap_in, + const search::BitVector *filter_in, const search::tensor::DistanceFunction *distanceFunction_in) : tfmd(tfmd_in), queryTensor(queryTensor_in), tensorAttribute(tensorAttribute_in), distanceHeap(distanceHeap_in), + filter(filter_in), distanceFunction(distanceFunction_in) {} }; @@ -50,6 +53,7 @@ public: const vespalib::tensor::DenseTensorView &queryTensor, const search::tensor::DenseTensorAttribute &tensorAttribute, NearestNeighborDistanceHeap &distanceHeap, + const search::BitVector *filter, const search::tensor::DistanceFunction *dist_fun); const Params& params() const { return _params; } |