diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-01 15:26:11 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-01 15:26:11 +0200 |
commit | 8286b1d9394a3f89c08b0d193d65d44e937be017 (patch) | |
tree | ae0f97745055aba39feae4ba59b9a3594a1a9b01 /model-evaluation | |
parent | b8ef0eedfd29827a98d561d65b4c657ecbadf243 (diff) |
Add short form output option to model-evaluation REST API
Diffstat (limited to 'model-evaluation')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java | 6 | ||||
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java | 10 |
2 files changed, 16 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index a0744128a11..bbd9962be77 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -10,6 +10,7 @@ import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; @@ -87,6 +88,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } } Tensor result = evaluator.evaluate(); + + Optional<String> format = property(request, "format"); + if (format.isPresent() && format.get().equalsIgnoreCase("short") && result instanceof IndexedTensor) { + return new Response(200, JsonFormat.encodeShortForm((IndexedTensor) result)); + } return new Response(200, JsonFormat.encode(result)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index df89919a76e..8034be6bb22 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -183,6 +183,16 @@ public class ModelsEvaluationHandlerTest { } @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() { + Map<String, String> properties = new HashMap<>(); + properties.put("Placeholder", inputTensorShortForm()); + properties.put("format", "short"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"value\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}"; + handler.assertResponse(url, properties, 200, expected); + } + + @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; |