aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-16 11:43:45 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-16 11:43:45 +0100
commit3f07bf2d9e6eae85c50aa8734694273c983f959b (patch)
treef528075cb0e877423d9d2e26d4f6925f6ff9784c /model-evaluation
parent416f596b150ec159717bfd2f9b2ef70e4d4cd3dd (diff)
Test direct rendering
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java22
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java86
2 files changed, 88 insertions, 20 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 5fabfca8737..6c4dd886f4b 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -6,6 +6,7 @@ import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.text.JSON;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -26,11 +27,18 @@ class HandlerTester {
}
private static Predicate<String> matchString(String expected) {
return s -> {
- // System.out.println("Expected: " + expected);
- // System.out.println("Actual: " + s);
+ //System.out.println("Expected: " + expected);
+ //System.out.println("Actual: " + s);
return expected.equals(s);
};
}
+ private static Predicate<String> matchJsonString(String expected) {
+ return s -> {
+ //System.out.println("Expected: " + expected);
+ //System.out.println("Actual: " + s);
+ return JSON.canonical(expected).equals(JSON.canonical(s));
+ };
+ }
public static Predicate<String> matchJson(String... expectedJson) {
var jExp = String.join("\n", expectedJson).replaceAll("'", "\"");
var expected = jsonToSlime(jExp);
@@ -72,6 +80,10 @@ class HandlerTester {
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
+ checkResponse(url, properties, expectedCode, matchJsonString(expectedResult), headers);
+ }
+
+ void assertStringResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
checkResponse(url, properties, expectedCode, matchString(expectedResult), headers);
}
@@ -91,15 +103,11 @@ class HandlerTester {
assertResponse(getRequest, expectedCode, expectedResult);
}
- void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
- checkResponse(request, expectedCode, matchString(expectedResult));
- }
-
void checkResponse(HttpRequest request, int expectedCode, Predicate<String> check) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
- assertEquals(expectedCode, response.getStatus());
assertEquals(true, check.test(getContents(response)));
+ assertEquals(expectedCode, response.getStatus());
}
void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 50dbecaffce..9b2b793212b 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -191,22 +191,82 @@ public class ModelsEvaluationHandlerTest {
}
@Test
- public void testMnistSoftmaxEvaluateSpecificFunctionWithBindingsShortForm() {
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format.tensors", "short");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected =
+ """
+ {
+ "type":"tensor(d0[],d1[10])",
+ "values":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]
+ }
+ """;
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithLongOutput() {
Map<String, String> properties = new HashMap<>();
properties.put("Placeholder", inputTensorShortForm());
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ String expected =
+ """
+ {
+ "type":"tensor(d0[],d1[10])",
+ "cells":[
+ {"address":{"d0":"0","d1":"0"},"value":-0.3546536862850189},
+ {"address":{"d0":"0","d1":"1"},"value":0.3759574592113495},
+ {"address":{"d0":"0","d1":"2"},"value":0.06054411828517914},
+ {"address":{"d0":"0","d1":"3"},"value":-0.251544713973999},
+ {"address":{"d0":"0","d1":"4"},"value":0.017951013520359993},
+ {"address":{"d0":"0","d1":"5"},"value":1.2899067401885986},
+ {"address":{"d0":"0","d1":"6"},"value":-0.10389615595340729},
+ {"address":{"d0":"0","d1":"7"},"value":0.6367976665496826},
+ {"address":{"d0":"0","d1":"8"},"value":-1.4136744737625122},
+ {"address":{"d0":"0","d1":"9"},"value":-0.2573896050453186}
+ ]
+ }
+ """;
handler.assertResponse(url, properties, 200, expected);
}
@Test
- public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortDirectOutput() {
Map<String, String> properties = new HashMap<>();
properties.put("Placeholder", inputTensorShortForm());
- properties.put("format.tensors", "short");
+ properties.put("format.tensors", "short-value");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"values\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
+ String expected =
+ """
+ [[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]
+ """;
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithLongDirectOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format.tensors", "long-value");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected =
+ """
+ [
+ {"address":{"d0":"0","d1":"0"},"value":-0.3546536862850189},
+ {"address":{"d0":"0","d1":"1"},"value":0.3759574592113495},
+ {"address":{"d0":"0","d1":"2"},"value":0.06054411828517914},
+ {"address":{"d0":"0","d1":"3"},"value":-0.251544713973999},
+ {"address":{"d0":"0","d1":"4"},"value":0.017951013520359993},
+ {"address":{"d0":"0","d1":"5"},"value":1.2899067401885986},
+ {"address":{"d0":"0","d1":"6"},"value":-0.10389615595340729},
+ {"address":{"d0":"0","d1":"7"},"value":0.6367976665496826},
+ {"address":{"d0":"0","d1":"8"},"value":-1.4136744737625122},
+ {"address":{"d0":"0","d1":"9"},"value":-0.2573896050453186}
+ ]
+ """;
handler.assertResponse(url, properties, 200, expected);
}
@@ -251,14 +311,14 @@ public class ModelsEvaluationHandlerTest {
Map<String, String> properties = new HashMap<>();
properties.put("format.tensors", "string");
String url = "http://localhost/model-evaluation/v1/vespa_model/";
- handler.assertResponse(url + "test_mapped/eval", properties, 200,
- "tensor(d0{}):{a:1.0, b:2.0}");
- handler.assertResponse(url + "test_indexed/eval", properties, 200,
- "tensor(d0[2],d1[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]");
- handler.assertResponse(url + "test_mixed/eval", properties, 200,
- "tensor(x{},y[3]):{a:[1.0, 2.0, 3.0], b:[4.0, 5.0, 6.0]}");
- handler.assertResponse(url + "test_mixed_2/eval", properties, 200,
- "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}");
+ handler.assertStringResponse(url + "test_mapped/eval", properties, 200,
+ "tensor(d0{}):{a:1.0, b:2.0}", Map.of());
+ handler.assertStringResponse(url + "test_indexed/eval", properties, 200,
+ "tensor(d0[2],d1[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]", Map.of());
+ handler.assertStringResponse(url + "test_mixed/eval", properties, 200,
+ "tensor(x{},y[3]):{a:[1.0, 2.0, 3.0], b:[4.0, 5.0, 6.0]}", Map.of());
+ handler.assertStringResponse(url + "test_mixed_2/eval", properties, 200,
+ "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}", Map.of());
}
@Test