aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-18 16:29:37 -0700
committerGitHub <noreply@github.com>2018-09-18 16:29:37 -0700
commita280933161001785f900d381357e6802ee274ada (patch)
tree178f5a9c637df755e1ca2930119b3f1e39662b23
parentb38fc011d4a5cdcdbdb5d71ca77252502957fa92 (diff)
parenta172868b098df9cf7e49a177544b59529202b71d (diff)
Merge pull request #6981 from vespa-engine/lesters/add-model-eval-rest-api
Lesters/add model eval rest api
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java20
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java5
-rw-r--r--model-evaluation/pom.xml12
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java202
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/package-info.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java2
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java201
9 files changed, 431 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
index 09990c7b9de..11736256d1b 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerModelEvaluation.java
@@ -2,9 +2,12 @@
package com.yahoo.vespa.model.container;
import ai.vespa.models.evaluation.ModelsEvaluator;
+import ai.vespa.models.handler.ModelsEvaluationHandler;
+import com.yahoo.osgi.provider.model.ComponentModel;
import com.yahoo.searchdefinition.derived.RankProfileList;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.model.container.component.Handler;
import java.util.List;
import java.util.Objects;
@@ -16,12 +19,17 @@ import java.util.Objects;
*/
public class ContainerModelEvaluation implements RankProfilesConfig.Producer, RankingConstantsConfig.Producer {
+ private final static String EVALUATOR_NAME = ModelsEvaluator.class.getName();
+ private final static String REST_HANDLER_NAME = ModelsEvaluationHandler.class.getName();
+ private final static String BUNDLE_NAME = "model-evaluation";
+
/** Global rank profiles, aka models */
private final RankProfileList rankProfileList;
public ContainerModelEvaluation(ContainerCluster cluster, RankProfileList rankProfileList) {
this.rankProfileList = Objects.requireNonNull(rankProfileList, "rankProfileList cannot be null");
- cluster.addSimpleComponent(ModelsEvaluator.class.getName(), null, "model-evaluation");
+ cluster.addSimpleComponent(EVALUATOR_NAME, null, BUNDLE_NAME);
+ cluster.addComponent(ContainerModelEvaluation.getHandler());
}
public void prepare(List<Container> containers) {
@@ -38,4 +46,14 @@ public class ContainerModelEvaluation implements RankProfilesConfig.Producer, Ra
rankProfileList.getConfig(builder);
}
+ public static Handler<?> getHandler() {
+ Handler<?> handler = new Handler<>(new ComponentModel(REST_HANDLER_NAME, null, BUNDLE_NAME));
+ String binding = ModelsEvaluationHandler.API_ROOT + "/" + ModelsEvaluationHandler.VERSION_V1;
+ handler.addServerBindings("http://*/" + binding,
+ "https://*/" + binding,
+ "http://*/" + binding + "/*",
+ "https://*/" + binding + "/*");
+ return handler;
+ }
+
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index b7b3fc99e20..9e26caf2cb4 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.model.ml;
import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import ai.vespa.models.evaluation.RankProfilesConfigImporter;
+import ai.vespa.models.handler.ModelsEvaluationHandler;
import com.yahoo.component.ComponentId;
import com.yahoo.config.FileReference;
import com.yahoo.config.application.api.ApplicationPackage;
@@ -80,6 +81,10 @@ public class ModelEvaluationTest {
ContainerCluster cluster = model.getContainerClusters().get("container");
assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
+ assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluationHandler.class.getName())));
+ assertTrue(cluster.getHandlers().stream()
+ .anyMatch(h -> h.getComponentId().toString().equals(ModelsEvaluationHandler.class.getName())));
+
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml
index 328d475c501..7c7410df833 100644
--- a/model-evaluation/pom.xml
+++ b/model-evaluation/pom.xml
@@ -72,6 +72,18 @@
<artifactId>guava</artifactId>
<scope>provided</scope>
</dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>jdisc_http_service</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>jdisc_jetty</artifactId>
+ <version>${project.version}</version>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index e08b9f77d15..1412936d4a0 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -56,6 +56,6 @@ public class FunctionEvaluator {
return function.getBody().evaluate(context).asTensor();
}
- LazyArrayContext context() { return context; }
+ public LazyArrayContext context() { return context; }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index beaa36b898f..c7d0cbd8f30 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -26,7 +26,7 @@ import java.util.Set;
*
* @author bratseth
*/
-final class LazyArrayContext extends Context implements ContextIndex {
+public final class LazyArrayContext extends Context implements ContextIndex {
private final IndexedBindings indexedBindings;
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 78c46864d7b..1c995c255f5 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -1,47 +1,213 @@
package ai.vespa.models.handler;
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.Model;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
-import com.yahoo.container.jdisc.LoggingRequestHandler;
+import com.yahoo.container.jdisc.ThreadedHttpRequestHandler;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.JsonFormat;
import java.io.IOException;
import java.io.OutputStream;
+import java.net.URI;
+import java.nio.charset.Charset;
+import java.util.Optional;
+import java.util.concurrent.Executor;
-public class ModelsEvaluationHandler extends LoggingRequestHandler {
+public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
+
+ public static final String API_ROOT = "model-evaluation";
+ public static final String VERSION_V1 = "v1";
+ public static final String EVALUATE = "eval";
private final ModelsEvaluator modelsEvaluator;
- public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Context context) {
- super(context);
+ public ModelsEvaluationHandler(ModelsEvaluator modelsEvaluator, Executor executor) {
+ super(executor);
this.modelsEvaluator = modelsEvaluator;
}
@Override
public HttpResponse handle(HttpRequest request) {
- Tensor result = modelsEvaluator.evaluatorOf(property("model", "serving_default", request),
- request.getProperty("function"))
- .evaluate();
- return new RawResponse(JsonFormat.encode(result));
+ Path path = new Path(request);
+ Optional<String> apiName = path.segment(0);
+ Optional<String> version = path.segment(1);
+ Optional<String> modelName = path.segment(2);
+
+ if ( ! apiName.isPresent() || ! apiName.get().equalsIgnoreCase(API_ROOT)) {
+ return new ErrorResponse(404, "unknown API");
+ }
+ if ( ! version.isPresent() || ! version.get().equalsIgnoreCase(VERSION_V1)) {
+ return new ErrorResponse(404, "unknown API version");
+ }
+ if ( ! modelName.isPresent()) {
+ return listAllModels(request);
+ }
+ if ( ! modelsEvaluator.models().containsKey(modelName.get())) {
+ // TODO: Replace by catching IllegalArgumentException and passing that error message
+ return new ErrorResponse(404, "no model with name '" + modelName.get() + "' found");
+ }
+
+ Model model = modelsEvaluator.models().get(modelName.get());
+
+ // The following logic follows from the spec, in that signature and
+ // output are optional if the model only has a single function.
+ // TODO: Try to avoid recreating that logic here
+
+ if (path.segments() == 3) {
+ if (model.functions().size() > 1) {
+ return listModelDetails(request, modelName.get());
+ }
+ return listTypeDetails(request, modelName.get());
+ }
+
+ if (path.segments() == 4) {
+ if ( ! path.segment(3).get().equalsIgnoreCase(EVALUATE)) {
+ return listTypeDetails(request, modelName.get(), path.segment(3).get());
+ }
+ if (model.functions().stream().anyMatch(f -> f.getName().equalsIgnoreCase(EVALUATE))) {
+ return listTypeDetails(request, modelName.get(), path.segment(3).get()); // model has a function "eval"
+ }
+ if (model.functions().size() <= 1) {
+ return evaluateModel(request, modelName.get());
+ }
+ // TODO: Replace by catching IllegalArgumentException and passing that error message
+ return new ErrorResponse(404, "attempt to evaluate model without specifying function");
+ }
+
+ if (path.segments() == 5) {
+ if (path.segment(4).get().equalsIgnoreCase(EVALUATE)) {
+ return evaluateModel(request, modelName.get(), path.segment(3).get());
+ }
+ }
+
+ return new ErrorResponse(404, "unrecognized request");
+ }
+
+ private HttpResponse listAllModels(HttpRequest request) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ for (String modelName: modelsEvaluator.models().keySet()) {
+ root.setString(modelName, baseUrl(request) + modelName);
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse listModelDetails(HttpRequest request, String modelName) {
+ Model model = modelsEvaluator.models().get(modelName);
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ for (ExpressionFunction func : model.functions()) {
+ root.setString(func.getName(), baseUrl(request) + modelName + "/" + func.getName());
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, String modelName) {
+ return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, String modelName, String signatureAndOutput) {
+ return listTypeDetails(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
+ }
+
+ private HttpResponse listTypeDetails(HttpRequest request, FunctionEvaluator evaluator) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ Cursor bindings = root.setArray("bindings");
+ for (String bindingName : evaluator.context().names()) {
+ // TODO: Use an API which exposes only the external binding names instead of this
+ if (bindingName.startsWith("constant(")) {
+ continue;
+ }
+ if (bindingName.startsWith("rankingExpression(")) {
+ continue;
+ }
+ Cursor binding = bindings.addObject();
+ binding.setString("name", bindingName);
+ binding.setString("type", ""); // TODO: implement type information when available
+ }
+ return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, String modelName) {
+ return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, String modelName, String signatureAndOutput) {
+ return evaluateModel(request, modelsEvaluator.evaluatorOf(modelName, signatureAndOutput));
+ }
+
+ private HttpResponse evaluateModel(HttpRequest request, FunctionEvaluator evaluator) {
+ for (String bindingName : evaluator.context().names()) {
+ property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
+ }
+ Tensor result = evaluator.evaluate();
+ return new Response(200, JsonFormat.encode(result));
+ }
+
+ private Optional<String> property(HttpRequest request, String name) {
+ return Optional.ofNullable(request.getProperty(name));
}
- private String property(String name, String defaultValue, HttpRequest request) {
- String value = request.getProperty(name);
- if (value == null) return defaultValue;
- return value;
+ private String baseUrl(HttpRequest request) {
+ URI uri = request.getUri();
+ StringBuilder sb = new StringBuilder();
+ sb.append(uri.getScheme()).append("://").append(uri.getHost());
+ if (uri.getPort() >= 0) {
+ sb.append(":").append(uri.getPort());
+ }
+ sb.append("/").append(API_ROOT).append("/").append(VERSION_V1).append("/");
+ return sb.toString();
}
- private static class RawResponse extends HttpResponse {
+ private static class Path {
+
+ private final String[] segments;
+
+ public Path(HttpRequest httpRequest) {
+ segments = splitPath(httpRequest);
+ }
+
+ Optional<String> segment(int index) {
+ return (index < 0 || index >= segments.length) ? Optional.empty() : Optional.of(segments[index]);
+ }
+
+ int segments() {
+ return segments.length;
+ }
+
+ private static String[] splitPath(HttpRequest request) {
+ String path = request.getUri().getPath().toLowerCase();
+ if (path.startsWith("/")) {
+ path = path.substring("/".length());
+ }
+ if (path.endsWith("/")) {
+ path = path.substring(0, path.length() - 1);
+ }
+ return path.split("/");
+ }
+
+ }
+
+ private static class Response extends HttpResponse {
private final byte[] data;
- RawResponse(byte[] data) {
- super(200);
+ Response(int code, byte[] data) {
+ super(code);
this.data = data;
}
+ Response(int code, String data) {
+ this(code, data.getBytes(Charset.forName(DEFAULT_CHARACTER_ENCODING)));
+ }
+
@Override
public String getContentType() {
return "application/json";
@@ -53,5 +219,11 @@ public class ModelsEvaluationHandler extends LoggingRequestHandler {
}
}
+ private static class ErrorResponse extends Response {
+ ErrorResponse(int code, String data) {
+ super(code, "{\"error\":\"" + data + "\"}");
+ }
+ }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/package-info.java b/model-evaluation/src/main/java/ai/vespa/models/handler/package-info.java
new file mode 100644
index 00000000000..7978abf2632
--- /dev/null
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/package-info.java
@@ -0,0 +1,4 @@
+@ExportPackage
+package ai.vespa.models.handler;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 0aceaccc3e0..9a3e59aed80 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -65,7 +65,7 @@ public class ModelTester {
}
/** Allows us to provide canned tensor constants during import since file distribution does not work in tests */
- private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {
+ public static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {
private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
new file mode 100644
index 00000000000..5f045a2feb4
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -0,0 +1,201 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.handler;
+
+import ai.vespa.models.evaluation.ModelTester;
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.container.jdisc.HttpResponse;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.path.Path;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Executors;
+
+import static org.junit.Assert.assertEquals;
+
+public class ModelsEvaluationHandlerTest {
+
+ private static ModelsEvaluationHandler handler;
+
+ @BeforeClass
+ static public void setUp() {
+ Executor executor = Executors.newSingleThreadExecutor();
+ ModelsEvaluator models = createModels("src/test/resources/config/models/");
+ handler = new ModelsEvaluationHandler(models, executor);
+ }
+
+ @Test
+ public void testUnknownAPI() {
+ assertResponse("http://localhost/wrong-api-binding", 404);
+ }
+
+ @Test
+ public void testUnknownVersion() {
+ assertResponse("http://localhost/model-evaluation/v0", 404);
+ }
+
+ @Test
+ public void testNonExistingModel() {
+ assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404);
+ }
+
+ @Test
+ public void testListModels() {
+ String url = "http://localhost/model-evaluation/v1";
+ String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testXgBoostEvaluationWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function
+ String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testXgBoostEvaluationWithBindings() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("f29", "-1.0");
+ properties.put("f56", "0.2");
+ properties.put("f60", "0.3");
+ properties.put("f109", "0.4");
+ properties.put("non-existing-binding", "-1");
+ String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
+ String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxDetails() {
+ String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
+ String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; // only has a single function
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxTypeDetails() {
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/";
+ String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", "{1.0}");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", "{1.0}");
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testMnistSavedDetails() {
+ String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
+ String expected = "{\"imported_ml_macro_mnist_saved_dnn_hidden1_add\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"serving_default.y\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\"}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSavedTypeDetails() {
+ String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/";
+ String expected = "{\"bindings\":[{\"name\":\"input\",\"type\":\"\"}]}";
+ assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
+ String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
+ String expected = "{\"error\":\"attempt to evaluate model without specifying function\"}";
+ assertResponse(url, 404, expected);
+ }
+
+ @Test
+ public void testMnistSavedEvaluateSpecificFunction() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input", "-1.0");
+ String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}";
+ assertResponse(url, properties, 200, expected);
+ }
+
+ static private void assertResponse(String url, int expectedCode) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, null);
+ }
+
+ static private void assertResponse(String url, int expectedCode, String expectedResult) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
+ }
+
+ static private void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
+ HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
+ HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
+ assertResponse(getRequest, expectedCode, expectedResult);
+ assertResponse(postRequest, expectedCode, expectedResult);
+ }
+
+ static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
+ HttpResponse response = handler.handle(request);
+ assertEquals("application/json", response.getContentType());
+ assertEquals(expectedCode, response.getStatus());
+ if (expectedResult != null) {
+ assertEquals(expectedResult, getContents(response));
+ }
+ }
+
+ static private String getContents(HttpResponse response) {
+ try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
+ response.render(stream);
+ return stream.toString();
+ } catch (IOException e) {
+ throw new Error(e);
+ }
+ }
+
+ static private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ ModelTester.RankProfilesConfigImporterWithMockedConstants importer =
+ new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"),
+ MockFileAcquirer.returnFile(null));
+ return new ModelsEvaluator(importer.importFrom(config, constantsConfig));
+ }
+
+}