diff options
Diffstat (limited to 'eval/src/vespa/eval/eval/interpreted_function.h')
-rw-r--r-- | eval/src/vespa/eval/eval/interpreted_function.h | 33 |
1 files changed, 26 insertions, 7 deletions
diff --git a/eval/src/vespa/eval/eval/interpreted_function.h b/eval/src/vespa/eval/eval/interpreted_function.h index 4d4c77f1116..bd54cccb7b6 100644 --- a/eval/src/vespa/eval/eval/interpreted_function.h +++ b/eval/src/vespa/eval/eval/interpreted_function.h @@ -8,6 +8,7 @@ #include "lazy_params.h" #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/time.h> +#include <functional> namespace vespalib::eval { @@ -99,7 +100,26 @@ public: } static Instruction nop(); }; - + + class Options { + public: + using optimize_fun_t = std::function<const TensorFunction &(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash)>; + private: + const ValueBuilderFactory &_factory; + optimize_fun_t _optimize; + CTFMetaData *_meta; + public: + Options(const ValueBuilderFactory &factory_in); + const ValueBuilderFactory &factory() const noexcept { return _factory; } + Options &meta(CTFMetaData *value) noexcept { _meta = value; return *this; } + CTFMetaData *meta() const noexcept { return _meta; } + Options &optimize(optimize_fun_t value) { _optimize = std::move(value); return *this; } + optimize_fun_t optimize() const { return _optimize; } + InterpretedFunction make(const nodes::Node &root, const NodeTypes &types) const { + return InterpretedFunction(*this, root, types); + } + }; + private: std::vector<Instruction> _program; Stash _stash; @@ -107,13 +127,12 @@ private: public: using UP = std::unique_ptr<InterpretedFunction>; - // for testing; use with care; the tensor function must be kept alive - InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function, CTFMetaData *meta); - InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function) - : InterpretedFunction(factory, function, nullptr) {} - InterpretedFunction(const ValueBuilderFactory &factory, const nodes::Node &root, const NodeTypes &types); + static Options opts(const ValueBuilderFactory &factory) { return Options(factory); } + // for testing; make sure to keep tensor function alive + InterpretedFunction(const ValueBuilderFactory &factory, const TensorFunction &function); + InterpretedFunction(const Options &opts, const nodes::Node &root, const NodeTypes &types); InterpretedFunction(const ValueBuilderFactory &factory, const Function &function, const NodeTypes &types) - : InterpretedFunction(factory, function.root(), types) {} + : InterpretedFunction(opts(factory), function.root(), types) {} InterpretedFunction(InterpretedFunction &&rhs) = default; ~InterpretedFunction(); size_t program_size() const { return _program.size(); } |