summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-01 12:53:32 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-01 12:53:32 +0100
commit237dfb95f61f572bcc45cf98fcb2c1b3af473cac (patch)
tree3efc15c5c054aba38e718e27ec4c70fe04ea5263 /searchlib
parent97a57faf30866ff14d2bb35b5b58ba4e88e64c9f (diff)
Allow type generalizations in if
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java12
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java1
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java54
4 files changed, 75 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index 1ec6ea4693b..c8d90e8c4e8 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -8,6 +8,8 @@ import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParser;
import com.yahoo.searchlib.rankingexpression.parser.TokenMgrError;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.io.File;
import java.io.FileNotFoundException;
@@ -265,6 +267,16 @@ public class RankingExpression implements Serializable {
}
/**
+ * Validates the type correctness of the given expression with the given context and
+ * returns the type this expression will produce from the given type context
+ *
+ * @throws IllegalArgumentException if this expression is not type correct in this context
+ */
+ public TensorType type(TypeContext context) {
+ return root.type(context);
+ }
+
+ /**
* Returns the value of evaluating this expression over the given context.
*
* @param context The variable bindings to use for this evaluation.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
index 076df327044..4f0ebc1c7e5 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java
@@ -49,7 +49,7 @@ public final class IfNode extends CompositeNode {
@Override
public List<ExpressionNode> children() {
- List<ExpressionNode> children = new ArrayList<ExpressionNode>(4);
+ List<ExpressionNode> children = new ArrayList<>(4);
children.add(condition);
children.add(trueExpression);
children.add(falseExpression);
@@ -78,11 +78,13 @@ public final class IfNode extends CompositeNode {
public TensorType type(TypeContext context) {
TensorType trueType = trueExpression.type(context);
TensorType falseType = falseExpression.type(context);
- if ( ! trueType.equals(falseType))
- throw new IllegalArgumentException("An if expression must produce a value of the same type in both " +
- "alternatives, but the 'true' type is " + trueType + " while the " +
- "'false' type is " + falseType);
- return trueType;
+
+ // Types of each branch must be compatible; the resulting type is the most general
+ if (trueType.isAssignableTo(falseType)) return falseType;
+ if (falseType.isAssignableTo(trueType)) return trueType;
+ throw new IllegalArgumentException("An if expression must produce compatible types in both " +
+ "alternatives, but the 'true' type is " + trueType + " while the " +
+ "'false' type is " + falseType);
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 3aa2d144f1f..6c7643b37b3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
new file mode 100644
index 00000000000..d1ea0fcf2e4
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.searchlib.rankingexpression.evaluation;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class TypeResolutionTestCase {
+
+ @Test
+ public void testTypeResolution() {
+ TypeMapContext context = new TypeMapContext();
+ context.setType("query('x1')", TensorType.fromSpec("tensor(x[])"));
+ context.setType("query('x2')", TensorType.fromSpec("tensor(x[10])"));
+ context.setType("query('y1')", TensorType.fromSpec("tensor(y[])"));
+
+ assertType("tensor(x[])", "query(x1)", context);
+ assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context);
+ assertIncompatibleType("if (1>0, query(x1), query(y1))", context);
+ }
+
+ private void assertType(String type, String expression, TypeContext context) {
+ try {
+ assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ private void assertIncompatibleType(String expression, TypeContext context) {
+ try {
+ new RankingExpression(expression).type(context);
+ fail("Expected type incompatibility exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[]) while the 'false' type is tensor(y[])",
+ expected.getMessage());
+ }
+ catch (ParseException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+}