diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-11-09 10:06:25 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-11-10 13:01:57 +0000 |
commit | 14874f78191132e7cbfa6b36e0121b5376948953 (patch) | |
tree | dbcb62c2cd1a6821fe6bab37466b70abad74ee73 /searchlib/src/main | |
parent | 747e349890214fed758192b0d74a11338c26eef8 (diff) |
add a new Java application for evaluating tensor conformance tests
* based on TensorConformanceTest unit test class
* reads JSON into Slime structure
* annotates with actual results from vespajlib evaluation
Diffstat (limited to 'searchlib/src/main')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/tensor/EvaluateTensorConformance.java | 160 |
1 files changed, 160 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/tensor/EvaluateTensorConformance.java b/searchlib/src/main/java/com/yahoo/searchlib/tensor/EvaluateTensorConformance.java new file mode 100644 index 00000000000..281533d6478 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/tensor/EvaluateTensorConformance.java @@ -0,0 +1,160 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.tensor; + +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.JsonFormat; +import com.yahoo.slime.ObjectTraverser; +import com.yahoo.slime.Slime; +import com.yahoo.slime.SlimeUtils; + +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; + +import java.io.BufferedOutputStream; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStream; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +public class EvaluateTensorConformance { + + public static void main(String[] args) { + var app = new EvaluateTensorConformance(); + app.evaluateStdIn(); + } + + OutputStream outStream = new BufferedOutputStream(System.out); + + void evaluateStdIn() { + int count = 0; + try (BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) { + String test = br.readLine(); + while (test != null) { + boolean success = testCase(test, count); + if (!success) { + System.err.println("FAILED testcase "+count); + } + ++count; + test = br.readLine(); + } + } catch (IOException e) { + System.err.println(count + " FAILED : " + e.toString()); + } + } + + void output(Slime result) { + try { + new JsonFormat(true).encode(outStream, result); + outStream.write('\n'); + outStream.flush(); + } catch (IOException e) { + System.err.println("FAILED writing output: "+e); + System.exit(1); + } + } + + private boolean testCase(String test, int count) { + boolean okAndEqual = false; + try { + Slime input = SlimeUtils.jsonToSlime(test); + Slime result = new Slime(); + var top = result.setObject(); + SlimeUtils.copyObject(input.get(), top); + var num_tests = input.get().field("num_tests"); + if (input.get().field("num_tests").valid()) { + long expect = input.get().field("num_tests").asLong(); + okAndEqual = (expect == count); + } else if (input.get().field("expression").valid()) { + Tensor expect = getTensor(input.get().field("result").field("expect").asString()); + String expression = input.get().field("expression").asString(); + MapContext context = getInput(input.get().field("inputs")); + Tensor actual = evaluate(expression, context); + okAndEqual = Tensor.equals(actual, expect); + if (!okAndEqual) { + System.err.println(count + " : Tensors not equal. Actual: " + actual.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } else if (! actual.type().valueType().equals(expect.type().valueType())) { + System.err.println(count + " : Tensor cell value types not equal. Actual: " + actual.type() + " Expected: " + expect.type() + " -> expression \"" + expression + "\""); + okAndEqual = false; + } + top.field("result").setData("vespajlib", TypedBinaryFormat.encode(actual)); + } else { + System.err.println(count + " : Invalid input >>>"+test+"<<<"); + } + output(result); + } catch (Exception e) { + System.err.println(count + " : " + e.toString()); + } + return okAndEqual; + } + + private Tensor evaluate(String expression, MapContext context) throws ParseException { + Value value = new RankingExpression(expression).evaluate(context); + if (!(value instanceof TensorValue)) { + throw new IllegalArgumentException("Result is not a tensor"); + } + return ((TensorValue)value).asTensor(); + } + + private MapContext getInput(Inspector inputs) { + MapContext context = new MapContext(); + inputs.traverse(new ObjectTraverser() { + public void field(String name, Inspector contents) { + String value = contents.asString(); + Tensor tensor = getTensor(value); + context.put(name, new TensorValue(tensor)); + } + }); + return context; + } + + private Tensor getTensor(String binaryRepresentation) { + byte[] bin = getBytes(binaryRepresentation); + return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin)); + } + + private byte[] getBytes(String binaryRepresentation) { + return parseHexValue(binaryRepresentation.substring(2)); + } + + private byte[] parseHexValue(String s) { + final int len = s.length(); + byte[] bytes = new byte[len/2]; + for (int i = 0; i < len; i += 2) { + int c1 = hexValue(s.charAt(i)) << 4; + int c2 = hexValue(s.charAt(i + 1)); + bytes[i/2] = (byte)(c1 + c2); + } + return bytes; + } + + private int hexValue(Character c) { + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } else if (c >= '0' && c <= '9') { + return c - '0'; + } + throw new IllegalArgumentException("Hex contains illegal characters"); + } + +} + |