summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-16 15:02:30 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-16 15:02:30 +0000
commit472b51a42688540d0c8ad27cf1cf9177987b1e3b (patch)
tree27874dda0490562d01a2cfa2ed7a0ee890c199b3 /eval/src
parentaa1c38bfb7bd0eb26d014abe9345a7cdc5ff3446 (diff)
use simple tensor engine as (expensive) fallback
... for new immediate API in default tensor engine
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp55
2 files changed, 39 insertions, 21 deletions
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
index 265f0404dca..9e4e7993cde 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
@@ -26,10 +26,13 @@ const SimpleTensor &to_simple(const Tensor &tensor) {
}
const SimpleTensor &to_simple(const Value &value, Stash &stash) {
+ if (value.is_double()) {
+ return stash.create<SimpleTensor>(value.as_double());
+ }
if (auto tensor = value.as_tensor()) {
return to_simple(*tensor);
}
- return stash.create<SimpleTensor>(value.as_double());
+ return stash.create<SimpleTensor>(); // error
}
const Value &to_value(std::unique_ptr<SimpleTensor> tensor, Stash &stash) {
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 2eb83932d83..a8430bbac4d 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -5,6 +5,7 @@
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/operation_visitor.h>
+#include <vespa/eval/eval/simple_tensor_engine.h>
#include "tensor.h"
#include "dense/dense_tensor_builder.h"
#include "dense/dense_tensor_function_compiler.h"
@@ -17,6 +18,7 @@ using Value = eval::Value;
using ErrorValue = eval::ErrorValue;
using DoubleValue = eval::DoubleValue;
using TensorValue = eval::TensorValue;
+using TensorSpec = eval::TensorSpec;
const DefaultTensorEngine DefaultTensorEngine::_engine;
@@ -49,7 +51,7 @@ DefaultTensorEngine::to_string(const Tensor &tensor) const
return my_tensor.toString();
}
-eval::TensorSpec
+TensorSpec
DefaultTensorEngine::to_spec(const Tensor &tensor) const
{
assert(&tensor.engine() == this);
@@ -223,48 +225,61 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten
//-----------------------------------------------------------------------------
+namespace {
+
+const eval::TensorEngine &simple_engine() { return eval::SimpleTensorEngine::ref(); }
+const eval::TensorEngine &default_engine() { return DefaultTensorEngine::ref(); }
+
+// map tensors to simple tensors before fall-back evaluation
+const Value &to_simple(const Value &value, Stash &stash) {
+ if (auto tensor = value.as_tensor()) {
+ TensorSpec spec = tensor->engine().to_spec(*tensor);
+ return stash.create<TensorValue>(simple_engine().create(spec));
+ }
+ return value;
+}
+
+// map tensors to default tensors after fall-back evaluation
+const Value &to_default(const Value &value, Stash &stash) {
+ if (auto tensor = value.as_tensor()) {
+ TensorSpec spec = tensor->engine().to_spec(*tensor);
+ return stash.create<TensorValue>(default_engine().create(spec));
+ }
+ return value;
+}
+
+} // namespace vespalib::tensor::<unnamed>
+
+//-----------------------------------------------------------------------------
+
const Value &
DefaultTensorEngine::map(const Value &a, const std::function<double(double)> &function, Stash &stash) const
{
- (void) a;
- (void) function;
- return stash.create<ErrorValue>();
+ return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash);
}
const Value &
DefaultTensorEngine::join(const Value &a, const Value &b, const std::function<double(double,double)> &function, Stash &stash) const
{
- (void) a;
- (void) b;
- (void) function;
- return stash.create<ErrorValue>();
+ return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
}
const Value &
DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const
{
- (void) a;
- (void) aggr;
- (void) dimensions;
- return stash.create<ErrorValue>();
+ return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash);
}
const Value &
DefaultTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const
{
- (void) a;
- (void) b;
- (void) dimension;
- return stash.create<ErrorValue>();
+ return to_default(simple_engine().concat(to_simple(a, stash), to_simple(b, stash), dimension, stash), stash);
}
const Value &
DefaultTensorEngine::rename(const Value &a, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash) const
{
- (void) a;
- (void) from;
- (void) to;
- return stash.create<ErrorValue>();
+ return to_default(simple_engine().rename(to_simple(a, stash), from, to, stash), stash);
}
//-----------------------------------------------------------------------------