diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2020-09-03 12:11:37 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-03 12:11:37 +0200 |
commit | 9156860dab6788bb7da0387cfef22a855e5d9e7c (patch) | |
tree | ed3b4a40dfa5923b13d7e7b1162671847c19639b | |
parent | 56190b768988935368bf59625f042035f7ec3b89 (diff) | |
parent | 7872cd8474006862b8c3b1aed0dd9591f97439b4 (diff) |
Merge pull request #14265 from vespa-engine/arnej/wrap-or-convert-simple-tensor
convert to specific tensor if possible
-rw-r--r-- | eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index 8d9767374a2..eac5d0aa26c 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -55,8 +55,20 @@ encoding_to_cell_type(uint32_t cell_encoding) { } } +std::unique_ptr<Tensor> +wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple) +{ + if (Tensor::supported({simple->type()})) { + nbostream data; + eval::SimpleTensor::encode(*simple, data); + // note: some danger of infinite recursion here + return TypedBinaryFormat::deserialize(data); + } + return std::make_unique<WrappedSimpleTensor>(std::move(simple)); } +} // namespace <unnamed> + void TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) { @@ -104,7 +116,7 @@ TypedBinaryFormat::deserialize(nbostream &stream) case MIXED_BINARY_FORMAT_TYPE: case MIXED_BINARY_FORMAT_WITH_CELLTYPE: stream.adjustReadPos(read_pos - stream.rp()); - return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream)); + return wrap_simple_tensor(eval::SimpleTensor::decode(stream)); default: throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId)); } |