aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/eval/interpreted_function.h
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/vespa/eval/eval/interpreted_function.h')
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.h33
1 files changed, 26 insertions, 7 deletions
diff --git a/eval/src/vespa/eval/eval/interpreted_function.h b/eval/src/vespa/eval/eval/interpreted_function.h
index 4d4c77f1116..bd54cccb7b6 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.h
+++ b/eval/src/vespa/eval/eval/interpreted_function.h
@@ -8,6 +8,7 @@
#include "lazy_params.h"
#include <vespa/vespalib/util/stash.h>
#include <vespa/vespalib/util/time.h>
+#include <functional>
namespace vespalib::eval {
@@ -99,7 +100,26 @@ public:
}
static Instruction nop();
};
-
+
+ class Options {
+ public:
+ using optimize_fun_t = std::function<const TensorFunction &(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash)>;
+ private:
+ const ValueBuilderFactory &_factory;
+ optimize_fun_t _optimize;
+ CTFMetaData *_meta;
+ public:
+ Options(const ValueBuilderFactory &factory_in);
+ const ValueBuilderFactory &factory() const noexcept { return _factory; }
+ Options &meta(CTFMetaData *value) noexcept { _meta = value; return *this; }
+ CTFMetaData *meta() const noexcept { return _meta; }
+ Options &optimize(optimize_fun_t value) { _optimize = std::move(value); return *this; }
+ optimize_fun_t optimize() const { return _optimize; }
+ InterpretedFunction make(const nodes::Node &root, const NodeTypes &types) const {
+ return InterpretedFunction(*this, root, types);
+ }
+ };
+
private:
std::vector<Instruction> _program;
Stash _stash;
@@ -107,13 +127,12 @@ private:
public:
using UP = std::unique_ptr<InterpretedFunction>;
- // for testing; use with care; the tensor function must be kept alive
- InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function, CTFMetaData *meta);
- InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function)
- : InterpretedFunction(factory, function, nullptr) {}
- InterpretedFunction(const ValueBuilderFactory &factory, const nodes::Node &root, const NodeTypes &types);
+ static Options opts(const ValueBuilderFactory &factory) { return Options(factory); }
+ // for testing; make sure to keep tensor function alive
+ InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function);
+ InterpretedFunction(const Options &opts, const nodes::Node &root, const NodeTypes &types);
InterpretedFunction(const ValueBuilderFactory &factory, const Function &function, const NodeTypes &types)
- : InterpretedFunction(factory, function.root(), types) {}
+ : InterpretedFunction(opts(factory), function.root(), types) {}
InterpretedFunction(InterpretedFunction &&rhs) = default;
~InterpretedFunction();
size_t program_size() const { return _program.size(); }