diff options
Diffstat (limited to 'searchlib')
3 files changed, 50 insertions, 36 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 8e34f35245d..82e5d0cfe5b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -143,6 +143,16 @@ public class EvaluationTestCase { "min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }", "max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); + // operators + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); @@ -158,8 +168,9 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }"); - tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }"); + tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }"); + tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }"); tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }"); @@ -237,6 +248,16 @@ public class EvaluationTestCase { "max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }", + "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }"); tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }", "tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }"); tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }", @@ -261,8 +282,13 @@ public class EvaluationTestCase { "tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", "tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", + "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }", "tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }", + "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }"); + // TODO // argmax // argmin diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index 27aaeb776e4..dde9d4bf21e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -18,6 +18,7 @@ import org.junit.Test; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; @@ -51,32 +52,32 @@ public class TensorConformanceTest { count++; } } - if (failList.size() > 0) { - System.out.println("Conformance test fails:"); - System.out.println(failList); - } - - // Disable this for now: - //assertEquals(0, failList.size()); + assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size()); } - private boolean testCase(String test, int count) throws IOException { + private boolean testCase(String test, int count) { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); - } else if (node.has("expression")) { - String expression = node.get("expression").asText(); - MapContext context = getInput(node.get("inputs")); - Tensor expect = getTensor(node.get("result").get("expect").asText()); - Tensor result = evaluate(expression, context); - boolean equals = Tensor.equals(result, expect); - if (!equals) { - System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); - } - return Tensor.equals(result, expect); + return true; + } + if (!node.has("expression")) { + return true; // ignore } + + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return equals; + } catch (Exception e) { System.out.println(count + " : " + e.toString()); } @@ -133,22 +134,5 @@ public class TensorConformanceTest { throw new IllegalArgumentException("Hex contains illegal characters"); } - private static String valueType(Value value) { - if (value instanceof StringValue) { - return "string"; - } - if (value instanceof BooleanValue) { - return "boolean"; - } - if (value instanceof DoubleCompatibleValue) { - return "double"; - } - if (value instanceof TensorValue) { - return ((TensorValue)value).asTensor().type().toString(); - } - return "unknown"; - } - - } diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist index 327f2b161cd..77b2294c668 100644 --- a/searchlib/src/tests/rankingexpression/rankingexpressionlist +++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist @@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4) "1008\x1977" "100819\x77" if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3) +10 % 3 +1 && 0 || 1 +!a && (a || a) +10 ^ 3 |