aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java43
1 files changed, 34 insertions, 9 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index b0e2be26f8a..d5ae1bbf591 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -4,6 +4,9 @@ package ai.vespa.models.handler;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.fasterxml.jackson.core.JsonEncoding;
+import com.fasterxml.jackson.core.JsonFactory;
+import com.fasterxml.jackson.core.JsonGenerator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
@@ -15,6 +18,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
+import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
@@ -65,14 +69,14 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
return listModelInformation(request, model, function);
- } catch (IllegalArgumentException e) {
+ } catch (IllegalArgumentException | IOException e) {
return new ErrorResponse(404, Exceptions.toMessageString(e));
} catch (IllegalStateException e) { // On missing bindings
return new ErrorResponse(400, Exceptions.toMessageString(e));
}
}
- private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
+ private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) throws IOException {
FunctionEvaluator evaluator = model.evaluatorOf(function);
property(request, missingValueKey).ifPresent(missingValue -> evaluator.setMissingValue(Tensor.from(missingValue)));
@@ -87,16 +91,37 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
}
}
}
- Tensor result = evaluator.evaluate();
+ String format = property(request, "format.tensors").orElse("default");
+ if (evaluator.outputs().size() > 1) {
+ evaluator.evaluate();
+ return new Response(200, encodeMultipleResults(evaluator, format));
+ }
+ return new Response(200, encodeSingleResult(evaluator.evaluate(), format));
+ }
- Optional<String> format = property(request, "format.tensors");
- if (format.isPresent() && format.get().equalsIgnoreCase("short")) {
- return new Response(200, JsonFormat.encodeShortForm(result));
+ private byte[] encodeMultipleResults(FunctionEvaluator evaluator, String format) throws IOException {
+ ByteArrayOutputStream out = new ByteArrayOutputStream();
+ JsonGenerator g = new JsonFactory().createGenerator(out, JsonEncoding.UTF8);
+ g.writeStartObject();
+ for (String output : evaluator.outputs()) {
+ g.writeFieldName(output);
+ g.writeRawValue(new String(encodeSingleResult(evaluator.result(output), format)));
}
- else if (format.isPresent() && format.get().equalsIgnoreCase("string")) {
- return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8));
+ g.writeEndObject();
+ g.close();
+ return out.toByteArray();
+ }
+
+ private byte[] encodeSingleResult(Tensor tensor, String format) {
+ if (format != null) {
+ if (format.equalsIgnoreCase("short")) {
+ return JsonFormat.encodeShortForm(tensor);
+ }
+ if (format.equalsIgnoreCase("string")) {
+ return tensor.toString().getBytes(StandardCharsets.UTF_8);
+ }
}
- return new Response(200, JsonFormat.encode(result));
+ return JsonFormat.encode(tensor);
}
private HttpResponse listAllModels(HttpRequest request) {