diff options
author | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:44:10 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-10-11 09:44:10 +0200 |
commit | 8c8932cb055dcf39102b7b93d73bdbe6650af0c2 (patch) | |
tree | a053bee4364ee977bd629e7f71b21a7c79a54f54 /model-integration | |
parent | e5ba4ddccfa45b440b6803eb1665c3b2e6f19be9 (diff) |
Test xgboost java evaluation with missing values
Diffstat (limited to 'model-integration')
2 files changed, 112 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java new file mode 100644 index 00000000000..78137352512 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java @@ -0,0 +1,86 @@ +package ai.vespa.rankingexpression.importer.xgboost; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode; +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author lesters + */ +public class XGBoostImportEvaluationTestCase { + + @Test + public void testXGBoostEvaluation() { + RankingExpression expression = new XGBoostImporter() + .importModel("xgb", "src/test/models/xgboost/xgboost.test.if_inversion.json") + .expressions().get("xgb"); + + ArrayContext context = new ArrayContext(expression, DoubleValue.NaN); + + assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0)); + assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0)); + assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0)); + assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0)); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0)); + assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0)); + assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0)); + assertXGBoostEvaluation(11.0, expression, features(context)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0))); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0))); + + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + assertTrue(expression.getRoot() instanceof GBDTForestNode); + +// assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0)); +// assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0)); +// assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0)); +// assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0)); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0)); + assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0)); + assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0)); + assertXGBoostEvaluation(11.0, expression, features(context)); + assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0))); + assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0))); + } + + private ArrayContext features(ArrayContext context) { + return context.clone(); + } + + private ArrayContext features(ArrayContext context, String f1, double v1) { + context = context.clone(); + context.put(f1, v1); + return context; + } + + private ArrayContext features(ArrayContext context, String f1, Tensor v1) { + context = context.clone(); + context.put(f1, new TensorValue(v1)); + return context; + } + + private ArrayContext features(ArrayContext context, String f1, double v1, String f2, double v2) { + context = context.clone(); + context.put(f1, v1); + context.put(f2, v2); + return context; + } + + private void assertXGBoostEvaluation(double expected, RankingExpression expr, Context context) { + assertEquals(expected, expr.evaluate(context).asDouble(), 1e-9); + } + +} diff --git a/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json new file mode 100644 index 00000000000..8994d89787e --- /dev/null +++ b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json @@ -0,0 +1,26 @@ +[ + { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 0.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [ + { "nodeid": 3, "leaf": 1.0 }, + { "nodeid": 4, "leaf": 2.0 } + ]}, + { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 5, "children": [ + { "nodeid": 5, "leaf": 3.0 }, + { "nodeid": 6, "leaf": 4.0 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "leaf": 0.0 }, + { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "leaf": 4.0 }, + { "nodeid": 4, "leaf": 5.0 } + ]} + ]}, + { "nodeid": 0, "depth": 0, "split": "f2", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [ + { "nodeid": 1, "leaf": 0.0 }, + { "nodeid": 2, "depth": 1, "split": "f1", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [ + { "nodeid": 3, "leaf": 4.0 }, + { "nodeid": 4, "leaf": 3.0 } + ]} + ]} +]
\ No newline at end of file |