diff options
Diffstat (limited to 'vespajlib/src/main/java')
35 files changed, 1205 insertions, 554 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java index 92a9ed26085..7a3d34eca7b 100644 --- a/vespajlib/src/main/java/com/yahoo/compress/Hasher.java +++ b/vespajlib/src/main/java/com/yahoo/compress/Hasher.java @@ -8,8 +8,25 @@ import net.openhft.hashing.LongHashFunction; * @author baldersheim */ public class Hasher { + private final LongHashFunction hasher; /** Uses net.openhft.hashing.LongHashFunction.xx3() */ public static long xxh3(byte [] data) { return LongHashFunction.xx3().hashBytes(data); } + public static long xxh3(byte [] data, long seed) { + return LongHashFunction.xx3(seed).hashBytes(data); + } + + private Hasher(LongHashFunction hasher) { + this.hasher = hasher; + } + public static Hasher withSeed(long seed) { + return new Hasher(LongHashFunction.xx3(seed)); + } + public long hash(long v) { + return hasher.hashLong(v); + } + public long hash(String s) { + return hasher.hashChars(s); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 83a625f72ac..640fa609432 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -11,10 +11,19 @@ import java.util.Arrays; public final class DimensionSizes { private final long[] sizes; + private final long[] productOfSizesFromHereOn; + private final long totalSize; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; builder.sizes = null; // invalidate builder to avoid copying the array + this.productOfSizesFromHereOn = new long[sizes.length]; + long product = 1; + for (int i = sizes.length; i-- > 0; ) { + productOfSizesFromHereOn[i] = product; + product *= sizes[i]; + } + this.totalSize = product; } /** @@ -49,10 +58,11 @@ public final class DimensionSizes { /** Returns the product of the sizes of this */ public long totalSize() { - long productSize = 1; - for (long dimensionSize : sizes ) - productSize *= dimensionSize; - return productSize; + return totalSize; + } + + long productOfDimensionsAfter(int afterIndex) { + return productOfSizesFromHereOn[afterIndex]; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java new file mode 100644 index 00000000000..cda3be47ddb --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java @@ -0,0 +1,55 @@ +package com.yahoo.tensor; + +/** + * Utility class for efficient access and iteration along dimensions in Indexed tensors. + * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration. + * long base = addr.getDirectIndex(); + * long stride = addr.getStride(dimension) + * i = 0...size_of_dimension + * double value = tensor.get(base + i * stride); + * + * @author baldersheim + */ +public final class DirectIndexedAddress { + + private final DimensionSizes sizes; + private final int[] indexes; + private long directIndex; + + private DirectIndexedAddress(DimensionSizes sizes) { + this.sizes = sizes; + indexes = new int[sizes.dimensions()]; + directIndex = 0; + } + + public static DirectIndexedAddress of(DimensionSizes sizes) { + return new DirectIndexedAddress(sizes); + } + + /** Sets the current index of a dimension */ + public void setIndex(int dimension, int index) { + if (index < 0 || index >= sizes.size(dimension)) { + throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">"); + } + int diff = index - indexes[dimension]; + directIndex += getStride(dimension) * diff; + indexes[dimension] = index; + } + + /** Retrieve the index that can be used for direct lookup in an indexed tensor. */ + public long getDirectIndex() { return directIndex; } + + public long [] getIndexes() { + long[] asLong = new long[indexes.length]; + for (int i=0; i < indexes.length; i++) { + asLong[i] = indexes[i]; + } + return asLong; + } + + /** returns the stride to be used for the given dimension */ + public long getStride(int dimension) { + return sizes.productOfDimensionsAfter(dimension); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 548d39dd767..53f50fc4d02 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -22,6 +22,10 @@ class IndexedDoubleTensor extends IndexedTensor { return values.length; } + /** Once we can store more cells than an int we should drop this method. */ + @Override + public int sizeAsInt() { return values.length; } + @Override public double get(long valueIndex) { return values[(int)valueIndex]; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 26560a70ac4..3085ef1a843 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -18,9 +18,11 @@ class IndexedFloatTensor extends IndexedTensor { } @Override - public long size() { - return values.length; - } + public long size() { return values.length; } + + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return values.length; } @Override public double get(long valueIndex) { return getFloat(valueIndex); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6a879fa533b..fc0473c635a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -90,9 +90,13 @@ public abstract class IndexedTensor implements Tensor { * @throws IllegalArgumentException if any of the indexes are out of bound or a wrong number of indexes are given */ public double get(long ... indexes) { - return get((int)toValueIndex(indexes, dimensionSizes)); + return get(toValueIndex(indexes, dimensionSizes)); } + public double get(DirectIndexedAddress address) { + return get(address.getDirectIndex()); + } + public DirectIndexedAddress directAddress() { return DirectIndexedAddress.of(dimensionSizes); } /** * Returns the value at the given indexes as a float * @@ -108,7 +112,7 @@ public abstract class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return get((int)toValueIndex(address, dimensionSizes, type)); + return get(toValueIndex(address, dimensionSizes, type)); } catch (IllegalArgumentException e) { return 0.0; @@ -116,6 +120,17 @@ public abstract class IndexedTensor implements Tensor { } @Override + public Double getAsDouble(TensorAddress address) { + try { + long index = toValueIndex(address, dimensionSizes, type); + if (index < 0 || size() <= index) return null; + return get(index); + } catch (IllegalArgumentException e) { + return null; + } + } + + @Override public boolean has(TensorAddress address) { try { long index = toValueIndex(address, dimensionSizes, type); @@ -150,30 +165,22 @@ public abstract class IndexedTensor implements Tensor { for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) throw new IllegalArgumentException(Arrays.toString(indexes) + " are not within bounds"); - valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; + valueIndex += sizes.productOfDimensionsAfter(i) * indexes[i]; } return valueIndex; } static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) { - if (address.isEmpty()) return 0; - long valueIndex = 0; - for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + for (int i = 0, size = address.size(); i < size; i++) { + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } - private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - long product = 1; - for (int i = afterIndex + 1; i < sizes.dimensions(); i++) - product *= sizes.size(i); - return product; - } - void throwOnIncompatibleType(TensorType type) { if ( ! this.type().isRenamableTo(type)) throw new IllegalArgumentException("Can not change type from " + this.type() + " to " + type + @@ -227,7 +234,7 @@ public abstract class IndexedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { @@ -250,8 +257,7 @@ public abstract class IndexedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (tensor.type().valueType()) { @@ -264,8 +270,7 @@ public abstract class IndexedTensor implements Tensor { } // end bracket and comma - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } if (index == maxCells && index < tensor.size()) b.append(", ...]"); @@ -286,7 +291,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -300,7 +305,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -314,7 +319,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -327,14 +332,13 @@ public abstract class IndexedTensor implements Tensor { */ public static Builder of(TensorType type, DimensionSizes sizes) { validate(type, sizes); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -348,14 +352,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, float[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } /** @@ -369,14 +372,13 @@ public abstract class IndexedTensor implements Tensor { public static Builder of(TensorType type, DimensionSizes sizes, double[] values) { validate(type, sizes); validateSizes(sizes, values.length); - switch (type.valueType()) { - case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); - case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); - default: - throw new IllegalStateException("Unexpected value type " + type.valueType()); - } + return switch (type.valueType()) { + case DOUBLE -> new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); + case FLOAT -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case BFLOAT16 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + case INT8 -> new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values); + default -> throw new IllegalStateException("Unexpected value type " + type.valueType()); + }; } private static void validateSizes(DimensionSizes sizes, int length) { @@ -518,7 +520,7 @@ public abstract class IndexedTensor implements Tensor { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, - offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, + offset + sizes.productOfDimensionsAfter(currentDimensionIndex) * i, (List<Object>) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values for (long i = 0; i < currentDimension.size(); i++) { @@ -623,11 +625,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator<Double> { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override @@ -889,8 +891,8 @@ public abstract class IndexedTensor implements Tensor { private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { long size = 1; - for (int iterateDimension : iterateDimensions) - size *= sizes.size(iterateDimension); + for (int i = 0; i < iterateDimensions.size(); i++) + size *= sizes.size(iterateDimensions.get(i)); return size; } @@ -1056,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor { /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { - private long lastComputedSourceValueIndex = -1; + private long lastComputedSourceValueIndex = Tensor.invalidIndex; private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); @@ -1091,8 +1093,8 @@ public abstract class IndexedTensor implements Tensor { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.sourceStep = productOfDimensionsAfter(iterateDimension, sourceSizes); - this.iterationStep = productOfDimensionsAfter(iterateDimension, iterateSizes); + this.sourceStep = sourceSizes.productOfDimensionsAfter(iterateDimension); + this.iterationStep = iterateSizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; @@ -1156,7 +1158,7 @@ public abstract class IndexedTensor implements Tensor { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; - this.step = productOfDimensionsAfter(iterateDimension, sizes); + this.step = sizes.productOfDimensionsAfter(iterateDimension); // Initialize to the (virtual) position before the first cell indexes[iterateDimension]--; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index e196569b18f..3e0df5f2261 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap; import java.util.Iterator; import java.util.Map; import java.util.Set; -import java.util.function.DoubleBinaryOperator; /** * A sparse implementation of a tensor backed by a Map of cells to values. @@ -31,6 +30,10 @@ public class MappedTensor implements Tensor { @Override public long size() { return cells.size(); } + /** Once we can store more cells than an int we should drop this. */ + @Override + public int sizeAsInt() { return cells.size(); } + @Override public double get(TensorAddress address) { return cells.getOrDefault(address, 0.0); } @@ -38,6 +41,9 @@ public class MappedTensor implements Tensor { public boolean has(TensorAddress address) { return cells.containsKey(address); } @Override + public Double getAsDouble(TensorAddress address) { return cells.get(address); } + + @Override public Iterator<Cell> cellIterator() { return new CellIteratorAdaptor(cells.entrySet().iterator()); } @Override @@ -79,7 +85,7 @@ public class MappedTensor implements Tensor { @Override public String toAbbreviatedString(boolean withType, boolean shortForms) { - return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + return toString(withType, shortForms, Math.max(2, 10 / (type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() + 1))); } private String toString(boolean withType, boolean shortForms, long maxCells) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 5d5a5f74063..65c6677e7e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -2,12 +2,13 @@ package com.yahoo.tensor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -28,7 +29,6 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int denseSubspaceSize; // XXX consider using "record" instead /** only exposed for internal use; subject to change without notice */ @@ -50,45 +50,15 @@ public class MixedTensor implements Tensor { } } - /** The cells in the tensor */ - private final List<DenseSubspace> denseSubspaces; - /** only exposed for internal use; subject to change without notice */ - public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } + public List<DenseSubspace> getInternalDenseSubspaces() { return index.denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { + private MixedTensor(TensorType type, Index index) { this.type = type; - this.denseSubspaceSize = index.denseSubspaceSize(); - this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.denseSubspaceSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); - } - long count = 0; - for (var block : this.denseSubspaces) { - if (index.sparseMap.get(block.sparseAddress) != count) { - throw new IllegalStateException("map vs list mismatch: block #" - + count - + " address maps to #" - + index.sparseMap.get(block.sparseAddress)); - } - if (block.cells.length != denseSubspaceSize) { - throw new IllegalStateException("dense subspace size mismatch, expected " - + denseSubspaceSize - + " cells, but got: " - + block.cells.length); - } - ++count; - } - if (count != index.sparseMap.size()) { - throw new IllegalStateException("mismatch: list size is " - + count - + " but map size is " - + index.sparseMap.size()); - } } /** Returns the tensor type */ @@ -97,32 +67,34 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return denseSubspaces.size() * denseSubspaceSize; } + public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { + var block = index.blockOf(address); + int denseOffset = index.denseOffsetOf(address); + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } + return block.cells[denseOffset]; + } + + @Override + public Double getAsDouble(TensorAddress address) { + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - if (denseOffset < 0 || denseOffset >= block.cells.length) { - return 0.0; + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { + return null; } return block.cells[denseOffset]; } @Override public boolean has(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return false; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - return (denseOffset >= 0 && denseOffset < block.cells.length); + return (block != null && denseOffset >= 0 && denseOffset < block.cells.length); } /** @@ -135,21 +107,30 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); - DenseSubspace currBlock = null; - int currOffset = denseSubspaceSize; + + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); + final int[] labels = new int[index.indexedDimensions.size()]; + DenseSubspace currentBlock = null; + int currOffset = index.denseSubspaceSize; + int prevOffset = -1; + @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } + @Override public Cell next() { - if (currOffset == denseSubspaceSize) { - currBlock = blockIterator.next(); + if (currOffset == index.denseSubspaceSize) { + currentBlock = blockIterator.next(); currOffset = 0; } - TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, currOffset); - double value = currBlock.cells[currOffset++]; + if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1 + index.denseOffsetToAddress(currOffset, labels); + } + TensorAddress fullAddr = currentBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels); + prevOffset = currOffset; + double value = currentBlock.cells[currOffset++]; return new Cell(fullAddr, value); } }; @@ -162,20 +143,23 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); - double[] currBlock = null; - int currOffset = denseSubspaceSize; + + final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); + double[] currentBlock = null; + int currOffset = index.denseSubspaceSize; + @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } + @Override public Double next() { - if (currOffset == denseSubspaceSize) { - currBlock = blockIterator.next().cells; + if (currOffset == index.denseSubspaceSize) { + currentBlock = blockIterator.next().cells; currOffset = 0; } - return currBlock[currOffset++]; + return currentBlock[currOffset++]; } }; } @@ -197,24 +181,22 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, denseSubspaces, index); + return new MixedTensor(other, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseSubspace> list = new ArrayList<>(); - for (var block : denseSubspaces) { + for (var block : index.denseSubspaces) { if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part - indexBuilder.addBlock(block.sparseAddress, list.size()); - list.add(block); + indexBuilder.addBlock(block); } } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } @Override - public int hashCode() { return Objects.hash(type, denseSubspaces); } + public int hashCode() { return Objects.hash(type, index.denseSubspaces); } @Override public String toString() { @@ -249,13 +231,14 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return denseSubspaceSize; + return index.denseSubspaceSize; } /** * Base class for building mixed tensors. */ public abstract static class Builder implements Tensor.Builder { + static final int INITIAL_HASH_CAPACITY = 1000; final TensorType type; @@ -265,10 +248,11 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + if (type.hasIndexedUnboundDimensions()) { + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } else { - return new BoundBuilder(type); + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -306,13 +290,14 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Map<TensorAddress, double[]> denseSubspaceMap; private final Index.Builder indexBuilder; private final Index index; private final TensorType denseSubtype; - private BoundBuilder(TensorType type) { + private BoundBuilder(TensorType type, int expectedSize) { super(type); + denseSubspaceMap = new LinkedHashMap<>(expectedSize, 0.5f); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); denseSubtype = new TensorType(type.valueType(), @@ -324,10 +309,7 @@ public class MixedTensor implements Tensor { } private double[] denseSubspace(TensorAddress sparseAddress) { - if (!denseSubspaceMap.containsKey(sparseAddress)) { - denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); - } - return denseSubspaceMap.get(sparseAddress); + return denseSubspaceMap.computeIfAbsent(sparseAddress, (key) -> new double[(int)denseSubspaceSize()]); } public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { @@ -343,7 +325,7 @@ public class MixedTensor implements Tensor { @Override public Tensor.Builder cell(TensorAddress address, double value) { - TensorAddress sparsePart = index.sparsePartialAddress(address); + TensorAddress sparsePart = address.mappedPartialAddress(index.sparseType, index.type.dimensions()); int denseOffset = index.denseOffsetOf(address); double[] denseSubspace = denseSubspace(sparsePart); denseSubspace[denseOffset] = value; @@ -362,19 +344,20 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseSubspace> list = new ArrayList<>(); - for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { + //TODO This can be solved more efficiently with a single map. + Set<Map.Entry<TensorAddress, double[]>> entrySet = denseSubspaceMap.entrySet(); + for (Map.Entry<TensorAddress, double[]> entry : entrySet) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); var block = new DenseSubspace(sparsePart, denseSubspace); - indexBuilder.addBlock(sparsePart, list.size()); - list.add(block); + indexBuilder.addBlock(block); } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } public static BoundBuilder of(TensorType type) { - return new BoundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new BoundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -391,9 +374,9 @@ public class MixedTensor implements Tensor { private final Map<TensorAddress, Double> cells; private final long[] dimensionBounds; - private UnboundBuilder(TensorType type) { + private UnboundBuilder(TensorType type, int expectedSize) { super(type); - cells = new HashMap<>(); + cells = new LinkedHashMap<>(expectedSize, 0.5f); dimensionBounds = new long[type.dimensions().size()]; } @@ -412,7 +395,7 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { TensorType boundType = createBoundType(); - BoundBuilder builder = new BoundBuilder(boundType); + BoundBuilder builder = new BoundBuilder(boundType, cells.size()); for (Map.Entry<TensorAddress, Double> cell : cells.entrySet()) { builder.cell(cell.getKey(), cell.getValue()); } @@ -443,7 +426,8 @@ public class MixedTensor implements Tensor { } public static UnboundBuilder of(TensorType type) { - return new UnboundBuilder(type); + //TODO Wire in expected map size to avoid expensive resize + return new UnboundBuilder(type, INITIAL_HASH_CAPACITY); } } @@ -460,8 +444,10 @@ public class MixedTensor implements Tensor { private final TensorType denseType; private final List<TensorType.Dimension> mappedDimensions; private final List<TensorType.Dimension> indexedDimensions; + private final int[] indexedDimensionsSize; private ImmutableMap<TensorAddress, Integer> sparseMap; + private List<DenseSubspace> denseSubspaces; private final int denseSubspaceSize; static private int computeDSS(List<TensorType.Dimension> dimensions) { @@ -477,17 +463,31 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList(); this.indexedDimensions = type.dimensions().stream().filter(TensorType.Dimension::isIndexed).toList(); + this.indexedDimensionsSize = new int[indexedDimensions.size()]; + for (int i = 0; i < indexedDimensions.size(); i++) { + long dimensionSize = indexedDimensions.get(i).size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")); + indexedDimensionsSize[i] = (int)dimensionSize; + } + this.sparseType = createPartialType(type.valueType(), mappedDimensions); this.denseType = createPartialType(type.valueType(), indexedDimensions); this.denseSubspaceSize = computeDSS(this.indexedDimensions); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); + } } - int blockIndexOf(TensorAddress address) { - TensorAddress sparsePart = sparsePartialAddress(address); - return sparseMap.getOrDefault(sparsePart, -1); + private DenseSubspace blockOf(TensorAddress address) { + TensorAddress sparsePart = address.mappedPartialAddress(sparseType, type.dimensions()); + Integer blockNum = sparseMap.get(sparsePart); + if (blockNum == null || blockNum >= denseSubspaces.size()) { + return null; + } + return denseSubspaces.get(blockNum); } - int denseOffsetOf(TensorAddress address) { + private int denseOffsetOf(TensorAddress address) { long innerSize = 1; long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { @@ -506,54 +506,19 @@ public class MixedTensor implements Tensor { return denseSubspaceSize; } - private TensorAddress sparsePartialAddress(TensorAddress address) { - if (type.dimensions().size() != address.size()) - throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address); - TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); - for (int i = 0; i < type.dimensions().size(); ++i) { - TensorType.Dimension dimension = type.dimensions().get(i); - if ( ! dimension.isIndexed()) - builder.add(dimension.name(), address.label(i)); - } - return builder.build(); - } - - private TensorAddress denseOffsetToAddress(long denseOffset) { + private void denseOffsetToAddress(long denseOffset, int [] labels) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } long restSize = denseOffset; long innerSize = denseSubspaceSize; - long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { - TensorType.Dimension dimension = indexedDimensions.get(i); - long dimensionSize = dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); - - innerSize /= dimensionSize; - labels[i] = restSize / innerSize; + innerSize /= indexedDimensionsSize[i]; + labels[i] = (int) (restSize / innerSize); restSize %= innerSize; } - return TensorAddress.of(labels); - } - - TensorAddress fullAddressOf(TensorAddress sparsePart, long denseOffset) { - TensorAddress densePart = denseOffsetToAddress(denseOffset); - String[] labels = new String[type.dimensions().size()]; - int mappedIndex = 0; - int indexedIndex = 0; - for (TensorType.Dimension d : type.dimensions()) { - if (d.isIndexed()) { - labels[mappedIndex + indexedIndex] = densePart.label(indexedIndex); - indexedIndex++; - } else { - labels[mappedIndex + indexedIndex] = sparsePart.label(mappedIndex); - mappedIndex++; - } - } - return TensorAddress.of(labels); } @Override @@ -563,7 +528,7 @@ public class MixedTensor implements Tensor { private String contentToString(MixedTensor tensor, long maxCells) { if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller"); - if (mappedDimensions.size() == 0) { + if (mappedDimensions.isEmpty()) { StringBuilder b = new StringBuilder(); int cellsWritten = denseSubspaceToString(tensor, 0, maxCells, b); if (cellsWritten == maxCells && cellsWritten < tensor.size()) @@ -605,8 +570,7 @@ public class MixedTensor implements Tensor { b.append(", "); // start brackets - for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) - b.append("["); + b.append("[".repeat(Math.max(0, indexes.nextDimensionsAtStart()))); // value switch (type.valueType()) { @@ -619,32 +583,38 @@ public class MixedTensor implements Tensor { } // end bracket - for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) - b.append("]"); + b.append("]".repeat(Math.max(0, indexes.nextDimensionsAtEnd()))); } return index; } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; + return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } - static class Builder { + private static class Builder { private final Index index; - private final ImmutableMap.Builder<TensorAddress, Integer> builder; + private final ImmutableMap.Builder<TensorAddress, Integer> builder = new ImmutableMap.Builder<>(); + private final ImmutableList.Builder<DenseSubspace> listBuilder = new ImmutableList.Builder<>(); + private int count = 0; Builder(TensorType type) { index = new Index(type); - builder = new ImmutableMap.Builder<>(); } - void addBlock(TensorAddress address, int sz) { - builder.put(address, sz); + void addBlock(DenseSubspace block) { + if (block.cells.length != index.denseSubspaceSize) { + throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize + + " cells, but got: " + block.cells.length); + } + builder.put(block.sparseAddress, count++); + listBuilder.add(block); } Index build() { index.sparseMap = builder.build(); + index.denseSubspaces = listBuilder.build(); return index; } @@ -654,27 +624,16 @@ public class MixedTensor implements Tensor { } } - private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { - - private final TensorType type; - private final double[] values; - - public DenseSubspaceBuilder(TensorType type, double[] values) { - this.type = type; - this.values = values; - } - - @Override - public TensorType type() { return type; } + private record DenseSubspaceBuilder(TensorType type, double[] values) implements IndexedTensor.DirectIndexBuilder { @Override public void cellByDirectIndex(long index, double value) { - values[(int)index] = value; + values[(int) index] = value; } @Override public void cellByDirectIndex(long index, float value) { - values[(int)index] = value; + values[(int) index] = value; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index f1b3245ec80..8852bcd1ff3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,16 +1,16 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import java.util.Arrays; +import com.yahoo.tensor.impl.Label; /** - * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors + * An address to a subset of a tensors' cells, specifying a label for some, but not necessarily all, of the tensors * dimensions. * * @author bratseth */ // Implementation notes: -// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. +// - These are created in inner (though not innermost) loops, so they are implemented with minimal allocation. // We also avoid non-essential error checking. // - We can add support for string labels later without breaking the API public class PartialAddress { @@ -18,7 +18,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final Object[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -35,15 +35,15 @@ public class PartialAddress { public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return asLong(labels[i]); - return -1; + return labels[i]; + return Tensor.invalidIndex; } /** Returns the label of this dimension, or null if no label is specified for it */ public String label(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i].toString(); + return Label.fromNumber(labels[i]); return null; } @@ -55,7 +55,7 @@ public class PartialAddress { public String label(int i) { if (i >= size()) throw new IllegalArgumentException("No label at position " + i + " in " + this); - return labels[i].toString(); + return Label.fromNumber(labels[i]); } public int size() { return dimensionNames.length; } @@ -65,40 +65,14 @@ public class PartialAddress { public TensorAddress asAddress(TensorType type) { if (type.rank() != size()) throw new IllegalArgumentException(type + " has a different rank than " + this); - if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { - long[] numericLabels = new long[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - long label = numericLabel(type.dimensions().get(i).name()); - if (label < 0) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - numericLabels[i] = label; - } - return TensorAddress.of(numericLabels); - } - else { - String[] stringLabels = new String[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - String label = label(type.dimensions().get(i).name()); - if (label == null) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - stringLabels[i] = label; - } - return TensorAddress.of(stringLabels); - } - } - - private long asLong(Object label) { - if (label instanceof Long) { - return (Long) label; - } - else { - try { - return Long.parseLong(label.toString()); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Label '" + label + "' is not numeric"); - } + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label == Tensor.invalidIndex) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; } + return TensorAddress.of(numericLabels); } @Override @@ -114,12 +88,12 @@ public class PartialAddress { public static class Builder { private String[] dimensionNames; - private Object[] labels; + private long[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new Object[size]; + labels = new long[size]; } public Builder add(String dimensionName, long label) { @@ -131,7 +105,7 @@ public class PartialAddress { public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; - labels[index] = label; + labels[index] = Label.toNumber(label); index++; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8a4179cdc1a..ac9dc4e4eca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -20,6 +20,7 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; import com.yahoo.tensor.functions.Expand; +import com.yahoo.tensor.impl.Label; import java.util.ArrayList; import java.util.Arrays; @@ -39,7 +40,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; * A multidimensional array which can be used in computations. * <p> * A tensor consists of a set of <i>dimension</i> names and a set of <i>cells</i> containing scalar <i>values</i>. - * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines + * Each cell is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer. * <p> * The size of the set of dimensions of a tensor is called its <i>rank</i>. @@ -55,6 +56,9 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; */ public interface Tensor { + /** The constant signaling a nonexisting value in operations addressing tensor values by index. */ + int invalidIndex = -1; + // ----------------- Accessors TensorType type(); @@ -63,11 +67,25 @@ public interface Tensor { default boolean isEmpty() { return size() == 0; } /** - * Returns the number of cells in this. - * TODO Figure how to best return an int instead of a long - * An int is large enough, and java is far better at int base loops than long - **/ - long size(); + * Returns the number of cells in this, allowing for very large tensors. + * Prefer sizeAsInt in implementations that cannot handle sizes outside the int range. + */ + default long size() { + return sizeAsInt(); + } + + /** + * Returns the size of this as an int or throws an exception if it is too large to fit in an int. + * Prefer this over size() with implementations that only handle sizes in the int range. + * + * @throws IndexOutOfBoundsException if the size is too large to fit in an int + */ + default int sizeAsInt() { + long size = size(); + if (size > Integer.MAX_VALUE) + throw new IndexOutOfBoundsException("size = " + size + ", which is too large to fit in an int"); + return (int) size; + } /** Returns the value of a cell, or 0.0 if this cell does not exist */ double get(TensorAddress address); @@ -75,6 +93,9 @@ public interface Tensor { /** Returns true if this cell exists */ boolean has(TensorAddress address); + /** Returns the value at this address, or null of it does not exist. */ + Double getAsDouble(TensorAddress address); + /** * Returns the cell of this in some undefined order. * A cell instances is only valid until next() is called. @@ -97,7 +118,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -113,7 +134,7 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been * modified according to the given operation and cells in the given map. - * Cells in the map outside of existing cells are thus ignored. + * Cells in the map outside existing cells are thus ignored. * * @param op the modifying function * @param cells the cells to modify @@ -132,9 +153,9 @@ public interface Tensor { /** * Returns a new tensor where existing cells in this tensor have been - * removed according to the given set of addresses. Only valid for sparse + * removed according to the given set of addresses. Only valid for mapped * or mixed tensors. For mixed tensors, addresses are assumed to only - * contain the sparse dimensions, as the entire dense subspace is removed. + * contain the mapped dimensions, as the entire indexed subspace is removed. * * @param addresses list of addresses to remove * @return a new tensor where cells have been removed @@ -484,11 +505,10 @@ public interface Tensor { public TensorAddress getKey() { return address; } /** - * Returns the direct index which can be used to locate this cell, or -1 if not available. - * This is for optimizations mapping between tensors where this is possible without creating a - * TensorAddress. + * Returns the direct index which can be used to locate this cell, or Tensor.invalidIndex if not available. + * This is for optimizations mapping between tensors where this is possible without creating a TensorAddress. */ - long getDirectIndex() { return -1; } + long getDirectIndex() { return invalidIndex; } /** Returns the value as a double */ @Override @@ -537,8 +557,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -549,8 +569,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -608,7 +628,7 @@ public interface Tensor { public TensorType type() { return tensorBuilder.type(); } public CellBuilder label(String dimension, long label) { - return label(dimension, String.valueOf(label)); + return label(dimension, Label.fromNumber(label)); } public Builder value(double cellValue) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index a1cb278c75a..4fa759668b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,10 +1,13 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.Label; +import com.yahoo.tensor.impl.TensorAddressAny; + import java.util.Arrays; +import java.util.List; import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -14,18 +17,20 @@ import java.util.stream.Collectors; */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - public static TensorAddress of(String[] labels) { - return new StringTensorAddress(labels); + return TensorAddressAny.of(labels); + } + + public static TensorAddress ofLabels(String... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress ofLabels(String ... labels) { - return new StringTensorAddress(labels); + public static TensorAddress of(long... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress of(long ... labels) { - return new NumericTensorAddress(labels); + public static TensorAddress of(int... labels) { + return TensorAddressAny.of(labels); } /** Returns the number of labels in this */ @@ -61,27 +66,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { } @Override - public int hashCode() { - int result = 1; - for (int i = 0; i < size(); i++) { - if (label(i) != null) - result = 31 * result + label(i).hashCode(); + public String toString() { + StringBuilder sb = new StringBuilder("cell address ("); + int size = size(); + if (size > 0) { + sb.append(label(0)); + for (int i = 1; i < size; i++) { + sb.append(',').append(label(i)); + } } - return result; - } - @Override - public boolean equals(Object o) { - if (o == this) return true; - if ( ! (o instanceof TensorAddress other)) return false; - if (other.size() != this.size()) return false; - for (int i = 0; i < this.size(); i++) - if ( ! Objects.equals(this.label(i), other.label(i))) - return false; - return true; + return sb.append(')').toString(); } - /** Returns this as a string on the appropriate form given the type */ + /** + * Returns this as a string on the appropriate form given the type + */ public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { @@ -94,106 +94,78 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - /** Returns a label as a string with appropriate quoting/escaping when necessary */ + /** + * Returns a label as a string with appropriate quoting/escaping when necessary + */ public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; return "'" + label + "'"; } - private static String[] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); + /** Returns an address with only some of the dimension. Ordering will also be according to indexMap */ + public TensorAddress partialCopy(int[] indexMap) { + int[] labels = new int[indexMap.length]; + for (int i = 0; i < labels.length; ++i) { + labels[i] = (int)numericLabel(indexMap[i]); } - return asStrings; + return TensorAddressAny.ofUnsafe(labels); } - private static String asString(long index) { - return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } - - private static final class StringTensorAddress extends TensorAddress { - - private final String[] labels; - - private StringTensorAddress(String ... labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return labels[i]; } - - @Override - public long numericLabel(int i) { - try { - return Long.parseLong(labels[i]); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); + /** Creates a complete address by taking the mapped dimmensions from this and the indexed from the indexedPart */ + public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int[] densePart) { + int[] labels = new int[dimensions.size()]; + int mappedIndex = 0; + int indexedIndex = 0; + for (int i = 0; i < labels.length; i++) { + TensorType.Dimension d = dimensions.get(i); + if (d.isIndexed()) { + labels[i] = densePart[indexedIndex]; + indexedIndex++; + } else { + labels[i] = (int)numericLabel(mappedIndex); + mappedIndex++; } } - - @Override - public TensorAddress withLabel(int index, long label) { - String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = TensorAddress.asString(label); - return new StringTensorAddress(labels); - } - - - @Override - public String toString() { - return "cell address (" + String.join(",", labels) + ")"; - } - + return TensorAddressAny.ofUnsafe(labels); } - private static final class NumericTensorAddress extends TensorAddress { - - private final long[] labels; - - private NumericTensorAddress(long[] labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return TensorAddress.asString(labels[i]); } - - @Override - public long numericLabel(int i) { return labels[i]; } - - @Override - public TensorAddress withLabel(int index, long label) { - long[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = label; - return new NumericTensorAddress(labels); - } - - @Override - public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")"; + /** + * Returns an address containing the mapped dimensions of this. + * + * @param mappedType the type of the mapped subset of the type this is an address of; + * which is also the type of the returned address + * @param dimensions all the dimensions of the type this is an address of + */ + public TensorAddress mappedPartialAddress(TensorType mappedType, List<TensorType.Dimension> dimensions) { + if (dimensions.size() != size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this); + TensorAddress.Builder builder = new TensorAddress.Builder(mappedType); + for (int i = 0; i < dimensions.size(); ++i) { + TensorType.Dimension dimension = dimensions.get(i); + if ( ! dimension.isIndexed()) + builder.add(dimension.name(), (int)numericLabel(i)); } - + return builder.build(); } /** Builder of a tensor address */ public static class Builder { final TensorType type; - final String[] labels; + final int[] labels; + + private static int[] createEmptyLabels(int size) { + int[] labels = new int[size]; + Arrays.fill(labels, Tensor.invalidIndex); + return labels; + } public Builder(TensorType type) { - this(type, new String[type.dimensions().size()]); + this(type, createEmptyLabels(type.dimensions().size())); } - private Builder(TensorType type, String[] labels) { + private Builder(TensorType type, int[] labels) { this.type = type; this.labels = labels; } @@ -207,7 +179,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { var mappedSubtype = type.mappedSubtype(); if (mappedSubtype.rank() != 1) throw new IllegalArgumentException("Cannot add a label without explicit dimension to a tensor of type " + - type + ": Must have exactly one sparse dimension"); + type + ": Must have exactly one mapped dimension"); add(mappedSubtype.dimensions().get(0).name(), label); return this; } @@ -220,10 +192,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public Builder add(String dimension, String label) { Objects.requireNonNull(dimension, "dimension cannot be null"); Objects.requireNonNull(label, "label cannot be null"); - Optional<Integer> labelIndex = type.indexOfDimension(dimension); - if ( labelIndex.isEmpty()) + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) + throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); + labels[labelIndex] = Label.toNumber(label); + return this; + } + + public Builder add(String dimension, long label) { + return add(dimension, Convert.safe2Int(label)); + } + public Builder add(String dimension, int label) { + Objects.requireNonNull(dimension, "dimension cannot be null"); + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); - labels[labelIndex.get()] = label; + labels[labelIndex] = label; return this; } @@ -237,14 +221,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { void validate() { for (int i = 0; i < labels.length; i++) - if (labels[i] == null) + if (labels[i] == Tensor.invalidIndex) throw new IllegalArgumentException("Missing a label for dimension '" + type.dimensions().get(i).name() + "' for " + type); } public TensorAddress build() { validate(); - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } } @@ -256,7 +240,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { super(type); } - private PartialBuilder(TensorType type, String[] labels) { + private PartialBuilder(TensorType type, int[] labels) { super(type, labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..6b81d023a9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List<Dimension> dimensions; + private final Set<String> dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection<Dimension> dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + public boolean hasIndexedDimensions() { return indexedSubtype != empty; } + public boolean hasMappedDimensions() { return mappedSubtype != empty; } + public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List<Value> valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ @@ -176,6 +199,13 @@ public class TensorType { return Optional.of(i); return Optional.empty(); } + /** Returns the 0-base index of this dimension, or empty if it is not present */ + public int indexOfDimensionAsInt(String dimension) { + for (int i = 0; i < dimensions.size(); i++) + if (dimensions.get(i).name().equals(dimension)) + return i; + return Tensor.invalidIndex; + } /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ public Optional<Long> sizeOfDimension(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 0e4fab95c87..9125b35ea5d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.Arrays; import java.util.HashMap; @@ -133,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return tensor; } else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + if (tensor.type().hasMappedDimensions()) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) @@ -172,7 +173,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, long concatOffset, String concatDimension) { long[] combinedLabels = new long[concatType.dimensions().size()]; - Arrays.fill(combinedLabels, -1); + Arrays.fill(combinedLabels, Tensor.invalidIndex); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here @@ -191,7 +192,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.invalidIndex); return toIndexes; } @@ -208,7 +209,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != from.numericLabel(i)) return false; to[toIndex] = from.numericLabel(i); } } @@ -354,21 +355,21 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { - String[] labels = new String[plan.resultType.rank()]; + int[] labels = new int[plan.resultType.rank()]; int out = 0; int m = 0; int a = 0; int b = 0; for (var how : plan.combineHow) { switch (how) { - case left -> labels[out++] = leftOnly.label(a++); - case right -> labels[out++] = rightOnly.label(b++); - case both -> labels[out++] = match.label(m++); - case concat -> labels[out++] = String.valueOf(concatDimIdx); + case left -> labels[out++] = (int) leftOnly.numericLabel(a++); + case right -> labels[out++] = (int) rightOnly.numericLabel(b++); + case both -> labels[out++] = (int) match.numericLabel(m++); + case concat -> labels[out++] = concatDimIdx; default -> throw new IllegalArgumentException("cannot handle: " + how); } } - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } Tensor merge(CellVectorMapMap a, CellVectorMapMap b) { @@ -398,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET CellVectorMapMap decompose(Tensor input, SplitHow how) { var iter = input.cellIterator(); - String[] commonLabels = new String[(int)how.numCommon()]; - String[] separateLabels = new String[(int)how.numSeparate()]; + int[] commonLabels = new int[(int)how.numCommon()]; + int[] separateLabels = new int[(int)how.numSeparate()]; CellVectorMapMap result = new CellVectorMapMap(); while (iter.hasNext()) { var cell = iter.next(); @@ -409,14 +410,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case common -> commonLabels[commonIdx++] = addr.label(i); - case separate -> separateLabels[separateIdx++] = addr.label(i); + case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i); + case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i); case concat -> ccDimIndex = addr.numericLabel(i); default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i)); } } - TensorAddress commonAddr = TensorAddress.of(commonLabels); - TensorAddress separateAddr = TensorAddress.of(separateLabels); + TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels); + TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels); result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); } return result; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 3b6e03186a3..b595b1a40cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 0) + if (!arguments.isEmpty()) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells.values()) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } @@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); - if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + if ( ! type.hasOnlyIndexedBoundDimensions()) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + "only indexed, bound dimensions, but this has " + type); this.cells = List.copyOf(cells); @@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 4c92e1e57a2..fb345264f56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,8 +12,11 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -113,7 +116,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) { - long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + int joinedRank = (int)Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); @@ -128,8 +131,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (b.has(key)) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + Double bVal = b.getAsDouble(key); + if (bVal != null) { + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } return builder.build(); @@ -144,7 +148,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) { - if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes + if (subspace.isEmpty() || superspace.isEmpty()) // special case empty here to avoid doing it when finding sizes return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build(); DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace); @@ -169,7 +173,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder, DoubleBinaryOperator combinator) { - long joinedLength = Math.min(subspaceSize, superspaceSize); + int joinedLength = (int)Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -204,12 +208,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); - TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); - if (subspace.has(subaddress)) { - double subspaceValue = subspace.get(subaddress); + TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes); + Double subspaceValue = subspace.getAsDouble(subaddress); + if (subspaceValue != null) { builder.cell(supercell.getKey(), - reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) - : combinator.applyAsDouble(subspaceValue, supercell.getValue())); + reversedArgumentOrder + ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } } return builder.build(); @@ -223,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return subspaceIndexes; } - private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { - String[] subspaceLabels = new String[subspaceIndexes.length]; - for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); - return TensorAddress.of(subspaceLabels); - } - /** Slow join which works for any two tensors */ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) @@ -250,8 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { - Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); - Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames())); + int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection + Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames())); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); @@ -262,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + PartialAddress matchingBCells = sharedDimensionSize > 0 + ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize) + : empty; // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -274,11 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } - private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { - PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); - for (int i = 0; i < addressType.dimensions().size(); i++) - if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); + private static final PartialAddress empty = new PartialAddress.Builder(0).build(); + private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, + Set<String> retainDimensions, int sharedDimensionSize) { + PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); + for (int i = 0; i < addressType.dimensions().size(); i++) { + String dimension = addressType.dimensions().get(i).name(); + if (retainDimensions.contains(dimension)) + builder.add(dimension, address.numericLabel(i)); + } return builder.build(); } @@ -330,19 +335,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP int[] bIndexesInJoined = mapIndexes(b.type(), joinedType); // Iterate once through the smaller tensor and construct a hash map for common dimensions - Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(); + Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); - aCellsByCommonAddress.putIfAbsent(partialCommonAddress, new ArrayList<>()); - aCellsByCommonAddress.get(partialCommonAddress).add(aCell); + TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon); + aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } // Iterate once through the larger tensor and use the hash map to find joinable cells Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell bCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon); + TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) { TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType); @@ -358,7 +362,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } /** - * Returns the an array having one entry in order for each dimension of fromType + * Returns an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. * That is, if the returned array contains n at index i then * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) @@ -367,17 +371,18 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP static int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimensionAsInt(fromType.dimensions().get(i).name()); return toIndexes; } private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { - String[] joinedLabels = new String[joinedType.dimensions().size()]; + int[] joinedLabels = new int[joinedType.dimensions().size()]; + Arrays.fill(joinedLabels, Tensor.invalidIndex); mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; - return TensorAddress.of(joinedLabels); + return TensorAddressAny.ofUnsafe(joinedLabels); } /** @@ -386,11 +391,13 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { - for (int i = 0; i < from.size(); i++) { + private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) { + for (int i = 0, size = from.size(); i < size; i++) { int toIndex = indexMap[i]; - if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false; - to[toIndex] = from.label(i); + int label = Convert.safe2Int(from.numericLabel(i)); + if (to[toIndex] != Tensor.invalidIndex && to[toIndex] != label) + return false; + to[toIndex] = label; } return true; } @@ -412,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return typeBuilder.build(); } - private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { - TensorAddress address = cell.getKey(); - String[] labels = new String[indexMap.length]; - for (int i = 0; i < labels.length; ++i) { - labels[i] = address.label(indexMap[i]); - } - return TensorAddress.of(labels); - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java index c87ef42976d..aa9602339e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java @@ -98,9 +98,9 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction for (int i = 0; i < inputType.dimensions().size(); i++) { var dim = inputType.dimensions().get(i); if (dim.isMapped()) { - mapAddrBuilder.add(dim.name(), fullAddr.label(i)); + mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } else { - idxAddrBuilder.add(dim.name(), fullAddr.label(i)); + idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } } var mapAddr = mapAddrBuilder.build(); @@ -123,11 +123,11 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction var addrBuilder = new TensorAddress.Builder(outputType); for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) { var dim = inputTypeMapped.dimensions().get(i); - addrBuilder.add(dim.name(), mappedAddr.label(i)); + addrBuilder.add(dim.name(), mappedAddr.numericLabel(i)); } for (int i = 0; i < denseOutputDims.size(); i++) { var dim = denseOutputDims.get(i); - addrBuilder.add(dim.name(), denseAddr.label(i)); + addrBuilder.add(dim.name(), denseAddr.numericLabel(i)); } builder.cell(addrBuilder.build(), cell.getValue()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index 59394785382..ddad91dc060 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -121,10 +121,11 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> aCell = i.next(); var key = aCell.getKey(); - if (! b.has(key)) { + Double bVal = b.getAsDouble(key); + if (bVal == null) { builder.cell(key, aCell.getValue()); } else if (combinator != null) { - builder.cell(key, combinator.applyAsDouble(aCell.getValue(), b.get(key))); + builder.cell(key, combinator.applyAsDouble(aCell.getValue(), bVal)); } } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 8cf88610599..947fd6e0012 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -1,6 +1,8 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.DimensionSizes; +import com.yahoo.tensor.DirectIndexedAddress; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -9,16 +11,15 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.Set; /** * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions @@ -112,32 +113,84 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) { - if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) + if (!dimensions.isEmpty() && !argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + - dimensions + ": Not all those dimensions are present in this tensor"); + dimensions + ": Not all those dimensions are present in this tensor"); // Special case: Reduce all - if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) + if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size()) { if (argument.isEmpty()) return Tensor.from(0.0); else if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor) - return reduceIndexedVector((IndexedTensor)argument, aggregator); + return reduceIndexedVector((IndexedTensor) argument, aggregator); else return reduceAllGeneral(argument, aggregator); + } TensorType reducedType = outputType(argument.type(), dimensions); + int[] indexesToReduce = createIndexesToReduce(argument.type(), dimensions); + int[] indexesToKeep = createIndexesToKeep(argument.type(), indexesToReduce); + if (argument instanceof IndexedTensor indexedTensor && reducedType.hasOnlyIndexedBoundDimensions()) { + return reduceIndexedTensor(indexedTensor, reducedType, indexesToKeep, indexesToReduce, aggregator); + } else { + return reduceGeneral(argument, reducedType, indexesToKeep, aggregator); + } + } + + private static void reduce(IndexedTensor argument, ValueAggregator aggregator, DirectIndexedAddress address, int[] reduce, int reduceIndex) { + int currentIndex = reduce[reduceIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + if (reduceIndex + 1 < reduce.length) { + int nextDimension = reduceIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + reduce(argument, aggregator, address, reduce, nextDimension); + } + } else { + address.setIndex(currentIndex, 0); + long increment = address.getStride(currentIndex); + long directIndex = address.getDirectIndex(); + for (int i = 0; i < dimSize; i++) { + aggregator.aggregate(argument.get(directIndex + i * increment)); + } + } + } + + private static void reduce(IndexedTensor.Builder builder, DirectIndexedAddress destAddress, IndexedTensor argument, Aggregator aggregator, DirectIndexedAddress address, int[] toKeep, int keepIndex, int[] toReduce) { + if (keepIndex < toKeep.length) { + int currentIndex = toKeep[keepIndex]; + int dimSize = Convert.safe2Int(argument.dimensionSizes().size(currentIndex)); + + int nextKeep = keepIndex + 1; + for (int i = 0; i < dimSize; i++) { + address.setIndex(currentIndex, i); + destAddress.setIndex(keepIndex, i); + reduce(builder, destAddress, argument, aggregator, address, toKeep, nextKeep, toReduce); + } + } else { + ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); + reduce(argument, valueAggregator, address, toReduce, 0); + builder.cell(valueAggregator.aggregatedValue(), destAddress.getIndexes()); + } + + } - // Reduce cells - int[] indexesToKeep = createIndexesToKeep(argument.type(), dimensions); + private static Tensor reduceIndexedTensor(IndexedTensor argument, TensorType reducedType, int[] indexesToKeep, int[] indexesToReduce, Aggregator aggregator) { + + var reducedBuilder = IndexedTensor.Builder.of(reducedType); + DirectIndexedAddress reducedAddress = DirectIndexedAddress.of(DimensionSizes.of(reducedType)); + reduce(reducedBuilder, reducedAddress, argument, aggregator, argument.directAddress(), indexesToKeep, 0, indexesToReduce); + return reducedBuilder.build(); + } + + private static Tensor reduceGeneral(Tensor argument, TensorType reducedType, int[] indexesToKeep, Aggregator aggregator) { // TODO cells.size() is most likely an overestimate, and might need a better heuristic // But the upside is larger than the downside. - Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>((int)argument.size()); + Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(argument.sizeAsInt()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress reducedAddress = reduceDimensions(indexesToKeep, cell.getKey()); - ValueAggregator aggr = aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator)); - if (aggr == null) - aggr = aggregatingCells.get(reducedAddress); + TensorAddress reducedAddress = cell.getKey().partialCopy(indexesToKeep); + ValueAggregator aggr = aggregatingCells.computeIfAbsent(reducedAddress, (key) ->ValueAggregator.ofType(aggregator)); aggr.aggregate(cell.getValue()); } Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType); @@ -146,39 +199,43 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return reducedBuilder.build(); } - private static int[] createIndexesToKeep(TensorType argumentType, List<String> dimensions) { - Set<Integer> indexesToRemove = new HashSet<>(dimensions.size()*2); - for (String dimensionToRemove : dimensions) - indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get()); - int[] indexesToKeep = new int[argumentType.rank() - indexesToRemove.size()]; + + private static int[] createIndexesToReduce(TensorType tensorType, List<String> dimensions) { + int[] indexesToReduce = new int[dimensions.size()]; + for (int i = 0; i < dimensions.size(); i++) { + indexesToReduce[i] = tensorType.indexOfDimension(dimensions.get(i)).get(); + } + return indexesToReduce; + } + private static int[] createIndexesToKeep(TensorType argumentType, int[] indexesToReduce) { + int[] indexesToKeep = new int[argumentType.rank() - indexesToReduce.length]; int toKeepIndex = 0; for (int i = 0; i < argumentType.rank(); i++) { - if ( ! indexesToRemove.contains(i)) + if ( ! contains(indexesToReduce, i)) indexesToKeep[toKeepIndex++] = i; } return indexesToKeep; } - - private static TensorAddress reduceDimensions(int[] indexesToKeep, TensorAddress address) { - String[] reducedLabels = new String[indexesToKeep.length]; - int reducedLabelIndex = 0; - for (int toKeep : indexesToKeep) - reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return TensorAddress.of(reducedLabels); + private static boolean contains(int[] list, int key) { + for (int candidate : list) { + if (candidate == key) return true; + } + return false; } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.dimensionSizes().size(0); i++) + int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0)); + for (int i = 0; i < dimensionSize ; i++) valueAggregator.aggregate(argument.get(i)); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } static abstract class ValueAggregator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index aece782d296..2d5a0518747 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if ( ! (a instanceof IndexedTensor)) return false; - if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (a.type().hasOnlyIndexedBoundDimensions())) return false; if ( ! (b instanceof IndexedTensor)) return false; - if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (b.type().hasOnlyIndexedBoundDimensions())) return false; TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index a2a3874eced..05db61f5395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -35,7 +35,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); - if (fromDimensions.size() < 1) + if (fromDimensions.isEmpty()) throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension"); if (fromDimensions.size() != toDimensions.size()) throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " + @@ -89,7 +89,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET for (int i = 0; i < tensor.type().dimensions().size(); i++) { String dimensionName = tensor.type().dimensions().get(i).name(); String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName); - toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get(); + toIndexes[renamedType.indexOfDimension(newDimensionName).get()] = i; } // avoid building a new tensor if dimensions can simply be renamed @@ -100,7 +100,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET Tensor.Builder builder = Tensor.Builder.of(renamedType); for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> cell = i.next(); - TensorAddress renamedAddress = rename(cell.getKey(), toIndexes); + TensorAddress renamedAddress = cell.getKey().partialCopy(toIndexes); builder.cell(renamedAddress, cell.getValue()); } return builder.build(); @@ -118,13 +118,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return true; } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { - String[] reorderedLabels = new String[toIndexes.length]; - for (int i = 0; i < toIndexes.length; i++) - reorderedLabels[toIndexes[i]] = address.label(i); - return TensorAddress.of(reorderedLabels); - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 807f56b1a49..38ac42a5f1f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -131,7 +131,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (int i = 0; i < address.size(); i++) { String dimension = type.dimensions().get(i).name(); if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) - b.add(dimension, address.label(i)); + b.add(dimension, (int)address.numericLabel(i)); } return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java new file mode 100644 index 00000000000..e2cb64fdd1f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java @@ -0,0 +1,16 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +/** + * Utility to make common conversions safe + * + * @author baldersheim + */ +public class Convert { + public static int safe2Int(long value) { + if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int"); + } + return (int) value; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java new file mode 100644 index 00000000000..7c1e8646245 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java @@ -0,0 +1,83 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.Tensor; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A label is a value of a mapped dimension of a tensor. + * This class provides a mapping of labels to numbers which allow for more efficient computation with + * mapped tensor dimensions. + * + * @author baldersheim + */ +public class Label { + + private static final String[] SMALL_INDEXES = createSmallIndexesAsStrings(1000); + + private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>(); + + // Index 0 is unused, that is a valid positive number + // 1(-1) is reserved for the Tensor.INVALID_INDEX + private static volatile String[] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"}; + private static int numUniqeStrings = 2; + + private static String[] createSmallIndexesAsStrings(int count) { + String[] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private static int addNewUniqueString(String s) { + synchronized (string2Enum) { + if (numUniqeStrings >= uniqueStrings.length) { + uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2); + } + uniqueStrings[numUniqeStrings] = s; + return -numUniqeStrings++; + } + } + + private static String asNumericString(long index) { + return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + + private static boolean validNumericIndex(String s) { + if (s.isEmpty() || ((s.length() > 1) && (s.charAt(0) == '0'))) return false; + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if ((c < '0') || (c > '9')) return false; + } + return true; + } + + public static int toNumber(String s) { + if (s == null) { return Tensor.invalidIndex; } + try { + if (validNumericIndex(s)) { + return Integer.parseInt(s); + } + } catch (NumberFormatException e) { + } + return string2Enum.computeIfAbsent(s, Label::addNewUniqueString); + } + + public static String fromNumber(int v) { + if (v >= 0) { + return asNumericString(v); + } else { + if (v == Tensor.invalidIndex) { return null; } + return uniqueStrings[-v]; + } + } + + public static String fromNumber(long v) { + return fromNumber(Convert.safe2Int(v)); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java new file mode 100644 index 00000000000..2e70811a67c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java @@ -0,0 +1,154 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import static com.yahoo.tensor.impl.Convert.safe2Int; +import static com.yahoo.tensor.impl.Label.toNumber; +import static com.yahoo.tensor.impl.Label.fromNumber; + +/** + * Parent of tensor address family centered around each dimension as int. + * A positive number represents a numeric index usable as a direect addressing. + * - 1 is representing an invalid/null address + * Other negative numbers are an enumeration maintained in {@link Label} + * + * @author baldersheim + */ +abstract public class TensorAddressAny extends TensorAddress { + + @Override + public String label(int i) { + return fromNumber((int)numericLabel(i)); + } + + public static TensorAddress of() { + return TensorAddressEmpty.empty; + } + + public static TensorAddress of(String label) { + return new TensorAddressAny1(toNumber(label)); + } + + public static TensorAddress of(String label0, String label1) { + return new TensorAddressAny2(toNumber(label0), toNumber(label1)); + } + + public static TensorAddress of(String label0, String label1, String label2) { + return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2)); + } + + public static TensorAddress of(String label0, String label1, String label2, String label3) { + return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3)); + } + + public static TensorAddress of(String[] labels) { + int[] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = toNumber(labels[i]); + } + return ofUnsafe(labelsAsInt); + } + + public static TensorAddress of(int label) { + return new TensorAddressAny1(sanitize(label)); + } + + public static TensorAddress of(int label0, int label1) { + return new TensorAddressAny2(sanitize(label0), sanitize(label1)); + } + + public static TensorAddress of(int label0, int label1, int label2) { + return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2)); + } + + public static TensorAddress of(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3)); + } + + public static TensorAddress of(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> new TensorAddressAny1(sanitize(labels[0])); + case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1])); + case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2])); + case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3])); + default -> { + for (int i = 0; i < labels.length; i++) { + sanitize(labels[i]); + } + yield new TensorAddressAnyN(labels); + } + }; + } + + public static TensorAddress of(long label) { + return of(safe2Int(label)); + } + + public static TensorAddress of(long label0, long label1) { + return of(safe2Int(label0), safe2Int(label1)); + } + + public static TensorAddress of(long label0, long label1, long label2) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2)); + } + + public static TensorAddress of(long label0, long label1, long label2, long label3) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3)); + } + + public static TensorAddress of(long ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(safe2Int(labels[0])); + case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1])); + case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2])); + case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3])); + default -> { + int[] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = safe2Int(labels[i]); + } + yield of(labelsAsInt); + } + }; + } + + private static TensorAddress ofUnsafe(int label) { + return new TensorAddressAny1(label); + } + + private static TensorAddress ofUnsafe(int label0, int label1) { + return new TensorAddressAny2(label0, label1); + } + + private static TensorAddress ofUnsafe(int label0, int label1, int label2) { + return new TensorAddressAny3(label0, label1, label2); + } + + private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(label0, label1, label2, label3); + } + + public static TensorAddress ofUnsafe(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(labels[0]); + case 2 -> ofUnsafe(labels[0], labels[1]); + case 3 -> ofUnsafe(labels[0], labels[1], labels[2]); + case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]); + default -> new TensorAddressAnyN(labels); + }; + } + + private static int sanitize(int label) { + if (label < Tensor.invalidIndex) { + throw new IndexOutOfBoundsException("cell label " + label + " must be positive"); + } + return label; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java new file mode 100644 index 00000000000..a9be6173781 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java @@ -0,0 +1,41 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * A one-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny1 extends TensorAddressAny { + + private final int label; + + TensorAddressAny1(int label) { this.label = label; } + + @Override public int size() { return 1; } + + @Override + public long numericLabel(int i) { + if (i == 0) { + return label; + } + throw new IndexOutOfBoundsException("Index is not zero: " + i); + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label)); + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override public int hashCode() { return Math.abs(label); } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny1 any) && (label == any.label); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java new file mode 100644 index 00000000000..43f65d495cf --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A two-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny2 extends TensorAddressAny { + + private final int label0, label1; + + TensorAddressAny2(int label0, int label1) { + this.label0 = label0; + this.label1 = label1; + } + + @Override public int size() { return 2; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1); + case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java new file mode 100644 index 00000000000..c22ff47b3c4 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java @@ -0,0 +1,61 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A three-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny3 extends TensorAddressAny { + + private final int label0, label1, label2; + + TensorAddressAny3(int label0, int label1, int label2) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + } + + @Override public int size() { return 3; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2); + case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2); + case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny3 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java new file mode 100644 index 00000000000..6eb6b9216bf --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java @@ -0,0 +1,66 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * A four-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAny4 extends TensorAddressAny { + + private final int label0, label1, label2, label3; + + TensorAddressAny4(int label0, int label1, int label2, int label3) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + this.label3 = label3; + } + + @Override public int size() { return 4; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + case 3 -> label3; + default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3); + case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3); + case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3); + case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) | + (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny4 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2) && + (label3 == any.label3); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java new file mode 100644 index 00000000000..d5bac62bf18 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java @@ -0,0 +1,53 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; + +import static java.lang.Math.abs; + +/** + * An n-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressAnyN extends TensorAddressAny { + + private final int[] labels; + + TensorAddressAnyN(int[] labels) { + if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label"); + this.labels = labels; + } + + @Override public int size() { return labels.length; } + + @Override public long numericLabel(int i) { return labels[i]; } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + int[] copy = Arrays.copyOf(labels, labels.length); + copy[labelIndex] = Convert.safe2Int(label); + return new TensorAddressAnyN(copy); + } + + @Override public int hashCode() { + int hash = abs(labels[0]); + for (int i = 0; i < size(); i++) { + hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash))); + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false; + for (int i = 0; i < size(); i++) { + if (labels[i] != any.labels[i]) return false; + } + return true; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java new file mode 100644 index 00000000000..eb7e62e913b --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java @@ -0,0 +1,33 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * A zero-dimensional address. + * + * @author baldersheim + */ +final class TensorAddressEmpty extends TensorAddressAny { + + static TensorAddress empty = new TensorAddressEmpty(); + + private TensorAddressEmpty() {} + + @Override public int size() { return 0; } + + @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override + public int hashCode() { return 0; } + + @Override + public boolean equals(Object o) { return o instanceof TensorAddressEmpty; } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java new file mode 100644 index 00000000000..6b004bf2d02 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java @@ -0,0 +1,6 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.tensor.impl; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index ca9527fd681..32e74c0f132 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -56,22 +56,22 @@ public class DenseBinaryFormat implements BinaryFormat { } private void encodeDoubleCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putDouble(tensor.get(i)); } private void encodeFloatCells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putFloat(tensor.getFloat(i)); } private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i))); } private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) { - for (int i = 0; i < tensor.size(); i++) + for (int i = 0; i < tensor.sizeAsInt(); i++) buffer.put((byte) tensor.getFloat(i)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 444ce02b14a..5598690e0bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -16,15 +16,7 @@ import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Slice; - -import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; -import java.util.List; -import java.util.Set; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -60,8 +52,7 @@ public class JsonFormat { // Short form for a single mapped dimension Cursor parent = root == null ? slime.setObject() : root.setObject("cells"); encodeSingleDimensionCells((MappedTensor) tensor, parent); - } else if (tensor instanceof MixedTensor && - tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) { + } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) { // Short form for a mixed tensor boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1; Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() ) @@ -143,9 +134,9 @@ public class JsonFormat { } private static void encodeBlocks(MixedTensor tensor, Cursor cursor) { - var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped()) + var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped) .map(d -> TensorType.Dimension.mapped(d.name())).toList(); - if (mappedDimensions.size() < 1) { + if (mappedDimensions.isEmpty()) { throw new IllegalArgumentException("Should be ensured by caller"); } @@ -179,23 +170,6 @@ public class JsonFormat { cursor.setDouble(field, value); } - private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) { - TensorAddress.Builder builder = new TensorAddress.Builder(subType); - for (TensorType.Dimension dim : subType.dimensions()) { - builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()). - orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index")))); - } - return builder.build(); - } - - private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) { - List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size()); - for (int i = 0; i < subAddress.size(); ++i) { - sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i))); - } - return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate(); - } - /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { @@ -204,7 +178,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) + else if (root.field("values").valid() && ! builder.type().hasMappedDimensions()) decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); @@ -298,14 +272,14 @@ public class JsonFormat { /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { - boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + boolean hasIndexed = builder.type().hasIndexedDimensions(); + boolean hasMapped = builder.type().hasMappedDimensions(); if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) decodeValuesAtTop(root, builder); - else if (hasMapped && hasIndexed) + else if (hasIndexed) decodeBlocks(root, builder); else decodeCells(root, builder); @@ -423,9 +397,7 @@ public class JsonFormat { if (decoded.length == 0) { throw new IllegalArgumentException("The block value string does not contain any values"); } - for (int i = 0; i < decoded.length; i++) { - values[i] = decoded[i]; - } + System.arraycopy(decoded, 0, values, 0, decoded.length); } else { throw new IllegalArgumentException("Expected a block to contain an array of values"); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index bdeb9add41a..3a117e41461 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -48,7 +48,7 @@ class SparseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation + buffer.putInt1_4Bytes(tensor.sizeAsInt()); // XXX: Size truncation switch (serializationValueType) { case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break; case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index d4b18c73f11..0a5c713f3e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -55,8 +55,8 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMappedDimensions = tensor.type().hasMappedDimensions(); + boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions(); boolean isMixed = hasMappedDimensions && hasIndexedDimensions; // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead |