From 238064fb7136aedecbff4f37c0a48f3b0152d32a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 9 Nov 2017 13:34:19 +0100 Subject: Add tensor conformance test in Java --- searchlib/pom.xml | 10 ++ .../searchlib/tensor/TensorConformanceTest.java | 154 +++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java (limited to 'searchlib') diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 36e6fa1ffda..c669903c3da 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -34,6 +34,16 @@ vespajlib ${project.version} + + com.fasterxml.jackson.core + jackson-core + test + + + com.fasterxml.jackson.core + jackson-databind + test + diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java new file mode 100644 index 00000000000..27aaeb776e4 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -0,0 +1,154 @@ +package com.yahoo.searchlib.tensor; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import org.junit.Assert; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +public class TensorConformanceTest { + + private static String testPath = "eval/src/apps/tensor_conformance/test_spec.json"; + + @Test + public void testConformance() throws IOException { + File testSpec = new File(testPath); + if (!testSpec.exists()) { + testSpec = new File("../" + testPath); + } + int count = 0; + List failList = new ArrayList<>(); + + try(BufferedReader br = new BufferedReader(new FileReader(testSpec))) { + String test = br.readLine(); + while (test != null) { + boolean success = testCase(test, count); + if (!success) { + failList.add(count); + } + test = br.readLine(); + count++; + } + } + if (failList.size() > 0) { + System.out.println("Conformance test fails:"); + System.out.println(failList); + } + + // Disable this for now: + //assertEquals(0, failList.size()); + } + + private boolean testCase(String test, int count) throws IOException { + try { + ObjectMapper mapper = new ObjectMapper(); + JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { + Assert.assertEquals(node.get("num_tests").asInt(), count); + } else if (node.has("expression")) { + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return Tensor.equals(result, expect); + } + } catch (Exception e) { + System.out.println(count + " : " + e.toString()); + } + return false; + } + + private Tensor evaluate(String expression, MapContext context) throws ParseException { + Value value = new RankingExpression(expression).evaluate(context); + if (!(value instanceof TensorValue)) { + throw new IllegalArgumentException("Result is not a tensor"); + } + return ((TensorValue)value).asTensor(); + } + + private MapContext getInput(JsonNode inputs) { + MapContext context = new MapContext(); + for (Iterator i = inputs.fieldNames(); i.hasNext(); ) { + String name = i.next(); + String value = inputs.get(name).asText(); + Tensor tensor = getTensor(value); + context.put(name, new TensorValue(tensor)); + } + return context; + } + + private Tensor getTensor(String binaryRepresentation) { + byte[] bin = getBytes(binaryRepresentation); + return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin)); + } + + private byte[] getBytes(String binaryRepresentation) { + return parseHexValue(binaryRepresentation.substring(2)); + } + + private byte[] parseHexValue(String s) { + final int len = s.length(); + byte[] bytes = new byte[len/2]; + for (int i = 0; i < len; i += 2) { + int c1 = hexValue(s.charAt(i)) << 4; + int c2 = hexValue(s.charAt(i + 1)); + bytes[i/2] = (byte)(c1 + c2); + } + return bytes; + } + + private int hexValue(Character c) { + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } else if (c >= '0' && c <= '9') { + return c - '0'; + } + throw new IllegalArgumentException("Hex contains illegal characters"); + } + + private static String valueType(Value value) { + if (value instanceof StringValue) { + return "string"; + } + if (value instanceof BooleanValue) { + return "boolean"; + } + if (value instanceof DoubleCompatibleValue) { + return "double"; + } + if (value instanceof TensorValue) { + return ((TensorValue)value).asTensor().type().toString(); + } + return "unknown"; + } + + +} + -- cgit v1.2.3