From 35547f0a1a70593dc3c75f2ebaf3ff0b2101f406 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Thu, 15 Dec 2022 11:43:11 +0000 Subject: make it possible to check for equivalent JSON --- .../ai/vespa/models/handler/HandlerTester.java | 46 +++++++++++++++++++--- 1 file changed, 40 insertions(+), 6 deletions(-) (limited to 'model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java') diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 3b16be311a0..00531e373ee 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -12,25 +12,53 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.Executors; +import java.util.function.Predicate; import static org.junit.Assert.assertEquals; +import static com.yahoo.slime.SlimeUtils.jsonToSlime; class HandlerTester { private final ModelsEvaluationHandler handler; + private static Predicate nop() { + return s -> true; + } + private static Predicate matchString(String expected) { + return s -> expected.equals(s); + } + public static Predicate matchJson(String... expectedJson) { + var jExp = String.join("\n", expectedJson).replaceAll("'", "\""); + var expected = jsonToSlime(jExp); + return s -> { + var got = jsonToSlime(s); + boolean result = got.equalTo(expected); + if (!result) { + System.err.println("got:"); + System.err.println(got); + System.err.println("expected:"); + System.err.println(expected); + } + return result; + }; + } + HandlerTester(ModelsEvaluator models) { this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor()); } void assertResponse(String url, int expectedCode) { - assertResponse(url, Map.of(), expectedCode, (String)null); + checkResponse(url, expectedCode, nop()); } void assertResponse(String url, int expectedCode, String expectedResult) { assertResponse(url, Map.of(), expectedCode, expectedResult); } + void checkResponse(String url, int expectedCode, Predicate check) { + checkResponse(url, Map.of(), expectedCode, check, Map.of()); + } + void assertResponse(String url, int expectedCode, String expectedResult, Map headers) { assertResponse(url, Map.of(), expectedCode, expectedResult, headers); } @@ -40,14 +68,18 @@ class HandlerTester { } void assertResponse(String url, Map properties, int expectedCode, String expectedResult, Map headers) { + checkResponse(url, properties, expectedCode, matchString(expectedResult), headers); + } + + void checkResponse(String url, Map properties, int expectedCode, Predicate check, Map headers) { HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); if (headers.size() > 0) { headers.forEach((k,v) -> getRequest.getJDiscRequest().headers().add(k, v)); headers.forEach((k,v) -> postRequest.getJDiscRequest().headers().add(k, v)); } - assertResponse(getRequest, expectedCode, expectedResult); - assertResponse(postRequest, expectedCode, expectedResult); + checkResponse(getRequest, expectedCode, check); + checkResponse(postRequest, expectedCode, check); } void assertResponse(String url, Map properties, int expectedCode, Tensor expectedResult) { @@ -56,12 +88,14 @@ class HandlerTester { } void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { + checkResponse(request, expectedCode, matchString(expectedResult)); + } + + void checkResponse(HttpRequest request, int expectedCode, Predicate check) { HttpResponse response = handler.handle(request); assertEquals("application/json", response.getContentType()); assertEquals(expectedCode, response.getStatus()); - if (expectedResult != null) { - assertEquals(expectedResult, getContents(response)); - } + assertEquals(true, check.test(getContents(response))); } void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) { -- cgit v1.2.3