diff options
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/simple_tensor_engine.cpp | 28 |
1 files changed, 11 insertions, 17 deletions
diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index dffc641e533..265f0404dca 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -26,10 +26,8 @@ const SimpleTensor &to_simple(const Tensor &tensor) { } const SimpleTensor &to_simple(const Value &value, Stash &stash) { - auto tensor = value.as_tensor(); - if (tensor) { - assert(&tensor->engine() == &SimpleTensorEngine::ref()); - return static_cast<const SimpleTensor &>(*tensor); + if (auto tensor = value.as_tensor()) { + return to_simple(*tensor); } return stash.create<SimpleTensor>(value.as_double()); } @@ -39,7 +37,11 @@ const Value &to_value(std::unique_ptr<SimpleTensor> tensor, Stash &stash) { assert(tensor->cells().size() == 1u); return stash.create<DoubleValue>(tensor->cells()[0].value); } - return stash.create<TensorValue>(std::move(tensor)); + if (tensor->type().is_tensor()) { + return stash.create<TensorValue>(std::move(tensor)); + } + assert(tensor->type().is_error()); + return stash.create<ErrorValue>(); } } // namespace vespalib::eval::<unnamed> @@ -49,26 +51,19 @@ const SimpleTensorEngine SimpleTensorEngine::_engine; ValueType SimpleTensorEngine::type_of(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor); - return simple_tensor.type(); + return to_simple(tensor).type(); } bool SimpleTensorEngine::equal(const Tensor &a, const Tensor &b) const { - assert(&a.engine() == this); - assert(&b.engine() == this); - const SimpleTensor &simple_a = static_cast<const SimpleTensor&>(a); - const SimpleTensor &simple_b = static_cast<const SimpleTensor&>(b); - return SimpleTensor::equal(simple_a, simple_b); + return SimpleTensor::equal(to_simple(a), to_simple(b)); } vespalib::string SimpleTensorEngine::to_string(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor); + const SimpleTensor &simple_tensor = to_simple(tensor); vespalib::string out = vespalib::make_string("simple(%s) {\n", simple_tensor.type().to_spec().c_str()); for (const auto &cell: simple_tensor.cells()) { size_t n = 0; @@ -92,8 +87,7 @@ SimpleTensorEngine::to_string(const Tensor &tensor) const TensorSpec SimpleTensorEngine::to_spec(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast<const SimpleTensor&>(tensor); + const SimpleTensor &simple_tensor = to_simple(tensor); ValueType type = simple_tensor.type(); const auto &dimensions = type.dimensions(); TensorSpec spec(type.to_spec()); |