diff options
Diffstat (limited to 'model-evaluation/src/test/java')
3 files changed, 69 insertions, 29 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 3b16be311a0..00531e373ee 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -12,25 +12,53 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.Executors; +import java.util.function.Predicate; import static org.junit.Assert.assertEquals; +import static com.yahoo.slime.SlimeUtils.jsonToSlime; class HandlerTester { private final ModelsEvaluationHandler handler; + private static Predicate<String> nop() { + return s -> true; + } + private static Predicate<String> matchString(String expected) { + return s -> expected.equals(s); + } + public static Predicate<String> matchJson(String... expectedJson) { + var jExp = String.join("\n", expectedJson).replaceAll("'", "\""); + var expected = jsonToSlime(jExp); + return s -> { + var got = jsonToSlime(s); + boolean result = got.equalTo(expected); + if (!result) { + System.err.println("got:"); + System.err.println(got); + System.err.println("expected:"); + System.err.println(expected); + } + return result; + }; + } + HandlerTester(ModelsEvaluator models) { this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor()); } void assertResponse(String url, int expectedCode) { - assertResponse(url, Map.of(), expectedCode, (String)null); + checkResponse(url, expectedCode, nop()); } void assertResponse(String url, int expectedCode, String expectedResult) { assertResponse(url, Map.of(), expectedCode, expectedResult); } + void checkResponse(String url, int expectedCode, Predicate<String> check) { + checkResponse(url, Map.of(), expectedCode, check, Map.of()); + } + void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) { assertResponse(url, Map.of(), expectedCode, expectedResult, headers); } @@ -40,14 +68,18 @@ class HandlerTester { } void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) { + checkResponse(url, properties, expectedCode, matchString(expectedResult), headers); + } + + void checkResponse(String url, Map<String, String> properties, int expectedCode, Predicate<String> check, Map<String, String> headers) { HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); if (headers.size() > 0) { headers.forEach((k,v) -> getRequest.getJDiscRequest().headers().add(k, v)); headers.forEach((k,v) -> postRequest.getJDiscRequest().headers().add(k, v)); } - assertResponse(getRequest, expectedCode, expectedResult); - assertResponse(postRequest, expectedCode, expectedResult); + checkResponse(getRequest, expectedCode, check); + checkResponse(postRequest, expectedCode, check); } void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) { @@ -56,12 +88,14 @@ class HandlerTester { } void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { + checkResponse(request, expectedCode, matchString(expectedResult)); + } + + void checkResponse(HttpRequest request, int expectedCode, Predicate<String> check) { HttpResponse response = handler.handle(request); assertEquals("application/json", response.getContentType()); assertEquals(expectedCode, response.getStatus()); - if (expectedResult != null) { - assertEquals(expectedResult, getContents(response)); - } + assertEquals(true, check.test(getContents(response))); } void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index c52bf66626a..c0e5dd9ccda 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -14,7 +14,6 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import java.util.HashMap; @@ -262,7 +261,6 @@ public class ModelsEvaluationHandlerTest { "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}"); } - @Ignore @Test public void testMnistSavedEvaluateSpecificFunction() { assumeTrue(OnnxEvaluator.isRuntimeAvailable()); @@ -270,7 +268,17 @@ public class ModelsEvaluationHandlerTest { properties.put("input", inputTensor()); properties.put("format.tensors", "long"); String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}"; + Tensor expected = Tensor.from("tensor(d0[1],d1[10]):{"+ + "{d0:0,d1:0}:-0.6319251673007533,"+ + "{d0:0,d1:1}:-0.0007577770600619843,"+ + "{d0:0,d1:2}:-0.010707969042025622,"+ + "{d0:0,d1:3}:-0.6344759233540788,"+ + "{d0:0,d1:4}:-0.17529455385847528,"+ + "{d0:0,d1:5}:0.7490809723192187,"+ + "{d0:0,d1:6}:-0.022790284182901716,"+ + "{d0:0,d1:7}:0.26799240657608936,"+ + "{d0:0,d1:8}:-0.3152438845465862,"+ + "{d0:0,d1:9}:0.05949304847735276}"); handler.assertResponse(url, properties, 200, expected); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index cc3f2863068..29795fbcd95 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -12,7 +12,6 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; import org.junit.BeforeClass; -import org.junit.Ignore; import org.junit.Test; import java.io.File; @@ -32,36 +31,35 @@ public class OnnxEvaluationHandlerTest { handler = new HandlerTester(createModels()); } - @Ignore @Test public void testListModels() { String url = "http://localhost/model-evaluation/v1"; String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," + "\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," + "\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}"; - handler.assertResponse(url, 200, expected); + handler.checkResponse(url, 200, HandlerTester.matchJson(expected)); } - @Ignore @Test public void testModelInfo() { String url = "http://localhost/model-evaluation/v1/add_mul"; - String expected = "{\"model\":\"add_mul\",\"functions\":[" + - "{\"function\":\"output1\"," + - "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output1\"," + - "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," + - "\"arguments\":[" + - "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + - "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + - "]}," + - "{\"function\":\"output2\"," + - "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output2\"," + - "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," + - "\"arguments\":[" + - "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + - "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + - "]}]}"; - handler.assertResponse(url, 200, expected); + var check = HandlerTester.matchJson( + "{'model':'add_mul','functions':[", + " {'function':'output1',", + " 'info':'http://localhost/model-evaluation/v1/add_mul/output1',", + " 'eval':'http://localhost/model-evaluation/v1/add_mul/output1/eval',", + " 'arguments':[", + " {'name':'input1','type':'tensor<float>(d0[1])'},", + " {'name':'input2','type':'tensor<float>(d0[1])'}", + " ]},", + " {'function':'output2',", + " 'info':'http://localhost/model-evaluation/v1/add_mul/output2',", + " 'eval':'http://localhost/model-evaluation/v1/add_mul/output2/eval',", + " 'arguments':[", + " {'name':'input1','type':'tensor<float>(d0[1])'},", + " {'name':'input2','type':'tensor<float>(d0[1])'}", + " ]}]}"); + handler.checkResponse(url, 200, check); } @Test |