aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java395
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java81
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java4
27 files changed, 435 insertions, 213 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index c4588b79fa9..ca396ae5bf2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -355,6 +355,10 @@ public interface Tensor {
@Override
boolean equals(Object o);
+ /** Returns a hash computed deterministically from the content of this tensor */
+ @Override
+ int hashCode();
+
/**
* Implement here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index dbc8396d701..8a9a85d343c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.Optional;
/**
@@ -62,6 +63,9 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
return name;
}
+ @Override
+ public int hashCode() { return Objects.hash("variableTensor", name, requiredType); }
+
private void verifyType(TensorType givenType) {
if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 55dd8a7bc8a..d2762ad762d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -52,4 +52,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("argmax", argument, dimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index f1f0b9d67b0..baedf41bcb8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -52,4 +52,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("argmin", argument, dimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index 09f84e6747e..176847cec93 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -111,4 +111,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("cellcast", argument, valueType); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 6d4b15be991..abf0d89c2b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -31,6 +31,191 @@ import java.util.stream.Collectors;
*/
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+ enum DimType { common, separate, concat }
+
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
+ private final String dimension;
+
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(dimension, "The dimension cannot be null");
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 2)
+ throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
+ return oldEvaluate(a, b);
+ }
+ var helper = new Helper(a, b, dimension);
+ return helper.result;
+ }
+
+ private Tensor oldEvaluate(Tensor a, Tensor b) {
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
+
+ IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
+ IndexedTensor bIndexed = (IndexedTensor) b;
+ DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+
+ Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ int[] aToIndexes = mapIndexes(a.type(), concatType);
+ int[] bToIndexes = mapIndexes(b.type(), concatType);
+ concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
+ concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
+ return builder.build();
+ }
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
+ Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
+ IndexedTensor.SubspaceIterator iaSubspace = ia.next();
+ TensorAddress aAddress = iaSubspace.address();
+ for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
+ IndexedTensor.SubspaceIterator ibSubspace = ib.next();
+ while (ibSubspace.hasNext()) {
+ Tensor.Cell bCell = ibSubspace.next();
+ TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
+ concatType, offset, dimension);
+ if (combinedAddress == null) continue; // incompatible
+
+ builder.cell(combinedAddress, bCell.getValue());
+ }
+ iaSubspace.reset();
+ }
+ }
+ }
+
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
+ Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
+ if ( dimension.isPresent() ) {
+ if ( ! dimension.get().isIndexed())
+ throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
+ "' requires that dimension to be indexed or absent, " +
+ "but got a tensor with type " + tensor.type());
+ return tensor;
+ }
+ else { // extend tensor with this dimension
+ if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ throw new IllegalArgumentException("Concat requires an indexed tensor, " +
+ "but got a tensor with type " + tensor.type());
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
+ return tensor.multiply(unitTensor);
+ }
+
+ }
+
+ /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
+ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < concatSizes.dimensions(); i++) {
+ String currentDimension = concatType.dimensions().get(i).name();
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
+ if (currentDimension.equals(concatDimension))
+ concatSizes.set(i, aSize + bSize);
+ else if (aSize != 0 && bSize != 0 && aSize!=bSize )
+ concatSizes.set(i, Math.min(aSize, bSize));
+ else
+ concatSizes.set(i, Math.max(aSize, bSize));
+ }
+ return concatSizes.build();
+ }
+
+ /**
+ * Combine two addresses, adding the offset to the concat dimension
+ *
+ * @return the combined address or null if the addresses are incompatible
+ * (in some other dimension than the concat dimension)
+ */
+ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
+ Arrays.fill(combinedLabels, -1);
+ int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
+ mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ if ( ! compatible) return null;
+ return TensorAddress.of(combinedLabels);
+ }
+
+ /**
+ * Returns the an array having one entry in order for each dimension of fromType
+ * containing the index at which toType contains the same dimension name.
+ * That is, if the returned array contains n at index i then
+ * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
+ * If some dimension in fromType is not present in toType, the corresponding index will be -1
+ */
+ // TODO: Stolen from join
+ private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ int[] toIndexes = new int[fromType.dimensions().size()];
+ for (int i = 0; i < fromType.dimensions().size(); i++)
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ return toIndexes;
+ }
+
+ /**
+ * Maps the content in the given list to the given array, using the given index map.
+ *
+ * @return true if the mapping was successful, false if one of the destination positions was
+ * occupied by a different value
+ */
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
+ for (int i = 0; i < from.size(); i++) {
+ int toIndex = indexMap[i];
+ if (concatDimension == toIndex) {
+ to[toIndex] = from.numericLabel(i) + concatOffset;
+ }
+ else {
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
+ }
+ }
+ return true;
+ }
+
static class CellVector {
ArrayList<Double> values = new ArrayList<>();
void setValue(int ccDimIndex, double value) {
@@ -57,8 +242,6 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
- enum DimType { common, separate, concat }
-
static class SplitHow {
List<DimType> handleDims = new ArrayList<>();
long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); }
@@ -76,7 +259,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
enum CombineHow { left, right, both, concat }
List<CombineHow> combineHow = new ArrayList<>();
-
+
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoA.handleDims.add(DimType.concat);
@@ -160,8 +343,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
static int concatDimensionSize(CellVectorMapMap data) {
Set<Integer> sizes = new HashSet<>();
data.map.forEach((m, cvmap) ->
- cvmap.map.forEach((e, vector) ->
- sizes.add(vector.values.size())));
+ cvmap.map.forEach((e, vector) ->
+ sizes.add(vector.values.size())));
if (sizes.isEmpty()) {
return 1;
}
@@ -207,17 +390,17 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
var lhs = entry.getValue();
var rhs = b.map.get(common);
lhs.map.forEach((leftOnly, leftCells) -> {
- rhs.map.forEach((rightOnly, rightCells) -> {
- for (int i = 0; i < leftCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i);
- builder.cell(addr, leftCells.values.get(i));
- }
- for (int i = 0; i < rightCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
- builder.cell(addr, rightCells.values.get(i));
- }
- });
+ rhs.map.forEach((rightOnly, rightCells) -> {
+ for (int i = 0; i < leftCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i);
+ builder.cell(addr, leftCells.values.get(i));
+ }
+ for (int i = 0; i < rightCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
+ builder.cell(addr, rightCells.values.get(i));
+ }
});
+ });
}
}
return builder.build();
@@ -240,7 +423,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
commonLabels[commonIdx++] = addr.label(i);
break;
case separate:
- separateLabels[separateIdx++] = addr.label(i);
+ separateLabels[separateIdx++] = addr.label(i);
break;
case concat:
ccDimIndex = addr.numericLabel(i);
@@ -257,184 +440,4 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- private final TensorFunction<NAMETYPE> argumentA, argumentB;
- private final String dimension;
-
- public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
- Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
- Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
- Objects.requireNonNull(dimension, "The dimension cannot be null");
- this.argumentA = argumentA;
- this.argumentB = argumentB;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
-
- @Override
- public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 2)
- throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat<>(arguments.get(0), arguments.get(1), dimension);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
- return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
- }
-
- @Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
- }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor a = argumentA.evaluate(context);
- Tensor b = argumentB.evaluate(context);
- if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
- return oldEvaluate(a, b);
- }
- var helper = new Helper(a, b, dimension);
- return helper.result;
- }
-
- private Tensor oldEvaluate(Tensor a, Tensor b) {
- TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
-
- a = ensureIndexedDimension(dimension, a, concatType.valueType());
- b = ensureIndexedDimension(dimension, b, concatType.valueType());
-
- IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
- IndexedTensor bIndexed = (IndexedTensor) b;
- DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
-
- Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
- int[] aToIndexes = mapIndexes(a.type(), concatType);
- int[] bToIndexes = mapIndexes(b.type(), concatType);
- concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
- concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
- return builder.build();
- }
-
- private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
- int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
- Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
- IndexedTensor.SubspaceIterator iaSubspace = ia.next();
- TensorAddress aAddress = iaSubspace.address();
- for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
- IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- while (ibSubspace.hasNext()) {
- Tensor.Cell bCell = ibSubspace.next();
- TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
- concatType, offset, dimension);
- if (combinedAddress == null) continue; // incompatible
-
- builder.cell(combinedAddress, bCell.getValue());
- }
- iaSubspace.reset();
- }
- }
- }
-
- private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
- Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
- if ( dimension.isPresent() ) {
- if ( ! dimension.get().isIndexed())
- throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
- "' requires that dimension to be indexed or absent, " +
- "but got a tensor with type " + tensor.type());
- return tensor;
- }
- else { // extend tensor with this dimension
- if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
- throw new IllegalArgumentException("Concat requires an indexed tensor, " +
- "but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
- .indexed(dimensionName, 1)
- .build())
- .cell(1,0)
- .build();
- return tensor.multiply(unitTensor);
- }
-
- }
-
- /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
- private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
- for (int i = 0; i < concatSizes.dimensions(); i++) {
- String currentDimension = concatType.dimensions().get(i).name();
- long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
- long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
- if (currentDimension.equals(concatDimension))
- concatSizes.set(i, aSize + bSize);
- else if (aSize != 0 && bSize != 0 && aSize!=bSize )
- concatSizes.set(i, Math.min(aSize, bSize));
- else
- concatSizes.set(i, Math.max(aSize, bSize));
- }
- return concatSizes.build();
- }
-
- /**
- * Combine two addresses, adding the offset to the concat dimension
- *
- * @return the combined address or null if the addresses are incompatible
- * (in some other dimension than the concat dimension)
- */
- private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, long concatOffset, String concatDimension) {
- long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
- int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
- mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
- boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
- if ( ! compatible) return null;
- return TensorAddress.of(combinedLabels);
- }
-
- /**
- * Returns the an array having one entry in order for each dimension of fromType
- * containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
- * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
- * If some dimension in fromType is not present in toType, the corresponding index will be -1
- */
- // TODO: Stolen from join
- private int[] mapIndexes(TensorType fromType, TensorType toType) {
- int[] toIndexes = new int[fromType.dimensions().size()];
- for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
- return toIndexes;
- }
-
- /**
- * Maps the content in the given list to the given array, using the given index map.
- *
- * @return true if the mapping was successful, false if one of the destination positions was
- * occupied by a different value
- */
- private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
- for (int i = 0; i < from.size(); i++) {
- int toIndex = indexMap[i];
- if (concatDimension == toIndex) {
- to[toIndex] = from.numericLabel(i) + concatOffset;
- }
- else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
- to[toIndex] = from.numericLabel(i);
- }
- }
- return true;
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index a0fd9272f54..92a72dfd280 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* A function which returns a constant tensor.
@@ -49,4 +50,9 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
@Override
public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); }
+ @Override
+ public int hashCode() {
+ return Objects.hash("constant", constant.hashCode());
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 92d89ec68f7..7218375de89 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -40,13 +41,16 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP
return new Generate<>(type, diagFunction);
}
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::name);
+ }
+
@Override
public String toString(ToStringContext<NAMETYPE> context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
- private Stream<String> dimensionNames() {
- return type.dimensions().stream().map(TensorType.Dimension::name);
- }
+ @Override
+ public int hashCode() { return Objects.hash("diag", type, diagFunction); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 46992115c23..c402a1bde5b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
@@ -45,13 +46,13 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
TensorType type() { return type; }
+ abstract String contentToString(ToStringContext<NAMETYPE> context);
+
@Override
public String toString(ToStringContext<NAMETYPE> context) {
return type().toString() + ":" + contentToString(context);
}
- abstract String contentToString(ToStringContext<NAMETYPE> context);
-
/** Creates a dynamic tensor function. The cell addresses must match the type. */
public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
return new MappedDynamicTensor<>(type, cells);
@@ -98,6 +99,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("mappedDynamicTensor", type(), cells); }
+
}
private static class IndexedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> {
@@ -141,6 +145,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("indexedDynamicTensor", type(), cells); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
index c049e5d41da..eee037c8dba 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* The <i>expand</i> tensor function returns a tensor with a new dimension of
@@ -45,4 +46,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "expand(" + argument.toString(context) + ", " + dimensionName + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("expand", argument, dimensionName); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 54e83fa472f..3ad3e1114cc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -126,6 +126,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
return boundGenerator.toString(new GenerateToStringContext(context));
}
+ @Override
+ public int hashCode() { return Objects.hash("generate", type, freeGenerator, boundGenerator); }
+
/**
* A context for generating all the values of a tensor produced by evaluating Generate.
* This returns all the current index values as variables and falls back to delivering from the given
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 52bef482fb4..4ec5b196dbc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -80,6 +80,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
@Override
+ public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); }
+
+ @Override
public TensorType type(TypeContext<NAMETYPE> context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@@ -356,7 +359,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return builder.build();
}
-
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index f47202d1b9f..38cc95ac6b2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -43,4 +44,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("l1_normalize", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 8f4e2f466d4..4a676449657 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -45,4 +46,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("l2_normalize", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 46772d8cbff..68645546be9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -75,4 +75,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE
return "map(" + argument.toString(context) + ", " + mapper + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("map", argument, mapper); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 8ac6d711c48..3239ab1a70c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.Name;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -49,4 +50,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("matmul", argument1, argument2, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index adc84225a63..2b9dc709e0e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -70,11 +70,6 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
- }
-
- @Override
public TensorType type(TypeContext<NAMETYPE> context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@@ -87,6 +82,15 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return evaluate(a, b, mergedType, merger);
}
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("merge", argumentA, argumentB, merger); }
+
static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
// Choose merge algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 18c5db8e3a7..34b8eba3e67 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -42,6 +43,9 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("random", type); }
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 45b827db900..7053eeb0a1c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -50,4 +51,9 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("range", type, rangeFunction);
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 8841cff15e9..96465de6c0f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -107,6 +107,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return evaluate(this.argument.evaluate(context), dimensions, aggregator);
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("reduce", argument, dimensions, aggregator);
+ }
+
static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
@@ -191,6 +196,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
/** Resets the aggregator */
public abstract void reset();
+ /** Returns a hash of this aggregator which only depends on its identity */
+ @Override
+ public abstract int hashCode();
+
}
private static class AvgAggregator extends ValueAggregator {
@@ -214,6 +223,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
valueCount = 0;
valueSum = 0.0;
}
+
+ @Override
+ public int hashCode() { return "avgAggregator".hashCode(); }
+
}
private static class CountAggregator extends ValueAggregator {
@@ -234,6 +247,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueCount = 0;
}
+
+ @Override
+ public int hashCode() { return "countAggregator".hashCode(); }
+
}
private static class MaxAggregator extends ValueAggregator {
@@ -255,6 +272,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
maxValue = Double.NEGATIVE_INFINITY;
}
+
+ @Override
+ public int hashCode() { return "maxAggregator".hashCode(); }
+
}
private static class MedianAggregator extends ValueAggregator {
@@ -288,6 +309,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
values = new ArrayList<>();
}
+ @Override
+ public int hashCode() { return "medianAggregator".hashCode(); }
+
}
private static class MinAggregator extends ValueAggregator {
@@ -310,6 +334,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
minValue = Double.POSITIVE_INFINITY;
}
+ @Override
+ public int hashCode() { return "minAggregator".hashCode(); }
+
}
private static class ProdAggregator extends ValueAggregator {
@@ -330,6 +357,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueProd = 1.0;
}
+
+ @Override
+ public int hashCode() { return "prodAggregator".hashCode(); }
+
}
private static class SumAggregator extends ValueAggregator {
@@ -350,6 +381,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueSum = 0.0;
}
+
+ @Override
+ public int hashCode() { return "sumAggregator".hashCode(); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 7505355beed..ccb437ef5a7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Arrays;
import java.util.List;
+import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
@@ -322,6 +323,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("reduce_join", argumentA, argumentB, combinator, aggregator, dimensions);
+ }
+
private static class MultiDimensionIterator {
private final long[] bounds;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index a434ecba5cc..023e91e424f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -127,12 +127,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return TensorAddress.of(reorderedLabels);
}
- @Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "rename(" + argument.toString(context) + ", " +
- toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
@@ -144,4 +138,13 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return b.toString();
}
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "rename(" + argument.toString(context) + ", " +
+ toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("rename", argument, fromDimensions, toDimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 517f6683cbf..2639e153923 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import java.util.Comparator;
import java.util.List;
+import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
@@ -75,6 +76,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left + right; }
@Override
public String toString() { return "f(a,b)(a + b)"; }
+ @Override
+ public int hashCode() { return "add".hashCode(); }
}
public static class Equal implements DoubleBinaryOperator {
@@ -82,6 +85,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a==b)"; }
+ @Override
+ public int hashCode() { return "equal".hashCode(); }
}
public static class Greater implements DoubleBinaryOperator {
@@ -89,6 +94,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a > b)"; }
+ @Override
+ public int hashCode() { return "greater".hashCode(); }
}
public static class Less implements DoubleBinaryOperator {
@@ -96,6 +103,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a < b)"; }
+ @Override
+ public int hashCode() { return "less".hashCode(); }
}
public static class Max implements DoubleBinaryOperator {
@@ -103,6 +112,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@Override
public String toString() { return "f(a,b)(max(a, b))"; }
+ @Override
+ public int hashCode() { return "max".hashCode(); }
}
public static class Min implements DoubleBinaryOperator {
@@ -110,6 +121,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.min(left, right); }
@Override
public String toString() { return "f(a,b)(min(a, b))"; }
+ @Override
+ public int hashCode() { return "min".hashCode(); }
}
public static class Mean implements DoubleBinaryOperator {
@@ -117,6 +130,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return (left + right) / 2; }
@Override
public String toString() { return "f(a,b)((a + b) / 2)"; }
+ @Override
+ public int hashCode() { return "mean".hashCode(); }
}
public static class Multiply implements DoubleBinaryOperator {
@@ -124,6 +139,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
+ @Override
+ public int hashCode() { return "multiply".hashCode(); }
}
public static class Pow implements DoubleBinaryOperator {
@@ -131,6 +148,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
@Override
public String toString() { return "f(a,b)(pow(a, b))"; }
+ @Override
+ public int hashCode() { return "pow".hashCode(); }
}
public static class Divide implements DoubleBinaryOperator {
@@ -138,6 +157,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left / right; }
@Override
public String toString() { return "f(a,b)(a / b)"; }
+ @Override
+ public int hashCode() { return "divide".hashCode(); }
}
public static class SquaredDifference implements DoubleBinaryOperator {
@@ -145,6 +166,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
@Override
public String toString() { return "f(a,b)((a-b) * (a-b))"; }
+ @Override
+ public int hashCode() { return "squareddifference".hashCode(); }
}
public static class Subtract implements DoubleBinaryOperator {
@@ -152,6 +175,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left - right; }
@Override
public String toString() { return "f(a,b)(a - b)"; }
+ @Override
+ public int hashCode() { return "subtract".hashCode(); }
}
@@ -172,6 +197,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return hamming(left, right); }
@Override
public String toString() { return "f(a,b)(hamming(a,b))"; }
+ @Override
+ public int hashCode() { return "hamming".hashCode(); }
}
@@ -182,6 +209,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.abs(operand); }
@Override
public String toString() { return "f(a)(fabs(a))"; }
+ @Override
+ public int hashCode() { return "abs".hashCode(); }
}
public static class Acos implements DoubleUnaryOperator {
@@ -189,6 +218,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
public String toString() { return "f(a)(acos(a))"; }
+ @Override
+ public int hashCode() { return "acos".hashCode(); }
}
public static class Asin implements DoubleUnaryOperator {
@@ -196,6 +227,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.asin(operand); }
@Override
public String toString() { return "f(a)(asin(a))"; }
+ @Override
+ public int hashCode() { return "asin".hashCode(); }
}
public static class Atan implements DoubleUnaryOperator {
@@ -203,6 +236,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.atan(operand); }
@Override
public String toString() { return "f(a)(atan(a))"; }
+ @Override
+ public int hashCode() { return "atan".hashCode(); }
}
public static class Ceil implements DoubleUnaryOperator {
@@ -210,6 +245,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.ceil(operand); }
@Override
public String toString() { return "f(a)(ceil(a))"; }
+ @Override
+ public int hashCode() { return "ceil".hashCode(); }
}
public static class Cos implements DoubleUnaryOperator {
@@ -217,6 +254,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.cos(operand); }
@Override
public String toString() { return "f(a)(cos(a))"; }
+ @Override
+ public int hashCode() { return "cos".hashCode(); }
}
public static class Elu implements DoubleUnaryOperator {
@@ -231,6 +270,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
+ @Override
+ public int hashCode() { return Objects.hash("elu", alpha); }
}
public static class Exp implements DoubleUnaryOperator {
@@ -238,6 +279,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
public String toString() { return "f(a)(exp(a))"; }
+ @Override
+ public int hashCode() { return "exp".hashCode(); }
}
public static class Floor implements DoubleUnaryOperator {
@@ -245,6 +288,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.floor(operand); }
@Override
public String toString() { return "f(a)(floor(a))"; }
+ @Override
+ public int hashCode() { return "floor".hashCode(); }
}
public static class Log implements DoubleUnaryOperator {
@@ -252,6 +297,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.log(operand); }
@Override
public String toString() { return "f(a)(log(a))"; }
+ @Override
+ public int hashCode() { return "log".hashCode(); }
}
public static class Neg implements DoubleUnaryOperator {
@@ -259,6 +306,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return -operand; }
@Override
public String toString() { return "f(a)(-a)"; }
+ @Override
+ public int hashCode() { return "neg".hashCode(); }
}
public static class Reciprocal implements DoubleUnaryOperator {
@@ -266,6 +315,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / operand; }
@Override
public String toString() { return "f(a)(1 / a)"; }
+ @Override
+ public int hashCode() { return "reciprocal".hashCode(); }
}
public static class Relu implements DoubleUnaryOperator {
@@ -273,6 +324,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.max(operand, 0); }
@Override
public String toString() { return "f(a)(max(0, a))"; }
+ @Override
+ public int hashCode() { return "relu".hashCode(); }
}
public static class Selu implements DoubleUnaryOperator {
@@ -290,6 +343,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
+ @Override
+ public int hashCode() { return Objects.hash("selu", scale, alpha); }
}
public static class LeakyRelu implements DoubleUnaryOperator {
@@ -304,6 +359,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
+ @Override
+ public int hashCode() { return Objects.hash("leakyrelu", alpha); }
}
public static class Sin implements DoubleUnaryOperator {
@@ -311,6 +368,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.sin(operand); }
@Override
public String toString() { return "f(a)(sin(a))"; }
+ @Override
+ public int hashCode() { return "sin".hashCode(); }
}
public static class Rsqrt implements DoubleUnaryOperator {
@@ -318,6 +377,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(1.0 / sqrt(a))"; }
+ @Override
+ public int hashCode() { return "rsqrt".hashCode(); }
}
public static class Sigmoid implements DoubleUnaryOperator {
@@ -325,6 +386,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }
@Override
public String toString() { return "f(a)(1 / (1 + exp(-a)))"; }
+ @Override
+ public int hashCode() { return "sigmoid".hashCode(); }
}
public static class Sqrt implements DoubleUnaryOperator {
@@ -332,6 +395,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(sqrt(a))"; }
+ @Override
+ public int hashCode() { return "sqrt".hashCode(); }
}
public static class Square implements DoubleUnaryOperator {
@@ -339,6 +404,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return operand * operand; }
@Override
public String toString() { return "f(a)(a * a)"; }
+ @Override
+ public int hashCode() { return "square".hashCode(); }
}
public static class Tan implements DoubleUnaryOperator {
@@ -346,6 +413,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.tan(operand); }
@Override
public String toString() { return "f(a)(tan(a))"; }
+ @Override
+ public int hashCode() { return "tan".hashCode(); }
}
public static class Tanh implements DoubleUnaryOperator {
@@ -353,6 +422,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.tanh(operand); }
@Override
public String toString() { return "f(a)(tanh(a))"; }
+ @Override
+ public int hashCode() { return "tanh".hashCode(); }
}
public static class Erf implements DoubleUnaryOperator {
@@ -410,6 +481,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return erf(operand); }
@Override
public String toString() { return "f(a)(erf(a))"; }
+ @Override
+ public int hashCode() { return "erf".hashCode(); }
static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);
@@ -464,6 +537,8 @@ public class ScalarFunctions {
}
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("equal", argumentNames); }
}
public static class Random implements Function<List<Long>, Double> {
@@ -473,6 +548,8 @@ public class ScalarFunctions {
}
@Override
public String toString() { return "random"; }
+ @Override
+ public int hashCode() { return "random".hashCode(); }
}
public static class SumElements implements Function<List<Long>, Double> {
@@ -492,6 +569,8 @@ public class ScalarFunctions {
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
+ @Override
+ public int hashCode() { return Objects.hash("sum", argumentNames); }
}
public static class Constant implements Function<List<Long>, Double> {
@@ -506,6 +585,8 @@ public class ScalarFunctions {
}
@Override
public String toString() { return Double.toString(value); }
+ @Override
+ public int hashCode() { return Objects.hash("constant", value); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index e3464255fac..39bddc3a3cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -166,6 +166,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); }
+
public static class DimensionValue<NAMETYPE extends Name> {
private final Optional<String> dimension;
@@ -255,6 +258,10 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return index.toString(context);
}
+ @Override
+ public int hashCode() { return Objects.hash(dimension, label, index); }
+
+
}
private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
@@ -273,6 +280,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public String toString() { return String.valueOf(value); }
+ @Override
+ public int hashCode() { return Objects.hash("constantIntegerFunction", value); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index 9ea9040831b..df8cd6d39cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -50,4 +51,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("softmax", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 1e1d1d3b5b9..503f414d8eb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -68,4 +68,8 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
@Override
public String toString() { return toString(ToStringContext.empty()); }
+ /** Returns a hashcode computed from the data in this */
+ @Override
+ public abstract int hashCode();
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 0223ad4d588..bd4fc7b8336 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.Name;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -51,4 +52,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("xwplusb", x, w, b, dimension); }
+
}