summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2024-02-23 15:48:35 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2024-02-23 15:49:27 +0000
commit71800872d9292bfb9faf8684855a88dd6ee05f2a (patch)
tree9de8d3457086ca8ec856a159e17bb5339a470a58 /searchlib
parent16088f3e45877a3417a85ef31a8678aba00e56f7 (diff)
added generic type-erased flow class
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp49
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/flow.h50
2 files changed, 79 insertions, 20 deletions
diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
index 8b8b6c1282e..be70a037c98 100644
--- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
+++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp
@@ -14,6 +14,20 @@ double ordered_cost_of(const std::vector<FlowStats> &data, bool strict) {
return flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict));
}
+template <typename FLOW>
+double dual_ordered_cost_of(const std::vector<FlowStats> &data, bool strict) {
+ double result = flow::ordered_cost_of(flow::DirectAdapter(), data, FLOW(strict));
+ AnyFlow any_flow = AnyFlow::create<FLOW>(strict);
+ double total_cost = 0.0;
+ for (const auto &item: data) {
+ double child_cost = any_flow.strict() ? item.strict_cost : any_flow.flow() * item.cost;
+ any_flow.update_cost(total_cost, child_cost);
+ any_flow.add(item.estimate);
+ }
+ EXPECT_EQ(total_cost, result);
+ return result;
+}
+
std::vector<FlowStats> gen_data(size_t size) {
static std::mt19937 gen;
static std::uniform_real_distribution<double> estimate(0.1, 0.9);
@@ -115,13 +129,22 @@ struct ExpectFlow {
};
void verify_flow(auto flow, const std::vector<double> &est_list, const std::vector<ExpectFlow> &expect) {
+ FlowCalc calc = flow_calc<decltype(flow)>(InFlow(flow.strict(), flow.flow()));
+ AnyFlow any_flow = AnyFlow::create<decltype(flow)>(InFlow(flow.strict(), flow.flow()));
ASSERT_EQ(est_list.size() + 1, expect.size());
for (size_t i = 0; i < expect.size(); ++i) {
+ EXPECT_EQ(any_flow.flow(), flow.flow());
+ EXPECT_EQ(any_flow.estimate(), flow.estimate());
+ EXPECT_EQ(any_flow.strict(), flow.strict());
EXPECT_DOUBLE_EQ(flow.flow(), expect[i].flow);
EXPECT_DOUBLE_EQ(flow.estimate(), expect[i].est);
EXPECT_EQ(flow.strict(), expect[i].strict);
if (i < est_list.size()) {
+ EXPECT_EQ(calc(est_list[i]), flow.flow());
flow.add(est_list[i]);
+ any_flow.add(est_list[i]);
+ } else {
+ EXPECT_EQ(calc(0.5), flow.flow());
}
}
}
@@ -141,8 +164,6 @@ TEST(FlowTest, full_and_flow) {
{0.4, 0.4, false},
{0.4*0.7, 0.4*0.7, false},
{0.4*0.7*0.2, 0.4*0.7*0.2, false}});
- verify_flow_calc(flow_calc<AndFlow>(strict),
- {0.4, 0.7, 0.2}, {1.0, 0.4, 0.4*0.7, 0.4*0.7*0.2});
}
}
@@ -153,8 +174,6 @@ TEST(FlowTest, partial_and_flow) {
{in*0.4, in*0.4, false},
{in*0.4*0.7, in*0.4*0.7, false},
{in*0.4*0.7*0.2, in*0.4*0.7*0.2, false}});
- verify_flow_calc(flow_calc<AndFlow>(in),
- {0.4, 0.7, 0.2}, {in*1.0, in*0.4, in*0.4*0.7, in*0.4*0.7*0.2});
}
}
@@ -164,15 +183,11 @@ TEST(FlowTest, full_or_flow) {
{0.6, 1.0-0.6, false},
{0.6*0.3, 1.0-0.6*0.3, false},
{0.6*0.3*0.8, 1.0-0.6*0.3*0.8, false}});
- verify_flow_calc(flow_calc<OrFlow>(1.0),
- {0.4, 0.7, 0.2}, {1.0, 0.6, 0.6*0.3, 0.6*0.3*0.8});
verify_flow(OrFlow(true), {0.4, 0.7, 0.2},
{{1.0, 0.0, true},
{1.0, 1.0-0.6, true},
{1.0, 1.0-0.6*0.3, true},
{1.0, 1.0-0.6*0.3*0.8, true}});
- verify_flow_calc(flow_calc<OrFlow>(true),
- {0.4, 0.7, 0.2}, {1.0, 1.0, 1.0, 1.0});
}
TEST(FlowTest, partial_or_flow) {
@@ -182,8 +197,6 @@ TEST(FlowTest, partial_or_flow) {
{in*0.6, 1.0-in*0.6, false},
{in*0.6*0.3, 1.0-in*0.6*0.3, false},
{in*0.6*0.3*0.8, 1.0-in*0.6*0.3*0.8, false}});
- verify_flow_calc(flow_calc<OrFlow>(in),
- {0.4, 0.7, 0.2}, {in, in*0.6, in*0.6*0.3, in*0.6*0.3*0.8});
}
}
@@ -194,8 +207,6 @@ TEST(FlowTest, full_and_not_flow) {
{0.4, 0.4, false},
{0.4*0.3, 0.4*0.3, false},
{0.4*0.3*0.8, 0.4*0.3*0.8, false}});
- verify_flow_calc(flow_calc<AndNotFlow>(strict),
- {0.4, 0.7, 0.2}, {1.0, 0.4, 0.4*0.3, 0.4*0.3*0.8});
}
}
@@ -206,8 +217,6 @@ TEST(FlowTest, partial_and_not_flow) {
{in*0.4, in*0.4, false},
{in*0.4*0.3, in*0.4*0.3, false},
{in*0.4*0.3*0.8, in*0.4*0.3*0.8, false}});
- verify_flow_calc(flow_calc<AndNotFlow>(in),
- {0.4, 0.7, 0.2}, {in, in*0.4, in*0.4*0.3, in*0.4*0.3*0.8});
}
}
@@ -256,12 +265,12 @@ TEST(FlowTest, in_flow_strict_vs_rate_interaction) {
TEST(FlowTest, flow_cost) {
std::vector<FlowStats> data = {{0.4, 1.1, 0.6}, {0.7, 1.2, 0.5}, {0.2, 1.3, 0.4}};
- EXPECT_DOUBLE_EQ(ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3);
- EXPECT_DOUBLE_EQ(ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3);
- EXPECT_DOUBLE_EQ(ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3);
- EXPECT_DOUBLE_EQ(ordered_cost_of<OrFlow>(data, true), 0.6 + 0.5 + 0.4);
- EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3);
- EXPECT_DOUBLE_EQ(ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.7*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, false), 1.1 + 0.6*1.2 + 0.6*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<OrFlow>(data, true), 0.6 + 0.5 + 0.4);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, false), 1.1 + 0.4*1.2 + 0.4*0.3*1.3);
+ EXPECT_DOUBLE_EQ(dual_ordered_cost_of<AndNotFlow>(data, true), 0.6 + 0.4*1.2 + 0.4*0.3*1.3);
}
TEST(FlowTest, optimal_and_flow) {
diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h
index ade2516b509..9dd6d82a491 100644
--- a/searchlib/src/vespa/searchlib/queryeval/flow.h
+++ b/searchlib/src/vespa/searchlib/queryeval/flow.h
@@ -328,4 +328,54 @@ inline FlowCalc full_flow_calc(InFlow in_flow) {
return [flow](double) noexcept { return flow; };
}
+// type-erased flow wrapper
+class AnyFlow {
+private:
+ struct API {
+ virtual void add(double est) noexcept = 0;
+ virtual double flow() const noexcept = 0;
+ virtual bool strict() const noexcept = 0;
+ virtual double estimate() const noexcept = 0;
+ virtual void update_cost(double &total_cost, double child_cost) noexcept = 0;
+ virtual ~API() = default;
+ };
+ template <typename FLOW> struct Wrapper final : API {
+ FLOW _flow;
+ Wrapper(InFlow in_flow) noexcept : _flow(in_flow) {}
+ void add(double est) noexcept override { _flow.add(est); }
+ double flow() const noexcept override { return _flow.flow(); }
+ bool strict() const noexcept override { return _flow.strict(); }
+ double estimate() const noexcept override { return _flow.estimate(); }
+ void update_cost(double &total_cost, double child_cost) noexcept override { return _flow.update_cost(total_cost, child_cost); }
+ ~Wrapper() = default;
+ };
+ alignas(8) char _space[24];
+ API &api() noexcept { return *reinterpret_cast<API*>(_space); }
+ const API &api() const noexcept { return *reinterpret_cast<const API*>(_space); }
+ template <typename FLOW> struct type_tag{};
+ template <typename FLOW> AnyFlow(InFlow in_flow, type_tag<FLOW>) noexcept {
+ using stored_type = Wrapper<FLOW>;
+ static_assert(alignof(stored_type) <= alignof(_space));
+ static_assert(sizeof(stored_type) <= sizeof(_space));
+ stored_type *obj = ::new (static_cast<void*>(_space)) stored_type(in_flow);
+ API *upcasted = obj;
+ (void) upcasted;
+ assert(static_cast<void*>(upcasted) == static_cast<void*>(_space));
+ }
+public:
+ AnyFlow() = delete;
+ AnyFlow(AnyFlow &&) = delete;
+ AnyFlow(const AnyFlow &) = delete;
+ AnyFlow &operator=(AnyFlow &&) = delete;
+ AnyFlow &operator=(const AnyFlow &) = delete;
+ template <typename FLOW> static AnyFlow create(InFlow in_flow) noexcept {
+ return AnyFlow(in_flow, type_tag<FLOW>());
+ }
+ void add(double est) noexcept { api().add(est); }
+ double flow() const noexcept { return api().flow(); }
+ bool strict() const noexcept { return api().strict(); }
+ double estimate() const noexcept { return api().estimate(); }
+ void update_cost(double &total_cost, double child_cost) noexcept { api().update_cost(total_cost, child_cost); }
+};
+
}