diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java | 110 |
1 files changed, 80 insertions, 30 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 9b298f1dffb..bcff4392c9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -27,32 +27,14 @@ public class TypedBinaryFormat { private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6; private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7; + private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing + private static final int FLOAT_VALUE_TYPE = 1; + public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); - if (tensor instanceof MixedTensor) { - buffer.putInt1_4Bytes(MIXED_BINARY_FORMAT_TYPE); - new MixedBinaryFormat().encode(buffer, tensor); - } - else if (tensor instanceof IndexedTensor) { - switch (tensor.type().valueType()) { - case DOUBLE: - buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor); - break; - default: - buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); - new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor); - break; - } - } - else { - buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); - new SparseBinaryFormat().encode(buffer, tensor); - } - buffer.flip(); - byte[] result = new byte[buffer.remaining()]; - buffer.get(result); - return result; + BinaryFormat encoder = getFormatEncoder(buffer, tensor); + encoder.encode(buffer, tensor); + return asByteArray(buffer); } /** @@ -64,14 +46,82 @@ public class TypedBinaryFormat { * @throws IllegalArgumentException if the tensor data was invalid */ public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) { - int formatType = buffer.getInt1_4Bytes(); + BinaryFormat decoder = getFormatDecoder(buffer); + return decoder.decode(type, buffer); + } + + private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { + if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE); + return new MixedBinaryFormat(); + } + if (tensor instanceof MixedTensor) { + encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new MixedBinaryFormat(tensor.type().valueType()); + } + if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE); + return new DenseBinaryFormat(); + } + if (tensor instanceof IndexedTensor) { + encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new DenseBinaryFormat(tensor.type().valueType()); + } + if (tensor.type().valueType() == TensorType.Value.DOUBLE) { + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE); + return new SparseBinaryFormat(); + } + encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); + encodeValueType(buffer, tensor.type().valueType()); + return new SparseBinaryFormat(tensor.type().valueType()); + } + + private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) { + int formatType = decodeFormatType(buffer); switch (formatType) { - case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer); - case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer); - case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer); - case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer); - default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); + case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat(); + case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(); + case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat(); + case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer)); + case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer)); + case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer)); } + throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); + } + + private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) { + buffer.putInt1_4Bytes(formatType); + } + + private static int decodeFormatType(GrowableByteBuffer buffer) { + return buffer.getInt1_4Bytes(); + } + + private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { + switch (valueType) { + case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; + case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; + default: + throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); + } + } + + private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) { + int valueType = buffer.getInt1_4Bytes(); + switch (valueType) { + case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; + case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; + } + throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); + } + + private static byte[] asByteArray(GrowableByteBuffer buffer) { + buffer.flip(); + byte[] result = new byte[buffer.remaining()]; + buffer.get(result); + return result; } } |