diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-02-16 15:02:30 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-02-16 15:02:30 +0000 |
commit | 472b51a42688540d0c8ad27cf1cf9177987b1e3b (patch) | |
tree | 27874dda0490562d01a2cfa2ed7a0ee890c199b3 /eval/src | |
parent | aa1c38bfb7bd0eb26d014abe9345a7cdc5ff3446 (diff) |
use simple tensor engine as (expensive) fallback
... for new immediate API in default tensor engine
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor_engine.cpp | 5 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 55 |
2 files changed, 39 insertions, 21 deletions
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index 265f0404dca..9e4e7993cde 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -26,10 +26,13 @@ const SimpleTensor &to_simple(const Tensor &tensor) { } const SimpleTensor &to_simple(const Value &value, Stash &stash) { + if (value.is_double()) { + return stash.create<SimpleTensor>(value.as_double()); + } if (auto tensor = value.as_tensor()) { return to_simple(*tensor); } - return stash.create<SimpleTensor>(value.as_double()); + return stash.create<SimpleTensor>(); // error } const Value &to_value(std::unique_ptr<SimpleTensor> tensor, Stash &stash) { diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 2eb83932d83..a8430bbac4d 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -5,6 +5,7 @@ #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/operation_visitor.h> +#include <vespa/eval/eval/simple_tensor_engine.h> #include "tensor.h" #include "dense/dense_tensor_builder.h" #include "dense/dense_tensor_function_compiler.h" @@ -17,6 +18,7 @@ using Value = eval::Value; using ErrorValue = eval::ErrorValue; using DoubleValue = eval::DoubleValue; using TensorValue = eval::TensorValue; +using TensorSpec = eval::TensorSpec; const DefaultTensorEngine DefaultTensorEngine::_engine; @@ -49,7 +51,7 @@ DefaultTensorEngine::to_string(const Tensor &tensor) const return my_tensor.toString(); } -eval::TensorSpec +TensorSpec DefaultTensorEngine::to_spec(const Tensor &tensor) const { assert(&tensor.engine() == this); @@ -223,48 +225,61 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten //----------------------------------------------------------------------------- +namespace { + +const eval::TensorEngine &simple_engine() { return eval::SimpleTensorEngine::ref(); } +const eval::TensorEngine &default_engine() { return DefaultTensorEngine::ref(); } + +// map tensors to simple tensors before fall-back evaluation +const Value &to_simple(const Value &value, Stash &stash) { + if (auto tensor = value.as_tensor()) { + TensorSpec spec = tensor->engine().to_spec(*tensor); + return stash.create<TensorValue>(simple_engine().create(spec)); + } + return value; +} + +// map tensors to default tensors after fall-back evaluation +const Value &to_default(const Value &value, Stash &stash) { + if (auto tensor = value.as_tensor()) { + TensorSpec spec = tensor->engine().to_spec(*tensor); + return stash.create<TensorValue>(default_engine().create(spec)); + } + return value; +} + +} // namespace vespalib::tensor::<unnamed> + +//----------------------------------------------------------------------------- + const Value & DefaultTensorEngine::map(const Value &a, const std::function<double(double)> &function, Stash &stash) const { - (void) a; - (void) function; - return stash.create<ErrorValue>(); + return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash); } const Value & DefaultTensorEngine::join(const Value &a, const Value &b, const std::function<double(double,double)> &function, Stash &stash) const { - (void) a; - (void) b; - (void) function; - return stash.create<ErrorValue>(); + return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); } const Value & DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const { - (void) a; - (void) aggr; - (void) dimensions; - return stash.create<ErrorValue>(); + return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash); } const Value & DefaultTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const { - (void) a; - (void) b; - (void) dimension; - return stash.create<ErrorValue>(); + return to_default(simple_engine().concat(to_simple(a, stash), to_simple(b, stash), dimension, stash), stash); } const Value & DefaultTensorEngine::rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const { - (void) a; - (void) from; - (void) to; - return stash.create<ErrorValue>(); + return to_default(simple_engine().rename(to_simple(a, stash), from, to, stash), stash); } //----------------------------------------------------------------------------- |