diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 58 |
1 files changed, 32 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index dc17c657db9..c369fe96562 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -221,12 +221,14 @@ public abstract class IndexedTensor implements Tensor { b.append("["); // value - if (tensor.type().valueType() == TensorType.Value.DOUBLE) - b.append(tensor.get(index)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(tensor.getFloat(index)); - else - throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + switch (tensor.type().valueType()) { + case DOUBLE: b.append(tensor.get(index)); break; + case FLOAT: b.append(tensor.getFloat(index)); break; + case BFLOAT16: b.append(tensor.getFloat(index)); break; + case INT8: b.append(tensor.getFloat(index)); break; + default: + throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) @@ -292,13 +294,14 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + } } /** @@ -312,13 +315,14 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } } /** @@ -332,13 +336,15 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } } private static void validateSizes(DimensionSizes sizes, int length) { |