aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-11-24 12:48:24 +0100
committerLester Solbakken <lesters@oath.com>2021-11-24 12:48:24 +0100
commit0fded3444d2214b3280f7119a64395065bd30593 (patch)
tree6d9f54262294d00f3a69e975361642826dc0a69f
parenta40eea2df1d64d8586768f1122da90a5756bef10 (diff)
Update stateless eval REST API for changes in multiple output functionslesters/stateless-onnx-eval-once
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java43
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java7
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java17
3 files changed, 49 insertions, 18 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index b0e2be26f8a..d5ae1bbf591 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -4,6 +4,9 @@ package ai.vespa.models.handler;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
@@ -15,6 +18,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
@@ -65,14 +69,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
return listModelInformation(request, model, function);
- } catch (IllegalArgumentException e) {
+ } catch (IllegalArgumentException | IOException e) {
return new ErrorResponse(404, Exceptions.toMessageString(e));
} catch (IllegalStateException e) { // On missing bindings
return new ErrorResponse(400, Exceptions.toMessageString(e));
}
}
- private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
+ private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) throws IOException {
FunctionEvaluator evaluator = model.evaluatorOf(function);
property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue)));
@@ -87,16 +91,37 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
}
}
- Tensor result = evaluator.evaluate();
+ String format = property(request, "format.tensors").orElse("default");
+ if (evaluator.outputs().size() > 1) {
+ evaluator.evaluate();
+ return new Response(200, encodeMultipleResults(evaluator, format));
+ }
+ return new Response(200, encodeSingleResult(evaluator.evaluate(), format));
+ }
- Optional<String> format = property(request, "format.tensors");
- if (format.isPresent() && format.get().equalsIgnoreCase("short")) {
- return new Response(200, JsonFormat.encodeShortForm(result));
+ private byte[] encodeMultipleResults(FunctionEvaluator evaluator, String format) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
+ g.writeStartObject();
+ for (String output : evaluator.outputs()) {
+ g.writeFieldName(output);
+ g.writeRawValue(new String(encodeSingleResult(evaluator.result(output), format)));
}
- else if (format.isPresent() && format.get().equalsIgnoreCase("string")) {
- return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8));
+ g.writeEndObject();
+ g.close();
+ return out.toByteArray();
+ }
+
+ private byte[] encodeSingleResult(Tensor tensor, String format) {
+ if (format != null) {
+ if (format.equalsIgnoreCase("short")) {
+ return JsonFormat.encodeShortForm(tensor);
+ }
+ if (format.equalsIgnoreCase("string")) {
+ return tensor.toString().getBytes(StandardCharsets.UTF_8);
+ }
}
- return new Response(200, JsonFormat.encode(result));
+ return JsonFormat.encode(tensor);
}
private HttpResponse listAllModels(HttpRequest request) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 33e56d5d465..8c7be4e7be9 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -206,13 +206,6 @@ public class ModelsEvaluationHandlerTest {
}
@Test
- public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
- String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
- String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
- handler.assertResponse(url, 404, expected);
- }
-
- @Test
public void testVespaModelShortOutput() {
Map<String, String> properties = new HashMap<>();
properties.put("format.tensors", "short");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 6014bd7c7ef..74715ad96a2 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -61,8 +61,8 @@ public class OnnxEvaluationHandlerTest {
@Test
public void testEvaluationWithoutSpecifyingOutput() {
String url = "http://localhost/model-evaluation/v1/add_mul/eval";
- String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}";
- handler.assertResponse(url, 404, expected);
+ String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor<float>(d0[1])\"}";
+ handler.assertResponse(url, 400, expected);
}
@Test
@@ -93,6 +93,19 @@ public class OnnxEvaluationHandlerTest {
}
@Test
+ public void testEvaluateAllOutputs() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input1", "tensor<float>(d0[1]):[2]");
+ properties.put("input2", "tensor<float>(d0[1]):[3]");
+ String url = "http://localhost/model-evaluation/v1/add_mul/eval"; // remember to add to discovery!
+ String expected = "{" +
+ "\"output1\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}," + // output1 is a mul
+ "\"output2\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}" + // output1 is an add
+ "}";
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testBatchDimensionModelInfo() {
String url = "http://localhost/model-evaluation/v1/one_layer";
String expected = "{\"model\":\"one_layer\",\"functions\":[" +