summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-03-01 10:39:52 +0100
committerJon Bratseth <bratseth@gmail.com>2022-03-01 10:39:52 +0100
commit05ab2e976349eb3016fa91020e161a8782bf00a5 (patch)
treed570863bbd636ddf908bf1d875efd21e5cbf9056
parent0e1e603359c9018cea86d1716903c3ce365e529e (diff)
Compute hash without serializing to string
-rw-r--r--searchlib/abi-spec.json40
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java6
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java40
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java14
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java6
-rw-r--r--vespajlib/abi-spec.json182
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java395
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java81
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java4
46 files changed, 691 insertions, 324 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 3213b1bb2b9..ced2517ff9f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1235,8 +1235,9 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
+ "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
+ "public int hashCode()",
+ "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
},
@@ -1295,6 +1296,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.ComparisonNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1327,7 +1329,8 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public java.lang.String sourceString()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)"
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1345,7 +1348,8 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
+ "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1360,13 +1364,13 @@
],
"methods": [
"public void <init>()",
- "public int hashCode()",
- "public final boolean equals(java.lang.Object)",
- "public final java.lang.String toString()",
- "public final java.lang.StringBuilder toString(com.yahoo.searchlib.rankingexpression.rule.SerializationContext)",
"public abstract java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public abstract com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)"
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
+ "public final java.lang.StringBuilder toString(com.yahoo.searchlib.rankingexpression.rule.SerializationContext)",
+ "public final boolean equals(java.lang.Object)",
+ "public abstract int hashCode()",
+ "public final java.lang.String toString()"
],
"fields": []
},
@@ -1438,6 +1442,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.FunctionNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1472,12 +1477,13 @@
],
"methods": [
"public void <init>(com.yahoo.tensor.TensorType, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode$LongListToDoubleLambda asLongListToDoubleOperator()",
"public java.util.List children()",
"public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)",
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode$LongListToDoubleLambda asLongListToDoubleOperator()"
+ "public int hashCode()"
],
"fields": []
},
@@ -1500,6 +1506,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.IfNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1518,7 +1525,8 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public java.util.function.DoubleUnaryOperator asDoubleUnaryOperator()",
- "public java.util.function.DoubleBinaryOperator asDoubleBinaryOperator()"
+ "public java.util.function.DoubleBinaryOperator asDoubleBinaryOperator()",
+ "public int hashCode()"
],
"fields": []
},
@@ -1534,7 +1542,8 @@
"public java.lang.String getValue()",
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)"
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1552,6 +1561,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.NegativeNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1570,6 +1580,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.NotNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1644,6 +1655,7 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode setChildren(java.util.List)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
],
"fields": []
@@ -1663,6 +1675,7 @@
"public java.util.Optional asScalarFunction()",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
],
"fields": []
@@ -1685,7 +1698,8 @@
"public static java.util.Map wrapScalars(java.util.Map)",
"public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.util.List, java.lang.String, java.util.List, java.util.Map)",
"public static java.util.List wrapScalars(com.yahoo.tensor.TensorType, java.util.List, java.util.List)",
- "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
+ "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public int hashCode()"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 6490c69894f..207603c5038 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -97,6 +97,7 @@ public abstract class Value {
@Override
public abstract boolean equals(Object other);
+ /** Returns a hash which only depends on the content of this value. */
@Override
public abstract int hashCode();
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
index 13db51c1363..c8b20e774b5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java
@@ -11,7 +11,9 @@ import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.Arrays;
import java.util.Deque;
+import java.util.Objects;
/**
* An optimized version of a sum of consecutive decision trees.
@@ -42,8 +44,12 @@ public class GBDTForestNode extends ExpressionNode {
}
/** Returns (optimized sum of condition trees) */
+ @Override
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
return string.append("(optimized sum of condition trees of size ").append(values.length*8).append(" bytes)");
}
+ @Override
+ public int hashCode() { return Objects.hash("gbdtForest", Arrays.hashCode(values)); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
index 6c6166c2869..949e1f026f7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java
@@ -11,7 +11,9 @@ import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
+import java.util.Arrays;
import java.util.Deque;
+import java.util.Objects;
/**
* An optimized version of a decision tree.
@@ -105,4 +107,8 @@ public final class GBDTNode extends ExpressionNode {
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
return string.append("(optimized condition tree)");
}
+
+ @Override
+ public int hashCode() { return Objects.hash("gbdtNode", Arrays.hashCode(values)); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 55fe66be69e..580f42e67cb 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -13,6 +13,7 @@ import java.util.ArrayList;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
+import java.util.Objects;
/**
* A binary mathematical operation
@@ -115,6 +116,16 @@ public final class ArithmeticNode extends CompositeNode {
lhs.value = rhs.op.evaluate(lhs.value, rhs.value);
}
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> newChildren) {
+ if (children.size() != newChildren.size())
+ throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size());
+ return new ArithmeticNode(newChildren, operators);
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash(children, operators); }
+
public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) {
if ( ! (left instanceof ArithmeticNode)) return new ArithmeticNode(left, op, right);
@@ -140,12 +151,5 @@ public final class ArithmeticNode extends CompositeNode {
}
}
- @Override
- public CompositeNode setChildren(List<ExpressionNode> newChildren) {
- if (children.size() != newChildren.size())
- throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size());
- return new ArithmeticNode(newChildren, operators);
- }
-
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
index 600d3b8d408..e726a351f74 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* A node which returns the outcome of a comparison.
@@ -62,4 +63,7 @@ public class ComparisonNode extends BooleanNode {
return new ComparisonNode(children.get(0), operator, children.get(1));
}
+ @Override
+ public int hashCode() { return Objects.hash(operator, conditions); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index ffbeec37c78..46e833197f9 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
+import java.util.Objects;
/**
* A node which holds a constant (frozen) value.
@@ -55,4 +56,7 @@ public final class ConstantNode extends ExpressionNode {
return value;
}
+ @Override
+ public int hashCode() { return Objects.hash("constantNode", value); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
index 9d389a4f6e9..64a1f42a7ba 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* This class represents another expression enclosed in braces.
@@ -60,4 +61,7 @@ public final class EmbracedNode extends CompositeNode {
return new EmbracedNode(newChildren.get(0));
}
+ @Override
+ public int hashCode() { return Objects.hash("embraced", value); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
index 8e00be3f056..51067930dd0 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java
@@ -18,31 +18,13 @@ import java.util.Deque;
*/
public abstract class ExpressionNode implements Serializable {
- @Override
- public int hashCode() {
- return toString().hashCode();
- }
-
- @Override
- public final boolean equals(Object obj) {
- return obj instanceof ExpressionNode && toString().equals(obj.toString());
- }
-
- @Override
- public final String toString() {
- return toString(new SerializationContext()).toString();
- }
- public final StringBuilder toString(SerializationContext context) {
- return toString(new StringBuilder(), context, null, null);
- }
-
/**
- * Returns a script instance of this based on the supplied script functions.
+ * Returns this in serialized form.
*
* @param builder the StringBuilder that will be appended to
* @param context the serialization context
* @param path the call path to this, used for cycle detection, or null if this is a root
- * @param parent the parent node of this, or null if it a root
+ * @param parent the parent node of this, or null if it is a root
* @return the main script, referring to script instances.
*/
public abstract StringBuilder toString(StringBuilder builder, SerializationContext context, Deque<String> path, CompositeNode parent);
@@ -63,4 +45,22 @@ public abstract class ExpressionNode implements Serializable {
*/
public abstract Value evaluate(Context context);
+ public final StringBuilder toString(SerializationContext context) {
+ return toString(new StringBuilder(), context, null, null);
+ }
+
+ @Override
+ public final boolean equals(Object obj) {
+ return obj instanceof ExpressionNode && toString().equals(obj.toString());
+ }
+
+ /** Returns a hashcode computed from the data in this */
+ @Override
+ public abstract int hashCode();
+
+ @Override
+ public final String toString() {
+ return toString(new SerializationContext()).toString();
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index d32cfb51f95..5e8bfc245a7 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -14,6 +14,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* Invocation of a native function.
@@ -108,4 +109,7 @@ public final class FunctionNode extends CompositeNode {
return new FunctionNode(function, children.get(0), children.get(1));
}
+ @Override
+ public int hashCode() { return Objects.hash("functionNode", function, arguments); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
index 7ff3a71d036..8d858341976 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* A tensor generating function, whose arguments are determined by a tensor type
@@ -31,6 +32,11 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
this.generator = generator;
}
+ /** Returns this as an operator which converts a list of integers into a double. */
+ public LongListToDoubleLambda asLongListToDoubleOperator() {
+ return new LongListToDoubleLambda();
+ }
+
@Override
public List<ExpressionNode> children() {
return Collections.singletonList(generator);
@@ -57,12 +63,8 @@ public class GeneratorLambdaFunctionNode extends CompositeNode {
return generator.evaluate(context);
}
- /**
- * Returns this as an operator which converts a list of integers into a double
- */
- public LongListToDoubleLambda asLongListToDoubleOperator() {
- return new LongListToDoubleLambda();
- }
+ @Override
+ public int hashCode() { return Objects.hash("generator", type, generator); }
private class LongListToDoubleLambda implements java.util.function.Function<List<Long>, Double> {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 02d437e83bf..6f46222c1d8 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* A conditional branch of a ranking expression.
@@ -17,8 +18,9 @@ import java.util.List;
* @author bratseth
*/
public final class IfNode extends CompositeNode {
- /** [condition, trueExpression, falseExpression]*/
- private final List<ExpressionNode> asList;
+
+ /** [condition, trueExpression, falseExpression] */
+ private final List<ExpressionNode> arguments;
private final Double trueProbability;
public IfNode(ExpressionNode condition, ExpressionNode trueExpression, ExpressionNode falseExpression) {
@@ -39,19 +41,19 @@ public final class IfNode extends CompositeNode {
if (trueProbability != null && ( trueProbability < 0.0 || trueProbability > 1.0) )
throw new IllegalArgumentException("trueProbability must be a between 0.0 and 1.0, not " + trueProbability);
this.trueProbability = trueProbability;
- this.asList = List.of(condition, trueExpression, falseExpression);
+ this.arguments = List.of(condition, trueExpression, falseExpression);
}
@Override
public List<ExpressionNode> children() {
- return asList;
+ return arguments;
}
- public ExpressionNode getCondition() { return asList.get(0); }
+ public ExpressionNode getCondition() { return arguments.get(0); }
- public ExpressionNode getTrueExpression() { return asList.get(1); }
+ public ExpressionNode getTrueExpression() { return arguments.get(1); }
- public ExpressionNode getFalseExpression() { return asList.get(2); }
+ public ExpressionNode getFalseExpression() { return arguments.get(2); }
/** The average probability that the condition of this node will evaluate to true, or null if not known */
public Double getTrueProbability() { return trueProbability; }
@@ -95,4 +97,7 @@ public final class IfNode extends CompositeNode {
return new IfNode(children.get(0), children.get(1), children.get(2));
}
+ @Override
+ public int hashCode() { return Objects.hash("if", arguments, trueProbability); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index a2b86360923..9f07f146264 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -13,6 +13,7 @@ import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
+import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.DoubleBinaryOperator;
@@ -162,6 +163,8 @@ public class LambdaFunctionNode extends CompositeNode {
return Set.of();
}
+ @Override
+ public int hashCode() { return Objects.hash("lambdaFunction", arguments, functionExpression); }
private class DoubleUnaryLambda implements DoubleUnaryOperator {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
index 34c8664c0cf..fec643e81df 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Deque;
+import java.util.Objects;
/**
* An opaque name in a ranking expression. This is used to represent names passed to the context
@@ -41,4 +42,7 @@ public final class NameNode extends ExpressionNode {
throw new RuntimeException("Name nodes should never be evaluated");
}
+ @Override
+ public int hashCode() { return Objects.hash("name", name); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 9516f38a155..8d2cf7b6387 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* A node which flips the sign of the value produced from the nested expression
@@ -54,4 +55,7 @@ public class NegativeNode extends CompositeNode {
return new NegativeNode(children.get(0));
}
+ @Override
+ public int hashCode() { return Objects.hash("negative", value); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
index 8b5ae256038..ac3566c26af 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
/**
* A node which flips the logical value produced from the nested expression.
@@ -55,5 +56,8 @@ public class NotNode extends BooleanNode {
return new NotNode(children.get(0));
}
+ @Override
+ public int hashCode() { return Objects.hash("not", value); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index 31f3013b756..5da2fbfe624 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
import java.util.Deque;
import java.util.List;
+import java.util.Objects;
import java.util.function.Predicate;
/**
@@ -100,4 +101,7 @@ public class SetMembershipNode extends BooleanNode {
return new SetMembershipNode(children.get(0), children.subList(1, children.size()));
}
+ @Override
+ public int hashCode() { return Objects.hash("setMembership", testValue, setValues); }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index ce5832027b7..7b68ad7e2af 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -144,6 +144,9 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionScalarFunction(node);
}
+ @Override
+ public int hashCode() { return function.hashCode(); }
+
private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -251,6 +254,9 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public int hashCode() { return expression.hashCode(); }
+
+ @Override
public String toString(ToStringContext<Reference> c) {
ToStringContext<Reference> outermost = c;
while (outermost.parent() != null)
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index a30ee055538..4e25d8ab0e0 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1232,6 +1232,7 @@
"public static java.lang.String toStandardString(com.yahoo.tensor.Tensor)",
"public static java.lang.String contentToString(com.yahoo.tensor.Tensor)",
"public abstract boolean equals(java.lang.Object)",
+ "public abstract int hashCode()",
"public static boolean equals(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor)",
"public static boolean approxEquals(double, double, double)",
"public static boolean approxEquals(double, double)",
@@ -1563,7 +1564,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1580,7 +1582,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1597,7 +1600,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1614,7 +1618,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1644,6 +1649,7 @@
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
],
@@ -1663,7 +1669,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1678,7 +1685,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1711,7 +1719,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1730,7 +1739,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1748,6 +1758,7 @@
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
],
@@ -1764,7 +1775,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1779,7 +1791,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1799,7 +1812,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1815,7 +1829,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1832,9 +1847,10 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1861,7 +1877,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1876,7 +1893,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1919,7 +1937,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)"
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1937,7 +1956,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public final com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public boolean canOptimize(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1957,7 +1977,8 @@
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -1990,7 +2011,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2005,7 +2027,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2020,7 +2043,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2035,7 +2059,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2050,7 +2075,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2065,7 +2091,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2081,6 +2108,7 @@
"public void <init>(double)",
"public java.lang.Double apply(java.util.List)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public bridge synthetic java.lang.Object apply(java.lang.Object)"
],
"fields": []
@@ -2096,7 +2124,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2111,7 +2140,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2127,7 +2157,8 @@
"public void <init>()",
"public void <init>(double)",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2142,7 +2173,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2157,6 +2189,7 @@
"methods": [
"public java.lang.Double apply(java.util.List)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public bridge synthetic java.lang.Object apply(java.lang.Object)"
],
"fields": []
@@ -2173,6 +2206,7 @@
"public void <init>()",
"public double applyAsDouble(double)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public static double erf(double)"
],
"fields": []
@@ -2188,7 +2222,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2203,7 +2238,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2218,7 +2254,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2234,7 +2271,8 @@
"public void <init>()",
"public static double hamming(double, double)",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2250,7 +2288,8 @@
"public void <init>()",
"public void <init>(double)",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2265,7 +2304,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2280,7 +2320,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2295,7 +2336,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2310,7 +2352,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2325,7 +2368,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2340,7 +2384,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2355,7 +2400,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2370,7 +2416,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2386,6 +2433,7 @@
"public void <init>()",
"public java.lang.Double apply(java.util.List)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public bridge synthetic java.lang.Object apply(java.lang.Object)"
],
"fields": []
@@ -2401,7 +2449,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2416,7 +2465,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2431,7 +2481,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2447,7 +2498,8 @@
"public void <init>()",
"public void <init>(double, double)",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2462,7 +2514,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2477,7 +2530,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2492,7 +2546,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2507,7 +2562,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2522,7 +2578,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2537,7 +2594,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double, double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2552,6 +2610,7 @@
"methods": [
"public java.lang.Double apply(java.util.List)",
"public java.lang.String toString()",
+ "public int hashCode()",
"public bridge synthetic java.lang.Object apply(java.lang.Object)"
],
"fields": []
@@ -2567,7 +2626,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2582,7 +2642,8 @@
"methods": [
"public void <init>()",
"public double applyAsDouble(double)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2658,7 +2719,8 @@
"public java.util.Optional dimension()",
"public java.util.Optional label()",
"public java.util.Optional index()",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public int hashCode()"
],
"fields": []
},
@@ -2676,6 +2738,7 @@
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()",
"public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)"
],
"fields": []
@@ -2692,7 +2755,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
@@ -2713,7 +2777,8 @@
"public final com.yahoo.tensor.Tensor evaluate()",
"public abstract java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public java.util.Optional asScalarFunction()",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public abstract int hashCode()"
],
"fields": []
},
@@ -2759,7 +2824,8 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public int hashCode()"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index c4588b79fa9..ca396ae5bf2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -355,6 +355,10 @@ public interface Tensor {
@Override
boolean equals(Object o);
+ /** Returns a hash computed deterministically from the content of this tensor */
+ @Override
+ int hashCode();
+
/**
* Implement here to make this work across implementations.
* Implementations must override equals and call this because this is an interface and cannot override equals.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index dbc8396d701..8a9a85d343c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.Optional;
/**
@@ -62,6 +63,9 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
return name;
}
+ @Override
+ public int hashCode() { return Objects.hash("variableTensor", name, requiredType); }
+
private void verifyType(TensorType givenType) {
if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 55dd8a7bc8a..d2762ad762d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -52,4 +52,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("argmax", argument, dimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index f1f0b9d67b0..baedf41bcb8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -52,4 +52,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("argmin", argument, dimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index 09f84e6747e..176847cec93 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -111,4 +111,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("cellcast", argument, valueType); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 6d4b15be991..abf0d89c2b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -31,6 +31,191 @@ import java.util.stream.Collectors;
*/
public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+ enum DimType { common, separate, concat }
+
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
+ private final String dimension;
+
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
+ Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
+ Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
+ Objects.requireNonNull(dimension, "The dimension cannot be null");
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 2)
+ throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("concat", argumentA, argumentB, dimension); }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
+ return oldEvaluate(a, b);
+ }
+ var helper = new Helper(a, b, dimension);
+ return helper.result;
+ }
+
+ private Tensor oldEvaluate(Tensor a, Tensor b) {
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
+
+ IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
+ IndexedTensor bIndexed = (IndexedTensor) b;
+ DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+
+ Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ int[] aToIndexes = mapIndexes(a.type(), concatType);
+ int[] bToIndexes = mapIndexes(b.type(), concatType);
+ concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
+ concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
+ return builder.build();
+ }
+
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
+ Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
+ IndexedTensor.SubspaceIterator iaSubspace = ia.next();
+ TensorAddress aAddress = iaSubspace.address();
+ for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
+ IndexedTensor.SubspaceIterator ibSubspace = ib.next();
+ while (ibSubspace.hasNext()) {
+ Tensor.Cell bCell = ibSubspace.next();
+ TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
+ concatType, offset, dimension);
+ if (combinedAddress == null) continue; // incompatible
+
+ builder.cell(combinedAddress, bCell.getValue());
+ }
+ iaSubspace.reset();
+ }
+ }
+ }
+
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
+ Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
+ if ( dimension.isPresent() ) {
+ if ( ! dimension.get().isIndexed())
+ throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
+ "' requires that dimension to be indexed or absent, " +
+ "but got a tensor with type " + tensor.type());
+ return tensor;
+ }
+ else { // extend tensor with this dimension
+ if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
+ throw new IllegalArgumentException("Concat requires an indexed tensor, " +
+ "but got a tensor with type " + tensor.type());
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
+ return tensor.multiply(unitTensor);
+ }
+
+ }
+
+ /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
+ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < concatSizes.dimensions(); i++) {
+ String currentDimension = concatType.dimensions().get(i).name();
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
+ if (currentDimension.equals(concatDimension))
+ concatSizes.set(i, aSize + bSize);
+ else if (aSize != 0 && bSize != 0 && aSize!=bSize )
+ concatSizes.set(i, Math.min(aSize, bSize));
+ else
+ concatSizes.set(i, Math.max(aSize, bSize));
+ }
+ return concatSizes.build();
+ }
+
+ /**
+ * Combine two addresses, adding the offset to the concat dimension
+ *
+ * @return the combined address or null if the addresses are incompatible
+ * (in some other dimension than the concat dimension)
+ */
+ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
+ Arrays.fill(combinedLabels, -1);
+ int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
+ mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ if ( ! compatible) return null;
+ return TensorAddress.of(combinedLabels);
+ }
+
+ /**
+ * Returns the an array having one entry in order for each dimension of fromType
+ * containing the index at which toType contains the same dimension name.
+ * That is, if the returned array contains n at index i then
+ * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
+ * If some dimension in fromType is not present in toType, the corresponding index will be -1
+ */
+ // TODO: Stolen from join
+ private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ int[] toIndexes = new int[fromType.dimensions().size()];
+ for (int i = 0; i < fromType.dimensions().size(); i++)
+ toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
+ return toIndexes;
+ }
+
+ /**
+ * Maps the content in the given list to the given array, using the given index map.
+ *
+ * @return true if the mapping was successful, false if one of the destination positions was
+ * occupied by a different value
+ */
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
+ for (int i = 0; i < from.size(); i++) {
+ int toIndex = indexMap[i];
+ if (concatDimension == toIndex) {
+ to[toIndex] = from.numericLabel(i) + concatOffset;
+ }
+ else {
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
+ }
+ }
+ return true;
+ }
+
static class CellVector {
ArrayList<Double> values = new ArrayList<>();
void setValue(int ccDimIndex, double value) {
@@ -57,8 +242,6 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
- enum DimType { common, separate, concat }
-
static class SplitHow {
List<DimType> handleDims = new ArrayList<>();
long numCommon() { return handleDims.stream().filter(t -> (t == DimType.common)).count(); }
@@ -76,7 +259,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
enum CombineHow { left, right, both, concat }
List<CombineHow> combineHow = new ArrayList<>();
-
+
void aOnly(String dimName) {
if (dimName.equals(concatDimension)) {
splitInfoA.handleDims.add(DimType.concat);
@@ -160,8 +343,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
static int concatDimensionSize(CellVectorMapMap data) {
Set<Integer> sizes = new HashSet<>();
data.map.forEach((m, cvmap) ->
- cvmap.map.forEach((e, vector) ->
- sizes.add(vector.values.size())));
+ cvmap.map.forEach((e, vector) ->
+ sizes.add(vector.values.size())));
if (sizes.isEmpty()) {
return 1;
}
@@ -207,17 +390,17 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
var lhs = entry.getValue();
var rhs = b.map.get(common);
lhs.map.forEach((leftOnly, leftCells) -> {
- rhs.map.forEach((rightOnly, rightCells) -> {
- for (int i = 0; i < leftCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i);
- builder.cell(addr, leftCells.values.get(i));
- }
- for (int i = 0; i < rightCells.values.size(); i++) {
- TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
- builder.cell(addr, rightCells.values.get(i));
- }
- });
+ rhs.map.forEach((rightOnly, rightCells) -> {
+ for (int i = 0; i < leftCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i);
+ builder.cell(addr, leftCells.values.get(i));
+ }
+ for (int i = 0; i < rightCells.values.size(); i++) {
+ TensorAddress addr = combine(common, leftOnly, rightOnly, i + aConcatSize);
+ builder.cell(addr, rightCells.values.get(i));
+ }
});
+ });
}
}
return builder.build();
@@ -240,7 +423,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
commonLabels[commonIdx++] = addr.label(i);
break;
case separate:
- separateLabels[separateIdx++] = addr.label(i);
+ separateLabels[separateIdx++] = addr.label(i);
break;
case concat:
ccDimIndex = addr.numericLabel(i);
@@ -257,184 +440,4 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
}
- private final TensorFunction<NAMETYPE> argumentA, argumentB;
- private final String dimension;
-
- public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
- Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
- Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
- Objects.requireNonNull(dimension, "The dimension cannot be null");
- this.argumentA = argumentA;
- this.argumentB = argumentB;
- this.dimension = dimension;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
-
- @Override
- public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 2)
- throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat<>(arguments.get(0), arguments.get(1), dimension);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
- return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
- }
-
- @Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
- }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor a = argumentA.evaluate(context);
- Tensor b = argumentB.evaluate(context);
- if (a instanceof IndexedTensor && b instanceof IndexedTensor) {
- return oldEvaluate(a, b);
- }
- var helper = new Helper(a, b, dimension);
- return helper.result;
- }
-
- private Tensor oldEvaluate(Tensor a, Tensor b) {
- TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
-
- a = ensureIndexedDimension(dimension, a, concatType.valueType());
- b = ensureIndexedDimension(dimension, b, concatType.valueType());
-
- IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
- IndexedTensor bIndexed = (IndexedTensor) b;
- DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
-
- Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
- int[] aToIndexes = mapIndexes(a.type(), concatType);
- int[] bToIndexes = mapIndexes(b.type(), concatType);
- concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
- concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
- return builder.build();
- }
-
- private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
- int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
- Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
- IndexedTensor.SubspaceIterator iaSubspace = ia.next();
- TensorAddress aAddress = iaSubspace.address();
- for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
- IndexedTensor.SubspaceIterator ibSubspace = ib.next();
- while (ibSubspace.hasNext()) {
- Tensor.Cell bCell = ibSubspace.next();
- TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
- concatType, offset, dimension);
- if (combinedAddress == null) continue; // incompatible
-
- builder.cell(combinedAddress, bCell.getValue());
- }
- iaSubspace.reset();
- }
- }
- }
-
- private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
- Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
- if ( dimension.isPresent() ) {
- if ( ! dimension.get().isIndexed())
- throw new IllegalArgumentException("Concat in dimension '" + dimensionName +
- "' requires that dimension to be indexed or absent, " +
- "but got a tensor with type " + tensor.type());
- return tensor;
- }
- else { // extend tensor with this dimension
- if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
- throw new IllegalArgumentException("Concat requires an indexed tensor, " +
- "but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
- .indexed(dimensionName, 1)
- .build())
- .cell(1,0)
- .build();
- return tensor.multiply(unitTensor);
- }
-
- }
-
- /** Returns the concrete (not type) dimension sizes resulting from combining a and b */
- private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
- for (int i = 0; i < concatSizes.dimensions(); i++) {
- String currentDimension = concatType.dimensions().get(i).name();
- long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
- long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
- if (currentDimension.equals(concatDimension))
- concatSizes.set(i, aSize + bSize);
- else if (aSize != 0 && bSize != 0 && aSize!=bSize )
- concatSizes.set(i, Math.min(aSize, bSize));
- else
- concatSizes.set(i, Math.max(aSize, bSize));
- }
- return concatSizes.build();
- }
-
- /**
- * Combine two addresses, adding the offset to the concat dimension
- *
- * @return the combined address or null if the addresses are incompatible
- * (in some other dimension than the concat dimension)
- */
- private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, long concatOffset, String concatDimension) {
- long[] combinedLabels = new long[concatType.dimensions().size()];
- Arrays.fill(combinedLabels, -1);
- int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
- mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
- boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
- if ( ! compatible) return null;
- return TensorAddress.of(combinedLabels);
- }
-
- /**
- * Returns the an array having one entry in order for each dimension of fromType
- * containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
- * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
- * If some dimension in fromType is not present in toType, the corresponding index will be -1
- */
- // TODO: Stolen from join
- private int[] mapIndexes(TensorType fromType, TensorType toType) {
- int[] toIndexes = new int[fromType.dimensions().size()];
- for (int i = 0; i < fromType.dimensions().size(); i++)
- toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
- return toIndexes;
- }
-
- /**
- * Maps the content in the given list to the given array, using the given index map.
- *
- * @return true if the mapping was successful, false if one of the destination positions was
- * occupied by a different value
- */
- private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
- for (int i = 0; i < from.size(); i++) {
- int toIndex = indexMap[i];
- if (concatDimension == toIndex) {
- to[toIndex] = from.numericLabel(i) + concatOffset;
- }
- else {
- if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
- to[toIndex] = from.numericLabel(i);
- }
- }
- return true;
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index a0fd9272f54..92a72dfd280 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* A function which returns a constant tensor.
@@ -49,4 +50,9 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
@Override
public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); }
+ @Override
+ public int hashCode() {
+ return Objects.hash("constant", constant.hashCode());
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 92d89ec68f7..7218375de89 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -40,13 +41,16 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP
return new Generate<>(type, diagFunction);
}
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::name);
+ }
+
@Override
public String toString(ToStringContext<NAMETYPE> context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
- private Stream<String> dimensionNames() {
- return type.dimensions().stream().map(TensorType.Dimension::name);
- }
+ @Override
+ public int hashCode() { return Objects.hash("diag", type, diagFunction); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 46992115c23..c402a1bde5b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -13,6 +13,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
@@ -45,13 +46,13 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
TensorType type() { return type; }
+ abstract String contentToString(ToStringContext<NAMETYPE> context);
+
@Override
public String toString(ToStringContext<NAMETYPE> context) {
return type().toString() + ":" + contentToString(context);
}
- abstract String contentToString(ToStringContext<NAMETYPE> context);
-
/** Creates a dynamic tensor function. The cell addresses must match the type. */
public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
return new MappedDynamicTensor<>(type, cells);
@@ -98,6 +99,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("mappedDynamicTensor", type(), cells); }
+
}
private static class IndexedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> {
@@ -141,6 +145,9 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("indexedDynamicTensor", type(), cells); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
index c049e5d41da..eee037c8dba 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* The <i>expand</i> tensor function returns a tensor with a new dimension of
@@ -45,4 +46,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "expand(" + argument.toString(context) + ", " + dimensionName + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("expand", argument, dimensionName); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 54e83fa472f..3ad3e1114cc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -126,6 +126,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
return boundGenerator.toString(new GenerateToStringContext(context));
}
+ @Override
+ public int hashCode() { return Objects.hash("generate", type, freeGenerator, boundGenerator); }
+
/**
* A context for generating all the values of a tensor produced by evaluating Generate.
* This returns all the current index values as variables and falls back to delivering from the given
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 52bef482fb4..4ec5b196dbc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -80,6 +80,9 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
@Override
+ public int hashCode() { return Objects.hash("join", argumentA, argumentB, combinator); }
+
+ @Override
public TensorType type(TypeContext<NAMETYPE> context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@@ -356,7 +359,6 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
return builder.build();
}
-
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index f47202d1b9f..38cc95ac6b2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -43,4 +44,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("l1_normalize", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 8f4e2f466d4..4a676449657 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -45,4 +46,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("l2_normalize", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 46772d8cbff..68645546be9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -75,4 +75,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE
return "map(" + argument.toString(context) + ", " + mapper + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("map", argument, mapper); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 8ac6d711c48..3239ab1a70c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.Name;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -49,4 +50,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("matmul", argument1, argument2, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index adc84225a63..2b9dc709e0e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -70,11 +70,6 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
- }
-
- @Override
public TensorType type(TypeContext<NAMETYPE> context) {
return outputType(argumentA.type(context), argumentB.type(context));
}
@@ -87,6 +82,15 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return evaluate(a, b, mergedType, merger);
}
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("merge", argumentA, argumentB, merger); }
+
static Tensor evaluate(Tensor a, Tensor b, TensorType mergedType, DoubleBinaryOperator combinator) {
// Choose merge algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 18c5db8e3a7..34b8eba3e67 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -42,6 +43,9 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("random", type); }
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 45b827db900..7053eeb0a1c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
@@ -50,4 +51,9 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("range", type, rangeFunction);
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 8841cff15e9..96465de6c0f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -107,6 +107,11 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return evaluate(this.argument.evaluate(context), dimensions, aggregator);
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("reduce", argument, dimensions, aggregator);
+ }
+
static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
@@ -191,6 +196,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
/** Resets the aggregator */
public abstract void reset();
+ /** Returns a hash of this aggregator which only depends on its identity */
+ @Override
+ public abstract int hashCode();
+
}
private static class AvgAggregator extends ValueAggregator {
@@ -214,6 +223,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
valueCount = 0;
valueSum = 0.0;
}
+
+ @Override
+ public int hashCode() { return "avgAggregator".hashCode(); }
+
}
private static class CountAggregator extends ValueAggregator {
@@ -234,6 +247,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueCount = 0;
}
+
+ @Override
+ public int hashCode() { return "countAggregator".hashCode(); }
+
}
private static class MaxAggregator extends ValueAggregator {
@@ -255,6 +272,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
maxValue = Double.NEGATIVE_INFINITY;
}
+
+ @Override
+ public int hashCode() { return "maxAggregator".hashCode(); }
+
}
private static class MedianAggregator extends ValueAggregator {
@@ -288,6 +309,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
values = new ArrayList<>();
}
+ @Override
+ public int hashCode() { return "medianAggregator".hashCode(); }
+
}
private static class MinAggregator extends ValueAggregator {
@@ -310,6 +334,9 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
minValue = Double.POSITIVE_INFINITY;
}
+ @Override
+ public int hashCode() { return "minAggregator".hashCode(); }
+
}
private static class ProdAggregator extends ValueAggregator {
@@ -330,6 +357,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueProd = 1.0;
}
+
+ @Override
+ public int hashCode() { return "prodAggregator".hashCode(); }
+
}
private static class SumAggregator extends ValueAggregator {
@@ -350,6 +381,10 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
public void reset() {
valueSum = 0.0;
}
+
+ @Override
+ public int hashCode() { return "sumAggregator".hashCode(); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 7505355beed..ccb437ef5a7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Arrays;
import java.util.List;
+import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
import java.util.stream.Collectors;
@@ -322,6 +323,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
Reduce.commaSeparated(dimensions) + ")";
}
+ @Override
+ public int hashCode() {
+ return Objects.hash("reduce_join", argumentA, argumentB, combinator, aggregator, dimensions);
+ }
+
private static class MultiDimensionIterator {
private final long[] bounds;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index a434ecba5cc..023e91e424f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -127,12 +127,6 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return TensorAddress.of(reorderedLabels);
}
- @Override
- public String toString(ToStringContext<NAMETYPE> context) {
- return "rename(" + argument.toString(context) + ", " +
- toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
- }
-
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
@@ -144,4 +138,13 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return b.toString();
}
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "rename(" + argument.toString(context) + ", " +
+ toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("rename", argument, fromDimensions, toDimensions); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 517f6683cbf..2639e153923 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import java.util.Comparator;
import java.util.List;
+import java.util.Objects;
import java.util.PriorityQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
@@ -75,6 +76,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left + right; }
@Override
public String toString() { return "f(a,b)(a + b)"; }
+ @Override
+ public int hashCode() { return "add".hashCode(); }
}
public static class Equal implements DoubleBinaryOperator {
@@ -82,6 +85,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a==b)"; }
+ @Override
+ public int hashCode() { return "equal".hashCode(); }
}
public static class Greater implements DoubleBinaryOperator {
@@ -89,6 +94,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a > b)"; }
+ @Override
+ public int hashCode() { return "greater".hashCode(); }
}
public static class Less implements DoubleBinaryOperator {
@@ -96,6 +103,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
@Override
public String toString() { return "f(a,b)(a < b)"; }
+ @Override
+ public int hashCode() { return "less".hashCode(); }
}
public static class Max implements DoubleBinaryOperator {
@@ -103,6 +112,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@Override
public String toString() { return "f(a,b)(max(a, b))"; }
+ @Override
+ public int hashCode() { return "max".hashCode(); }
}
public static class Min implements DoubleBinaryOperator {
@@ -110,6 +121,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.min(left, right); }
@Override
public String toString() { return "f(a,b)(min(a, b))"; }
+ @Override
+ public int hashCode() { return "min".hashCode(); }
}
public static class Mean implements DoubleBinaryOperator {
@@ -117,6 +130,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return (left + right) / 2; }
@Override
public String toString() { return "f(a,b)((a + b) / 2)"; }
+ @Override
+ public int hashCode() { return "mean".hashCode(); }
}
public static class Multiply implements DoubleBinaryOperator {
@@ -124,6 +139,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
+ @Override
+ public int hashCode() { return "multiply".hashCode(); }
}
public static class Pow implements DoubleBinaryOperator {
@@ -131,6 +148,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
@Override
public String toString() { return "f(a,b)(pow(a, b))"; }
+ @Override
+ public int hashCode() { return "pow".hashCode(); }
}
public static class Divide implements DoubleBinaryOperator {
@@ -138,6 +157,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left / right; }
@Override
public String toString() { return "f(a,b)(a / b)"; }
+ @Override
+ public int hashCode() { return "divide".hashCode(); }
}
public static class SquaredDifference implements DoubleBinaryOperator {
@@ -145,6 +166,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return (left - right) * (left - right); }
@Override
public String toString() { return "f(a,b)((a-b) * (a-b))"; }
+ @Override
+ public int hashCode() { return "squareddifference".hashCode(); }
}
public static class Subtract implements DoubleBinaryOperator {
@@ -152,6 +175,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return left - right; }
@Override
public String toString() { return "f(a,b)(a - b)"; }
+ @Override
+ public int hashCode() { return "subtract".hashCode(); }
}
@@ -172,6 +197,8 @@ public class ScalarFunctions {
public double applyAsDouble(double left, double right) { return hamming(left, right); }
@Override
public String toString() { return "f(a,b)(hamming(a,b))"; }
+ @Override
+ public int hashCode() { return "hamming".hashCode(); }
}
@@ -182,6 +209,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.abs(operand); }
@Override
public String toString() { return "f(a)(fabs(a))"; }
+ @Override
+ public int hashCode() { return "abs".hashCode(); }
}
public static class Acos implements DoubleUnaryOperator {
@@ -189,6 +218,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
public String toString() { return "f(a)(acos(a))"; }
+ @Override
+ public int hashCode() { return "acos".hashCode(); }
}
public static class Asin implements DoubleUnaryOperator {
@@ -196,6 +227,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.asin(operand); }
@Override
public String toString() { return "f(a)(asin(a))"; }
+ @Override
+ public int hashCode() { return "asin".hashCode(); }
}
public static class Atan implements DoubleUnaryOperator {
@@ -203,6 +236,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.atan(operand); }
@Override
public String toString() { return "f(a)(atan(a))"; }
+ @Override
+ public int hashCode() { return "atan".hashCode(); }
}
public static class Ceil implements DoubleUnaryOperator {
@@ -210,6 +245,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.ceil(operand); }
@Override
public String toString() { return "f(a)(ceil(a))"; }
+ @Override
+ public int hashCode() { return "ceil".hashCode(); }
}
public static class Cos implements DoubleUnaryOperator {
@@ -217,6 +254,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.cos(operand); }
@Override
public String toString() { return "f(a)(cos(a))"; }
+ @Override
+ public int hashCode() { return "cos".hashCode(); }
}
public static class Elu implements DoubleUnaryOperator {
@@ -231,6 +270,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return operand < 0 ? alpha * (Math.exp(operand) - 1) : operand; }
@Override
public String toString() { return "f(a)(if(a < 0, " + alpha + " * (exp(a)-1), a))"; }
+ @Override
+ public int hashCode() { return Objects.hash("elu", alpha); }
}
public static class Exp implements DoubleUnaryOperator {
@@ -238,6 +279,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
public String toString() { return "f(a)(exp(a))"; }
+ @Override
+ public int hashCode() { return "exp".hashCode(); }
}
public static class Floor implements DoubleUnaryOperator {
@@ -245,6 +288,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.floor(operand); }
@Override
public String toString() { return "f(a)(floor(a))"; }
+ @Override
+ public int hashCode() { return "floor".hashCode(); }
}
public static class Log implements DoubleUnaryOperator {
@@ -252,6 +297,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.log(operand); }
@Override
public String toString() { return "f(a)(log(a))"; }
+ @Override
+ public int hashCode() { return "log".hashCode(); }
}
public static class Neg implements DoubleUnaryOperator {
@@ -259,6 +306,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return -operand; }
@Override
public String toString() { return "f(a)(-a)"; }
+ @Override
+ public int hashCode() { return "neg".hashCode(); }
}
public static class Reciprocal implements DoubleUnaryOperator {
@@ -266,6 +315,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / operand; }
@Override
public String toString() { return "f(a)(1 / a)"; }
+ @Override
+ public int hashCode() { return "reciprocal".hashCode(); }
}
public static class Relu implements DoubleUnaryOperator {
@@ -273,6 +324,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.max(operand, 0); }
@Override
public String toString() { return "f(a)(max(0, a))"; }
+ @Override
+ public int hashCode() { return "relu".hashCode(); }
}
public static class Selu implements DoubleUnaryOperator {
@@ -290,6 +343,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return scale * (operand >= 0.0 ? operand : alpha * (Math.exp(operand)-1)); }
@Override
public String toString() { return "f(a)(" + scale + " * if(a >= 0, a, " + alpha + " * (exp(a) - 1)))"; }
+ @Override
+ public int hashCode() { return Objects.hash("selu", scale, alpha); }
}
public static class LeakyRelu implements DoubleUnaryOperator {
@@ -304,6 +359,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.max(alpha * operand, operand); }
@Override
public String toString() { return "f(a)(max(" + alpha + " * a, a))"; }
+ @Override
+ public int hashCode() { return Objects.hash("leakyrelu", alpha); }
}
public static class Sin implements DoubleUnaryOperator {
@@ -311,6 +368,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.sin(operand); }
@Override
public String toString() { return "f(a)(sin(a))"; }
+ @Override
+ public int hashCode() { return "sin".hashCode(); }
}
public static class Rsqrt implements DoubleUnaryOperator {
@@ -318,6 +377,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(1.0 / sqrt(a))"; }
+ @Override
+ public int hashCode() { return "rsqrt".hashCode(); }
}
public static class Sigmoid implements DoubleUnaryOperator {
@@ -325,6 +386,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return 1.0 / (1.0 + Math.exp(-operand)); }
@Override
public String toString() { return "f(a)(1 / (1 + exp(-a)))"; }
+ @Override
+ public int hashCode() { return "sigmoid".hashCode(); }
}
public static class Sqrt implements DoubleUnaryOperator {
@@ -332,6 +395,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
@Override
public String toString() { return "f(a)(sqrt(a))"; }
+ @Override
+ public int hashCode() { return "sqrt".hashCode(); }
}
public static class Square implements DoubleUnaryOperator {
@@ -339,6 +404,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return operand * operand; }
@Override
public String toString() { return "f(a)(a * a)"; }
+ @Override
+ public int hashCode() { return "square".hashCode(); }
}
public static class Tan implements DoubleUnaryOperator {
@@ -346,6 +413,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.tan(operand); }
@Override
public String toString() { return "f(a)(tan(a))"; }
+ @Override
+ public int hashCode() { return "tan".hashCode(); }
}
public static class Tanh implements DoubleUnaryOperator {
@@ -353,6 +422,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return Math.tanh(operand); }
@Override
public String toString() { return "f(a)(tanh(a))"; }
+ @Override
+ public int hashCode() { return "tanh".hashCode(); }
}
public static class Erf implements DoubleUnaryOperator {
@@ -410,6 +481,8 @@ public class ScalarFunctions {
public double applyAsDouble(double operand) { return erf(operand); }
@Override
public String toString() { return "f(a)(erf(a))"; }
+ @Override
+ public int hashCode() { return "erf".hashCode(); }
static final double nearZeroMultiplier = 2.0 / Math.sqrt(Math.PI);
@@ -464,6 +537,8 @@ public class ScalarFunctions {
}
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("equal", argumentNames); }
}
public static class Random implements Function<List<Long>, Double> {
@@ -473,6 +548,8 @@ public class ScalarFunctions {
}
@Override
public String toString() { return "random"; }
+ @Override
+ public int hashCode() { return "random".hashCode(); }
}
public static class SumElements implements Function<List<Long>, Double> {
@@ -492,6 +569,8 @@ public class ScalarFunctions {
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
+ @Override
+ public int hashCode() { return Objects.hash("sum", argumentNames); }
}
public static class Constant implements Function<List<Long>, Double> {
@@ -506,6 +585,8 @@ public class ScalarFunctions {
}
@Override
public String toString() { return Double.toString(value); }
+ @Override
+ public int hashCode() { return Objects.hash("constant", value); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index e3464255fac..39bddc3a3cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -166,6 +166,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return b.toString();
}
+ @Override
+ public int hashCode() { return Objects.hash("slice", argument, subspaceAddress); }
+
public static class DimensionValue<NAMETYPE extends Name> {
private final Optional<String> dimension;
@@ -255,6 +258,10 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return index.toString(context);
}
+ @Override
+ public int hashCode() { return Objects.hash(dimension, label, index); }
+
+
}
private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
@@ -273,6 +280,9 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public String toString() { return String.valueOf(value); }
+ @Override
+ public int hashCode() { return Objects.hash("constantIntegerFunction", value); }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index 9ea9040831b..df8cd6d39cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -50,4 +51,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("softmax", argument, dimension); }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 1e1d1d3b5b9..503f414d8eb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -68,4 +68,8 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
@Override
public String toString() { return toString(ToStringContext.empty()); }
+ /** Returns a hashcode computed from the data in this */
+ @Override
+ public abstract int hashCode();
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 0223ad4d588..bd4fc7b8336 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -5,6 +5,7 @@ import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.Name;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -51,4 +52,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
dimension + ")";
}
+ @Override
+ public int hashCode() { return Objects.hash("xwplusb", x, w, b, dimension); }
+
}