summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java198
1 files changed, 183 insertions, 15 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 78c46864d7b..48011c7e6b6 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -1,47 +1,209 @@
package ai.vespa.models.handler;
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
-import com.yahoo.container.jdisc.LoggingRequestHandler;
+import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.JsonFormat;
import java.io.IOException;
import java.io.OutputStream;
+import java.net.URI;
+import java.nio.charset.Charset;
+import java.util.Optional;
+import java.util.concurrent.Executor;
-public class ModelsEvaluationHandler extends LoggingRequestHandler {
+public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
+
+ public static final String API_ROOT = "model-evaluation";
+ public static final String VERSION_V1 = "v1";
+ public static final String EVALUATE = "eval";
private final ModelsEvaluator modelsEvaluator;
- public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Context context) {
- super(context);
+ public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) {
+ super(executor);
this.modelsEvaluator = modelsEvaluator;
}
@Override
public HttpResponse handle(HttpRequest request) {
- Tensor result = modelsEvaluator.evaluatorOf(property("model", "serving_default", request),
- request.getProperty("function"))
- .evaluate();
- return new RawResponse(JsonFormat.encode(result));
+ Path path = new Path(request);
+ Optional<String> apiName = path.segment(0);
+ Optional<String> version = path.segment(1);
+ Optional<String> modelName = path.segment(2);
+
+ if ( ! apiName.isPresent() || ! apiName.get().equalsIgnoreCase(API_ROOT)) {
+ return new ErrorResponse(400, "unknown API");
+ }
+ if ( ! version.isPresent() || ! version.get().equalsIgnoreCase(VERSION_V1)) {
+ return new ErrorResponse(400, "unknown API version");
+ }
+ if ( ! modelName.isPresent()) {
+ return listAllModels(request);
+ }
+ if ( ! modelsEvaluator.models().containsKey(modelName.get())) {
+ return new ErrorResponse(400, "no model with name '" + modelName.get() + "' found");
+ }
+
+ Model model = modelsEvaluator.models().get(modelName.get());
+
+ // The following logic follows from the spec, in that signature and
+ // output are optional if the model only has a single function.
+
+ if (path.segments() == 3) {
+ if (model.functions().size() > 1) {
+ return listModelDetails(request, modelName.get());
+ }
+ return listTypeDetails(request, modelName.get());
+ }
+
+ if (path.segments() == 4) {
+ if ( ! path.segment(3).get().equalsIgnoreCase(EVALUATE)) {
+ return listTypeDetails(request, modelName.get(), path.segment(3).get());
+ }
+ if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) {
+ return listTypeDetails(request, modelName.get(), path.segment(3).get()); // model has a function "eval"
+ }
+ if (model.functions().size() <= 1) {
+ return evaluateModel(request, modelName.get());
+ }
+ return new ErrorResponse(400, "attempt to evaluate model without specifying function");
+ }
+
+ if (path.segments() == 5) {
+ if (path.segment(4).get().equalsIgnoreCase(EVALUATE)) {
+ return evaluateModel(request, modelName.get(), path.segment(3).get());
+ }
+ }
+
+ return new ErrorResponse(400, "unrecognized request");
+ }
+
+ private HttpResponse listAllModels(HttpRequest request) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ for (String modelName: modelsEvaluator.models().keySet()) {
+ root.setString(modelName, baseUrl(request) + modelName);
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse listModelDetails(HttpRequest request, String modelName) {
+ Model model = modelsEvaluator.models().get(modelName);
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ for (ExpressionFunction func : model.functions()) {
+ root.setString(func.getName(), baseUrl(request) + modelName + "/" + func.getName());
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, String modelName) {
+ return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) {
+ return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ Cursor bindings = root.setArray("bindings");
+ for (String bindingName : evaluator.context().names()) {
+ if (bindingName.startsWith("constant(")) {
+ continue;
+ }
+ if (bindingName.startsWith("rankingExpression(")) {
+ continue;
+ }
+ Cursor binding = bindings.addObject();
+ binding.setString("name", bindingName);
+ binding.setString("type", ""); // todo: implement type information when available
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, String modelName) {
+ return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) {
+ return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) {
+ for (String bindingName : evaluator.context().names()) {
+ property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
+ }
+ Tensor result = evaluator.evaluate();
+ return new Response(200, JsonFormat.encode(result));
+ }
+
+ private Optional<String> property(HttpRequest request, String name) {
+ return Optional.ofNullable(request.getProperty(name));
}
- private String property(String name, String defaultValue, HttpRequest request) {
- String value = request.getProperty(name);
- if (value == null) return defaultValue;
- return value;
+ private String baseUrl(HttpRequest request) {
+ URI uri = request.getUri();
+ StringBuilder sb = new StringBuilder();
+ sb.append(uri.getScheme()).append("://").append(uri.getHost());
+ if (uri.getPort() >= 0) {
+ sb.append(":").append(uri.getPort());
+ }
+ sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/");
+ return sb.toString();
}
- private static class RawResponse extends HttpResponse {
+ private static class Path {
+
+ private final String[] segments;
+
+ public Path(HttpRequest httpRequest) {
+ segments = splitPath(httpRequest);
+ }
+
+ Optional<String> segment(int index) {
+ return (index < 0 || index >= segments.length) ? Optional.empty() : Optional.of(segments[index]);
+ }
+
+ int segments() {
+ return segments.length;
+ }
+
+ private static String[] splitPath(HttpRequest request) {
+ String path = request.getUri().getPath().toLowerCase();
+ if (path.startsWith("/")) {
+ path = path.substring("/".length());
+ }
+ if (path.endsWith("/")) {
+ path = path.substring(0, path.length() - 1);
+ }
+ return path.split("/");
+ }
+
+ }
+
+ private static class Response extends HttpResponse {
private final byte[] data;
- RawResponse(byte[] data) {
- super(200);
+ Response(int code, byte[] data) {
+ super(code);
this.data = data;
}
+ Response(int code, String data) {
+ this(code, data.getBytes(Charset.forName(DEFAULT_CHARACTER_ENCODING)));
+ }
+
@Override
public String getContentType() {
return "application/json";
@@ -53,5 +215,11 @@ public class ModelsEvaluationHandler extends LoggingRequestHandler {
}
}
+ private static class ErrorResponse extends Response {
+ ErrorResponse(int code, String data) {
+ super(code, "{\"error\":\"" + data + "\"}");
+ }
+ }
+
}