summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-16 15:18:27 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-16 15:18:27 +0000
commit531f6fe7f568dc1f2a653d146dfc7b43b35ba7ea (patch)
treec9ab365d55b00e13d8c42beab6c1a42d6b7fc6fa /eval
parentf1cf307162effa0a79b0a5e099e0924b4792451d (diff)
convert result from map/apply to DoubleValue if type is double
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp9
1 files changed, 8 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index a8430bbac4d..39b900c3b30 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -165,7 +165,11 @@ DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash
assert(&a.engine() == this);
const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(a);
CellFunctionAdapter cell_function(op);
- return stash.create<TensorValue>(my_a.apply(cell_function));
+ auto result = my_a.apply(cell_function);
+ if (result->getType().is_double()) {
+ return stash.create<DoubleValue>(result->sum());
+ }
+ return stash.create<TensorValue>(std::move(result));
}
struct TensorOperationOverride : eval::DefaultOperationVisitor {
@@ -217,6 +221,9 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten
TensorOperationOverride tensor_override(my_a, my_b);
op.accept(tensor_override);
if (tensor_override.result) {
+ if (tensor_override.result->getType().is_double()) {
+ return stash.create<DoubleValue>(tensor_override.result->sum());
+ }
return stash.create<TensorValue>(std::move(tensor_override.result));
} else {
return stash.create<ErrorValue>();