From 416f596b150ec159717bfd2f9b2ef70e4d4cd3dd Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Sat, 14 Jan 2023 18:41:49 +0100 Subject: Support direct tensor rendering --- .../vespa/models/handler/ModelsEvaluationHandler.java | 18 +++++++++--------- .../java/ai/vespa/models/handler/HandlerTester.java | 6 +++++- .../models/handler/ModelsEvaluationHandlerTest.java | 4 ++-- .../models/handler/OnnxEvaluationHandlerTest.java | 4 ++-- 4 files changed, 18 insertions(+), 14 deletions(-) (limited to 'model-evaluation') diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index ef04b6641e5..1bcd6363d2d 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -91,15 +91,15 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } } Tensor result = evaluator.evaluate(); - - Optional format = property(request, "format.tensors"); - if (format.isPresent() && format.get().equalsIgnoreCase("long")) { - return new Response(200, JsonFormat.encode(result)); - } - else if (format.isPresent() && format.get().equalsIgnoreCase("string")) { - return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8)); - } - return new Response(200, JsonFormat.encodeShortForm(result)); + return switch (property(request, "format.tensors").orElse("short").toLowerCase()) { + case "short" -> new Response(200, JsonFormat.encode(result, true, false)); + case "long" -> new Response(200, JsonFormat.encode(result, false, false)); + case "short-value" -> new Response(200, JsonFormat.encode(result, true, true)); + case "long-value" -> new Response(200, JsonFormat.encode(result, false, true)); + case "string" -> new Response(200, result.toString(true, true).getBytes(StandardCharsets.UTF_8)); + case "string-long " -> new Response(200, result.toString(true, false ).getBytes(StandardCharsets.UTF_8)); + default -> new ErrorResponse(400, "Unknown tensor format '" + property(request, "format.tensors") + "'"); + }; } private HttpResponse listAllModels(HttpRequest request) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 00531e373ee..5fabfca8737 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -25,7 +25,11 @@ class HandlerTester { return s -> true; } private static Predicate matchString(String expected) { - return s -> expected.equals(s); + return s -> { + // System.out.println("Expected: " + expected); + // System.out.println("Actual: " + s); + return expected.equals(s); + }; } public static Predicate matchJson(String... expectedJson) { var jExp = String.join("\n", expectedJson).replaceAll("'", "\""); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index c0e5dd9ccda..50dbecaffce 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -107,7 +107,7 @@ public class ModelsEvaluationHandlerTest { properties.put("non-existing-binding", "-1"); properties.put("format.tensors", "long"); String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; - String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; + String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; handler.assertResponse(url, properties, 200, expected); } @@ -196,7 +196,7 @@ public class ModelsEvaluationHandlerTest { properties.put("Placeholder", inputTensorShortForm()); properties.put("format.tensors", "long"); String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; handler.assertResponse(url, properties, 200, expected); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 29795fbcd95..86f56e14e2d 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -83,7 +83,7 @@ public class OnnxEvaluationHandlerTest { properties.put("input2", "tensor(d0[1]):[3]"); properties.put("format.tensors", "long"); String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval"; - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul + String expected = "{\"type\":\"tensor(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul handler.assertResponse(url, properties, 200, expected); } @@ -94,7 +94,7 @@ public class OnnxEvaluationHandlerTest { properties.put("input2", "tensor(d0[1]):[3]"); properties.put("format.tensors", "long"); String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval"; - String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add + String expected = "{\"type\":\"tensor(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add handler.assertResponse(url, properties, 200, expected); } -- cgit v1.2.3