summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java28
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java8
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java32
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj39
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java64
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java54
-rw-r--r--searchlib/src/tests/rankingexpression/rankingexpressionlist4
13 files changed, 298 insertions, 63 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index 0ed2bdd6331..ea750295423 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -44,6 +44,26 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
+ public Value and(Value value) {
+ return new BooleanValue(asBoolean() && value.asBoolean());
+ }
+
+ @Override
+ public Value or(Value value) {
+ return new BooleanValue(asBoolean() || value.asBoolean());
+ }
+
+ @Override
+ public Value not() {
+ return new BooleanValue(!asBoolean());
+ }
+
+ @Override
+ public Value power(Value value) {
+ return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
+ }
+
+ @Override
public Value compare(TruthOperator operator, Value value) {
return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 5374a9d3ce6..ac8aba6a617 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -54,22 +54,42 @@ public class StringValue extends Value {
@Override
public Value subtract(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support subtraction");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
}
@Override
public Value multiply(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support multiplication");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
}
@Override
public Value divide(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support division");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
}
@Override
public Value modulo(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support modulo");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ }
+
+ @Override
+ public Value and(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
+ }
+
+ @Override
+ public Value or(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
+ }
+
+ @Override
+ public Value not() {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
+ }
+
+ @Override
+ public Value power(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b283603e713..49c3ccb7b01 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -89,6 +89,34 @@ public class TensorValue extends Value {
return new TensorValue(value.map((value) -> value % argument.asDouble()));
}
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value power(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.pow(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble())));
+ }
private Tensor asTensor(Value value, String operationName) {
if ( ! (value instanceof TensorValue))
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index f42082321b3..b2ccbe572d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -43,6 +43,14 @@ public abstract class Value {
public abstract Value modulo(Value value);
+ public abstract Value and(Value value);
+
+ public abstract Value or(Value value);
+
+ public abstract Value not();
+
+ public abstract Value power(Value value);
+
/** Perform the comparison specified by the operator between this value and the given value */
public abstract Value compare(TruthOperator operator, Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 91d8abec1be..518a15bcc87 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -77,7 +77,7 @@ public final class ArithmeticNode extends CompositeNode {
Iterator<ExpressionNode> child = children.iterator();
Deque<ValueItem> stack = new ArrayDeque<>();
- stack.push(new ValueItem(ArithmeticOperator.PLUS, child.next().evaluate(context)));
+ stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context)));
for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
ArithmeticOperator op = it.next();
if (!stack.isEmpty()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
index 2187a96ba4d..a715490e95a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
@@ -14,22 +14,31 @@ import java.util.List;
*/
public enum ArithmeticOperator {
- PLUS(0, "+") { public Value evaluate(Value x, Value y) {
+ OR(0, "||") { public Value evaluate(Value x, Value y) {
+ return x.or(y);
+ }},
+ AND(1, "&&") { public Value evaluate(Value x, Value y) {
+ return x.and(y);
+ }},
+ PLUS(2, "+") { public Value evaluate(Value x, Value y) {
return x.add(y);
}},
- MINUS(1, "-") { public Value evaluate(Value x, Value y) {
+ MINUS(3, "-") { public Value evaluate(Value x, Value y) {
return x.subtract(y);
}},
- MULTIPLY(2, "*") { public Value evaluate(Value x, Value y) {
+ MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) {
return x.multiply(y);
}},
- DIVIDE(3, "/") { public Value evaluate(Value x, Value y) {
+ DIVIDE(5, "/") { public Value evaluate(Value x, Value y) {
return x.divide(y);
}},
- MODULO(4, "%") { public Value evaluate(Value x, Value y) {
+ MODULO(6, "%") { public Value evaluate(Value x, Value y) {
return x.modulo(y);
+ }},
+ POWER(7, "^") { public Value evaluate(Value x, Value y) {
+ return x.power(y);
}};
-
+
/** A list of all the operators in this in order of decreasing precedence */
public static final List<ArithmeticOperator> operatorsByPrecedence = operatorsByPrecedence();
@@ -55,11 +64,14 @@ public enum ArithmeticOperator {
private static List<ArithmeticOperator> operatorsByPrecedence() {
List<ArithmeticOperator> operators = new ArrayList<>();
+ operators.add(POWER);
operators.add(MODULO);
operators.add(DIVIDE);
operators.add(MULTIPLY);
operators.add(MINUS);
operators.add(PLUS);
+ operators.add(AND);
+ operators.add(OR);
return Collections.unmodifiableList(operators);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
new file mode 100644
index 00000000000..8c459a032bd
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -0,0 +1,50 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which flips the logical value produced from the nested expression.
+ *
+ * @author lesters
+ */
+public class NotNode extends BooleanNode {
+
+ private final ExpressionNode value;
+
+ public NotNode(ExpressionNode value) {
+ this.value = value;
+ }
+
+ public ExpressionNode getValue() {
+ return value;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(value);
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "!" + value.toString(context, path, parent);
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return value.evaluate(context).not();
+ }
+
+ @Override
+ public NotNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size());
+ return new NotNode(children.get(0));
+ }
+
+}
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index f8e44f1087c..f6b1a1a8979 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.Predicate;
/**
* A node which returns true or false depending on a set membership test
@@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
Value value = testValue.evaluate(context);
+ if (value instanceof TensorValue) {
+ return evaluateTensor(((TensorValue) value).asTensor(), context);
+ }
+ return evaluateValue(value, context);
+ }
+
+ private Value evaluateValue(Value value, Context context) {
+ return new BooleanValue(testMembership(value::equals, context));
+ }
+
+ private Value evaluateTensor(Tensor tensor, Context context) {
+ return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0));
+ }
+
+ private boolean contains(double value, Context context) {
+ return testMembership((setValue) -> setValue.asDouble() == value, context);
+ }
+
+ private boolean testMembership(Predicate<Value> test, Context context) {
for (ExpressionNode setValue : setValues) {
- if (setValue.evaluate(context).equals(value))
- return new BooleanValue(true);
+ if (test.test(setValue.evaluate(context)))
+ return true;
}
- return new BooleanValue(false);
+ return false;
}
@Override
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 01fed00202c..7821ab88b86 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -66,6 +66,7 @@ TOKEN :
<MUL: "*"> |
<DOT: "."> |
<MOD: "%"> |
+ <POWOP: "^"> |
<DOLLAR: "$"> |
<COMMA: ","> |
@@ -86,6 +87,10 @@ TOKEN :
<IN: "in"> |
<F: "f"> |
+ <NOT: "!"> |
+ <AND: "&&"> |
+ <OR: "||"> |
+
<ABS: "abs"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
@@ -200,11 +205,14 @@ ExpressionNode arithmeticExpression() :
ArithmeticOperator arithmetic() : { }
{
- ( <ADD> { return ArithmeticOperator.PLUS; } |
- <SUB> { return ArithmeticOperator.MINUS; } |
- <DIV> { return ArithmeticOperator.DIVIDE; } |
- <MUL> { return ArithmeticOperator.MULTIPLY; } |
- <MOD> { return ArithmeticOperator.MODULO; } )
+ ( <ADD> { return ArithmeticOperator.PLUS; } |
+ <SUB> { return ArithmeticOperator.MINUS; } |
+ <DIV> { return ArithmeticOperator.DIVIDE; } |
+ <MUL> { return ArithmeticOperator.MULTIPLY; } |
+ <MOD> { return ArithmeticOperator.MODULO; } |
+ <AND> { return ArithmeticOperator.AND; } |
+ <OR> { return ArithmeticOperator.OR; } |
+ <POWOP> { return ArithmeticOperator.POWER; } )
{ return null; }
}
@@ -224,16 +232,23 @@ ExpressionNode value() :
{
ExpressionNode ret;
boolean neg = false;
+ boolean not = false;
}
{
- ( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
- ( ret = constantPrimitive() |
- LOOKAHEAD(2) ret = ifExpression() |
- LOOKAHEAD(4) ret = function() |
- ret = feature() |
- ret = queryFeature() |
+ (
+ [ <NOT> { not = true; } ]
+ [ LOOKAHEAD(2) <SUB> { neg = true; } ]
+ ( ret = constantPrimitive() |
+ LOOKAHEAD(2) ret = ifExpression() |
+ LOOKAHEAD(4) ret = function() |
+ ret = feature() |
+ ret = queryFeature() |
( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) )
- { return neg ? new NegativeNode(ret) : ret; }
+ {
+ ret = not ? new NotNode(ret) : ret;
+ ret = neg ? new NegativeNode(ret) : ret;
+ return ret;
+ }
}
IfNode ifExpression() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 5d357777657..82e5d0cfe5b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -29,6 +29,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0.75, "0.5 + 0.25");
tester.assertEvaluates(0.75, "one_half + a_quarter");
tester.assertEvaluates(1.25, "0.5 - 0.25 + one");
+ tester.assertEvaluates(9.0, "3 ^ 2");
// String
tester.assertEvaluates(1, "if(\"a\"==\"a\",1,0)");
@@ -37,6 +38,9 @@ public class EvaluationTestCase {
tester.assertEvaluates(26, "2*3+4*5");
tester.assertEvaluates(1, "2/6+4/6");
tester.assertEvaluates(2 * 3 * 4 + 3 * 4 * 5 - 4 * 200 / 10, "2*3*4+3*4*5-4*200/10");
+ tester.assertEvaluates(3, "1 + 10 % 6 / 2");
+ tester.assertEvaluates(10.0, "3 ^ 2 + 1");
+ tester.assertEvaluates(18.0, "2 * 3 ^ 2");
// Conditionals
tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10");
@@ -89,6 +93,38 @@ public class EvaluationTestCase {
}
@Test
+ public void testBooleanEvaluation() {
+ EvaluationTester tester = new EvaluationTester();
+
+ // and
+ tester.assertEvaluates(false, "0 && 0");
+ tester.assertEvaluates(false, "0 && 1");
+ tester.assertEvaluates(false, "1 && 0");
+ tester.assertEvaluates(true, "1 && 1");
+ tester.assertEvaluates(true, "1 && 2");
+ tester.assertEvaluates(true, "1 && 0.1");
+
+ // or
+ tester.assertEvaluates(false, "0 || 0");
+ tester.assertEvaluates(true, "0 || 0.1");
+ tester.assertEvaluates(true, "0 || 1");
+ tester.assertEvaluates(true, "1 || 0");
+ tester.assertEvaluates(true, "1 || 1");
+
+ // not
+ tester.assertEvaluates(true, "!0");
+ tester.assertEvaluates(false, "!1");
+ tester.assertEvaluates(false, "!2");
+ tester.assertEvaluates(true, "!0 && 1");
+
+ // precedence
+ tester.assertEvaluates(0, "2 * (0 && 1)");
+ tester.assertEvaluates(2, "2 * (1 && 1)");
+ tester.assertEvaluates(true, "2 + 0 && 1");
+ tester.assertEvaluates(true, "1 && 0 + 2");
+ }
+
+ @Test
public void testTensorEvaluation() {
EvaluationTester tester = new EvaluationTester();
tester.assertEvaluates("{}", "tensor0", "{}");
@@ -107,6 +143,16 @@ public class EvaluationTestCase {
"min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }",
"max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
+ // operators
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
@@ -122,8 +168,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }");
- tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }");
+ tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }");
+ tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }");
tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }");
@@ -201,6 +248,16 @@ public class EvaluationTestCase {
"max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }",
"min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }",
+ "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }");
tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }",
"tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }");
tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }",
@@ -225,8 +282,13 @@ public class EvaluationTestCase {
"tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
"tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
+ "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }",
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
+ "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+
// TODO
// argmax
// argmin
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index d67c9dfd9dc..ee2b1c147e3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -58,10 +58,18 @@ public class EvaluationTester {
return assertEvaluates(value, expressionString, defaultContext);
}
+ public RankingExpression assertEvaluates(boolean value, String expressionString) {
+ return assertEvaluates(value, expressionString, defaultContext);
+ }
+
public RankingExpression assertEvaluates(double value, String expressionString, Context context) {
return assertEvaluates(new DoubleValue(value), expressionString, context, "");
}
+ public RankingExpression assertEvaluates(boolean value, String expressionString, Context context) {
+ return assertEvaluates(new BooleanValue(value), expressionString, context, "");
+ }
+
public RankingExpression assertEvaluates(Value value, String expressionString, Context context, String explanation) {
try {
RankingExpression expression = new RankingExpression(expressionString);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index 27aaeb776e4..dde9d4bf21e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -18,6 +18,7 @@ import org.junit.Test;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
@@ -51,32 +52,32 @@ public class TensorConformanceTest {
count++;
}
}
- if (failList.size() > 0) {
- System.out.println("Conformance test fails:");
- System.out.println(failList);
- }
-
- // Disable this for now:
- //assertEquals(0, failList.size());
+ assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size());
}
- private boolean testCase(String test, int count) throws IOException {
+ private boolean testCase(String test, int count) {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
- } else if (node.has("expression")) {
- String expression = node.get("expression").asText();
- MapContext context = getInput(node.get("inputs"));
- Tensor expect = getTensor(node.get("result").get("expect").asText());
- Tensor result = evaluate(expression, context);
- boolean equals = Tensor.equals(result, expect);
- if (!equals) {
- System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\"");
- }
- return Tensor.equals(result, expect);
+ return true;
+ }
+ if (!node.has("expression")) {
+ return true; // ignore
}
+
+ String expression = node.get("expression").asText();
+ MapContext context = getInput(node.get("inputs"));
+ Tensor expect = getTensor(node.get("result").get("expect").asText());
+ Tensor result = evaluate(expression, context);
+ boolean equals = Tensor.equals(result, expect);
+ if (!equals) {
+ System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\"");
+ }
+ return equals;
+
} catch (Exception e) {
System.out.println(count + " : " + e.toString());
}
@@ -133,22 +134,5 @@ public class TensorConformanceTest {
throw new IllegalArgumentException("Hex contains illegal characters");
}
- private static String valueType(Value value) {
- if (value instanceof StringValue) {
- return "string";
- }
- if (value instanceof BooleanValue) {
- return "boolean";
- }
- if (value instanceof DoubleCompatibleValue) {
- return "double";
- }
- if (value instanceof TensorValue) {
- return ((TensorValue)value).asTensor().type().toString();
- }
- return "unknown";
- }
-
-
}
diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist
index 327f2b161cd..77b2294c668 100644
--- a/searchlib/src/tests/rankingexpression/rankingexpressionlist
+++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist
@@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4)
"1008\x1977"
"100819\x77"
if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3)
+10 % 3
+1 && 0 || 1
+!a && (a || a)
+10 ^ 3