diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-10-23 11:17:41 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-10-23 11:17:41 +0000 |
commit | 94e2f9e38d3491391ca6a8c10e151f8f94dbe98b (patch) | |
tree | 23db524714d9f013384ed59599831ae3ee333cab /eval | |
parent | d6ae51f8a026ff407926825cfb88b749e6e968ef (diff) |
implement new map API
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 27 |
1 files changed, 23 insertions, 4 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index c9ed801dec2..35db1ed3b4b 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -199,12 +199,19 @@ DefaultTensorEngine::reduce(const Tensor &tensor, const BinaryOperation &op, con return stash.create<ErrorValue>(); } -struct CellFunctionAdapter : tensor::CellFunction { +struct CellFunctionOpAdapter : tensor::CellFunction { const eval::UnaryOperation &op; - CellFunctionAdapter(const eval::UnaryOperation &op_in) : op(op_in) {} + CellFunctionOpAdapter(const eval::UnaryOperation &op_in) : op(op_in) {} virtual double apply(double value) const override { return op.eval(value); } }; +struct CellFunctionFunAdapter : tensor::CellFunction { + using map_fun_t = DefaultTensorEngine::map_fun_t; + map_fun_t fun; + CellFunctionFunAdapter(map_fun_t fun_in) : fun(fun_in) {} + virtual double apply(double value) const override { return fun(value); } +}; + const eval::Value & DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash) const { @@ -213,7 +220,7 @@ DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash if (!tensor::Tensor::supported({my_a.getType()})) { return to_default(simple_engine().map(op, to_simple(my_a, stash), stash), stash); } - CellFunctionAdapter cell_function(op); + CellFunctionOpAdapter cell_function(op); return to_value(my_a.apply(cell_function), stash); } @@ -295,7 +302,19 @@ DefaultTensorEngine::decode(nbostream &input, Stash &stash) const const Value & DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const { - return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash); + if (a.is_double()) { + return stash.create<DoubleValue>(function(a.as_double())); + } else if (auto tensor = a.as_tensor()) { + assert(&tensor->engine() == this); + const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor); + if (!tensor::Tensor::supported({my_a.getType()})) { + return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash); + } + CellFunctionFunAdapter cell_function(function); + return to_value(my_a.apply(cell_function), stash); + } else { // error + return a; + } } const Value & |