summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-04-02 15:13:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-04-03 09:38:00 +0000
commit092186debc2bb188b6c13b7a2d438225bd94afa5 (patch)
treef7c8de5e25763abfbb5ebac8b5cd28429005461f /eval
parent0cdbde812e5dd775d8f76793a2cdadf0bce1d4ae (diff)
delay preparing tensor lambda function for execution
This will allow implementation-specific tensor lambda optimizations to look at the lambda function and perform appropriate optimizations before it is converted to an interpreted function.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp4
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp14
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h20
3 files changed, 25 insertions, 13 deletions
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index b264c3c2fe7..f503532c1f9 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -122,13 +122,13 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
}
void make_lambda(const TensorLambda &node) {
- InterpretedFunction my_fun(tensor_engine, node.lambda().root(), types);
if (node.bindings().empty()) {
NoParams no_bound_params;
+ InterpretedFunction my_fun(tensor_engine, node.lambda().root(), types);
TensorSpec spec = tensor_function::Lambda::create_spec_impl(node.type(), no_bound_params, node.bindings(), my_fun);
make_const(node, *stash.create<Value::UP>(tensor_engine.from_spec(spec)));
} else {
- stack.push_back(tensor_function::lambda(node.type(), node.bindings(), std::move(my_fun), stash));
+ stack.push_back(tensor_function::lambda(node.type(), node.bindings(), node.lambda(), types.export_types(node.lambda().root()), stash));
}
}
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 40c4c614f20..9b7968f5092 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -135,8 +135,8 @@ void op_tensor_create(State &state, uint64_t param) {
}
void op_tensor_lambda(State &state, uint64_t param) {
- const Lambda &self = unwrap_param<Lambda>(param);
- TensorSpec spec = self.create_spec(*state.params);
+ const Lambda::Self &self = unwrap_param<Lambda::Self>(param);
+ TensorSpec spec = self.parent.create_spec(*state.params, self.fun);
const Value &result = *state.stash.create<Value::UP>(state.engine.from_spec(spec));
state.stack.emplace_back(result);
}
@@ -436,9 +436,11 @@ Lambda::create_spec_impl(const ValueType &type, const LazyParams &params, const
}
InterpretedFunction::Instruction
-Lambda::compile_self(const TensorEngine &, Stash &) const
+Lambda::compile_self(const TensorEngine &engine, Stash &stash) const
{
- return Instruction(op_tensor_lambda, wrap_param<Lambda>(*this));
+ InterpretedFunction fun(engine, _lambda->root(), _lambda_types);
+ Self &self = stash.create<Self>(*this, std::move(fun));
+ return Instruction(op_tensor_lambda, wrap_param<Self>(self));
}
void
@@ -578,8 +580,8 @@ const Node &create(const ValueType &type, const std::map<TensorSpec::Address,Nod
return stash.create<Create>(type, spec);
}
-const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash) {
- return stash.create<Lambda>(type, bindings, std::move(function));
+const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash) {
+ return stash.create<Lambda>(type, bindings, function, std::move(node_types));
}
const Node &peek(const Node &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash) {
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 3aedf56affc..2cc70f50b15 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -329,14 +329,24 @@ public:
class Lambda : public Node
{
using Super = Node;
+public:
+ struct Self {
+ const Lambda &parent;
+ InterpretedFunction fun;
+ Self(const Lambda &parent_in, InterpretedFunction fun_in)
+ : parent(parent_in), fun(std::move(fun_in)) {}
+ };
private:
std::vector<size_t> _bindings;
- InterpretedFunction _lambda;
+ std::shared_ptr<Function const> _lambda;
+ NodeTypes _lambda_types;
public:
- Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, InterpretedFunction lambda_in)
- : Node(result_type_in), _bindings(bindings_in), _lambda(std::move(lambda_in)) {}
+ Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, const Function &lambda_in, NodeTypes lambda_types_in)
+ : Node(result_type_in), _bindings(bindings_in), _lambda(lambda_in.shared_from_this()), _lambda_types(std::move(lambda_types_in)) {}
static TensorSpec create_spec_impl(const ValueType &type, const LazyParams &params, const std::vector<size_t> &bind, const InterpretedFunction &fun);
- TensorSpec create_spec(const LazyParams &params) const { return create_spec_impl(result_type(), params, _bindings, _lambda); }
+ TensorSpec create_spec(const LazyParams &params, const InterpretedFunction &fun) const {
+ return create_spec_impl(result_type(), params, _bindings, fun);
+ }
bool result_is_mutable() const override { return true; }
InterpretedFunction::Instruction compile_self(const TensorEngine &engine, Stash &stash) const final override;
void push_children(std::vector<Child::CREF> &children) const final override;
@@ -435,7 +445,7 @@ const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &s
const Node &merge(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash);
const Node &create(const ValueType &type, const std::map<TensorSpec::Address, Node::CREF> &spec, Stash &stash);
-const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, InterpretedFunction function, Stash &stash);
+const Node &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash);
const Node &peek(const Node &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, Node::CREF>> &spec, Stash &stash);
const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash);
const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash);