aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-05-12 16:29:12 +0200
committerGitHub <noreply@github.com>2020-05-12 16:29:12 +0200
commitb43e8090a860bf92b07153a2fd01d95f7fa2e548 (patch)
tree1ac17c129d4ceb0d53f4d9f6c3ec55586a16345b
parentfeb582188627adb3a2e3ede8c61b3e15cabd1d82 (diff)
parent6007048884ea5c222c0ad88862f7cc12992ac336 (diff)
Merge pull request #13221 from vespa-engine/arnej/allow-filter-bruteforce-nn
allow filter in bruteforce
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp35
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h4
4 files changed, 77 insertions, 17 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 5d933cb1285..e8c83b8548a 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -6,6 +6,7 @@
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/tensor/tensor.h>
+#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/common/feature.h>
#include <vespa/searchlib/fef/matchdata.h>
#include <vespa/searchlib/queryeval/nearest_neighbor_iterator.h>
@@ -23,6 +24,7 @@ LOG_SETUP("nearest_neighbor_test");
using search::feature_t;
using search::tensor::DenseTensorAttribute;
using search::AttributeVector;
+using search::BitVector;
using vespalib::eval::ValueType;
using CellType = vespalib::eval::ValueType::CellType;
using vespalib::eval::TensorSpec;
@@ -65,13 +67,15 @@ struct Fixture
vespalib::string _typeSpec;
std::shared_ptr<DenseTensorAttribute> _tensorAttr;
std::shared_ptr<AttributeVector> _attr;
+ std::unique_ptr<BitVector> _global_filter;
Fixture(const vespalib::string &typeSpec)
: _cfg(BasicType::TENSOR, CollectionType::SINGLE),
_name("test"),
_typeSpec(typeSpec),
_tensorAttr(),
- _attr()
+ _attr(),
+ _global_filter()
{
_cfg.setTensorType(ValueType::from_spec(typeSpec));
_tensorAttr = makeAttr();
@@ -93,6 +97,15 @@ struct Fixture
}
}
+ void setFilter(std::vector<uint32_t> docids) {
+ uint32_t sz = _attr->getNumDocs();
+ _global_filter = BitVector::create(sz);
+ for (uint32_t id : docids) {
+ EXPECT_LESS(id, sz);
+ _global_filter->setBit(id);
+ }
+ }
+
void setTensor(uint32_t docId, const Tensor &tensor) {
ensureSpace(docId);
_tensorAttr->setTensor(docId, tensor);
@@ -119,7 +132,8 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun());
+ const BitVector *filter = env._global_filter.get();
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun());
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
} else {
@@ -158,13 +172,45 @@ TEST("require that NearestNeighborIterator returns expected results") {
TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat));
}
+void
+verify_iterator_returns_filtered_results(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
+ fixture.ensureSpace(6);
+ fixture.setFilter({1,3,4});
+ fixture.setTensor(1, 3.0, 4.0);
+ fixture.setTensor(2, 6.0, 8.0);
+ fixture.setTensor(3, 5.0, 12.0);
+ fixture.setTensor(4, 4.0, 3.0);
+ fixture.setTensor(5, 8.0, 6.0);
+ fixture.setTensor(6, 4.0, 3.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
+ SimpleResult result = find_matches<true>(fixture, *nullTensor);
+ SimpleResult nullExpect({1,3,4});
+ EXPECT_EQUAL(result, nullExpect);
+ result = find_matches<false>(fixture, *nullTensor);
+ EXPECT_EQUAL(result, nullExpect);
+ auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0);
+ SimpleResult farExpect({1,3,4});
+ result = find_matches<true>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
+ result = find_matches<false>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
+}
+
+TEST("require that NearestNeighborIterator returns filtered results") {
+ TEST_DO(verify_iterator_returns_filtered_results(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_returns_filtered_results(denseSpecFloat, denseSpecFloat));
+}
+
template <bool strict>
std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, env.dist_fun());
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, nullptr, env.dist_fun());
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 68be4d35972..b0db678dfc6 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -121,7 +121,8 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
return NnsIndexIterator::create(tfmd, _found_hits, _dist_fun);
}
const vespalib::tensor::DenseTensorView &qT = *_query_tensor;
- return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap, _dist_fun);
+ return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor,
+ _distance_heap, nullptr, _dist_fun);
}
void
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index 68c6a1603d0..07e10271c55 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "nearest_neighbor_iterator.h"
+#include <vespa/searchlib/common/bitvector.h>
using search::tensor::DenseTensorAttribute;
using vespalib::ConstArrayRef;
@@ -29,7 +30,7 @@ is_compatible(const vespalib::eval::ValueType& lhs,
* Keeps a heap of the K best hit distances.
* Currently always does brute-force scanning, which is very expensive.
**/
-template <bool strict>
+template <bool strict, bool has_filter>
class NearestNeighborImpl : public NearestNeighborIterator
{
public:
@@ -48,11 +49,13 @@ public:
void doSeek(uint32_t docId) override {
double distanceLimit = params().distanceHeap.distanceLimit();
while (__builtin_expect((docId < getEndId()), true)) {
- double d = computeDistance(docId, distanceLimit);
- if (d <= distanceLimit) {
- _lastScore = d;
- setDocId(docId);
- return;
+ if ((!has_filter) || params().filter->testBit(docId)) {
+ double d = computeDistance(docId, distanceLimit);
+ if (d <= distanceLimit) {
+ _lastScore = d;
+ setDocId(docId);
+ return;
+ }
}
if (strict) {
++docId;
@@ -83,22 +86,23 @@ private:
double _lastScore;
};
-template <bool strict>
-NearestNeighborImpl<strict>::~NearestNeighborImpl() = default;
+template <bool strict, bool has_filter>
+NearestNeighborImpl<strict, has_filter>::~NearestNeighborImpl() = default;
namespace {
+template <bool has_filter>
std::unique_ptr<NearestNeighborIterator>
-resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params &params)
+resolve_strict(bool strict, const NearestNeighborIterator::Params &params)
{
CellType lct = params.queryTensor.fast_type().cell_type();
CellType rct = params.tensorAttribute.getTensorType().cell_type();
if (lct != rct) abort();
if (strict) {
- using NNI = NearestNeighborImpl<true>;
+ using NNI = NearestNeighborImpl<true, has_filter>;
return std::make_unique<NNI>(params);
} else {
- using NNI = NearestNeighborImpl<false>;
+ using NNI = NearestNeighborImpl<false, has_filter>;
return std::make_unique<NNI>(params);
}
}
@@ -112,11 +116,16 @@ NearestNeighborIterator::create(
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
NearestNeighborDistanceHeap &distanceHeap,
+ const search::BitVector *filter,
const search::tensor::DistanceFunction *dist_fun)
{
- Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, dist_fun);
- return resolve_strict_LCT_RCT(strict, params);
+ Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, filter, dist_fun);
+ if (filter) {
+ return resolve_strict<true>(strict, params);
+ } else {
+ return resolve_strict<false>(strict, params);
+ }
}
} // namespace
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
index 2a800f96710..9cbb1d39a91 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
@@ -25,17 +25,20 @@ public:
const DenseTensorView &queryTensor;
const DenseTensorAttribute &tensorAttribute;
NearestNeighborDistanceHeap &distanceHeap;
+ const search::BitVector *filter;
const search::tensor::DistanceFunction *distanceFunction;
Params(fef::TermFieldMatchData &tfmd_in,
const DenseTensorView &queryTensor_in,
const DenseTensorAttribute &tensorAttribute_in,
NearestNeighborDistanceHeap &distanceHeap_in,
+ const search::BitVector *filter_in,
const search::tensor::DistanceFunction *distanceFunction_in)
: tfmd(tfmd_in),
queryTensor(queryTensor_in),
tensorAttribute(tensorAttribute_in),
distanceHeap(distanceHeap_in),
+ filter(filter_in),
distanceFunction(distanceFunction_in)
{}
};
@@ -50,6 +53,7 @@ public:
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
NearestNeighborDistanceHeap &distanceHeap,
+ const search::BitVector *filter,
const search::tensor::DistanceFunction *dist_fun);
const Params& params() const { return _params; }