summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java110
1 files changed, 80 insertions, 30 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 9b298f1dffb..bcff4392c9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -27,32 +27,14 @@ public class TypedBinaryFormat {
private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6;
private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7;
+ private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
+ private static final int FLOAT_VALUE_TYPE = 1;
+
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
- if (tensor instanceof MixedTensor) {
- buffer.putInt1_4Bytes(MIXED_BINARY_FORMAT_TYPE);
- new MixedBinaryFormat().encode(buffer, tensor);
- }
- else if (tensor instanceof IndexedTensor) {
- switch (tensor.type().valueType()) {
- case DOUBLE:
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).encode(buffer, tensor);
- break;
- default:
- buffer.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
- new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).encode(buffer, tensor);
- break;
- }
- }
- else {
- buffer.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
- new SparseBinaryFormat().encode(buffer, tensor);
- }
- buffer.flip();
- byte[] result = new byte[buffer.remaining()];
- buffer.get(result);
- return result;
+ BinaryFormat encoder = getFormatEncoder(buffer, tensor);
+ encoder.encode(buffer, tensor);
+ return asByteArray(buffer);
}
/**
@@ -64,14 +46,82 @@ public class TypedBinaryFormat {
* @throws IllegalArgumentException if the tensor data was invalid
*/
public static Tensor decode(Optional<TensorType> type, GrowableByteBuffer buffer) {
- int formatType = buffer.getInt1_4Bytes();
+ BinaryFormat decoder = getFormatDecoder(buffer);
+ return decoder.decode(type, buffer);
+ }
+
+ private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
+ if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE);
+ return new MixedBinaryFormat();
+ }
+ if (tensor instanceof MixedTensor) {
+ encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new MixedBinaryFormat(tensor.type().valueType());
+ }
+ if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE);
+ return new DenseBinaryFormat();
+ }
+ if (tensor instanceof IndexedTensor) {
+ encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new DenseBinaryFormat(tensor.type().valueType());
+ }
+ if (tensor.type().valueType() == TensorType.Value.DOUBLE) {
+ encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE);
+ return new SparseBinaryFormat();
+ }
+ encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
+ encodeValueType(buffer, tensor.type().valueType());
+ return new SparseBinaryFormat(tensor.type().valueType());
+ }
+
+ private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) {
+ int formatType = decodeFormatType(buffer);
switch (formatType) {
- case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat().decode(type, buffer);
- case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat().decode(type, buffer);
- case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.DOUBLE_IS_DEFAULT).decode(type, buffer);
- case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(DenseBinaryFormat.EncodeType.NO_DEFAULT).decode(type, buffer);
- default: throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat();
+ case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat();
+ case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat();
+ case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer));
+ case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer));
+ case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer));
}
+ throw new IllegalArgumentException("Binary format type " + formatType + " is unknown");
+ }
+
+ private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) {
+ buffer.putInt1_4Bytes(formatType);
+ }
+
+ private static int decodeFormatType(GrowableByteBuffer buffer) {
+ return buffer.getInt1_4Bytes();
+ }
+
+ private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) {
+ switch (valueType) {
+ case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break;
+ case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break;
+ default:
+ throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
+ }
+ }
+
+ private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) {
+ int valueType = buffer.getInt1_4Bytes();
+ switch (valueType) {
+ case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
+ case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
+ }
+ throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal.");
+ }
+
+ private static byte[] asByteArray(GrowableByteBuffer buffer) {
+ buffer.flip();
+ byte[] result = new byte[buffer.remaining()];
+ buffer.get(result);
+ return result;
}
}