diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-03-01 10:39:52 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-03-01 10:39:52 +0100 |
commit | 05ab2e976349eb3016fa91020e161a8782bf00a5 (patch) | |
tree | d570863bbd636ddf908bf1d875efd21e5cbf9056 /vespajlib/src/main/java | |
parent | 0e1e603359c9018cea86d1716903c3ce365e529e (diff) |
Compute hash without serializing to string
Diffstat (limited to 'vespajlib/src/main/java')
27 files changed, 435 insertions, 213 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index c4588b79fa9..ca396ae5bf2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -355,6 +355,10 @@ public interface Tensor { @Override boolean equals(Object o); + /** Returns a hash computed deterministically from the content of this tensor */ + @Override + int hashCode(); + /** * Implement here to make this work across implementations. * Implementations must override equals and call this because this is an interface and cannot override equals. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index dbc8396d701..8a9a85d343c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.ToStringContext; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.Optional; /** @@ -62,6 +63,9 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti return name; } + @Override + public int hashCode() { return Objects.hash("variableTensor", name, requiredType); } + private void verifyType(TensorType givenType) { if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get())) throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 55dd8a7bc8a..d2762ad762d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -52,4 +52,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } + @Override + public int hashCode() { return Objects.hash("argmax", argument, dimensions); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index f1f0b9d67b0..baedf41bcb8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -52,4 +52,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } + @Override + public int hashCode() { return Objects.hash("argmin", argument, dimensions); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index 09f84e6747e..176847cec93 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -111,4 +111,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; } + @Override + public int hashCode() { return Objects.hash("cellcast", argument, valueType); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 6d4b15be991..abf0d89c2b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -31,6 +31,191 @@ import java.util.stream.Collectors; */ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + enum DimType { common, separate, concat } + + private final TensorFunction<NAMETYPE> argumentA, argumentB; + private final String dimension; + + public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { + Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); + Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); + Objects.requireNonNull(dimension, "The dimension cannot be null"); + this.argumentA = argumentA; + this.argumentB = argumentB; + this.dimension = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 2) + throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); + return new Concat<>(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; + } + + @Override + public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + if (a instanceof IndexedTensor && b instanceof IndexedTensor) { + return oldEvaluate(a, b); + } + var helper = new Helper(a, b, dimension); + return helper.result; + } + + private Tensor oldEvaluate(Tensor a, Tensor b) { + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); + + a = ensureIndexedDimension(dimension, a, concatType.valueType()); + b = ensureIndexedDimension(dimension, b, concatType.valueType()); + + IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor + IndexedTensor bIndexed = (IndexedTensor) b; + DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); + + Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + int[] aToIndexes = mapIndexes(a.type(), concatType); + int[] bToIndexes = mapIndexes(b.type(), concatType); + concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); + concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); + return builder.build(); + } + + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, + int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { + Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { + IndexedTensor.SubspaceIterator iaSubspace = ia.next(); + TensorAddress aAddress = iaSubspace.address(); + for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { + IndexedTensor.SubspaceIterator ibSubspace = ib.next(); + while (ibSubspace.hasNext()) { + Tensor.Cell bCell = ibSubspace.next(); + TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, + concatType, offset, dimension); + if (combinedAddress == null) continue; // incompatible + + builder.cell(combinedAddress, bCell.getValue()); + } + iaSubspace.reset(); + } + } + } + + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { + Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); + if ( dimension.isPresent() ) { + if ( ! dimension.get().isIndexed()) + throw new IllegalArgumentException("Concat in dimension '" + dimensionName + + "' requires that dimension to be indexed or absent, " + + "but got a tensor with type " + tensor.type()); + return tensor; + } + else { // extend tensor with this dimension + if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + throw new IllegalArgumentException("Concat requires an indexed tensor, " + + "but got a tensor with type " + tensor.type()); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); + return tensor.multiply(unitTensor); + } + + } + + /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ + private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { + DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); + for (int i = 0; i < concatSizes.dimensions(); i++) { + String currentDimension = concatType.dimensions().get(i).name(); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); + if (currentDimension.equals(concatDimension)) + concatSizes.set(i, aSize + bSize); + else if (aSize != 0 && bSize != 0 && aSize!=bSize ) + concatSizes.set(i, Math.min(aSize, bSize)); + else + concatSizes.set(i, Math.max(aSize, bSize)); + } + return concatSizes.build(); + } + + /** + * Combine two addresses, adding the offset to the concat dimension + * + * @return the combined address or null if the addresses are incompatible + * (in some other dimension than the concat dimension) + */ + private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; + Arrays.fill(combinedLabels, -1); + int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); + mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension + boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here + if ( ! compatible) return null; + return TensorAddress.of(combinedLabels); + } + + /** + * Returns the an array having one entry in order for each dimension of fromType + * containing the index at which toType contains the same dimension name. + * That is, if the returned array contains n at index i then + * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) + * If some dimension in fromType is not present in toType, the corresponding index will be -1 + */ + // TODO: Stolen from join + private int[] mapIndexes(TensorType fromType, TensorType toType) { + int[] toIndexes = new int[fromType.dimensions().size()]; + for (int i = 0; i < fromType.dimensions().size(); i++) + toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); + return toIndexes; + } + + /** + * Maps the content in the given list to the given array, using the given index map. + * + * @return true if the mapping was successful, false if one of the destination positions was + * occupied by a different value + */ + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { + for (int i = 0; i < from.size(); i++) { + int toIndex = indexMap[i]; + if (concatDimension == toIndex) { + to[toIndex] = from.numericLabel(i) + concatOffset; + } + else { + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); + } + } + return true; + } + static class CellVector { ArrayList<Double> values = new ArrayList<>(); void setValue(int ccDimIndex, double value) { @@ -57,8 +242,6 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } - enum DimType { common, separate, concat } - static class SplitHow { List<DimType> handleDims = new ArrayList<>(); long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); } @@ -76,7 +259,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET enum CombineHow { left, right, both, concat } List<CombineHow> combineHow = new ArrayList<>(); - + void aOnly(String dimName) { if (dimName.equals(concatDimension)) { splitInfoA.handleDims.add(DimType.concat); @@ -160,8 +343,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET static int concatDimensionSize(CellVectorMapMap data) { Set<Integer> sizes = new HashSet<>(); data.map.forEach((m, cvmap) -> - cvmap.map.forEach((e, vector) -> - sizes.add(vector.values.size()))); + cvmap.map.forEach((e, vector) -> + sizes.add(vector.values.size()))); if (sizes.isEmpty()) { return 1; } @@ -207,17 +390,17 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET var lhs = entry.getValue(); var rhs = b.map.get(common); lhs.map.forEach((leftOnly, leftCells) -> { - rhs.map.forEach((rightOnly, rightCells) -> { - for (int i = 0; i < leftCells.values.size(); i++) { - TensorAddress addr = combine(common, leftOnly, rightOnly, i); - builder.cell(addr, leftCells.values.get(i)); - } - for (int i = 0; i < rightCells.values.size(); i++) { - TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize); - builder.cell(addr, rightCells.values.get(i)); - } - }); + rhs.map.forEach((rightOnly, rightCells) -> { + for (int i = 0; i < leftCells.values.size(); i++) { + TensorAddress addr = combine(common, leftOnly, rightOnly, i); + builder.cell(addr, leftCells.values.get(i)); + } + for (int i = 0; i < rightCells.values.size(); i++) { + TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize); + builder.cell(addr, rightCells.values.get(i)); + } }); + }); } } return builder.build(); @@ -240,7 +423,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET commonLabels[commonIdx++] = addr.label(i); break; case separate: - separateLabels[separateIdx++] = addr.label(i); + separateLabels[separateIdx++] = addr.label(i); break; case concat: ccDimIndex = addr.numericLabel(i); @@ -257,184 +440,4 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } } - private final TensorFunction<NAMETYPE> argumentA, argumentB; - private final String dimension; - - public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { - Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); - Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); - Objects.requireNonNull(dimension, "The dimension cannot be null"); - this.argumentA = argumentA; - this.argumentB = argumentB; - this.dimension = dimension; - } - - @Override - public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } - - @Override - public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 2) - throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); - return new Concat<>(arguments.get(0), arguments.get(1), dimension); - } - - @Override - public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { - return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); - } - - @Override - public String toString(ToStringContext<NAMETYPE> context) { - return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; - } - - @Override - public TensorType type(TypeContext<NAMETYPE> context) { - return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); - } - - @Override - public Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor a = argumentA.evaluate(context); - Tensor b = argumentB.evaluate(context); - if (a instanceof IndexedTensor && b instanceof IndexedTensor) { - return oldEvaluate(a, b); - } - var helper = new Helper(a, b, dimension); - return helper.result; - } - - private Tensor oldEvaluate(Tensor a, Tensor b) { - TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); - - a = ensureIndexedDimension(dimension, a, concatType.valueType()); - b = ensureIndexedDimension(dimension, b, concatType.valueType()); - - IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor - IndexedTensor bIndexed = (IndexedTensor) b; - DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); - - Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); - int[] aToIndexes = mapIndexes(a.type(), concatType); - int[] bToIndexes = mapIndexes(b.type(), concatType); - concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); - concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder); - return builder.build(); - } - - private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, - int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { - Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); - for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { - IndexedTensor.SubspaceIterator iaSubspace = ia.next(); - TensorAddress aAddress = iaSubspace.address(); - for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { - IndexedTensor.SubspaceIterator ibSubspace = ib.next(); - while (ibSubspace.hasNext()) { - Tensor.Cell bCell = ibSubspace.next(); - TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, - concatType, offset, dimension); - if (combinedAddress == null) continue; // incompatible - - builder.cell(combinedAddress, bCell.getValue()); - } - iaSubspace.reset(); - } - } - } - - private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { - Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); - if ( dimension.isPresent() ) { - if ( ! dimension.get().isIndexed()) - throw new IllegalArgumentException("Concat in dimension '" + dimensionName + - "' requires that dimension to be indexed or absent, " + - "but got a tensor with type " + tensor.type()); - return tensor; - } - else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) - throw new IllegalArgumentException("Concat requires an indexed tensor, " + - "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) - .indexed(dimensionName, 1) - .build()) - .cell(1,0) - .build(); - return tensor.multiply(unitTensor); - } - - } - - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ - private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { - DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); - for (int i = 0; i < concatSizes.dimensions(); i++) { - String currentDimension = concatType.dimensions().get(i).name(); - long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); - long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); - if (currentDimension.equals(concatDimension)) - concatSizes.set(i, aSize + bSize); - else if (aSize != 0 && bSize != 0 && aSize!=bSize ) - concatSizes.set(i, Math.min(aSize, bSize)); - else - concatSizes.set(i, Math.max(aSize, bSize)); - } - return concatSizes.build(); - } - - /** - * Combine two addresses, adding the offset to the concat dimension - * - * @return the combined address or null if the addresses are incompatible - * (in some other dimension than the concat dimension) - */ - private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, long concatOffset, String concatDimension) { - long[] combinedLabels = new long[concatType.dimensions().size()]; - Arrays.fill(combinedLabels, -1); - int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); - mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension - boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here - if ( ! compatible) return null; - return TensorAddress.of(combinedLabels); - } - - /** - * Returns the an array having one entry in order for each dimension of fromType - * containing the index at which toType contains the same dimension name. - * That is, if the returned array contains n at index i then - * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) - * If some dimension in fromType is not present in toType, the corresponding index will be -1 - */ - // TODO: Stolen from join - private int[] mapIndexes(TensorType fromType, TensorType toType) { - int[] toIndexes = new int[fromType.dimensions().size()]; - for (int i = 0; i < fromType.dimensions().size(); i++) - toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1); - return toIndexes; - } - - /** - * Maps the content in the given list to the given array, using the given index map. - * - * @return true if the mapping was successful, false if one of the destination positions was - * occupied by a different value - */ - private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { - for (int i = 0; i < from.size(); i++) { - int toIndex = indexMap[i]; - if (concatDimension == toIndex) { - to[toIndex] = from.numericLabel(i) + concatOffset; - } - else { - if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; - to[toIndex] = from.numericLabel(i); - } - } - return true; - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index a0fd9272f54..92a72dfd280 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * A function which returns a constant tensor. @@ -49,4 +50,9 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti @Override public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); } + @Override + public int hashCode() { + return Objects.hash("constant", constant.hashCode()); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 92d89ec68f7..7218375de89 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -40,13 +41,16 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP return new Generate<>(type, diagFunction); } + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::name); + } + @Override public String toString(ToStringContext<NAMETYPE> context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } - private Stream<String> dimensionNames() { - return type.dimensions().stream().map(TensorType.Dimension::name); - } + @Override + public int hashCode() { return Objects.hash("diag", type, diagFunction); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 46992115c23..c402a1bde5b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -13,6 +13,7 @@ import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. @@ -45,13 +46,13 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens TensorType type() { return type; } + abstract String contentToString(ToStringContext<NAMETYPE> context); + @Override public String toString(ToStringContext<NAMETYPE> context) { return type().toString() + ":" + contentToString(context); } - abstract String contentToString(ToStringContext<NAMETYPE> context); - /** Creates a dynamic tensor function. The cell addresses must match the type. */ public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { return new MappedDynamicTensor<>(type, cells); @@ -98,6 +99,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return b.toString(); } + @Override + public int hashCode() { return Objects.hash("mappedDynamicTensor", type(), cells); } + } private static class IndexedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> { @@ -141,6 +145,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens return b.toString(); } + @Override + public int hashCode() { return Objects.hash("indexedDynamicTensor", type(), cells); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java index c049e5d41da..eee037c8dba 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * The <i>expand</i> tensor function returns a tensor with a new dimension of @@ -45,4 +46,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; } + @Override + public int hashCode() { return Objects.hash("expand", argument, dimensionName); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 54e83fa472f..3ad3e1114cc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -126,6 +126,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM return boundGenerator.toString(new GenerateToStringContext(context)); } + @Override + public int hashCode() { return Objects.hash("generate", type, freeGenerator, boundGenerator); } + /** * A context for generating all the values of a tensor produced by evaluating Generate. * This returns all the current index values as variables and falls back to delivering from the given diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 52bef482fb4..4ec5b196dbc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -80,6 +80,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } @Override + public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); } + + @Override public TensorType type(TypeContext<NAMETYPE> context) { return outputType(argumentA.type(context), argumentB.type(context)); } @@ -356,7 +359,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP return builder.build(); } - /** * Returns the an array having one entry in order for each dimension of fromType * containing the index at which toType contains the same dimension name. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index f47202d1b9f..38cc95ac6b2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -43,4 +44,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; } + @Override + public int hashCode() { return Objects.hash("l1_normalize", argument, dimension); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 8f4e2f466d4..4a676449657 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -45,4 +46,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; } + @Override + public int hashCode() { return Objects.hash("l2_normalize", argument, dimension); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 46772d8cbff..68645546be9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -75,4 +75,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE return "map(" + argument.toString(context) + ", " + mapper + ")"; } + @Override + public int hashCode() { return Objects.hash("map", argument, mapper); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 8ac6d711c48..3239ab1a70c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.Name; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -49,4 +50,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; } + @Override + public int hashCode() { return Objects.hash("matmul", argument1, argument2, dimension); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index adc84225a63..2b9dc709e0e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -70,11 +70,6 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY } @Override - public String toString(ToStringContext<NAMETYPE> context) { - return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; - } - - @Override public TensorType type(TypeContext<NAMETYPE> context) { return outputType(argumentA.type(context), argumentB.type(context)); } @@ -87,6 +82,15 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return evaluate(a, b, mergedType, merger); } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; + } + + @Override + public int hashCode() { return Objects.hash("merge", argumentA, argumentB, merger); } + static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) { // Choose merge algorithm if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name())) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 18c5db8e3a7..34b8eba3e67 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -42,6 +43,9 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } + @Override + public int hashCode() { return Objects.hash("random", type); } + private Stream<String> dimensionNames() { return type.dimensions().stream().map(TensorType.Dimension::toString); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 45b827db900..7053eeb0a1c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -50,4 +51,9 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY return type.dimensions().stream().map(TensorType.Dimension::toString); } + @Override + public int hashCode() { + return Objects.hash("range", type, rangeFunction); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 8841cff15e9..96465de6c0f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -107,6 +107,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return evaluate(this.argument.evaluate(context), dimensions, aggregator); } + @Override + public int hashCode() { + return Objects.hash("reduce", argument, dimensions, aggregator); + } + static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) { if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + @@ -191,6 +196,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET /** Resets the aggregator */ public abstract void reset(); + /** Returns a hash of this aggregator which only depends on its identity */ + @Override + public abstract int hashCode(); + } private static class AvgAggregator extends ValueAggregator { @@ -214,6 +223,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET valueCount = 0; valueSum = 0.0; } + + @Override + public int hashCode() { return "avgAggregator".hashCode(); } + } private static class CountAggregator extends ValueAggregator { @@ -234,6 +247,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public void reset() { valueCount = 0; } + + @Override + public int hashCode() { return "countAggregator".hashCode(); } + } private static class MaxAggregator extends ValueAggregator { @@ -255,6 +272,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public void reset() { maxValue = Double.NEGATIVE_INFINITY; } + + @Override + public int hashCode() { return "maxAggregator".hashCode(); } + } private static class MedianAggregator extends ValueAggregator { @@ -288,6 +309,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET values = new ArrayList<>(); } + @Override + public int hashCode() { return "medianAggregator".hashCode(); } + } private static class MinAggregator extends ValueAggregator { @@ -310,6 +334,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET minValue = Double.POSITIVE_INFINITY; } + @Override + public int hashCode() { return "minAggregator".hashCode(); } + } private static class ProdAggregator extends ValueAggregator { @@ -330,6 +357,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public void reset() { valueProd = 1.0; } + + @Override + public int hashCode() { return "prodAggregator".hashCode(); } + } private static class SumAggregator extends ValueAggregator { @@ -350,6 +381,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET public void reset() { valueSum = 0.0; } + + @Override + public int hashCode() { return "sumAggregator".hashCode(); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 7505355beed..ccb437ef5a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Arrays; import java.util.List; +import java.util.Objects; import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; @@ -322,6 +323,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N Reduce.commaSeparated(dimensions) + ")"; } + @Override + public int hashCode() { + return Objects.hash("reduce_join", argumentA, argumentB, combinator, aggregator, dimensions); + } + private static class MultiDimensionIterator { private final long[] bounds; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index a434ecba5cc..023e91e424f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -127,12 +127,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return TensorAddress.of(reorderedLabels); } - @Override - public String toString(ToStringContext<NAMETYPE> context) { - return "rename(" + argument.toString(context) + ", " + - toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); @@ -144,4 +138,13 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return b.toString(); } + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "rename(" + argument.toString(context) + ", " + + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; + } + + @Override + public int hashCode() { return Objects.hash("rename", argument, fromDimensions, toDimensions); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 517f6683cbf..2639e153923 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList; import java.util.Comparator; import java.util.List; +import java.util.Objects; import java.util.PriorityQueue; import java.util.concurrent.ThreadLocalRandom; import java.util.function.DoubleBinaryOperator; @@ -75,6 +76,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left + right; } @Override public String toString() { return "f(a,b)(a + b)"; } + @Override + public int hashCode() { return "add".hashCode(); } } public static class Equal implements DoubleBinaryOperator { @@ -82,6 +85,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } @Override public String toString() { return "f(a,b)(a==b)"; } + @Override + public int hashCode() { return "equal".hashCode(); } } public static class Greater implements DoubleBinaryOperator { @@ -89,6 +94,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; } @Override public String toString() { return "f(a,b)(a > b)"; } + @Override + public int hashCode() { return "greater".hashCode(); } } public static class Less implements DoubleBinaryOperator { @@ -96,6 +103,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; } @Override public String toString() { return "f(a,b)(a < b)"; } + @Override + public int hashCode() { return "less".hashCode(); } } public static class Max implements DoubleBinaryOperator { @@ -103,6 +112,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return Math.max(left, right); } @Override public String toString() { return "f(a,b)(max(a, b))"; } + @Override + public int hashCode() { return "max".hashCode(); } } public static class Min implements DoubleBinaryOperator { @@ -110,6 +121,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return Math.min(left, right); } @Override public String toString() { return "f(a,b)(min(a, b))"; } + @Override + public int hashCode() { return "min".hashCode(); } } public static class Mean implements DoubleBinaryOperator { @@ -117,6 +130,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return (left + right) / 2; } @Override public String toString() { return "f(a,b)((a + b) / 2)"; } + @Override + public int hashCode() { return "mean".hashCode(); } } public static class Multiply implements DoubleBinaryOperator { @@ -124,6 +139,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left * right; } @Override public String toString() { return "f(a,b)(a * b)"; } + @Override + public int hashCode() { return "multiply".hashCode(); } } public static class Pow implements DoubleBinaryOperator { @@ -131,6 +148,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return Math.pow(left, right); } @Override public String toString() { return "f(a,b)(pow(a, b))"; } + @Override + public int hashCode() { return "pow".hashCode(); } } public static class Divide implements DoubleBinaryOperator { @@ -138,6 +157,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left / right; } @Override public String toString() { return "f(a,b)(a / b)"; } + @Override + public int hashCode() { return "divide".hashCode(); } } public static class SquaredDifference implements DoubleBinaryOperator { @@ -145,6 +166,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return (left - right) * (left - right); } @Override public String toString() { return "f(a,b)((a-b) * (a-b))"; } + @Override + public int hashCode() { return "squareddifference".hashCode(); } } public static class Subtract implements DoubleBinaryOperator { @@ -152,6 +175,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return left - right; } @Override public String toString() { return "f(a,b)(a - b)"; } + @Override + public int hashCode() { return "subtract".hashCode(); } } @@ -172,6 +197,8 @@ public class ScalarFunctions { public double applyAsDouble(double left, double right) { return hamming(left, right); } @Override public String toString() { return "f(a,b)(hamming(a,b))"; } + @Override + public int hashCode() { return "hamming".hashCode(); } } @@ -182,6 +209,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.abs(operand); } @Override public String toString() { return "f(a)(fabs(a))"; } + @Override + public int hashCode() { return "abs".hashCode(); } } public static class Acos implements DoubleUnaryOperator { @@ -189,6 +218,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.acos(operand); } @Override public String toString() { return "f(a)(acos(a))"; } + @Override + public int hashCode() { return "acos".hashCode(); } } public static class Asin implements DoubleUnaryOperator { @@ -196,6 +227,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.asin(operand); } @Override public String toString() { return "f(a)(asin(a))"; } + @Override + public int hashCode() { return "asin".hashCode(); } } public static class Atan implements DoubleUnaryOperator { @@ -203,6 +236,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.atan(operand); } @Override public String toString() { return "f(a)(atan(a))"; } + @Override + public int hashCode() { return "atan".hashCode(); } } public static class Ceil implements DoubleUnaryOperator { @@ -210,6 +245,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.ceil(operand); } @Override public String toString() { return "f(a)(ceil(a))"; } + @Override + public int hashCode() { return "ceil".hashCode(); } } public static class Cos implements DoubleUnaryOperator { @@ -217,6 +254,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.cos(operand); } @Override public String toString() { return "f(a)(cos(a))"; } + @Override + public int hashCode() { return "cos".hashCode(); } } public static class Elu implements DoubleUnaryOperator { @@ -231,6 +270,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; } @Override public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; } + @Override + public int hashCode() { return Objects.hash("elu", alpha); } } public static class Exp implements DoubleUnaryOperator { @@ -238,6 +279,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.exp(operand); } @Override public String toString() { return "f(a)(exp(a))"; } + @Override + public int hashCode() { return "exp".hashCode(); } } public static class Floor implements DoubleUnaryOperator { @@ -245,6 +288,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.floor(operand); } @Override public String toString() { return "f(a)(floor(a))"; } + @Override + public int hashCode() { return "floor".hashCode(); } } public static class Log implements DoubleUnaryOperator { @@ -252,6 +297,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.log(operand); } @Override public String toString() { return "f(a)(log(a))"; } + @Override + public int hashCode() { return "log".hashCode(); } } public static class Neg implements DoubleUnaryOperator { @@ -259,6 +306,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return -operand; } @Override public String toString() { return "f(a)(-a)"; } + @Override + public int hashCode() { return "neg".hashCode(); } } public static class Reciprocal implements DoubleUnaryOperator { @@ -266,6 +315,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return 1.0 / operand; } @Override public String toString() { return "f(a)(1 / a)"; } + @Override + public int hashCode() { return "reciprocal".hashCode(); } } public static class Relu implements DoubleUnaryOperator { @@ -273,6 +324,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.max(operand, 0); } @Override public String toString() { return "f(a)(max(0, a))"; } + @Override + public int hashCode() { return "relu".hashCode(); } } public static class Selu implements DoubleUnaryOperator { @@ -290,6 +343,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); } @Override public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; } + @Override + public int hashCode() { return Objects.hash("selu", scale, alpha); } } public static class LeakyRelu implements DoubleUnaryOperator { @@ -304,6 +359,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); } @Override public String toString() { return "f(a)(max(" + alpha + " * a, a))"; } + @Override + public int hashCode() { return Objects.hash("leakyrelu", alpha); } } public static class Sin implements DoubleUnaryOperator { @@ -311,6 +368,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.sin(operand); } @Override public String toString() { return "f(a)(sin(a))"; } + @Override + public int hashCode() { return "sin".hashCode(); } } public static class Rsqrt implements DoubleUnaryOperator { @@ -318,6 +377,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } @Override public String toString() { return "f(a)(1.0 / sqrt(a))"; } + @Override + public int hashCode() { return "rsqrt".hashCode(); } } public static class Sigmoid implements DoubleUnaryOperator { @@ -325,6 +386,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); } @Override public String toString() { return "f(a)(1 / (1 + exp(-a)))"; } + @Override + public int hashCode() { return "sigmoid".hashCode(); } } public static class Sqrt implements DoubleUnaryOperator { @@ -332,6 +395,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.sqrt(operand); } @Override public String toString() { return "f(a)(sqrt(a))"; } + @Override + public int hashCode() { return "sqrt".hashCode(); } } public static class Square implements DoubleUnaryOperator { @@ -339,6 +404,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return operand * operand; } @Override public String toString() { return "f(a)(a * a)"; } + @Override + public int hashCode() { return "square".hashCode(); } } public static class Tan implements DoubleUnaryOperator { @@ -346,6 +413,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.tan(operand); } @Override public String toString() { return "f(a)(tan(a))"; } + @Override + public int hashCode() { return "tan".hashCode(); } } public static class Tanh implements DoubleUnaryOperator { @@ -353,6 +422,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return Math.tanh(operand); } @Override public String toString() { return "f(a)(tanh(a))"; } + @Override + public int hashCode() { return "tanh".hashCode(); } } public static class Erf implements DoubleUnaryOperator { @@ -410,6 +481,8 @@ public class ScalarFunctions { public double applyAsDouble(double operand) { return erf(operand); } @Override public String toString() { return "f(a)(erf(a))"; } + @Override + public int hashCode() { return "erf".hashCode(); } static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI); @@ -464,6 +537,8 @@ public class ScalarFunctions { } return b.toString(); } + @Override + public int hashCode() { return Objects.hash("equal", argumentNames); } } public static class Random implements Function<List<Long>, Double> { @@ -473,6 +548,8 @@ public class ScalarFunctions { } @Override public String toString() { return "random"; } + @Override + public int hashCode() { return "random".hashCode(); } } public static class SumElements implements Function<List<Long>, Double> { @@ -492,6 +569,8 @@ public class ScalarFunctions { public String toString() { return argumentNames.stream().collect(Collectors.joining("+")); } + @Override + public int hashCode() { return Objects.hash("sum", argumentNames); } } public static class Constant implements Function<List<Long>, Double> { @@ -506,6 +585,8 @@ public class ScalarFunctions { } @Override public String toString() { return Double.toString(value); } + @Override + public int hashCode() { return Objects.hash("constant", value); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index e3464255fac..39bddc3a3cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -166,6 +166,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return b.toString(); } + @Override + public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); } + public static class DimensionValue<NAMETYPE extends Name> { private final Optional<String> dimension; @@ -255,6 +258,10 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY return index.toString(context); } + @Override + public int hashCode() { return Objects.hash(dimension, label, index); } + + } private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { @@ -273,6 +280,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public String toString() { return String.valueOf(value); } + @Override + public int hashCode() { return Objects.hash("constantIntegerFunction", value); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 9ea9040831b..df8cd6d39cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -50,4 +51,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME return "softmax(" + argument.toString(context) + ", " + dimension + ")"; } + @Override + public int hashCode() { return Objects.hash("softmax", argument, dimension); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 1e1d1d3b5b9..503f414d8eb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -68,4 +68,8 @@ public abstract class TensorFunction<NAMETYPE extends Name> { @Override public String toString() { return toString(ToStringContext.empty()); } + /** Returns a hashcode computed from the data in this */ + @Override + public abstract int hashCode(); + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 0223ad4d588..bd4fc7b8336 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.Name; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -51,4 +52,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME dimension + ")"; } + @Override + public int hashCode() { return Objects.hash("xwplusb", x, w, b, dimension); } + } |