aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp50
1 files changed, 28 insertions, 22 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index 99809601d9a..6b72dd9ca06 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -75,6 +75,7 @@ class Optimize
private:
struct ctor_tag{};
public:
+ using optimize_fun_t = InterpretedFunction::Options::optimize_fun_t;
enum class With { NONE, CUSTOM, PROD, SPECIFIC };
With with;
vespalib::string name;
@@ -92,6 +93,29 @@ public:
static Optimize specific(const vespalib::string &name_in, tensor_function_optimizer optimizer_in) {
return {With::SPECIFIC, name_in, {}, optimizer_in, {}};
}
+ optimize_fun_t make_optimize_fun() const {
+ switch (with) {
+ case Optimize::With::NONE: return do_not_optimize_tensor_function;
+ case Optimize::With::PROD: return optimize_tensor_function;
+ case Optimize::With::CUSTOM:
+ return [options=options](const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash)
+ ->const TensorFunction &
+ {
+ return optimize_tensor_function_impl(factory, function, stash, options);
+ };
+ case Optimize::With::SPECIFIC:
+ return [optimizer=optimizer](const ValueBuilderFactory &, const TensorFunction &function, Stash &stash)
+ ->const TensorFunction &
+ {
+ size_t count = 0;
+ const auto &result = apply_tensor_function_optimizer(function, optimizer, stash,
+ [&count](const auto &)noexcept{ ++count; });
+ EXPECT_EQ(count, 1);
+ return result;
+ };
+ }
+ abort();
+ }
~Optimize();
};
Optimize::~Optimize() = default;
@@ -201,29 +225,11 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
}
NodeTypes node_types(*fun, param_types);
ASSERT_FALSE(node_types.get_type(fun->root()).is_error());
- Stash stash;
- const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash);
- const TensorFunction *optimized = nullptr;
- switch (optimize.with) {
- case Optimize::With::NONE:
- optimized = std::addressof(plain_fun);
- break;
- case Optimize::With::PROD:
- optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash));
- break;
- case Optimize::With::CUSTOM:
- optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash, optimize.options));
- break;
- case Optimize::With::SPECIFIC:
- size_t count = 0;
- optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash,
- [&count](const auto &)noexcept{ ++count; }));
- ASSERT_EQ(count, 1);
- break;
- }
- ASSERT_NE(optimized, nullptr);
CTFMetaData ctf_meta;
- InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta);
+ auto ifun = InterpretedFunction::opts(prod_factory)
+ .optimize(optimize.make_optimize_fun())
+ .meta(&ctf_meta)
+ .make(fun->root(), node_types);
InterpretedFunction::ProfiledContext pctx(ifun);
ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size());
std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero());