diff options
author | Lester Solbakken <lesters@oath.com> | 2018-09-20 14:23:30 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-09-20 14:23:30 +0200 |
commit | 561814bd8d3e374fbdfc22823cb87a178d6e5838 (patch) | |
tree | 69ca7c46e289515052a906747f1577b632437e78 | |
parent | 0c1420e13a92e9383406cdc1a106c23f5448ff98 (diff) |
Move most of logic to model evaluation and better discovery information
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java | 113 | ||||
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java | 10 |
2 files changed, 51 insertions, 72 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index e224d4a50dc..683a1f345d8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.charset.Charset; +import java.util.Arrays; import java.util.Optional; import java.util.concurrent.Executor; @@ -49,46 +50,27 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { if ( ! modelName.isPresent()) { return listAllModels(request); } - if ( ! modelsEvaluator.models().containsKey(modelName.get())) { - throw new IllegalArgumentException("no model with name '" + modelName.get() + "' found"); - } - - Model model = modelsEvaluator.models().get(modelName.get()); - - // The following logic follows from the spec, in that signature and - // output are optional if the model only has a single function. + Model model = modelsEvaluator.requireModel(modelName.get()); - if (path.segments() == 3) { - if (model.functions().size() > 1) { - return listModelDetails(request, modelName.get()); - } - return listTypeDetails(request, modelName.get()); - } - - if (path.segments() == 4) { - if ( ! path.segment(3).get().equalsIgnoreCase(EVALUATE)) { - return listTypeDetails(request, modelName.get(), path.segment(3).get()); - } - if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) { - return listTypeDetails(request, modelName.get(), path.segment(3).get()); // model has a function "eval" - } - if (model.functions().size() <= 1) { - return evaluateModel(request, modelName.get()); - } - throw new IllegalArgumentException("attempt to evaluate model without specifying function"); - } - - if (path.segments() == 5) { - if (path.segment(4).get().equalsIgnoreCase(EVALUATE)) { - return evaluateModel(request, modelName.get(), path.segment(3).get()); - } + Optional<Integer> evalSegment = path.lastIndexOf(EVALUATE); + String[] function = path.range(3, evalSegment); + if (evalSegment.isPresent()) { + return evaluateModel(request, model, function); } + return listModelInformation(request, model, function); } catch (IllegalArgumentException e) { return new ErrorResponse(404, e.getMessage()); } + } - return new ErrorResponse(404, "unrecognized request"); + private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { + FunctionEvaluator evaluator = model.evaluatorOf(function); + for (String bindingName : evaluator.context().names()) { + property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); + } + Tensor result = evaluator.evaluate(); + return new Response(200, JsonFormat.encode(result)); } private HttpResponse listAllModels(HttpRequest request) { @@ -100,28 +82,33 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); } - private HttpResponse listModelDetails(HttpRequest request, String modelName) { - Model model = modelsEvaluator.models().get(modelName); + private HttpResponse listModelInformation(HttpRequest request, Model model, String[] function) { Slime slime = new Slime(); Cursor root = slime.setObject(); - for (ExpressionFunction func : model.functions()) { - root.setString(func.getName(), baseUrl(request) + modelName + "/" + func.getName()); + root.setString("model", model.name()); + if (function.length == 0) { + listFunctions(request, model, root); + } else { + listFunctionDetails(request, model, function, root); } return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); } - private HttpResponse listTypeDetails(HttpRequest request, String modelName) { - return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName)); - } - - private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) { - return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput)); + private void listFunctions(HttpRequest request, Model model, Cursor cursor) { + Cursor functions = cursor.setArray("functions"); + for (ExpressionFunction func : model.functions()) { + Cursor function = functions.addObject(); + listFunctionDetails(request, model, new String[] { func.getName() }, function); + } } - private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) { - Slime slime = new Slime(); - Cursor root = slime.setObject(); - Cursor bindings = root.setArray("bindings"); + private void listFunctionDetails(HttpRequest request, Model model, String[] function, Cursor cursor) { + String compactedFunction = String.join(".", function); + FunctionEvaluator evaluator = model.evaluatorOf(function); + cursor.setString("function", compactedFunction); + cursor.setString("info", baseUrl(request) + model.name() + "/" + compactedFunction); + cursor.setString("eval", baseUrl(request) + model.name() + "/" + compactedFunction + "/" + EVALUATE); + Cursor bindings = cursor.setArray("bindings"); for (String bindingName : evaluator.context().names()) { // TODO: Use an API which exposes only the external binding names instead of this if (bindingName.startsWith("constant(")) { @@ -131,26 +118,9 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { continue; } Cursor binding = bindings.addObject(); - binding.setString("name", bindingName); + binding.setString("binding", bindingName); binding.setString("type", ""); // TODO: implement type information when available } - return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); - } - - private HttpResponse evaluateModel(HttpRequest request, String modelName) { - return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName)); - } - - private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) { - return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput)); - } - - private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) { - for (String bindingName : evaluator.context().names()) { - property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); - } - Tensor result = evaluator.evaluate(); - return new Response(200, JsonFormat.encode(result)); } private Optional<String> property(HttpRequest request, String name) { @@ -180,8 +150,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { return (index < 0 || index >= segments.length) ? Optional.empty() : Optional.of(segments[index]); } - int segments() { - return segments.length; + Optional<Integer> lastIndexOf(String segment) { + for (int i = segments.length - 1; i >= 0; --i) { + if (segments[i].equalsIgnoreCase(segment)) { + return Optional.of(i); + } + } + return Optional.empty(); + } + + public String[] range(int start, Optional<Integer> end) { + return Arrays.copyOfRange(segments, start, end.isPresent() ? end.get() : segments.length); } private static String[] splitPath(HttpRequest request) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 5f045a2feb4..6726f117c05 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -80,14 +80,14 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSoftmaxDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; - String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; // only has a single function + String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}]}"; assertResponse(url, 200, expected); } @Test public void testMnistSoftmaxTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/"; - String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; + String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}"; assertResponse(url, 200, expected); } @@ -126,21 +126,21 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; - String expected = "{\"imported_ml_macro_mnist_saved_dnn_hidden1_add\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"serving_default.y\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\"}"; + String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; assertResponse(url, 200, expected); } @Test public void testMnistSavedTypeDetails() { String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/"; - String expected = "{\"bindings\":[{\"name\":\"input\",\"type\":\"\"}]}"; + String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}"; assertResponse(url, 200, expected); } @Test public void testMnistSavedEvaluateDefaultFunctionShouldFail() { String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; - String expected = "{\"error\":\"attempt to evaluate model without specifying function\"}"; + String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_macro_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; assertResponse(url, 404, expected); } |