summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-10-23 11:17:41 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-10-23 11:17:41 +0000
commit94e2f9e38d3491391ca6a8c10e151f8f94dbe98b (patch)
tree23db524714d9f013384ed59599831ae3ee333cab /eval/src
parentd6ae51f8a026ff407926825cfb88b749e6e968ef (diff)
implement new map API
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp27
1 files changed, 23 insertions, 4 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index c9ed801dec2..35db1ed3b4b 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -199,12 +199,19 @@ DefaultTensorEngine::reduce(const Tensor &tensor, const BinaryOperation &op, con
return stash.create<ErrorValue>();
}
-struct CellFunctionAdapter : tensor::CellFunction {
+struct CellFunctionOpAdapter : tensor::CellFunction {
const eval::UnaryOperation &op;
- CellFunctionAdapter(const eval::UnaryOperation &op_in) : op(op_in) {}
+ CellFunctionOpAdapter(const eval::UnaryOperation &op_in) : op(op_in) {}
virtual double apply(double value) const override { return op.eval(value); }
};
+struct CellFunctionFunAdapter : tensor::CellFunction {
+ using map_fun_t = DefaultTensorEngine::map_fun_t;
+ map_fun_t fun;
+ CellFunctionFunAdapter(map_fun_t fun_in) : fun(fun_in) {}
+ virtual double apply(double value) const override { return fun(value); }
+};
+
const eval::Value &
DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash) const
{
@@ -213,7 +220,7 @@ DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash
if (!tensor::Tensor::supported({my_a.getType()})) {
return to_default(simple_engine().map(op, to_simple(my_a, stash), stash), stash);
}
- CellFunctionAdapter cell_function(op);
+ CellFunctionOpAdapter cell_function(op);
return to_value(my_a.apply(cell_function), stash);
}
@@ -295,7 +302,19 @@ DefaultTensorEngine::decode(nbostream &input, Stash &stash) const
const Value &
DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const
{
- return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash);
+ if (a.is_double()) {
+ return stash.create<DoubleValue>(function(a.as_double()));
+ } else if (auto tensor = a.as_tensor()) {
+ assert(&tensor->engine() == this);
+ const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor);
+ if (!tensor::Tensor::supported({my_a.getType()})) {
+ return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash);
+ }
+ CellFunctionFunAdapter cell_function(function);
+ return to_value(my_a.apply(cell_function), stash);
+ } else { // error
+ return a;
+ }
}
const Value &