From cc89f8b15b0ba1c1301702920d2fb7471e513d9c Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Mon, 17 Sep 2018 16:15:00 +0200 Subject: Add model evaluation rest api handler --- model-evaluation/pom.xml | 12 ++ .../vespa/models/evaluation/FunctionEvaluator.java | 2 +- .../vespa/models/evaluation/LazyArrayContext.java | 2 +- .../models/handler/ModelsEvaluationHandler.java | 198 ++++++++++++++++++-- .../ai/vespa/models/evaluation/ModelTester.java | 2 +- .../handler/ModelsEvaluationHandlerTest.java | 201 +++++++++++++++++++++ 6 files changed, 399 insertions(+), 18 deletions(-) create mode 100644 model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java (limited to 'model-evaluation') diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index 328d475c501..7c7410df833 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -72,6 +72,18 @@ guava provided + + com.yahoo.vespa + jdisc_http_service + ${project.version} + provided + + + com.yahoo.vespa + jdisc_jetty + ${project.version} + test + diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index e08b9f77d15..1412936d4a0 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -56,6 +56,6 @@ public class FunctionEvaluator { return function.getBody().evaluate(context).asTensor(); } - LazyArrayContext context() { return context; } + public LazyArrayContext context() { return context; } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index beaa36b898f..c7d0cbd8f30 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -26,7 +26,7 @@ import java.util.Set; * * @author bratseth */ -final class LazyArrayContext extends Context implements ContextIndex { +public final class LazyArrayContext extends Context implements ContextIndex { private final IndexedBindings indexedBindings; diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 78c46864d7b..48011c7e6b6 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -1,47 +1,209 @@ package ai.vespa.models.handler; +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; import ai.vespa.models.evaluation.ModelsEvaluator; import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; -import com.yahoo.container.jdisc.LoggingRequestHandler; +import com.yahoo.container.jdisc.ThreadedHttpRequestHandler; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.JsonFormat; import java.io.IOException; import java.io.OutputStream; +import java.net.URI; +import java.nio.charset.Charset; +import java.util.Optional; +import java.util.concurrent.Executor; -public class ModelsEvaluationHandler extends LoggingRequestHandler { +public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { + + public static final String API_ROOT = "model-evaluation"; + public static final String VERSION_V1 = "v1"; + public static final String EVALUATE = "eval"; private final ModelsEvaluator modelsEvaluator; - public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Context context) { - super(context); + public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) { + super(executor); this.modelsEvaluator = modelsEvaluator; } @Override public HttpResponse handle(HttpRequest request) { - Tensor result = modelsEvaluator.evaluatorOf(property("model", "serving_default", request), - request.getProperty("function")) - .evaluate(); - return new RawResponse(JsonFormat.encode(result)); + Path path = new Path(request); + Optional apiName = path.segment(0); + Optional version = path.segment(1); + Optional modelName = path.segment(2); + + if ( ! apiName.isPresent() || ! apiName.get().equalsIgnoreCase(API_ROOT)) { + return new ErrorResponse(400, "unknown API"); + } + if ( ! version.isPresent() || ! version.get().equalsIgnoreCase(VERSION_V1)) { + return new ErrorResponse(400, "unknown API version"); + } + if ( ! modelName.isPresent()) { + return listAllModels(request); + } + if ( ! modelsEvaluator.models().containsKey(modelName.get())) { + return new ErrorResponse(400, "no model with name '" + modelName.get() + "' found"); + } + + Model model = modelsEvaluator.models().get(modelName.get()); + + // The following logic follows from the spec, in that signature and + // output are optional if the model only has a single function. + + if (path.segments() == 3) { + if (model.functions().size() > 1) { + return listModelDetails(request, modelName.get()); + } + return listTypeDetails(request, modelName.get()); + } + + if (path.segments() == 4) { + if ( ! path.segment(3).get().equalsIgnoreCase(EVALUATE)) { + return listTypeDetails(request, modelName.get(), path.segment(3).get()); + } + if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) { + return listTypeDetails(request, modelName.get(), path.segment(3).get()); // model has a function "eval" + } + if (model.functions().size() <= 1) { + return evaluateModel(request, modelName.get()); + } + return new ErrorResponse(400, "attempt to evaluate model without specifying function"); + } + + if (path.segments() == 5) { + if (path.segment(4).get().equalsIgnoreCase(EVALUATE)) { + return evaluateModel(request, modelName.get(), path.segment(3).get()); + } + } + + return new ErrorResponse(400, "unrecognized request"); + } + + private HttpResponse listAllModels(HttpRequest request) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + for (String modelName: modelsEvaluator.models().keySet()) { + root.setString(modelName, baseUrl(request) + modelName); + } + return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); + } + + private HttpResponse listModelDetails(HttpRequest request, String modelName) { + Model model = modelsEvaluator.models().get(modelName); + Slime slime = new Slime(); + Cursor root = slime.setObject(); + for (ExpressionFunction func : model.functions()) { + root.setString(func.getName(), baseUrl(request) + modelName + "/" + func.getName()); + } + return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); + } + + private HttpResponse listTypeDetails(HttpRequest request, String modelName) { + return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName)); + } + + private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) { + return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput)); + } + + private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) { + Slime slime = new Slime(); + Cursor root = slime.setObject(); + Cursor bindings = root.setArray("bindings"); + for (String bindingName : evaluator.context().names()) { + if (bindingName.startsWith("constant(")) { + continue; + } + if (bindingName.startsWith("rankingExpression(")) { + continue; + } + Cursor binding = bindings.addObject(); + binding.setString("name", bindingName); + binding.setString("type", ""); // todo: implement type information when available + } + return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); + } + + private HttpResponse evaluateModel(HttpRequest request, String modelName) { + return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName)); + } + + private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) { + return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput)); + } + + private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) { + for (String bindingName : evaluator.context().names()) { + property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); + } + Tensor result = evaluator.evaluate(); + return new Response(200, JsonFormat.encode(result)); + } + + private Optional property(HttpRequest request, String name) { + return Optional.ofNullable(request.getProperty(name)); } - private String property(String name, String defaultValue, HttpRequest request) { - String value = request.getProperty(name); - if (value == null) return defaultValue; - return value; + private String baseUrl(HttpRequest request) { + URI uri = request.getUri(); + StringBuilder sb = new StringBuilder(); + sb.append(uri.getScheme()).append("://").append(uri.getHost()); + if (uri.getPort() >= 0) { + sb.append(":").append(uri.getPort()); + } + sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/"); + return sb.toString(); } - private static class RawResponse extends HttpResponse { + private static class Path { + + private final String[] segments; + + public Path(HttpRequest httpRequest) { + segments = splitPath(httpRequest); + } + + Optional segment(int index) { + return (index < 0 || index >= segments.length) ? Optional.empty() : Optional.of(segments[index]); + } + + int segments() { + return segments.length; + } + + private static String[] splitPath(HttpRequest request) { + String path = request.getUri().getPath().toLowerCase(); + if (path.startsWith("/")) { + path = path.substring("/".length()); + } + if (path.endsWith("/")) { + path = path.substring(0, path.length() - 1); + } + return path.split("/"); + } + + } + + private static class Response extends HttpResponse { private final byte[] data; - RawResponse(byte[] data) { - super(200); + Response(int code, byte[] data) { + super(code); this.data = data; } + Response(int code, String data) { + this(code, data.getBytes(Charset.forName(DEFAULT_CHARACTER_ENCODING))); + } + @Override public String getContentType() { return "application/json"; @@ -53,5 +215,11 @@ public class ModelsEvaluationHandler extends LoggingRequestHandler { } } + private static class ErrorResponse extends Response { + ErrorResponse(int code, String data) { + super(code, "{\"error\":\"" + data + "\"}"); + } + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index 0aceaccc3e0..9a3e59aed80 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -65,7 +65,7 @@ public class ModelTester { } /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */ - private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { + public static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName()); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java new file mode 100644 index 00000000000..9966fd3d88e --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -0,0 +1,201 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.handler; + +import ai.vespa.models.evaluation.ModelTester; +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertEquals; + +public class ModelsEvaluationHandlerTest { + + private static ModelsEvaluationHandler handler; + + @BeforeClass + static public void setUp() { + Executor executor = Executors.newSingleThreadExecutor(); + ModelsEvaluator models = createModels("src/test/resources/config/models/"); + handler = new ModelsEvaluationHandler(models, executor); + } + + @Test + public void testUnknownAPI() { + assertResponse("http://localhost/wrong-api-binding", 400); + } + + @Test + public void testUnknownVersion() { + assertResponse("http://localhost/model-evaluation/v0", 400); + } + + @Test + public void testNonExistingModel() { + assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 400); + } + + @Test + public void testListModels() { + String url = "http://localhost/model-evaluation/v1"; + String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}"; + assertResponse(url, 200, expected); + } + + @Test + public void testXgBoostEvaluationWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function + String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testXgBoostEvaluationWithBindings() { + Map properties = new HashMap<>(); + properties.put("f29", "-1.0"); + properties.put("f56", "0.2"); + properties.put("f60", "0.3"); + properties.put("f109", "0.4"); + properties.put("non-existing-binding", "-1"); + String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSoftmaxDetails() { + String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; + String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; // only has a single function + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxTypeDetails() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/"; + String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() { + Map properties = new HashMap<>(); + properties.put("Placeholder", "{1.0}"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() { + Map properties = new HashMap<>(); + properties.put("Placeholder", "{1.0}"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSavedDetails() { + String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; + String expected = "{\"imported_ml_macro_mnist_saved_dnn_hidden1_add\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"serving_default.y\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\"}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSavedTypeDetails() { + String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/"; + String expected = "{\"bindings\":[{\"name\":\"input\",\"type\":\"\"}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSavedEvaluateDefaultFunctionShouldFail() { + String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; + String expected = "{\"error\":\"attempt to evaluate model without specifying function\"}"; + assertResponse(url, 400, expected); + } + + @Test + public void testMnistSavedEvaluateSpecificFunction() { + Map properties = new HashMap<>(); + properties.put("input", "-1.0"); + String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}"; + assertResponse(url, properties, 200, expected); + } + + static private void assertResponse(String url, int expectedCode) { + assertResponse(url, Collections.emptyMap(), expectedCode, null); + } + + static private void assertResponse(String url, int expectedCode, String expectedResult) { + assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); + } + + static private void assertResponse(String url, Map properties, int expectedCode, String expectedResult) { + HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); + HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); + assertResponse(getRequest, expectedCode, expectedResult); + assertResponse(postRequest, expectedCode, expectedResult); + } + + static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { + HttpResponse response = handler.handle(request); + assertEquals("application/json", response.getContentType()); + assertEquals(expectedCode, response.getStatus()); + if (expectedResult != null) { + assertEquals(expectedResult, getContents(response)); + } + } + + static private String getContents(HttpResponse response) { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + response.render(stream); + return stream.toString(); + } catch (IOException e) { + throw new Error(e); + } + } + + static private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + ModelTester.RankProfilesConfigImporterWithMockedConstants importer = + new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), + MockFileAcquirer.returnFile(null)); + return new ModelsEvaluator(importer.importFrom(config, constantsConfig)); + } + +} -- cgit v1.2.3