summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java9
1 files changed, 8 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 5c47572c779..cddb283489c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -29,6 +29,8 @@ public class TypedBinaryFormat {
private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
private static final int FLOAT_VALUE_TYPE = 1;
+ private static final int BFLOAT16_VALUE_TYPE = 2;
+ private static final int INT8_VALUE_TYPE = 3;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
@@ -113,6 +115,8 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break;
case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break;
+ case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break;
+ case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break;
default:
throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
}
@@ -123,8 +127,11 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
+ case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16;
+ case INT8_VALUE_TYPE: return TensorType.Value.INT8;
}
- throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal.");
+ throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " +
+ "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
}
private static byte[] asByteArray(GrowableByteBuffer buffer) {