diff options
4 files changed, 177 insertions, 3 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 36e6fa1ffda..c669903c3da 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -34,6 +34,16 @@ <artifactId>vespajlib</artifactId> <version>${project.version}</version> </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + <scope>test</scope> + </dependency> </dependencies> <build> <plugins> diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java new file mode 100644 index 00000000000..27aaeb776e4 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -0,0 +1,154 @@ +package com.yahoo.searchlib.tensor; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import org.junit.Assert; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +public class TensorConformanceTest { + + private static String testPath = "eval/src/apps/tensor_conformance/test_spec.json"; + + @Test + public void testConformance() throws IOException { + File testSpec = new File(testPath); + if (!testSpec.exists()) { + testSpec = new File("../" + testPath); + } + int count = 0; + List<Integer> failList = new ArrayList<>(); + + try(BufferedReader br = new BufferedReader(new FileReader(testSpec))) { + String test = br.readLine(); + while (test != null) { + boolean success = testCase(test, count); + if (!success) { + failList.add(count); + } + test = br.readLine(); + count++; + } + } + if (failList.size() > 0) { + System.out.println("Conformance test fails:"); + System.out.println(failList); + } + + // Disable this for now: + //assertEquals(0, failList.size()); + } + + private boolean testCase(String test, int count) throws IOException { + try { + ObjectMapper mapper = new ObjectMapper(); + JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { + Assert.assertEquals(node.get("num_tests").asInt(), count); + } else if (node.has("expression")) { + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return Tensor.equals(result, expect); + } + } catch (Exception e) { + System.out.println(count + " : " + e.toString()); + } + return false; + } + + private Tensor evaluate(String expression, MapContext context) throws ParseException { + Value value = new RankingExpression(expression).evaluate(context); + if (!(value instanceof TensorValue)) { + throw new IllegalArgumentException("Result is not a tensor"); + } + return ((TensorValue)value).asTensor(); + } + + private MapContext getInput(JsonNode inputs) { + MapContext context = new MapContext(); + for (Iterator<String> i = inputs.fieldNames(); i.hasNext(); ) { + String name = i.next(); + String value = inputs.get(name).asText(); + Tensor tensor = getTensor(value); + context.put(name, new TensorValue(tensor)); + } + return context; + } + + private Tensor getTensor(String binaryRepresentation) { + byte[] bin = getBytes(binaryRepresentation); + return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin)); + } + + private byte[] getBytes(String binaryRepresentation) { + return parseHexValue(binaryRepresentation.substring(2)); + } + + private byte[] parseHexValue(String s) { + final int len = s.length(); + byte[] bytes = new byte[len/2]; + for (int i = 0; i < len; i += 2) { + int c1 = hexValue(s.charAt(i)) << 4; + int c2 = hexValue(s.charAt(i + 1)); + bytes[i/2] = (byte)(c1 + c2); + } + return bytes; + } + + private int hexValue(Character c) { + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } else if (c >= '0' && c <= '9') { + return c - '0'; + } + throw new IllegalArgumentException("Hex contains illegal characters"); + } + + private static String valueType(Value value) { + if (value instanceof StringValue) { + return "string"; + } + if (value instanceof BooleanValue) { + return "boolean"; + } + if (value instanceof DoubleCompatibleValue) { + return "double"; + } + if (value instanceof TensorValue) { + return ((TensorValue)value).asTensor().type().toString(); + } + return "unknown"; + } + + +} + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6750c99bf98..c207dabca3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -125,8 +125,12 @@ public class IndexedTensor implements Tensor { if (indexes.length == 0) return 0; // for speed int valueIndex = 0; - for (int i = 0; i < indexes.length; i++) + for (int i = 0; i < indexes.length; i++) { + if (indexes[i] >= sizes.size(i)) { + throw new IndexOutOfBoundsException(); + } valueIndex += productOfDimensionsAfter(i, sizes) * indexes[i]; + } return valueIndex; } @@ -134,8 +138,12 @@ public class IndexedTensor implements Tensor { if (address.isEmpty()) return 0; int valueIndex = 0; - for (int i = 0; i < address.size(); i++) + for (int i = 0; i < address.size(); i++) { + if (address.intLabel(i) >= sizes.size(i)) { + throw new IndexOutOfBoundsException(); + } valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i); + } return valueIndex; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 75bdb6fb15d..8fc80e3b440 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -259,7 +259,9 @@ public interface Tensor { if ( a.size() != b.size()) return false; for (Iterator<Cell> aIterator = a.cellIterator(); aIterator.hasNext(); ) { Cell aCell = aIterator.next(); - if ( ! aCell.getValue().equals(b.get(aCell.getKey()))) return false; + double aValue = aCell.getValue(); + double bValue = b.get(aCell.getKey()); + if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision } return true; } |