diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-02-16 15:18:27 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-02-16 15:18:27 +0000 |
commit | 531f6fe7f568dc1f2a653d146dfc7b43b35ba7ea (patch) | |
tree | c9ab365d55b00e13d8c42beab6c1a42d6b7fc6fa /eval | |
parent | f1cf307162effa0a79b0a5e099e0924b4792451d (diff) |
convert result from map/apply to DoubleValue if type is double
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index a8430bbac4d..39b900c3b30 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -165,7 +165,11 @@ DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash assert(&a.engine() == this); const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(a); CellFunctionAdapter cell_function(op); - return stash.create<TensorValue>(my_a.apply(cell_function)); + auto result = my_a.apply(cell_function); + if (result->getType().is_double()) { + return stash.create<DoubleValue>(result->sum()); + } + return stash.create<TensorValue>(std::move(result)); } struct TensorOperationOverride : eval::DefaultOperationVisitor { @@ -217,6 +221,9 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten TensorOperationOverride tensor_override(my_a, my_b); op.accept(tensor_override); if (tensor_override.result) { + if (tensor_override.result->getType().is_double()) { + return stash.create<DoubleValue>(tensor_override.result->sum()); + } return stash.create<TensorValue>(std::move(tensor_override.result)); } else { return stash.create<ErrorValue>(); |