summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
1 files changed, 31 insertions, 43 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index b7ced9258b7..1319675f5d4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -108,7 +108,7 @@ public abstract class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return get((int)toValueIndex(address, dimensionSizes, type));
+ return get(toValueIndex(address, dimensionSizes, type));
}
catch (IllegalArgumentException e) {
return 0.0;
@@ -150,7 +150,7 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i))
throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds");
- valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
+ valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i];
}
return valueIndex;
}
@@ -162,18 +162,11 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < address.size(); i++) {
if (address.numericLabel(i) >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i);
}
return valueIndex;
}
- private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- long product = 1;
- for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
- product *= sizes.size(i);
- return product;
- }
-
void throwOnIncompatibleType(TensorType type) {
if ( ! this.type().isRenamableTo(type))
throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
@@ -227,7 +220,7 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
@@ -250,8 +243,7 @@ public abstract class IndexedTensor implements Tensor {
b.append(", ");
// start brackets
- for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
- b.append("[");
+ b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (tensor.type().valueType()) {
@@ -264,8 +256,7 @@ public abstract class IndexedTensor implements Tensor {
}
// end bracket and comma
- for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
- b.append("]");
+ b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
if (index == maxCells && index < tensor.size())
b.append(", ...]");
@@ -327,14 +318,13 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -348,14 +338,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -369,14 +358,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
private static void validateSizes(DimensionSizes sizes, int length) {
@@ -518,7 +506,7 @@ public abstract class IndexedTensor implements Tensor {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
- offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
+ offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i,
(List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
for (long i = 0; i < currentDimension.size(); i++) {
@@ -1091,8 +1079,8 @@ public abstract class IndexedTensor implements Tensor {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes);
- this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes);
+ this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension);
+ this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
@@ -1156,7 +1144,7 @@ public abstract class IndexedTensor implements Tensor {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.step = productOfDimensionsAfter(iterateDimension, sizes);
+ this.step = sizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;