summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java58
1 files changed, 32 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index dc17c657db9..c369fe96562 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -221,12 +221,14 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (tensor.type().valueType() == TensorType.Value.DOUBLE)
- b.append(tensor.get(index));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(tensor.getFloat(index));
- else
- throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ switch (tensor.type().valueType()) {
+ case DOUBLE: b.append(tensor.get(index)); break;
+ case FLOAT: b.append(tensor.getFloat(index)); break;
+ case BFLOAT16: b.append(tensor.getFloat(index)); break;
+ case INT8: b.append(tensor.getFloat(index)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
@@ -292,13 +294,14 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ }
}
/**
@@ -312,13 +315,14 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
}
/**
@@ -332,13 +336,15 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
}
private static void validateSizes(DimensionSizes sizes, int length) {