From 7b658207786c84aeeee69d6026953988b53ea68d Mon Sep 17 00:00:00 2001 From: Haavard Date: Fri, 10 Feb 2017 14:21:38 +0000 Subject: enable converting simple tensors to errors --- eval/src/vespa/eval/eval/simple_tensor_engine.cpp | 28 +++++++++-------------- 1 file changed, 11 insertions(+), 17 deletions(-) (limited to 'eval') diff --git a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp index dffc641e533..265f0404dca 100644 --- a/eval/src/vespa/eval/eval/simple_tensor_engine.cpp +++ b/eval/src/vespa/eval/eval/simple_tensor_engine.cpp @@ -26,10 +26,8 @@ const SimpleTensor &to_simple(const Tensor &tensor) { } const SimpleTensor &to_simple(const Value &value, Stash &stash) { - auto tensor = value.as_tensor(); - if (tensor) { - assert(&tensor->engine() == &SimpleTensorEngine::ref()); - return static_cast(*tensor); + if (auto tensor = value.as_tensor()) { + return to_simple(*tensor); } return stash.create(value.as_double()); } @@ -39,7 +37,11 @@ const Value &to_value(std::unique_ptr tensor, Stash &stash) { assert(tensor->cells().size() == 1u); return stash.create(tensor->cells()[0].value); } - return stash.create(std::move(tensor)); + if (tensor->type().is_tensor()) { + return stash.create(std::move(tensor)); + } + assert(tensor->type().is_error()); + return stash.create(); } } // namespace vespalib::eval:: @@ -49,26 +51,19 @@ const SimpleTensorEngine SimpleTensorEngine::_engine; ValueType SimpleTensorEngine::type_of(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast(tensor); - return simple_tensor.type(); + return to_simple(tensor).type(); } bool SimpleTensorEngine::equal(const Tensor &a, const Tensor &b) const { - assert(&a.engine() == this); - assert(&b.engine() == this); - const SimpleTensor &simple_a = static_cast(a); - const SimpleTensor &simple_b = static_cast(b); - return SimpleTensor::equal(simple_a, simple_b); + return SimpleTensor::equal(to_simple(a), to_simple(b)); } vespalib::string SimpleTensorEngine::to_string(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast(tensor); + const SimpleTensor &simple_tensor = to_simple(tensor); vespalib::string out = vespalib::make_string("simple(%s) {\n", simple_tensor.type().to_spec().c_str()); for (const auto &cell: simple_tensor.cells()) { size_t n = 0; @@ -92,8 +87,7 @@ SimpleTensorEngine::to_string(const Tensor &tensor) const TensorSpec SimpleTensorEngine::to_spec(const Tensor &tensor) const { - assert(&tensor.engine() == this); - const SimpleTensor &simple_tensor = static_cast(tensor); + const SimpleTensor &simple_tensor = to_simple(tensor); ValueType type = simple_tensor.type(); const auto &dimensions = type.dimensions(); TensorSpec spec(type.to_spec()); -- cgit v1.2.3