summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-05-31 17:55:21 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-05-31 17:55:21 +0200
commit986c2da2986a2fc0de4895a8107c85e4d0f37fd3 (patch)
tree3d7934b9feb062b9d1d48f7d4f88734ab8fecd9b /model-integration/src/test
parent470e70ea9fe12681bf0427497cf470ac76b9eb95 (diff)
Support native Vespa standalone models
Diffstat (limited to 'model-integration/src/test')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java58
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java1
-rw-r--r--model-integration/src/test/models/vespa/empty.model2
-rw-r--r--model-integration/src/test/models/vespa/example.model10
-rw-r--r--model-integration/src/test/models/vespa/misnamed.model3
5 files changed, 73 insertions, 1 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
new file mode 100644
index 00000000000..4f9fb9c070a
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -0,0 +1,58 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.vespa;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class VespaImportTestCase {
+
+ @Test
+ public void testExample() {
+ ImportedModel model = importModel("example");
+
+ assertEquals(1, model.inputs().size());
+ assertEquals("tensor(name{},x[10])", model.inputs().get("input1").toString());
+
+ assertEquals("var1 * var2", model.expressions().get("foo").getRoot().toString());
+ }
+
+ @Test
+ public void testEmpty() {
+ ImportedModel model = importModel("empty");
+ assertTrue(model.expressions().isEmpty());
+ assertTrue(model.functions().isEmpty());
+ assertTrue(model.inputs().isEmpty());
+ assertTrue(model.largeConstants().isEmpty());
+ assertTrue(model.smallConstants().isEmpty());
+ }
+
+ @Test
+ public void testWrongName() {
+ try {
+ importModel("misnamed");
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Model 'expectedname' must be saved in a file named 'expectedname.model'", e.getMessage());
+ }
+ }
+
+ private ImportedModel importModel(String name) {
+ String modelPath = "src/test/models/vespa/" + name + ".model";
+
+ VespaImporter importer = new VespaImporter();
+ assertTrue(importer.canImport(modelPath));
+ ImportedModel model = new VespaImporter().importModel(name, modelPath);
+ assertEquals(name, model.name());
+ assertEquals(modelPath, model.source());
+ return model;
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 965d5eb8577..67a3b17255c 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -18,7 +18,6 @@ public class XGBoostImportTestCase {
ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json");
assertTrue("All inputs are scalar", model.inputs().isEmpty());
assertEquals(1, model.expressions().size());
- System.out.println(model.expressions().keySet());
RankingExpression expression = model.expressions().get("test");
assertNotNull(expression);
assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
diff --git a/model-integration/src/test/models/vespa/empty.model b/model-integration/src/test/models/vespa/empty.model
new file mode 100644
index 00000000000..f5381b2ba93
--- /dev/null
+++ b/model-integration/src/test/models/vespa/empty.model
@@ -0,0 +1,2 @@
+model empty {
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model
new file mode 100644
index 00000000000..19598690aad
--- /dev/null
+++ b/model-integration/src/test/models/vespa/example.model
@@ -0,0 +1,10 @@
+model example {
+
+ input1: tensor(name{}, x[10])
+
+
+ function foo() {
+ expression: var1 * var2
+ }
+
+} \ No newline at end of file
diff --git a/model-integration/src/test/models/vespa/misnamed.model b/model-integration/src/test/models/vespa/misnamed.model
new file mode 100644
index 00000000000..44bfa5e380d
--- /dev/null
+++ b/model-integration/src/test/models/vespa/misnamed.model
@@ -0,0 +1,3 @@
+model expectedname {
+
+} \ No newline at end of file