summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@yahoo-inc.com>2017-11-16 12:26:32 +0100
committerLester Solbakken <lesters@yahoo-inc.com>2017-11-16 12:26:32 +0100
commitd248daea0a53004b7f15fb36393504d182171f01 (patch)
tree22c977b63fa9df59f0de075ea52ec6feb46ce3d7 /searchlib
parent7a6772d914b0de0ecf683f1233c349e34067ec37 (diff)
Enable Java tensor conformance test
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java28
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java54
-rw-r--r--searchlib/src/tests/rankingexpression/rankingexpressionlist4
3 files changed, 50 insertions, 36 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 8e34f35245d..82e5d0cfe5b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -143,6 +143,16 @@ public class EvaluationTestCase {
"min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }",
"max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
+ // operators
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
@@ -158,8 +168,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }");
- tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }");
+ tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }");
+ tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }");
tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }");
@@ -237,6 +248,16 @@ public class EvaluationTestCase {
"max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }",
"min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }",
+ "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }");
tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }",
"tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }");
tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }",
@@ -261,8 +282,13 @@ public class EvaluationTestCase {
"tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
"tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
+ "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }",
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
+ "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+
// TODO
// argmax
// argmin
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
index 27aaeb776e4..dde9d4bf21e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -18,6 +18,7 @@ import org.junit.Test;
import java.io.BufferedReader;
import java.io.File;
+import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
@@ -51,32 +52,32 @@ public class TensorConformanceTest {
count++;
}
}
- if (failList.size() > 0) {
- System.out.println("Conformance test fails:");
- System.out.println(failList);
- }
-
- // Disable this for now:
- //assertEquals(0, failList.size());
+ assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size());
}
- private boolean testCase(String test, int count) throws IOException {
+ private boolean testCase(String test, int count) {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode node = mapper.readTree(test);
+
if (node.has("num_tests")) {
Assert.assertEquals(node.get("num_tests").asInt(), count);
- } else if (node.has("expression")) {
- String expression = node.get("expression").asText();
- MapContext context = getInput(node.get("inputs"));
- Tensor expect = getTensor(node.get("result").get("expect").asText());
- Tensor result = evaluate(expression, context);
- boolean equals = Tensor.equals(result, expect);
- if (!equals) {
- System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\"");
- }
- return Tensor.equals(result, expect);
+ return true;
+ }
+ if (!node.has("expression")) {
+ return true; // ignore
}
+
+ String expression = node.get("expression").asText();
+ MapContext context = getInput(node.get("inputs"));
+ Tensor expect = getTensor(node.get("result").get("expect").asText());
+ Tensor result = evaluate(expression, context);
+ boolean equals = Tensor.equals(result, expect);
+ if (!equals) {
+ System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\"");
+ }
+ return equals;
+
} catch (Exception e) {
System.out.println(count + " : " + e.toString());
}
@@ -133,22 +134,5 @@ public class TensorConformanceTest {
throw new IllegalArgumentException("Hex contains illegal characters");
}
- private static String valueType(Value value) {
- if (value instanceof StringValue) {
- return "string";
- }
- if (value instanceof BooleanValue) {
- return "boolean";
- }
- if (value instanceof DoubleCompatibleValue) {
- return "double";
- }
- if (value instanceof TensorValue) {
- return ((TensorValue)value).asTensor().type().toString();
- }
- return "unknown";
- }
-
-
}
diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist
index 327f2b161cd..77b2294c668 100644
--- a/searchlib/src/tests/rankingexpression/rankingexpressionlist
+++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist
@@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4)
"1008\x1977"
"100819\x77"
if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3)
+10 % 3
+1 && 0 || 1
+!a && (a || a)
+10 ^ 3