aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-06-13 07:12:23 +0200
committerGitHub <noreply@github.com>2019-06-13 07:12:23 +0200
commit0e20abc33aa88066aedadd43b09353d115d5928b (patch)
tree159df75408a4eb188b13a532fb42b8896083cfc9 /searchlib/src/main
parent17b6704b20a073a4961baefd1be58dd48012bec4 (diff)
Revert "Revert "Require constant() for large constants and fix a type resolving bug""
Diffstat (limited to 'searchlib/src/main')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java6
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java93
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java12
6 files changed, 121 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
index c4f3a75f2f8..2aedec2109b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java
@@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.transform.TensorMaxMinTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Join;
@@ -67,6 +68,11 @@ public final class FunctionNode extends CompositeNode {
@Override
public TensorType type(TypeContext<Reference> context) {
+ // Check if this node should be interpreted as tensor reduce, as this impacts the type
+ ExpressionNode thisTransformed = TensorMaxMinTransformer.transformFunctionNode(this, context);
+ if (thisTransformed != this)
+ return thisTransformed.type(context);
+
if (arguments.expressions().size() == 0)
return TensorType.empty;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 28dc623be72..92c6d6f8638 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -85,7 +85,9 @@ public final class IfNode extends CompositeNode {
return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() ->
new IllegalArgumentException("An if expression must produce compatible types in both " +
"alternatives, but the 'true' type is " + trueType + " while the " +
- "'false' type is " + falseType)
+ "'false' type is " + falseType +
+ "\n'true' branch: " + trueExpression +
+ "\n'false' branch: " + falseExpression)
);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index eb8d2229a6d..e15ce158e83 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -95,7 +95,13 @@ public final class ReferenceNode extends CompositeNode {
@Override
public TensorType type(TypeContext<Reference> context) {
- TensorType type = context.getType(reference);
+ TensorType type = null;
+ try {
+ type = context.getType(reference);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(reference + " is invalid", e);
+ }
if (type == null)
throw new IllegalArgumentException("Unknown feature '" + toString() + "'");
return type;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
index 22d314bcb28..31567ba120b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java
@@ -10,7 +10,7 @@ import java.util.List;
/**
* Superclass of expression transformers. The scope (lifetime) of a transformer instance is a single compilation
- * of alle the expressions in one rank profile.
+ * of all the expressions in one rank profile.
*
* @author bratseth
*/
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
new file mode 100644
index 00000000000..979c5b0f88c
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -0,0 +1,93 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.transform;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.NameNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Reduce;
+
+import java.util.Optional;
+
+/**
+ * Transforms min(tensor,dim) and max(tensor,dim) to
+ * reduce(tensor,min/max,dim). This is necessary as the backend does
+ * not recognize these forms of min and max.
+ *
+ * @author lesters
+ */
+public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends ExpressionTransformer<CONTEXT> {
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, CONTEXT context) {
+ if (node instanceof CompositeNode) {
+ node = transformChildren((CompositeNode) node, context);
+ }
+ if (node instanceof FunctionNode) {
+ node = transformFunctionNode((FunctionNode) node, context.types());
+ }
+ return node;
+ }
+
+ public static ExpressionNode transformFunctionNode(FunctionNode node, TypeContext<Reference> context) {
+ switch (node.getFunction()) {
+ case min:
+ case max:
+ return transformMaxAndMinFunctionNode(node, context);
+ }
+ return node;
+ }
+
+ /**
+ * Transforms max and min functions if the first
+ * argument returns a tensor type and the second argument is a valid
+ * dimension in the tensor.
+ */
+ private static ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, TypeContext<Reference> context) {
+ if (node.children().size() != 2) {
+ return node;
+ }
+ ExpressionNode arg1 = node.children().get(0);
+ Optional<String> dimension = dimensionName(node.children().get(1));
+ if (dimension.isPresent()) {
+ TensorType type = arg1.type(context);
+ if (type.dimension(dimension.get()).isPresent()) {
+ return replaceMaxAndMinFunction(node);
+ }
+ }
+ return node;
+ }
+
+ private static Optional<String> dimensionName(ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (reference.isIdentifier())
+ return Optional.of(reference.name());
+ else
+ return Optional.empty();
+ }
+ else if (node instanceof NameNode) {
+ return Optional.of(((NameNode)node).getValue());
+ }
+ else {
+ return Optional.empty();
+ }
+ }
+
+ private static ExpressionNode replaceMaxAndMinFunction(FunctionNode node) {
+ ExpressionNode arg1 = node.children().get(0);
+ ExpressionNode arg2 = node.children().get(1);
+
+ TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
+ Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
+ String dimension = ((ReferenceNode) arg2).getName();
+
+ return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
index 7485ce69f98..0113a650277 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java
@@ -1,7 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Map;
@@ -13,11 +15,19 @@ import java.util.Map;
public class TransformContext {
private final Map<String, Value> constants;
+ private final TypeContext<Reference> types;
- public TransformContext(Map<String, Value> constants) {
+ public TransformContext(Map<String, Value> constants, TypeContext<Reference> types) {
this.constants = constants;
+ this.types = types;
}
public Map<String, Value> constants() { return constants; }
+ /**
+ * Returns the types known in this context. We may have type information for references
+ * for which no value is available
+ */
+ public TypeContext<Reference> types() { return types; }
+
}