aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
commit049e9a325c8142958909d0464da12a56e5a8f638 (patch)
tree31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
parentbccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff)
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java11
1 files changed, 8 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 236e9d31c39..d7cf5bffcfe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -33,7 +33,7 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE("double"), FLOAT("float");
+ DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16");
private final String id;
@@ -59,6 +59,9 @@ public class TensorType {
public static Value largestOf(Value value1, Value value2) {
if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ if (value1 == FLOAT || value2 == FLOAT) return FLOAT;
+ if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT;
+ if (value1 == INT8 || value2 == INT8) return FLOAT;
return FLOAT;
}
@@ -69,8 +72,10 @@ public class TensorType {
switch (valueTypeString) {
case "double" : return Value.DOUBLE;
case "float" : return Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
+ case "bfloat16" : return Value.BFLOAT16;
+ case "int8" : return Value.INT8;
+ default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " +
+ "'bfloat16', or 'int8' but was '" + valueTypeString + "'");
}
}