summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2020-09-03 12:11:37 +0200
committerGitHub <noreply@github.com>2020-09-03 12:11:37 +0200
commit9156860dab6788bb7da0387cfef22a855e5d9e7c (patch)
treeed3b4a40dfa5923b13d7e7b1162671847c19639b
parent56190b768988935368bf59625f042035f7ec3b89 (diff)
parent7872cd8474006862b8c3b1aed0dd9591f97439b4 (diff)
Merge pull request #14265 from vespa-engine/arnej/wrap-or-convert-simple-tensor
convert to specific tensor if possible
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index 8d9767374a2..eac5d0aa26c 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -55,8 +55,20 @@ encoding_to_cell_type(uint32_t cell_encoding) {
}
}
+std::unique_ptr<Tensor>
+wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple)
+{
+ if (Tensor::supported({simple->type()})) {
+ nbostream data;
+ eval::SimpleTensor::encode(*simple, data);
+ // note: some danger of infinite recursion here
+ return TypedBinaryFormat::deserialize(data);
+ }
+ return std::make_unique<WrappedSimpleTensor>(std::move(simple));
}
+} // namespace <unnamed>
+
void
TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
{
@@ -104,7 +116,7 @@ TypedBinaryFormat::deserialize(nbostream &stream)
case MIXED_BINARY_FORMAT_TYPE:
case MIXED_BINARY_FORMAT_WITH_CELLTYPE:
stream.adjustReadPos(read_pos - stream.rp());
- return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream));
+ return wrap_simple_tensor(eval::SimpleTensor::decode(stream));
default:
throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId));
}