summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-10 14:21:38 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-10 16:07:49 +0000
commit7b658207786c84aeeee69d6026953988b53ea68d (patch)
tree0ab3f18dc026eafbaf895973cbb7b6c4e9eed503 /eval
parent6f4eb0728a0ce1182dcc8d9434ff0d6b812d5e31 (diff)
enable converting simple tensors to errors
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor_engine.cpp28
1 files changed, 11 insertions, 17 deletions
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
index dffc641e533..265f0404dca 100644
--- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp
@@ -26,10 +26,8 @@ const SimpleTensor &to_simple(const Tensor &tensor) {
}
const SimpleTensor &to_simple(const Value &value, Stash &stash) {
- auto tensor = value.as_tensor();
- if (tensor) {
- assert(&tensor->engine() == &SimpleTensorEngine::ref());
- return static_cast<const SimpleTensor &>(*tensor);
+ if (auto tensor = value.as_tensor()) {
+ return to_simple(*tensor);
}
return stash.create<SimpleTensor>(value.as_double());
}
@@ -39,7 +37,11 @@ const Value &to_value(std::unique_ptr<SimpleTensor> tensor, Stash &stash) {
assert(tensor->cells().size() == 1u);
return stash.create<DoubleValue>(tensor->cells()[0].value);
}
- return stash.create<TensorValue>(std::move(tensor));
+ if (tensor->type().is_tensor()) {
+ return stash.create<TensorValue>(std::move(tensor));
+ }
+ assert(tensor->type().is_error());
+ return stash.create<ErrorValue>();
}
} // namespace vespalib::eval::<unnamed>
@@ -49,26 +51,19 @@ const SimpleTensorEngine SimpleTensorEngine::_engine;
ValueType
SimpleTensorEngine::type_of(const Tensor &tensor) const
{
- assert(&tensor.engine() == this);
- const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor);
- return simple_tensor.type();
+ return to_simple(tensor).type();
}
bool
SimpleTensorEngine::equal(const Tensor &a, const Tensor &b) const
{
- assert(&a.engine() == this);
- assert(&b.engine() == this);
- const SimpleTensor &simple_a = static_cast<const SimpleTensor&>(a);
- const SimpleTensor &simple_b = static_cast<const SimpleTensor&>(b);
- return SimpleTensor::equal(simple_a, simple_b);
+ return SimpleTensor::equal(to_simple(a), to_simple(b));
}
vespalib::string
SimpleTensorEngine::to_string(const Tensor &tensor) const
{
- assert(&tensor.engine() == this);
- const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor);
+ const SimpleTensor &simple_tensor = to_simple(tensor);
vespalib::string out = vespalib::make_string("simple(%s) {\n", simple_tensor.type().to_spec().c_str());
for (const auto &cell: simple_tensor.cells()) {
size_t n = 0;
@@ -92,8 +87,7 @@ SimpleTensorEngine::to_string(const Tensor &tensor) const
TensorSpec
SimpleTensorEngine::to_spec(const Tensor &tensor) const
{
- assert(&tensor.engine() == this);
- const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor);
+ const SimpleTensor &simple_tensor = to_simple(tensor);
ValueType type = simple_tensor.type();
const auto &dimensions = type.dimensions();
TensorSpec spec(type.to_spec());