From 472b51a42688540d0c8ad27cf1cf9177987b1e3b Mon Sep 17 00:00:00 2001 From: Haavard Date: Thu, 16 Feb 2017 15:02:30 +0000 Subject: use simple tensor engine as (expensive) fallback ... for new immediate API in default tensor engine --- eval/src/vespa/eval/eval/simple_tensor_engine.cpp | 5 +- .../vespa/eval/tensor/default_tensor_engine.cpp | 55 ++++++++++++++-------- 2 files changed, 39 insertions(+), 21 deletions(-) (limited to 'eval/src') diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index 265f0404dca..9e4e7993cde 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -26,10 +26,13 @@ const SimpleTensor &to_simple(const Tensor &tensor) { } const SimpleTensor &to_simple(const Value &value, Stash &stash) { + if (value.is_double()) { + return stash.create(value.as_double()); + } if (auto tensor = value.as_tensor()) { return to_simple(*tensor); } - return stash.create(value.as_double()); + return stash.create(); // error } const Value &to_value(std::unique_ptr tensor, Stash &stash) { diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 2eb83932d83..a8430bbac4d 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "tensor.h" #include "dense/dense_tensor_builder.h" #include "dense/dense_tensor_function_compiler.h" @@ -17,6 +18,7 @@ using Value = eval::Value; using ErrorValue = eval::ErrorValue; using DoubleValue = eval::DoubleValue; using TensorValue = eval::TensorValue; +using TensorSpec = eval::TensorSpec; const DefaultTensorEngine DefaultTensorEngine::_engine; @@ -49,7 +51,7 @@ DefaultTensorEngine::to_string(const Tensor &tensor) const return my_tensor.toString(); } -eval::TensorSpec +TensorSpec DefaultTensorEngine::to_spec(const Tensor &tensor) const { assert(&tensor.engine() == this); @@ -223,48 +225,61 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten //----------------------------------------------------------------------------- +namespace { + +const eval::TensorEngine &simple_engine() { return eval::SimpleTensorEngine::ref(); } +const eval::TensorEngine &default_engine() { return DefaultTensorEngine::ref(); } + +// map tensors to simple tensors before fall-back evaluation +const Value &to_simple(const Value &value, Stash &stash) { + if (auto tensor = value.as_tensor()) { + TensorSpec spec = tensor->engine().to_spec(*tensor); + return stash.create(simple_engine().create(spec)); + } + return value; +} + +// map tensors to default tensors after fall-back evaluation +const Value &to_default(const Value &value, Stash &stash) { + if (auto tensor = value.as_tensor()) { + TensorSpec spec = tensor->engine().to_spec(*tensor); + return stash.create(default_engine().create(spec)); + } + return value; +} + +} // namespace vespalib::tensor:: + +//----------------------------------------------------------------------------- + const Value & DefaultTensorEngine::map(const Value &a, const std::function &function, Stash &stash) const { - (void) a; - (void) function; - return stash.create(); + return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash); } const Value & DefaultTensorEngine::join(const Value &a, const Value &b, const std::function &function, Stash &stash) const { - (void) a; - (void) b; - (void) function; - return stash.create(); + return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); } const Value & DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector &dimensions, Stash &stash) const { - (void) a; - (void) aggr; - (void) dimensions; - return stash.create(); + return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash); } const Value & DefaultTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const { - (void) a; - (void) b; - (void) dimension; - return stash.create(); + return to_default(simple_engine().concat(to_simple(a, stash), to_simple(b, stash), dimension, stash), stash); } const Value & DefaultTensorEngine::rename(const Value &a, const std::vector &from, const std::vector &to, Stash &stash) const { - (void) a; - (void) from; - (void) to; - return stash.create(); + return to_default(simple_engine().rename(to_simple(a, stash), from, to, stash), stash); } //----------------------------------------------------------------------------- -- cgit v1.2.3