summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java16
1 files changed, 16 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index dc17c657db9..9f3d7c01c6b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -225,6 +225,10 @@ public abstract class IndexedTensor implements Tensor {
b.append(tensor.get(index));
else if (tensor.type().valueType() == TensorType.Value.FLOAT)
b.append(tensor.getFloat(index));
+ else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
+ b.append(tensor.getFloat(index));
+ else if (tensor.type().valueType() == TensorType.Value.INT8)
+ b.append(tensor.getFloat(index));
else
throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
@@ -295,6 +299,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
else
@@ -315,6 +323,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
else
@@ -335,6 +347,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
else