summaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java28
1 files changed, 28 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
new file mode 100644
index 00000000000..48c7f5bee19
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -0,0 +1,28 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author bratseth
+ */
+public class XGBoostImportTestCase {
+
+ @Test
+ public void testXGBoost() {
+ ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json");
+ assertTrue("All inputs are scalar", model.inputs().isEmpty());
+ assertEquals(1, model.expressions().size());
+ System.out.println(model.expressions().keySet());
+ RankingExpression expression = model.expressions().get("test");
+ assertNotNull(expression);
+ assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)",
+ expression.getRoot().toString());
+ }
+
+}