diff options
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java')
-rw-r--r-- | model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java new file mode 100644 index 00000000000..48c7f5bee19 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -0,0 +1,28 @@ +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.ImportedModel; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author bratseth + */ +public class XGBoostImportTestCase { + + @Test + public void testXGBoost() { + ImportedModel model = new XGBoostImporter().importModel("test", "src/test/models/xgboost/xgboost.2.2.json"); + assertTrue("All inputs are scalar", model.inputs().isEmpty()); + assertEquals(1, model.expressions().size()); + System.out.println(model.expressions().keySet()); + RankingExpression expression = model.expressions().get("test"); + assertNotNull(expression); + assertEquals("if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)", + expression.getRoot().toString()); + } + +} |