summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-24 01:33:20 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-25 20:17:49 +0100
commita1e14c645d88fecfab1abb0072e0abc26677e752 (patch)
tree9d22ff810f3694912005ddc455ca3425699b2fe7 /vespajlib
parentd9fb5104948ad6b8758e5a902af3fad0f9e506ce (diff)
Make tensor addresses integer based instead of as strings.
Positive numbers are direct indexes, while strings that does not represent numbers are enumerated and represented with negative integers.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java43
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java60
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java130
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java70
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java59
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java52
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java136
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java49
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java57
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java62
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java43
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java7
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java48
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java28
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java31
28 files changed, 725 insertions, 331 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index df75a6f6d1f..1f44d90f924 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1265,7 +1265,9 @@
"public static com.yahoo.tensor.Tensor from(java.lang.String)",
"public static com.yahoo.tensor.Tensor from(double)"
],
- "fields" : [ ]
+ "fields" : [
+ "public static final int INVALID_INDEX"
+ ]
},
"com.yahoo.tensor.TensorAddress$Builder" : {
"superClass" : "java.lang.Object",
@@ -1277,6 +1279,7 @@
"public void <init>(com.yahoo.tensor.TensorType)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String)",
"public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, java.lang.String)",
+ "public com.yahoo.tensor.TensorAddress$Builder add(java.lang.String, int)",
"public com.yahoo.tensor.TensorAddress$Builder copy()",
"public com.yahoo.tensor.TensorType type()",
"public com.yahoo.tensor.TensorAddress build()"
@@ -1309,16 +1312,19 @@
"public static com.yahoo.tensor.TensorAddress of(java.lang.String[])",
"public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])",
"public static varargs com.yahoo.tensor.TensorAddress of(long[])",
+ "public static varargs com.yahoo.tensor.TensorAddress of(int[])",
"public abstract int size()",
"public abstract java.lang.String label(int)",
"public abstract long numericLabel(int)",
"public abstract com.yahoo.tensor.TensorAddress withLabel(int, long)",
"public final boolean isEmpty()",
"public int compareTo(com.yahoo.tensor.TensorAddress)",
- "public int hashCode()",
- "public boolean equals(java.lang.Object)",
+ "public java.lang.String toString()",
"public final java.lang.String toString(com.yahoo.tensor.TensorType)",
"public static java.lang.String labelToString(java.lang.String)",
+ "public com.yahoo.tensor.TensorAddress partialCopy(int[])",
+ "public com.yahoo.tensor.TensorAddress fullAddressOf(java.util.List, int[])",
+ "public com.yahoo.tensor.TensorAddress sparsePartialAddress(com.yahoo.tensor.TensorType, java.util.List)",
"public bridge synthetic int compareTo(java.lang.Object)"
],
"fields" : [ ]
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 53f50fc4d02..085f9172095 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -78,6 +78,9 @@ class IndexedDoubleTensor extends IndexedTensor {
@Override
public Builder cell(TensorAddress address, double value) {
+ if (address == null) {
+ return null;
+ }
values[(int)toValueIndex(address, sizes(), type)] = value;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index f26174d9576..a428524612b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -171,10 +171,8 @@ public abstract class IndexedTensor implements Tensor {
}
static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) {
- if (address.isEmpty()) return 0;
-
long valueIndex = 0;
- for (int i = 0; i < address.size(); i++) {
+ for (int i = 0, sz = address.size(); i < sz; i++) {
long label = address.numericLabel(i);
if (label >= sizes.size(i))
throw new IllegalArgumentException(address + " is not within the bounds of " + type);
@@ -893,8 +891,8 @@ public abstract class IndexedTensor implements Tensor {
private static long computeSize(DimensionSizes sizes, List<Integer> iterateDimensions) {
long size = 1;
- for (int iterateDimension : iterateDimensions)
- size *= sizes.size(iterateDimension);
+ for (int i = 0; i < iterateDimensions.size(); i++)
+ size *= sizes.size(iterateDimensions.get(i));
return size;
}
@@ -1060,7 +1058,7 @@ public abstract class IndexedTensor implements Tensor {
/** In this case we can reuse the source index computation for the iteration index */
private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes {
- private long lastComputedSourceValueIndex = -1;
+ private long lastComputedSourceValueIndex = Tensor.INVALID_INDEX;
private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List<Integer> iterateDimensions, long[] initialIndexes, long size) {
super(sizes, sizes, iterateDimensions, initialIndexes, size);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 95d1d70118a..d4469f447cb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -4,8 +4,6 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
-import com.yahoo.tensor.impl.NumericTensorAddress;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Arrays;
@@ -111,7 +109,7 @@ public class MixedTensor implements Tensor {
return new Iterator<>() {
final Iterator<DenseSubspace> blockIterator = index.denseSubspaces.iterator();
DenseSubspace currBlock = null;
- final long[] labels = new long[index.indexedDimensions.size()];
+ final int[] labels = new int[index.indexedDimensions.size()];
int currOffset = index.denseSubspaceSize;
int prevOffset = -1;
@Override
@@ -127,7 +125,7 @@ public class MixedTensor implements Tensor {
if (currOffset != prevOffset) { // Optimization for index.denseSubspaceSize == 1
index.denseOffsetToAddress(currOffset, labels);
}
- TensorAddress fullAddr = index.fullAddressOf(currBlock.sparseAddress, labels);
+ TensorAddress fullAddr = currBlock.sparseAddress.fullAddressOf(index.type.dimensions(), labels);
prevOffset = currOffset;
double value = currBlock.cells[currOffset++];
return new Cell(fullAddr, value);
@@ -321,7 +319,7 @@ public class MixedTensor implements Tensor {
@Override
public Tensor.Builder cell(TensorAddress address, double value) {
- TensorAddress sparsePart = index.sparsePartialAddress(address);
+ TensorAddress sparsePart = address.sparsePartialAddress(index.sparseType, index.type.dimensions());
int denseOffset = index.denseOffsetOf(address);
double[] denseSubspace = denseSubspace(sparsePart);
denseSubspace[denseOffset] = value;
@@ -475,7 +473,7 @@ public class MixedTensor implements Tensor {
}
private DenseSubspace blockOf(TensorAddress address) {
- TensorAddress sparsePart = sparsePartialAddress(address);
+ TensorAddress sparsePart = address.sparsePartialAddress(sparseType, type.dimensions());
Integer blockNum = sparseMap.get(sparsePart);
if (blockNum == null || blockNum >= denseSubspaces.size()) {
return null;
@@ -502,19 +500,7 @@ public class MixedTensor implements Tensor {
return denseSubspaceSize;
}
- private TensorAddress sparsePartialAddress(TensorAddress address) {
- if (type.dimensions().size() != address.size())
- throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address);
- TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
- for (int i = 0; i < type.dimensions().size(); ++i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- if ( ! dimension.isIndexed())
- builder.add(dimension.name(), address.label(i));
- }
- return builder.build();
- }
-
- private void denseOffsetToAddress(long denseOffset, long [] labels) {
+ private void denseOffsetToAddress(long denseOffset, int [] labels) {
if (denseOffset < 0 || denseOffset > denseSubspaceSize) {
throw new IllegalArgumentException("Offset out of bounds");
}
@@ -524,28 +510,11 @@ public class MixedTensor implements Tensor {
for (int i = 0; i < labels.length; ++i) {
innerSize /= indexedDimensionsSize[i];
- labels[i] = restSize / innerSize;
+ labels[i] = (int) (restSize / innerSize);
restSize %= innerSize;
}
}
- private TensorAddress fullAddressOf(TensorAddress sparsePart, long [] densePart) {
- String[] labels = new String[type.dimensions().size()];
- int mappedIndex = 0;
- int indexedIndex = 0;
- for (int i = 0; i < type.dimensions().size(); i++) {
- TensorType.Dimension d = type.dimensions().get(i);
- if (d.isIndexed()) {
- labels[i] = NumericTensorAddress.asString(densePart[indexedIndex]);
- indexedIndex++;
- } else {
- labels[i] = sparsePart.label(mappedIndex);
- mappedIndex++;
- }
- }
- return StringTensorAddress.unsafeOf(labels);
- }
-
@Override
public String toString() {
return "index into " + type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 3e41e6d94eb..da643d8c173 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -1,9 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.impl.StringTensorAddress;
-
-import java.util.Arrays;
+import com.yahoo.tensor.impl.Label;
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
@@ -20,7 +18,7 @@ public class PartialAddress {
// Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final Object[] labels;
+ private final long[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -37,15 +35,15 @@ public class PartialAddress {
public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return asLong(labels[i]);
- return -1;
+ return labels[i];
+ return Tensor.INVALID_INDEX;
}
/** Returns the label of this dimension, or null if no label is specified for it */
public String label(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
return null;
}
@@ -57,7 +55,7 @@ public class PartialAddress {
public String label(int i) {
if (i >= size())
throw new IllegalArgumentException("No label at position " + i + " in " + this);
- return labels[i].toString();
+ return Label.fromNumber(labels[i]);
}
public int size() { return dimensionNames.length; }
@@ -67,40 +65,14 @@ public class PartialAddress {
public TensorAddress asAddress(TensorType type) {
if (type.rank() != size())
throw new IllegalArgumentException(type + " has a different rank than " + this);
- if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) {
- long[] numericLabels = new long[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- long label = numericLabel(type.dimensions().get(i).name());
- if (label < 0)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- numericLabels[i] = label;
- }
- return TensorAddress.of(numericLabels);
- }
- else {
- String[] stringLabels = new String[labels.length];
- for (int i = 0; i < type.dimensions().size(); i++) {
- String label = label(type.dimensions().get(i).name());
- if (label == null)
- throw new IllegalArgumentException(type + " dimension names does not match " + this);
- stringLabels[i] = label;
- }
- return StringTensorAddress.unsafeOf(stringLabels);
- }
- }
-
- private long asLong(Object label) {
- if (label instanceof Long) {
- return (Long) label;
- }
- else {
- try {
- return Long.parseLong(label.toString());
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Label '" + label + "' is not numeric");
- }
+ long[] numericLabels = new long[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ long label = numericLabel(type.dimensions().get(i).name());
+ if (label == Tensor.INVALID_INDEX)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ numericLabels[i] = label;
}
+ return TensorAddress.of(numericLabels);
}
@Override
@@ -116,12 +88,12 @@ public class PartialAddress {
public static class Builder {
private String[] dimensionNames;
- private Object[] labels;
+ private long[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
- labels = new Object[size];
+ labels = new long[size];
}
public Builder add(String dimensionName, long label) {
@@ -133,7 +105,7 @@ public class PartialAddress {
public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
- labels[index] = label;
+ labels[index] = Label.toNumber(label);
index++;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index d034ac551f8..d650b88f202 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -20,7 +20,7 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
import com.yahoo.tensor.functions.Expand;
-import com.yahoo.tensor.impl.NumericTensorAddress;
+import com.yahoo.tensor.impl.Label;
import java.util.ArrayList;
import java.util.Arrays;
@@ -55,6 +55,7 @@ import static com.yahoo.tensor.functions.ScalarFunctions.Hamming;
* @author bratseth
*/
public interface Tensor {
+ int INVALID_INDEX = -1;
// ----------------- Accessors
@@ -506,7 +507,7 @@ public interface Tensor {
* This is for optimizations mapping between tensors where this is possible without creating a
* TensorAddress.
*/
- long getDirectIndex() { return -1; }
+ long getDirectIndex() { return INVALID_INDEX; }
/** Returns the value as a double */
@Override
@@ -626,7 +627,7 @@ public interface Tensor {
public TensorType type() { return tensorBuilder.type(); }
public CellBuilder label(String dimension, long label) {
- return label(dimension, NumericTensorAddress.asString(label));
+ return label(dimension, Label.fromNumber(label));
}
public Builder value(double cellValue) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 1b88a5d1b2f..59a5e2a49b1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,13 +1,11 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.impl.NumericTensorAddress;
-import com.yahoo.tensor.impl.StringTensorAddress;
-import net.jpountz.xxhash.XXHash32;
-import net.jpountz.xxhash.XXHashFactory;
+import com.yahoo.tensor.impl.Label;
+import com.yahoo.tensor.impl.TensorAddressAny;
-import java.nio.charset.StandardCharsets;
import java.util.Arrays;
+import java.util.List;
import java.util.Objects;
/**
@@ -18,23 +16,25 @@ import java.util.Objects;
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final XXHash32 hasher = XXHashFactory.fastestJavaInstance().hash32();
-
public static TensorAddress of(String[] labels) {
- return StringTensorAddress.of(labels);
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress ofLabels(String ... labels) {
- return StringTensorAddress.of(labels);
+ public static TensorAddress ofLabels(String... labels) {
+ return TensorAddressAny.of(labels);
}
- public static TensorAddress of(long ... labels) {
- return NumericTensorAddress.of(labels);
+ public static TensorAddress of(long... labels) {
+ return TensorAddressAny.of(labels);
}
- private int cached_hash = 0;
+ public static TensorAddress of(int... labels) {
+ return TensorAddressAny.of(labels);
+ }
- /** Returns the number of labels in this */
+ /**
+ * Returns the number of labels in this
+ */
public abstract int size();
/**
@@ -67,32 +67,22 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
}
@Override
- public int hashCode() {
- if (cached_hash != 0) return cached_hash;
-
- int hash = 0;
- for (int i = 0; i < size(); i++) {
- String label = label(i);
- if (label != null) {
- byte [] buf = label.getBytes(StandardCharsets.UTF_8);
- hash = hasher.hash(buf, 0, buf.length, hash);
+ public String toString() {
+ StringBuilder sb = new StringBuilder("cell address (");
+ int sz = size();
+ if (sz > 0) {
+ sb.append(label(0));
+ for (int i = 1; i < sz; i++) {
+ sb.append(',').append(label(i));
}
}
- return cached_hash = hash;
- }
- @Override
- public boolean equals(Object o) {
- if (o == this) return true;
- if ( ! (o instanceof TensorAddress other)) return false;
- if (other.size() != this.size()) return false;
- for (int i = 0; i < this.size(); i++)
- if ( ! Objects.equals(this.label(i), other.label(i)))
- return false;
- return true;
+ return sb.append(')').toString();
}
- /** Returns this as a string on the appropriate form given the type */
+ /**
+ * Returns this as a string on the appropriate form given the type
+ */
public final String toString(TensorType type) {
StringBuilder b = new StringBuilder("{");
for (int i = 0; i < size(); i++) {
@@ -105,24 +95,72 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with appropriate quoting/escaping when necessary */
+ /**
+ * Returns a label as a string with appropriate quoting/escaping when necessary
+ */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
}
+ /** Returns an address with only some of the dimension */
+ public TensorAddress partialCopy(int[] indexMap) {
+ int[] labels = new int[indexMap.length];
+ for (int i = 0; i < labels.length; ++i) {
+ labels[i] = (int)numericLabel(indexMap[i]);
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Creates a complete address by taking the sparse dimmensions from this and the indexed from the densePart */
+ public TensorAddress fullAddressOf(List<TensorType.Dimension> dimensions, int [] densePart) {
+ int [] labels = new int[dimensions.size()];
+ int mappedIndex = 0;
+ int indexedIndex = 0;
+ for (int i = 0; i < labels.length; i++) {
+ TensorType.Dimension d = dimensions.get(i);
+ if (d.isIndexed()) {
+ labels[i] = densePart[indexedIndex];
+ indexedIndex++;
+ } else {
+ labels[i] = (int)numericLabel(mappedIndex);
+ mappedIndex++;
+ }
+ }
+ return TensorAddressAny.ofUnsafe(labels);
+ }
+
+ /** Extracts the sparse(non-indexed) dimensions of the address */
+ public TensorAddress sparsePartialAddress(TensorType sparseType, List<TensorType.Dimension> dimensions) {
+ if (dimensions.size() != size())
+ throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + this);
+ TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
+ for (int i = 0; i < dimensions.size(); ++i) {
+ TensorType.Dimension dimension = dimensions.get(i);
+ if ( ! dimension.isIndexed())
+ builder.add(dimension.name(), (int)numericLabel(i));
+ }
+ return builder.build();
+ }
+
/** Builder of a tensor address */
public static class Builder {
final TensorType type;
- final String[] labels;
+ final int[] labels;
+
+ private static int [] createEmptyLabels(int size) {
+ int [] labels = new int[size];
+ Arrays.fill(labels, Tensor.INVALID_INDEX);
+ return labels;
+ }
public Builder(TensorType type) {
- this(type, new String[type.dimensions().size()]);
+ this(type, createEmptyLabels(type.dimensions().size()));
}
- private Builder(TensorType type, String[] labels) {
+ private Builder(TensorType type, int[] labels) {
this.type = type;
this.labels = labels;
}
@@ -152,6 +190,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
int labelIndex = type.indexOfDimensionAsInt(dimension);
if ( labelIndex < 0)
throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
+ labels[labelIndex] = Label.toNumber(label);
+ return this;
+ }
+ public Builder add(String dimension, int label) {
+ Objects.requireNonNull(dimension, "dimension cannot be null");
+ int labelIndex = type.indexOfDimensionAsInt(dimension);
+ if ( labelIndex < 0)
+ throw new IllegalArgumentException(type + " does not contain dimension '" + dimension + "'");
labels[labelIndex] = label;
return this;
}
@@ -166,14 +212,14 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
void validate() {
for (int i = 0; i < labels.length; i++)
- if (labels[i] == null)
+ if (labels[i] == Tensor.INVALID_INDEX)
throw new IllegalArgumentException("Missing a label for dimension '" +
type.dimensions().get(i).name() + "' for " + type);
}
public TensorAddress build() {
validate();
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
}
@@ -185,7 +231,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
super(type);
}
- private PartialBuilder(TensorType type, String[] labels) {
+ private PartialBuilder(TensorType type, int[] labels) {
super(type, labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index dcfee88d599..62ed4ad683c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -204,7 +204,7 @@ public class TensorType {
for (int i = 0; i < dimensions.size(); i++)
if (dimensions.get(i).name().equals(dimension))
return i;
- return -1;
+ return Tensor.INVALID_INDEX;
}
/* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 866b710b72e..37ca7f979a1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.Arrays;
import java.util.HashMap;
@@ -173,7 +172,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType concatType, long concatOffset, String concatDimension) {
long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
+ Arrays.fill(combinedLabels, Tensor.INVALID_INDEX);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
@@ -192,7 +191,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
private int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(Tensor.INVALID_INDEX);
return toIndexes;
}
@@ -209,7 +208,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != from.numericLabel(i)) return false;
to[toIndex] = from.numericLabel(i);
}
}
@@ -369,7 +368,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
- return StringTensorAddress.unsafeOf(labels);
+ return TensorAddress.of(labels);
}
Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index e0ac549651c..047d8ee6ef0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -12,9 +12,11 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
+import com.yahoo.tensor.impl.Convert;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
@@ -206,7 +208,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> supercell = i.next();
- TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
+ TensorAddress subaddress = supercell.getKey().partialCopy(subspaceIndexes);
Double subspaceValue = subspace.getAsDouble(subaddress);
if (subspaceValue != null) {
builder.cell(supercell.getKey(),
@@ -226,13 +228,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return subspaceIndexes;
}
- private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
- String[] subspaceLabels = new String[subspaceIndexes.length];
- for (int i = 0; i < subspaceIndexes.length; i++)
- subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
- return StringTensorAddress.unsafeOf(subspaceLabels);
- }
-
/** Slow join which works for any two tensors */
private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
@@ -253,9 +248,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
DoubleBinaryOperator combinator) {
- Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> sharedDimensions = Set.copyOf(Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()));
int sharedDimensionSize = sharedDimensions.size(); // Expensive to compute size after intersection
- Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> dimensionsOnlyInA = Set.copyOf(Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()));
DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
@@ -266,7 +261,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
Tensor.Cell aCell = aSubspace.next();
- PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize);
+ PartialAddress matchingBCells = sharedDimensionSize > 0
+ ? partialAddress(a.type(), aSubspace.address(), sharedDimensions, sharedDimensionSize)
+ : empty;
// for each matching combination of dimensions ony in b
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
@@ -278,12 +275,15 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
}
+ private static PartialAddress empty = new PartialAddress.Builder(0).build();
private static PartialAddress partialAddress(TensorType addressType, TensorAddress address,
Set<String> retainDimensions, int sharedDimensionSize) {
PartialAddress.Builder builder = new PartialAddress.Builder(sharedDimensionSize);
- for (int i = 0; i < addressType.dimensions().size(); i++)
- if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
+ for (int i = 0; i < addressType.dimensions().size(); i++) {
+ String dimension = addressType.dimensions().get(i).name();
+ if (retainDimensions.contains(dimension))
+ builder.add(dimension, address.numericLabel(i));
+ }
return builder.build();
}
@@ -338,7 +338,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Map<TensorAddress, List<Tensor.Cell>> aCellsByCommonAddress = new HashMap<>(a.sizeAsInt());
for (Iterator<Tensor.Cell> cellIterator = a.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell aCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(aCell, aIndexesInCommon);
+ TensorAddress partialCommonAddress = aCell.getKey().partialCopy(aIndexesInCommon);
aCellsByCommonAddress.computeIfAbsent(partialCommonAddress, (key) -> new ArrayList<>()).add(aCell);
}
@@ -346,7 +346,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> cellIterator = b.cellIterator(); cellIterator.hasNext(); ) {
Tensor.Cell bCell = cellIterator.next();
- TensorAddress partialCommonAddress = partialCommonAddress(bCell, bIndexesInCommon);
+ TensorAddress partialCommonAddress = bCell.getKey().partialCopy(bIndexesInCommon);
for (Tensor.Cell aCell : aCellsByCommonAddress.getOrDefault(partialCommonAddress, List.of())) {
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aIndexesInJoined,
bCell.getKey(), bIndexesInJoined, joinedType);
@@ -377,11 +377,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType joinedType) {
- String[] joinedLabels = new String[joinedType.dimensions().size()];
+ int[] joinedLabels = new int[joinedType.dimensions().size()];
+ Arrays.fill(joinedLabels, Tensor.INVALID_INDEX);
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
if ( ! compatible) return null;
- return StringTensorAddress.unsafeOf(joinedLabels);
+ return TensorAddressAny.ofUnsafe(joinedLabels);
}
/**
@@ -390,11 +391,12 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
- for (int i = 0; i < from.size(); i++) {
+ private static boolean mapContent(TensorAddress from, int[] to, int[] indexMap) {
+ for (int i = 0, sz = from.size(); i < sz; i++) {
int toIndex = indexMap[i];
- String label = from.label(i);
- if (to[toIndex] != null && ! to[toIndex].equals(label)) return false;
+ int label = Convert.safe2Int(from.numericLabel(i));
+ if (to[toIndex] != Tensor.INVALID_INDEX && to[toIndex] != label)
+ return false;
to[toIndex] = label;
}
return true;
@@ -417,14 +419,5 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return typeBuilder.build();
}
- private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
- TensorAddress address = cell.getKey();
- String[] labels = new String[indexMap.length];
- for (int i = 0; i < labels.length; ++i) {
- labels[i] = address.label(indexMap[i]);
- }
- return StringTensorAddress.unsafeOf(labels);
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 77e82b818a7..0985e48c4e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.impl.Convert;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Collections;
@@ -164,7 +163,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int reducedLabelIndex = 0;
for (int toKeep : indexesToKeep)
reducedLabels[reducedLabelIndex++] = address.label(toKeep);
- return StringTensorAddress.unsafeOf(reducedLabels);
+ return TensorAddress.of(reducedLabels);
}
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ecd302db361..910c5900495 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -8,7 +8,6 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.HashMap;
import java.util.Iterator;
@@ -123,7 +122,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
reorderedLabels[toIndexes[i]] = address.label(i);
- return StringTensorAddress.unsafeOf(reorderedLabels);
+ return TensorAddress.of(reorderedLabels);
}
private String toVectorString(List<String> elements) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
new file mode 100644
index 00000000000..0ab1454eb58
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Label.java
@@ -0,0 +1,70 @@
+package com.yahoo.tensor.impl;
+
+
+import com.yahoo.tensor.Tensor;
+
+import java.util.Arrays;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+public class Label {
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
+ private final static Map<String, Integer> string2Enum = new ConcurrentHashMap<>();
+ // Index 0 is unused, that is a valid positive number
+ // 1(-1) is reserved for the Tensor.INVALID_INDEX
+ private static volatile String [] uniqueStrings = {"UNIQUE_UNUSED_MAGIC", "Tensor.INVALID_INDEX"};
+ private static int numUniqeStrings = 2;
+
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static int addNewUniqueString(String s) {
+ synchronized (string2Enum) {
+ if (numUniqeStrings >= uniqueStrings.length) {
+ uniqueStrings = Arrays.copyOf(uniqueStrings, uniqueStrings.length*2);
+ }
+ uniqueStrings[numUniqeStrings] = s;
+ return -numUniqeStrings++;
+ }
+ }
+
+ private static String asNumericString(long index) {
+ return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
+ private static boolean validNumericIndex(String s) {
+ for (int i = 0; i < s.length(); i++) {
+ char c = s.charAt(i);
+ if ((c < '0') || (c > '9')) return false;
+ }
+ return true;
+ }
+
+ public static int toNumber(String s) {
+ if (s == null) { return Tensor.INVALID_INDEX; }
+ try {
+ if (validNumericIndex(s)) {
+ return Integer.parseInt(s);
+ }
+ } catch (NumberFormatException e) {
+ }
+ return string2Enum.computeIfAbsent(s, Label::addNewUniqueString);
+ }
+ public static String fromNumber(int v) {
+ if (v >= 0) {
+ return asNumericString(v);
+ } else {
+ if (v == Tensor.INVALID_INDEX) { return null; }
+ return uniqueStrings[-v];
+ }
+ }
+ public static String fromNumber(long v) {
+ return fromNumber(Convert.safe2Int(v));
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
deleted file mode 100644
index 983074c9c90..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
+++ /dev/null
@@ -1,59 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Arrays;
-import java.util.stream.Collectors;
-
-public final class NumericTensorAddress extends TensorAddress {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
-
- private final long[] labels;
-
- private static String[] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
-
- private NumericTensorAddress(long[] labels) {
- this.labels = labels;
- }
-
- public static NumericTensorAddress of(long ... labels) {
- return new NumericTensorAddress(Arrays.copyOf(labels, labels.length));
- }
-
- public static NumericTensorAddress unsafeOf(long ... labels) {
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return asString(labels[i]); }
-
- @Override
- public long numericLabel(int i) { return labels[i]; }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- long[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = label;
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public String toString() {
- return "cell address (" + Arrays.stream(labels).mapToObj(NumericTensorAddress::asString).collect(Collectors.joining(",")) + ")";
- }
-
- public static String asString(long index) {
- return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
-
-}
-
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
deleted file mode 100644
index ca54494a19c..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
+++ /dev/null
@@ -1,52 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-import java.util.Arrays;
-
-public final class StringTensorAddress extends TensorAddress {
-
- private final String[] labels;
-
- private StringTensorAddress(String [] labels) {
- this.labels = labels;
- }
-
- public static StringTensorAddress of(String[] labels) {
- return new StringTensorAddress(Arrays.copyOf(labels, labels.length));
- }
-
- public static StringTensorAddress unsafeOf(String[] labels) {
- return new StringTensorAddress(labels);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return labels[i]; }
-
- @Override
- public long numericLabel(int i) {
- try {
- return Long.parseLong(labels[i]);
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
- }
- }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- String[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = NumericTensorAddress.asString(label);
- return new StringTensorAddress(labels);
- }
-
-
- @Override
- public String toString() {
- return "cell address (" + String.join(",", labels) + ")";
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
new file mode 100644
index 00000000000..31863c99a74
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
@@ -0,0 +1,136 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+
+import static com.yahoo.tensor.impl.Convert.safe2Int;
+import static com.yahoo.tensor.impl.Label.toNumber;
+import static com.yahoo.tensor.impl.Label.fromNumber;
+
+/**
+ * Parent of tensor address family centered around each dimension as int.
+ * A positive number represents a numeric index usable as a direect addressing.
+ * - 1 is representing an invalid/null address
+ * Other negative numbers are an enumeration maintained in {@link Label}
+ *
+ * @author baldersheim
+ */
+abstract public class TensorAddressAny extends TensorAddress {
+ @Override
+ public String label(int i) {
+ return fromNumber((int)numericLabel(i));
+ }
+
+ public static TensorAddress of() {
+ return TensorAddressEmpty.empty;
+ }
+ public static TensorAddress of(String label) {
+ return new TensorAddressAny1(toNumber(label));
+ }
+ public static TensorAddress of(String label0, String label1) {
+ return new TensorAddressAny2(toNumber(label0), toNumber(label1));
+ }
+ public static TensorAddress of(String label0, String label1, String label2) {
+ return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2));
+ }
+ public static TensorAddress of(String label0, String label1, String label2, String label3) {
+ return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3));
+ }
+ public static TensorAddress of(String [] labels) {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = toNumber(labels[i]);
+ }
+ return ofUnsafe(labelsAsInt);
+ }
+ public static TensorAddress of(int label) {
+ return new TensorAddressAny1(sanitize(label));
+ }
+ public static TensorAddress of(int label0, int label1) {
+ return new TensorAddressAny2(sanitize(label0), sanitize(label1));
+ }
+ public static TensorAddress of(int label0, int label1, int label2) {
+ return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2));
+ }
+ public static TensorAddress of(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3));
+ }
+ public static TensorAddress of(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> new TensorAddressAny1(sanitize(labels[0]));
+ case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1]));
+ case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]));
+ case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3]));
+ default -> {
+ for (int i = 0; i < labels.length; i++) {
+ sanitize(labels[i]);
+ }
+ yield new TensorAddressAnyN(labels);
+ }
+ };
+ }
+ public static TensorAddress of(long label) {
+ return of(safe2Int(label));
+ }
+
+ public static TensorAddress of(long label0, long label1) {
+ return of(safe2Int(label0), safe2Int(label1));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2));
+ }
+
+ public static TensorAddress of(long label0, long label1, long label2, long label3) {
+ return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3));
+ }
+
+ public static TensorAddress of(long ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(safe2Int(labels[0]));
+ case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]));
+ case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]));
+ case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3]));
+ default -> {
+ int [] labelsAsInt = new int[labels.length];
+ for (int i = 0; i < labels.length; i++) {
+ labelsAsInt[i] = safe2Int(labels[i]);
+ }
+ yield of(labelsAsInt);
+ }
+ };
+ }
+
+ private static TensorAddress ofUnsafe(int label) {
+ return new TensorAddressAny1(label);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1) {
+ return new TensorAddressAny2(label0, label1);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2) {
+ return new TensorAddressAny3(label0, label1, label2);
+ }
+ private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) {
+ return new TensorAddressAny4(label0, label1, label2, label3);
+ }
+ public static TensorAddress ofUnsafe(int ... labels) {
+ return switch (labels.length) {
+ case 0 -> of();
+ case 1 -> ofUnsafe(labels[0]);
+ case 2 -> ofUnsafe(labels[0], labels[1]);
+ case 3 -> ofUnsafe(labels[0], labels[1], labels[2]);
+ case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]);
+ default -> new TensorAddressAnyN(labels);
+ };
+ }
+ private static int sanitize(int label) {
+ if (label < Tensor.INVALID_INDEX) {
+ throw new IndexOutOfBoundsException("cell label " + label + " must be positive");
+ }
+ return label;
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
new file mode 100644
index 00000000000..a2b0d318a50
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny1.java
@@ -0,0 +1,37 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * Single dimension
+ * @author baldersheim
+ */
+final class TensorAddressAny1 extends TensorAddressAny {
+ private final int label;
+ TensorAddressAny1(int label) { this.label = label; }
+
+ @Override public int size() { return 1; }
+
+ @Override
+ public long numericLabel(int i) {
+ if (i == 0) {
+ return label;
+ }
+ throw new IndexOutOfBoundsException("Index is not zero: " + i);
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ if (labelIndex == 0) return new TensorAddressAny1(Convert.safe2Int(label));
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override public int hashCode() { return Math.abs(label); }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny1 any) && (label == any.label);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
new file mode 100644
index 00000000000..d77a689852f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny2.java
@@ -0,0 +1,49 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 2 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny2 extends TensorAddressAny {
+ private final int label0, label1;
+ TensorAddressAny2(int label0, int label1) {
+ this.label0 = label0;
+ this.label1 = label1;
+ }
+
+ @Override public int size() { return 2; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,1]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny2(Convert.safe2Int(label), label1);
+ case 1 -> new TensorAddressAny2(label0, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) | (abs(label1) << 32 - Integer.numberOfLeadingZeros(abs(label0)));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny2 any) && (label0 == any.label0) && (label1 == any.label1);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
new file mode 100644
index 00000000000..95e14bd375c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny3.java
@@ -0,0 +1,57 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 3 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny3 extends TensorAddressAny {
+ private final int label0, label1, label2;
+ TensorAddressAny3(int label0, int label1, int label2) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ }
+
+ @Override public int size() { return 3; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,2]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny3(Convert.safe2Int(label), label1, label2);
+ case 1 -> new TensorAddressAny3(label0, Convert.safe2Int(label), label2);
+ case 2 -> new TensorAddressAny3(label0, label1, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny3 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
new file mode 100644
index 00000000000..8a45483340e
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny4.java
@@ -0,0 +1,62 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import static java.lang.Math.abs;
+
+/**
+ * 4 dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAny4 extends TensorAddressAny {
+ private final int label0, label1, label2, label3;
+ TensorAddressAny4(int label0, int label1, int label2, int label3) {
+ this.label0 = label0;
+ this.label1 = label1;
+ this.label2 = label2;
+ this.label3 = label3;
+ }
+
+ @Override public int size() { return 4; }
+
+ @Override
+ public long numericLabel(int i) {
+ return switch (i) {
+ case 0 -> label0;
+ case 1 -> label1;
+ case 2 -> label2;
+ case 3 -> label3;
+ default -> throw new IndexOutOfBoundsException("Index is not in [0,3]: " + i);
+ };
+ }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ return switch (labelIndex) {
+ case 0 -> new TensorAddressAny4(Convert.safe2Int(label), label1, label2, label3);
+ case 1 -> new TensorAddressAny4(label0, Convert.safe2Int(label), label2, label3);
+ case 2 -> new TensorAddressAny4(label0, label1, Convert.safe2Int(label), label3);
+ case 3 -> new TensorAddressAny4(label0, label1, label2, Convert.safe2Int(label));
+ default -> throw new IllegalArgumentException("No label " + labelIndex);
+ };
+ }
+
+ @Override
+ public int hashCode() {
+ return abs(label0) |
+ (abs(label1) << (1*32 - Integer.numberOfLeadingZeros(abs(label0)))) |
+ (abs(label2) << (2*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1))))) |
+ (abs(label3) << (3*32 - (Integer.numberOfLeadingZeros(abs(label0)) + Integer.numberOfLeadingZeros(abs(label1)) + Integer.numberOfLeadingZeros(abs(label1)))));
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return (o instanceof TensorAddressAny4 any) &&
+ (label0 == any.label0) &&
+ (label1 == any.label1) &&
+ (label2 == any.label2) &&
+ (label3 == any.label3);
+ }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
index 65d97b41404..acd7ed60722 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAnyN.java
@@ -1,11 +1,48 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
package com.yahoo.tensor.impl;
-public class TensorAddressAnyN extends TensorAdressAny {
- private final long [] labels;
- public TensorAddressAnyN(long [] labels) {
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Arrays;
+
+import static java.lang.Math.abs;
+
+/**
+ * N dimensional address
+ * @author baldersheim
+ */
+final class TensorAddressAnyN extends TensorAddressAny {
+ private final int [] labels;
+ TensorAddressAnyN(int [] labels) {
+ if (labels.length < 1) throw new IllegalArgumentException("Need at least 1 label");
this.labels = labels;
}
@Override public int size() { return labels.length; }
@Override public long numericLabel(int i) { return labels[i]; }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ int [] copy = Arrays.copyOf(labels, labels.length);
+ copy[labelIndex] = Convert.safe2Int(label);
+ return new TensorAddressAnyN(copy);
+ }
+
+ @Override public int hashCode() {
+ int hash = abs(labels[0]);
+ for (int i = 0; i < size(); i++) {
+ hash = hash | (abs(labels[i]) << (32 - Integer.numberOfLeadingZeros(hash)));
+ }
+ return hash;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (! (o instanceof TensorAddressAnyN any) || (size() != any.size())) return false;
+ for (int i = 0; i < size(); i++) {
+ if (labels[i] != any.labels[i]) return false;
+ }
+ return true;
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
new file mode 100644
index 00000000000..2d9cd3eed78
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressEmpty.java
@@ -0,0 +1,26 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+/**
+ * 0 dimesional/empty address
+ * @author baldersheim
+ */
+final class TensorAddressEmpty extends TensorAddressAny {
+ static TensorAddress empty = new TensorAddressEmpty();
+ private TensorAddressEmpty() {}
+ @Override public int size() { return 0; }
+ @Override public long numericLabel(int i) { throw new IllegalArgumentException("Empty address with no labels"); }
+
+ @Override
+ public TensorAddress withLabel(int labelIndex, long label) {
+ throw new IllegalArgumentException("No label " + labelIndex);
+ }
+
+ @Override
+ public int hashCode() { return 0; }
+ @Override
+ public boolean equals(Object o) { return o instanceof TensorAddressEmpty; }
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java
deleted file mode 100644
index 87593784841..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAdressAny.java
+++ /dev/null
@@ -1,10 +0,0 @@
-package com.yahoo.tensor.impl;
-
-import com.yahoo.tensor.TensorAddress;
-
-abstract public class TensorAdressAny extends TensorAddress {
- @Override
- public String label(int i) {
- return null;
- }
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index afc95d295f0..528ca57d256 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -46,12 +46,7 @@ public class IndexedTensorTestCase {
@Test
public void testNegativeLabels() {
- TensorAddress numeric = TensorAddress.of(-1, 0, 1, 1234567, -1234567);
- assertEquals("-1", numeric.label(0));
- assertEquals("0", numeric.label(1));
- assertEquals("1", numeric.label(2));
- assertEquals("1234567", numeric.label(3));
- assertEquals("-1234567", numeric.label(4));
+ assertThrows(IndexOutOfBoundsException.class, () ->TensorAddress.of(-1, 0, 1, 1234567, -1234567));
}
private void verifyFloat(String spec) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
index 79202e3f07e..472ebca2360 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
@@ -1,8 +1,13 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import static com.yahoo.tensor.TensorAddress.of;
+import static com.yahoo.tensor.TensorAddress.ofLabels;
+
import org.junit.jupiter.api.Test;
+import java.util.Arrays;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotEquals;
@@ -12,33 +17,56 @@ import static org.junit.jupiter.api.Assertions.assertNotEquals;
* @author baldersheim
*/
public class TensorAddressTestCase {
- private void equal(Object a, Object b) {
+ public static void equal(TensorAddress a, TensorAddress b) {
assertEquals(a.hashCode(), b.hashCode());
assertEquals(a, b);
+ assertEquals(a.size(), b.size());
+ for (int i = 0; i < a.size(); i++) {
+ assertEquals(a.label(i), b.label(i));
+ assertEquals(a.numericLabel(i), b.numericLabel(i));
+ }
}
- private void notEqual(Object a, Object b) {
+ public static void notEqual(TensorAddress a, TensorAddress b) {
assertNotEquals(a.hashCode(), b.hashCode()); // This might not hold, but is bad if not very rare
assertNotEquals(a, b);
}
@Test
void testStringVersusNumericAddressEquality() {
- equal(TensorAddress.ofLabels("1"), TensorAddress.of(1));
+ equal(ofLabels("1"), of(1));
}
@Test
void testInEquality() {
- notEqual(TensorAddress.ofLabels("1"), TensorAddress.ofLabels("2"));
- notEqual(TensorAddress.of(1), TensorAddress.of(2));
+ notEqual(ofLabels("1"), ofLabels("2"));
+ notEqual(of(1), of(2));
}
@Test
void testDimensionsEffectsEqualityAndHash() {
- notEqual(TensorAddress.ofLabels("1"), TensorAddress.ofLabels("1", "1"));
- notEqual(TensorAddress.of(1), TensorAddress.of(1, 1));
+ notEqual(ofLabels("1"), ofLabels("1", "1"));
+ notEqual(of(1), of(1, 1));
}
@Test
void testAllowNullDimension() {
- TensorAddress s1 = TensorAddress.ofLabels("1", null, "2");
- TensorAddress s2 = TensorAddress.ofLabels("1", "2");
+ TensorAddress s1 = ofLabels("1", null, "2");
+ TensorAddress s2 = ofLabels("1", "2");
assertNotEquals(s1, s2);
- assertEquals(s1.hashCode(), s2.hashCode());
+ assertEquals(-1, s1.numericLabel(1));
+ assertEquals(null, s1.label(1));
+ }
+
+ private static void verifyWithLabel(int dimensions) {
+ int [] indexes = new int[dimensions];
+ Arrays.fill(indexes, 1);
+ TensorAddress next = of(indexes);
+ for (int i = 0; i < dimensions; i++) {
+ indexes[i] = 3;
+ assertEquals(of(indexes), next = next.withLabel(i, 3));
+ }
}
+ @Test
+ void testWithLabel() {
+ for (int i=0; i < 10; i++) {
+ verifyWithLabel(i);
+ }
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 74237a218fb..91880c9af93 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -73,7 +73,7 @@ public class TensorFunctionBenchmark {
for (int i = 0; i < vectorCount; i++) {
Tensor.Builder builder = Tensor.Builder.of(type);
for (int j = 0; j < vectorSize; j++) {
- builder.cell().label("x", String.valueOf(j)).value(random.nextDouble());
+ builder.cell().label("x", j).value(random.nextDouble());
}
tensors.add(builder.build());
}
@@ -88,8 +88,8 @@ public class TensorFunctionBenchmark {
for (int i = 0; i < vectorCount; i++) {
for (int j = 0; j < vectorSize; j++) {
builder.cell()
- .label("i", String.valueOf(i))
- .label("x", String.valueOf(j))
+ .label("i", i)
+ .label("x", j)
.value(random.nextDouble());
}
}
@@ -110,6 +110,7 @@ public class TensorFunctionBenchmark {
double time = 0;
// ---------------- Indexed unbound:
+
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time);
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
@@ -132,6 +133,7 @@ public class TensorFunctionBenchmark {
// ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations):
time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time);
+
time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time);
@@ -143,16 +145,16 @@ public class TensorFunctionBenchmark {
System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time);
/* 2.4Ghz Intel Core i9, Macbook Pro 2019
- * Indexed unbound vectors, time per join: 0,067 ms
- * Indexed unbound matrix, time per join: 0,107 ms
- * Indexed bound vectors, time per join: 0,068 ms
- * Indexed bound matrix, time per join: 0,105 ms
- * Mapped vectors, time per join: 1,342 ms
- * Mapped matrix, time per join: 3,448 ms
- * Indexed vectors, x space time per join: 6,398 ms
- * Indexed matrix, x space time per join: 3,220 ms
- * Mapped vectors, x space time per join: 14,984 ms
- * Mapped matrix, x space time per join: 19,873 ms
+ Indexed unbound vectors, time per join: 0,066 ms
+ Indexed unbound matrix, time per join: 0,108 ms
+ Indexed bound vectors, time per join: 0,068 ms
+ Indexed bound matrix, time per join: 0,106 ms
+ Mapped vectors, time per join: 0,845 ms
+ Mapped matrix, time per join: 1,779 ms
+ Indexed vectors, x space time per join: 5,778 ms
+ Indexed matrix, x space time per join: 3,342 ms
+ Mapped vectors, x space time per join: 8,184 ms
+ Mapped matrix, x space time per join: 11,547 ms
*/
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 7cf0bd35b38..85619dca16c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -33,7 +33,7 @@ public class DynamicTensorTestCase {
public void testDynamicMappedRank1TensorFunction() {
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor<Name> t2 = DynamicTensor.from(sparse,
- Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
+ java.util.Map.of(new TensorAddress.Builder(sparse).add("x", "a").build(),
new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java
new file mode 100644
index 00000000000..ae13b95052b
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/impl/TensorAddressAnyTestCase.java
@@ -0,0 +1,31 @@
+package com.yahoo.tensor.impl;
+
+import static com.yahoo.tensor.impl.TensorAddressAny.of;
+import static com.yahoo.tensor.TensorAddressTestCase.equal;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import org.junit.jupiter.api.Test;
+
+public class TensorAddressAnyTestCase {
+ @Test
+ void testSize() {
+ for (int i = 0; i < 10; i++) {
+ int [] indexes = new int [i];
+ assertEquals(i, of(indexes).size());
+ }
+ }
+
+ @Test
+ void testNumericStringEquality() {
+ for (int i = 0; i < 10; i++) {
+ int [] numericIndexes = new int [i];
+ String [] stringIndexes = new String[i];
+ for (int j = 0; j < i; j++) {
+ numericIndexes[j] = j;
+ stringIndexes[j] = String.valueOf(j);
+ }
+ equal(of(stringIndexes), of(numericIndexes));
+ }
+ }
+
+}