diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java | 9 |
1 files changed, 8 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5c47572c779..cddb283489c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -29,6 +29,8 @@ public class TypedBinaryFormat { private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing private static final int FLOAT_VALUE_TYPE = 1; + private static final int BFLOAT16_VALUE_TYPE = 2; + private static final int INT8_VALUE_TYPE = 3; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); @@ -113,6 +115,8 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; + case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break; + case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break; default: throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); } @@ -123,8 +127,11 @@ public class TypedBinaryFormat { switch (valueType) { case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; + case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16; + case INT8_VALUE_TYPE: return TensorType.Value.INT8; } - throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); + throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " + + "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal."); } private static byte[] asByteArray(GrowableByteBuffer buffer) { |