summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/query/streaming_query_test.cpp43
-rw-r--r--searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp4
-rw-r--r--searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp2
-rw-r--r--searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp256
-rw-r--r--searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp3
-rw-r--r--searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h (renamed from searchlib/src/vespa/searchlib/attribute/multi_term_filter.h)8
-rw-r--r--searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp (renamed from searchlib/src/vespa/searchlib/attribute/multi_term_filter.hpp)14
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/query.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/query.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.cpp24
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynode.h1
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp18
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h3
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/queryterm.h2
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp27
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/regexp_term.h25
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp53
-rw-r--r--searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h20
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/blueprint.h30
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/flow.h231
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp40
25 files changed, 689 insertions, 138 deletions
diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp
index 7c4b7555158..fe6149e6fba 100644
--- a/searchlib/src/tests/query/streaming_query_test.cpp
+++ b/searchlib/src/tests/query/streaming_query_test.cpp
@@ -7,6 +7,7 @@
#include <vespa/searchlib/query/streaming/query.h>
#include <vespa/searchlib/query/streaming/nearest_neighbor_query_node.h>
#include <vespa/searchlib/query/streaming/wand_term.h>
+#include <vespa/searchlib/query/streaming/weighted_set_term.h>
#include <vespa/searchlib/query/tree/querybuilder.h>
#include <vespa/searchlib/query/tree/simplequery.h>
#include <vespa/searchlib/query/tree/stackdumpcreator.h>
@@ -1020,6 +1021,48 @@ TEST(StreamingQueryTest, wand_term)
check_wand_term(exp_wand_score_field_11 + 1, "hidden score below limit");
}
+TEST(StreamingQueryTest, weighted_set_term)
+{
+ search::streaming::WeightedSetTerm term({}, "index", 2);
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "7", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(4));
+ term.add_term(std::make_unique<QueryTerm>(std::unique_ptr<QueryNodeResultBase>(), "9", "", QueryTermSimple::Type::WORD));
+ term.get_terms().back()->setWeight(Weight(13));
+ EXPECT_EQ(2, term.get_terms().size());
+ SimpleTermData td;
+ /*
+ * Search in fields 10, 11 and 12 (cf. fieldset in schema).
+ * Fields 11 and 12 have content for doc containing the keys.
+ * Fields 10 and 12 have valid handles and can be used for ranking.
+ * Field 11 does not have a valid handle, thus no associated match data.
+ */
+ td.addField(10);
+ td.addField(11);
+ td.addField(12);
+ td.lookupField(10)->setHandle(0);
+ td.lookupField(12)->setHandle(1);
+ EXPECT_FALSE(term.evaluate());
+ auto& q0 = *term.get_terms()[0];
+ q0.add(0, 11, 0, 10);
+ q0.add(0, 12, 0, 10);
+ auto& q1 = *term.get_terms()[1];
+ q1.add(0, 11, 0, 10);
+ q1.add(0, 12, 0, 10);
+ EXPECT_TRUE(term.evaluate());
+ MatchData md(MatchData::params().numTermFields(2));
+ term.unpack_match_data(23, td, md);
+ auto tmd0 = md.resolveTermField(0);
+ EXPECT_NE(23, tmd0->getDocId());
+ auto tmd1 = md.resolveTermField(1);
+ EXPECT_EQ(23, tmd1->getDocId());
+ using Weights = std::vector<int32_t>;
+ Weights weights;
+ for (auto& pos : *tmd1) {
+ weights.emplace_back(pos.getElementWeight());
+ }
+ EXPECT_EQ((Weights{13, 4}), weights);
+}
+
TEST(StreamingQueryTest, control_the_size_of_query_terms)
{
EXPECT_EQ(112u, sizeof(QueryTermSimple));
diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp
index f800e124bdc..bbd2744119a 100644
--- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp
+++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp
@@ -24,10 +24,10 @@ class MyOr : public IntermediateBlueprint
private:
public:
double calculate_cost() const final {
- return cost_of(get_children(), OrFlow());
+ return OrFlow::cost_of(get_children());
}
double calculate_relative_estimate() const final {
- return estimate_of(get_children(), OrFlow());
+ return OrFlow::estimate_of(get_children());
}
HitEstimate combine(const std::vector<HitEstimate> &data) const override {
return max(data);
diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
index ab1c004c721..856ac2391f8 100644
--- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
+++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp
@@ -1380,7 +1380,7 @@ TEST("cost for ONEAR") {
}
TEST("cost for WEAKAND") {
- verify_cost(make::WEAKAND(1000), calc_cost({{1.1, 0.8},{1.2, 0.7},{1.3, 0.5}}));
+ verify_cost(make::WEAKAND(1000), calc_cost({{1.3, 0.5},{1.2, 0.7},{1.1, 0.8}}));
}
TEST_MAIN() { TEST_DEBUG("lhs.out", "rhs.out"); TEST_RUN_ALL(); }
diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
index ceda30f169a..9a9adeac2bc 100644
--- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
+++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
@@ -5,44 +5,46 @@
#include <vector>
#include <random>
-using search::queryeval::AndFlow;
-using search::queryeval::OrFlow;
+constexpr size_t loop_cnt = 64;
+
+using namespace search::queryeval;
+
+struct ItemAdapter {
+ double estimate(const auto &child) const noexcept { return child.rel_est; }
+ double cost(const auto &child) const noexcept { return child.cost; }
+ double strict_cost(const auto &child) const noexcept { return child.strict_cost; }
+};
struct Item {
double rel_est;
double cost;
- Item(double rel_est_in, double cost_in) noexcept
- : rel_est(rel_est_in), cost(cost_in) {}
- static void sort_for_and(std::vector<Item> &data) {
- std::sort(data.begin(), data.end(), [](const Item &a, const Item &b) noexcept {
- return (1.0 - a.rel_est) / a.cost > (1.0 - b.rel_est) / b.cost;
- });
+ double strict_cost;
+ Item(double rel_est_in, double cost_in, double strict_cost_in) noexcept
+ : rel_est(rel_est_in), cost(cost_in), strict_cost(strict_cost_in) {}
+ template <typename FLOW> static double estimate_of(std::vector<Item> &data) {
+ return FLOW::estimate_of(ItemAdapter(), data);
}
- static void sort_for_or(std::vector<Item> &data) {
- std::sort(data.begin(), data.end(), [](const Item &a, const Item &b) noexcept {
- return a.rel_est / a.cost > b.rel_est / b.cost;
- });
+ template <typename FLOW> static void sort(std::vector<Item> &data, bool strict) {
+ FLOW::sort(ItemAdapter(), data, strict);
}
- static double cost_of(const std::vector<Item> &data, auto flow) {
- double cost = 0.0;
- for (const Item &item: data) {
- cost += flow.flow() * item.cost;
- flow.add(item.rel_est);
- }
- return cost;
+ template <typename FLOW> static double cost_of(const std::vector<Item> &data, bool strict) {
+ return FLOW::cost_of(ItemAdapter(), data, strict);
+ }
+ template <typename FLOW> static double ordered_cost_of(const std::vector<Item> &data, bool strict) {
+ return flow::ordered_cost_of(ItemAdapter(), data, FLOW(1.0, strict));
}
- static double cost_of_and(const std::vector<Item> &data) { return cost_of(data, AndFlow()); }
- static double cost_of_or(const std::vector<Item> &data) { return cost_of(data, OrFlow()); }
+ auto operator <=>(const Item &rhs) const noexcept = default;
};
std::vector<Item> gen_data(size_t size) {
static std::mt19937 gen;
- static std::uniform_real_distribution<double> rel_est(0.1, 0.9);
- static std::uniform_real_distribution<double> cost(1.0, 10.0);
+ static std::uniform_real_distribution<double> rel_est(0.1, 0.9);
+ static std::uniform_real_distribution<double> cost(1.0, 10.0);
+ static std::uniform_real_distribution<double> strict_cost(0.1, 5.0);
std::vector<Item> result;
result.reserve(size);
for (size_t i = 0; i < size; ++i) {
- result.emplace_back(rel_est(gen), cost(gen));
+ result.emplace_back(rel_est(gen), cost(gen), strict_cost(gen));
}
return result;
}
@@ -80,37 +82,191 @@ TEST(FlowTest, perm_test) {
EXPECT_EQ(seen.size(), 120);
}
+template <template <typename> typename ORDER>
+void verify_ordering_is_strict_weak() {
+ auto cmp = ORDER(ItemAdapter());
+ auto input = gen_data(7);
+ input.emplace_back(0.5, 1.5, 0.5);
+ input.emplace_back(0.5, 1.5, 0.5);
+ input.emplace_back(0.5, 1.5, 0.5);
+ input.emplace_back(0.0, 1.5, 0.5);
+ input.emplace_back(0.0, 1.5, 0.5);
+ input.emplace_back(0.5, 0.0, 0.5);
+ input.emplace_back(0.5, 0.0, 0.5);
+ input.emplace_back(0.5, 1.5, 0.0);
+ input.emplace_back(0.5, 1.5, 0.0);
+ input.emplace_back(0.0, 0.0, 0.0);
+ input.emplace_back(0.0, 0.0, 0.0);
+ std::vector<Item> output;
+ for (const Item &in: input) {
+ EXPECT_FALSE(cmp(in, in)); // Irreflexivity
+ size_t out_idx = 0;
+ bool lower = false;
+ bool upper = false;
+ for (const Item &out: output) {
+ if (cmp(out, in)) {
+ EXPECT_FALSE(cmp(in, out)); // Antisymmetry
+ EXPECT_FALSE(lower); // Transitivity
+ EXPECT_FALSE(upper); // Transitivity
+ ++out_idx;
+ } else {
+ lower = true;
+ if (cmp(in, out)) {
+ upper = true;
+ } else {
+ EXPECT_FALSE(upper); // Transitivity
+ }
+ }
+ }
+ output.insert(output.begin() + out_idx, in);
+ }
+}
+
+TEST(FlowTest, and_ordering_is_strict_weak) {
+ verify_ordering_is_strict_weak<flow::MinAndCost>();
+}
+
+TEST(FlowTest, or_ordering_is_strict_weak) {
+ verify_ordering_is_strict_weak<flow::MinOrCost>();
+}
+
+TEST(FlowTest, strict_or_ordering_is_strict_weak) {
+ verify_ordering_is_strict_weak<flow::MinOrStrictCost>();
+}
+
+struct ExpectFlow {
+ double flow;
+ double est;
+ bool strict;
+};
+
+void verify_flow(auto flow, const std::vector<double> &est_list, const std::vector<ExpectFlow> &expect) {
+ ASSERT_EQ(est_list.size() + 1, expect.size());
+ for (size_t i = 0; i < expect.size(); ++i) {
+ EXPECT_DOUBLE_EQ(flow.flow(), expect[i].flow);
+ EXPECT_DOUBLE_EQ(flow.estimate(), expect[i].est);
+ EXPECT_EQ(flow.strict(), expect[i].strict);
+ if (i < est_list.size()) {
+ flow.add(est_list[i]);
+ }
+ }
+}
+
+TEST(FlowTest, basic_and_flow) {
+ for (double in: {1.0, 0.5, 0.25}) {
+ for (bool strict: {false, true}) {
+ verify_flow(AndFlow(in, strict), {0.4, 0.7, 0.2},
+ {{in, 0.0, strict},
+ {in*0.4, in*0.4, false},
+ {in*0.4*0.7, in*0.4*0.7, false},
+ {in*0.4*0.7*0.2, in*0.4*0.7*0.2, false}});
+ }
+ }
+}
+
+TEST(FlowTest, basic_or_flow) {
+ for (double in: {1.0, 0.5, 0.25}) {
+ for (bool strict: {false, true}) {
+ verify_flow(OrFlow(in, strict), {0.4, 0.7, 0.2},
+ {{in, 0.0, strict},
+ {in*0.6, 1.0-in*0.6, strict},
+ {in*0.6*0.3, 1.0-in*0.6*0.3, strict},
+ {in*0.6*0.3*0.8, 1.0-in*0.6*0.3*0.8, strict}});
+ }
+ }
+}
+
+TEST(FlowTest, basic_and_not_flow) {
+ for (double in: {1.0, 0.5, 0.25}) {
+ for (bool strict: {false, true}) {
+ verify_flow(AndNotFlow(in, strict), {0.4, 0.7, 0.2},
+ {{in, 0.0, strict},
+ {in*0.4, in*0.4, false},
+ {in*0.4*0.3, in*0.4*0.3, false},
+ {in*0.4*0.3*0.8, in*0.4*0.3*0.8, false}});
+ }
+ }
+}
+
+TEST(FlowTest, flow_cost) {
+ std::vector<Item> data = {{0.4, 1.1, 0.6}, {0.7, 1.2, 0.5}, {0.2, 1.3, 0.4}};
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3);
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<OrFlow>(data, true), 0.6 + 0.6*0.5 + 0.6*0.3*0.4);
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3);
+ EXPECT_DOUBLE_EQ(Item::ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3);
+}
+
TEST(FlowTest, optimal_and_flow) {
- for (size_t i = 0; i < 256; ++i) {
- auto data = gen_data(7);
- Item::sort_for_and(data);
- double min_cost = Item::cost_of_and(data);
- double max_cost = 0.0;
- auto check = [min_cost,&max_cost](const std::vector<Item> &my_data) noexcept {
- double my_cost = Item::cost_of_and(my_data);
- EXPECT_LE(min_cost, my_cost);
- max_cost = std::max(max_cost, my_cost);
- };
- each_perm(data, check);
- fprintf(stderr, " and cost(%zu): min: %g, max: %g, factor: %g\n",
- i, min_cost, max_cost, max_cost / min_cost);
+ for (size_t i = 0; i < loop_cnt; ++i) {
+ for (bool strict: {false, true}) {
+ auto data = gen_data(7);
+ double ref_est = Item::estimate_of<AndFlow>(data);
+ double min_cost = Item::cost_of<AndFlow>(data, strict);
+ double max_cost = 0.0;
+ Item::sort<AndFlow>(data, strict);
+ EXPECT_EQ(Item::ordered_cost_of<AndFlow>(data, strict), min_cost);
+ auto check = [&](const std::vector<Item> &my_data) noexcept {
+ double my_cost = Item::ordered_cost_of<AndFlow>(my_data, strict);
+ EXPECT_LE(min_cost, my_cost);
+ max_cost = std::max(max_cost, my_cost);
+ };
+ each_perm(data, check);
+ if (loop_cnt < 1024 || i % 1024 == 0) {
+ fprintf(stderr, " AND cost(%zu,%s): min: %g, max: %g, factor: %g\n",
+ i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost);
+ }
+ EXPECT_NEAR(ref_est, Item::estimate_of<AndFlow>(data), 1e-9);
+ }
}
}
TEST(FlowTest, optimal_or_flow) {
- for (size_t i = 0; i < 256; ++i) {
- auto data = gen_data(7);
- Item::sort_for_or(data);
- double min_cost = Item::cost_of_or(data);
- double max_cost = 0.0;
- auto check = [min_cost,&max_cost](const std::vector<Item> &my_data) noexcept {
- double my_cost = Item::cost_of_or(my_data);
- EXPECT_LE(min_cost, my_cost);
- max_cost = std::max(max_cost, my_cost);
- };
- each_perm(data, check);
- fprintf(stderr, " or cost(%zu): min: %g, max: %g, factor: %g\n",
- i, min_cost, max_cost, max_cost / min_cost);
+ for (size_t i = 0; i < loop_cnt; ++i) {
+ for (bool strict: {false, true}) {
+ auto data = gen_data(7);
+ double min_cost = Item::cost_of<OrFlow>(data, strict);
+ double max_cost = 0.0;
+ Item::sort<OrFlow>(data, strict);
+ EXPECT_EQ(Item::ordered_cost_of<OrFlow>(data, strict), min_cost);
+ auto check = [&](const std::vector<Item> &my_data) noexcept {
+ double my_cost = Item::ordered_cost_of<OrFlow>(my_data, strict);
+ EXPECT_LE(min_cost, my_cost);
+ max_cost = std::max(max_cost, my_cost);
+ };
+ each_perm(data, check);
+ if (loop_cnt < 1024 || i % 1024 == 0) {
+ fprintf(stderr, " OR cost(%zu,%s): min: %g, max: %g, factor: %g\n",
+ i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost);
+ }
+ }
+ }
+}
+
+TEST(FlowTest, optimal_and_not_flow) {
+ for (size_t i = 0; i < loop_cnt; ++i) {
+ for (bool strict: {false, true}) {
+ auto data = gen_data(7);
+ Item first = data[0];
+ double min_cost = Item::cost_of<AndNotFlow>(data, strict);
+ double max_cost = 0.0;
+ Item::sort<AndNotFlow>(data, strict);
+ EXPECT_EQ(data[0], first);
+ EXPECT_EQ(Item::ordered_cost_of<AndNotFlow>(data, strict), min_cost);
+ auto check = [&](const std::vector<Item> &my_data) noexcept {
+ if (my_data[0] == first) {
+ double my_cost = Item::ordered_cost_of<AndNotFlow>(my_data, strict);
+ EXPECT_LE(min_cost, my_cost);
+ max_cost = std::max(max_cost, my_cost);
+ }
+ };
+ each_perm(data, check);
+ if (loop_cnt < 1024 || i % 1024 == 0) {
+ fprintf(stderr, " ANDNOT cost(%zu,%s): min: %g, max: %g, factor: %g\n",
+ i, strict ? "strict" : "non-strict", min_cost, max_cost, max_cost / min_cost);
+ }
+ }
}
}
diff --git a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
index aa6d922f23f..a9f549a0bd9 100644
--- a/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
+++ b/searchlib/src/tests/queryeval/parallel_weak_and/parallel_weak_and_test.cpp
@@ -391,8 +391,11 @@ struct HeapFixture
SearchIterator::UP sb(spec.create());
result.search(*sb);
}
+ ~HeapFixture();
};
+HeapFixture::~HeapFixture() = default;
+
TEST(ParallelWeakAndTest, require_that_scores_are_collected_in_batches_before_adjusting_heap)
{
HeapFixture f;
diff --git a/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp b/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp
index 9409b2b26c4..689f9f085d0 100644
--- a/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp
+++ b/searchlib/src/tests/queryeval/weak_and/weak_and_test.cpp
@@ -37,8 +37,11 @@ struct SimpleWandFixture {
SearchIterator::UP search(spec.create());
hits.search(*search);
}
+ ~SimpleWandFixture();
};
+SimpleWandFixture::~SimpleWandFixture() = default;
+
struct AdvancedWandFixture {
MyWandSpec spec;
SimpleResult hits;
@@ -51,8 +54,11 @@ struct AdvancedWandFixture {
SearchIterator::UP search(spec.create());
hits.search(*search);
}
+ ~AdvancedWandFixture();
};
+AdvancedWandFixture::~AdvancedWandFixture() = default;
+
struct WeightOrder {
bool operator()(const wand::Term &t1, const wand::Term &t2) const {
return (t1.weight < t2.weight);
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp
index 99d3ba3f7aa..01148c11c9c 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp
@@ -1,7 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "attribute_weighted_set_blueprint.h"
-#include "multi_term_filter.hpp"
+#include "multi_term_hash_filter.hpp"
#include <vespa/searchcommon/attribute/i_search_context.h>
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/fef/matchdatalayout.h>
@@ -73,7 +73,7 @@ make_multi_term_filter(fef::TermFieldMatchData& tfmd,
const std::vector<int32_t>& weights,
const std::vector<ISearchContext*>& contexts)
{
- using FilterType = attribute::MultiTermFilter<WrapperType>;
+ using FilterType = attribute::MultiTermHashFilter<WrapperType>;
typename FilterType::TokenMap tokens;
WrapperType wrapper(attr);
for (size_t i = 0; i < contexts.size(); ++i) {
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.h b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h
index adbf37d2dcd..9c3ea258fdc 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.h
+++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.h
@@ -18,7 +18,7 @@ namespace search::attribute {
* @tparam WrapperType Type that wraps an attribute vector and provides access to the attribute value for a given docid.
*/
template <typename WrapperType>
-class MultiTermFilter final : public queryeval::SearchIterator
+class MultiTermHashFilter final : public queryeval::SearchIterator
{
public:
using Key = typename WrapperType::TokenT;
@@ -31,9 +31,9 @@ private:
int32_t _weight;
public:
- MultiTermFilter(fef::TermFieldMatchData& tfmd,
- WrapperType attr,
- TokenMap&& map);
+ MultiTermHashFilter(fef::TermFieldMatchData& tfmd,
+ WrapperType attr,
+ TokenMap&& map);
void and_hits_into(BitVector& result, uint32_t begin_id) override;
void doSeek(uint32_t docId) override;
diff --git a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.hpp b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp
index dc572aedbff..96d5b3ac1f3 100644
--- a/searchlib/src/vespa/searchlib/attribute/multi_term_filter.hpp
+++ b/searchlib/src/vespa/searchlib/attribute/multi_term_hash_filter.hpp
@@ -2,16 +2,16 @@
#pragma once
-#include "multi_term_filter.h"
+#include "multi_term_hash_filter.h"
#include <vespa/searchlib/common/bitvector.h>
#include <vespa/searchlib/fef/termfieldmatchdata.h>
namespace search::attribute {
template <typename WrapperType>
-MultiTermFilter<WrapperType>::MultiTermFilter(fef::TermFieldMatchData& tfmd,
- WrapperType attr,
- TokenMap&& map)
+MultiTermHashFilter<WrapperType>::MultiTermHashFilter(fef::TermFieldMatchData& tfmd,
+ WrapperType attr,
+ TokenMap&& map)
: _tfmd(tfmd),
_attr(attr),
_map(std::move(map)),
@@ -21,7 +21,7 @@ MultiTermFilter<WrapperType>::MultiTermFilter(fef::TermFieldMatchData& tfmd,
template <typename WrapperType>
void
-MultiTermFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id)
+MultiTermHashFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id)
{
auto end = _map.end();
result.foreach_truebit([&, end](uint32_t key) { if ( _map.find(_attr.getToken(key)) == end) { result.clearBit(key); }}, begin_id);
@@ -29,7 +29,7 @@ MultiTermFilter<WrapperType>::and_hits_into(BitVector& result, uint32_t begin_id
template <typename WrapperType>
void
-MultiTermFilter<WrapperType>::doSeek(uint32_t docId)
+MultiTermHashFilter<WrapperType>::doSeek(uint32_t docId)
{
auto pos = _map.find(_attr.getToken(docId));
if (pos != _map.end()) {
@@ -40,7 +40,7 @@ MultiTermFilter<WrapperType>::doSeek(uint32_t docId)
template <typename WrapperType>
void
-MultiTermFilter<WrapperType>::doUnpack(uint32_t docId)
+MultiTermHashFilter<WrapperType>::doUnpack(uint32_t docId)
{
_tfmd.reset(docId);
fef::TermFieldMatchDataPosition pos;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
index 9b53407aff5..05a75f4662e 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/query/streaming/CMakeLists.txt
@@ -10,5 +10,7 @@ vespa_add_library(searchlib_query_streaming OBJECT
querynoderesultbase.cpp
queryterm.cpp
wand_term.cpp
+ weighted_set_term.cpp
+ regexp_term.cpp
DEPENDS
)
diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.cpp b/searchlib/src/vespa/searchlib/query/streaming/query.cpp
index 3079ec31e8f..ca742aabe26 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/query.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/query.cpp
@@ -107,9 +107,7 @@ QueryConnector::create(ParseItem::ItemType type)
case search::ParseItem::ITEM_AND: return std::make_unique<AndQueryNode>();
case search::ParseItem::ITEM_OR:
case search::ParseItem::ITEM_WEAK_AND: return std::make_unique<OrQueryNode>();
- case search::ParseItem::ITEM_WEIGHTED_SET:
case search::ParseItem::ITEM_EQUIV: return std::make_unique<EquivQueryNode>();
- case search::ParseItem::ITEM_WAND: return std::make_unique<OrQueryNode>();
case search::ParseItem::ITEM_NOT: return std::make_unique<AndNotQueryNode>();
case search::ParseItem::ITEM_PHRASE: return std::make_unique<PhraseQueryNode>();
case search::ParseItem::ITEM_SAME_ELEMENT: return std::make_unique<SameElementQueryNode>();
diff --git a/searchlib/src/vespa/searchlib/query/streaming/query.h b/searchlib/src/vespa/searchlib/query/streaming/query.h
index 8befa2fe7fa..84c693b86d0 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/query.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/query.h
@@ -103,8 +103,7 @@ public:
EquivQueryNode() noexcept : OrQueryNode("EQUIV") { }
bool evaluate() const override;
bool isFlattenable(ParseItem::ItemType type) const override {
- return (type == ParseItem::ITEM_EQUIV) ||
- (type == ParseItem::ITEM_WEIGHTED_SET);
+ return (type == ParseItem::ITEM_EQUIV);
}
};
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
index c24f41d16cf..2ee515f062a 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp
@@ -2,10 +2,12 @@
#include "query.h"
#include "nearest_neighbor_query_node.h"
+#include "regexp_term.h"
#include <vespa/searchlib/parsequery/stackdumpiterator.h>
#include <vespa/searchlib/query/streaming/dot_product_term.h>
#include <vespa/searchlib/query/streaming/in_term.h>
#include <vespa/searchlib/query/streaming/wand_term.h>
+#include <vespa/searchlib/query/streaming/weighted_set_term.h>
#include <vespa/searchlib/query/tree/term_vector.h>
#include <charconv>
#include <vespa/log/log.h>
@@ -40,7 +42,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_OR:
case ParseItem::ITEM_WEAK_AND:
case ParseItem::ITEM_EQUIV:
- case ParseItem::ITEM_WEIGHTED_SET:
case ParseItem::ITEM_NOT:
case ParseItem::ITEM_PHRASE:
case ParseItem::ITEM_SAME_ELEMENT:
@@ -55,7 +56,6 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
nqn->distance(queryRep.getNearDistance());
}
if ((type == ParseItem::ITEM_WEAK_AND) ||
- (type == ParseItem::ITEM_WEIGHTED_SET) ||
(type == ParseItem::ITEM_SAME_ELEMENT))
{
qn->setIndex(queryRep.getIndexName());
@@ -146,7 +146,12 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
qn = std::make_unique<TrueNode>();
} else {
Normalizing normalize_mode = factory.normalizing_mode(ssIndex);
- auto qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode);
+ std::unique_ptr<QueryTerm> qt;
+ if (sTerm != TermType::REGEXP) {
+ qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm, normalize_mode);
+ } else {
+ qt = std::make_unique<RegexpTerm>(factory.create(), ssTerm, ssIndex, TermType::REGEXP, normalize_mode);
+ }
qt->setWeight(queryRep.GetWeight());
qt->setUniqueId(queryRep.getUniqueId());
if (qt->isFuzzy()) {
@@ -192,6 +197,9 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor
case ParseItem::ITEM_WAND:
qn = build_wand_term(factory, queryRep);
break;
+ case ParseItem::ITEM_WEIGHTED_SET:
+ qn = build_weighted_set_term(factory, queryRep);
+ break;
default:
skip_unknown(queryRep);
break;
@@ -270,6 +278,16 @@ QueryNode::build_wand_term(const QueryNodeResultFactory& factory, SimpleQuerySta
return wand;
}
+std::unique_ptr<QueryNode>
+QueryNode::build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep)
+{
+ auto ws = std::make_unique<WeightedSetTerm>(factory.create(), queryRep.getIndexName(), queryRep.getArity());
+ ws->setWeight(queryRep.GetWeight());
+ ws->setUniqueId(queryRep.getUniqueId());
+ populate_multi_term(factory.normalizing_mode(ws->index()), *ws, queryRep);
+ return ws;
+}
+
void
QueryNode::skip_unknown(SimpleQueryStackDumpIterator& queryRep)
{
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.h b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
index a0561b2e52e..454932c0a68 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynode.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.h
@@ -32,6 +32,7 @@ class QueryNode
static void populate_multi_term(Normalizing string_normalize_mode, MultiTerm& mt, SimpleQueryStackDumpIterator& queryRep);
static std::unique_ptr<QueryNode> build_dot_product_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static std::unique_ptr<QueryNode> build_wand_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
+ static std::unique_ptr<QueryNode> build_weighted_set_term(const QueryNodeResultFactory& factory, SimpleQueryStackDumpIterator& queryRep);
static void skip_unknown(SimpleQueryStackDumpIterator& queryRep);
public:
using UP = std::unique_ptr<QueryNode>;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp
index c58ec55de9f..d72a3371846 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.cpp
@@ -3,4 +3,22 @@
namespace search::streaming {
+namespace {
+
+const char* to_str(Normalizing norm) noexcept {
+ switch (norm) {
+ case Normalizing::NONE: return "NONE";
+ case Normalizing::LOWERCASE: return "LOWERCASE";
+ case Normalizing::LOWERCASE_AND_FOLD: return "LOWERCASE_AND_FOLD";
+ }
+ abort();
+}
+
+}
+
+std::ostream& operator<<(std::ostream& os, Normalizing n) {
+ os << to_str(n);
+ return os;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h
index 74f872ad187..83fb27794a3 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/querynoderesultbase.h
@@ -2,6 +2,7 @@
#pragma once
#include <vespa/vespalib/stllike/string.h>
+#include <iosfwd>
#include <memory>
namespace search::streaming {
@@ -24,6 +25,8 @@ enum class Normalizing {
LOWERCASE_AND_FOLD
};
+std::ostream& operator<<(std::ostream&, Normalizing);
+
class QueryNodeResultFactory {
public:
virtual ~QueryNodeResultFactory() = default;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
index 3950a179d67..3e05d381ee2 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
+++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp
@@ -179,4 +179,10 @@ QueryTerm::as_multi_term() noexcept
return nullptr;
}
+RegexpTerm*
+QueryTerm::as_regexp_term() noexcept
+{
+ return nullptr;
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
index 743998a630e..cd2bdd7eaec 100644
--- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
+++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h
@@ -13,6 +13,7 @@ namespace search::streaming {
class NearestNeighborQueryNode;
class MultiTerm;
+class RegexpTerm;
/**
This is a leaf in the Query tree. All terms are leafs.
@@ -93,6 +94,7 @@ public:
void setFuzzyPrefixLength(uint32_t fuzzyPrefixLength) { _fuzzyPrefixLength = fuzzyPrefixLength; }
virtual NearestNeighborQueryNode* as_nearest_neighbor_query_node() noexcept;
virtual MultiTerm* as_multi_term() noexcept;
+ virtual RegexpTerm* as_regexp_term() noexcept;
protected:
using QueryNodeResultBaseContainer = std::unique_ptr<QueryNodeResultBase>;
string _index;
diff --git a/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp
new file mode 100644
index 00000000000..4508caa7072
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.cpp
@@ -0,0 +1,27 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include "regexp_term.h"
+
+namespace search::streaming {
+
+using vespalib::Regex;
+
+namespace {
+
+constexpr Regex::Options normalize_mode_to_regex_opts(Normalizing norm) noexcept {
+ return ((norm == Normalizing::NONE)
+ ? Regex::Options::None
+ : Regex::Options::IgnoreCase);
+}
+
+}
+
+RegexpTerm::RegexpTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term,
+ const string& index, Type type, Normalizing normalizing)
+ : QueryTerm(std::move(result_base), term, index, type, normalizing),
+ _regexp(Regex::from_pattern({term.data(), term.size()}, normalize_mode_to_regex_opts(normalizing)))
+{
+}
+
+RegexpTerm::~RegexpTerm() = default;
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h
new file mode 100644
index 00000000000..96d14eeb0bd
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/regexp_term.h
@@ -0,0 +1,25 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#pragma once
+
+#include "queryterm.h"
+#include <vespa/vespalib/regex/regex.h>
+
+namespace search::streaming {
+
+/**
+ * Query term that matches fields using a regular expression, with case sensitivity
+ * controlled by the provided Normalizing mode.
+ */
+class RegexpTerm : public QueryTerm {
+ vespalib::Regex _regexp;
+public:
+ RegexpTerm(std::unique_ptr<QueryNodeResultBase> result_base, stringref term,
+ const string& index, Type type, Normalizing normalizing);
+ ~RegexpTerm() override;
+
+ RegexpTerm* as_regexp_term() noexcept override { return this; }
+
+ [[nodiscard]] const vespalib::Regex& regexp() const noexcept { return _regexp; }
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp
new file mode 100644
index 00000000000..90d0be5d43c
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.cpp
@@ -0,0 +1,53 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "weighted_set_term.h"
+#include <vespa/searchlib/fef/itermdata.h>
+#include <vespa/searchlib/fef/matchdata.h>
+#include <vespa/vespalib/stllike/hash_map.hpp>
+
+using search::fef::ITermData;
+using search::fef::MatchData;
+
+namespace search::streaming {
+
+WeightedSetTerm::WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string & index, uint32_t num_terms)
+ : MultiTerm(std::move(result_base), index, num_terms)
+{
+}
+
+WeightedSetTerm::~WeightedSetTerm() = default;
+
+void
+WeightedSetTerm::unpack_match_data(uint32_t docid, const ITermData& td, MatchData& match_data)
+{
+ vespalib::hash_map<uint32_t,std::vector<double>> scores;
+ HitList hl_store;
+ for (const auto& term : _terms) {
+ auto& hl = term->evaluateHits(hl_store);
+ for (auto& hit : hl) {
+ scores[hit.context()].emplace_back(term->weight().percent());
+ }
+ }
+ auto num_fields = td.numFields();
+ for (uint32_t field_idx = 0; field_idx < num_fields; ++field_idx) {
+ auto& tfd = td.field(field_idx);
+ auto field_id = tfd.getFieldId();
+ if (scores.contains(field_id)) {
+ auto handle = tfd.getHandle();
+ if (handle != fef::IllegalHandle) {
+ auto &field_scores = scores[field_id];
+ std::sort(field_scores.begin(), field_scores.end(), std::greater());
+ auto tmd = match_data.resolveTermField(tfd.getHandle());
+ tmd->setFieldId(field_id);
+ tmd->reset(docid);
+ for (auto& field_score : field_scores) {
+ fef::TermFieldMatchDataPosition pos;
+ pos.setElementWeight(field_score);
+ tmd->appendPosition(pos);
+ }
+ }
+ }
+ }
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h
new file mode 100644
index 00000000000..4473e0fa45b
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/query/streaming/weighted_set_term.h
@@ -0,0 +1,20 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include "multi_term.h"
+
+namespace search::streaming {
+
+/*
+ * A weighted set query term for streaming search.
+ */
+class WeightedSetTerm : public MultiTerm {
+ double _score_threshold;
+public:
+ WeightedSetTerm(std::unique_ptr<QueryNodeResultBase> result_base, const string& index, uint32_t num_terms);
+ ~WeightedSetTerm() override;
+ void unpack_match_data(uint32_t docid, const fef::ITermData& td, fef::MatchData& match_data) override;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h
index a78dd092f5a..d998c2e343e 100644
--- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h
@@ -144,22 +144,6 @@ public:
return (total_docs == 0) ? 0.0 : double(est) / double(total_docs);
}
- static double cost_of(const Children &children, auto flow) {
- double cost = 0.0;
- for (const auto &child: children) {
- cost += flow.flow() * child->cost();
- flow.add(child->estimate());
- }
- return cost;
- }
-
- static double estimate_of(const Children &children, auto flow) {
- for (const auto &child: children) {
- flow.add(child->estimate());
- }
- return flow.estimate();
- }
-
// utility that just takes maximum estimate
static HitEstimate max(const std::vector<HitEstimate> &data);
@@ -172,20 +156,6 @@ public:
// lower limit for docid_limit: max child estimate
static HitEstimate sat_sum(const std::vector<HitEstimate> &data, uint32_t docid_limit);
- // sort children to minimize total cost of OR flow
- struct MinimalOrCost {
- bool operator () (const auto &a, const auto &b) const noexcept {
- return a->estimate() / a->cost() > b->estimate() / b->cost();
- }
- };
-
- // sort children to minimize total cost of AND flow
- struct MinimalAndCost {
- bool operator () (const auto &a, const auto &b) const noexcept {
- return (1.0 - a->estimate()) / a->cost() > (1.0 - b->estimate()) / b->cost();
- }
- };
-
// utility to get the greater estimate to sort first, higher tiers last
struct TieredGreaterEstimate {
bool operator () (const auto &a, const auto &b) const noexcept {
diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h
index 36c0a259feb..86ce6f8b93b 100644
--- a/searchlib/src/vespa/searchlib/queryeval/flow.h
+++ b/searchlib/src/vespa/searchlib/queryeval/flow.h
@@ -2,60 +2,261 @@
#pragma once
#include <cstddef>
-
-namespace search::queryeval {
+#include <algorithm>
+#include <vespa/vespalib/util/small_vector.h>
// Model how boolean result decisions flow through intermediate nodes
// of different types based on relative estimates for sub-expressions
-class AndFlow {
+namespace search::queryeval {
+
+namespace flow {
+
+// the default adapter expects the shape of std::unique_ptr<Blueprint>
+// with respect to estimate, cost and (coming soon) strict_cost.
+struct DefaultAdapter {
+ double estimate(const auto &child) const noexcept { return child->estimate(); }
+ double cost(const auto &child) const noexcept { return child->cost(); }
+ // Estimate the per-document cost of strict evaluation of this
+ // child. This will typically be something like (estimate() *
+ // cost()) for leafs with posting lists. OR will aggregate strict
+ // cost by calculating the minimal OR flow of strict child
+ // costs. AND will aggregate strict cost by calculating the
+ // minimal AND flow where the cost of the first child is
+ // substituted by its strict cost. This value is currently not
+ // available in Blueprints.
+ double strict_cost(const auto &child) const noexcept { return child->cost(); }
+};
+
+template <typename ADAPTER, typename T>
+struct IndirectAdapter {
+ const T &data;
+ [[no_unique_address]] ADAPTER adapter;
+ IndirectAdapter(ADAPTER adapter_in, const T &data_in) noexcept
+ : data(data_in), adapter(adapter_in) {}
+ double estimate(size_t child) const noexcept { return adapter.estimate(data[child]); }
+ double cost(size_t child) const noexcept { return adapter.cost(data[child]); }
+ double strict_cost(size_t child) const noexcept { return adapter.strict_cost(data[child]); }
+};
+
+auto make_index(const auto &children) {
+ vespalib::SmallVector<uint32_t> index(children.size());
+ for (size_t i = 0; i < index.size(); ++i) {
+ index[i] = i;
+ }
+ return index;
+}
+
+template <typename ADAPTER>
+struct MinAndCost {
+ // sort children to minimize total cost of AND flow
+ [[no_unique_address]] ADAPTER adapter;
+ MinAndCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {}
+ bool operator () (const auto &a, const auto &b) const noexcept {
+ return (1.0 - adapter.estimate(a)) * adapter.cost(b) > (1.0 - adapter.estimate(b)) * adapter.cost(a);
+ }
+};
+
+template <typename ADAPTER>
+struct MinOrCost {
+ // sort children to minimize total cost of OR flow
+ [[no_unique_address]] ADAPTER adapter;
+ MinOrCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {}
+ bool operator () (const auto &a, const auto &b) const noexcept {
+ return adapter.estimate(a) * adapter.cost(b) > adapter.estimate(b) * adapter.cost(a);
+ }
+};
+
+template <typename ADAPTER>
+struct MinOrStrictCost {
+ // sort children to minimize total cost of strict OR flow
+ [[no_unique_address]] ADAPTER adapter;
+ MinOrStrictCost(ADAPTER adapter_in) noexcept : adapter(adapter_in) {}
+ bool operator () (const auto &a, const auto &b) const noexcept {
+ return adapter.estimate(a) * adapter.strict_cost(b) > adapter.estimate(b) * adapter.strict_cost(a);
+ }
+};
+
+template <typename ADAPTER, typename T, typename F>
+double estimate_of(ADAPTER adapter, const T &children, F flow) {
+ for (const auto &child: children) {
+ flow.add(adapter.estimate(child));
+ }
+ return flow.estimate();
+}
+
+template <template <typename> typename ORDER, typename ADAPTER, typename T>
+void sort(ADAPTER adapter, T &children) {
+ std::sort(children.begin(), children.end(), ORDER(adapter));
+}
+
+template <template <typename> typename ORDER, typename ADAPTER, typename T>
+void sort_partial(ADAPTER adapter, T &children, size_t offset) {
+ if (children.size() > offset) {
+ std::sort(children.begin() + offset, children.end(), ORDER(adapter));
+ }
+}
+
+template <typename ADAPTER, typename T, typename F>
+double ordered_cost_of(ADAPTER adapter, const T &children, F flow) {
+ double cost = 0.0;
+ for (const auto &child: children) {
+ double child_cost = flow.strict() ? adapter.strict_cost(child) : adapter.cost(child);
+ cost += flow.flow() * child_cost;
+ flow.add(adapter.estimate(child));
+ }
+ return cost;
+}
+
+template <typename ADAPTER, typename T>
+size_t select_strict_and_child(ADAPTER adapter, const T &children) {
+ size_t idx = 0;
+ double cost = 0.0;
+ size_t best_idx = 0;
+ double best_diff = 0.0;
+ double est = 1.0;
+ for (const auto &child: children) {
+ double child_cost = est * adapter.cost(child);
+ double child_strict_cost = adapter.strict_cost(child);
+ double child_est = adapter.estimate(child);
+ if (idx == 0) {
+ best_diff = child_strict_cost - child_cost;
+ } else {
+ double my_diff = (child_strict_cost + child_est * cost) - (cost + child_cost);
+ if (my_diff < best_diff) {
+ best_diff = my_diff;
+ best_idx = idx;
+ }
+ }
+ cost += child_cost;
+ est *= child_est;
+ ++idx;
+ }
+ return best_idx;
+}
+
+} // flow
+
+template <typename FLOW>
+struct FlowMixin {
+ static double estimate_of(auto adapter, const auto &children) {
+ return flow::estimate_of(adapter, children, FLOW(1.0, false));
+ }
+ static double estimate_of(const auto &children) {
+ return estimate_of(flow::DefaultAdapter(), children);
+ }
+ static double cost_of(auto adapter, const auto &children, bool strict) {
+ auto my_adapter = flow::IndirectAdapter(adapter, children);
+ auto order = flow::make_index(children);
+ FLOW::sort(my_adapter, order, strict);
+ return flow::ordered_cost_of(my_adapter, order, FLOW(1.0, strict));
+ }
+ static double cost_of(const auto &children, bool strict) {
+ return cost_of(flow::DefaultAdapter(), children, strict);
+ }
+ // TODO: remove
+ static double cost_of(const auto &children) { return cost_of(children, false); }
+};
+
+class AndFlow : public FlowMixin<AndFlow> {
private:
double _flow;
- size_t _cnt;
+ bool _strict;
+ bool _first;
public:
- AndFlow(double in = 1.0) noexcept : _flow(in), _cnt(0) {}
+ AndFlow(double in, bool strict) noexcept
+ : _flow(in), _strict(strict), _first(true) {}
void add(double est) noexcept {
_flow *= est;
- ++_cnt;
+ _first = false;
}
double flow() const noexcept {
return _flow;
}
+ bool strict() const noexcept {
+ return _strict && _first;
+ }
double estimate() const noexcept {
- return (_cnt > 0) ? _flow : 0.0;
+ return _first ? 0.0 : _flow;
+ }
+ static void sort(auto adapter, auto &children, bool strict) {
+ flow::sort<flow::MinAndCost>(adapter, children);
+ if (strict && children.size() > 1) {
+ size_t idx = flow::select_strict_and_child(adapter, children);
+ auto the_one = std::move(children[idx]);
+ for (; idx > 0; --idx) {
+ children[idx] = std::move(children[idx-1]);
+ }
+ children[0] = std::move(the_one);
+ }
+ }
+ // TODO: add strict
+ static void sort(auto &children) {
+ sort(flow::DefaultAdapter(), children, false);
}
};
-class OrFlow {
+class OrFlow : public FlowMixin<OrFlow>{
private:
double _flow;
+ bool _strict;
+ bool _first;
public:
- OrFlow(double in = 1.0) noexcept : _flow(in) {}
+ OrFlow(double in, bool strict) noexcept
+ : _flow(in), _strict(strict), _first(true) {}
void add(double est) noexcept {
_flow *= (1.0 - est);
+ _first = false;
}
double flow() const noexcept {
return _flow;
}
+ bool strict() const noexcept {
+ return _strict;
+ }
double estimate() const noexcept {
- return (1.0 - _flow);
+ return _first ? 0.0 : (1.0 - _flow);
+ }
+ static void sort(auto adapter, auto &children, bool strict) {
+ if (strict) {
+ flow::sort<flow::MinOrStrictCost>(adapter, children);
+ } else {
+ flow::sort<flow::MinOrCost>(adapter, children);
+ }
+ }
+ // TODO: add strict
+ static void sort(auto &children) {
+ sort(flow::DefaultAdapter(), children, false);
}
};
-class AndNotFlow {
+class AndNotFlow : public FlowMixin<AndNotFlow> {
private:
double _flow;
- size_t _cnt;
+ bool _strict;
+ bool _first;
public:
- AndNotFlow(double in = 1.0) noexcept : _flow(in), _cnt(0) {}
+ AndNotFlow(double in, bool strict) noexcept
+ : _flow(in), _strict(strict), _first(true) {}
void add(double est) noexcept {
- _flow *= (_cnt++ == 0) ? est : (1.0 - est);
+ _flow *= _first ? est : (1.0 - est);
+ _first = false;
}
double flow() const noexcept {
return _flow;
}
+ bool strict() const noexcept {
+ return _strict && _first;
+ }
double estimate() const noexcept {
- return (_cnt > 0) ? _flow : 0.0;
+ return _first ? 0.0 : _flow;
+ }
+ static void sort(auto adapter, auto &children, bool) {
+ flow::sort_partial<flow::MinOrCost>(adapter, children, 1);
+ }
+ // TODO: add strict
+ static void sort(auto &children) {
+ sort(flow::DefaultAdapter(), children, false);
}
};
diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp
index bebc1f433f7..e60fe3d3f85 100644
--- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp
@@ -89,13 +89,13 @@ need_normal_features_for_children(const IntermediateBlueprint &blueprint, fef::M
double
AndNotBlueprint::calculate_cost() const
{
- return cost_of(get_children(), AndNotFlow());
+ return AndNotFlow::cost_of(get_children());
}
double
AndNotBlueprint::calculate_relative_estimate() const
{
- return estimate_of(get_children(), AndNotFlow());
+ return AndNotFlow::estimate_of(get_children());
}
Blueprint::HitEstimate
@@ -168,10 +168,10 @@ AndNotBlueprint::get_replacement()
void
AndNotBlueprint::sort(Children &children, bool sort_by_cost) const
{
- if (children.size() > 2) {
- if (sort_by_cost) {
- std::sort(children.begin() + 1, children.end(), MinimalOrCost());
- } else {
+ if (sort_by_cost) {
+ AndNotFlow::sort(children);
+ } else {
+ if (children.size() > 2) {
std::sort(children.begin() + 1, children.end(), TieredGreaterEstimate());
}
}
@@ -214,12 +214,12 @@ AndNotBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) co
double
AndBlueprint::calculate_cost() const {
- return cost_of(get_children(), AndFlow());
+ return AndFlow::cost_of(get_children());
}
double
AndBlueprint::calculate_relative_estimate() const {
- return estimate_of(get_children(), AndFlow());
+ return AndFlow::estimate_of(get_children());
}
Blueprint::HitEstimate
@@ -265,7 +265,7 @@ void
AndBlueprint::sort(Children &children, bool sort_by_cost) const
{
if (sort_by_cost) {
- std::sort(children.begin(), children.end(), MinimalAndCost());
+ AndFlow::sort(children);
} else {
std::sort(children.begin(), children.end(), TieredLessEstimate());
}
@@ -323,12 +323,12 @@ OrBlueprint::~OrBlueprint() = default;
double
OrBlueprint::calculate_cost() const {
- return cost_of(get_children(), OrFlow());
+ return OrFlow::cost_of(get_children());
}
double
OrBlueprint::calculate_relative_estimate() const {
- return estimate_of(get_children(), OrFlow());
+ return OrFlow::estimate_of(get_children());
}
Blueprint::HitEstimate
@@ -376,7 +376,7 @@ void
OrBlueprint::sort(Children &children, bool sort_by_cost) const
{
if (sort_by_cost) {
- std::sort(children.begin(), children.end(), MinimalOrCost());
+ OrFlow::sort(children);
} else {
std::sort(children.begin(), children.end(), TieredGreaterEstimate());
}
@@ -428,12 +428,12 @@ WeakAndBlueprint::~WeakAndBlueprint() = default;
double
WeakAndBlueprint::calculate_cost() const {
- return cost_of(get_children(), OrFlow());
+ return OrFlow::cost_of(get_children());
}
double
WeakAndBlueprint::calculate_relative_estimate() const {
- double child_est = estimate_of(get_children(), OrFlow());
+ double child_est = OrFlow::estimate_of(get_children());
double my_est = abs_to_rel_est(_n, get_docid_limit());
return std::min(my_est, child_est);
}
@@ -499,12 +499,12 @@ WeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) c
double
NearBlueprint::calculate_cost() const {
- return cost_of(get_children(), AndFlow()) + childCnt() * 1.0;
+ return AndFlow::cost_of(get_children()) + childCnt() * 1.0;
}
double
NearBlueprint::calculate_relative_estimate() const {
- return estimate_of(get_children(), AndFlow());
+ return AndFlow::estimate_of(get_children());
}
Blueprint::HitEstimate
@@ -523,7 +523,7 @@ void
NearBlueprint::sort(Children &children, bool sort_by_cost) const
{
if (sort_by_cost) {
- std::sort(children.begin(), children.end(), MinimalAndCost());
+ AndFlow::sort(children);
} else {
std::sort(children.begin(), children.end(), TieredLessEstimate());
}
@@ -566,12 +566,12 @@ NearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons
double
ONearBlueprint::calculate_cost() const {
- return cost_of(get_children(), AndFlow()) + (childCnt() * 1.0);
+ return AndFlow::cost_of(get_children()) + (childCnt() * 1.0);
}
double
ONearBlueprint::calculate_relative_estimate() const {
- return estimate_of(get_children(), AndFlow());
+ return AndFlow::estimate_of(get_children());
}
Blueprint::HitEstimate
@@ -741,7 +741,7 @@ SourceBlenderBlueprint::calculate_cost() const {
double
SourceBlenderBlueprint::calculate_relative_estimate() const {
- return estimate_of(get_children(), OrFlow());
+ return OrFlow::estimate_of(get_children());
}
Blueprint::HitEstimate