summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java43
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java60
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java130
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java70
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java59
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java52
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java136
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java49
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java62
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java43
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java10
22 files changed, 630 insertions, 298 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 53f50fc4d02..085f9172095 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -78,6 +78,9 @@ class IndexedDoubleTensor extends IndexedTensor {
@Override
public Builder cell(TensorAddress address, double value) {
+ if (address == null) {
+ return null;
+ }
values[(int)toValueIndex(address, sizes(), type)] = value;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index f26174d9576..a428524612b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -171,10 +171,8 @@ public abstract class IndexedTensor implements Tensor {
}
static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) {
- if (address.isEmpty()) return 0;
-
long valueIndex = 0;
- for (int i = 0; i < address.size(); i++) {
+ for (int i = 0, sz = address.size(); i < sz; i++) {
long label = address.numericLabel(i);
if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
@@ -893,8 +891,8 @@ public abstract class IndexedTensor implements Tensor {
private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
long size = 1;
- for (int iterateDimension : iterateDimensions)
- size *= sizes.size(iterateDimension);
+ for (int i = 0; i < iterateDimensions.size(); i++)
+ size *= sizes.size(iterateDimensions.get(i));
return size;
}
@@ -1060,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor {
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private long lastComputedSourceValueIndex = -1;
+ private long lastComputedSourceValueIndex = Tensor.INVALID_INDEX;
private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 95d1d70118a..d4469f447cb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -4,8 +4,6 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.impl.NumericTensorAddress;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Arrays;
@@ -111,7 +109,7 @@ public class MixedTensor implements Tensor {
return new Iterator<>() {
final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator();
DenseSubspace currBlock = null;
- final long[] labels = new long[index.indexedDimensions.size()];
+ final int[] labels = new int[index.indexedDimensions.size()];
int currOffset = index.denseSubspaceSize;
int prevOffset = -1;
@Override
@@ -127,7 +125,7 @@ public class MixedTensor implements Tensor {
if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1
index.denseOffsetToAddress(currOffset, labels);
}
- TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, labels);
+ TensorAddress fullAddr = currBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels);
prevOffset = currOffset;
double value = currBlock.cells[currOffset++];
return new Cell(fullAddr, value);
@@ -321,7 +319,7 @@ public class MixedTensor implements Tensor {
@Override
public Tensor.Builder cell(TensorAddress address, double value) {
- TensorAddress sparsePart = index.sparsePartialAddress(address);
+ TensorAddress sparsePart = address.sparsePartialAddress(index.sparseType, index.type.dimensions());
int denseOffset = index.denseOffsetOf(address);
double[] denseSubspace = denseSubspace(sparsePart);
denseSubspace[denseOffset] = value;
@@ -475,7 +473,7 @@ public class MixedTensor implements Tensor {
}
private DenseSubspace blockOf(TensorAddress address) {
- TensorAddress sparsePart = sparsePartialAddress(address);
+ TensorAddress sparsePart = address.sparsePartialAddress(sparseType, type.dimensions());
Integer blockNum = sparseMap.get(sparsePart);
if (blockNum == null || blockNum >= denseSubspaces.size()) {
return null;
@@ -502,19 +500,7 @@ public class MixedTensor implements Tensor {
return denseSubspaceSize;
}
- private TensorAddress sparsePartialAddress(TensorAddress address) {
- if (type.dimensions().size() != address.size())
- throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address);
- TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
- for (int i = 0; i < type.dimensions().size(); ++i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- if ( ! dimension.isIndexed())
- builder.add(dimension.name(), address.label(i));
- }
- return builder.build();
- }
-
- private void denseOffsetToAddress(long denseOffset, long [] labels) {
+ private void denseOffsetToAddress(long denseOffset, int [] labels) {
if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
throw new IllegalArgumentException("Offset out of bounds");
}
@@ -524,28 +510,11 @@ public class MixedTensor implements Tensor {
for (int i = 0; i < labels.length; ++i) {
innerSize /= indexedDimensionsSize[i];
- labels[i] = restSize / innerSize;
+ labels[i] = (int) (restSize / innerSize);
restSize %= innerSize;
}
}
- private TensorAddress fullAddressOf(TensorAddress sparsePart, long [] densePart) {
- String[] labels = new String[type.dimensions().size()];
- int mappedIndex = 0;
- int indexedIndex = 0;
- for (int i = 0; i < type.dimensions().size(); i++) {
- TensorType.Dimension d = type.dimensions().get(i);
- if (d.isIndexed()) {
- labels[i] = NumericTensorAddress.asString(densePart[indexedIndex]);
- indexedIndex++;
- } else {
- labels[i] = sparsePart.label(mappedIndex);
- mappedIndex++;
- }
- }
- return StringTensorAddress.unsafeOf(labels);
- }
-
@Override
public String toString() {
return "index into " + type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 3e41e6d94eb..da643d8c173 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -1,9 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.impl.StringTensorAddress;
-
-import java.util.Arrays;
+import com.yahoo.tensor.impl.Label;
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
@@ -20,7 +18,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final Object[] labels;
+ private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -37,15 +35,15 @@ public class PartialAddress {
public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return asLong(labels[i]);
- return -1;
+ return labels[i];
+ return Tensor.INVALID_INDEX;
}
/** Returns the label of this dimension, or null if no label is specified for it */
public String label(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
return null;
}
@@ -57,7 +55,7 @@ public class PartialAddress {
public String label(int i) {
if (i >= size())
throw new IllegalArgumentException("No label at position " + i + " in " + this);
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
}
public int size() { return dimensionNames.length; }
@@ -67,40 +65,14 @@ public class PartialAddress {
public TensorAddress asAddress(TensorType type) {
if (type.rank() != size())
throw new IllegalArgumentException(type + " has a different rank than " + this);
- if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) {
- long[] numericLabels = new long[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- long label = numericLabel(type.dimensions().get(i).name());
- if (label < 0)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- numericLabels[i] = label;
- }
- return TensorAddress.of(numericLabels);
- }
- else {
- String[] stringLabels = new String[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- String label = label(type.dimensions().get(i).name());
- if (label == null)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- stringLabels[i] = label;
- }
- return StringTensorAddress.unsafeOf(stringLabels);
- }
- }
-
- private long asLong(Object label) {
- if (label instanceof Long) {
- return (Long) label;
- }
- else {
- try {
- return Long.parseLong(label.toString());
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Label '" + label + "' is not numeric");
- }
+ long[] numericLabels = new long[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ long label = numericLabel(type.dimensions().get(i).name());
+ if (label == Tensor.INVALID_INDEX)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ numericLabels[i] = label;
}
+ return TensorAddress.of(numericLabels);
}
@Override
@@ -116,12 +88,12 @@ public class PartialAddress {
public static class Builder {
private String[] dimensionNames;
- private Object[] labels;
+ private long[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
- labels = new Object[size];
+ labels = new long[size];
}
public Builder add(String dimensionName, long label) {
@@ -133,7 +105,7 @@ public class PartialAddress {
public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
- labels[index] = label;
+ labels[index] = Label.toNumber(label);
index++;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index d034ac551f8..d650b88f202 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -20,7 +20,7 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import com.yahoo.tensor.functions.Expand;
-import com.yahoo.tensor.impl.NumericTensorAddress;
+import com.yahoo.tensor.impl.Label;
import java.util.ArrayList;
import java.util.Arrays;
@@ -55,6 +55,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming;
* @author bratseth
*/
public interface Tensor {
+ int INVALID_INDEX = -1;
// ----------------- Accessors
@@ -506,7 +507,7 @@ public interface Tensor {
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
- long getDirectIndex() { return -1; }
+ long getDirectIndex() { return INVALID_INDEX; }
/** Returns the value as a double */
@Override
@@ -626,7 +627,7 @@ public interface Tensor {
public TensorType type() { return tensorBuilder.type(); }
public CellBuilder label(String dimension, long label) {
- return label(dimension, NumericTensorAddress.asString(label));
+ return label(dimension, Label.fromNumber(label));
}
public Builder value(double cellValue) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 1b88a5d1b2f..59a5e2a49b1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,13 +1,11 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.impl.NumericTensorAddress;
-import com.yahoo.tensor.impl.StringTensorAddress;
-import net.jpountz.xxhash.XXHash32;
-import net.jpountz.xxhash.XXHashFactory;
+import com.yahoo.tensor.impl.Label;
+import com.yahoo.tensor.impl.TensorAddressAny;
-import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
/**
@@ -18,23 +16,25 @@ import java.util.Objects;
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32();
-
public static TensorAddress of(String[] labels) {
- return StringTensorAddress.of(labels);
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress ofLabels(String ... labels) {
- return StringTensorAddress.of(labels);
+ public static TensorAddress ofLabels(String... labels) {
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress of(long ... labels) {
- return NumericTensorAddress.of(labels);
+ public static TensorAddress of(long... labels) {
+ return TensorAddressAny.of(labels);
}
- private int cached_hash = 0;
+ public static TensorAddress of(int... labels) {
+ return TensorAddressAny.of(labels);
+ }
- /** Returns the number of labels in this */
+ /**
+ * Returns the number of labels in this
+ */
public abstract int size();
/**
@@ -67,32 +67,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
@Override
- public int hashCode() {
- if (cached_hash != 0) return cached_hash;
-
- int hash = 0;
- for (int i = 0; i < size(); i++) {
- String label = label(i);
- if (label != null) {
- byte [] buf = label.getBytes(StandardCharsets.UTF_8);
- hash = hasher.hash(buf, 0, buf.length, hash);
+ public String toString() {
+ StringBuilder sb = new StringBuilder("cell address (");
+ int sz = size();
+ if (sz > 0) {
+ sb.append(label(0));
+ for (int i = 1; i < sz; i++) {
+ sb.append(',').append(label(i));
}
}
- return cached_hash = hash;
- }
- @Override
- public boolean equals(Object o) {
- if (o == this) return true;
- if ( ! (o instanceof TensorAddress other)) return false;
- if (other.size() != this.size()) return false;
- for (int i = 0; i < this.size(); i++)
- if ( ! Objects.equals(this.label(i), other.label(i)))
- return false;
- return true;
+ return sb.append(')').toString();
}
- /** Returns this as a string on the appropriate form given the type */
+ /**
+ * Returns this as a string on the appropriate form given the type
+ */
public final String toString(TensorType type) {
StringBuilder b = new StringBuilder("{");
for (int i = 0; i < size(); i++) {
@@ -105,24 +95,72 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with appropriate quoting/escaping when necessary */
+ /**
+ * Returns a label as a string with appropriate quoting/escaping when necessary
+ */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
}
+ /** Returns an address with only some of the dimension */
+ public TensorAddress partialCopy(int[] indexMap) {
+ int[] labels = new int[indexMap.length];
+ for (int i = 0; i < labels.length; ++i) {
+ labels[i] = (int)numericLabel(indexMap[i]);
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */
+ public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) {
+ int [] labels = new int[dimensions.size()];
+ int mappedIndex = 0;
+ int indexedIndex = 0;
+ for (int i = 0; i < labels.length; i++) {
+ TensorType.Dimension d = dimensions.get(i);
+ if (d.isIndexed()) {
+ labels[i] = densePart[indexedIndex];
+ indexedIndex++;
+ } else {
+ labels[i] = (int)numericLabel(mappedIndex);
+ mappedIndex++;
+ }
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Extracts the sparse(non-indexed) dimensions of the address */
+ public TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) {
+ if (dimensions.size() != size())
+ throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this);
+ TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
+ for (int i = 0; i < dimensions.size(); ++i) {
+ TensorType.Dimension dimension = dimensions.get(i);
+ if ( ! dimension.isIndexed())
+ builder.add(dimension.name(), (int)numericLabel(i));
+ }
+ return builder.build();
+ }
+
/** Builder of a tensor address */
public static class Builder {
final TensorType type;
- final String[] labels;
+ final int[] labels;
+
+ private static int [] createEmptyLabels(int size) {
+ int [] labels = new int[size];
+ Arrays.fill(labels, Tensor.INVALID_INDEX);
+ return labels;
+ }
public Builder(TensorType type) {
- this(type, new String[type.dimensions().size()]);
+ this(type, createEmptyLabels(type.dimensions().size()));
}
- private Builder(TensorType type, String[] labels) {
+ private Builder(TensorType type, int[] labels) {
this.type = type;
this.labels = labels;
}
@@ -152,6 +190,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
int labelIndex = type.indexOfDimensionAsInt(dimension);
if ( labelIndex < 0)
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
+ labels[labelIndex] = Label.toNumber(label);
+ return this;
+ }
+ public Builder add(String dimension, int label) {
+ Objects.requireNonNull(dimension, "dimension cannot be null");
+ int labelIndex = type.indexOfDimensionAsInt(dimension);
+ if ( labelIndex < 0)
+ throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
labels[labelIndex] = label;
return this;
}
@@ -166,14 +212,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
void validate() {
for (int i = 0; i < labels.length; i++)
- if (labels[i] == null)
+ if (labels[i] == Tensor.INVALID_INDEX)
throw new IllegalArgumentException("Missing a label for dimension '" +
type.dimensions().get(i).name() + "' for " + type);
}
public TensorAddress build() {
validate();
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
}
@@ -185,7 +231,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
super(type);
}
- private PartialBuilder(TensorType type, String[] labels) {
+ private PartialBuilder(TensorType type, int[] labels) {
super(type, labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index dcfee88d599..62ed4ad683c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -204,7 +204,7 @@ public class TensorType {
for (int i = 0; i < dimensions.size(); i++)
if (dimensions.get(i).name().equals(dimension))
return i;
- return -1;
+ return Tensor.INVALID_INDEX;
}
/* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 866b710b72e..37ca7f979a1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.Arrays;
import java.util.HashMap;
@@ -173,7 +172,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType concatType, long concatOffset, String concatDimension) {
long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
+ Arrays.fill(combinedLabels, Tensor.INVALID_INDEX);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
@@ -192,7 +191,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.INVALID_INDEX);
return toIndexes;
}
@@ -209,7 +208,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != from.numericLabel(i)) return false;
to[toIndex] = from.numericLabel(i);
}
}
@@ -369,7 +368,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
- return StringTensorAddress.unsafeOf(labels);
+ return TensorAddress.of(labels);
}
Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index e0ac549651c..047d8ee6ef0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,9 +12,11 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
+import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -206,7 +208,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
- TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
+ TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
Double subspaceValue = subspace.getAsDouble(subaddress);
if (subspaceValue != null) {
builder.cell(supercell.getKey(),
@@ -226,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return subspaceIndexes;
}
- private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
- String[] subspaceLabels = new String[subspaceIndexes.length];
- for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
- return StringTensorAddress.unsafeOf(subspaceLabels);
- }
-
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
@@ -253,9 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
- Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
- Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
@@ -266,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
- PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize);
+ PartialAddress matchingBCells = sharedDimensionSize > 0
+ ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
+ : empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
@@ -278,12 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
}
+ private static PartialAddress empty = new PartialAddress.Builder(0).build();
private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
Set<String> retainDimensions, int sharedDimensionSize) {
PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
- for (int i = 0; i < addressType.dimensions().size(); i++)
- if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
+ for (int i = 0; i < addressType.dimensions().size(); i++) {
+ String dimension = addressType.dimensions().get(i).name();
+ if (retainDimensions.contains(dimension))
+ builder.add(dimension, address.numericLabel(i));
+ }
return builder.build();
}
@@ -338,7 +338,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
+ TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
@@ -346,7 +346,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
+ TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
@@ -377,11 +377,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
- String[] joinedLabels = new String[joinedType.dimensions().size()];
+ int[] joinedLabels = new int[joinedType.dimensions().size()];
+ Arrays.fill(joinedLabels, Tensor.INVALID_INDEX);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return StringTensorAddress.unsafeOf(joinedLabels);
+ return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
@@ -390,11 +391,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
- for (int i = 0; i < from.size(); i++) {
+ private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) {
+ for (int i = 0, sz = from.size(); i < sz; i++) {
int toIndex = indexMap[i];
- String label = from.label(i);
- if (to[toIndex] != null && ! to[toIndex].equals(label)) return false;
+ int label = Convert.safe2Int(from.numericLabel(i));
+ if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label)
+ return false;
to[toIndex] = label;
}
return true;
@@ -417,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return typeBuilder.build();
}
- private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
- TensorAddress address = cell.getKey();
- String[] labels = new String[indexMap.length];
- for (int i = 0; i < labels.length; ++i) {
- labels[i] = address.label(indexMap[i]);
- }
- return StringTensorAddress.unsafeOf(labels);
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 77e82b818a7..0985e48c4e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Collections;
@@ -164,7 +163,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int reducedLabelIndex = 0;
for (int toKeep : indexesToKeep)
reducedLabels[reducedLabelIndex++] = address.label(toKeep);
- return StringTensorAddress.unsafeOf(reducedLabels);
+ return TensorAddress.of(reducedLabels);
}
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ecd302db361..910c5900495 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -8,7 +8,6 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.HashMap;
import java.util.Iterator;
@@ -123,7 +122,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
reorderedLabels[toIndexes[i]] = address.label(i);
- return StringTensorAddress.unsafeOf(reorderedLabels);
+ return TensorAddress.of(reorderedLabels);
}
private String toVectorString(List<String> elements) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
new file mode 100644
index 00000000000..0ab1454eb58
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
@@ -0,0 +1,70 @@
+package com.yahoo.tensor.impl;
+
+
+import com.yahoo.tensor.Tensor;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class Label {
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
+ private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>();
+ // Index 0 is unused, that is a valid positive number
+ // 1(-1) is reserved for the Tensor.INVALID_INDEX
+ private static volatile String [] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"};
+ private static int numUniqeStrings = 2;
+
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static int addNewUniqueString(String s) {
+ synchronized (string2Enum) {
+ if (numUniqeStrings >= uniqueStrings.length) {
+ uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2);
+ }
+ uniqueStrings[numUniqeStrings] = s;
+ return -numUniqeStrings++;
+ }
+ }
+
+ private static String asNumericString(long index) {
+ return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
+ private static boolean validNumericIndex(String s) {
+ for (int i = 0; i < s.length(); i++) {
+ char c = s.charAt(i);
+ if ((c < '0') || (c > '9')) return false;
+ }
+ return true;
+ }
+
+ public static int toNumber(String s) {
+ if (s == null) { return Tensor.INVALID_INDEX; }
+ try {
+ if (validNumericIndex(s)) {
+ return Integer.parseInt(s);
+ }
+ } catch (NumberFormatException e) {
+ }
+ return string2Enum.computeIfAbsent(s, Label::addNewUniqueString);
+ }
+ public static String fromNumber(int v) {
+ if (v >= 0) {
+ return asNumericString(v);
+ } else {
+ if (v == Tensor.INVALID_INDEX) { return null; }
+ return uniqueStrings[-v];
+ }
+ }
+ public static String fromNumber(long v) {
+ return fromNumber(Convert.safe2Int(v));
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
deleted file mode 100644
index 983074c9c90..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
+++ /dev/null
@@ -1,59 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
-public final class NumericTensorAddress extends TensorAddress {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
-
- private final long[] labels;
-
- private static String[] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
-
- private NumericTensorAddress(long[] labels) {
- this.labels = labels;
- }
-
- public static NumericTensorAddress of(long ... labels) {
- return new NumericTensorAddress(Arrays.copyOf(labels, labels.length));
- }
-
- public static NumericTensorAddress unsafeOf(long ... labels) {
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return asString(labels[i]); }
-
- @Override
- public long numericLabel(int i) { return labels[i]; }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- long[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = label;
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public String toString() {
- return "cell address (" + Arrays.stream(labels).mapToObj(NumericTensorAddress::asString).collect(Collectors.joining(",")) + ")";
- }
-
- public static String asString(long index) {
- return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
-
-}
-
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
deleted file mode 100644
index ca54494a19c..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Arrays;
-
-public final class StringTensorAddress extends TensorAddress {
-
- private final String[] labels;
-
- private StringTensorAddress(String [] labels) {
- this.labels = labels;
- }
-
- public static StringTensorAddress of(String[] labels) {
- return new StringTensorAddress(Arrays.copyOf(labels, labels.length));
- }
-
- public static StringTensorAddress unsafeOf(String[] labels) {
- return new StringTensorAddress(labels);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return labels[i]; }
-
- @Override
- public long numericLabel(int i) {
- try {
- return Long.parseLong(labels[i]);
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
- }
- }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- String[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = NumericTensorAddress.asString(label);
- return new StringTensorAddress(labels);
- }
-
-
- @Override
- public String toString() {
- return "cell address (" + String.join(",", labels) + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
new file mode 100644
index 00000000000..31863c99a74
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
@@ -0,0 +1,136 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import static com.yahoo.tensor.impl.Convert.safe2Int;
+import static com.yahoo.tensor.impl.Label.toNumber;
+import static com.yahoo.tensor.impl.Label.fromNumber;
+
+/**
+ * Parent of tensor address family centered around each dimension as int.
+ * A positive number represents a numeric index usable as a direect addressing.
+ * - 1 is representing an invalid/null address
+ * Other negative numbers are an enumeration maintained in {@link Label}
+ *
+ * @author baldersheim
+ */
+abstract public class TensorAddressAny extends TensorAddress {
+ @Override
+ public String label(int i) {
+ return fromNumber((int)numericLabel(i));
+ }
+
+ public static TensorAddress of() {
+ return TensorAddressEmpty.empty;
+ }
+ public static TensorAddress of(String label) {
+ return new TensorAddressAny1(toNumber(label));
+ }
+ public static TensorAddress of(String label0, String label1) {
+ return new TensorAddressAny2(toNumber(label0), toNumber(label1));
+ }
+ public static TensorAddress of(String label0, String label1, String label2) {
+ return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2));
+ }
+ public static TensorAddress of(String label0, String label1, String label2, String label3) {
+ return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3));
+ }
+ public static TensorAddress of(String [] labels) {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = toNumber(labels[i]);
+ }
+ return ofUnsafe(labelsAsInt);
+ }
+ public static TensorAddress of(int label) {
+ return new TensorAddressAny1(sanitize(label));
+ }
+ public static TensorAddress of(int label0, int label1) {
+ return new TensorAddressAny2(sanitize(label0), sanitize(label1));
+ }
+ public static TensorAddress of(int label0, int label1, int label2) {
+ return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2));
+ }
+ public static TensorAddress of(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3));
+ }
+ public static TensorAddress of(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> new TensorAddressAny1(sanitize(labels[0]));
+ case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1]));
+ case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]));
+ case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3]));
+ default -> {
+ for (int i = 0; i < labels.length; i++) {
+ sanitize(labels[i]);
+ }
+ yield new TensorAddressAnyN(labels);
+ }
+ };
+ }
+ public static TensorAddress of(long label) {
+ return of(safe2Int(label));
+ }
+
+ public static TensorAddress of(long label0, long label1) {
+ return of(safe2Int(label0), safe2Int(label1));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2, long label3) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3));
+ }
+
+ public static TensorAddress of(long ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(safe2Int(labels[0]));
+ case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]));
+ case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]));
+ case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3]));
+ default -> {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = safe2Int(labels[i]);
+ }
+ yield of(labelsAsInt);
+ }
+ };
+ }
+
+ private static TensorAddress ofUnsafe(int label) {
+ return new TensorAddressAny1(label);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1) {
+ return new TensorAddressAny2(label0, label1);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2) {
+ return new TensorAddressAny3(label0, label1, label2);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(label0, label1, label2, label3);
+ }
+ public static TensorAddress ofUnsafe(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(labels[0]);
+ case 2 -> ofUnsafe(labels[0], labels[1]);
+ case 3 -> ofUnsafe(labels[0], labels[1], labels[2]);
+ case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]);
+ default -> new TensorAddressAnyN(labels);
+ };
+ }
+ private static int sanitize(int label) {
+ if (label < Tensor.INVALID_INDEX) {
+ throw new IndexOutOfBoundsException("cell label " + label + " must be positive");
+ }
+ return label;
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
new file mode 100644
index 00000000000..a2b0d318a50
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
@@ -0,0 +1,37 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * Single dimension
+ * @author baldersheim
+ */
+final class TensorAddressAny1 extends TensorAddressAny {
+ private final int label;
+ TensorAddressAny1(int label) { this.label = label; }
+
+ @Override public int size() { return 1; }
+
+ @Override
+ public long numericLabel(int i) {
+ if (i == 0) {
+ return label;
+ }
+ throw new IndexOutOfBoundsException("Index is not zero: " + i);
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label));
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override public int hashCode() { return Math.abs(label); }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny1 any) && (label == any.label);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
new file mode 100644
index 00000000000..d77a689852f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
@@ -0,0 +1,49 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 2 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny2 extends TensorAddressAny {
+ private final int label0, label1;
+ TensorAddressAny2(int label0, int label1) {
+ this.label0 = label0;
+ this.label1 = label1;
+ }
+
+ @Override public int size() { return 2; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1);
+ case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0)));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
new file mode 100644
index 00000000000..95e14bd375c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
@@ -0,0 +1,57 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 3 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny3 extends TensorAddressAny {
+ private final int label0, label1, label2;
+ TensorAddressAny3(int label0, int label1, int label2) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ }
+
+ @Override public int size() { return 3; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2);
+ case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2);
+ case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny3 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
new file mode 100644
index 00000000000..8a45483340e
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
@@ -0,0 +1,62 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 4 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny4 extends TensorAddressAny {
+ private final int label0, label1, label2, label3;
+ TensorAddressAny4(int label0, int label1, int label2, int label3) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ this.label3 = label3;
+ }
+
+ @Override public int size() { return 4; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ case 3 -> label3;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3);
+ case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3);
+ case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3);
+ case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) |
+ (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny4 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2) &&
+ (label3 == any.label3);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
index 65d97b41404..acd7ed60722 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
@@ -1,11 +1,48 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
package com.yahoo.tensor.impl;
-public class TensorAddressAnyN extends TensorAdressAny {
- private final long [] labels;
- public TensorAddressAnyN(long [] labels) {
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Arrays;
+
+import static java.lang.Math.abs;
+
+/**
+ * N dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAnyN extends TensorAddressAny {
+ private final int [] labels;
+ TensorAddressAnyN(int [] labels) {
+ if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label");
this.labels = labels;
}
@Override public int size() { return labels.length; }
@Override public long numericLabel(int i) { return labels[i]; }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ int [] copy = Arrays.copyOf(labels, labels.length);
+ copy[labelIndex] = Convert.safe2Int(label);
+ return new TensorAddressAnyN(copy);
+ }
+
+ @Override public int hashCode() {
+ int hash = abs(labels[0]);
+ for (int i = 0; i < size(); i++) {
+ hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash)));
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false;
+ for (int i = 0; i < size(); i++) {
+ if (labels[i] != any.labels[i]) return false;
+ }
+ return true;
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
new file mode 100644
index 00000000000..2d9cd3eed78
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
@@ -0,0 +1,26 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * 0 dimesional/empty address
+ * @author baldersheim
+ */
+final class TensorAddressEmpty extends TensorAddressAny {
+ static TensorAddress empty = new TensorAddressEmpty();
+ private TensorAddressEmpty() {}
+ @Override public int size() { return 0; }
+ @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override
+ public int hashCode() { return 0; }
+ @Override
+ public boolean equals(Object o) { return o instanceof TensorAddressEmpty; }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java
deleted file mode 100644
index 87593784841..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java
+++ /dev/null
@@ -1,10 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-abstract public class TensorAdressAny extends TensorAddress {
- @Override
- public String label(int i) {
- return null;
- }
-}