diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java | 43 |
1 files changed, 34 insertions, 9 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index b0e2be26f8a..d5ae1bbf591 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -4,6 +4,9 @@ package ai.vespa.models.handler; import ai.vespa.models.evaluation.FunctionEvaluator; import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; +import com.fasterxml.jackson.core.JsonEncoding; +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; @@ -15,6 +18,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; +import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.OutputStream; import java.net.URI; @@ -65,14 +69,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } return listModelInformation(request, model, function); - } catch (IllegalArgumentException e) { + } catch (IllegalArgumentException | IOException e) { return new ErrorResponse(404, Exceptions.toMessageString(e)); } catch (IllegalStateException e) { // On missing bindings return new ErrorResponse(400, Exceptions.toMessageString(e)); } } - private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { + private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) throws IOException { FunctionEvaluator evaluator = model.evaluatorOf(function); property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue))); @@ -87,16 +91,37 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { } } } - Tensor result = evaluator.evaluate(); + String format = property(request, "format.tensors").orElse("default"); + if (evaluator.outputs().size() > 1) { + evaluator.evaluate(); + return new Response(200, encodeMultipleResults(evaluator, format)); + } + return new Response(200, encodeSingleResult(evaluator.evaluate(), format)); + } - Optional<String> format = property(request, "format.tensors"); - if (format.isPresent() && format.get().equalsIgnoreCase("short")) { - return new Response(200, JsonFormat.encodeShortForm(result)); + private byte[] encodeMultipleResults(FunctionEvaluator evaluator, String format) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8); + g.writeStartObject(); + for (String output : evaluator.outputs()) { + g.writeFieldName(output); + g.writeRawValue(new String(encodeSingleResult(evaluator.result(output), format))); } - else if (format.isPresent() && format.get().equalsIgnoreCase("string")) { - return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8)); + g.writeEndObject(); + g.close(); + return out.toByteArray(); + } + + private byte[] encodeSingleResult(Tensor tensor, String format) { + if (format != null) { + if (format.equalsIgnoreCase("short")) { + return JsonFormat.encodeShortForm(tensor); + } + if (format.equalsIgnoreCase("string")) { + return tensor.toString().getBytes(StandardCharsets.UTF_8); + } } - return new Response(200, JsonFormat.encode(result)); + return JsonFormat.encode(tensor); } private HttpResponse listAllModels(HttpRequest request) { |