diff options
author | Håvard Pettersen <havardpe@oath.com> | 2018-01-16 15:10:10 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2018-01-16 15:10:10 +0000 |
commit | 3108978de950d2751b6abd008750d30f302b4a3b (patch) | |
tree | d49c3762b99e17e149bf4fbd107e15a3ef942a55 /eval | |
parent | 3ec992ccf4f4f2246c85014fba8e1465ce6ecd0f (diff) |
added 'const_value' and 'if_node' to tensor IR
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.cpp | 38 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_function.h | 57 |
2 files changed, 83 insertions, 12 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index a98051a489d..c3c1d66c645 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -43,6 +43,14 @@ Op2::push_children(std::vector<Child::CREF> &children) const //----------------------------------------------------------------------------- const Value & +ConstValue::eval(const LazyParams &, Stash &) const +{ + return _value; +} + +//----------------------------------------------------------------------------- + +const Value & Inject::eval(const LazyParams ¶ms, Stash &stash) const { return params.resolve(_param_idx, stash); @@ -102,6 +110,28 @@ Rename::eval(const LazyParams ¶ms, Stash &stash) const //----------------------------------------------------------------------------- +void +If::push_children(std::vector<Child::CREF> &children) const +{ + children.emplace_back(_cond); + children.emplace_back(_true_child); + children.emplace_back(_false_child); +} + +const Value & +If::eval(const LazyParams ¶ms, Stash &stash) const +{ + return (cond().eval(params, stash).as_bool() + ? true_child().eval(params, stash) + : false_child().eval(params, stash)); +} + +//----------------------------------------------------------------------------- + +const Node &const_value(const Value &value, Stash &stash) { + return stash.create<ConstValue>(value); +} + const Node &inject(const ValueType &type, size_t param_idx, Stash &stash) { return stash.create<Inject>(type, param_idx); } @@ -131,6 +161,14 @@ const Node &rename(const Node &child, const std::vector<vespalib::string> &from, return stash.create<Rename>(result_type, child, from, to); } +const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash) { + ValueType result_type = true_child.result_type(); + if (result_type != false_child.result_type()) { + result_type = ValueType::any_type(); + } + return stash.create<If>(result_type, cond, true_child, false_child); +} + } // namespace vespalib::eval::tensor_function } // namespace vespalib::eval } // namespace vespalib diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index 6df863ad818..ed19b6dce4a 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -33,9 +33,9 @@ class Tensor; * node will directly represent a single call to the tensor engine * immediate API. * - * The generic implementation-independent tree will then be optimized - * (in-place, bottom-up) where sub-expressions may be replaced with - * optimized implementation-specific alternatives. + * The generic tree will then be optimized (in-place, bottom-up) where + * sub-expressions may be replaced with optimized + * implementation-specific alternatives. * * This leaves us with a mixed-mode tree with some generic and some * specialized nodes, that may be evaluated recursively. @@ -106,7 +106,7 @@ private: ValueType _result_type; public: Node(const ValueType &result_type_in) : _result_type(result_type_in) {} - const ValueType &result_type() const override { return _result_type; } + const ValueType &result_type() const final override { return _result_type; } }; //----------------------------------------------------------------------------- @@ -147,16 +147,26 @@ public: //----------------------------------------------------------------------------- +class ConstValue : public Leaf +{ +private: + const Value &_value; +public: + ConstValue(const Value &value_in) : Leaf(value_in.type()), _value(value_in) {} + const Value &eval(const LazyParams ¶ms, Stash &) const final override; +}; + +//----------------------------------------------------------------------------- + class Inject : public Leaf { private: size_t _param_idx; public: - Inject(const ValueType &result_type_in, - size_t param_idx_in) + Inject(const ValueType &result_type_in, size_t param_idx_in) : Leaf(result_type_in), _param_idx(param_idx_in) {} size_t param_idx() const { return _param_idx; } - const Value &eval(const LazyParams ¶ms, Stash &) const override; + const Value &eval(const LazyParams ¶ms, Stash &) const final override; }; //----------------------------------------------------------------------------- @@ -174,7 +184,7 @@ public: : Op1(result_type_in, child_in), _aggr(aggr_in), _dimensions(dimensions_in) {} Aggr aggr() const { return _aggr; } const std::vector<vespalib::string> &dimensions() const { return _dimensions; } - const Value &eval(const LazyParams ¶ms, Stash &stash) const override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -189,7 +199,7 @@ public: map_fun_t function_in) : Op1(result_type_in, child_in), _function(function_in) {} map_fun_t function() const { return _function; } - const Value &eval(const LazyParams ¶ms, Stash &stash) const override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -205,7 +215,7 @@ public: join_fun_t function_in) : Op2(result_type_in, lhs_in, rhs_in), _function(function_in) {} join_fun_t function() const { return _function; } - const Value &eval(const LazyParams ¶ms, Stash &stash) const override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -221,7 +231,7 @@ public: const vespalib::string &dimension_in) : Op2(result_type_in, lhs_in, rhs_in), _dimension(dimension_in) {} const vespalib::string &dimension() const { return _dimension; } - const Value &eval(const LazyParams ¶ms, Stash &stash) const override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; }; //----------------------------------------------------------------------------- @@ -239,17 +249,40 @@ public: : Op1(result_type_in, child_in), _from(from_in), _to(to_in) {} const std::vector<vespalib::string> &from() const { return _from; } const std::vector<vespalib::string> &to() const { return _to; } - const Value &eval(const LazyParams ¶ms, Stash &stash) const override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; +}; + +//----------------------------------------------------------------------------- + +class If : public Node +{ +private: + Child _cond; + Child _true_child; + Child _false_child; +public: + If(const ValueType &result_type_in, + const TensorFunction &cond_in, + const TensorFunction &true_child_in, + const TensorFunction &false_child_in) + : Node(result_type_in), _cond(cond_in), _true_child(true_child_in), _false_child(false_child_in) {} + const TensorFunction &cond() const { return _cond.get(); } + const TensorFunction &true_child() const { return _true_child.get(); } + const TensorFunction &false_child() const { return _false_child.get(); } + void push_children(std::vector<Child::CREF> &children) const final override; + const Value &eval(const LazyParams ¶ms, Stash &stash) const final override; }; //----------------------------------------------------------------------------- +const Node &const_value(const Value &value, Stash &stash); const Node &inject(const ValueType &type, size_t param_idx, Stash &stash); const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash); const Node &map(const Node &child, map_fun_t function, Stash &stash); const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash); const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash); const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash); +const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash); } // namespace vespalib::eval::tensor_function } // namespace vespalib::eval |