From a72a4afc5d33c175f360460f727e5d51c9574fac Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 12 May 2020 14:21:34 +0000 Subject: own the filter in a class and use shared_from_this --- .../blueprint/intermediate_blueprints_test.cpp | 17 +++++----- searchlib/src/tests/queryeval/blueprint/mysearch.h | 2 +- .../src/vespa/searchlib/queryeval/CMakeLists.txt | 1 + .../src/vespa/searchlib/queryeval/blueprint.cpp | 4 +-- .../src/vespa/searchlib/queryeval/blueprint.h | 5 +-- .../vespa/searchlib/queryeval/global_filter.cpp | 3 ++ .../src/vespa/searchlib/queryeval/global_filter.h | 39 ++++++++++++++++++++++ .../queryeval/nearest_neighbor_blueprint.cpp | 17 +++++----- .../queryeval/nearest_neighbor_blueprint.h | 4 +-- 9 files changed, 69 insertions(+), 23 deletions(-) create mode 100644 searchlib/src/vespa/searchlib/queryeval/global_filter.cpp create mode 100644 searchlib/src/vespa/searchlib/queryeval/global_filter.h (limited to 'searchlib') diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 4f3d5443ab8..60fed5e42bb 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -70,8 +70,9 @@ TEST("test AndNot Blueprint") { EXPECT_EQUAL(false, a.getState().want_global_filter()); a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, a.getState().want_global_filter()); - std::shared_ptr empty_global_filter; - a.set_global_filter(empty_global_filter); + std::shared_ptr empty_global_filter = GlobalFilter::create(); + EXPECT_FALSE(empty_global_filter->has_filter()); + a.set_global_filter(*empty_global_filter); EXPECT_EQUAL(false, got_global_filter(a.getChild(0))); EXPECT_EQUAL(true, got_global_filter(a.getChild(1))); } @@ -145,8 +146,8 @@ TEST("test And Blueprint") { EXPECT_EQUAL(false, a.getState().want_global_filter()); a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, a.getState().want_global_filter()); - std::shared_ptr empty_global_filter; - a.set_global_filter(empty_global_filter); + std::shared_ptr empty_global_filter = GlobalFilter::create(); + a.set_global_filter(*empty_global_filter); EXPECT_EQUAL(false, got_global_filter(a.getChild(0))); EXPECT_EQUAL(true, got_global_filter(a.getChild(1))); } @@ -225,8 +226,8 @@ TEST("test Or Blueprint") { EXPECT_EQUAL(false, o.getState().want_global_filter()); o.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, o.getState().want_global_filter()); - std::shared_ptr empty_global_filter; - o.set_global_filter(empty_global_filter); + std::shared_ptr empty_global_filter = GlobalFilter::create(); + o.set_global_filter(*empty_global_filter); EXPECT_EQUAL(false, got_global_filter(o.getChild(0))); EXPECT_EQUAL(true, got_global_filter(o.getChild(o.childCnt() - 1))); } @@ -380,8 +381,8 @@ TEST("test Rank Blueprint") { EXPECT_EQUAL(false, a.getState().want_global_filter()); a.addChild(ap(MyLeafSpec(20).addField(1, 1).want_global_filter().create())); EXPECT_EQUAL(true, a.getState().want_global_filter()); - std::shared_ptr empty_global_filter; - a.set_global_filter(empty_global_filter); + std::shared_ptr empty_global_filter = GlobalFilter::create(); + a.set_global_filter(*empty_global_filter); EXPECT_EQUAL(false, got_global_filter(a.getChild(0))); EXPECT_EQUAL(true, got_global_filter(a.getChild(1))); } diff --git a/searchlib/src/tests/queryeval/blueprint/mysearch.h b/searchlib/src/tests/queryeval/blueprint/mysearch.h index f0400fb1d90..ae1d47e8403 100644 --- a/searchlib/src/tests/queryeval/blueprint/mysearch.h +++ b/searchlib/src/tests/queryeval/blueprint/mysearch.h @@ -132,7 +132,7 @@ public: set_cost_tier(value); return *this; } - void set_global_filter(std::shared_ptr) override { + void set_global_filter(GlobalFilter &) override { _got_global_filter = true; } bool got_global_filter() const { return _got_global_filter; } diff --git a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt index 0dcb0393473..dc36e0a1f7e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/queryeval/CMakeLists.txt @@ -19,6 +19,7 @@ vespa_add_library(searchlib_queryeval OBJECT fake_searchable.cpp field_spec.cpp get_weight_from_node.cpp + global_filter.cpp hitcollector.cpp intermediate_blueprints.cpp isourceselector.cpp diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index 1f0ed4ed76f..190ce054b00 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -105,7 +105,7 @@ Blueprint::get_replacement() } void -Blueprint::set_global_filter(std::shared_ptr) +Blueprint::set_global_filter(GlobalFilter &) { } @@ -371,7 +371,7 @@ IntermediateBlueprint::optimize(Blueprint* &self) } void -IntermediateBlueprint::set_global_filter(std::shared_ptr global_filter) +IntermediateBlueprint::set_global_filter(GlobalFilter &global_filter) { for (auto & child : _children) { if (child->getState().want_global_filter()) { diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index 091a5b924ae..2f944b9271b 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -5,6 +5,7 @@ #include "field_spec.h" #include "unpackinfo.h" #include "executeinfo.h" +#include "global_filter.h" #include namespace vespalib { class ObjectVisitor; } @@ -186,7 +187,7 @@ public: virtual bool supports_termwise_children() const { return false; } virtual bool always_needs_unpack() const { return false; } - virtual void set_global_filter(std::shared_ptr global_filter); + virtual void set_global_filter(GlobalFilter &global_filter); virtual const State &getState() const = 0; const Blueprint &root() const; @@ -273,7 +274,7 @@ public: void setDocIdLimit(uint32_t limit) override final; void optimize(Blueprint* &self) override final; - void set_global_filter(std::shared_ptr global_filter) override; + void set_global_filter(GlobalFilter &global_filter) override; IndexList find(const IPredicate & check) const; size_t childCnt() const { return _children.size(); } diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp new file mode 100644 index 00000000000..849700b250e --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.cpp @@ -0,0 +1,3 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "global_filter.h" diff --git a/searchlib/src/vespa/searchlib/queryeval/global_filter.h b/searchlib/src/vespa/searchlib/queryeval/global_filter.h new file mode 100644 index 00000000000..9a9f6d0506f --- /dev/null +++ b/searchlib/src/vespa/searchlib/queryeval/global_filter.h @@ -0,0 +1,39 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include +#include + +namespace search::queryeval { + +class GlobalFilter : public std::enable_shared_from_this +{ +private: + struct ctor_tag {}; + std::unique_ptr bit_vector; + + GlobalFilter(const GlobalFilter &) = delete; + GlobalFilter(GlobalFilter &&) = delete; +public: + + GlobalFilter(ctor_tag, std::unique_ptr bit_vector_in) + : bit_vector(std::move(bit_vector_in)) + {} + + GlobalFilter(ctor_tag) : bit_vector() {} + + ~GlobalFilter() {} + + template + static std::shared_ptr create(T&& ... params) { + ctor_tag x; + return std::make_shared(x, std::forward(params)...); + } + + const search::BitVector *filter() const { return bit_vector.get(); } + + bool has_filter() const { return (bool)bit_vector; } +}; + +} // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 3ea515e5cd4..de9688c4b43 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -60,7 +60,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _fallback_dist_fun(), _distance_heap(target_num_hits), _found_hits(), - _global_filter() + _global_filter(GlobalFilter::create()) { auto lct = _query_tensor->cellsRef().type; auto rct = _attr_tensor.getTensorType().cell_type(); @@ -81,14 +81,14 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f NearestNeighborBlueprint::~NearestNeighborBlueprint() = default; void -NearestNeighborBlueprint::set_global_filter(std::shared_ptr global_filter) +NearestNeighborBlueprint::set_global_filter(GlobalFilter &global_filter) { - _global_filter = global_filter; + _global_filter = global_filter.shared_from_this(); auto nns_index = _attr_tensor.nearest_neighbor_index(); if (_approximate && nns_index) { uint32_t est_hits = _attr_tensor.getNumDocs(); - if (_global_filter) { - uint32_t max_hits = _global_filter->countTrueBits(); + if (_global_filter->has_filter()) { + uint32_t max_hits = _global_filter->filter()->countTrueBits(); if (max_hits * 10 < est_hits) { // too many hits filtered out, use brute force implementation: _approximate = false; @@ -112,8 +112,9 @@ NearestNeighborBlueprint::perform_top_k() if (lhs_type == rhs_type) { auto lhs = _query_tensor->cellsRef(); uint32_t k = _target_num_hits; - if (_global_filter) { - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *_global_filter, k + _explore_additional_hits); + if (_global_filter->has_filter()) { + auto filter = _global_filter->filter(); + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits); } else { _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits); } @@ -138,7 +139,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData } const vespalib::tensor::DenseTensorView &qT = *_query_tensor; return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, - _distance_heap, _global_filter.get(), _dist_fun); + _distance_heap, _global_filter->filter(), _dist_fun); } void diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index c1c6c28de37..e6506f777c3 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -28,7 +28,7 @@ private: const search::tensor::DistanceFunction *_dist_fun; mutable NearestNeighborDistanceHeap _distance_heap; std::vector _found_hits; - std::shared_ptr _global_filter; + std::shared_ptr _global_filter; void perform_top_k(); public: @@ -42,7 +42,7 @@ public: const tensor::DenseTensorAttribute& get_attribute_tensor() const { return _attr_tensor; } const vespalib::tensor::DenseTensorView& get_query_tensor() const { return *_query_tensor; } uint32_t get_target_num_hits() const { return _target_num_hits; } - void set_global_filter(std::shared_ptr global_filter) override; + void set_global_filter(GlobalFilter &global_filter) override; std::unique_ptr createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda, bool strict) const override; -- cgit v1.2.3