diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-09-03 09:21:39 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-09-03 09:21:39 +0000 |
commit | 7872cd8474006862b8c3b1aed0dd9591f97439b4 (patch) | |
tree | 5e91296278e92f84eb36a05db6b93d0e1582afc7 /eval | |
parent | bed63d34ef760934ba45bb80d36699345c9416f5 (diff) |
convert to specific tensor if possible
* when doing fallback decoding to SimpleTensor, check if its
resulting type is something we can support with a more specific
implementation. If so, convert by serializing the SimpleTensor
and recursing once to deserialize().
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp | 14 |
1 files changed, 13 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index 8d9767374a2..eac5d0aa26c 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -55,8 +55,20 @@ encoding_to_cell_type(uint32_t cell_encoding) { } } +std::unique_ptr<Tensor> +wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple) +{ + if (Tensor::supported({simple->type()})) { + nbostream data; + eval::SimpleTensor::encode(*simple, data); + // note: some danger of infinite recursion here + return TypedBinaryFormat::deserialize(data); + } + return std::make_unique<WrappedSimpleTensor>(std::move(simple)); } +} // namespace <unnamed> + void TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) { @@ -104,7 +116,7 @@ TypedBinaryFormat::deserialize(nbostream &stream) case MIXED_BINARY_FORMAT_TYPE: case MIXED_BINARY_FORMAT_WITH_CELLTYPE: stream.adjustReadPos(read_pos - stream.rp()); - return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream)); + return wrap_simple_tensor(eval::SimpleTensor::decode(stream)); default: throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId)); } |