diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dc17c657db9..9f3d7c01c6b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -225,6 +225,10 @@ public abstract class IndexedTensor implements Tensor { b.append(tensor.get(index)); else if (tensor.type().valueType() == TensorType.Value.FLOAT) b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) + b.append(tensor.getFloat(index)); + else if (tensor.type().valueType() == TensorType.Value.INT8) + b.append(tensor.getFloat(index)); else throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); @@ -295,6 +299,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else @@ -315,6 +323,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); else @@ -335,6 +347,10 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.BFLOAT16) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + else if (type.valueType() == TensorType.Value.INT8) + return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); else |