diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java | 18 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 74 |
2 files changed, 45 insertions, 47 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 83a625f72ac..640fa609432 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -11,10 +11,19 @@ import java.util.Arrays; public final class DimensionSizes { private final long[] sizes; + private final long[] productOfSizesFromHereOn; + private final long totalSize; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; builder.sizes = null; // invalidate builder to avoid copying the array + this.productOfSizesFromHereOn = new long[sizes.length]; + long product = 1; + for (int i = sizes.length; i-- > 0; ) { + productOfSizesFromHereOn[i] = product; + product *= sizes[i]; + } + this.totalSize = product; } /** @@ -49,10 +58,11 @@ public final class DimensionSizes { /** Returns the product of the sizes of this */ public long totalSize() { - long productSize = 1; - for (long dimensionSize : sizes ) - productSize *= dimensionSize; - return productSize; + return totalSize; + } + + long productOfDimensionsAfter(int afterIndex) { + return productOfSizesFromHereOn[afterIndex]; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index b7ced9258b7..1319675f5d4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -108,7 +108,7 @@ public abstract class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return get((int)toValueIndex(address, dimensionSizes, type)); + return get(toValueIndex(address, dimensionSizes, type)); } catch (IllegalArgumentException e) { return 0.0; @@ -150,7 +150,7 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds"); - valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; + valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i]; } return valueIndex; } @@ -162,18 +162,11 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < address.size(); i++) { if (address.numericLabel(i) >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); } return valueIndex; } - private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - long product = 1; - for (int i = afterIndex + 1; i < sizes.dimensions(); i++) - product *= sizes.size(i); - return product; - } - void throwOnIncompatibleType(TensorType type) { if ( ! this.type().isRenamableTo(type)) throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + @@ -227,7 +220,7 @@ public abstract class IndexedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { @@ -250,8 +243,7 @@ public abstract class IndexedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (tensor.type().valueType()) { @@ -264,8 +256,7 @@ public abstract class IndexedTensor implements Tensor { } // end bracket and comma - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } if (index == maxCells && index < tensor.size()) b.append(", ...]"); @@ -327,14 +318,13 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -348,14 +338,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -369,14 +358,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } private static void validateSizes(DimensionSizes sizes, int length) { @@ -518,7 +506,7 @@ public abstract class IndexedTensor implements Tensor { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, - offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, + offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i, (List<Object>) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values for (long i = 0; i < currentDimension.size(); i++) { @@ -1091,8 +1079,8 @@ public abstract class IndexedTensor implements Tensor { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes); - this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes); + this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension); + this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; @@ -1156,7 +1144,7 @@ public abstract class IndexedTensor implements Tensor { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.step = productOfDimensionsAfter(iterateDimension, sizes); + this.step = sizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; |