aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 15:05:01 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 15:05:01 +0200
commitcb1ea8c336adb05c90200468b25fe4ab89ee803c (patch)
tree1d0b5da3bf31ec8be03c38818ad9a08d592de120 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
parent8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a (diff)
Resolve feedback from PR review
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
1 files changed, 32 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 9f3d7c01c6b..c369fe96562 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -221,16 +221,14 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (tensor.type().valueType() == TensorType.Value.DOUBLE)
- b.append(tensor.get(index));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(tensor.getFloat(index));
- else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
- b.append(tensor.getFloat(index));
- else if (tensor.type().valueType() == TensorType.Value.INT8)
- b.append(tensor.getFloat(index));
- else
- throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ switch (tensor.type().valueType()) {
+ case DOUBLE: b.append(tensor.get(index)); break;
+ case FLOAT: b.append(tensor.getFloat(index)); break;
+ case BFLOAT16: b.append(tensor.getFloat(index)); break;
+ case INT8: b.append(tensor.getFloat(index)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
@@ -296,17 +294,14 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ }
}
/**
@@ -320,17 +315,14 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
}
/**
@@ -344,17 +336,15 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
}
private static void validateSizes(DimensionSizes sizes, int length) {