diff options
4 files changed, 75 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index 1ec6ea4693b..c8d90e8c4e8 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -8,6 +8,8 @@ import com.yahoo.searchlib.rankingexpression.parser.RankingExpressionParser; import com.yahoo.searchlib.rankingexpression.parser.TokenMgrError; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.File; import java.io.FileNotFoundException; @@ -265,6 +267,16 @@ public class RankingExpression implements Serializable { } /** + * Validates the type correctness of the given expression with the given context and + * returns the type this expression will produce from the given type context + * + * @throws IllegalArgumentException if this expression is not type correct in this context + */ + public TensorType type(TypeContext context) { + return root.type(context); + } + + /** * Returns the value of evaluating this expression over the given context. * * @param context The variable bindings to use for this evaluation. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 076df327044..4f0ebc1c7e5 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -49,7 +49,7 @@ public final class IfNode extends CompositeNode { @Override public List<ExpressionNode> children() { - List<ExpressionNode> children = new ArrayList<ExpressionNode>(4); + List<ExpressionNode> children = new ArrayList<>(4); children.add(condition); children.add(trueExpression); children.add(falseExpression); @@ -78,11 +78,13 @@ public final class IfNode extends CompositeNode { public TensorType type(TypeContext context) { TensorType trueType = trueExpression.type(context); TensorType falseType = falseExpression.type(context); - if ( ! trueType.equals(falseType)) - throw new IllegalArgumentException("An if expression must produce a value of the same type in both " + - "alternatives, but the 'true' type is " + trueType + " while the " + - "'false' type is " + falseType); - return trueType; + + // Types of each branch must be compatible; the resulting type is the most general + if (trueType.isAssignableTo(falseType)) return falseType; + if (falseType.isAssignableTo(trueType)) return trueType; + throw new IllegalArgumentException("An if expression must produce compatible types in both " + + "alternatives, but the 'true' type is " + trueType + " while the " + + "'false' type is " + falseType); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 3aa2d144f1f..6c7643b37b3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java new file mode 100644 index 00000000000..d1ea0fcf2e4 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.searchlib.rankingexpression.evaluation; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class TypeResolutionTestCase { + + @Test + public void testTypeResolution() { + TypeMapContext context = new TypeMapContext(); + context.setType("query('x1')", TensorType.fromSpec("tensor(x[])")); + context.setType("query('x2')", TensorType.fromSpec("tensor(x[10])")); + context.setType("query('y1')", TensorType.fromSpec("tensor(y[])")); + + assertType("tensor(x[])", "query(x1)", context); + assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context); + assertIncompatibleType("if (1>0, query(x1), query(y1))", context); + } + + private void assertType(String type, String expression, TypeContext context) { + try { + assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context)); + } + catch (ParseException e) { + throw new RuntimeException(e); + } + } + + private void assertIncompatibleType(String expression, TypeContext context) { + try { + new RankingExpression(expression).type(context); + fail("Expected type incompatibility exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[]) while the 'false' type is tensor(y[])", + expected.getMessage()); + } + catch (ParseException e) { + throw new RuntimeException(e); + } + } + +} |