diff options
Diffstat (limited to 'searchlib/src')
13 files changed, 298 insertions, 63 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 0ed2bdd6331..ea750295423 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -44,6 +44,26 @@ public abstract class DoubleCompatibleValue extends Value { } @Override + public Value and(Value value) { + return new BooleanValue(asBoolean() && value.asBoolean()); + } + + @Override + public Value or(Value value) { + return new BooleanValue(asBoolean() || value.asBoolean()); + } + + @Override + public Value not() { + return new BooleanValue(!asBoolean()); + } + + @Override + public Value power(Value value) { + return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); + } + + @Override public Value compare(TruthOperator operator, Value value) { return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 5374a9d3ce6..ac8aba6a617 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -54,22 +54,42 @@ public class StringValue extends Value { @Override public Value subtract(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support subtraction"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); } @Override public Value multiply(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support multiplication"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); } @Override public Value divide(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support division"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); } @Override public Value modulo(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support modulo"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + } + + @Override + public Value and(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); + } + + @Override + public Value or(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); + } + + @Override + public Value not() { + throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); + } + + @Override + public Value power(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b283603e713..49c3ccb7b01 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -89,6 +89,34 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> value % argument.asDouble())); } + @Override + public Value and(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value or(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value not() { + return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); + } + + @Override + public Value power(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.pow(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble()))); + } private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index f42082321b3..b2ccbe572d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -43,6 +43,14 @@ public abstract class Value { public abstract Value modulo(Value value); + public abstract Value and(Value value); + + public abstract Value or(Value value); + + public abstract Value not(); + + public abstract Value power(Value value); + /** Perform the comparison specified by the operator between this value and the given value */ public abstract Value compare(TruthOperator operator, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index 91d8abec1be..518a15bcc87 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -77,7 +77,7 @@ public final class ArithmeticNode extends CompositeNode { Iterator<ExpressionNode> child = children.iterator(); Deque<ValueItem> stack = new ArrayDeque<>(); - stack.push(new ValueItem(ArithmeticOperator.PLUS, child.next().evaluate(context))); + stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context))); for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) { ArithmeticOperator op = it.next(); if (!stack.isEmpty()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java index 2187a96ba4d..a715490e95a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java @@ -14,22 +14,31 @@ import java.util.List; */ public enum ArithmeticOperator { - PLUS(0, "+") { public Value evaluate(Value x, Value y) { + OR(0, "||") { public Value evaluate(Value x, Value y) { + return x.or(y); + }}, + AND(1, "&&") { public Value evaluate(Value x, Value y) { + return x.and(y); + }}, + PLUS(2, "+") { public Value evaluate(Value x, Value y) { return x.add(y); }}, - MINUS(1, "-") { public Value evaluate(Value x, Value y) { + MINUS(3, "-") { public Value evaluate(Value x, Value y) { return x.subtract(y); }}, - MULTIPLY(2, "*") { public Value evaluate(Value x, Value y) { + MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) { return x.multiply(y); }}, - DIVIDE(3, "/") { public Value evaluate(Value x, Value y) { + DIVIDE(5, "/") { public Value evaluate(Value x, Value y) { return x.divide(y); }}, - MODULO(4, "%") { public Value evaluate(Value x, Value y) { + MODULO(6, "%") { public Value evaluate(Value x, Value y) { return x.modulo(y); + }}, + POWER(7, "^") { public Value evaluate(Value x, Value y) { + return x.power(y); }}; - + /** A list of all the operators in this in order of decreasing precedence */ public static final List<ArithmeticOperator> operatorsByPrecedence = operatorsByPrecedence(); @@ -55,11 +64,14 @@ public enum ArithmeticOperator { private static List<ArithmeticOperator> operatorsByPrecedence() { List<ArithmeticOperator> operators = new ArrayList<>(); + operators.add(POWER); operators.add(MODULO); operators.add(DIVIDE); operators.add(MULTIPLY); operators.add(MINUS); operators.add(PLUS); + operators.add(AND); + operators.add(OR); return Collections.unmodifiableList(operators); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java new file mode 100644 index 00000000000..8c459a032bd --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -0,0 +1,50 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; + +/** + * A node which flips the logical value produced from the nested expression. + * + * @author lesters + */ +public class NotNode extends BooleanNode { + + private final ExpressionNode value; + + public NotNode(ExpressionNode value) { + this.value = value; + } + + public ExpressionNode getValue() { + return value; + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(value); + } + + @Override + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + return "!" + value.toString(context, path, parent); + } + + @Override + public Value evaluate(Context context) { + return value.evaluate(context).not(); + } + + @Override + public NotNode setChildren(List<ExpressionNode> children) { + if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size()); + return new NotNode(children.get(0)); + } + +} + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index f8e44f1087c..f6b1a1a8979 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; -import java.util.*; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.function.Predicate; /** * A node which returns true or false depending on a set membership test @@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode { @Override public Value evaluate(Context context) { Value value = testValue.evaluate(context); + if (value instanceof TensorValue) { + return evaluateTensor(((TensorValue) value).asTensor(), context); + } + return evaluateValue(value, context); + } + + private Value evaluateValue(Value value, Context context) { + return new BooleanValue(testMembership(value::equals, context)); + } + + private Value evaluateTensor(Tensor tensor, Context context) { + return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0)); + } + + private boolean contains(double value, Context context) { + return testMembership((setValue) -> setValue.asDouble() == value, context); + } + + private boolean testMembership(Predicate<Value> test, Context context) { for (ExpressionNode setValue : setValues) { - if (setValue.evaluate(context).equals(value)) - return new BooleanValue(true); + if (test.test(setValue.evaluate(context))) + return true; } - return new BooleanValue(false); + return false; } @Override diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 01fed00202c..7821ab88b86 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -66,6 +66,7 @@ TOKEN : <MUL: "*"> | <DOT: "."> | <MOD: "%"> | + <POWOP: "^"> | <DOLLAR: "$"> | <COMMA: ","> | @@ -86,6 +87,10 @@ TOKEN : <IN: "in"> | <F: "f"> | + <NOT: "!"> | + <AND: "&&"> | + <OR: "||"> | + <ABS: "abs"> | <ACOS: "acos"> | <ASIN: "asin"> | @@ -200,11 +205,14 @@ ExpressionNode arithmeticExpression() : ArithmeticOperator arithmetic() : { } { - ( <ADD> { return ArithmeticOperator.PLUS; } | - <SUB> { return ArithmeticOperator.MINUS; } | - <DIV> { return ArithmeticOperator.DIVIDE; } | - <MUL> { return ArithmeticOperator.MULTIPLY; } | - <MOD> { return ArithmeticOperator.MODULO; } ) + ( <ADD> { return ArithmeticOperator.PLUS; } | + <SUB> { return ArithmeticOperator.MINUS; } | + <DIV> { return ArithmeticOperator.DIVIDE; } | + <MUL> { return ArithmeticOperator.MULTIPLY; } | + <MOD> { return ArithmeticOperator.MODULO; } | + <AND> { return ArithmeticOperator.AND; } | + <OR> { return ArithmeticOperator.OR; } | + <POWOP> { return ArithmeticOperator.POWER; } ) { return null; } } @@ -224,16 +232,23 @@ ExpressionNode value() : { ExpressionNode ret; boolean neg = false; + boolean not = false; } { - ( [ LOOKAHEAD(2) <SUB> { neg = true; } ] - ( ret = constantPrimitive() | - LOOKAHEAD(2) ret = ifExpression() | - LOOKAHEAD(4) ret = function() | - ret = feature() | - ret = queryFeature() | + ( + [ <NOT> { not = true; } ] + [ LOOKAHEAD(2) <SUB> { neg = true; } ] + ( ret = constantPrimitive() | + LOOKAHEAD(2) ret = ifExpression() | + LOOKAHEAD(4) ret = function() | + ret = feature() | + ret = queryFeature() | ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) ) - { return neg ? new NegativeNode(ret) : ret; } + { + ret = not ? new NotNode(ret) : ret; + ret = neg ? new NegativeNode(ret) : ret; + return ret; + } } IfNode ifExpression() : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 5d357777657..82e5d0cfe5b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -29,6 +29,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0.75, "0.5 + 0.25"); tester.assertEvaluates(0.75, "one_half + a_quarter"); tester.assertEvaluates(1.25, "0.5 - 0.25 + one"); + tester.assertEvaluates(9.0, "3 ^ 2"); // String tester.assertEvaluates(1, "if(\"a\"==\"a\",1,0)"); @@ -37,6 +38,9 @@ public class EvaluationTestCase { tester.assertEvaluates(26, "2*3+4*5"); tester.assertEvaluates(1, "2/6+4/6"); tester.assertEvaluates(2 * 3 * 4 + 3 * 4 * 5 - 4 * 200 / 10, "2*3*4+3*4*5-4*200/10"); + tester.assertEvaluates(3, "1 + 10 % 6 / 2"); + tester.assertEvaluates(10.0, "3 ^ 2 + 1"); + tester.assertEvaluates(18.0, "2 * 3 ^ 2"); // Conditionals tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10"); @@ -89,6 +93,38 @@ public class EvaluationTestCase { } @Test + public void testBooleanEvaluation() { + EvaluationTester tester = new EvaluationTester(); + + // and + tester.assertEvaluates(false, "0 && 0"); + tester.assertEvaluates(false, "0 && 1"); + tester.assertEvaluates(false, "1 && 0"); + tester.assertEvaluates(true, "1 && 1"); + tester.assertEvaluates(true, "1 && 2"); + tester.assertEvaluates(true, "1 && 0.1"); + + // or + tester.assertEvaluates(false, "0 || 0"); + tester.assertEvaluates(true, "0 || 0.1"); + tester.assertEvaluates(true, "0 || 1"); + tester.assertEvaluates(true, "1 || 0"); + tester.assertEvaluates(true, "1 || 1"); + + // not + tester.assertEvaluates(true, "!0"); + tester.assertEvaluates(false, "!1"); + tester.assertEvaluates(false, "!2"); + tester.assertEvaluates(true, "!0 && 1"); + + // precedence + tester.assertEvaluates(0, "2 * (0 && 1)"); + tester.assertEvaluates(2, "2 * (1 && 1)"); + tester.assertEvaluates(true, "2 + 0 && 1"); + tester.assertEvaluates(true, "1 && 0 + 2"); + } + + @Test public void testTensorEvaluation() { EvaluationTester tester = new EvaluationTester(); tester.assertEvaluates("{}", "tensor0", "{}"); @@ -107,6 +143,16 @@ public class EvaluationTestCase { "min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }", "max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); + // operators + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); @@ -122,8 +168,9 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }"); - tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }"); + tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }"); + tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }"); tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }"); @@ -201,6 +248,16 @@ public class EvaluationTestCase { "max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }", + "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }"); tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }", "tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }"); tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }", @@ -225,8 +282,13 @@ public class EvaluationTestCase { "tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", "tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", + "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }", "tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }", + "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }"); + // TODO // argmax // argmin diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index d67c9dfd9dc..ee2b1c147e3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -58,10 +58,18 @@ public class EvaluationTester { return assertEvaluates(value, expressionString, defaultContext); } + public RankingExpression assertEvaluates(boolean value, String expressionString) { + return assertEvaluates(value, expressionString, defaultContext); + } + public RankingExpression assertEvaluates(double value, String expressionString, Context context) { return assertEvaluates(new DoubleValue(value), expressionString, context, ""); } + public RankingExpression assertEvaluates(boolean value, String expressionString, Context context) { + return assertEvaluates(new BooleanValue(value), expressionString, context, ""); + } + public RankingExpression assertEvaluates(Value value, String expressionString, Context context, String explanation) { try { RankingExpression expression = new RankingExpression(expressionString); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index 27aaeb776e4..dde9d4bf21e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -18,6 +18,7 @@ import org.junit.Test; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; @@ -51,32 +52,32 @@ public class TensorConformanceTest { count++; } } - if (failList.size() > 0) { - System.out.println("Conformance test fails:"); - System.out.println(failList); - } - - // Disable this for now: - //assertEquals(0, failList.size()); + assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size()); } - private boolean testCase(String test, int count) throws IOException { + private boolean testCase(String test, int count) { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); - } else if (node.has("expression")) { - String expression = node.get("expression").asText(); - MapContext context = getInput(node.get("inputs")); - Tensor expect = getTensor(node.get("result").get("expect").asText()); - Tensor result = evaluate(expression, context); - boolean equals = Tensor.equals(result, expect); - if (!equals) { - System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); - } - return Tensor.equals(result, expect); + return true; + } + if (!node.has("expression")) { + return true; // ignore } + + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return equals; + } catch (Exception e) { System.out.println(count + " : " + e.toString()); } @@ -133,22 +134,5 @@ public class TensorConformanceTest { throw new IllegalArgumentException("Hex contains illegal characters"); } - private static String valueType(Value value) { - if (value instanceof StringValue) { - return "string"; - } - if (value instanceof BooleanValue) { - return "boolean"; - } - if (value instanceof DoubleCompatibleValue) { - return "double"; - } - if (value instanceof TensorValue) { - return ((TensorValue)value).asTensor().type().toString(); - } - return "unknown"; - } - - } diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist index 327f2b161cd..77b2294c668 100644 --- a/searchlib/src/tests/rankingexpression/rankingexpressionlist +++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist @@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4) "1008\x1977" "100819\x77" if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3) +10 % 3 +1 && 0 || 1 +!a && (a || a) +10 ^ 3 |