summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-14 18:41:49 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-14 18:41:49 +0100
commit416f596b150ec159717bfd2f9b2ef70e4d4cd3dd (patch)
treefd78cf0541670dd50e2dc3256c5b9755ced8f73e /model-evaluation
parenta289581cbf94ff6997356110b54bd6993e956b9e (diff)
Support direct tensor rendering
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java18
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java6
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java4
4 files changed, 18 insertions, 14 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index ef04b6641e5..1bcd6363d2d 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -91,15 +91,15 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
}
Tensor result = evaluator.evaluate();
-
- Optional<String> format = property(request, "format.tensors");
- if (format.isPresent() && format.get().equalsIgnoreCase("long")) {
- return new Response(200, JsonFormat.encode(result));
- }
- else if (format.isPresent() && format.get().equalsIgnoreCase("string")) {
- return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8));
- }
- return new Response(200, JsonFormat.encodeShortForm(result));
+ return switch (property(request, "format.tensors").orElse("short").toLowerCase()) {
+ case "short" -> new Response(200, JsonFormat.encode(result, true, false));
+ case "long" -> new Response(200, JsonFormat.encode(result, false, false));
+ case "short-value" -> new Response(200, JsonFormat.encode(result, true, true));
+ case "long-value" -> new Response(200, JsonFormat.encode(result, false, true));
+ case "string" -> new Response(200, result.toString(true, true).getBytes(StandardCharsets.UTF_8));
+ case "string-long " -> new Response(200, result.toString(true, false ).getBytes(StandardCharsets.UTF_8));
+ default -> new ErrorResponse(400, "Unknown tensor format '" + property(request, "format.tensors") + "'");
+ };
}
private HttpResponse listAllModels(HttpRequest request) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 00531e373ee..5fabfca8737 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -25,7 +25,11 @@ class HandlerTester {
return s -> true;
}
private static Predicate<String> matchString(String expected) {
- return s -> expected.equals(s);
+ return s -> {
+ // System.out.println("Expected: " + expected);
+ // System.out.println("Actual: " + s);
+ return expected.equals(s);
+ };
}
public static Predicate<String> matchJson(String... expectedJson) {
var jExp = String.join("\n", expectedJson).replaceAll("'", "\"");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index c0e5dd9ccda..50dbecaffce 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -107,7 +107,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
- String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
+ String expected = "{\"type\":\"tensor()\",\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
handler.assertResponse(url, properties, 200, expected);
}
@@ -196,7 +196,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensorShortForm());
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
handler.assertResponse(url, properties, 200, expected);
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 29795fbcd95..86f56e14e2d 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -83,7 +83,7 @@ public class OnnxEvaluationHandlerTest {
properties.put("input2", "tensor<float>(d0[1]):[3]");
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul
+ String expected = "{\"type\":\"tensor<float>(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul
handler.assertResponse(url, properties, 200, expected);
}
@@ -94,7 +94,7 @@ public class OnnxEvaluationHandlerTest {
properties.put("input2", "tensor<float>(d0[1]):[3]");
properties.put("format.tensors", "long");
String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval";
- String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add
+ String expected = "{\"type\":\"tensor<float>(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add
handler.assertResponse(url, properties, 200, expected);
}