diff options
Diffstat (limited to 'vespajlib/src/main/java')
22 files changed, 630 insertions, 298 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 53f50fc4d02..085f9172095 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -78,6 +78,9 @@ class IndexedDoubleTensor extends IndexedTensor { @Override public Builder cell(TensorAddress address, double value) { + if (address == null) { + return null; + } values[(int)toValueIndex(address, sizes(), type)] = value; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index f26174d9576..a428524612b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -171,10 +171,8 @@ public abstract class IndexedTensor implements Tensor { } static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) { - if (address.isEmpty()) return 0; - long valueIndex = 0; - for (int i = 0; i < address.size(); i++) { + for (int i = 0, sz = address.size(); i < sz; i++) { long label = address.numericLabel(i); if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); @@ -893,8 +891,8 @@ public abstract class IndexedTensor implements Tensor { private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) { long size = 1; - for (int iterateDimension : iterateDimensions) - size *= sizes.size(iterateDimension); + for (int i = 0; i < iterateDimensions.size(); i++) + size *= sizes.size(iterateDimensions.get(i)); return size; } @@ -1060,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor { /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { - private long lastComputedSourceValueIndex = -1; + private long lastComputedSourceValueIndex = Tensor.INVALID_INDEX; private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 95d1d70118a..d4469f447cb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -4,8 +4,6 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.impl.NumericTensorAddress; -import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Arrays; @@ -111,7 +109,7 @@ public class MixedTensor implements Tensor { return new Iterator<>() { final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator(); DenseSubspace currBlock = null; - final long[] labels = new long[index.indexedDimensions.size()]; + final int[] labels = new int[index.indexedDimensions.size()]; int currOffset = index.denseSubspaceSize; int prevOffset = -1; @Override @@ -127,7 +125,7 @@ public class MixedTensor implements Tensor { if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1 index.denseOffsetToAddress(currOffset, labels); } - TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, labels); + TensorAddress fullAddr = currBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels); prevOffset = currOffset; double value = currBlock.cells[currOffset++]; return new Cell(fullAddr, value); @@ -321,7 +319,7 @@ public class MixedTensor implements Tensor { @Override public Tensor.Builder cell(TensorAddress address, double value) { - TensorAddress sparsePart = index.sparsePartialAddress(address); + TensorAddress sparsePart = address.sparsePartialAddress(index.sparseType, index.type.dimensions()); int denseOffset = index.denseOffsetOf(address); double[] denseSubspace = denseSubspace(sparsePart); denseSubspace[denseOffset] = value; @@ -475,7 +473,7 @@ public class MixedTensor implements Tensor { } private DenseSubspace blockOf(TensorAddress address) { - TensorAddress sparsePart = sparsePartialAddress(address); + TensorAddress sparsePart = address.sparsePartialAddress(sparseType, type.dimensions()); Integer blockNum = sparseMap.get(sparsePart); if (blockNum == null || blockNum >= denseSubspaces.size()) { return null; @@ -502,19 +500,7 @@ public class MixedTensor implements Tensor { return denseSubspaceSize; } - private TensorAddress sparsePartialAddress(TensorAddress address) { - if (type.dimensions().size() != address.size()) - throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address); - TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); - for (int i = 0; i < type.dimensions().size(); ++i) { - TensorType.Dimension dimension = type.dimensions().get(i); - if ( ! dimension.isIndexed()) - builder.add(dimension.name(), address.label(i)); - } - return builder.build(); - } - - private void denseOffsetToAddress(long denseOffset, long [] labels) { + private void denseOffsetToAddress(long denseOffset, int [] labels) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } @@ -524,28 +510,11 @@ public class MixedTensor implements Tensor { for (int i = 0; i < labels.length; ++i) { innerSize /= indexedDimensionsSize[i]; - labels[i] = restSize / innerSize; + labels[i] = (int) (restSize / innerSize); restSize %= innerSize; } } - private TensorAddress fullAddressOf(TensorAddress sparsePart, long [] densePart) { - String[] labels = new String[type.dimensions().size()]; - int mappedIndex = 0; - int indexedIndex = 0; - for (int i = 0; i < type.dimensions().size(); i++) { - TensorType.Dimension d = type.dimensions().get(i); - if (d.isIndexed()) { - labels[i] = NumericTensorAddress.asString(densePart[indexedIndex]); - indexedIndex++; - } else { - labels[i] = sparsePart.label(mappedIndex); - mappedIndex++; - } - } - return StringTensorAddress.unsafeOf(labels); - } - @Override public String toString() { return "index into " + type; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 3e41e6d94eb..da643d8c173 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,9 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.tensor.impl.StringTensorAddress; - -import java.util.Arrays; +import com.yahoo.tensor.impl.Label; /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors @@ -20,7 +18,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final Object[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -37,15 +35,15 @@ public class PartialAddress { public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return asLong(labels[i]); - return -1; + return labels[i]; + return Tensor.INVALID_INDEX; } /** Returns the label of this dimension, or null if no label is specified for it */ public String label(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i].toString(); + return Label.fromNumber(labels[i]); return null; } @@ -57,7 +55,7 @@ public class PartialAddress { public String label(int i) { if (i >= size()) throw new IllegalArgumentException("No label at position " + i + " in " + this); - return labels[i].toString(); + return Label.fromNumber(labels[i]); } public int size() { return dimensionNames.length; } @@ -67,40 +65,14 @@ public class PartialAddress { public TensorAddress asAddress(TensorType type) { if (type.rank() != size()) throw new IllegalArgumentException(type + " has a different rank than " + this); - if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { - long[] numericLabels = new long[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - long label = numericLabel(type.dimensions().get(i).name()); - if (label < 0) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - numericLabels[i] = label; - } - return TensorAddress.of(numericLabels); - } - else { - String[] stringLabels = new String[labels.length]; - for (int i = 0; i < type.dimensions().size(); i++) { - String label = label(type.dimensions().get(i).name()); - if (label == null) - throw new IllegalArgumentException(type + " dimension names does not match " + this); - stringLabels[i] = label; - } - return StringTensorAddress.unsafeOf(stringLabels); - } - } - - private long asLong(Object label) { - if (label instanceof Long) { - return (Long) label; - } - else { - try { - return Long.parseLong(label.toString()); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Label '" + label + "' is not numeric"); - } + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label == Tensor.INVALID_INDEX) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; } + return TensorAddress.of(numericLabels); } @Override @@ -116,12 +88,12 @@ public class PartialAddress { public static class Builder { private String[] dimensionNames; - private Object[] labels; + private long[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new Object[size]; + labels = new long[size]; } public Builder add(String dimensionName, long label) { @@ -133,7 +105,7 @@ public class PartialAddress { public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; - labels[index] = label; + labels[index] = Label.toNumber(label); index++; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index d034ac551f8..d650b88f202 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -20,7 +20,7 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.XwPlusB; import com.yahoo.tensor.functions.Expand; -import com.yahoo.tensor.impl.NumericTensorAddress; +import com.yahoo.tensor.impl.Label; import java.util.ArrayList; import java.util.Arrays; @@ -55,6 +55,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; * @author bratseth */ public interface Tensor { + int INVALID_INDEX = -1; // ----------------- Accessors @@ -506,7 +507,7 @@ public interface Tensor { * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ - long getDirectIndex() { return -1; } + long getDirectIndex() { return INVALID_INDEX; } /** Returns the value as a double */ @Override @@ -626,7 +627,7 @@ public interface Tensor { public TensorType type() { return tensorBuilder.type(); } public CellBuilder label(String dimension, long label) { - return label(dimension, NumericTensorAddress.asString(label)); + return label(dimension, Label.fromNumber(label)); } public Builder value(double cellValue) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 1b88a5d1b2f..59a5e2a49b1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,13 +1,11 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.tensor.impl.NumericTensorAddress; -import com.yahoo.tensor.impl.StringTensorAddress; -import net.jpountz.xxhash.XXHash32; -import net.jpountz.xxhash.XXHashFactory; +import com.yahoo.tensor.impl.Label; +import com.yahoo.tensor.impl.TensorAddressAny; -import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.List; import java.util.Objects; /** @@ -18,23 +16,25 @@ import java.util.Objects; */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32(); - public static TensorAddress of(String[] labels) { - return StringTensorAddress.of(labels); + return TensorAddressAny.of(labels); } - public static TensorAddress ofLabels(String ... labels) { - return StringTensorAddress.of(labels); + public static TensorAddress ofLabels(String... labels) { + return TensorAddressAny.of(labels); } - public static TensorAddress of(long ... labels) { - return NumericTensorAddress.of(labels); + public static TensorAddress of(long... labels) { + return TensorAddressAny.of(labels); } - private int cached_hash = 0; + public static TensorAddress of(int... labels) { + return TensorAddressAny.of(labels); + } - /** Returns the number of labels in this */ + /** + * Returns the number of labels in this + */ public abstract int size(); /** @@ -67,32 +67,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { } @Override - public int hashCode() { - if (cached_hash != 0) return cached_hash; - - int hash = 0; - for (int i = 0; i < size(); i++) { - String label = label(i); - if (label != null) { - byte [] buf = label.getBytes(StandardCharsets.UTF_8); - hash = hasher.hash(buf, 0, buf.length, hash); + public String toString() { + StringBuilder sb = new StringBuilder("cell address ("); + int sz = size(); + if (sz > 0) { + sb.append(label(0)); + for (int i = 1; i < sz; i++) { + sb.append(',').append(label(i)); } } - return cached_hash = hash; - } - @Override - public boolean equals(Object o) { - if (o == this) return true; - if ( ! (o instanceof TensorAddress other)) return false; - if (other.size() != this.size()) return false; - for (int i = 0; i < this.size(); i++) - if ( ! Objects.equals(this.label(i), other.label(i))) - return false; - return true; + return sb.append(')').toString(); } - /** Returns this as a string on the appropriate form given the type */ + /** + * Returns this as a string on the appropriate form given the type + */ public final String toString(TensorType type) { StringBuilder b = new StringBuilder("{"); for (int i = 0; i < size(); i++) { @@ -105,24 +95,72 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - /** Returns a label as a string with appropriate quoting/escaping when necessary */ + /** + * Returns a label as a string with appropriate quoting/escaping when necessary + */ public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; return "'" + label + "'"; } + /** Returns an address with only some of the dimension */ + public TensorAddress partialCopy(int[] indexMap) { + int[] labels = new int[indexMap.length]; + for (int i = 0; i < labels.length; ++i) { + labels[i] = (int)numericLabel(indexMap[i]); + } + return TensorAddressAny.ofUnsafe(labels); + } + + /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */ + public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) { + int [] labels = new int[dimensions.size()]; + int mappedIndex = 0; + int indexedIndex = 0; + for (int i = 0; i < labels.length; i++) { + TensorType.Dimension d = dimensions.get(i); + if (d.isIndexed()) { + labels[i] = densePart[indexedIndex]; + indexedIndex++; + } else { + labels[i] = (int)numericLabel(mappedIndex); + mappedIndex++; + } + } + return TensorAddressAny.ofUnsafe(labels); + } + + /** Extracts the sparse(non-indexed) dimensions of the address */ + public TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) { + if (dimensions.size() != size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this); + TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); + for (int i = 0; i < dimensions.size(); ++i) { + TensorType.Dimension dimension = dimensions.get(i); + if ( ! dimension.isIndexed()) + builder.add(dimension.name(), (int)numericLabel(i)); + } + return builder.build(); + } + /** Builder of a tensor address */ public static class Builder { final TensorType type; - final String[] labels; + final int[] labels; + + private static int [] createEmptyLabels(int size) { + int [] labels = new int[size]; + Arrays.fill(labels, Tensor.INVALID_INDEX); + return labels; + } public Builder(TensorType type) { - this(type, new String[type.dimensions().size()]); + this(type, createEmptyLabels(type.dimensions().size())); } - private Builder(TensorType type, String[] labels) { + private Builder(TensorType type, int[] labels) { this.type = type; this.labels = labels; } @@ -152,6 +190,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { int labelIndex = type.indexOfDimensionAsInt(dimension); if ( labelIndex < 0) throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); + labels[labelIndex] = Label.toNumber(label); + return this; + } + public Builder add(String dimension, int label) { + Objects.requireNonNull(dimension, "dimension cannot be null"); + int labelIndex = type.indexOfDimensionAsInt(dimension); + if ( labelIndex < 0) + throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'"); labels[labelIndex] = label; return this; } @@ -166,14 +212,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { void validate() { for (int i = 0; i < labels.length; i++) - if (labels[i] == null) + if (labels[i] == Tensor.INVALID_INDEX) throw new IllegalArgumentException("Missing a label for dimension '" + type.dimensions().get(i).name() + "' for " + type); } public TensorAddress build() { validate(); - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } } @@ -185,7 +231,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { super(type); } - private PartialBuilder(TensorType type, String[] labels) { + private PartialBuilder(TensorType type, int[] labels) { super(type, labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index dcfee88d599..62ed4ad683c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -204,7 +204,7 @@ public class TensorType { for (int i = 0; i < dimensions.size(); i++) if (dimensions.get(i).name().equals(dimension)) return i; - return -1; + return Tensor.INVALID_INDEX; } /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 866b710b72e..37ca7f979a1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.impl.StringTensorAddress; import java.util.Arrays; import java.util.HashMap; @@ -173,7 +172,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, long concatOffset, String concatDimension) { long[] combinedLabels = new long[concatType.dimensions().size()]; - Arrays.fill(combinedLabels, -1); + Arrays.fill(combinedLabels, Tensor.INVALID_INDEX); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here @@ -192,7 +191,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET private int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.INVALID_INDEX); return toIndexes; } @@ -209,7 +208,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != from.numericLabel(i)) return false; to[toIndex] = from.numericLabel(i); } } @@ -369,7 +368,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET default -> throw new IllegalArgumentException("cannot handle: " + how); } } - return StringTensorAddress.unsafeOf(labels); + return TensorAddress.of(labels); } Tensor merge(CellVectorMapMap a, CellVectorMapMap b) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index e0ac549651c..047d8ee6ef0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -12,9 +12,11 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.impl.StringTensorAddress; +import com.yahoo.tensor.impl.Convert; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; @@ -206,7 +208,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) { Map.Entry<TensorAddress, Double> supercell = i.next(); - TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); + TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes); Double subspaceValue = subspace.getAsDouble(subaddress); if (subspaceValue != null) { builder.cell(supercell.getKey(), @@ -226,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return subspaceIndexes; } - private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) { - String[] subspaceLabels = new String[subspaceIndexes.length]; - for (int i = 0; i < subspaceIndexes.length; i++) - subspaceLabels[i] = superAddress.label(subspaceIndexes[i]); - return StringTensorAddress.unsafeOf(subspaceLabels); - } - /** Slow join which works for any two tensors */ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) { if (a instanceof IndexedTensor && b instanceof IndexedTensor) @@ -253,9 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder, DoubleBinaryOperator combinator) { - Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames())); int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection - Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames())); DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); @@ -266,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP // for each combination of dimensions in a which is also in b while (aSubspace.hasNext()) { Tensor.Cell aCell = aSubspace.next(); - PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize); + PartialAddress matchingBCells = sharedDimensionSize > 0 + ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize) + : empty; // for each matching combination of dimensions ony in b for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { Tensor.Cell bCell = bSubspace.next(); @@ -278,12 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } } + private static PartialAddress empty = new PartialAddress.Builder(0).build(); private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions, int sharedDimensionSize) { PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize); - for (int i = 0; i < addressType.dimensions().size(); i++) - if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); + for (int i = 0; i < addressType.dimensions().size(); i++) { + String dimension = addressType.dimensions().get(i).name(); + if (retainDimensions.contains(dimension)) + builder.add(dimension, address.numericLabel(i)); + } return builder.build(); } @@ -338,7 +338,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt()); for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell aCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon); + TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon); aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell); } @@ -346,7 +346,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) { Tensor.Cell bCell = cellIterator.next(); - TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon); + TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon); for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) { TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined, bCell.getKey(), bIndexesInJoined, joinedType); @@ -377,11 +377,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType joinedType) { - String[] joinedLabels = new String[joinedType.dimensions().size()]; + int[] joinedLabels = new int[joinedType.dimensions().size()]; + Arrays.fill(joinedLabels, Tensor.INVALID_INDEX); mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); if ( ! compatible) return null; - return StringTensorAddress.unsafeOf(joinedLabels); + return TensorAddressAny.ofUnsafe(joinedLabels); } /** @@ -390,11 +391,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) { - for (int i = 0; i < from.size(); i++) { + private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) { + for (int i = 0, sz = from.size(); i < sz; i++) { int toIndex = indexMap[i]; - String label = from.label(i); - if (to[toIndex] != null && ! to[toIndex].equals(label)) return false; + int label = Convert.safe2Int(from.numericLabel(i)); + if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label) + return false; to[toIndex] = label; } return true; @@ -417,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return typeBuilder.build(); } - private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) { - TensorAddress address = cell.getKey(); - String[] labels = new String[indexMap.length]; - for (int i = 0; i < labels.length; ++i) { - labels[i] = address.label(indexMap[i]); - } - return StringTensorAddress.unsafeOf(labels); - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 77e82b818a7..0985e48c4e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.impl.Convert; -import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Collections; @@ -164,7 +163,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int reducedLabelIndex = 0; for (int toKeep : indexesToKeep) reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return StringTensorAddress.unsafeOf(reducedLabels); + return TensorAddress.of(reducedLabels); } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ecd302db361..910c5900495 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.impl.StringTensorAddress; import java.util.HashMap; import java.util.Iterator; @@ -123,7 +122,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) reorderedLabels[toIndexes[i]] = address.label(i); - return StringTensorAddress.unsafeOf(reorderedLabels); + return TensorAddress.of(reorderedLabels); } private String toVectorString(List<String> elements) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java new file mode 100644 index 00000000000..0ab1454eb58 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java @@ -0,0 +1,70 @@ +package com.yahoo.tensor.impl; + + +import com.yahoo.tensor.Tensor; + +import java.util.Arrays; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +public class Label { + private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); + private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>(); + // Index 0 is unused, that is a valid positive number + // 1(-1) is reserved for the Tensor.INVALID_INDEX + private static volatile String [] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"}; + private static int numUniqeStrings = 2; + + private static String[] createSmallIndexesAsStrings(int count) { + String [] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private static int addNewUniqueString(String s) { + synchronized (string2Enum) { + if (numUniqeStrings >= uniqueStrings.length) { + uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2); + } + uniqueStrings[numUniqeStrings] = s; + return -numUniqeStrings++; + } + } + + private static String asNumericString(long index) { + return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + + private static boolean validNumericIndex(String s) { + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + if ((c < '0') || (c > '9')) return false; + } + return true; + } + + public static int toNumber(String s) { + if (s == null) { return Tensor.INVALID_INDEX; } + try { + if (validNumericIndex(s)) { + return Integer.parseInt(s); + } + } catch (NumberFormatException e) { + } + return string2Enum.computeIfAbsent(s, Label::addNewUniqueString); + } + public static String fromNumber(int v) { + if (v >= 0) { + return asNumericString(v); + } else { + if (v == Tensor.INVALID_INDEX) { return null; } + return uniqueStrings[-v]; + } + } + public static String fromNumber(long v) { + return fromNumber(Convert.safe2Int(v)); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java deleted file mode 100644 index 983074c9c90..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java +++ /dev/null @@ -1,59 +0,0 @@ -package com.yahoo.tensor.impl; - -import com.yahoo.tensor.TensorAddress; - -import java.util.Arrays; -import java.util.stream.Collectors; - -public final class NumericTensorAddress extends TensorAddress { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - - private final long[] labels; - - private static String[] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); - } - return asStrings; - } - - private NumericTensorAddress(long[] labels) { - this.labels = labels; - } - - public static NumericTensorAddress of(long ... labels) { - return new NumericTensorAddress(Arrays.copyOf(labels, labels.length)); - } - - public static NumericTensorAddress unsafeOf(long ... labels) { - return new NumericTensorAddress(labels); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return asString(labels[i]); } - - @Override - public long numericLabel(int i) { return labels[i]; } - - @Override - public TensorAddress withLabel(int index, long label) { - long[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = label; - return new NumericTensorAddress(labels); - } - - @Override - public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(NumericTensorAddress::asString).collect(Collectors.joining(",")) + ")"; - } - - public static String asString(long index) { - return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } - -} - diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java deleted file mode 100644 index ca54494a19c..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java +++ /dev/null @@ -1,52 +0,0 @@ -package com.yahoo.tensor.impl; - -import com.yahoo.tensor.TensorAddress; - -import java.util.Arrays; - -public final class StringTensorAddress extends TensorAddress { - - private final String[] labels; - - private StringTensorAddress(String [] labels) { - this.labels = labels; - } - - public static StringTensorAddress of(String[] labels) { - return new StringTensorAddress(Arrays.copyOf(labels, labels.length)); - } - - public static StringTensorAddress unsafeOf(String[] labels) { - return new StringTensorAddress(labels); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return labels[i]; } - - @Override - public long numericLabel(int i) { - try { - return Long.parseLong(labels[i]); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); - } - } - - @Override - public TensorAddress withLabel(int index, long label) { - String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = NumericTensorAddress.asString(label); - return new StringTensorAddress(labels); - } - - - @Override - public String toString() { - return "cell address (" + String.join(",", labels) + ")"; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java new file mode 100644 index 00000000000..31863c99a74 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java @@ -0,0 +1,136 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; + +import static com.yahoo.tensor.impl.Convert.safe2Int; +import static com.yahoo.tensor.impl.Label.toNumber; +import static com.yahoo.tensor.impl.Label.fromNumber; + +/** + * Parent of tensor address family centered around each dimension as int. + * A positive number represents a numeric index usable as a direect addressing. + * - 1 is representing an invalid/null address + * Other negative numbers are an enumeration maintained in {@link Label} + * + * @author baldersheim + */ +abstract public class TensorAddressAny extends TensorAddress { + @Override + public String label(int i) { + return fromNumber((int)numericLabel(i)); + } + + public static TensorAddress of() { + return TensorAddressEmpty.empty; + } + public static TensorAddress of(String label) { + return new TensorAddressAny1(toNumber(label)); + } + public static TensorAddress of(String label0, String label1) { + return new TensorAddressAny2(toNumber(label0), toNumber(label1)); + } + public static TensorAddress of(String label0, String label1, String label2) { + return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2)); + } + public static TensorAddress of(String label0, String label1, String label2, String label3) { + return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3)); + } + public static TensorAddress of(String [] labels) { + int [] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = toNumber(labels[i]); + } + return ofUnsafe(labelsAsInt); + } + public static TensorAddress of(int label) { + return new TensorAddressAny1(sanitize(label)); + } + public static TensorAddress of(int label0, int label1) { + return new TensorAddressAny2(sanitize(label0), sanitize(label1)); + } + public static TensorAddress of(int label0, int label1, int label2) { + return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2)); + } + public static TensorAddress of(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3)); + } + public static TensorAddress of(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> new TensorAddressAny1(sanitize(labels[0])); + case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1])); + case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2])); + case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3])); + default -> { + for (int i = 0; i < labels.length; i++) { + sanitize(labels[i]); + } + yield new TensorAddressAnyN(labels); + } + }; + } + public static TensorAddress of(long label) { + return of(safe2Int(label)); + } + + public static TensorAddress of(long label0, long label1) { + return of(safe2Int(label0), safe2Int(label1)); + } + + public static TensorAddress of(long label0, long label1, long label2) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2)); + } + + public static TensorAddress of(long label0, long label1, long label2, long label3) { + return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3)); + } + + public static TensorAddress of(long ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(safe2Int(labels[0])); + case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1])); + case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2])); + case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3])); + default -> { + int [] labelsAsInt = new int[labels.length]; + for (int i = 0; i < labels.length; i++) { + labelsAsInt[i] = safe2Int(labels[i]); + } + yield of(labelsAsInt); + } + }; + } + + private static TensorAddress ofUnsafe(int label) { + return new TensorAddressAny1(label); + } + private static TensorAddress ofUnsafe(int label0, int label1) { + return new TensorAddressAny2(label0, label1); + } + private static TensorAddress ofUnsafe(int label0, int label1, int label2) { + return new TensorAddressAny3(label0, label1, label2); + } + private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) { + return new TensorAddressAny4(label0, label1, label2, label3); + } + public static TensorAddress ofUnsafe(int ... labels) { + return switch (labels.length) { + case 0 -> of(); + case 1 -> ofUnsafe(labels[0]); + case 2 -> ofUnsafe(labels[0], labels[1]); + case 3 -> ofUnsafe(labels[0], labels[1], labels[2]); + case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]); + default -> new TensorAddressAnyN(labels); + }; + } + private static int sanitize(int label) { + if (label < Tensor.INVALID_INDEX) { + throw new IndexOutOfBoundsException("cell label " + label + " must be positive"); + } + return label; + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java new file mode 100644 index 00000000000..a2b0d318a50 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java @@ -0,0 +1,37 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * Single dimension + * @author baldersheim + */ +final class TensorAddressAny1 extends TensorAddressAny { + private final int label; + TensorAddressAny1(int label) { this.label = label; } + + @Override public int size() { return 1; } + + @Override + public long numericLabel(int i) { + if (i == 0) { + return label; + } + throw new IndexOutOfBoundsException("Index is not zero: " + i); + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label)); + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override public int hashCode() { return Math.abs(label); } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny1 any) && (label == any.label); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java new file mode 100644 index 00000000000..d77a689852f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * 2 dimensional address + * @author baldersheim + */ +final class TensorAddressAny2 extends TensorAddressAny { + private final int label0, label1; + TensorAddressAny2(int label0, int label1) { + this.label0 = label0; + this.label1 = label1; + } + + @Override public int size() { return 2; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1); + case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java new file mode 100644 index 00000000000..95e14bd375c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java @@ -0,0 +1,57 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * 3 dimensional address + * @author baldersheim + */ +final class TensorAddressAny3 extends TensorAddressAny { + private final int label0, label1, label2; + TensorAddressAny3(int label0, int label1, int label2) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + } + + @Override public int size() { return 3; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2); + case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2); + case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny3 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java new file mode 100644 index 00000000000..8a45483340e --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java @@ -0,0 +1,62 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import static java.lang.Math.abs; + +/** + * 4 dimensional address + * @author baldersheim + */ +final class TensorAddressAny4 extends TensorAddressAny { + private final int label0, label1, label2, label3; + TensorAddressAny4(int label0, int label1, int label2, int label3) { + this.label0 = label0; + this.label1 = label1; + this.label2 = label2; + this.label3 = label3; + } + + @Override public int size() { return 4; } + + @Override + public long numericLabel(int i) { + return switch (i) { + case 0 -> label0; + case 1 -> label1; + case 2 -> label2; + case 3 -> label3; + default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i); + }; + } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + return switch (labelIndex) { + case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3); + case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3); + case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3); + case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label)); + default -> throw new IllegalArgumentException("No label " + labelIndex); + }; + } + + @Override + public int hashCode() { + return abs(label0) | + (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) | + (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) | + (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1))))); + } + + @Override + public boolean equals(Object o) { + return (o instanceof TensorAddressAny4 any) && + (label0 == any.label0) && + (label1 == any.label1) && + (label2 == any.label2) && + (label3 == any.label3); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java index 65d97b41404..acd7ed60722 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java @@ -1,11 +1,48 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + package com.yahoo.tensor.impl; -public class TensorAddressAnyN extends TensorAdressAny { - private final long [] labels; - public TensorAddressAnyN(long [] labels) { +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; + +import static java.lang.Math.abs; + +/** + * N dimensional address + * @author baldersheim + */ +final class TensorAddressAnyN extends TensorAddressAny { + private final int [] labels; + TensorAddressAnyN(int [] labels) { + if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label"); this.labels = labels; } @Override public int size() { return labels.length; } @Override public long numericLabel(int i) { return labels[i]; } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + int [] copy = Arrays.copyOf(labels, labels.length); + copy[labelIndex] = Convert.safe2Int(label); + return new TensorAddressAnyN(copy); + } + + @Override public int hashCode() { + int hash = abs(labels[0]); + for (int i = 0; i < size(); i++) { + hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash))); + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false; + for (int i = 0; i < size(); i++) { + if (labels[i] != any.labels[i]) return false; + } + return true; + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java new file mode 100644 index 00000000000..2d9cd3eed78 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java @@ -0,0 +1,26 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +/** + * 0 dimesional/empty address + * @author baldersheim + */ +final class TensorAddressEmpty extends TensorAddressAny { + static TensorAddress empty = new TensorAddressEmpty(); + private TensorAddressEmpty() {} + @Override public int size() { return 0; } + @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); } + + @Override + public TensorAddress withLabel(int labelIndex, long label) { + throw new IllegalArgumentException("No label " + labelIndex); + } + + @Override + public int hashCode() { return 0; } + @Override + public boolean equals(Object o) { return o instanceof TensorAddressEmpty; } +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java deleted file mode 100644 index 87593784841..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java +++ /dev/null @@ -1,10 +0,0 @@ -package com.yahoo.tensor.impl; - -import com.yahoo.tensor.TensorAddress; - -abstract public class TensorAdressAny extends TensorAddress { - @Override - public String label(int i) { - return null; - } -} |