aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com')
-rw-r--r--vespajlib/src/main/java/com/yahoo/compress/Hasher.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java114
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java271
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java62
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java58
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java202
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java36
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java92
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java115
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java83
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java154
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java53
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java61
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java66
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java53
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java44
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java4
35 files changed, 1205 insertions, 554 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java
index 92a9ed26085..7a3d34eca7b 100644
--- a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java
+++ b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java
@@ -8,8 +8,25 @@ import net.openhft.hashing.LongHashFunction;
* @author baldersheim
*/
public class Hasher {
+ private final LongHashFunction hasher;
/** Uses net.openhft.hashing.LongHashFunction.xx3() */
public static long xxh3(byte [] data) {
return LongHashFunction.xx3().hashBytes(data);
}
+ public static long xxh3(byte [] data, long seed) {
+ return LongHashFunction.xx3(seed).hashBytes(data);
+ }
+
+ private Hasher(LongHashFunction hasher) {
+ this.hasher = hasher;
+ }
+ public static Hasher withSeed(long seed) {
+ return new Hasher(LongHashFunction.xx3(seed));
+ }
+ public long hash(long v) {
+ return hasher.hashLong(v);
+ }
+ public long hash(String s) {
+ return hasher.hashChars(s);
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 83a625f72ac..640fa609432 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -11,10 +11,19 @@ import java.util.Arrays;
public final class DimensionSizes {
private final long[] sizes;
+ private final long[] productOfSizesFromHereOn;
+ private final long totalSize;
private DimensionSizes(Builder builder) {
this.sizes = builder.sizes;
builder.sizes = null; // invalidate builder to avoid copying the array
+ this.productOfSizesFromHereOn = new long[sizes.length];
+ long product = 1;
+ for (int i = sizes.length; i-- > 0; ) {
+ productOfSizesFromHereOn[i] = product;
+ product *= sizes[i];
+ }
+ this.totalSize = product;
}
/**
@@ -49,10 +58,11 @@ public final class DimensionSizes {
/** Returns the product of the sizes of this */
public long totalSize() {
- long productSize = 1;
- for (long dimensionSize : sizes )
- productSize *= dimensionSize;
- return productSize;
+ return totalSize;
+ }
+
+ long productOfDimensionsAfter(int afterIndex) {
+ return productOfSizesFromHereOn[afterIndex];
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
new file mode 100644
index 00000000000..cda3be47ddb
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
@@ -0,0 +1,55 @@
+package com.yahoo.tensor;
+
+/**
+ * Utility class for efficient access and iteration along dimensions in Indexed tensors.
+ * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
+ * long base = addr.getDirectIndex();
+ * long stride = addr.getStride(dimension)
+ * i = 0...size_of_dimension
+ * double value = tensor.get(base + i * stride);
+ *
+ * @author baldersheim
+ */
+public final class DirectIndexedAddress {
+
+ private final DimensionSizes sizes;
+ private final int[] indexes;
+ private long directIndex;
+
+ private DirectIndexedAddress(DimensionSizes sizes) {
+ this.sizes = sizes;
+ indexes = new int[sizes.dimensions()];
+ directIndex = 0;
+ }
+
+ public static DirectIndexedAddress of(DimensionSizes sizes) {
+ return new DirectIndexedAddress(sizes);
+ }
+
+ /** Sets the current index of a dimension */
+ public void setIndex(int dimension, int index) {
+ if (index < 0 || index >= sizes.size(dimension)) {
+ throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
+ }
+ int diff = index - indexes[dimension];
+ directIndex += getStride(dimension) * diff;
+ indexes[dimension] = index;
+ }
+
+ /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
+ public long getDirectIndex() { return directIndex; }
+
+ public long [] getIndexes() {
+ long[] asLong = new long[indexes.length];
+ for (int i=0; i < indexes.length; i++) {
+ asLong[i] = indexes[i];
+ }
+ return asLong;
+ }
+
+ /** returns the stride to be used for the given dimension */
+ public long getStride(int dimension) {
+ return sizes.productOfDimensionsAfter(dimension);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 548d39dd767..53f50fc4d02 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -22,6 +22,10 @@ class IndexedDoubleTensor extends IndexedTensor {
return values.length;
}
+ /** Once we can store more cells than an int we should drop this method. */
+ @Override
+ public int sizeAsInt() { return values.length; }
+
@Override
public double get(long valueIndex) { return values[(int)valueIndex]; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
index 26560a70ac4..3085ef1a843 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java
@@ -18,9 +18,11 @@ class IndexedFloatTensor extends IndexedTensor {
}
@Override
- public long size() {
- return values.length;
- }
+ public long size() { return values.length; }
+
+ /** Once we can store more cells than an int we should drop this. */
+ @Override
+ public int sizeAsInt() { return values.length; }
@Override
public double get(long valueIndex) { return getFloat(valueIndex); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 6a879fa533b..fc0473c635a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -90,9 +90,13 @@ public abstract class IndexedTensor implements Tensor {
* @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given
*/
public double get(long ... indexes) {
- return get((int)toValueIndex(indexes, dimensionSizes));
+ return get(toValueIndex(indexes, dimensionSizes));
}
+ public double get(DirectIndexedAddress address) {
+ return get(address.getDirectIndex());
+ }
+ public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); }
/**
* Returns the value at the given indexes as a float
*
@@ -108,7 +112,7 @@ public abstract class IndexedTensor implements Tensor {
public double get(TensorAddress address) {
// optimize for fast lookup within bounds:
try {
- return get((int)toValueIndex(address, dimensionSizes, type));
+ return get(toValueIndex(address, dimensionSizes, type));
}
catch (IllegalArgumentException e) {
return 0.0;
@@ -116,6 +120,17 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
+ public Double getAsDouble(TensorAddress address) {
+ try {
+ long index = toValueIndex(address, dimensionSizes, type);
+ if (index < 0 || size() <= index) return null;
+ return get(index);
+ } catch (IllegalArgumentException e) {
+ return null;
+ }
+ }
+
+ @Override
public boolean has(TensorAddress address) {
try {
long index = toValueIndex(address, dimensionSizes, type);
@@ -150,30 +165,22 @@ public abstract class IndexedTensor implements Tensor {
for (int i = 0; i < indexes.length; i++) {
if (indexes[i] >= sizes.size(i))
throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds");
- valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i];
+ valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i];
}
return valueIndex;
}
static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) {
- if (address.isEmpty()) return 0;
-
long valueIndex = 0;
- for (int i = 0; i < address.size(); i++) {
- if (address.numericLabel(i) >= sizes.size(i))
+ for (int i = 0, size = address.size(); i < size; i++) {
+ long label = address.numericLabel(i);
+ if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
- valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i);
+ valueIndex += sizes.productOfDimensionsAfter(i) * label;
}
return valueIndex;
}
- private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) {
- long product = 1;
- for (int i = afterIndex + 1; i < sizes.dimensions(); i++)
- product *= sizes.size(i);
- return product;
- }
-
void throwOnIncompatibleType(TensorType type) {
if ( ! this.type().isRenamableTo(type))
throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type +
@@ -227,7 +234,7 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
@@ -250,8 +257,7 @@ public abstract class IndexedTensor implements Tensor {
b.append(", ");
// start brackets
- for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
- b.append("[");
+ b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (tensor.type().valueType()) {
@@ -264,8 +270,7 @@ public abstract class IndexedTensor implements Tensor {
}
// end bracket and comma
- for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
- b.append("]");
+ b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
if (index == maxCells && index < tensor.size())
b.append(", ...]");
@@ -286,7 +291,7 @@ public abstract class IndexedTensor implements Tensor {
}
public static Builder of(TensorType type) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type));
else
return new UnboundBuilder(type);
@@ -300,7 +305,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, float[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -314,7 +319,7 @@ public abstract class IndexedTensor implements Tensor {
* must not be further mutated by the caller
*/
public static Builder of(TensorType type, double[] values) {
- if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension))
+ if (type.hasOnlyIndexedBoundDimensions())
return of(type, BoundBuilder.dimensionSizesOf(type), values);
else
return new UnboundBuilder(type);
@@ -327,14 +332,13 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -348,14 +352,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
/**
@@ -369,14 +372,13 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
- switch (type.valueType()) {
- case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- default:
- throw new IllegalStateException("Unexpected value type " + type.valueType());
- }
+ return switch (type.valueType()) {
+ case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default -> throw new IllegalStateException("Unexpected value type " + type.valueType());
+ };
}
private static void validateSizes(DimensionSizes sizes, int length) {
@@ -518,7 +520,7 @@ public abstract class IndexedTensor implements Tensor {
if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension
for (long i = 0; i < currentDimension.size(); i++)
fillValues(currentDimensionIndex + 1,
- offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i,
+ offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i,
(List<Object>) currentDimension.get((int)i), sizes, values);
} else { // last dimension - fill values
for (long i = 0; i < currentDimension.size(); i++) {
@@ -623,11 +625,11 @@ public abstract class IndexedTensor implements Tensor {
private final class ValueIterator implements Iterator<Double> {
- private long count = 0;
+ private int count = 0;
@Override
public boolean hasNext() {
- return count < size();
+ return count < sizeAsInt();
}
@Override
@@ -889,8 +891,8 @@ public abstract class IndexedTensor implements Tensor {
private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
long size = 1;
- for (int iterateDimension : iterateDimensions)
- size *= sizes.size(iterateDimension);
+ for (int i = 0; i < iterateDimensions.size(); i++)
+ size *= sizes.size(iterateDimensions.get(i));
return size;
}
@@ -1056,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor {
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private long lastComputedSourceValueIndex = -1;
+ private long lastComputedSourceValueIndex = Tensor.invalidIndex;
private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
@@ -1091,8 +1093,8 @@ public abstract class IndexedTensor implements Tensor {
super(sourceSizes, iterateSizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes);
- this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes);
+ this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension);
+ this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
@@ -1156,7 +1158,7 @@ public abstract class IndexedTensor implements Tensor {
super(sizes, sizes, initialIndexes);
this.iterateDimension = iterateDimension;
this.size = size;
- this.step = productOfDimensionsAfter(iterateDimension, sizes);
+ this.step = sizes.productOfDimensionsAfter(iterateDimension);
// Initialize to the (virtual) position before the first cell
indexes[iterateDimension]--;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index e196569b18f..3e0df5f2261 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
-import java.util.function.DoubleBinaryOperator;
/**
* A sparse implementation of a tensor backed by a Map of cells to values.
@@ -31,6 +30,10 @@ public class MappedTensor implements Tensor {
@Override
public long size() { return cells.size(); }
+ /** Once we can store more cells than an int we should drop this. */
+ @Override
+ public int sizeAsInt() { return cells.size(); }
+
@Override
public double get(TensorAddress address) { return cells.getOrDefault(address, 0.0); }
@@ -38,6 +41,9 @@ public class MappedTensor implements Tensor {
public boolean has(TensorAddress address) { return cells.containsKey(address); }
@Override
+ public Double getAsDouble(TensorAddress address) { return cells.get(address); }
+
+ @Override
public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); }
@Override
@@ -79,7 +85,7 @@ public class MappedTensor implements Tensor {
@Override
public String toAbbreviatedString(boolean withType, boolean shortForms) {
- return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1)));
+ return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1)));
}
private String toString(boolean withType, boolean shortForms, long maxCells) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 5d5a5f74063..65c6677e7e3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -2,12 +2,13 @@
package com.yahoo.tensor;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.ArrayList;
import java.util.Arrays;
-import java.util.HashMap;
import java.util.Iterator;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -28,7 +29,6 @@ public class MixedTensor implements Tensor {
/** The dimension specification for this tensor */
private final TensorType type;
- private final int denseSubspaceSize;
// XXX consider using "record" instead
/** only exposed for internal use; subject to change without notice */
@@ -50,45 +50,15 @@ public class MixedTensor implements Tensor {
}
}
- /** The cells in the tensor */
- private final List<DenseSubspace> denseSubspaces;
-
/** only exposed for internal use; subject to change without notice */
- public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; }
+ public List<DenseSubspace> getInternalDenseSubspaces() { return index.denseSubspaces; }
/** An index structure over the cell list */
private final Index index;
- private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) {
+ private MixedTensor(TensorType type, Index index) {
this.type = type;
- this.denseSubspaceSize = index.denseSubspaceSize();
- this.denseSubspaces = List.copyOf(denseSubspaces);
this.index = index;
- if (this.denseSubspaceSize < 1) {
- throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize);
- }
- long count = 0;
- for (var block : this.denseSubspaces) {
- if (index.sparseMap.get(block.sparseAddress) != count) {
- throw new IllegalStateException("map vs list mismatch: block #"
- + count
- + " address maps to #"
- + index.sparseMap.get(block.sparseAddress));
- }
- if (block.cells.length != denseSubspaceSize) {
- throw new IllegalStateException("dense subspace size mismatch, expected "
- + denseSubspaceSize
- + " cells, but got: "
- + block.cells.length);
- }
- ++count;
- }
- if (count != index.sparseMap.size()) {
- throw new IllegalStateException("mismatch: list size is "
- + count
- + " but map size is "
- + index.sparseMap.size());
- }
}
/** Returns the tensor type */
@@ -97,32 +67,34 @@ public class MixedTensor implements Tensor {
/** Returns the size of the tensor measured in number of cells */
@Override
- public long size() { return denseSubspaces.size() * denseSubspaceSize; }
+ public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; }
/** Returns the value at the given address */
@Override
public double get(TensorAddress address) {
- int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > denseSubspaces.size()) {
+ var block = index.blockOf(address);
+ int denseOffset = index.denseOffsetOf(address);
+ if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) {
return 0.0;
}
+ return block.cells[denseOffset];
+ }
+
+ @Override
+ public Double getAsDouble(TensorAddress address) {
+ var block = index.blockOf(address);
int denseOffset = index.denseOffsetOf(address);
- var block = denseSubspaces.get(blockNum);
- if (denseOffset < 0 || denseOffset >= block.cells.length) {
- return 0.0;
+ if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) {
+ return null;
}
return block.cells[denseOffset];
}
@Override
public boolean has(TensorAddress address) {
- int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > denseSubspaces.size()) {
- return false;
- }
+ var block = index.blockOf(address);
int denseOffset = index.denseOffsetOf(address);
- var block = denseSubspaces.get(blockNum);
- return (denseOffset >= 0 && denseOffset < block.cells.length);
+ return (block != null && denseOffset >= 0 && denseOffset < block.cells.length);
}
/**
@@ -135,21 +107,30 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Cell> cellIterator() {
return new Iterator<>() {
- final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
- DenseSubspace currBlock = null;
- int currOffset = denseSubspaceSize;
+
+ final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator();
+ final int[] labels = new int[index.indexedDimensions.size()];
+ DenseSubspace currentBlock = null;
+ int currOffset = index.denseSubspaceSize;
+ int prevOffset = -1;
+
@Override
public boolean hasNext() {
- return (currOffset < denseSubspaceSize || blockIterator.hasNext());
+ return (currOffset < index.denseSubspaceSize || blockIterator.hasNext());
}
+
@Override
public Cell next() {
- if (currOffset == denseSubspaceSize) {
- currBlock = blockIterator.next();
+ if (currOffset == index.denseSubspaceSize) {
+ currentBlock = blockIterator.next();
currOffset = 0;
}
- TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, currOffset);
- double value = currBlock.cells[currOffset++];
+ if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1
+ index.denseOffsetToAddress(currOffset, labels);
+ }
+ TensorAddress fullAddr = currentBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels);
+ prevOffset = currOffset;
+ double value = currentBlock.cells[currOffset++];
return new Cell(fullAddr, value);
}
};
@@ -162,20 +143,23 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Double> valueIterator() {
return new Iterator<>() {
- final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
- double[] currBlock = null;
- int currOffset = denseSubspaceSize;
+
+ final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator();
+ double[] currentBlock = null;
+ int currOffset = index.denseSubspaceSize;
+
@Override
public boolean hasNext() {
- return (currOffset < denseSubspaceSize || blockIterator.hasNext());
+ return (currOffset < index.denseSubspaceSize || blockIterator.hasNext());
}
+
@Override
public Double next() {
- if (currOffset == denseSubspaceSize) {
- currBlock = blockIterator.next().cells;
+ if (currOffset == index.denseSubspaceSize) {
+ currentBlock = blockIterator.next().cells;
currOffset = 0;
}
- return currBlock[currOffset++];
+ return currentBlock[currOffset++];
}
};
}
@@ -197,24 +181,22 @@ public class MixedTensor implements Tensor {
throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" +
this.type + "', requested type: '" + type + "'");
}
- return new MixedTensor(other, denseSubspaces, index);
+ return new MixedTensor(other, index);
}
@Override
public Tensor remove(Set<TensorAddress> addresses) {
var indexBuilder = new Index.Builder(type);
- List<DenseSubspace> list = new ArrayList<>();
- for (var block : denseSubspaces) {
+ for (var block : index.denseSubspaces) {
if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part
- indexBuilder.addBlock(block.sparseAddress, list.size());
- list.add(block);
+ indexBuilder.addBlock(block);
}
}
- return new MixedTensor(type, list, indexBuilder.build());
+ return new MixedTensor(type, indexBuilder.build());
}
@Override
- public int hashCode() { return Objects.hash(type, denseSubspaces); }
+ public int hashCode() { return Objects.hash(type, index.denseSubspaces); }
@Override
public String toString() {
@@ -249,13 +231,14 @@ public class MixedTensor implements Tensor {
/** Returns the size of dense subspaces */
public long denseSubspaceSize() {
- return denseSubspaceSize;
+ return index.denseSubspaceSize;
}
/**
* Base class for building mixed tensors.
*/
public abstract static class Builder implements Tensor.Builder {
+ static final int INITIAL_HASH_CAPACITY = 1000;
final TensorType type;
@@ -265,10 +248,11 @@ public class MixedTensor implements Tensor {
* a temporary structure while finding dimension bounds.
*/
public static Builder of(TensorType type) {
- if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) {
- return new UnboundBuilder(type);
+ //TODO Wire in expected map size to avoid expensive resize
+ if (type.hasIndexedUnboundDimensions()) {
+ return new UnboundBuilder(type, INITIAL_HASH_CAPACITY);
} else {
- return new BoundBuilder(type);
+ return new BoundBuilder(type, INITIAL_HASH_CAPACITY);
}
}
@@ -306,13 +290,14 @@ public class MixedTensor implements Tensor {
public static class BoundBuilder extends Builder {
/** For each sparse partial address, hold a dense subspace */
- private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
+ private final Map<TensorAddress, double[]> denseSubspaceMap;
private final Index.Builder indexBuilder;
private final Index index;
private final TensorType denseSubtype;
- private BoundBuilder(TensorType type) {
+ private BoundBuilder(TensorType type, int expectedSize) {
super(type);
+ denseSubspaceMap = new LinkedHashMap<>(expectedSize, 0.5f);
indexBuilder = new Index.Builder(type);
index = indexBuilder.index();
denseSubtype = new TensorType(type.valueType(),
@@ -324,10 +309,7 @@ public class MixedTensor implements Tensor {
}
private double[] denseSubspace(TensorAddress sparseAddress) {
- if (!denseSubspaceMap.containsKey(sparseAddress)) {
- denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]);
- }
- return denseSubspaceMap.get(sparseAddress);
+ return denseSubspaceMap.computeIfAbsent(sparseAddress, (key) -> new double[(int)denseSubspaceSize()]);
}
public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) {
@@ -343,7 +325,7 @@ public class MixedTensor implements Tensor {
@Override
public Tensor.Builder cell(TensorAddress address, double value) {
- TensorAddress sparsePart = index.sparsePartialAddress(address);
+ TensorAddress sparsePart = address.mappedPartialAddress(index.sparseType, index.type.dimensions());
int denseOffset = index.denseOffsetOf(address);
double[] denseSubspace = denseSubspace(sparsePart);
denseSubspace[denseOffset] = value;
@@ -362,19 +344,20 @@ public class MixedTensor implements Tensor {
@Override
public MixedTensor build() {
- List<DenseSubspace> list = new ArrayList<>();
- for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
+ //TODO This can be solved more efficiently with a single map.
+ Set<Map.Entry<TensorAddress, double[]>> entrySet = denseSubspaceMap.entrySet();
+ for (Map.Entry<TensorAddress, double[]> entry : entrySet) {
TensorAddress sparsePart = entry.getKey();
double[] denseSubspace = entry.getValue();
var block = new DenseSubspace(sparsePart, denseSubspace);
- indexBuilder.addBlock(sparsePart, list.size());
- list.add(block);
+ indexBuilder.addBlock(block);
}
- return new MixedTensor(type, list, indexBuilder.build());
+ return new MixedTensor(type, indexBuilder.build());
}
public static BoundBuilder of(TensorType type) {
- return new BoundBuilder(type);
+ //TODO Wire in expected map size to avoid expensive resize
+ return new BoundBuilder(type, INITIAL_HASH_CAPACITY);
}
}
@@ -391,9 +374,9 @@ public class MixedTensor implements Tensor {
private final Map<TensorAddress, Double> cells;
private final long[] dimensionBounds;
- private UnboundBuilder(TensorType type) {
+ private UnboundBuilder(TensorType type, int expectedSize) {
super(type);
- cells = new HashMap<>();
+ cells = new LinkedHashMap<>(expectedSize, 0.5f);
dimensionBounds = new long[type.dimensions().size()];
}
@@ -412,7 +395,7 @@ public class MixedTensor implements Tensor {
@Override
public MixedTensor build() {
TensorType boundType = createBoundType();
- BoundBuilder builder = new BoundBuilder(boundType);
+ BoundBuilder builder = new BoundBuilder(boundType, cells.size());
for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) {
builder.cell(cell.getKey(), cell.getValue());
}
@@ -443,7 +426,8 @@ public class MixedTensor implements Tensor {
}
public static UnboundBuilder of(TensorType type) {
- return new UnboundBuilder(type);
+ //TODO Wire in expected map size to avoid expensive resize
+ return new UnboundBuilder(type, INITIAL_HASH_CAPACITY);
}
}
@@ -460,8 +444,10 @@ public class MixedTensor implements Tensor {
private final TensorType denseType;
private final List<TensorType.Dimension> mappedDimensions;
private final List<TensorType.Dimension> indexedDimensions;
+ private final int[] indexedDimensionsSize;
private ImmutableMap<TensorAddress, Integer> sparseMap;
+ private List<DenseSubspace> denseSubspaces;
private final int denseSubspaceSize;
static private int computeDSS(List<TensorType.Dimension> dimensions) {
@@ -477,17 +463,31 @@ public class MixedTensor implements Tensor {
this.type = type;
this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList();
this.indexedDimensions = type.dimensions().stream().filter(TensorType.Dimension::isIndexed).toList();
+ this.indexedDimensionsSize = new int[indexedDimensions.size()];
+ for (int i = 0; i < indexedDimensions.size(); i++) {
+ long dimensionSize = indexedDimensions.get(i).size().orElseThrow(() ->
+ new IllegalArgumentException("Unknown size of indexed dimension."));
+ indexedDimensionsSize[i] = (int)dimensionSize;
+ }
+
this.sparseType = createPartialType(type.valueType(), mappedDimensions);
this.denseType = createPartialType(type.valueType(), indexedDimensions);
this.denseSubspaceSize = computeDSS(this.indexedDimensions);
+ if (this.denseSubspaceSize < 1) {
+ throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize);
+ }
}
- int blockIndexOf(TensorAddress address) {
- TensorAddress sparsePart = sparsePartialAddress(address);
- return sparseMap.getOrDefault(sparsePart, -1);
+ private DenseSubspace blockOf(TensorAddress address) {
+ TensorAddress sparsePart = address.mappedPartialAddress(sparseType, type.dimensions());
+ Integer blockNum = sparseMap.get(sparsePart);
+ if (blockNum == null || blockNum >= denseSubspaces.size()) {
+ return null;
+ }
+ return denseSubspaces.get(blockNum);
}
- int denseOffsetOf(TensorAddress address) {
+ private int denseOffsetOf(TensorAddress address) {
long innerSize = 1;
long offset = 0;
for (int i = type.dimensions().size(); --i >= 0; ) {
@@ -506,54 +506,19 @@ public class MixedTensor implements Tensor {
return denseSubspaceSize;
}
- private TensorAddress sparsePartialAddress(TensorAddress address) {
- if (type.dimensions().size() != address.size())
- throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address);
- TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
- for (int i = 0; i < type.dimensions().size(); ++i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- if ( ! dimension.isIndexed())
- builder.add(dimension.name(), address.label(i));
- }
- return builder.build();
- }
-
- private TensorAddress denseOffsetToAddress(long denseOffset) {
+ private void denseOffsetToAddress(long denseOffset, int [] labels) {
if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
throw new IllegalArgumentException("Offset out of bounds");
}
long restSize = denseOffset;
long innerSize = denseSubspaceSize;
- long[] labels = new long[indexedDimensions.size()];
for (int i = 0; i < labels.length; ++i) {
- TensorType.Dimension dimension = indexedDimensions.get(i);
- long dimensionSize = dimension.size().orElseThrow(() ->
- new IllegalArgumentException("Unknown size of indexed dimension."));
-
- innerSize /= dimensionSize;
- labels[i] = restSize / innerSize;
+ innerSize /= indexedDimensionsSize[i];
+ labels[i] = (int) (restSize / innerSize);
restSize %= innerSize;
}
- return TensorAddress.of(labels);
- }
-
- TensorAddress fullAddressOf(TensorAddress sparsePart, long denseOffset) {
- TensorAddress densePart = denseOffsetToAddress(denseOffset);
- String[] labels = new String[type.dimensions().size()];
- int mappedIndex = 0;
- int indexedIndex = 0;
- for (TensorType.Dimension d : type.dimensions()) {
- if (d.isIndexed()) {
- labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex);
- indexedIndex++;
- } else {
- labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex);
- mappedIndex++;
- }
- }
- return TensorAddress.of(labels);
}
@Override
@@ -563,7 +528,7 @@ public class MixedTensor implements Tensor {
private String contentToString(MixedTensor tensor, long maxCells) {
if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller");
- if (mappedDimensions.size() == 0) {
+ if (mappedDimensions.isEmpty()) {
StringBuilder b = new StringBuilder();
int cellsWritten = denseSubspaceToString(tensor, 0, maxCells, b);
if (cellsWritten == maxCells && cellsWritten < tensor.size())
@@ -605,8 +570,7 @@ public class MixedTensor implements Tensor {
b.append(", ");
// start brackets
- for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
- b.append("[");
+ b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart())));
// value
switch (type.valueType()) {
@@ -619,32 +583,38 @@ public class MixedTensor implements Tensor {
}
// end bracket
- for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
- b.append("]");
+ b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd())));
}
return index;
}
private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) {
- return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset];
+ return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset];
}
- static class Builder {
+ private static class Builder {
private final Index index;
- private final ImmutableMap.Builder<TensorAddress, Integer> builder;
+ private final ImmutableMap.Builder<TensorAddress, Integer> builder = new ImmutableMap.Builder<>();
+ private final ImmutableList.Builder<DenseSubspace> listBuilder = new ImmutableList.Builder<>();
+ private int count = 0;
Builder(TensorType type) {
index = new Index(type);
- builder = new ImmutableMap.Builder<>();
}
- void addBlock(TensorAddress address, int sz) {
- builder.put(address, sz);
+ void addBlock(DenseSubspace block) {
+ if (block.cells.length != index.denseSubspaceSize) {
+ throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize
+ + " cells, but got: " + block.cells.length);
+ }
+ builder.put(block.sparseAddress, count++);
+ listBuilder.add(block);
}
Index build() {
index.sparseMap = builder.build();
+ index.denseSubspaces = listBuilder.build();
return index;
}
@@ -654,27 +624,16 @@ public class MixedTensor implements Tensor {
}
}
- private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder {
-
- private final TensorType type;
- private final double[] values;
-
- public DenseSubspaceBuilder(TensorType type, double[] values) {
- this.type = type;
- this.values = values;
- }
-
- @Override
- public TensorType type() { return type; }
+ private record DenseSubspaceBuilder(TensorType type, double[] values) implements IndexedTensor.DirectIndexBuilder {
@Override
public void cellByDirectIndex(long index, double value) {
- values[(int)index] = value;
+ values[(int) index] = value;
}
@Override
public void cellByDirectIndex(long index, float value) {
- values[(int)index] = value;
+ values[(int) index] = value;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index f1b3245ec80..8852bcd1ff3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -1,16 +1,16 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import java.util.Arrays;
+import com.yahoo.tensor.impl.Label;
/**
- * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
+ * An address to a subset of a tensors' cells, specifying a label for some, but not necessarily all, of the tensors
* dimensions.
*
* @author bratseth
*/
// Implementation notes:
-// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
+// - These are created in inner (though not innermost) loops, so they are implemented with minimal allocation.
// We also avoid non-essential error checking.
// - We can add support for string labels later without breaking the API
public class PartialAddress {
@@ -18,7 +18,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final Object[] labels;
+ private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -35,15 +35,15 @@ public class PartialAddress {
public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return asLong(labels[i]);
- return -1;
+ return labels[i];
+ return Tensor.invalidIndex;
}
/** Returns the label of this dimension, or null if no label is specified for it */
public String label(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
return null;
}
@@ -55,7 +55,7 @@ public class PartialAddress {
public String label(int i) {
if (i >= size())
throw new IllegalArgumentException("No label at position " + i + " in " + this);
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
}
public int size() { return dimensionNames.length; }
@@ -65,40 +65,14 @@ public class PartialAddress {
public TensorAddress asAddress(TensorType type) {
if (type.rank() != size())
throw new IllegalArgumentException(type + " has a different rank than " + this);
- if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) {
- long[] numericLabels = new long[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- long label = numericLabel(type.dimensions().get(i).name());
- if (label < 0)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- numericLabels[i] = label;
- }
- return TensorAddress.of(numericLabels);
- }
- else {
- String[] stringLabels = new String[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- String label = label(type.dimensions().get(i).name());
- if (label == null)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- stringLabels[i] = label;
- }
- return TensorAddress.of(stringLabels);
- }
- }
-
- private long asLong(Object label) {
- if (label instanceof Long) {
- return (Long) label;
- }
- else {
- try {
- return Long.parseLong(label.toString());
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Label '" + label + "' is not numeric");
- }
+ long[] numericLabels = new long[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ long label = numericLabel(type.dimensions().get(i).name());
+ if (label == Tensor.invalidIndex)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ numericLabels[i] = label;
}
+ return TensorAddress.of(numericLabels);
}
@Override
@@ -114,12 +88,12 @@ public class PartialAddress {
public static class Builder {
private String[] dimensionNames;
- private Object[] labels;
+ private long[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
- labels = new Object[size];
+ labels = new long[size];
}
public Builder add(String dimensionName, long label) {
@@ -131,7 +105,7 @@ public class PartialAddress {
public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
- labels[index] = label;
+ labels[index] = Label.toNumber(label);
index++;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 8a4179cdc1a..ac9dc4e4eca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import com.yahoo.tensor.functions.Expand;
+import com.yahoo.tensor.impl.Label;
import java.util.ArrayList;
import java.util.Arrays;
@@ -39,7 +40,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming;
* A multidimensional array which can be used in computations.
* <p>
* A tensor consists of a set of <i>dimension</i> names and a set of <i>cells</i> containing scalar <i>values</i>.
- * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines
+ * Each cell is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines
* the location of that cell. Both dimensions and labels are string on the form of an identifier or integer.
* <p>
* The size of the set of dimensions of a tensor is called its <i>rank</i>.
@@ -55,6 +56,9 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming;
*/
public interface Tensor {
+ /** The constant signaling a nonexisting value in operations addressing tensor values by index. */
+ int invalidIndex = -1;
+
// ----------------- Accessors
TensorType type();
@@ -63,11 +67,25 @@ public interface Tensor {
default boolean isEmpty() { return size() == 0; }
/**
- * Returns the number of cells in this.
- * TODO Figure how to best return an int instead of a long
- * An int is large enough, and java is far better at int base loops than long
- **/
- long size();
+ * Returns the number of cells in this, allowing for very large tensors.
+ * Prefer sizeAsInt in implementations that cannot handle sizes outside the int range.
+ */
+ default long size() {
+ return sizeAsInt();
+ }
+
+ /**
+ * Returns the size of this as an int or throws an exception if it is too large to fit in an int.
+ * Prefer this over size() with implementations that only handle sizes in the int range.
+ *
+ * @throws IndexOutOfBoundsException if the size is too large to fit in an int
+ */
+ default int sizeAsInt() {
+ long size = size();
+ if (size > Integer.MAX_VALUE)
+ throw new IndexOutOfBoundsException("size = " + size + ", which is too large to fit in an int");
+ return (int) size;
+ }
/** Returns the value of a cell, or 0.0 if this cell does not exist */
double get(TensorAddress address);
@@ -75,6 +93,9 @@ public interface Tensor {
/** Returns true if this cell exists */
boolean has(TensorAddress address);
+ /** Returns the value at this address, or null of it does not exist. */
+ Double getAsDouble(TensorAddress address);
+
/**
* Returns the cell of this in some undefined order.
* A cell instances is only valid until next() is called.
@@ -97,7 +118,7 @@ public interface Tensor {
* @throws IllegalStateException if this does not have zero dimensions and one value
*/
default double asDouble() {
- if (type().dimensions().size() > 0)
+ if (!type().dimensions().isEmpty())
throw new IllegalStateException("Require a dimensionless tensor but has " + type());
if (size() == 0) return Double.NaN;
return valueIterator().next();
@@ -113,7 +134,7 @@ public interface Tensor {
/**
* Returns a new tensor where existing cells in this tensor have been
* modified according to the given operation and cells in the given map.
- * Cells in the map outside of existing cells are thus ignored.
+ * Cells in the map outside existing cells are thus ignored.
*
* @param op the modifying function
* @param cells the cells to modify
@@ -132,9 +153,9 @@ public interface Tensor {
/**
* Returns a new tensor where existing cells in this tensor have been
- * removed according to the given set of addresses. Only valid for sparse
+ * removed according to the given set of addresses. Only valid for mapped
* or mixed tensors. For mixed tensors, addresses are assumed to only
- * contain the sparse dimensions, as the entire dense subspace is removed.
+ * contain the mapped dimensions, as the entire indexed subspace is removed.
*
* @param addresses list of addresses to remove
* @return a new tensor where cells have been removed
@@ -484,11 +505,10 @@ public interface Tensor {
public TensorAddress getKey() { return address; }
/**
- * Returns the direct index which can be used to locate this cell, or -1 if not available.
- * This is for optimizations mapping between tensors where this is possible without creating a
- * TensorAddress.
+ * Returns the direct index which can be used to locate this cell, or Tensor.invalidIndex if not available.
+ * This is for optimizations mapping between tensors where this is possible without creating a TensorAddress.
*/
- long getDirectIndex() { return -1; }
+ long getDirectIndex() { return invalidIndex; }
/** Returns the value as a double */
@Override
@@ -537,8 +557,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
@@ -549,8 +569,8 @@ public interface Tensor {
/** Creates a suitable builder for the given type */
static Builder of(TensorType type, DimensionSizes dimensionSizes) {
- boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ boolean containsIndexed = type.hasIndexedDimensions();
+ boolean containsMapped = type.hasMappedDimensions();
if (containsIndexed && containsMapped)
return MixedTensor.Builder.of(type);
if (containsMapped)
@@ -608,7 +628,7 @@ public interface Tensor {
public TensorType type() { return tensorBuilder.type(); }
public CellBuilder label(String dimension, long label) {
- return label(dimension, String.valueOf(label));
+ return label(dimension, Label.fromNumber(label));
}
public Builder value(double cellValue) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index a1cb278c75a..4fa759668b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,10 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.Label;
+import com.yahoo.tensor.impl.TensorAddressAny;
+
import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
-import java.util.Optional;
-import java.util.stream.Collectors;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
@@ -14,18 +17,20 @@ import java.util.stream.Collectors;
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
-
public static TensorAddress of(String[] labels) {
- return new StringTensorAddress(labels);
+ return TensorAddressAny.of(labels);
+ }
+
+ public static TensorAddress ofLabels(String... labels) {
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress ofLabels(String ... labels) {
- return new StringTensorAddress(labels);
+ public static TensorAddress of(long... labels) {
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress of(long ... labels) {
- return new NumericTensorAddress(labels);
+ public static TensorAddress of(int... labels) {
+ return TensorAddressAny.of(labels);
}
/** Returns the number of labels in this */
@@ -61,27 +66,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
@Override
- public int hashCode() {
- int result = 1;
- for (int i = 0; i < size(); i++) {
- if (label(i) != null)
- result = 31 * result + label(i).hashCode();
+ public String toString() {
+ StringBuilder sb = new StringBuilder("cell address (");
+ int size = size();
+ if (size > 0) {
+ sb.append(label(0));
+ for (int i = 1; i < size; i++) {
+ sb.append(',').append(label(i));
+ }
}
- return result;
- }
- @Override
- public boolean equals(Object o) {
- if (o == this) return true;
- if ( ! (o instanceof TensorAddress other)) return false;
- if (other.size() != this.size()) return false;
- for (int i = 0; i < this.size(); i++)
- if ( ! Objects.equals(this.label(i), other.label(i)))
- return false;
- return true;
+ return sb.append(')').toString();
}
- /** Returns this as a string on the appropriate form given the type */
+ /**
+ * Returns this as a string on the appropriate form given the type
+ */
public final String toString(TensorType type) {
StringBuilder b = new StringBuilder("{");
for (int i = 0; i < size(); i++) {
@@ -94,106 +94,78 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with appropriate quoting/escaping when necessary */
+ /**
+ * Returns a label as a string with appropriate quoting/escaping when necessary
+ */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
}
- private static String[] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
+ /** Returns an address with only some of the dimension. Ordering will also be according to indexMap */
+ public TensorAddress partialCopy(int[] indexMap) {
+ int[] labels = new int[indexMap.length];
+ for (int i = 0; i < labels.length; ++i) {
+ labels[i] = (int)numericLabel(indexMap[i]);
}
- return asStrings;
+ return TensorAddressAny.ofUnsafe(labels);
}
- private static String asString(long index) {
- return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
-
- private static final class StringTensorAddress extends TensorAddress {
-
- private final String[] labels;
-
- private StringTensorAddress(String ... labels) {
- this.labels = Arrays.copyOf(labels, labels.length);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return labels[i]; }
-
- @Override
- public long numericLabel(int i) {
- try {
- return Long.parseLong(labels[i]);
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
+ /** Creates a complete address by taking the mapped dimmensions from this and the indexed from the indexedPart */
+ public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int[] densePart) {
+ int[] labels = new int[dimensions.size()];
+ int mappedIndex = 0;
+ int indexedIndex = 0;
+ for (int i = 0; i < labels.length; i++) {
+ TensorType.Dimension d = dimensions.get(i);
+ if (d.isIndexed()) {
+ labels[i] = densePart[indexedIndex];
+ indexedIndex++;
+ } else {
+ labels[i] = (int)numericLabel(mappedIndex);
+ mappedIndex++;
}
}
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- String[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = TensorAddress.asString(label);
- return new StringTensorAddress(labels);
- }
-
-
- @Override
- public String toString() {
- return "cell address (" + String.join(",", labels) + ")";
- }
-
+ return TensorAddressAny.ofUnsafe(labels);
}
- private static final class NumericTensorAddress extends TensorAddress {
-
- private final long[] labels;
-
- private NumericTensorAddress(long[] labels) {
- this.labels = Arrays.copyOf(labels, labels.length);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return TensorAddress.asString(labels[i]); }
-
- @Override
- public long numericLabel(int i) { return labels[i]; }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- long[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = label;
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public String toString() {
- return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")";
+ /**
+ * Returns an address containing the mapped dimensions of this.
+ *
+ * @param mappedType the type of the mapped subset of the type this is an address of;
+ * which is also the type of the returned address
+ * @param dimensions all the dimensions of the type this is an address of
+ */
+ public TensorAddress mappedPartialAddress(TensorType mappedType, List<TensorType.Dimension> dimensions) {
+ if (dimensions.size() != size())
+ throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this);
+ TensorAddress.Builder builder = new TensorAddress.Builder(mappedType);
+ for (int i = 0; i < dimensions.size(); ++i) {
+ TensorType.Dimension dimension = dimensions.get(i);
+ if ( ! dimension.isIndexed())
+ builder.add(dimension.name(), (int)numericLabel(i));
}
-
+ return builder.build();
}
/** Builder of a tensor address */
public static class Builder {
final TensorType type;
- final String[] labels;
+ final int[] labels;
+
+ private static int[] createEmptyLabels(int size) {
+ int[] labels = new int[size];
+ Arrays.fill(labels, Tensor.invalidIndex);
+ return labels;
+ }
public Builder(TensorType type) {
- this(type, new String[type.dimensions().size()]);
+ this(type, createEmptyLabels(type.dimensions().size()));
}
- private Builder(TensorType type, String[] labels) {
+ private Builder(TensorType type, int[] labels) {
this.type = type;
this.labels = labels;
}
@@ -207,7 +179,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
var mappedSubtype = type.mappedSubtype();
if (mappedSubtype.rank() != 1)
throw new IllegalArgumentException("Cannot add a label without explicit dimension to a tensor of type " +
- type + ": Must have exactly one sparse dimension");
+ type + ": Must have exactly one mapped dimension");
add(mappedSubtype.dimensions().get(0).name(), label);
return this;
}
@@ -220,10 +192,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
public Builder add(String dimension, String label) {
Objects.requireNonNull(dimension, "dimension cannot be null");
Objects.requireNonNull(label, "label cannot be null");
- Optional<Integer> labelIndex = type.indexOfDimension(dimension);
- if ( labelIndex.isEmpty())
+ int labelIndex = type.indexOfDimensionAsInt(dimension);
+ if ( labelIndex < 0)
+ throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
+ labels[labelIndex] = Label.toNumber(label);
+ return this;
+ }
+
+ public Builder add(String dimension, long label) {
+ return add(dimension, Convert.safe2Int(label));
+ }
+ public Builder add(String dimension, int label) {
+ Objects.requireNonNull(dimension, "dimension cannot be null");
+ int labelIndex = type.indexOfDimensionAsInt(dimension);
+ if ( labelIndex < 0)
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
- labels[labelIndex.get()] = label;
+ labels[labelIndex] = label;
return this;
}
@@ -237,14 +221,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
void validate() {
for (int i = 0; i < labels.length; i++)
- if (labels[i] == null)
+ if (labels[i] == Tensor.invalidIndex)
throw new IllegalArgumentException("Missing a label for dimension '" +
type.dimensions().get(i).name() + "' for " + type);
}
public TensorAddress build() {
validate();
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
}
@@ -256,7 +240,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
super(type);
}
- private PartialBuilder(TensorType type, String[] labels) {
+ private PartialBuilder(TensorType type, int[] labels) {
super(type, labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index b30b664a5f7..6b81d023a9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -86,16 +87,20 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final List<Dimension> dimensions;
+ private final Set<String> dimensionNames;
private final TensorType mappedSubtype;
private final TensorType indexedSubtype;
+ private final int indexedUnBoundCount;
// only used to initialize the "empty" instance
private TensorType() {
this.valueType = Value.DOUBLE;
this.dimensions = List.of();
+ this.dimensionNames = Set.of();
this.mappedSubtype = this;
this.indexedSubtype = this;
+ indexedUnBoundCount = 0;
}
public TensorType(Value valueType, Collection<Dimension> dimensions) {
@@ -103,12 +108,25 @@ public class TensorType {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
+ ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>();
+ int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0;
+ for (Dimension dimension : dimensionList) {
+ namesbuilder.add(dimension.name());
+ Dimension.Type type = dimension.type();
+ switch (type) {
+ case indexedUnbound -> indexedUnBoundCount++;
+ case indexedBound -> indexedBoundCount++;
+ case mapped -> mappedCount++;
+ }
+ }
+ this.indexedUnBoundCount = indexedUnBoundCount;
+ dimensionNames = namesbuilder.build();
- if (dimensionList.stream().allMatch(Dimension::isIndexed)) {
+ if (mappedCount == 0) {
mappedSubtype = empty;
indexedSubtype = this;
}
- else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) {
+ else if ((indexedBoundCount + indexedUnBoundCount) == 0) {
mappedSubtype = this;
indexedSubtype = empty;
}
@@ -118,6 +136,11 @@ public class TensorType {
}
}
+ public boolean hasIndexedDimensions() { return indexedSubtype != empty; }
+ public boolean hasMappedDimensions() { return mappedSubtype != empty; }
+ public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); }
+ boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; }
+
static public Value combinedValueType(TensorType ... types) {
List<Value> valueTypes = new ArrayList<>();
for (TensorType type : types) {
@@ -161,7 +184,7 @@ public class TensorType {
/** Returns an immutable set of the names of the dimensions of this */
public Set<String> dimensionNames() {
- return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
+ return dimensionNames;
}
/** Returns the dimension with this name, or empty if not present */
@@ -176,6 +199,13 @@ public class TensorType {
return Optional.of(i);
return Optional.empty();
}
+ /** Returns the 0-base index of this dimension, or empty if it is not present */
+ public int indexOfDimensionAsInt(String dimension) {
+ for (int i = 0; i < dimensions.size(); i++)
+ if (dimensions.get(i).name().equals(dimension))
+ return i;
+ return Tensor.invalidIndex;
+ }
/* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
public Optional<Long> sizeOfDimension(String dimension) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 0e4fab95c87..9125b35ea5d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.Arrays;
import java.util.HashMap;
@@ -133,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return tensor;
}
else { // extend tensor with this dimension
- if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ if (tensor.type().hasMappedDimensions())
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
@@ -172,7 +173,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType concatType, long concatOffset, String concatDimension) {
long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
+ Arrays.fill(combinedLabels, Tensor.invalidIndex);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
@@ -191,7 +192,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.invalidIndex);
return toIndexes;
}
@@ -208,7 +209,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != from.numericLabel(i)) return false;
to[toIndex] = from.numericLabel(i);
}
}
@@ -354,21 +355,21 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
- String[] labels = new String[plan.resultType.rank()];
+ int[] labels = new int[plan.resultType.rank()];
int out = 0;
int m = 0;
int a = 0;
int b = 0;
for (var how : plan.combineHow) {
switch (how) {
- case left -> labels[out++] = leftOnly.label(a++);
- case right -> labels[out++] = rightOnly.label(b++);
- case both -> labels[out++] = match.label(m++);
- case concat -> labels[out++] = String.valueOf(concatDimIdx);
+ case left -> labels[out++] = (int) leftOnly.numericLabel(a++);
+ case right -> labels[out++] = (int) rightOnly.numericLabel(b++);
+ case both -> labels[out++] = (int) match.numericLabel(m++);
+ case concat -> labels[out++] = concatDimIdx;
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
@@ -398,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
CellVectorMapMap decompose(Tensor input, SplitHow how) {
var iter = input.cellIterator();
- String[] commonLabels = new String[(int)how.numCommon()];
- String[] separateLabels = new String[(int)how.numSeparate()];
+ int[] commonLabels = new int[(int)how.numCommon()];
+ int[] separateLabels = new int[(int)how.numSeparate()];
CellVectorMapMap result = new CellVectorMapMap();
while (iter.hasNext()) {
var cell = iter.next();
@@ -409,14 +410,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
- case common -> commonLabels[commonIdx++] = addr.label(i);
- case separate -> separateLabels[separateIdx++] = addr.label(i);
+ case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i);
+ case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i);
case concat -> ccDimIndex = addr.numericLabel(i);
default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i));
}
}
- TensorAddress commonAddr = TensorAddress.of(commonLabels);
- TensorAddress separateAddr = TensorAddress.of(separateLabels);
+ TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels);
+ TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels);
result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
}
return result;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 3b6e03186a3..b595b1a40cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 0)
+ if (!arguments.isEmpty())
throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
return this;
}
@@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
var result = new ArrayList<TensorFunction<NAMETYPE>>();
for (var fun : cells.values()) {
- fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
@@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
super(type);
- if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
+ if ( ! type.hasOnlyIndexedBoundDimensions())
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
"only indexed, bound dimensions, but this has " + type);
this.cells = List.copyOf(cells);
@@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() {
var result = new ArrayList<TensorFunction<NAMETYPE>>();
for (var fun : cells) {
- fun.asTensorFunction().ifPresent(tf -> result.add(tf));
+ fun.asTensorFunction().ifPresent(result::add);
}
return result;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 4c92e1e57a2..fb345264f56 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,8 +12,11 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -113,7 +116,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
- long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
@@ -128,8 +131,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
var key = aCell.getKey();
- if (b.has(key)) {
- builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ Double bVal = b.getAsDouble(key);
+ if (bVal != null) {
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
return builder.build();
@@ -144,7 +148,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
- if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
+ if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
@@ -169,7 +173,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder,
DoubleBinaryOperator combinator) {
- long joinedLength = Math.min(subspaceSize, superspaceSize);
+ int joinedLength = (int)Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -204,12 +208,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
- TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
- if (subspace.has(subaddress)) {
- double subspaceValue = subspace.get(subaddress);
+ TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
+ Double subspaceValue = subspace.getAsDouble(subaddress);
+ if (subspaceValue != null) {
builder.cell(supercell.getKey(),
- reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
- : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ reversedArgumentOrder
+ ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
}
return builder.build();
@@ -223,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return subspaceIndexes;
}
- private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
- String[] subspaceLabels = new String[subspaceIndexes.length];
- for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
- return TensorAddress.of(subspaceLabels);
- }
-
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
@@ -250,8 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
- Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
- Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
+ int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
+ Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
@@ -262,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
- PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions);
+ PartialAddress matchingBCells = sharedDimensionSize > 0
+ ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
+ : empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
@@ -274,11 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
}
- private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
- PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
- for (int i = 0; i < addressType.dimensions().size(); i++)
- if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
+ private static final PartialAddress empty = new PartialAddress.Builder(0).build();
+ private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
+ Set<String> retainDimensions, int sharedDimensionSize) {
+ PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
+ for (int i = 0; i < addressType.dimensions().size(); i++) {
+ String dimension = addressType.dimensions().get(i).name();
+ if (retainDimensions.contains(dimension))
+ builder.add(dimension, address.numericLabel(i));
+ }
return builder.build();
}
@@ -330,19 +335,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
int[] bIndexesInJoined = mapIndexes(b.type(), joinedType);
// Iterate once through the smaller tensor and construct a hash map for common dimensions
- Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>();
+ Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
- aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>());
- aCellsByCommonAddress.get(partialCommonAddress).add(aCell);
+ TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
+ aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
// Iterate once through the larger tensor and use the hash map to find joinable cells
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
+ TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
@@ -358,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
/**
- * Returns the an array having one entry in order for each dimension of fromType
+ * Returns an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
* That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
@@ -367,17 +371,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name());
return toIndexes;
}
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
- String[] joinedLabels = new String[joinedType.dimensions().size()];
+ int[] joinedLabels = new int[joinedType.dimensions().size()];
+ Arrays.fill(joinedLabels, Tensor.invalidIndex);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return TensorAddress.of(joinedLabels);
+ return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
@@ -386,11 +391,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
- for (int i = 0; i < from.size(); i++) {
+ private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) {
+ for (int i = 0, size = from.size(); i < size; i++) {
int toIndex = indexMap[i];
- if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false;
- to[toIndex] = from.label(i);
+ int label = Convert.safe2Int(from.numericLabel(i));
+ if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label)
+ return false;
+ to[toIndex] = label;
}
return true;
}
@@ -412,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return typeBuilder.build();
}
- private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
- TensorAddress address = cell.getKey();
- String[] labels = new String[indexMap.length];
- for (int i = 0; i < labels.length; ++i) {
- labels[i] = address.label(indexMap[i]);
- }
- return TensorAddress.of(labels);
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
index c87ef42976d..aa9602339e9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
@@ -98,9 +98,9 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction
for (int i = 0; i < inputType.dimensions().size(); i++) {
var dim = inputType.dimensions().get(i);
if (dim.isMapped()) {
- mapAddrBuilder.add(dim.name(), fullAddr.label(i));
+ mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
} else {
- idxAddrBuilder.add(dim.name(), fullAddr.label(i));
+ idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
}
}
var mapAddr = mapAddrBuilder.build();
@@ -123,11 +123,11 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction
var addrBuilder = new TensorAddress.Builder(outputType);
for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) {
var dim = inputTypeMapped.dimensions().get(i);
- addrBuilder.add(dim.name(), mappedAddr.label(i));
+ addrBuilder.add(dim.name(), mappedAddr.numericLabel(i));
}
for (int i = 0; i < denseOutputDims.size(); i++) {
var dim = denseOutputDims.get(i);
- addrBuilder.add(dim.name(), denseAddr.label(i));
+ addrBuilder.add(dim.name(), denseAddr.numericLabel(i));
}
builder.cell(addrBuilder.build(), cell.getValue());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index 59394785382..ddad91dc060 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
var key = aCell.getKey();
- if (! b.has(key)) {
+ Double bVal = b.getAsDouble(key);
+ if (bVal == null) {
builder.cell(key, aCell.getValue());
} else if (combinator != null) {
- builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key)));
+ builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal));
}
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 8cf88610599..947fd6e0012 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -1,6 +1,8 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.DirectIndexedAddress;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -9,16 +11,15 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.impl.Convert;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
-import java.util.Set;
/**
* The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
@@ -112,32 +113,84 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
- if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
+ if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
- dimensions + ": Not all those dimensions are present in this tensor");
+ dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
- if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
+ if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) {
if (argument.isEmpty())
return Tensor.from(0.0);
else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
- return reduceIndexedVector((IndexedTensor)argument, aggregator);
+ return reduceIndexedVector((IndexedTensor) argument, aggregator);
else
return reduceAllGeneral(argument, aggregator);
+ }
TensorType reducedType = outputType(argument.type(), dimensions);
+ int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions);
+ int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce);
+ if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) {
+ return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator);
+ } else {
+ return reduceGeneral(argument, reducedType, indexesToKeep, aggregator);
+ }
+ }
+
+ private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) {
+ int currentIndex = reduce[reduceIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+ if (reduceIndex + 1 < reduce.length) {
+ int nextDimension = reduceIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ reduce(argument, aggregator, address, reduce, nextDimension);
+ }
+ } else {
+ address.setIndex(currentIndex, 0);
+ long increment = address.getStride(currentIndex);
+ long directIndex = address.getDirectIndex();
+ for (int i = 0; i < dimSize; i++) {
+ aggregator.aggregate(argument.get(directIndex + i * increment));
+ }
+ }
+ }
+
+ private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) {
+ if (keepIndex < toKeep.length) {
+ int currentIndex = toKeep[keepIndex];
+ int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex));
+
+ int nextKeep = keepIndex + 1;
+ for (int i = 0; i < dimSize; i++) {
+ address.setIndex(currentIndex, i);
+ destAddress.setIndex(keepIndex, i);
+ reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce);
+ }
+ } else {
+ ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
+ reduce(argument, valueAggregator, address, toReduce, 0);
+ builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes());
+ }
+
+ }
- // Reduce cells
- int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions);
+ private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) {
+
+ var reducedBuilder = IndexedTensor.Builder.of(reducedType);
+ DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType));
+ reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce);
+ return reducedBuilder.build();
+ }
+
+ private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) {
// TODO cells.size() is most likely an overestimate, and might need a better heuristic
// But the upside is larger than the downside.
- Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size());
+ Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey());
- ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
- if (aggr == null)
- aggr = aggregatingCells.get(reducedAddress);
+ TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep);
+ ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator));
aggr.aggregate(cell.getValue());
}
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
@@ -146,39 +199,43 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return reducedBuilder.build();
}
- private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) {
- Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2);
- for (String dimensionToRemove : dimensions)
- indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
- int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()];
+
+ private static int[] createIndexesToReduce(TensorType tensorType, List<String> dimensions) {
+ int[] indexesToReduce = new int[dimensions.size()];
+ for (int i = 0; i < dimensions.size(); i++) {
+ indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get();
+ }
+ return indexesToReduce;
+ }
+ private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) {
+ int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length];
int toKeepIndex = 0;
for (int i = 0; i < argumentType.rank(); i++) {
- if ( ! indexesToRemove.contains(i))
+ if ( ! contains(indexesToReduce, i))
indexesToKeep[toKeepIndex++] = i;
}
return indexesToKeep;
}
-
- private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) {
- String[] reducedLabels = new String[indexesToKeep.length];
- int reducedLabelIndex = 0;
- for (int toKeep : indexesToKeep)
- reducedLabels[reducedLabelIndex++] = address.label(toKeep);
- return TensorAddress.of(reducedLabels);
+ private static boolean contains(int[] list, int key) {
+ for (int candidate : list) {
+ if (candidate == key) return true;
+ }
+ return false;
}
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
valueAggregator.aggregate(i.next());
- return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
+ return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build();
}
private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
- for (int i = 0; i < argument.dimensionSizes().size(0); i++)
+ int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0));
+ for (int i = 0; i < dimensionSize ; i++)
valueAggregator.aggregate(argument.get(i));
- return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
+ return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build();
}
static abstract class ValueAggregator {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index aece782d296..2d5a0518747 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
return false;
if ( ! (a instanceof IndexedTensor))
return false;
- if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ if ( ! (a.type().hasOnlyIndexedBoundDimensions()))
return false;
if ( ! (b instanceof IndexedTensor))
return false;
- if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ if ( ! (b.type().hasOnlyIndexedBoundDimensions()))
return false;
TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index a2a3874eced..05db61f5395 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -35,7 +35,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
- if (fromDimensions.size() < 1)
+ if (fromDimensions.isEmpty())
throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
if (fromDimensions.size() != toDimensions.size())
throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
@@ -89,7 +89,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
String dimensionName = tensor.type().dimensions().get(i).name();
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
- toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
+ toIndexes[renamedType.indexOfDimension(newDimensionName).get()] = i;
}
// avoid building a new tensor if dimensions can simply be renamed
@@ -100,7 +100,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress renamedAddress = rename(cell.getKey(), toIndexes);
+ TensorAddress renamedAddress = cell.getKey().partialCopy(toIndexes);
builder.cell(renamedAddress, cell.getValue());
}
return builder.build();
@@ -118,13 +118,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return true;
}
- private TensorAddress rename(TensorAddress address, int[] toIndexes) {
- String[] reorderedLabels = new String[toIndexes.length];
- for (int i = 0; i < toIndexes.length; i++)
- reorderedLabels[toIndexes[i]] = address.label(i);
- return TensorAddress.of(reorderedLabels);
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 807f56b1a49..38ac42a5f1f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -131,7 +131,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
for (int i = 0; i < address.size(); i++) {
String dimension = type.dimensions().get(i).name();
if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
- b.add(dimension, address.label(i));
+ b.add(dimension, (int)address.numericLabel(i));
}
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java
new file mode 100644
index 00000000000..e2cb64fdd1f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java
@@ -0,0 +1,16 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.impl;
+
+/**
+ * Utility to make common conversions safe
+ *
+ * @author baldersheim
+ */
+public class Convert {
+ public static int safe2Int(long value) {
+ if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) {
+ throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int");
+ }
+ return (int) value;
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
new file mode 100644
index 00000000000..7c1e8646245
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
@@ -0,0 +1,83 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.Tensor;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+/**
+ * A label is a value of a mapped dimension of a tensor.
+ * This class provides a mapping of labels to numbers which allow for more efficient computation with
+ * mapped tensor dimensions.
+ *
+ * @author baldersheim
+ */
+public class Label {
+
+ private static final String[] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
+
+ private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>();
+
+ // Index 0 is unused, that is a valid positive number
+ // 1(-1) is reserved for the Tensor.INVALID_INDEX
+ private static volatile String[] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"};
+ private static int numUniqeStrings = 2;
+
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String[] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static int addNewUniqueString(String s) {
+ synchronized (string2Enum) {
+ if (numUniqeStrings >= uniqueStrings.length) {
+ uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2);
+ }
+ uniqueStrings[numUniqeStrings] = s;
+ return -numUniqeStrings++;
+ }
+ }
+
+ private static String asNumericString(long index) {
+ return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
+ private static boolean validNumericIndex(String s) {
+ if (s.isEmpty() || ((s.length() > 1) && (s.charAt(0) == '0'))) return false;
+ for (int i = 0; i < s.length(); i++) {
+ char c = s.charAt(i);
+ if ((c < '0') || (c > '9')) return false;
+ }
+ return true;
+ }
+
+ public static int toNumber(String s) {
+ if (s == null) { return Tensor.invalidIndex; }
+ try {
+ if (validNumericIndex(s)) {
+ return Integer.parseInt(s);
+ }
+ } catch (NumberFormatException e) {
+ }
+ return string2Enum.computeIfAbsent(s, Label::addNewUniqueString);
+ }
+
+ public static String fromNumber(int v) {
+ if (v >= 0) {
+ return asNumericString(v);
+ } else {
+ if (v == Tensor.invalidIndex) { return null; }
+ return uniqueStrings[-v];
+ }
+ }
+
+ public static String fromNumber(long v) {
+ return fromNumber(Convert.safe2Int(v));
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
new file mode 100644
index 00000000000..2e70811a67c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
@@ -0,0 +1,154 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import static com.yahoo.tensor.impl.Convert.safe2Int;
+import static com.yahoo.tensor.impl.Label.toNumber;
+import static com.yahoo.tensor.impl.Label.fromNumber;
+
+/**
+ * Parent of tensor address family centered around each dimension as int.
+ * A positive number represents a numeric index usable as a direect addressing.
+ * - 1 is representing an invalid/null address
+ * Other negative numbers are an enumeration maintained in {@link Label}
+ *
+ * @author baldersheim
+ */
+abstract public class TensorAddressAny extends TensorAddress {
+
+ @Override
+ public String label(int i) {
+ return fromNumber((int)numericLabel(i));
+ }
+
+ public static TensorAddress of() {
+ return TensorAddressEmpty.empty;
+ }
+
+ public static TensorAddress of(String label) {
+ return new TensorAddressAny1(toNumber(label));
+ }
+
+ public static TensorAddress of(String label0, String label1) {
+ return new TensorAddressAny2(toNumber(label0), toNumber(label1));
+ }
+
+ public static TensorAddress of(String label0, String label1, String label2) {
+ return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2));
+ }
+
+ public static TensorAddress of(String label0, String label1, String label2, String label3) {
+ return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3));
+ }
+
+ public static TensorAddress of(String[] labels) {
+ int[] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = toNumber(labels[i]);
+ }
+ return ofUnsafe(labelsAsInt);
+ }
+
+ public static TensorAddress of(int label) {
+ return new TensorAddressAny1(sanitize(label));
+ }
+
+ public static TensorAddress of(int label0, int label1) {
+ return new TensorAddressAny2(sanitize(label0), sanitize(label1));
+ }
+
+ public static TensorAddress of(int label0, int label1, int label2) {
+ return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2));
+ }
+
+ public static TensorAddress of(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3));
+ }
+
+ public static TensorAddress of(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> new TensorAddressAny1(sanitize(labels[0]));
+ case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1]));
+ case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]));
+ case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3]));
+ default -> {
+ for (int i = 0; i < labels.length; i++) {
+ sanitize(labels[i]);
+ }
+ yield new TensorAddressAnyN(labels);
+ }
+ };
+ }
+
+ public static TensorAddress of(long label) {
+ return of(safe2Int(label));
+ }
+
+ public static TensorAddress of(long label0, long label1) {
+ return of(safe2Int(label0), safe2Int(label1));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2, long label3) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3));
+ }
+
+ public static TensorAddress of(long ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(safe2Int(labels[0]));
+ case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]));
+ case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]));
+ case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3]));
+ default -> {
+ int[] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = safe2Int(labels[i]);
+ }
+ yield of(labelsAsInt);
+ }
+ };
+ }
+
+ private static TensorAddress ofUnsafe(int label) {
+ return new TensorAddressAny1(label);
+ }
+
+ private static TensorAddress ofUnsafe(int label0, int label1) {
+ return new TensorAddressAny2(label0, label1);
+ }
+
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2) {
+ return new TensorAddressAny3(label0, label1, label2);
+ }
+
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(label0, label1, label2, label3);
+ }
+
+ public static TensorAddress ofUnsafe(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(labels[0]);
+ case 2 -> ofUnsafe(labels[0], labels[1]);
+ case 3 -> ofUnsafe(labels[0], labels[1], labels[2]);
+ case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]);
+ default -> new TensorAddressAnyN(labels);
+ };
+ }
+
+ private static int sanitize(int label) {
+ if (label < Tensor.invalidIndex) {
+ throw new IndexOutOfBoundsException("cell label " + label + " must be positive");
+ }
+ return label;
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
new file mode 100644
index 00000000000..a9be6173781
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
@@ -0,0 +1,41 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * A one-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressAny1 extends TensorAddressAny {
+
+ private final int label;
+
+ TensorAddressAny1(int label) { this.label = label; }
+
+ @Override public int size() { return 1; }
+
+ @Override
+ public long numericLabel(int i) {
+ if (i == 0) {
+ return label;
+ }
+ throw new IndexOutOfBoundsException("Index is not zero: " + i);
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label));
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override public int hashCode() { return Math.abs(label); }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny1 any) && (label == any.label);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
new file mode 100644
index 00000000000..43f65d495cf
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
@@ -0,0 +1,53 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * A two-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressAny2 extends TensorAddressAny {
+
+ private final int label0, label1;
+
+ TensorAddressAny2(int label0, int label1) {
+ this.label0 = label0;
+ this.label1 = label1;
+ }
+
+ @Override public int size() { return 2; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1);
+ case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0)));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
new file mode 100644
index 00000000000..c22ff47b3c4
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
@@ -0,0 +1,61 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * A three-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressAny3 extends TensorAddressAny {
+
+ private final int label0, label1, label2;
+
+ TensorAddressAny3(int label0, int label1, int label2) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ }
+
+ @Override public int size() { return 3; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2);
+ case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2);
+ case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny3 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
new file mode 100644
index 00000000000..6eb6b9216bf
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
@@ -0,0 +1,66 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * A four-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressAny4 extends TensorAddressAny {
+
+ private final int label0, label1, label2, label3;
+
+ TensorAddressAny4(int label0, int label1, int label2, int label3) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ this.label3 = label3;
+ }
+
+ @Override public int size() { return 4; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ case 3 -> label3;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3);
+ case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3);
+ case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3);
+ case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) |
+ (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny4 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2) &&
+ (label3 == any.label3);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
new file mode 100644
index 00000000000..d5bac62bf18
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
@@ -0,0 +1,53 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Arrays;
+
+import static java.lang.Math.abs;
+
+/**
+ * An n-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressAnyN extends TensorAddressAny {
+
+ private final int[] labels;
+
+ TensorAddressAnyN(int[] labels) {
+ if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label");
+ this.labels = labels;
+ }
+
+ @Override public int size() { return labels.length; }
+
+ @Override public long numericLabel(int i) { return labels[i]; }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ int[] copy = Arrays.copyOf(labels, labels.length);
+ copy[labelIndex] = Convert.safe2Int(label);
+ return new TensorAddressAnyN(copy);
+ }
+
+ @Override public int hashCode() {
+ int hash = abs(labels[0]);
+ for (int i = 0; i < size(); i++) {
+ hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash)));
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false;
+ for (int i = 0; i < size(); i++) {
+ if (labels[i] != any.labels[i]) return false;
+ }
+ return true;
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
new file mode 100644
index 00000000000..eb7e62e913b
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
@@ -0,0 +1,33 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * A zero-dimensional address.
+ *
+ * @author baldersheim
+ */
+final class TensorAddressEmpty extends TensorAddressAny {
+
+ static TensorAddress empty = new TensorAddressEmpty();
+
+ private TensorAddressEmpty() {}
+
+ @Override public int size() { return 0; }
+
+ @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override
+ public int hashCode() { return 0; }
+
+ @Override
+ public boolean equals(Object o) { return o instanceof TensorAddressEmpty; }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java
new file mode 100644
index 00000000000..6b004bf2d02
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java
@@ -0,0 +1,6 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+@ExportPackage
+package com.yahoo.tensor.impl;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index ca9527fd681..32e74c0f132 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -56,22 +56,22 @@ public class DenseBinaryFormat implements BinaryFormat {
}
private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putDouble(tensor.get(i));
}
private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putFloat(tensor.getFloat(i));
}
private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i)));
}
private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
- for (int i = 0; i < tensor.size(); i++)
+ for (int i = 0; i < tensor.sizeAsInt(); i++)
buffer.put((byte) tensor.getFloat(i));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 444ce02b14a..5598690e0bf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -16,15 +16,7 @@ import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.Name;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Slice;
-
-import java.util.ArrayList;
-import java.util.HashSet;
import java.util.Iterator;
-import java.util.List;
-import java.util.Set;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -60,8 +52,7 @@ public class JsonFormat {
// Short form for a single mapped dimension
Cursor parent = root == null ? slime.setObject() : root.setObject("cells");
encodeSingleDimensionCells((MappedTensor) tensor, parent);
- } else if (tensor instanceof MixedTensor &&
- tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) {
+ } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) {
// Short form for a mixed tensor
boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1;
Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() )
@@ -143,9 +134,9 @@ public class JsonFormat {
}
private static void encodeBlocks(MixedTensor tensor, Cursor cursor) {
- var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped())
+ var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped)
.map(d -> TensorType.Dimension.mapped(d.name())).toList();
- if (mappedDimensions.size() < 1) {
+ if (mappedDimensions.isEmpty()) {
throw new IllegalArgumentException("Should be ensured by caller");
}
@@ -179,23 +170,6 @@ public class JsonFormat {
cursor.setDouble(field, value);
}
- private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) {
- TensorAddress.Builder builder = new TensorAddress.Builder(subType);
- for (TensorType.Dimension dim : subType.dimensions()) {
- builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()).
- orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index"))));
- }
- return builder.build();
- }
-
- private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) {
- List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size());
- for (int i = 0; i < subAddress.size(); ++i) {
- sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i)));
- }
- return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate();
- }
-
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
@@ -204,7 +178,7 @@ public class JsonFormat {
if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
- else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
+ else if (root.field("values").valid() && ! builder.type().hasMappedDimensions())
decodeValuesAtTop(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
@@ -298,14 +272,14 @@ public class JsonFormat {
/** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
- boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
- boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+ boolean hasIndexed = builder.type().hasIndexedDimensions();
+ boolean hasMapped = builder.type().hasMappedDimensions();
if (isArrayOfObjects(root))
decodeCells(root, builder);
else if ( ! hasMapped)
decodeValuesAtTop(root, builder);
- else if (hasMapped && hasIndexed)
+ else if (hasIndexed)
decodeBlocks(root, builder);
else
decodeCells(root, builder);
@@ -423,9 +397,7 @@ public class JsonFormat {
if (decoded.length == 0) {
throw new IllegalArgumentException("The block value string does not contain any values");
}
- for (int i = 0; i < decoded.length; i++) {
- values[i] = decoded[i];
- }
+ System.arraycopy(decoded, 0, values, 0, decoded.length);
} else {
throw new IllegalArgumentException("Expected a block to contain an array of values");
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index bdeb9add41a..3a117e41461 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -48,7 +48,7 @@ class SparseBinaryFormat implements BinaryFormat {
}
private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) {
- buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation
+ buffer.putInt1_4Bytes(tensor.sizeAsInt()); // XXX: Size truncation
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index d4b18c73f11..0a5c713f3e2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -55,8 +55,8 @@ public class TypedBinaryFormat {
}
private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) {
- boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
- boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMappedDimensions = tensor.type().hasMappedDimensions();
+ boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions();
boolean isMixed = hasMappedDimensions && hasIndexedDimensions;
// TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead