aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/test
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/eval/test')
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp21
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.h10
2 files changed, 19 insertions, 12 deletions
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 562e92a4767..601ac2f3cdc 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -77,6 +77,16 @@ const TensorFunction &maybe_patch(bool allow_mutable, const TensorFunction &plai
return apply_tensor_function_optimizer(plain_fun, optimizer, stash);
}
+auto make_optimize_fun(std::set<size_t> mutable_set, bool allow_mutable, bool optimized, const TensorFunction **root_out) {
+ return [=](const ValueBuilderFactory &factory, const TensorFunction &plain_function, Stash &stash)->const TensorFunction &{
+ const auto &patched_function = maybe_patch(allow_mutable, plain_function, mutable_set, stash);
+ const auto &optimized_root = optimized ? optimize_tensor_function(factory, patched_function, stash) : patched_function;
+ // will be kept alive by the static stash of the interpreted function
+ *root_out = &optimized_root;
+ return optimized_root;
+ };
+}
+
std::vector<Value::UP> make_params(const ValueBuilderFactory &factory, const Function &function,
const ParamRepo &param_repo)
{
@@ -97,7 +107,7 @@ std::vector<Value::CREF> get_refs(const std::vector<Value::UP> &values) {
return result;
}
-} // namespace vespalib::eval::test
+} // namespace vespalib::eval::test::<unnamed>
ParamRepo &
EvalFixture::ParamRepo::add(const vespalib::string &name, TensorSpec value)
@@ -173,14 +183,11 @@ EvalFixture::EvalFixture(const ValueBuilderFactory &factory,
bool optimized,
bool allow_mutable)
: _factory(factory),
- _stash(),
_function(verify_function(Function::parse(expr))),
_node_types(get_types(*_function, param_repo)),
- _mutable_set(get_mutable(*_function, param_repo)),
- _plain_tensor_function(make_tensor_function(_factory, _function->root(), _node_types, _stash)),
- _patched_tensor_function(maybe_patch(allow_mutable, _plain_tensor_function, _mutable_set, _stash)),
- _tensor_function(optimized ? optimize_tensor_function(_factory, _patched_tensor_function, _stash) : _patched_tensor_function),
- _ifun(_factory, _tensor_function),
+ _optimized_root(nullptr),
+ _my_optimize(make_optimize_fun(get_mutable(*_function, param_repo), allow_mutable, optimized, &_optimized_root)),
+ _ifun(InterpretedFunction::opts(_factory).optimize(_my_optimize), _function->root(), _node_types),
_ictx(_ifun),
_param_values(make_params(_factory, *_function, param_repo)),
_params(get_refs(_param_values)),
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h
index 7e33b5417b6..ec3e05cc4dd 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.h
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.h
@@ -64,14 +64,13 @@ public:
};
private:
+ using optimize_fun_t = InterpretedFunction::Options::optimize_fun_t;
const ValueBuilderFactory &_factory;
- Stash _stash;
std::shared_ptr<Function const> _function;
NodeTypes _node_types;
std::set<size_t> _mutable_set;
- const TensorFunction &_plain_tensor_function;
- const TensorFunction &_patched_tensor_function;
- const TensorFunction &_tensor_function;
+ const TensorFunction *_optimized_root;
+ optimize_fun_t _my_optimize;
InterpretedFunction _ifun;
InterpretedFunction::Context _ictx;
std::vector<Value::UP> _param_values;
@@ -115,7 +114,8 @@ public:
template <typename T>
std::vector<const T *> find_all() const {
std::vector<const T *> list;
- find_all(_tensor_function, list);
+ REQUIRE(_optimized_root != nullptr);
+ find_all(*_optimized_root, list);
return list;
}
const Value &result_value() const { return _result_value; }