aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-09-01 15:26:11 +0200
committerLester Solbakken <lesters@oath.com>2021-09-01 15:26:11 +0200
commit8286b1d9394a3f89c08b0d193d65d44e937be017 (patch)
treeae0f97745055aba39feae4ba59b9a3594a1a9b01 /model-evaluation
parentb8ef0eedfd29827a98d561d65b4c657ecbadf243 (diff)
Add short form output option to model-evaluation REST API
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java6
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java10
2 files changed, 16 insertions, 0 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index a0744128a11..bbd9962be77 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -10,6 +10,7 @@ import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Slime;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
@@ -87,6 +88,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
}
Tensor result = evaluator.evaluate();
+
+ Optional<String> format = property(request, "format");
+ if (format.isPresent() && format.get().equalsIgnoreCase("short") && result instanceof IndexedTensor) {
+ return new Response(200, JsonFormat.encodeShortForm((IndexedTensor) result));
+ }
return new Response(200, JsonFormat.encode(result));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index df89919a76e..8034be6bb22 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -183,6 +183,16 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithShortOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ properties.put("format", "short");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"value\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";