From 0fded3444d2214b3280f7119a64395065bd30593 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 24 Nov 2021 12:48:24 +0100 Subject: Update stateless eval REST API for changes in multiple output functions --- .../models/handler/ModelsEvaluationHandler.java | 43 +++++++++++++++++----- .../handler/ModelsEvaluationHandlerTest.java | 7 ---- .../models/handler/OnnxEvaluationHandlerTest.java | 17 ++++++++- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index b0e2be26f8a..d5ae1bbf591 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -4,6 +4,9 @@ package ai.vespa.models.handler; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; @@ -15,6 +18,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.net.URI; @@ -65,14 +69,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } return listModelInformation(request, model, function); - } catch (IllegalArgumentException e) { + } catch (IllegalArgumentException | IOException e) { return new ErrorResponse(404, Exceptions.toMessageString(e)); } catch (IllegalStateException e) { // On missing bindings return new ErrorResponse(400, Exceptions.toMessageString(e)); } } - private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { + private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) throws IOException { FunctionEvaluator evaluator = model.evaluatorOf(function); property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue))); @@ -87,16 +91,37 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } } } - Tensor result = evaluator.evaluate(); + String format = property(request, "format.tensors").orElse("default"); + if (evaluator.outputs().size() > 1) { + evaluator.evaluate(); + return new Response(200, encodeMultipleResults(evaluator, format)); + } + return new Response(200, encodeSingleResult(evaluator.evaluate(), format)); + } - Optional format = property(request, "format.tensors"); - if (format.isPresent() && format.get().equalsIgnoreCase("short")) { - return new Response(200, JsonFormat.encodeShortForm(result)); + private byte[] encodeMultipleResults(FunctionEvaluator evaluator, String format) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + for (String output : evaluator.outputs()) { + g.writeFieldName(output); + g.writeRawValue(new String(encodeSingleResult(evaluator.result(output), format))); } - else if (format.isPresent() && format.get().equalsIgnoreCase("string")) { - return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8)); + g.writeEndObject(); + g.close(); + return out.toByteArray(); + } + + private byte[] encodeSingleResult(Tensor tensor, String format) { + if (format != null) { + if (format.equalsIgnoreCase("short")) { + return JsonFormat.encodeShortForm(tensor); + } + if (format.equalsIgnoreCase("string")) { + return tensor.toString().getBytes(StandardCharsets.UTF_8); + } } - return new Response(200, JsonFormat.encode(result)); + return JsonFormat.encode(tensor); } private HttpResponse listAllModels(HttpRequest request) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 33e56d5d465..8c7be4e7be9 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -205,13 +205,6 @@ public class ModelsEvaluationHandlerTest { handler.assertResponse(url, 200, expected); } - @Test - public void testMnistSavedEvaluateDefaultFunctionShouldFail() { - String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; - String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; - handler.assertResponse(url, 404, expected); - } - @Test public void testVespaModelShortOutput() { Map properties = new HashMap<>(); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 6014bd7c7ef..74715ad96a2 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -61,8 +61,8 @@ public class OnnxEvaluationHandlerTest { @Test public void testEvaluationWithoutSpecifyingOutput() { String url = "http://localhost/model-evaluation/v1/add_mul/eval"; - String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}"; - handler.assertResponse(url, 404, expected); + String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor(d0[1])\"}"; + handler.assertResponse(url, 400, expected); } @Test @@ -92,6 +92,19 @@ public class OnnxEvaluationHandlerTest { handler.assertResponse(url, properties, 200, expected); } + @Test + public void testEvaluateAllOutputs() { + Map properties = new HashMap<>(); + properties.put("input1", "tensor(d0[1]):[2]"); + properties.put("input2", "tensor(d0[1]):[3]"); + String url = "http://localhost/model-evaluation/v1/add_mul/eval"; // remember to add to discovery! + String expected = "{" + + "\"output1\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}," + // output1 is a mul + "\"output2\":{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}" + // output1 is an add + "}"; + handler.assertResponse(url, properties, 200, expected); + } + @Test public void testBatchDimensionModelInfo() { String url = "http://localhost/model-evaluation/v1/one_layer"; -- cgit v1.2.3