aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-16 15:10:10 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-16 15:10:10 +0000
commit3108978de950d2751b6abd008750d30f302b4a3b (patch)
treed49c3762b99e17e149bf4fbd107e15a3ef942a55 /eval
parent3ec992ccf4f4f2246c85014fba8e1465ce6ecd0f (diff)
added 'const_value' and 'if_node' to tensor IR
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp38
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h57
2 files changed, 83 insertions, 12 deletions
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index a98051a489d..c3c1d66c645 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -43,6 +43,14 @@ Op2::push_children(std::vector<Child::CREF> &children) const
//-----------------------------------------------------------------------------
const Value &
+ConstValue::eval(const LazyParams &, Stash &) const
+{
+ return _value;
+}
+
+//-----------------------------------------------------------------------------
+
+const Value &
Inject::eval(const LazyParams &params, Stash &stash) const
{
return params.resolve(_param_idx, stash);
@@ -102,6 +110,28 @@ Rename::eval(const LazyParams &params, Stash &stash) const
//-----------------------------------------------------------------------------
+void
+If::push_children(std::vector<Child::CREF> &children) const
+{
+ children.emplace_back(_cond);
+ children.emplace_back(_true_child);
+ children.emplace_back(_false_child);
+}
+
+const Value &
+If::eval(const LazyParams &params, Stash &stash) const
+{
+ return (cond().eval(params, stash).as_bool()
+ ? true_child().eval(params, stash)
+ : false_child().eval(params, stash));
+}
+
+//-----------------------------------------------------------------------------
+
+const Node &const_value(const Value &value, Stash &stash) {
+ return stash.create<ConstValue>(value);
+}
+
const Node &inject(const ValueType &type, size_t param_idx, Stash &stash) {
return stash.create<Inject>(type, param_idx);
}
@@ -131,6 +161,14 @@ const Node &rename(const Node &child, const std::vector<vespalib::string> &from,
return stash.create<Rename>(result_type, child, from, to);
}
+const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash) {
+ ValueType result_type = true_child.result_type();
+ if (result_type != false_child.result_type()) {
+ result_type = ValueType::any_type();
+ }
+ return stash.create<If>(result_type, cond, true_child, false_child);
+}
+
} // namespace vespalib::eval::tensor_function
} // namespace vespalib::eval
} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 6df863ad818..ed19b6dce4a 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -33,9 +33,9 @@ class Tensor;
* node will directly represent a single call to the tensor engine
* immediate API.
*
- * The generic implementation-independent tree will then be optimized
- * (in-place, bottom-up) where sub-expressions may be replaced with
- * optimized implementation-specific alternatives.
+ * The generic tree will then be optimized (in-place, bottom-up) where
+ * sub-expressions may be replaced with optimized
+ * implementation-specific alternatives.
*
* This leaves us with a mixed-mode tree with some generic and some
* specialized nodes, that may be evaluated recursively.
@@ -106,7 +106,7 @@ private:
ValueType _result_type;
public:
Node(const ValueType &result_type_in) : _result_type(result_type_in) {}
- const ValueType &result_type() const override { return _result_type; }
+ const ValueType &result_type() const final override { return _result_type; }
};
//-----------------------------------------------------------------------------
@@ -147,16 +147,26 @@ public:
//-----------------------------------------------------------------------------
+class ConstValue : public Leaf
+{
+private:
+ const Value &_value;
+public:
+ ConstValue(const Value &value_in) : Leaf(value_in.type()), _value(value_in) {}
+ const Value &eval(const LazyParams &params, Stash &) const final override;
+};
+
+//-----------------------------------------------------------------------------
+
class Inject : public Leaf
{
private:
size_t _param_idx;
public:
- Inject(const ValueType &result_type_in,
- size_t param_idx_in)
+ Inject(const ValueType &result_type_in, size_t param_idx_in)
: Leaf(result_type_in), _param_idx(param_idx_in) {}
size_t param_idx() const { return _param_idx; }
- const Value &eval(const LazyParams &params, Stash &) const override;
+ const Value &eval(const LazyParams &params, Stash &) const final override;
};
//-----------------------------------------------------------------------------
@@ -174,7 +184,7 @@ public:
: Op1(result_type_in, child_in), _aggr(aggr_in), _dimensions(dimensions_in) {}
Aggr aggr() const { return _aggr; }
const std::vector<vespalib::string> &dimensions() const { return _dimensions; }
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
};
//-----------------------------------------------------------------------------
@@ -189,7 +199,7 @@ public:
map_fun_t function_in)
: Op1(result_type_in, child_in), _function(function_in) {}
map_fun_t function() const { return _function; }
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
};
//-----------------------------------------------------------------------------
@@ -205,7 +215,7 @@ public:
join_fun_t function_in)
: Op2(result_type_in, lhs_in, rhs_in), _function(function_in) {}
join_fun_t function() const { return _function; }
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
};
//-----------------------------------------------------------------------------
@@ -221,7 +231,7 @@ public:
const vespalib::string &dimension_in)
: Op2(result_type_in, lhs_in, rhs_in), _dimension(dimension_in) {}
const vespalib::string &dimension() const { return _dimension; }
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
};
//-----------------------------------------------------------------------------
@@ -239,17 +249,40 @@ public:
: Op1(result_type_in, child_in), _from(from_in), _to(to_in) {}
const std::vector<vespalib::string> &from() const { return _from; }
const std::vector<vespalib::string> &to() const { return _to; }
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
+};
+
+//-----------------------------------------------------------------------------
+
+class If : public Node
+{
+private:
+ Child _cond;
+ Child _true_child;
+ Child _false_child;
+public:
+ If(const ValueType &result_type_in,
+ const TensorFunction &cond_in,
+ const TensorFunction &true_child_in,
+ const TensorFunction &false_child_in)
+ : Node(result_type_in), _cond(cond_in), _true_child(true_child_in), _false_child(false_child_in) {}
+ const TensorFunction &cond() const { return _cond.get(); }
+ const TensorFunction &true_child() const { return _true_child.get(); }
+ const TensorFunction &false_child() const { return _false_child.get(); }
+ void push_children(std::vector<Child::CREF> &children) const final override;
+ const Value &eval(const LazyParams &params, Stash &stash) const final override;
};
//-----------------------------------------------------------------------------
+const Node &const_value(const Value &value, Stash &stash);
const Node &inject(const ValueType &type, size_t param_idx, Stash &stash);
const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash);
const Node &map(const Node &child, map_fun_t function, Stash &stash);
const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
const Node &concat(const Node &lhs, const Node &rhs, const vespalib::string &dimension, Stash &stash);
const Node &rename(const Node &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash);
+const Node &if_node(const Node &cond, const Node &true_child, const Node &false_child, Stash &stash);
} // namespace vespalib::eval::tensor_function
} // namespace vespalib::eval