diff options
Diffstat (limited to 'searchlib')
25 files changed, 689 insertions, 138 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index 7c4b7555158..fe6149e6fba 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -7,6 +7,7 @@ #include <vespa/searchlib/query/streaming/query.h> #include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h> #include <vespa/searchlib/query/streaming/wand_term.h> +#include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/searchlib/query/tree/querybuilder.h> #include <vespa/searchlib/query/tree/simplequery.h> #include <vespa/searchlib/query/tree/stackdumpcreator.h> @@ -1020,6 +1021,48 @@ TEST(StreamingQueryTest, wand_term) check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit"); } +TEST(StreamingQueryTest, weighted_set_term) +{ + search::streaming::WeightedSetTerm term({}, "index", 2); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(4)); + term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD)); + term.get_terms().back()->setWeight(Weight(13)); + EXPECT_EQ(2, term.get_terms().size()); + SimpleTermData td; + /* + * Search in fields 10, 11 and 12 (cf. fieldset in schema). + * Fields 11 and 12 have content for doc containing the keys. + * Fields 10 and 12 have valid handles and can be used for ranking. + * Field 11 does not have a valid handle, thus no associated match data. + */ + td.addField(10); + td.addField(11); + td.addField(12); + td.lookupField(10)->setHandle(0); + td.lookupField(12)->setHandle(1); + EXPECT_FALSE(term.evaluate()); + auto& q0 = *term.get_terms()[0]; + q0.add(0, 11, 0, 10); + q0.add(0, 12, 0, 10); + auto& q1 = *term.get_terms()[1]; + q1.add(0, 11, 0, 10); + q1.add(0, 12, 0, 10); + EXPECT_TRUE(term.evaluate()); + MatchData md(MatchData::params().numTermFields(2)); + term.unpack_match_data(23, td, md); + auto tmd0 = md.resolveTermField(0); + EXPECT_NE(23, tmd0->getDocId()); + auto tmd1 = md.resolveTermField(1); + EXPECT_EQ(23, tmd1->getDocId()); + using Weights = std::vector<int32_t>; + Weights weights; + for (auto& pos : *tmd1) { + weights.emplace_back(pos.getElementWeight()); + } + EXPECT_EQ((Weights{13, 4}), weights); +} + TEST(StreamingQueryTest, control_the_size_of_query_terms) { EXPECT_EQ(112u, sizeof(QueryTermSimple)); diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp index f800e124bdc..bbd2744119a 100644 --- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp @@ -24,10 +24,10 @@ class MyOr : public IntermediateBlueprint private: public: double calculate_cost() const final { - return cost_of(get_children(), OrFlow()); + return OrFlow::cost_of(get_children()); } double calculate_relative_estimate() const final { - return estimate_of(get_children(), OrFlow()); + return OrFlow::estimate_of(get_children()); } HitEstimate combine(const std::vector<HitEstimate> &data) const override { return max(data); diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index ab1c004c721..856ac2391f8 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -1380,7 +1380,7 @@ TEST("cost for ONEAR") { } TEST("cost for WEAKAND") { - verify_cost(make::WEAKAND(1000), calc_cost({{1.1, 0.8},{1.2, 0.7},{1.3, 0.5}})); + verify_cost(make::WEAKAND(1000), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}})); } TEST_MAIN() { TEST_DEBUG("lhs.out", "rhs.out"); TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp index ceda30f169a..9a9adeac2bc 100644 --- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp +++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp @@ -5,44 +5,46 @@ #include <vector> #include <random> -using search::queryeval::AndFlow; -using search::queryeval::OrFlow; +constexpr size_t loop_cnt = 64; + +using namespace search::queryeval; + +struct ItemAdapter { + double estimate(const auto &child) const noexcept { return child.rel_est; } + double cost(const auto &child) const noexcept { return child.cost; } + double strict_cost(const auto &child) const noexcept { return child.strict_cost; } +}; struct Item { double rel_est; double cost; - Item(double rel_est_in, double cost_in) noexcept - : rel_est(rel_est_in), cost(cost_in) {} - static void sort_for_and(std::vector<Item> &data) { - std::sort(data.begin(), data.end(), [](const Item &a, const Item &b) noexcept { - return (1.0 - a.rel_est) / a.cost > (1.0 - b.rel_est) / b.cost; - }); + double strict_cost; + Item(double rel_est_in, double cost_in, double strict_cost_in) noexcept + : rel_est(rel_est_in), cost(cost_in), strict_cost(strict_cost_in) {} + template <typename FLOW> static double estimate_of(std::vector<Item> &data) { + return FLOW::estimate_of(ItemAdapter(), data); } - static void sort_for_or(std::vector<Item> &data) { - std::sort(data.begin(), data.end(), [](const Item &a, const Item &b) noexcept { - return a.rel_est / a.cost > b.rel_est / b.cost; - }); + template <typename FLOW> static void sort(std::vector<Item> &data, bool strict) { + FLOW::sort(ItemAdapter(), data, strict); } - static double cost_of(const std::vector<Item> &data, auto flow) { - double cost = 0.0; - for (const Item &item: data) { - cost += flow.flow() * item.cost; - flow.add(item.rel_est); - } - return cost; + template <typename FLOW> static double cost_of(const std::vector<Item> &data, bool strict) { + return FLOW::cost_of(ItemAdapter(), data, strict); + } + template <typename FLOW> static double ordered_cost_of(const std::vector<Item> &data, bool strict) { + return flow::ordered_cost_of(ItemAdapter(), data, FLOW(1.0, strict)); } - static double cost_of_and(const std::vector<Item> &data) { return cost_of(data, AndFlow()); } - static double cost_of_or(const std::vector<Item> &data) { return cost_of(data, OrFlow()); } + auto operator <=>(const Item &rhs) const noexcept = default; }; std::vector<Item> gen_data(size_t size) { static std::mt19937 gen; - static std::uniform_real_distribution<double> rel_est(0.1, 0.9); - static std::uniform_real_distribution<double> cost(1.0, 10.0); + static std::uniform_real_distribution<double> rel_est(0.1, 0.9); + static std::uniform_real_distribution<double> cost(1.0, 10.0); + static std::uniform_real_distribution<double> strict_cost(0.1, 5.0); std::vector<Item> result; result.reserve(size); for (size_t i = 0; i < size; ++i) { - result.emplace_back(rel_est(gen), cost(gen)); + result.emplace_back(rel_est(gen), cost(gen), strict_cost(gen)); } return result; } @@ -80,37 +82,191 @@ TEST(FlowTest, perm_test) { EXPECT_EQ(seen.size(), 120); } +template <template <typename> typename ORDER> +void verify_ordering_is_strict_weak() { + auto cmp = ORDER(ItemAdapter()); + auto input = gen_data(7); + input.emplace_back(0.5, 1.5, 0.5); + input.emplace_back(0.5, 1.5, 0.5); + input.emplace_back(0.5, 1.5, 0.5); + input.emplace_back(0.0, 1.5, 0.5); + input.emplace_back(0.0, 1.5, 0.5); + input.emplace_back(0.5, 0.0, 0.5); + input.emplace_back(0.5, 0.0, 0.5); + input.emplace_back(0.5, 1.5, 0.0); + input.emplace_back(0.5, 1.5, 0.0); + input.emplace_back(0.0, 0.0, 0.0); + input.emplace_back(0.0, 0.0, 0.0); + std::vector<Item> output; + for (const Item &in: input) { + EXPECT_FALSE(cmp(in, in)); // Irreflexivity + size_t out_idx = 0; + bool lower = false; + bool upper = false; + for (const Item &out: output) { + if (cmp(out, in)) { + EXPECT_FALSE(cmp(in, out)); // Antisymmetry + EXPECT_FALSE(lower); // Transitivity + EXPECT_FALSE(upper); // Transitivity + ++out_idx; + } else { + lower = true; + if (cmp(in, out)) { + upper = true; + } else { + EXPECT_FALSE(upper); // Transitivity + } + } + } + output.insert(output.begin() + out_idx, in); + } +} + +TEST(FlowTest, and_ordering_is_strict_weak) { + verify_ordering_is_strict_weak<flow::MinAndCost>(); +} + +TEST(FlowTest, or_ordering_is_strict_weak) { + verify_ordering_is_strict_weak<flow::MinOrCost>(); +} + +TEST(FlowTest, strict_or_ordering_is_strict_weak) { + verify_ordering_is_strict_weak<flow::MinOrStrictCost>(); +} + +struct ExpectFlow { + double flow; + double est; + bool strict; +}; + +void verify_flow(auto flow, const std::vector<double> &est_list, const std::vector<ExpectFlow> &expect) { + ASSERT_EQ(est_list.size() + 1, expect.size()); + for (size_t i = 0; i < expect.size(); ++i) { + EXPECT_DOUBLE_EQ(flow.flow(), expect[i].flow); + EXPECT_DOUBLE_EQ(flow.estimate(), expect[i].est); + EXPECT_EQ(flow.strict(), expect[i].strict); + if (i < est_list.size()) { + flow.add(est_list[i]); + } + } +} + +TEST(FlowTest, basic_and_flow) { + for (double in: {1.0, 0.5, 0.25}) { + for (bool strict: {false, true}) { + verify_flow(AndFlow(in, strict), {0.4, 0.7, 0.2}, + {{in, 0.0, strict}, + {in*0.4, in*0.4, false}, + {in*0.4*0.7, in*0.4*0.7, false}, + {in*0.4*0.7*0.2, in*0.4*0.7*0.2, false}}); + } + } +} + +TEST(FlowTest, basic_or_flow) { + for (double in: {1.0, 0.5, 0.25}) { + for (bool strict: {false, true}) { + verify_flow(OrFlow(in, strict), {0.4, 0.7, 0.2}, + {{in, 0.0, strict}, + {in*0.6, 1.0-in*0.6, strict}, + {in*0.6*0.3, 1.0-in*0.6*0.3, strict}, + {in*0.6*0.3*0.8, 1.0-in*0.6*0.3*0.8, strict}}); + } + } +} + +TEST(FlowTest, basic_and_not_flow) { + for (double in: {1.0, 0.5, 0.25}) { + for (bool strict: {false, true}) { + verify_flow(AndNotFlow(in, strict), {0.4, 0.7, 0.2}, + {{in, 0.0, strict}, + {in*0.4, in*0.4, false}, + {in*0.4*0.3, in*0.4*0.3, false}, + {in*0.4*0.3*0.8, in*0.4*0.3*0.8, false}}); + } + } +} + +TEST(FlowTest, flow_cost) { + std::vector<Item> data = {{0.4, 1.1, 0.6}, {0.7, 1.2, 0.5}, {0.2, 1.3, 0.4}}; + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3); + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3); + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3); + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<OrFlow>(data, true), 0.6 + 0.6*0.5 + 0.6*0.3*0.4); + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3); + EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3); +} + TEST(FlowTest, optimal_and_flow) { - for (size_t i = 0; i < 256; ++i) { - auto data = gen_data(7); - Item::sort_for_and(data); - double min_cost = Item::cost_of_and(data); - double max_cost = 0.0; - auto check = [min_cost,&max_cost](const std::vector<Item> &my_data) noexcept { - double my_cost = Item::cost_of_and(my_data); - EXPECT_LE(min_cost, my_cost); - max_cost = std::max(max_cost, my_cost); - }; - each_perm(data, check); - fprintf(stderr, " and cost(%zu): min: %g, max: %g, factor: %g\n", - i, min_cost, max_cost, max_cost / min_cost); + for (size_t i = 0; i < loop_cnt; ++i) { + for (bool strict: {false, true}) { + auto data = gen_data(7); + double ref_est = Item::estimate_of<AndFlow>(data); + double min_cost = Item::cost_of<AndFlow>(data, strict); + double max_cost = 0.0; + Item::sort<AndFlow>(data, strict); + EXPECT_EQ(Item::ordered_cost_of<AndFlow>(data, strict), min_cost); + auto check = [&](const std::vector<Item> &my_data) noexcept { + double my_cost = Item::ordered_cost_of<AndFlow>(my_data, strict); + EXPECT_LE(min_cost, my_cost); + max_cost = std::max(max_cost, my_cost); + }; + each_perm(data, check); + if (loop_cnt < 1024 || i % 1024 == 0) { + fprintf(stderr, " AND cost(%zu,%s): min: %g, max: %g, factor: %g\n", + i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost); + } + EXPECT_NEAR(ref_est, Item::estimate_of<AndFlow>(data), 1e-9); + } } } TEST(FlowTest, optimal_or_flow) { - for (size_t i = 0; i < 256; ++i) { - auto data = gen_data(7); - Item::sort_for_or(data); - double min_cost = Item::cost_of_or(data); - double max_cost = 0.0; - auto check = [min_cost,&max_cost](const std::vector<Item> &my_data) noexcept { - double my_cost = Item::cost_of_or(my_data); - EXPECT_LE(min_cost, my_cost); - max_cost = std::max(max_cost, my_cost); - }; - each_perm(data, check); - fprintf(stderr, " or cost(%zu): min: %g, max: %g, factor: %g\n", - i, min_cost, max_cost, max_cost / min_cost); + for (size_t i = 0; i < loop_cnt; ++i) { + for (bool strict: {false, true}) { + auto data = gen_data(7); + double min_cost = Item::cost_of<OrFlow>(data, strict); + double max_cost = 0.0; + Item::sort<OrFlow>(data, strict); + EXPECT_EQ(Item::ordered_cost_of<OrFlow>(data, strict), min_cost); + auto check = [&](const std::vector<Item> &my_data) noexcept { + double my_cost = Item::ordered_cost_of<OrFlow>(my_data, strict); + EXPECT_LE(min_cost, my_cost); + max_cost = std::max(max_cost, my_cost); + }; + each_perm(data, check); + if (loop_cnt < 1024 || i % 1024 == 0) { + fprintf(stderr, " OR cost(%zu,%s): min: %g, max: %g, factor: %g\n", + i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost); + } + } + } +} + +TEST(FlowTest, optimal_and_not_flow) { + for (size_t i = 0; i < loop_cnt; ++i) { + for (bool strict: {false, true}) { + auto data = gen_data(7); + Item first = data[0]; + double min_cost = Item::cost_of<AndNotFlow>(data, strict); + double max_cost = 0.0; + Item::sort<AndNotFlow>(data, strict); + EXPECT_EQ(data[0], first); + EXPECT_EQ(Item::ordered_cost_of<AndNotFlow>(data, strict), min_cost); + auto check = [&](const std::vector<Item> &my_data) noexcept { + if (my_data[0] == first) { + double my_cost = Item::ordered_cost_of<AndNotFlow>(my_data, strict); + EXPECT_LE(min_cost, my_cost); + max_cost = std::max(max_cost, my_cost); + } + }; + each_perm(data, check); + if (loop_cnt < 1024 || i % 1024 == 0) { + fprintf(stderr, " ANDNOT cost(%zu,%s): min: %g, max: %g, factor: %g\n", + i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost); + } + } } } diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp index aa6d922f23f..a9f549a0bd9 100644 --- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp +++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp @@ -391,8 +391,11 @@ struct HeapFixture SearchIterator::UP sb(spec.create()); result.search(*sb); } + ~HeapFixture(); }; +HeapFixture::~HeapFixture() = default; + TEST(ParallelWeakAndTest, require_that_scores_are_collected_in_batches_before_adjusting_heap) { HeapFixture f; diff --git a/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp b/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp index 9409b2b26c4..689f9f085d0 100644 --- a/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp +++ b/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp @@ -37,8 +37,11 @@ struct SimpleWandFixture { SearchIterator::UP search(spec.create()); hits.search(*search); } + ~SimpleWandFixture(); }; +SimpleWandFixture::~SimpleWandFixture() = default; + struct AdvancedWandFixture { MyWandSpec spec; SimpleResult hits; @@ -51,8 +54,11 @@ struct AdvancedWandFixture { SearchIterator::UP search(spec.create()); hits.search(*search); } + ~AdvancedWandFixture(); }; +AdvancedWandFixture::~AdvancedWandFixture() = default; + struct WeightOrder { bool operator()(const wand::Term &t1, const wand::Term &t2) const { return (t1.weight < t2.weight); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 99d3ba3f7aa..01148c11c9c 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -1,7 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "attribute_weighted_set_blueprint.h" -#include "multi_term_filter.hpp" +#include "multi_term_hash_filter.hpp" #include <vespa/searchcommon/attribute/i_search_context.h> #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/matchdatalayout.h> @@ -73,7 +73,7 @@ make_multi_term_filter(fef::TermFieldMatchData& tfmd, const std::vector<int32_t>& weights, const std::vector<ISearchContext*>& contexts) { - using FilterType = attribute::MultiTermFilter<WrapperType>; + using FilterType = attribute::MultiTermHashFilter<WrapperType>; typename FilterType::TokenMap tokens; WrapperType wrapper(attr); for (size_t i = 0; i < contexts.size(); ++i) { diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h index adbf37d2dcd..9c3ea258fdc 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h @@ -18,7 +18,7 @@ namespace search::attribute { * @tparam WrapperType Type that wraps an attribute vector and provides access to the attribute value for a given docid. */ template <typename WrapperType> -class MultiTermFilter final : public queryeval::SearchIterator +class MultiTermHashFilter final : public queryeval::SearchIterator { public: using Key = typename WrapperType::TokenT; @@ -31,9 +31,9 @@ private: int32_t _weight; public: - MultiTermFilter(fef::TermFieldMatchData& tfmd, - WrapperType attr, - TokenMap&& map); + MultiTermHashFilter(fef::TermFieldMatchData& tfmd, + WrapperType attr, + TokenMap&& map); void and_hits_into(BitVector& result, uint32_t begin_id) override; void doSeek(uint32_t docId) override; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp index dc572aedbff..96d5b3ac1f3 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp @@ -2,16 +2,16 @@ #pragma once -#include "multi_term_filter.h" +#include "multi_term_hash_filter.h" #include <vespa/searchlib/common/bitvector.h> #include <vespa/searchlib/fef/termfieldmatchdata.h> namespace search::attribute { template <typename WrapperType> -MultiTermFilter<WrapperType>::MultiTermFilter(fef::TermFieldMatchData& tfmd, - WrapperType attr, - TokenMap&& map) +MultiTermHashFilter<WrapperType>::MultiTermHashFilter(fef::TermFieldMatchData& tfmd, + WrapperType attr, + TokenMap&& map) : _tfmd(tfmd), _attr(attr), _map(std::move(map)), @@ -21,7 +21,7 @@ MultiTermFilter<WrapperType>::MultiTermFilter(fef::TermFieldMatchData& tfmd, template <typename WrapperType> void -MultiTermFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id) +MultiTermHashFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id) { auto end = _map.end(); result.foreach_truebit([&, end](uint32_t key) { if ( _map.find(_attr.getToken(key)) == end) { result.clearBit(key); }}, begin_id); @@ -29,7 +29,7 @@ MultiTermFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id template <typename WrapperType> void -MultiTermFilter<WrapperType>::doSeek(uint32_t docId) +MultiTermHashFilter<WrapperType>::doSeek(uint32_t docId) { auto pos = _map.find(_attr.getToken(docId)); if (pos != _map.end()) { @@ -40,7 +40,7 @@ MultiTermFilter<WrapperType>::doSeek(uint32_t docId) template <typename WrapperType> void -MultiTermFilter<WrapperType>::doUnpack(uint32_t docId) +MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId) { _tfmd.reset(docId); fef::TermFieldMatchDataPosition pos; diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt index 9b53407aff5..05a75f4662e 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt @@ -10,5 +10,7 @@ vespa_add_library(searchlib_query_streaming OBJECT querynoderesultbase.cpp queryterm.cpp wand_term.cpp + weighted_set_term.cpp + regexp_term.cpp DEPENDS ) diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp index 3079ec31e8f..ca742aabe26 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp @@ -107,9 +107,7 @@ QueryConnector::create(ParseItem::ItemType type) case search::ParseItem::ITEM_AND: return std::make_unique<AndQueryNode>(); case search::ParseItem::ITEM_OR: case search::ParseItem::ITEM_WEAK_AND: return std::make_unique<OrQueryNode>(); - case search::ParseItem::ITEM_WEIGHTED_SET: case search::ParseItem::ITEM_EQUIV: return std::make_unique<EquivQueryNode>(); - case search::ParseItem::ITEM_WAND: return std::make_unique<OrQueryNode>(); case search::ParseItem::ITEM_NOT: return std::make_unique<AndNotQueryNode>(); case search::ParseItem::ITEM_PHRASE: return std::make_unique<PhraseQueryNode>(); case search::ParseItem::ITEM_SAME_ELEMENT: return std::make_unique<SameElementQueryNode>(); diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h index 8befa2fe7fa..84c693b86d0 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/query.h +++ b/searchlib/src/vespa/searchlib/query/streaming/query.h @@ -103,8 +103,7 @@ public: EquivQueryNode() noexcept : OrQueryNode("EQUIV") { } bool evaluate() const override; bool isFlattenable(ParseItem::ItemType type) const override { - return (type == ParseItem::ITEM_EQUIV) || - (type == ParseItem::ITEM_WEIGHTED_SET); + return (type == ParseItem::ITEM_EQUIV); } }; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index c24f41d16cf..2ee515f062a 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -2,10 +2,12 @@ #include "query.h" #include "nearest_neighbor_query_node.h" +#include "regexp_term.h" #include <vespa/searchlib/parsequery/stackdumpiterator.h> #include <vespa/searchlib/query/streaming/dot_product_term.h> #include <vespa/searchlib/query/streaming/in_term.h> #include <vespa/searchlib/query/streaming/wand_term.h> +#include <vespa/searchlib/query/streaming/weighted_set_term.h> #include <vespa/searchlib/query/tree/term_vector.h> #include <charconv> #include <vespa/log/log.h> @@ -40,7 +42,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_OR: case ParseItem::ITEM_WEAK_AND: case ParseItem::ITEM_EQUIV: - case ParseItem::ITEM_WEIGHTED_SET: case ParseItem::ITEM_NOT: case ParseItem::ITEM_PHRASE: case ParseItem::ITEM_SAME_ELEMENT: @@ -55,7 +56,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor nqn->distance(queryRep.getNearDistance()); } if ((type == ParseItem::ITEM_WEAK_AND) || - (type == ParseItem::ITEM_WEIGHTED_SET) || (type == ParseItem::ITEM_SAME_ELEMENT)) { qn->setIndex(queryRep.getIndexName()); @@ -146,7 +146,12 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor qn = std::make_unique<TrueNode>(); } else { Normalizing normalize_mode = factory.normalizing_mode(ssIndex); - auto qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); + std::unique_ptr<QueryTerm> qt; + if (sTerm != TermType::REGEXP) { + qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode); + } else { + qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode); + } qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); if (qt->isFuzzy()) { @@ -192,6 +197,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor case ParseItem::ITEM_WAND: qn = build_wand_term(factory, queryRep); break; + case ParseItem::ITEM_WEIGHTED_SET: + qn = build_weighted_set_term(factory, queryRep); + break; default: skip_unknown(queryRep); break; @@ -270,6 +278,16 @@ QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQuerySta return wand; } +std::unique_ptr<QueryNode> +QueryNode::build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep) +{ + auto ws = std::make_unique<WeightedSetTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity()); + ws->setWeight(queryRep.GetWeight()); + ws->setUniqueId(queryRep.getUniqueId()); + populate_multi_term(factory.normalizing_mode(ws->index()), *ws, queryRep); + return ws; +} + void QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep) { diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h index a0561b2e52e..454932c0a68 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h @@ -32,6 +32,7 @@ class QueryNode static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); + static std::unique_ptr<QueryNode> build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep); static void skip_unknown(SimpleQueryStackDumpIterator& queryRep); public: using UP = std::unique_ptr<QueryNode>; diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp index c58ec55de9f..d72a3371846 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp @@ -3,4 +3,22 @@ namespace search::streaming { +namespace { + +const char* to_str(Normalizing norm) noexcept { + switch (norm) { + case Normalizing::NONE: return "NONE"; + case Normalizing::LOWERCASE: return "LOWERCASE"; + case Normalizing::LOWERCASE_AND_FOLD: return "LOWERCASE_AND_FOLD"; + } + abort(); +} + +} + +std::ostream& operator<<(std::ostream& os, Normalizing n) { + os << to_str(n); + return os; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h index 74f872ad187..83fb27794a3 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h +++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h @@ -2,6 +2,7 @@ #pragma once #include <vespa/vespalib/stllike/string.h> +#include <iosfwd> #include <memory> namespace search::streaming { @@ -24,6 +25,8 @@ enum class Normalizing { LOWERCASE_AND_FOLD }; +std::ostream& operator<<(std::ostream&, Normalizing); + class QueryNodeResultFactory { public: virtual ~QueryNodeResultFactory() = default; diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 3950a179d67..3e05d381ee2 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -179,4 +179,10 @@ QueryTerm::as_multi_term() noexcept return nullptr; } +RegexpTerm* +QueryTerm::as_regexp_term() noexcept +{ + return nullptr; +} + } diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index 743998a630e..cd2bdd7eaec 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -13,6 +13,7 @@ namespace search::streaming { class NearestNeighborQueryNode; class MultiTerm; +class RegexpTerm; /** This is a leaf in the Query tree. All terms are leafs. @@ -93,6 +94,7 @@ public: void setFuzzyPrefixLength(uint32_t fuzzyPrefixLength) { _fuzzyPrefixLength = fuzzyPrefixLength; } virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept; virtual MultiTerm* as_multi_term() noexcept; + virtual RegexpTerm* as_regexp_term() noexcept; protected: using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>; string _index; diff --git a/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp new file mode 100644 index 00000000000..4508caa7072 --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp @@ -0,0 +1,27 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "regexp_term.h" + +namespace search::streaming { + +using vespalib::Regex; + +namespace { + +constexpr Regex::Options normalize_mode_to_regex_opts(Normalizing norm) noexcept { + return ((norm == Normalizing::NONE) + ? Regex::Options::None + : Regex::Options::IgnoreCase); +} + +} + +RegexpTerm::RegexpTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing) + : QueryTerm(std::move(result_base), term, index, type, normalizing), + _regexp(Regex::from_pattern({term.data(), term.size()}, normalize_mode_to_regex_opts(normalizing))) +{ +} + +RegexpTerm::~RegexpTerm() = default; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h new file mode 100644 index 00000000000..96d14eeb0bd --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h @@ -0,0 +1,25 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "queryterm.h" +#include <vespa/vespalib/regex/regex.h> + +namespace search::streaming { + +/** + * Query term that matches fields using a regular expression, with case sensitivity + * controlled by the provided Normalizing mode. + */ +class RegexpTerm : public QueryTerm { + vespalib::Regex _regexp; +public: + RegexpTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term, + const string& index, Type type, Normalizing normalizing); + ~RegexpTerm() override; + + RegexpTerm* as_regexp_term() noexcept override { return this; } + + [[nodiscard]] const vespalib::Regex& regexp() const noexcept { return _regexp; } +}; + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp new file mode 100644 index 00000000000..90d0be5d43c --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "weighted_set_term.h" +#include <vespa/searchlib/fef/itermdata.h> +#include <vespa/searchlib/fef/matchdata.h> +#include <vespa/vespalib/stllike/hash_map.hpp> + +using search::fef::ITermData; +using search::fef::MatchData; + +namespace search::streaming { + +WeightedSetTerm::WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms) + : MultiTerm(std::move(result_base), index, num_terms) +{ +} + +WeightedSetTerm::~WeightedSetTerm() = default; + +void +WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data) +{ + vespalib::hash_map<uint32_t,std::vector<double>> scores; + HitList hl_store; + for (const auto& term : _terms) { + auto& hl = term->evaluateHits(hl_store); + for (auto& hit : hl) { + scores[hit.context()].emplace_back(term->weight().percent()); + } + } + auto num_fields = td.numFields(); + for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) { + auto& tfd = td.field(field_idx); + auto field_id = tfd.getFieldId(); + if (scores.contains(field_id)) { + auto handle = tfd.getHandle(); + if (handle != fef::IllegalHandle) { + auto &field_scores = scores[field_id]; + std::sort(field_scores.begin(), field_scores.end(), std::greater()); + auto tmd = match_data.resolveTermField(tfd.getHandle()); + tmd->setFieldId(field_id); + tmd->reset(docid); + for (auto& field_score : field_scores) { + fef::TermFieldMatchDataPosition pos; + pos.setElementWeight(field_score); + tmd->appendPosition(pos); + } + } + } + } +} + +} diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h new file mode 100644 index 00000000000..4473e0fa45b --- /dev/null +++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h @@ -0,0 +1,20 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "multi_term.h" + +namespace search::streaming { + +/* + * A weighted set query term for streaming search. + */ +class WeightedSetTerm : public MultiTerm { + double _score_threshold; +public: + WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms); + ~WeightedSetTerm() override; + void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index a78dd092f5a..d998c2e343e 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -144,22 +144,6 @@ public: return (total_docs == 0) ? 0.0 : double(est) / double(total_docs); } - static double cost_of(const Children &children, auto flow) { - double cost = 0.0; - for (const auto &child: children) { - cost += flow.flow() * child->cost(); - flow.add(child->estimate()); - } - return cost; - } - - static double estimate_of(const Children &children, auto flow) { - for (const auto &child: children) { - flow.add(child->estimate()); - } - return flow.estimate(); - } - // utility that just takes maximum estimate static HitEstimate max(const std::vector<HitEstimate> &data); @@ -172,20 +156,6 @@ public: // lower limit for docid_limit: max child estimate static HitEstimate sat_sum(const std::vector<HitEstimate> &data, uint32_t docid_limit); - // sort children to minimize total cost of OR flow - struct MinimalOrCost { - bool operator () (const auto &a, const auto &b) const noexcept { - return a->estimate() / a->cost() > b->estimate() / b->cost(); - } - }; - - // sort children to minimize total cost of AND flow - struct MinimalAndCost { - bool operator () (const auto &a, const auto &b) const noexcept { - return (1.0 - a->estimate()) / a->cost() > (1.0 - b->estimate()) / b->cost(); - } - }; - // utility to get the greater estimate to sort first, higher tiers last struct TieredGreaterEstimate { bool operator () (const auto &a, const auto &b) const noexcept { diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h index 36c0a259feb..86ce6f8b93b 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow.h @@ -2,60 +2,261 @@ #pragma once #include <cstddef> - -namespace search::queryeval { +#include <algorithm> +#include <vespa/vespalib/util/small_vector.h> // Model how boolean result decisions flow through intermediate nodes // of different types based on relative estimates for sub-expressions -class AndFlow { +namespace search::queryeval { + +namespace flow { + +// the default adapter expects the shape of std::unique_ptr<Blueprint> +// with respect to estimate, cost and (coming soon) strict_cost. +struct DefaultAdapter { + double estimate(const auto &child) const noexcept { return child->estimate(); } + double cost(const auto &child) const noexcept { return child->cost(); } + // Estimate the per-document cost of strict evaluation of this + // child. This will typically be something like (estimate() * + // cost()) for leafs with posting lists. OR will aggregate strict + // cost by calculating the minimal OR flow of strict child + // costs. AND will aggregate strict cost by calculating the + // minimal AND flow where the cost of the first child is + // substituted by its strict cost. This value is currently not + // available in Blueprints. + double strict_cost(const auto &child) const noexcept { return child->cost(); } +}; + +template <typename ADAPTER, typename T> +struct IndirectAdapter { + const T &data; + [[no_unique_address]] ADAPTER adapter; + IndirectAdapter(ADAPTER adapter_in, const T &data_in) noexcept + : data(data_in), adapter(adapter_in) {} + double estimate(size_t child) const noexcept { return adapter.estimate(data[child]); } + double cost(size_t child) const noexcept { return adapter.cost(data[child]); } + double strict_cost(size_t child) const noexcept { return adapter.strict_cost(data[child]); } +}; + +auto make_index(const auto &children) { + vespalib::SmallVector<uint32_t> index(children.size()); + for (size_t i = 0; i < index.size(); ++i) { + index[i] = i; + } + return index; +} + +template <typename ADAPTER> +struct MinAndCost { + // sort children to minimize total cost of AND flow + [[no_unique_address]] ADAPTER adapter; + MinAndCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {} + bool operator () (const auto &a, const auto &b) const noexcept { + return (1.0 - adapter.estimate(a)) * adapter.cost(b) > (1.0 - adapter.estimate(b)) * adapter.cost(a); + } +}; + +template <typename ADAPTER> +struct MinOrCost { + // sort children to minimize total cost of OR flow + [[no_unique_address]] ADAPTER adapter; + MinOrCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {} + bool operator () (const auto &a, const auto &b) const noexcept { + return adapter.estimate(a) * adapter.cost(b) > adapter.estimate(b) * adapter.cost(a); + } +}; + +template <typename ADAPTER> +struct MinOrStrictCost { + // sort children to minimize total cost of strict OR flow + [[no_unique_address]] ADAPTER adapter; + MinOrStrictCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {} + bool operator () (const auto &a, const auto &b) const noexcept { + return adapter.estimate(a) * adapter.strict_cost(b) > adapter.estimate(b) * adapter.strict_cost(a); + } +}; + +template <typename ADAPTER, typename T, typename F> +double estimate_of(ADAPTER adapter, const T &children, F flow) { + for (const auto &child: children) { + flow.add(adapter.estimate(child)); + } + return flow.estimate(); +} + +template <template <typename> typename ORDER, typename ADAPTER, typename T> +void sort(ADAPTER adapter, T &children) { + std::sort(children.begin(), children.end(), ORDER(adapter)); +} + +template <template <typename> typename ORDER, typename ADAPTER, typename T> +void sort_partial(ADAPTER adapter, T &children, size_t offset) { + if (children.size() > offset) { + std::sort(children.begin() + offset, children.end(), ORDER(adapter)); + } +} + +template <typename ADAPTER, typename T, typename F> +double ordered_cost_of(ADAPTER adapter, const T &children, F flow) { + double cost = 0.0; + for (const auto &child: children) { + double child_cost = flow.strict() ? adapter.strict_cost(child) : adapter.cost(child); + cost += flow.flow() * child_cost; + flow.add(adapter.estimate(child)); + } + return cost; +} + +template <typename ADAPTER, typename T> +size_t select_strict_and_child(ADAPTER adapter, const T &children) { + size_t idx = 0; + double cost = 0.0; + size_t best_idx = 0; + double best_diff = 0.0; + double est = 1.0; + for (const auto &child: children) { + double child_cost = est * adapter.cost(child); + double child_strict_cost = adapter.strict_cost(child); + double child_est = adapter.estimate(child); + if (idx == 0) { + best_diff = child_strict_cost - child_cost; + } else { + double my_diff = (child_strict_cost + child_est * cost) - (cost + child_cost); + if (my_diff < best_diff) { + best_diff = my_diff; + best_idx = idx; + } + } + cost += child_cost; + est *= child_est; + ++idx; + } + return best_idx; +} + +} // flow + +template <typename FLOW> +struct FlowMixin { + static double estimate_of(auto adapter, const auto &children) { + return flow::estimate_of(adapter, children, FLOW(1.0, false)); + } + static double estimate_of(const auto &children) { + return estimate_of(flow::DefaultAdapter(), children); + } + static double cost_of(auto adapter, const auto &children, bool strict) { + auto my_adapter = flow::IndirectAdapter(adapter, children); + auto order = flow::make_index(children); + FLOW::sort(my_adapter, order, strict); + return flow::ordered_cost_of(my_adapter, order, FLOW(1.0, strict)); + } + static double cost_of(const auto &children, bool strict) { + return cost_of(flow::DefaultAdapter(), children, strict); + } + // TODO: remove + static double cost_of(const auto &children) { return cost_of(children, false); } +}; + +class AndFlow : public FlowMixin<AndFlow> { private: double _flow; - size_t _cnt; + bool _strict; + bool _first; public: - AndFlow(double in = 1.0) noexcept : _flow(in), _cnt(0) {} + AndFlow(double in, bool strict) noexcept + : _flow(in), _strict(strict), _first(true) {} void add(double est) noexcept { _flow *= est; - ++_cnt; + _first = false; } double flow() const noexcept { return _flow; } + bool strict() const noexcept { + return _strict && _first; + } double estimate() const noexcept { - return (_cnt > 0) ? _flow : 0.0; + return _first ? 0.0 : _flow; + } + static void sort(auto adapter, auto &children, bool strict) { + flow::sort<flow::MinAndCost>(adapter, children); + if (strict && children.size() > 1) { + size_t idx = flow::select_strict_and_child(adapter, children); + auto the_one = std::move(children[idx]); + for (; idx > 0; --idx) { + children[idx] = std::move(children[idx-1]); + } + children[0] = std::move(the_one); + } + } + // TODO: add strict + static void sort(auto &children) { + sort(flow::DefaultAdapter(), children, false); } }; -class OrFlow { +class OrFlow : public FlowMixin<OrFlow>{ private: double _flow; + bool _strict; + bool _first; public: - OrFlow(double in = 1.0) noexcept : _flow(in) {} + OrFlow(double in, bool strict) noexcept + : _flow(in), _strict(strict), _first(true) {} void add(double est) noexcept { _flow *= (1.0 - est); + _first = false; } double flow() const noexcept { return _flow; } + bool strict() const noexcept { + return _strict; + } double estimate() const noexcept { - return (1.0 - _flow); + return _first ? 0.0 : (1.0 - _flow); + } + static void sort(auto adapter, auto &children, bool strict) { + if (strict) { + flow::sort<flow::MinOrStrictCost>(adapter, children); + } else { + flow::sort<flow::MinOrCost>(adapter, children); + } + } + // TODO: add strict + static void sort(auto &children) { + sort(flow::DefaultAdapter(), children, false); } }; -class AndNotFlow { +class AndNotFlow : public FlowMixin<AndNotFlow> { private: double _flow; - size_t _cnt; + bool _strict; + bool _first; public: - AndNotFlow(double in = 1.0) noexcept : _flow(in), _cnt(0) {} + AndNotFlow(double in, bool strict) noexcept + : _flow(in), _strict(strict), _first(true) {} void add(double est) noexcept { - _flow *= (_cnt++ == 0) ? est : (1.0 - est); + _flow *= _first ? est : (1.0 - est); + _first = false; } double flow() const noexcept { return _flow; } + bool strict() const noexcept { + return _strict && _first; + } double estimate() const noexcept { - return (_cnt > 0) ? _flow : 0.0; + return _first ? 0.0 : _flow; + } + static void sort(auto adapter, auto &children, bool) { + flow::sort_partial<flow::MinOrCost>(adapter, children, 1); + } + // TODO: add strict + static void sort(auto &children) { + sort(flow::DefaultAdapter(), children, false); } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp index bebc1f433f7..e60fe3d3f85 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp @@ -89,13 +89,13 @@ need_normal_features_for_children(const IntermediateBlueprint &blueprint, fef::M double AndNotBlueprint::calculate_cost() const { - return cost_of(get_children(), AndNotFlow()); + return AndNotFlow::cost_of(get_children()); } double AndNotBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), AndNotFlow()); + return AndNotFlow::estimate_of(get_children()); } Blueprint::HitEstimate @@ -168,10 +168,10 @@ AndNotBlueprint::get_replacement() void AndNotBlueprint::sort(Children &children, bool sort_by_cost) const { - if (children.size() > 2) { - if (sort_by_cost) { - std::sort(children.begin() + 1, children.end(), MinimalOrCost()); - } else { + if (sort_by_cost) { + AndNotFlow::sort(children); + } else { + if (children.size() > 2) { std::sort(children.begin() + 1, children.end(), TieredGreaterEstimate()); } } @@ -214,12 +214,12 @@ AndNotBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) co double AndBlueprint::calculate_cost() const { - return cost_of(get_children(), AndFlow()); + return AndFlow::cost_of(get_children()); } double AndBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), AndFlow()); + return AndFlow::estimate_of(get_children()); } Blueprint::HitEstimate @@ -265,7 +265,7 @@ void AndBlueprint::sort(Children &children, bool sort_by_cost) const { if (sort_by_cost) { - std::sort(children.begin(), children.end(), MinimalAndCost()); + AndFlow::sort(children); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -323,12 +323,12 @@ OrBlueprint::~OrBlueprint() = default; double OrBlueprint::calculate_cost() const { - return cost_of(get_children(), OrFlow()); + return OrFlow::cost_of(get_children()); } double OrBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), OrFlow()); + return OrFlow::estimate_of(get_children()); } Blueprint::HitEstimate @@ -376,7 +376,7 @@ void OrBlueprint::sort(Children &children, bool sort_by_cost) const { if (sort_by_cost) { - std::sort(children.begin(), children.end(), MinimalOrCost()); + OrFlow::sort(children); } else { std::sort(children.begin(), children.end(), TieredGreaterEstimate()); } @@ -428,12 +428,12 @@ WeakAndBlueprint::~WeakAndBlueprint() = default; double WeakAndBlueprint::calculate_cost() const { - return cost_of(get_children(), OrFlow()); + return OrFlow::cost_of(get_children()); } double WeakAndBlueprint::calculate_relative_estimate() const { - double child_est = estimate_of(get_children(), OrFlow()); + double child_est = OrFlow::estimate_of(get_children()); double my_est = abs_to_rel_est(_n, get_docid_limit()); return std::min(my_est, child_est); } @@ -499,12 +499,12 @@ WeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) c double NearBlueprint::calculate_cost() const { - return cost_of(get_children(), AndFlow()) + childCnt() * 1.0; + return AndFlow::cost_of(get_children()) + childCnt() * 1.0; } double NearBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), AndFlow()); + return AndFlow::estimate_of(get_children()); } Blueprint::HitEstimate @@ -523,7 +523,7 @@ void NearBlueprint::sort(Children &children, bool sort_by_cost) const { if (sort_by_cost) { - std::sort(children.begin(), children.end(), MinimalAndCost()); + AndFlow::sort(children); } else { std::sort(children.begin(), children.end(), TieredLessEstimate()); } @@ -566,12 +566,12 @@ NearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons double ONearBlueprint::calculate_cost() const { - return cost_of(get_children(), AndFlow()) + (childCnt() * 1.0); + return AndFlow::cost_of(get_children()) + (childCnt() * 1.0); } double ONearBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), AndFlow()); + return AndFlow::estimate_of(get_children()); } Blueprint::HitEstimate @@ -741,7 +741,7 @@ SourceBlenderBlueprint::calculate_cost() const { double SourceBlenderBlueprint::calculate_relative_estimate() const { - return estimate_of(get_children(), OrFlow()); + return OrFlow::estimate_of(get_children()); } Blueprint::HitEstimate |