aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/test/eval_fixture.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/eval/test/eval_fixture.cpp')
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp21
1 files changed, 14 insertions, 7 deletions
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 562e92a4767..601ac2f3cdc 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -77,6 +77,16 @@ const TensorFunction &maybe_patch(bool allow_mutable, const TensorFunction &plai
return apply_tensor_function_optimizer(plain_fun, optimizer, stash);
}
+auto make_optimize_fun(std::set<size_t> mutable_set, bool allow_mutable, bool optimized, const TensorFunction **root_out) {
+ return [=](const ValueBuilderFactory &factory, const TensorFunction &plain_function, Stash &stash)->const TensorFunction &{
+ const auto &patched_function = maybe_patch(allow_mutable, plain_function, mutable_set, stash);
+ const auto &optimized_root = optimized ? optimize_tensor_function(factory, patched_function, stash) : patched_function;
+ // will be kept alive by the static stash of the interpreted function
+ *root_out = &optimized_root;
+ return optimized_root;
+ };
+}
+
std::vector<Value::UP> make_params(const ValueBuilderFactory &factory, const Function &function,
const ParamRepo &param_repo)
{
@@ -97,7 +107,7 @@ std::vector<Value::CREF> get_refs(const std::vector<Value::UP> &values) {
return result;
}
-} // namespace vespalib::eval::test
+} // namespace vespalib::eval::test::<unnamed>
ParamRepo &
EvalFixture::ParamRepo::add(const vespalib::string &name, TensorSpec value)
@@ -173,14 +183,11 @@ EvalFixture::EvalFixture(const ValueBuilderFactory &factory,
bool optimized,
bool allow_mutable)
: _factory(factory),
- _stash(),
_function(verify_function(Function::parse(expr))),
_node_types(get_types(*_function, param_repo)),
- _mutable_set(get_mutable(*_function, param_repo)),
- _plain_tensor_function(make_tensor_function(_factory, _function->root(), _node_types, _stash)),
- _patched_tensor_function(maybe_patch(allow_mutable, _plain_tensor_function, _mutable_set, _stash)),
- _tensor_function(optimized ? optimize_tensor_function(_factory, _patched_tensor_function, _stash) : _patched_tensor_function),
- _ifun(_factory, _tensor_function),
+ _optimized_root(nullptr),
+ _my_optimize(make_optimize_fun(get_mutable(*_function, param_repo), allow_mutable, optimized, &_optimized_root)),
+ _ifun(InterpretedFunction::opts(_factory).optimize(_my_optimize), _function->root(), _node_types),
_ictx(_ifun),
_param_values(make_params(_factory, *_function, param_repo)),
_params(get_refs(_param_values)),