diff options
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java')
-rw-r--r-- | model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java | 201 |
1 files changed, 201 insertions, 0 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java new file mode 100644 index 00000000000..9966fd3d88e --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -0,0 +1,201 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.handler; + +import ai.vespa.models.evaluation.ModelTester; +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.container.jdisc.HttpResponse; +import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; +import com.yahoo.path.Path; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertEquals; + +public class ModelsEvaluationHandlerTest { + + private static ModelsEvaluationHandler handler; + + @BeforeClass + static public void setUp() { + Executor executor = Executors.newSingleThreadExecutor(); + ModelsEvaluator models = createModels("src/test/resources/config/models/"); + handler = new ModelsEvaluationHandler(models, executor); + } + + @Test + public void testUnknownAPI() { + assertResponse("http://localhost/wrong-api-binding", 400); + } + + @Test + public void testUnknownVersion() { + assertResponse("http://localhost/model-evaluation/v0", 400); + } + + @Test + public void testNonExistingModel() { + assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 400); + } + + @Test + public void testListModels() { + String url = "http://localhost/model-evaluation/v1"; + String expected = "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"}"; + assertResponse(url, 200, expected); + } + + @Test + public void testXgBoostEvaluationWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function + String expected = "{\"cells\":[{\"address\":{},\"value\":-8.17695}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testXgBoostEvaluationWithBindings() { + Map<String, String> properties = new HashMap<>(); + properties.put("f29", "-1.0"); + properties.put("f56", "0.2"); + properties.put("f60", "0.3"); + properties.put("f109", "0.4"); + properties.put("non-existing-binding", "-1"); + String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; + String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSoftmaxDetails() { + String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax"; + String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; // only has a single function + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxTypeDetails() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/"; + String expected = "{\"bindings\":[{\"name\":\"Placeholder\",\"type\":\"\"}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() { + Map<String, String> properties = new HashMap<>(); + properties.put("Placeholder", "{1.0}"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() { + Map<String, String> properties = new HashMap<>(); + properties.put("Placeholder", "{1.0}"); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + assertResponse(url, properties, 200, expected); + } + + @Test + public void testMnistSavedDetails() { + String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; + String expected = "{\"imported_ml_macro_mnist_saved_dnn_hidden1_add\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"serving_default.y\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\"}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSavedTypeDetails() { + String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/"; + String expected = "{\"bindings\":[{\"name\":\"input\",\"type\":\"\"}]}"; + assertResponse(url, 200, expected); + } + + @Test + public void testMnistSavedEvaluateDefaultFunctionShouldFail() { + String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; + String expected = "{\"error\":\"attempt to evaluate model without specifying function\"}"; + assertResponse(url, 400, expected); + } + + @Test + public void testMnistSavedEvaluateSpecificFunction() { + Map<String, String> properties = new HashMap<>(); + properties.put("input", "-1.0"); + String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; + String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}"; + assertResponse(url, properties, 200, expected); + } + + static private void assertResponse(String url, int expectedCode) { + assertResponse(url, Collections.emptyMap(), expectedCode, null); + } + + static private void assertResponse(String url, int expectedCode, String expectedResult) { + assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); + } + + static private void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) { + HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties); + HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties); + assertResponse(getRequest, expectedCode, expectedResult); + assertResponse(postRequest, expectedCode, expectedResult); + } + + static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { + HttpResponse response = handler.handle(request); + assertEquals("application/json", response.getContentType()); + assertEquals(expectedCode, response.getStatus()); + if (expectedResult != null) { + assertEquals(expectedResult, getContents(response)); + } + } + + static private String getContents(HttpResponse response) { + try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) { + response.render(stream); + return stream.toString(); + } catch (IOException e) { + throw new Error(e); + } + } + + static private ModelsEvaluator createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + ModelTester.RankProfilesConfigImporterWithMockedConstants importer = + new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), + MockFileAcquirer.returnFile(null)); + return new ModelsEvaluator(importer.importFrom(config, constantsConfig)); + } + +} |