diff options
Diffstat (limited to 'eval/src/tests/instruction')
3 files changed, 31 insertions, 25 deletions
diff --git a/eval/src/tests/instruction/dense_replace_type_function/dense_replace_type_function_test.cpp b/eval/src/tests/instruction/dense_replace_type_function/dense_replace_type_function_test.cpp index 0cb5a821136..748f38a3343 100644 --- a/eval/src/tests/instruction/dense_replace_type_function/dense_replace_type_function_test.cpp +++ b/eval/src/tests/instruction/dense_replace_type_function/dense_replace_type_function_test.cpp @@ -22,7 +22,7 @@ struct ChildMock : Leaf { bool is_mutable; ChildMock(const ValueType &type) : Leaf(type), is_mutable(true) {} bool result_is_mutable() const override { return is_mutable; } - InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &, Stash &) const override { abort(); } + InterpretedFunction::Instruction compile_self(const CTFContext &) const override { abort(); } }; struct Fixture { @@ -42,7 +42,7 @@ struct Fixture { { my_fun.push_children(children); state.stack.push_back(*my_value); - my_fun.compile_self(prod_factory, state.stash).perform(state); + my_fun.compile_self(CTFContext(prod_factory, state.stash, nullptr)).perform(state); ASSERT_EQUAL(children.size(), 1u); ASSERT_EQUAL(state.stack.size(), 1u); ASSERT_TRUE(!new_type.is_error()); diff --git a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp index 7f25edd50c8..df1661cc57f 100644 --- a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp +++ b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp @@ -115,7 +115,7 @@ TensorSpec tensor_function_peek(const TensorSpec &a, const ValueType &result_typ } const auto &func_param = tensor_function::inject(param->type(), 0, stash); const auto &peek_node = tensor_function::peek(func_param, func_spec, stash); - auto my_op = peek_node.compile_self(factory, stash); + auto my_op = peek_node.compile_self(CTFContext(factory, stash, nullptr)); InterpretedFunction::EvalSingle single(factory, my_op); return spec_from_value(single.eval(my_stack)); } diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp index 99809601d9a..6b72dd9ca06 100644 --- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp +++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp @@ -75,6 +75,7 @@ class Optimize private: struct ctor_tag{}; public: + using optimize_fun_t = InterpretedFunction::Options::optimize_fun_t; enum class With { NONE, CUSTOM, PROD, SPECIFIC }; With with; vespalib::string name; @@ -92,6 +93,29 @@ public: static Optimize specific(const vespalib::string &name_in, tensor_function_optimizer optimizer_in) { return {With::SPECIFIC, name_in, {}, optimizer_in, {}}; } + optimize_fun_t make_optimize_fun() const { + switch (with) { + case Optimize::With::NONE: return do_not_optimize_tensor_function; + case Optimize::With::PROD: return optimize_tensor_function; + case Optimize::With::CUSTOM: + return [options=options](const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash) + ->const TensorFunction & + { + return optimize_tensor_function_impl(factory, function, stash, options); + }; + case Optimize::With::SPECIFIC: + return [optimizer=optimizer](const ValueBuilderFactory &, const TensorFunction &function, Stash &stash) + ->const TensorFunction & + { + size_t count = 0; + const auto &result = apply_tensor_function_optimizer(function, optimizer, stash, + [&count](const auto &)noexcept{ ++count; }); + EXPECT_EQ(count, 1); + return result; + }; + } + abort(); + } ~Optimize(); }; Optimize::~Optimize() = default; @@ -201,29 +225,11 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) { } NodeTypes node_types(*fun, param_types); ASSERT_FALSE(node_types.get_type(fun->root()).is_error()); - Stash stash; - const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash); - const TensorFunction *optimized = nullptr; - switch (optimize.with) { - case Optimize::With::NONE: - optimized = std::addressof(plain_fun); - break; - case Optimize::With::PROD: - optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash)); - break; - case Optimize::With::CUSTOM: - optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash, optimize.options)); - break; - case Optimize::With::SPECIFIC: - size_t count = 0; - optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, - [&count](const auto &)noexcept{ ++count; })); - ASSERT_EQ(count, 1); - break; - } - ASSERT_NE(optimized, nullptr); CTFMetaData ctf_meta; - InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta); + auto ifun = InterpretedFunction::opts(prod_factory) + .optimize(optimize.make_optimize_fun()) + .meta(&ctf_meta) + .make(fun->root(), node_types); InterpretedFunction::ProfiledContext pctx(ifun); ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size()); std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero()); |