aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2022-12-15 11:43:11 +0000
committerArne Juul <arnej@yahooinc.com>2022-12-15 11:46:40 +0000
commit35547f0a1a70593dc3c75f2ebaf3ff0b2101f406 (patch)
treec6229dba18d9ce2fd262c464326fd30b34e30585 /model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
parentb5994fafd8a92746cd4543ba8bd33175a377e291 (diff)
make it possible to check for equivalent JSON
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java46
1 files changed, 40 insertions, 6 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 3b16be311a0..00531e373ee 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -12,25 +12,53 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.concurrent.Executors;
+import java.util.function.Predicate;
import static org.junit.Assert.assertEquals;
+import static com.yahoo.slime.SlimeUtils.jsonToSlime;
class HandlerTester {
private final ModelsEvaluationHandler handler;
+ private static Predicate<String> nop() {
+ return s -> true;
+ }
+ private static Predicate<String> matchString(String expected) {
+ return s -> expected.equals(s);
+ }
+ public static Predicate<String> matchJson(String... expectedJson) {
+ var jExp = String.join("\n", expectedJson).replaceAll("'", "\"");
+ var expected = jsonToSlime(jExp);
+ return s -> {
+ var got = jsonToSlime(s);
+ boolean result = got.equalTo(expected);
+ if (!result) {
+ System.err.println("got:");
+ System.err.println(got);
+ System.err.println("expected:");
+ System.err.println(expected);
+ }
+ return result;
+ };
+ }
+
HandlerTester(ModelsEvaluator models) {
this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor());
}
void assertResponse(String url, int expectedCode) {
- assertResponse(url, Map.of(), expectedCode, (String)null);
+ checkResponse(url, expectedCode, nop());
}
void assertResponse(String url, int expectedCode, String expectedResult) {
assertResponse(url, Map.of(), expectedCode, expectedResult);
}
+ void checkResponse(String url, int expectedCode, Predicate<String> check) {
+ checkResponse(url, Map.of(), expectedCode, check, Map.of());
+ }
+
void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) {
assertResponse(url, Map.of(), expectedCode, expectedResult, headers);
}
@@ -40,14 +68,18 @@ class HandlerTester {
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
+ checkResponse(url, properties, expectedCode, matchString(expectedResult), headers);
+ }
+
+ void checkResponse(String url, Map<String, String> properties, int expectedCode, Predicate<String> check, Map<String, String> headers) {
HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
if (headers.size() > 0) {
headers.forEach((k,v) -> getRequest.getJDiscRequest().headers().add(k, v));
headers.forEach((k,v) -> postRequest.getJDiscRequest().headers().add(k, v));
}
- assertResponse(getRequest, expectedCode, expectedResult);
- assertResponse(postRequest, expectedCode, expectedResult);
+ checkResponse(getRequest, expectedCode, check);
+ checkResponse(postRequest, expectedCode, check);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) {
@@ -56,12 +88,14 @@ class HandlerTester {
}
void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
+ checkResponse(request, expectedCode, matchString(expectedResult));
+ }
+
+ void checkResponse(HttpRequest request, int expectedCode, Predicate<String> check) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
assertEquals(expectedCode, response.getStatus());
- if (expectedResult != null) {
- assertEquals(expectedResult, getContents(response));
- }
+ assertEquals(true, check.test(getContents(response)));
}
void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {