diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 236e9d31c39..d7cf5bffcfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -33,7 +33,7 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE("double"), FLOAT("float"); + DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16"); private final String id; @@ -59,6 +59,9 @@ public class TensorType { public static Value largestOf(Value value1, Value value2) { if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + if (value1 == FLOAT || value2 == FLOAT) return FLOAT; + if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT; + if (value1 == INT8 || value2 == INT8) return FLOAT; return FLOAT; } @@ -69,8 +72,10 @@ public class TensorType { switch (valueTypeString) { case "double" : return Value.DOUBLE; case "float" : return Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); + case "bfloat16" : return Value.BFLOAT16; + case "int8" : return Value.INT8; + default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " + + "'bfloat16', or 'int8' but was '" + valueTypeString + "'"); } } |