summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-10-11 09:44:10 +0200
committerLester Solbakken <lesters@oath.com>2019-10-11 09:44:10 +0200
commit8c8932cb055dcf39102b7b93d73bdbe6650af0c2 (patch)
treea053bee4364ee977bd629e7f71b21a7c79a54f54 /model-integration
parente5ba4ddccfa45b440b6803eb1665c3b2e6f19be9 (diff)
Test xgboost java evaluation with missing values
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java86
-rw-r--r--model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json26
2 files changed, 112 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
new file mode 100644
index 00000000000..78137352512
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
@@ -0,0 +1,86 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class XGBoostImportEvaluationTestCase {
+
+ @Test
+ public void testXGBoostEvaluation() {
+ RankingExpression expression = new XGBoostImporter()
+ .importModel("xgb", "src/test/models/xgboost/xgboost.test.if_inversion.json")
+ .expressions().get("xgb");
+
+ ArrayContext context = new ArrayContext(expression, DoubleValue.NaN);
+
+ assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+ assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+ assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+ assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ assertTrue(expression.getRoot() instanceof GBDTForestNode);
+
+// assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+// assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+// assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+// assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+ }
+
+ private ArrayContext features(ArrayContext context) {
+ return context.clone();
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1) {
+ context = context.clone();
+ context.put(f1, v1);
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, Tensor v1) {
+ context = context.clone();
+ context.put(f1, new TensorValue(v1));
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1, String f2, double v2) {
+ context = context.clone();
+ context.put(f1, v1);
+ context.put(f2, v2);
+ return context;
+ }
+
+ private void assertXGBoostEvaluation(double expected, RankingExpression expr, Context context) {
+ assertEquals(expected, expr.evaluate(context).asDouble(), 1e-9);
+ }
+
+}
diff --git a/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json
new file mode 100644
index 00000000000..8994d89787e
--- /dev/null
+++ b/model-integration/src/test/models/xgboost/xgboost.test.if_inversion.json
@@ -0,0 +1,26 @@
+[
+ { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 0.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 3, "children": [
+ { "nodeid": 3, "leaf": 1.0 },
+ { "nodeid": 4, "leaf": 2.0 }
+ ]},
+ { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 5, "no": 6, "missing": 5, "children": [
+ { "nodeid": 5, "leaf": 3.0 },
+ { "nodeid": 6, "leaf": 4.0 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f1", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "leaf": 0.0 },
+ { "nodeid": 2, "depth": 1, "split": "f2", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [
+ { "nodeid": 3, "leaf": 4.0 },
+ { "nodeid": 4, "leaf": 5.0 }
+ ]}
+ ]},
+ { "nodeid": 0, "depth": 0, "split": "f2", "split_condition": 1.5, "yes": 1, "no": 2, "missing": 2, "children": [
+ { "nodeid": 1, "leaf": 0.0 },
+ { "nodeid": 2, "depth": 1, "split": "f1", "split_condition": 0.5, "yes": 3, "no": 4, "missing": 4, "children": [
+ { "nodeid": 3, "leaf": 4.0 },
+ { "nodeid": 4, "leaf": 3.0 }
+ ]}
+ ]}
+] \ No newline at end of file