summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-09-20 14:23:30 +0200
committerLester Solbakken <lesters@oath.com>2018-09-20 14:23:30 +0200
commit561814bd8d3e374fbdfc22823cb87a178d6e5838 (patch)
tree69ca7c46e289515052a906747f1577b632437e78 /model-evaluation
parent0c1420e13a92e9383406cdc1a106c23f5448ff98 (diff)
Move most of logic to model evaluation and better discovery information
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java113
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java10
2 files changed, 51 insertions, 72 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index e224d4a50dc..683a1f345d8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -16,6 +16,7 @@ import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
+import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.Executor;
@@ -49,46 +50,27 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
if ( ! modelName.isPresent()) {
return listAllModels(request);
}
- if ( ! modelsEvaluator.models().containsKey(modelName.get())) {
- throw new IllegalArgumentException("no model with name '" + modelName.get() + "' found");
- }
-
- Model model = modelsEvaluator.models().get(modelName.get());
-
- // The following logic follows from the spec, in that signature and
- // output are optional if the model only has a single function.
+ Model model = modelsEvaluator.requireModel(modelName.get());
- if (path.segments() == 3) {
- if (model.functions().size() > 1) {
- return listModelDetails(request, modelName.get());
- }
- return listTypeDetails(request, modelName.get());
- }
-
- if (path.segments() == 4) {
- if ( ! path.segment(3).get().equalsIgnoreCase(EVALUATE)) {
- return listTypeDetails(request, modelName.get(), path.segment(3).get());
- }
- if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) {
- return listTypeDetails(request, modelName.get(), path.segment(3).get()); // model has a function "eval"
- }
- if (model.functions().size() <= 1) {
- return evaluateModel(request, modelName.get());
- }
- throw new IllegalArgumentException("attempt to evaluate model without specifying function");
- }
-
- if (path.segments() == 5) {
- if (path.segment(4).get().equalsIgnoreCase(EVALUATE)) {
- return evaluateModel(request, modelName.get(), path.segment(3).get());
- }
+ Optional<Integer> evalSegment = path.lastIndexOf(EVALUATE);
+ String[] function = path.range(3, evalSegment);
+ if (evalSegment.isPresent()) {
+ return evaluateModel(request, model, function);
}
+ return listModelInformation(request, model, function);
} catch (IllegalArgumentException e) {
return new ErrorResponse(404, e.getMessage());
}
+ }
- return new ErrorResponse(404, "unrecognized request");
+ private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
+ FunctionEvaluator evaluator = model.evaluatorOf(function);
+ for (String bindingName : evaluator.context().names()) {
+ property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
+ }
+ Tensor result = evaluator.evaluate();
+ return new Response(200, JsonFormat.encode(result));
}
private HttpResponse listAllModels(HttpRequest request) {
@@ -100,28 +82,33 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
}
- private HttpResponse listModelDetails(HttpRequest request, String modelName) {
- Model model = modelsEvaluator.models().get(modelName);
+ private HttpResponse listModelInformation(HttpRequest request, Model model, String[] function) {
Slime slime = new Slime();
Cursor root = slime.setObject();
- for (ExpressionFunction func : model.functions()) {
- root.setString(func.getName(), baseUrl(request) + modelName + "/" + func.getName());
+ root.setString("model", model.name());
+ if (function.length == 0) {
+ listFunctions(request, model, root);
+ } else {
+ listFunctionDetails(request, model, function, root);
}
return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
}
- private HttpResponse listTypeDetails(HttpRequest request, String modelName) {
- return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName));
- }
-
- private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) {
- return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
+ private void listFunctions(HttpRequest request, Model model, Cursor cursor) {
+ Cursor functions = cursor.setArray("functions");
+ for (ExpressionFunction func : model.functions()) {
+ Cursor function = functions.addObject();
+ listFunctionDetails(request, model, new String[] { func.getName() }, function);
+ }
}
- private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) {
- Slime slime = new Slime();
- Cursor root = slime.setObject();
- Cursor bindings = root.setArray("bindings");
+ private void listFunctionDetails(HttpRequest request, Model model, String[] function, Cursor cursor) {
+ String compactedFunction = String.join(".", function);
+ FunctionEvaluator evaluator = model.evaluatorOf(function);
+ cursor.setString("function", compactedFunction);
+ cursor.setString("info", baseUrl(request) + model.name() + "/" + compactedFunction);
+ cursor.setString("eval", baseUrl(request) + model.name() + "/" + compactedFunction + "/" + EVALUATE);
+ Cursor bindings = cursor.setArray("bindings");
for (String bindingName : evaluator.context().names()) {
// TODO: Use an API which exposes only the external binding names instead of this
if (bindingName.startsWith("constant(")) {
@@ -131,26 +118,9 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
continue;
}
Cursor binding = bindings.addObject();
- binding.setString("name", bindingName);
+ binding.setString("binding", bindingName);
binding.setString("type", ""); // TODO: implement type information when available
}
- return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
- }
-
- private HttpResponse evaluateModel(HttpRequest request, String modelName) {
- return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName));
- }
-
- private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) {
- return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
- }
-
- private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) {
- for (String bindingName : evaluator.context().names()) {
- property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
- }
- Tensor result = evaluator.evaluate();
- return new Response(200, JsonFormat.encode(result));
}
private Optional<String> property(HttpRequest request, String name) {
@@ -180,8 +150,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
return (index < 0 || index >= segments.length) ? Optional.empty() : Optional.of(segments[index]);
}
- int segments() {
- return segments.length;
+ Optional<Integer> lastIndexOf(String segment) {
+ for (int i = segments.length - 1; i >= 0; --i) {
+ if (segments[i].equalsIgnoreCase(segment)) {
+ return Optional.of(i);
+ }
+ }
+ return Optional.empty();
+ }
+
+ public String[] range(int start, Optional<Integer> end) {
+ return Arrays.copyOfRange(segments, start, end.isPresent() ? end.get() : segments.length);
}
private static String[] splitPath(HttpRequest request) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 5f045a2feb4..6726f117c05 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -80,14 +80,14 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
- String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; // only has a single function
+ String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}]}";
assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/";
- String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}";
+ String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"bindings\":[{\"binding\":\"Placeholder\",\"type\":\"\"}]}";
assertResponse(url, 200, expected);
}
@@ -126,21 +126,21 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
- String expected = "{\"imported_ml_macro_mnist_saved_dnn_hidden1_add\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"serving_default.y\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\"}";
+ String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/";
- String expected = "{\"bindings\":[{\"name\":\"input\",\"type\":\"\"}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}";
assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
- String expected = "{\"error\":\"attempt to evaluate model without specifying function\"}";
+ String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_macro_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
assertResponse(url, 404, expected);
}