diff options
author | Håvard Pettersen <havardpe@oath.com> | 2020-04-02 15:13:22 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2020-04-03 09:38:00 +0000 |
commit | 092186debc2bb188b6c13b7a2d438225bd94afa5 (patch) | |
tree | f7c8de5e25763abfbb5ebac8b5cd28429005461f /eval | |
parent | 0cdbde812e5dd775d8f76793a2cdadf0bce1d4ae (diff) |
delay preparing tensor lambda function for execution
This will allow implementation-specific tensor lambda optimizations to
look at the lambda function and perform appropriate optimizations
before it is converted to an interpreted function.
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/make_tensor_function.cpp | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 14 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.h | 20 |
3 files changed, 25 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index b264c3c2fe7..f503532c1f9 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -122,13 +122,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { } void make_lambda(const TensorLambda &node) { - InterpretedFunction my_fun(tensor_engine, node.lambda().root(), types); if (node.bindings().empty()) { NoParams no_bound_params; + InterpretedFunction my_fun(tensor_engine, node.lambda().root(), types); TensorSpec spec = tensor_function::Lambda::create_spec_impl(node.type(), no_bound_params, node.bindings(), my_fun); make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec))); } else { - stack.push_back(tensor_function::lambda(node.type(), node.bindings(), std::move(my_fun), stash)); + stack.push_back(tensor_function::lambda(node.type(), node.bindings(), node.lambda(), types.export_types(node.lambda().root()), stash)); } } diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 40c4c614f20..9b7968f5092 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -135,8 +135,8 @@ void op_tensor_create(State &state, uint64_t param) { } void op_tensor_lambda(State &state, uint64_t param) { - const Lambda &self = unwrap_param<Lambda>(param); - TensorSpec spec = self.create_spec(*state.params); + const Lambda::Self &self = unwrap_param<Lambda::Self>(param); + TensorSpec spec = self.parent.create_spec(*state.params, self.fun); const Value &result = *state.stash.create<Value::UP>(state.engine.from_spec(spec)); state.stack.emplace_back(result); } @@ -436,9 +436,11 @@ Lambda::create_spec_impl(const ValueType &type, const LazyParams ¶ms, const } InterpretedFunction::Instruction -Lambda::compile_self(const TensorEngine &, Stash &) const +Lambda::compile_self(const TensorEngine &engine, Stash &stash) const { - return Instruction(op_tensor_lambda, wrap_param<Lambda>(*this)); + InterpretedFunction fun(engine, _lambda->root(), _lambda_types); + Self &self = stash.create<Self>(*this, std::move(fun)); + return Instruction(op_tensor_lambda, wrap_param<Self>(self)); } void @@ -578,8 +580,8 @@ const Node &create(const ValueType &type, const std::map<TensorSpec::Address,Nod return stash.create<Create>(type, spec); } -const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash) { - return stash.create<Lambda>(type, bindings, std::move(function)); +const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash) { + return stash.create<Lambda>(type, bindings, function, std::move(node_types)); } const Node &peek(const Node ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash) { diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index 3aedf56affc..2cc70f50b15 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -329,14 +329,24 @@ public: class Lambda : public Node { using Super = Node; +public: + struct Self { + const Lambda &parent; + InterpretedFunction fun; + Self(const Lambda &parent_in, InterpretedFunction fun_in) + : parent(parent_in), fun(std::move(fun_in)) {} + }; private: std::vector<size_t> _bindings; - InterpretedFunction _lambda; + std::shared_ptr<Function const> _lambda; + NodeTypes _lambda_types; public: - Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, InterpretedFunction lambda_in) - : Node(result_type_in), _bindings(bindings_in), _lambda(std::move(lambda_in)) {} + Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, const Function &lambda_in, NodeTypes lambda_types_in) + : Node(result_type_in), _bindings(bindings_in), _lambda(lambda_in.shared_from_this()), _lambda_types(std::move(lambda_types_in)) {} static TensorSpec create_spec_impl(const ValueType &type, const LazyParams ¶ms, const std::vector<size_t> &bind, const InterpretedFunction &fun); - TensorSpec create_spec(const LazyParams ¶ms) const { return create_spec_impl(result_type(), params, _bindings, _lambda); } + TensorSpec create_spec(const LazyParams ¶ms, const InterpretedFunction &fun) const { + return create_spec_impl(result_type(), params, _bindings, fun); + } bool result_is_mutable() const override { return true; } InterpretedFunction::Instruction compile_self(const TensorEngine &engine, Stash &stash) const final override; void push_children(std::vector<Child::CREF> &children) const final override; @@ -435,7 +445,7 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash); const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash); const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash); -const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash); +const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash); const Node &peek(const Node ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash); const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash); const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash); |