summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-09-03 09:21:39 +0000
committerArne Juul <arnej@verizonmedia.com>2020-09-03 09:21:39 +0000
commit7872cd8474006862b8c3b1aed0dd9591f97439b4 (patch)
tree5e91296278e92f84eb36a05db6b93d0e1582afc7 /eval
parentbed63d34ef760934ba45bb80d36699345c9416f5 (diff)
convert to specific tensor if possible
* when doing fallback decoding to SimpleTensor, check if its resulting type is something we can support with a more specific implementation. If so, convert by serializing the SimpleTensor and recursing once to deserialize().
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp14
1 files changed, 13 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index 8d9767374a2..eac5d0aa26c 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -55,8 +55,20 @@ encoding_to_cell_type(uint32_t cell_encoding) {
}
}
+std::unique_ptr<Tensor>
+wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple)
+{
+ if (Tensor::supported({simple->type()})) {
+ nbostream data;
+ eval::SimpleTensor::encode(*simple, data);
+ // note: some danger of infinite recursion here
+ return TypedBinaryFormat::deserialize(data);
+ }
+ return std::make_unique<WrappedSimpleTensor>(std::move(simple));
}
+} // namespace <unnamed>
+
void
TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
{
@@ -104,7 +116,7 @@ TypedBinaryFormat::deserialize(nbostream &stream)
case MIXED_BINARY_FORMAT_TYPE:
case MIXED_BINARY_FORMAT_WITH_CELLTYPE:
stream.adjustReadPos(read_pos - stream.rp());
- return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream));
+ return wrap_simple_tensor(eval::SimpleTensor::decode(stream));
default:
throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId));
}