aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/test')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java69
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java76
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java93
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java137
-rw-r--r--model-evaluation/src/test/resources/config/models/onnx-models.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx24
-rwxr-xr-xmodel-evaluation/src/test/resources/config/onnx/models/add_mul.py30
-rw-r--r--model-evaluation/src/test/resources/config/onnx/models/one_layer.onnxbin0 -> 299 bytes
-rwxr-xr-xmodel-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py38
-rw-r--r--model-evaluation/src/test/resources/config/onnx/onnx-models.cfg16
-rw-r--r--model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg17
-rw-r--r--model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg0
16 files changed, 441 insertions, 69 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index bacdb52a201..d252594e729 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import java.io.IOException;
@@ -45,8 +46,10 @@ public class ModelTester {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"), MockFileAcquirer.returnFile(null))
- .importFrom(config, constantsConfig);
+ .importFrom(config, constantsConfig, onnxModelsConfig);
}
public ExpressionFunction assertFunction(String name, String expression, Model model) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 6fcf76d2815..dce033c79b0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.yolean.Exceptions;
import org.junit.Test;
@@ -131,7 +132,9 @@ public class ModelsEvaluatorTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
- return new ModelsEvaluator(config, constantsConfig, MockFileAcquirer.returnFile(null));
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null));
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
new file mode 100644
index 00000000000..1d55fdf9e6a
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -0,0 +1,69 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class OnnxEvaluatorTest {
+
+ private static final double delta = 0.00000000001;
+
+ @Test
+ public void testOnnxEvaluation() {
+ ModelsEvaluator models = createModels("src/test/resources/config/onnx/");
+
+ assertTrue(models.models().containsKey("add_mul"));
+ assertTrue(models.models().containsKey("one_layer"));
+
+ FunctionEvaluator function = models.evaluatorOf("add_mul", "output1");
+ function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
+ function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
+ assertEquals(6.0, function.evaluate().sum().asDouble(), delta);
+
+ function = models.evaluatorOf("add_mul", "output2");
+ function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
+ function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
+ assertEquals(5.0, function.evaluate().sum().asDouble(), delta);
+
+ function = models.evaluatorOf("one_layer");
+ function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]"));
+ assertEquals(function.evaluate(), Tensor.from("tensor<float>(d0[2],d1[1]):[0.63931,0.67574]"));
+ }
+
+ private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value()));
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
new file mode 100644
index 00000000000..0da7f2ed096
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -0,0 +1,76 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.handler;
+
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.container.jdisc.HttpResponse;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.JsonFormat;
+
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.Collections;
+import java.util.Map;
+import java.util.concurrent.Executors;
+
+import static org.junit.Assert.assertEquals;
+
+class HandlerTester {
+
+ private final ModelsEvaluationHandler handler;
+
+ HandlerTester(ModelsEvaluator models) {
+ this.handler = new ModelsEvaluationHandler(models, Executors.newSingleThreadExecutor());
+ }
+
+ void assertResponse(String url, int expectedCode) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, (String)null);
+ }
+
+ void assertResponse(String url, int expectedCode, String expectedResult) {
+ assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
+ }
+
+ void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
+ HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
+ HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
+ assertResponse(getRequest, expectedCode, expectedResult);
+ assertResponse(postRequest, expectedCode, expectedResult);
+ }
+
+ void assertResponse(String url, Map<String, String> properties, int expectedCode, Tensor expectedResult) {
+ HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
+ assertResponse(getRequest, expectedCode, expectedResult);
+ }
+
+ void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
+ HttpResponse response = handler.handle(request);
+ assertEquals("application/json", response.getContentType());
+ assertEquals(expectedCode, response.getStatus());
+ if (expectedResult != null) {
+ assertEquals(expectedResult, getContents(response));
+ }
+ }
+
+ void assertResponse(HttpRequest request, int expectedCode, Tensor expectedResult) {
+ HttpResponse response = handler.handle(request);
+ assertEquals("application/json", response.getContentType());
+ assertEquals(expectedCode, response.getStatus());
+ if (expectedResult != null) {
+ String contents = getContents(response);
+ Tensor result = JsonFormat.decode(expectedResult.type(), contents.getBytes(StandardCharsets.UTF_8));
+ assertEquals(expectedResult, result);
+ }
+ }
+
+ private String getContents(HttpResponse response) {
+ try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
+ response.render(stream);
+ return stream.toString();
+ } catch (IOException e) {
+ throw new Error(e);
+ }
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index c9e49d3be02..a69a220e532 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -5,51 +5,41 @@ import ai.vespa.models.evaluation.ModelTester;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
-import com.yahoo.container.jdisc.HttpRequest;
-import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import org.junit.BeforeClass;
import org.junit.Test;
-import java.io.ByteArrayOutputStream;
-import java.io.IOException;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
-import java.util.concurrent.Executor;
-import java.util.concurrent.Executors;
-
-import static org.junit.Assert.assertEquals;
public class ModelsEvaluationHandlerTest {
- private static ModelsEvaluationHandler handler;
+ private static HandlerTester handler;
@BeforeClass
static public void setUp() {
- Executor executor = Executors.newSingleThreadExecutor();
- ModelsEvaluator models = createModels("src/test/resources/config/models/");
- handler = new ModelsEvaluationHandler(models, executor);
+ handler = new HandlerTester(createModels("src/test/resources/config/models/"));
}
@Test
public void testUnknownAPI() {
- assertResponse("http://localhost/wrong-api-binding", 404);
+ handler.assertResponse("http://localhost/wrong-api-binding", 404);
}
@Test
public void testUnknownVersion() {
- assertResponse("http://localhost/model-evaluation/v0", 404);
+ handler.assertResponse("http://localhost/model-evaluation/v0", 404);
}
@Test
public void testNonExistingModel() {
- assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404);
+ handler.assertResponse("http://localhost/model-evaluation/v1/non-existing-model", 404);
}
@Test
@@ -57,14 +47,14 @@ public class ModelsEvaluationHandlerTest {
String url = "http://localhost/model-evaluation/v1";
String expected =
"{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testXgBoostEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval"; // only has a single function
String expected = "{\"cells\":[{\"address\":{},\"value\":-4.376589999999999}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
@@ -77,7 +67,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -90,14 +80,14 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/xgboost_2_2/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":-7.936679999999999}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testLightGBMEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":1.9130086820218188}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
@@ -110,7 +100,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":2.054697758469921}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -123,35 +113,35 @@ public class ModelsEvaluationHandlerTest {
properties.put("non-existing-binding", "-1");
String url = "http://localhost/model-evaluation/v1/lightgbm_regression/eval";
String expected = "{\"cells\":[{\"address\":{},\"value\":2.0745534018208094}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSoftmaxDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_softmax";
String expected = "{\"model\":\"mnist_softmax\",\"functions\":[{\"function\":\"default.add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/";
String expected = "{\"model\":\"mnist_softmax\",\"function\":\"default.add\",\"info\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval\",\"arguments\":[{\"name\":\"Placeholder\",\"type\":\"tensor(d0[],d1[784])\"}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
- assertResponse(url, 400, expected);
+ handler.assertResponse(url, 400, expected);
}
@Test
public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
- assertResponse(url, 400, expected);
+ handler.assertResponse(url, 400, expected);
}
@Test
@@ -160,7 +150,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
@@ -169,28 +159,28 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, properties, 200, expected);
+ handler.assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedTypeDetails() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/";
String expected = "{\"model\":\"mnist_saved\",\"function\":\"serving_default.y\",\"info\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}";
- assertResponse(url, 200, expected);
+ handler.assertResponse(url, 200, expected);
}
@Test
public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
- assertResponse(url, 404, expected);
+ handler.assertResponse(url, 404, expected);
}
@Test
@@ -199,40 +189,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("input", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval";
String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}";
- assertResponse(url, properties, 200, expected);
- }
-
- static private void assertResponse(String url, int expectedCode) {
- assertResponse(url, Collections.emptyMap(), expectedCode, null);
- }
-
- static private void assertResponse(String url, int expectedCode, String expectedResult) {
- assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
- }
-
- static private void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
- HttpRequest getRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.GET, null, properties);
- HttpRequest postRequest = HttpRequest.createTestRequest(url, com.yahoo.jdisc.http.HttpRequest.Method.POST, null, properties);
- assertResponse(getRequest, expectedCode, expectedResult);
- assertResponse(postRequest, expectedCode, expectedResult);
- }
-
- static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
- HttpResponse response = handler.handle(request);
- assertEquals("application/json", response.getContentType());
- if (expectedResult != null) {
- assertEquals(expectedResult, getContents(response));
- }
- assertEquals(expectedCode, response.getStatus());
- }
-
- static private String getContents(HttpResponse response) {
- try (ByteArrayOutputStream stream = new ByteArrayOutputStream()) {
- response.render(stream);
- return stream.toString();
- } catch (IOException e) {
- throw new Error(e);
- }
+ handler.assertResponse(url, properties, 200, expected);
}
static private ModelsEvaluator createModels(String path) {
@@ -241,10 +198,12 @@ public class ModelsEvaluationHandlerTest {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
ModelTester.RankProfilesConfigImporterWithMockedConstants importer =
new ModelTester.RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"),
MockFileAcquirer.returnFile(null));
- return new ModelsEvaluator(importer.importFrom(config, constantsConfig));
+ return new ModelsEvaluator(importer.importFrom(config, constantsConfig, onnxModelsConfig));
}
private String inputTensor() {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
new file mode 100644
index 00000000000..6cfda4d8ce8
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -0,0 +1,137 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.handler;
+
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
+import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+
+public class OnnxEvaluationHandlerTest {
+
+ private static HandlerTester handler;
+
+ @BeforeClass
+ static public void setUp() {
+ handler = new HandlerTester(createModels("src/test/resources/config/onnx/"));
+ }
+
+ @Test
+ public void testListModels() {
+ String url = "http://localhost/model-evaluation/v1";
+ String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," +
+ "\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," +
+ "\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testModelInfo() {
+ String url = "http://localhost/model-evaluation/v1/add_mul";
+ String expected = "{\"model\":\"add_mul\",\"functions\":[" +
+ "{\"function\":\"output1\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output1\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"onnxModel(add_mul).output1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
+ "]}," +
+ "{\"function\":\"output2\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/add_mul/output2\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"onnxModel(add_mul).output2\",\"type\":\"tensor<float>(d0[1])\"}," +
+ "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
+ "]}]}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testEvaluationWithoutSpecifyingOutput() {
+ String url = "http://localhost/model-evaluation/v1/add_mul/eval";
+ String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}";
+ handler.assertResponse(url, 404, expected);
+ }
+
+ @Test
+ public void testEvaluationWithoutBindings() {
+ String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
+ String expected = "{\"error\":\"Argument 'input2' must be bound to a value of type tensor<float>(d0[1])\"}";
+ handler.assertResponse(url, 400, expected);
+ }
+
+ @Test
+ public void testEvaluationOutput1() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input1", "tensor<float>(d0[1]):[2]");
+ properties.put("input2", "tensor<float>(d0[1]):[3]");
+ String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}"; // output1 is a mul
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testEvaluationOutput2() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input1", "tensor<float>(d0[1]):[2]");
+ properties.put("input2", "tensor<float>(d0[1]):[3]");
+ String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}"; // output2 is an add
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
+ public void testBatchDimensionModelInfo() {
+ String url = "http://localhost/model-evaluation/v1/one_layer";
+ String expected = "{\"model\":\"one_layer\",\"functions\":[" +
+ "{\"function\":\"output\"," +
+ "\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," +
+ "\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," +
+ "\"arguments\":[" +
+ "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}," +
+ "{\"name\":\"onnxModel(one_layer)\",\"type\":\"tensor<float>(d0[],d1[1])\"}" +
+ "]}]}";
+ handler.assertResponse(url, 200, expected);
+ }
+
+ @Test
+ public void testBatchDimensionEvaluation() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("input", "tensor<float>(d0[],d1[3]):{{d0:0,d1:0}:0.1,{d0:0,d1:1}:0.2,{d0:0,d1:2}:0.3,{d0:1,d1:0}:0.4,{d0:1,d1:1}:0.5,{d0:1,d1:2}:0.6}");
+ String url = "http://localhost/model-evaluation/v1/one_layer/eval"; // output not specified
+ Tensor expected = Tensor.from("tensor<float>(d0[2],d1[1]):[0.6393113,0.67574286]");
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ static private ModelsEvaluator createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ OnnxModelsConfig onnxModelsConfig = new ConfigGetter<>(new FileSource(configDir.append("onnx-models.cfg").toFile()),
+ OnnxModelsConfig.class).getConfig("");
+
+ Map<String, File> fileMap = new HashMap<>();
+ for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
+ fileMap.put(onnxModel.fileref().value(), new File(path + onnxModel.fileref().value()));
+ }
+ FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);
+
+ return new ModelsEvaluator(config, constantsConfig, onnxModelsConfig, fileAcquirer);
+ }
+
+}
diff --git a/model-evaluation/src/test/resources/config/models/onnx-models.cfg b/model-evaluation/src/test/resources/config/models/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/models/onnx-models.cfg
diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx
new file mode 100644
index 00000000000..ab054d112e9
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.onnx
@@ -0,0 +1,24 @@
+
+add_mul.py:£
+
+input1
+input2output1"Mul
+
+input1
+input2output2"Addadd_mulZ
+input1
+
+
+Z
+input2
+
+
+b
+output1
+
+
+b
+output2
+
+
+B \ No newline at end of file
diff --git a/model-evaluation/src/test/resources/config/onnx/models/add_mul.py b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py
new file mode 100755
index 00000000000..3a4522042e8
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/add_mul.py
@@ -0,0 +1,30 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+INPUT_1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [1])
+INPUT_2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, [1])
+OUTPUT_1 = helper.make_tensor_value_info('output1', TensorProto.FLOAT, [1])
+OUTPUT_2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, [1])
+
+nodes = [
+ helper.make_node(
+ 'Mul',
+ ['input1', 'input2'],
+ ['output1'],
+ ),
+ helper.make_node(
+ 'Add',
+ ['input1', 'input2'],
+ ['output2'],
+ ),
+]
+graph_def = helper.make_graph(
+ nodes,
+ 'add_mul',
+ [INPUT_1, INPUT_2],
+ [OUTPUT_1, OUTPUT_2],
+)
+model_def = helper.make_model(graph_def, producer_name='add_mul.py', opset_imports=[onnx.OperatorSetIdProto(version=12)])
+onnx.save(model_def, 'add_mul.onnx')
diff --git a/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx
new file mode 100644
index 00000000000..dc9f664b943
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/one_layer.onnx
Binary files differ
diff --git a/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py
new file mode 100755
index 00000000000..1296d84e180
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/models/pytorch_one_layer.py
@@ -0,0 +1,38 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import torch
+import torch.onnx
+
+
+class MyModel(torch.nn.Module):
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self.linear = torch.nn.Linear(in_features=3, out_features=1)
+ self.logistic = torch.nn.Sigmoid()
+
+ def forward(self, vec):
+ return self.logistic(self.linear(vec))
+
+
+def main():
+ model = MyModel()
+
+ # Omit training - just export randomly initialized network
+
+ data = torch.FloatTensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6]])
+ torch.onnx.export(model,
+ data,
+ "one_layer.onnx",
+ input_names = ["input"],
+ output_names = ["output"],
+ dynamic_axes = {
+ "input": {0: "batch"},
+ "output": {0: "batch"},
+ },
+ opset_version=12)
+
+
+if __name__ == "__main__":
+ main()
+
+
diff --git a/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg
new file mode 100644
index 00000000000..9ad9c7f6a07
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/onnx-models.cfg
@@ -0,0 +1,16 @@
+model[0].name "add_mul"
+model[0].fileref "models/add_mul.onnx"
+model[0].input[0].name "input1"
+model[0].input[0].source "input1"
+model[0].input[1].name "input2"
+model[0].input[1].source "input2"
+model[0].output[0].name "output1"
+model[0].output[0].as "output1"
+model[0].output[1].name "output2"
+model[0].output[1].as "output2"
+model[1].name "one_layer"
+model[1].fileref "models/one_layer.onnx"
+model[1].input[0].name "input"
+model[1].input[0].source "input"
+model[1].output[0].name "output"
+model[1].output[0].as "output"
diff --git a/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg
new file mode 100644
index 00000000000..047b7c3c77b
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/rank-profiles.cfg
@@ -0,0 +1,17 @@
+rankprofile[0].name "add_mul"
+rankprofile[0].fef.property[0].name "rankingExpression(output1).rankingScript"
+rankprofile[0].fef.property[0].value "onnxModel(add_mul).output1"
+rankprofile[0].fef.property[1].name "rankingExpression(output1).type"
+rankprofile[0].fef.property[1].value "tensor<float>(d0[1])"
+rankprofile[0].fef.property[2].name "rankingExpression(output2).rankingScript"
+rankprofile[0].fef.property[2].value "onnxModel(add_mul).output2"
+rankprofile[0].fef.property[3].name "rankingExpression(output2).type"
+rankprofile[0].fef.property[3].value "tensor<float>(d0[1])"
+rankprofile[1].name "one_layer"
+rankprofile[1].fef.property[0].name "rankingExpression(output).rankingScript"
+rankprofile[1].fef.property[0].value "onnxModel(one_layer)"
+rankprofile[1].fef.property[1].name "rankingExpression(output).type"
+rankprofile[1].fef.property[1].value "tensor<float>(d0[],d1[1])"
+rankprofile[2].name "no_model"
+rankprofile[2].fef.property[0].name "rankingExpression(output).rankingScript"
+rankprofile[2].fef.property[0].value "onnxModel(no_model)"
diff --git a/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/onnx/ranking-constants.cfg
diff --git a/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/rankexpression/onnx-models.cfg
diff --git a/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/smallconstant/onnx-models.cfg