diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-08 15:05:01 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-08 15:05:01 +0200 |
commit | cb1ea8c336adb05c90200468b25fe4ab89ee803c (patch) | |
tree | 1d0b5da3bf31ec8be03c38818ad9a08d592de120 /vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | |
parent | 8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a (diff) |
Resolve feedback from PR review
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 74 |
1 files changed, 32 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 9f3d7c01c6b..c369fe96562 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -221,16 +221,14 @@ public abstract class IndexedTensor implements Tensor { b.append("["); // value - if (tensor.type().valueType() == TensorType.Value.DOUBLE) - b.append(tensor.get(index)); - else if (tensor.type().valueType() == TensorType.Value.FLOAT) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.BFLOAT16) - b.append(tensor.getFloat(index)); - else if (tensor.type().valueType() == TensorType.Value.INT8) - b.append(tensor.getFloat(index)); - else - throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + switch (tensor.type().valueType()) { + case DOUBLE: b.append(tensor.get(index)); break; + case FLOAT: b.append(tensor.getFloat(index)); break; + case BFLOAT16: b.append(tensor.getFloat(index)); break; + case INT8: b.append(tensor.getFloat(index)); break; + default: + throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); + } // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) @@ -296,17 +294,14 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + } } /** @@ -320,17 +315,14 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default + } } /** @@ -344,17 +336,15 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); + switch (type.valueType()) { + case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default: + return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default - if (type.valueType() == TensorType.Value.FLOAT) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.BFLOAT16) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.INT8) - return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - else if (type.valueType() == TensorType.Value.DOUBLE) - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - else - return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default + } } private static void validateSizes(DimensionSizes sizes, int length) { |