summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:48:13 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 23:48:13 +0100
commitbceb0e5d4dd71c12a87cd15e18d31ec7ca4957e7 (patch)
treeb4a90fd2dac2119bf74136d4e199ee8fadb56898 /vespajlib
parent914cad21b94a09f2ec340572491681eba8108834 (diff)
- Move computation of productOfDimensionsAfter to DimensionSizes.
- And then precompute them construction time.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
2 files changed, 45 insertions, 47 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 83a625f72ac..640fa609432 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -11,10 +11,19 @@ import java.util.Arrays;
public final class DimensionSizes {
private final long[] sizes;
+ private final long[] productOfSizesFromHereOn;
+ private final long totalSize;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
builder.sizes = null; // invalidate builder to avoid copying the array
+ this.productOfSizesFromHereOn = new long[sizes.length];
+ long product = 1;
+ for (int i = sizes.length; i-- > 0; ) {
+ productOfSizesFromHereOn[i] = product;
+ product *= sizes[i];
+ }
+ this.totalSize = product;
}
/**
@@ -49,10 +58,11 @@ public final class DimensionSizes {
/** Returns the product of the sizes of this */
public long totalSize() {
- long productSize = 1;
- for (long dimensionSize : sizes )
- productSize *= dimensionSize;
- return productSize;
+ return totalSize;
+ }
+
+ long productOfDimensionsAfter(int afterIndex) {
+ return productOfSizesFromHereOn[afterIndex];
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index b7ced9258b7..1319675f5d4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -108,7 +108,7 @@ public abstract class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return get((int)toValueIndex(address, dimensionSizes, type));
+ return get(toValueIndex(address, dimensionSizes, type));
}
catch (IllegalArgumentException e) {
return 0.0;
@@ -150,7 +150,7 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i))
throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds");
- valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
+ valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i];
}
return valueIndex;
}
@@ -162,18 +162,11 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < address.size(); i++) {
if (address.numericLabel(i) >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i);
}
return valueIndex;
}
- private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- long product = 1;
- for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
- product *= sizes.size(i);
- return product;
- }
-
void throwOnIncompatibleType(TensorType type) {
if ( ! this.type().isRenamableTo(type))
throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
@@ -227,7 +220,7 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
@@ -250,8 +243,7 @@ public abstract class IndexedTensor implements Tensor {
b.append(", ");
// start brackets
- for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
- b.append("[");
+ b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (tensor.type().valueType()) {
@@ -264,8 +256,7 @@ public abstract class IndexedTensor implements Tensor {
}
// end bracket and comma
- for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
- b.append("]");
+ b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
if (index == maxCells && index < tensor.size())
b.append(", ...]");
@@ -327,14 +318,13 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -348,14 +338,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -369,14 +358,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
private static void validateSizes(DimensionSizes sizes, int length) {
@@ -518,7 +506,7 @@ public abstract class IndexedTensor implements Tensor {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
- offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
+ offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i,
(List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
for (long i = 0; i < currentDimension.size(); i++) {
@@ -1091,8 +1079,8 @@ public abstract class IndexedTensor implements Tensor {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes);
- this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes);
+ this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension);
+ this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
@@ -1156,7 +1144,7 @@ public abstract class IndexedTensor implements Tensor {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.step = productOfDimensionsAfter(iterateDimension, sizes);
+ this.step = sizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;