summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-12 14:21:34 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-12 14:29:42 +0000
commita72a4afc5d33c175f360460f727e5d51c9574fac (patch)
tree45c9aaa3c940d26ae99b22ed3077a6b8b69b3757 /searchlib
parentaa958f364f43e58a0e8fff81b4e2e77513f22a7b (diff)
own the filter in a class and use shared_from_this
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp17
-rw-r--r--searchlib/src/tests/queryeval/blueprint/mysearch.h2
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/blueprint.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/blueprint.h5
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/global_filter.h39
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp17
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h4
9 files changed, 69 insertions, 23 deletions
diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
index 4f3d5443ab8..60fed5e42bb 100644
--- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
+++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
@@ -70,8 +70,9 @@ TEST("test AndNot Blueprint") {
EXPECT_EQUAL(false, a.getState().want_global_filter());
a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, a.getState().want_global_filter());
- std::shared_ptr<BitVector> empty_global_filter;
- a.set_global_filter(empty_global_filter);
+ std::shared_ptr<GlobalFilter> empty_global_filter = GlobalFilter::create();
+ EXPECT_FALSE(empty_global_filter->has_filter());
+ a.set_global_filter(*empty_global_filter);
EXPECT_EQUAL(false, got_global_filter(a.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(a.getChild(1)));
}
@@ -145,8 +146,8 @@ TEST("test And Blueprint") {
EXPECT_EQUAL(false, a.getState().want_global_filter());
a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, a.getState().want_global_filter());
- std::shared_ptr<BitVector> empty_global_filter;
- a.set_global_filter(empty_global_filter);
+ std::shared_ptr<GlobalFilter> empty_global_filter = GlobalFilter::create();
+ a.set_global_filter(*empty_global_filter);
EXPECT_EQUAL(false, got_global_filter(a.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(a.getChild(1)));
}
@@ -225,8 +226,8 @@ TEST("test Or Blueprint") {
EXPECT_EQUAL(false, o.getState().want_global_filter());
o.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, o.getState().want_global_filter());
- std::shared_ptr<BitVector> empty_global_filter;
- o.set_global_filter(empty_global_filter);
+ std::shared_ptr<GlobalFilter> empty_global_filter = GlobalFilter::create();
+ o.set_global_filter(*empty_global_filter);
EXPECT_EQUAL(false, got_global_filter(o.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(o.getChild(o.childCnt() - 1)));
}
@@ -380,8 +381,8 @@ TEST("test Rank Blueprint") {
EXPECT_EQUAL(false, a.getState().want_global_filter());
a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create()));
EXPECT_EQUAL(true, a.getState().want_global_filter());
- std::shared_ptr<BitVector> empty_global_filter;
- a.set_global_filter(empty_global_filter);
+ std::shared_ptr<GlobalFilter> empty_global_filter = GlobalFilter::create();
+ a.set_global_filter(*empty_global_filter);
EXPECT_EQUAL(false, got_global_filter(a.getChild(0)));
EXPECT_EQUAL(true, got_global_filter(a.getChild(1)));
}
diff --git a/searchlib/src/tests/queryeval/blueprint/mysearch.h b/searchlib/src/tests/queryeval/blueprint/mysearch.h
index f0400fb1d90..ae1d47e8403 100644
--- a/searchlib/src/tests/queryeval/blueprint/mysearch.h
+++ b/searchlib/src/tests/queryeval/blueprint/mysearch.h
@@ -132,7 +132,7 @@ public:
set_cost_tier(value);
return *this;
}
- void set_global_filter(std::shared_ptr<BitVector>) override {
+ void set_global_filter(GlobalFilter &) override {
_got_global_filter = true;
}
bool got_global_filter() const { return _got_global_filter; }
diff --git a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt
index 0dcb0393473..dc36e0a1f7e 100644
--- a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt
@@ -19,6 +19,7 @@ vespa_add_library(searchlib_queryeval OBJECT
fake_searchable.cpp
field_spec.cpp
get_weight_from_node.cpp
+ global_filter.cpp
hitcollector.cpp
intermediate_blueprints.cpp
isourceselector.cpp
diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp
index 1f0ed4ed76f..190ce054b00 100644
--- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp
@@ -105,7 +105,7 @@ Blueprint::get_replacement()
}
void
-Blueprint::set_global_filter(std::shared_ptr<BitVector>)
+Blueprint::set_global_filter(GlobalFilter &)
{
}
@@ -371,7 +371,7 @@ IntermediateBlueprint::optimize(Blueprint* &self)
}
void
-IntermediateBlueprint::set_global_filter(std::shared_ptr<BitVector> global_filter)
+IntermediateBlueprint::set_global_filter(GlobalFilter &global_filter)
{
for (auto & child : _children) {
if (child->getState().want_global_filter()) {
diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h
index 091a5b924ae..2f944b9271b 100644
--- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h
@@ -5,6 +5,7 @@
#include "field_spec.h"
#include "unpackinfo.h"
#include "executeinfo.h"
+#include "global_filter.h"
#include <vespa/searchlib/common/bitvector.h>
namespace vespalib { class ObjectVisitor; }
@@ -186,7 +187,7 @@ public:
virtual bool supports_termwise_children() const { return false; }
virtual bool always_needs_unpack() const { return false; }
- virtual void set_global_filter(std::shared_ptr<BitVector> global_filter);
+ virtual void set_global_filter(GlobalFilter &global_filter);
virtual const State &getState() const = 0;
const Blueprint &root() const;
@@ -273,7 +274,7 @@ public:
void setDocIdLimit(uint32_t limit) override final;
void optimize(Blueprint* &self) override final;
- void set_global_filter(std::shared_ptr<BitVector> global_filter) override;
+ void set_global_filter(GlobalFilter &global_filter) override;
IndexList find(const IPredicate & check) const;
size_t childCnt() const { return _children.size(); }
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
new file mode 100644
index 00000000000..849700b250e
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp
@@ -0,0 +1,3 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "global_filter.h"
diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
new file mode 100644
index 00000000000..9a9f6d0506f
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h
@@ -0,0 +1,39 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <memory>
+#include <vespa/searchlib/common/bitvector.h>
+
+namespace search::queryeval {
+
+class GlobalFilter : public std::enable_shared_from_this<GlobalFilter>
+{
+private:
+ struct ctor_tag {};
+ std::unique_ptr<search::BitVector> bit_vector;
+
+ GlobalFilter(const GlobalFilter &) = delete;
+ GlobalFilter(GlobalFilter &&) = delete;
+public:
+
+ GlobalFilter(ctor_tag, std::unique_ptr<search::BitVector> bit_vector_in)
+ : bit_vector(std::move(bit_vector_in))
+ {}
+
+ GlobalFilter(ctor_tag) : bit_vector() {}
+
+ ~GlobalFilter() {}
+
+ template<typename ... T>
+ static std::shared_ptr<GlobalFilter> create(T&& ... params) {
+ ctor_tag x;
+ return std::make_shared<GlobalFilter>(x, std::forward(params)...);
+ }
+
+ const search::BitVector *filter() const { return bit_vector.get(); }
+
+ bool has_filter() const { return (bool)bit_vector; }
+};
+
+} // namespace
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 3ea515e5cd4..de9688c4b43 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -60,7 +60,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_fallback_dist_fun(),
_distance_heap(target_num_hits),
_found_hits(),
- _global_filter()
+ _global_filter(GlobalFilter::create())
{
auto lct = _query_tensor->cellsRef().type;
auto rct = _attr_tensor.getTensorType().cell_type();
@@ -81,14 +81,14 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
NearestNeighborBlueprint::~NearestNeighborBlueprint() = default;
void
-NearestNeighborBlueprint::set_global_filter(std::shared_ptr<BitVector> global_filter)
+NearestNeighborBlueprint::set_global_filter(GlobalFilter &global_filter)
{
- _global_filter = global_filter;
+ _global_filter = global_filter.shared_from_this();
auto nns_index = _attr_tensor.nearest_neighbor_index();
if (_approximate && nns_index) {
uint32_t est_hits = _attr_tensor.getNumDocs();
- if (_global_filter) {
- uint32_t max_hits = _global_filter->countTrueBits();
+ if (_global_filter->has_filter()) {
+ uint32_t max_hits = _global_filter->filter()->countTrueBits();
if (max_hits * 10 < est_hits) {
// too many hits filtered out, use brute force implementation:
_approximate = false;
@@ -112,8 +112,9 @@ NearestNeighborBlueprint::perform_top_k()
if (lhs_type == rhs_type) {
auto lhs = _query_tensor->cellsRef();
uint32_t k = _target_num_hits;
- if (_global_filter) {
- _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits);
+ if (_global_filter->has_filter()) {
+ auto filter = _global_filter->filter();
+ _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits);
} else {
_found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits);
}
@@ -138,7 +139,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
}
const vespalib::tensor::DenseTensorView &qT = *_query_tensor;
return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor,
- _distance_heap, _global_filter.get(), _dist_fun);
+ _distance_heap, _global_filter->filter(), _dist_fun);
}
void
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index c1c6c28de37..e6506f777c3 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -28,7 +28,7 @@ private:
const search::tensor::DistanceFunction *_dist_fun;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
- std::shared_ptr<search::BitVector> _global_filter;
+ std::shared_ptr<GlobalFilter> _global_filter;
void perform_top_k();
public:
@@ -42,7 +42,7 @@ public:
const tensor::DenseTensorAttribute& get_attribute_tensor() const { return _attr_tensor; }
const vespalib::tensor::DenseTensorView& get_query_tensor() const { return *_query_tensor; }
uint32_t get_target_num_hits() const { return _target_num_hits; }
- void set_global_filter(std::shared_ptr<BitVector> global_filter) override;
+ void set_global_filter(GlobalFilter &global_filter) override;
std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda,
bool strict) const override;