aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-10-11 09:44:10 +0200
committerLester Solbakken <lesters@oath.com>2019-10-11 09:44:10 +0200
commit8c8932cb055dcf39102b7b93d73bdbe6650af0c2 (patch)
treea053bee4364ee977bd629e7f71b21a7c79a54f54 /model-integration/src/test/java/ai/vespa/rankingexpression/importer
parente5ba4ddccfa45b440b6803eb1665c3b2e6f19be9 (diff)
Test xgboost java evaluation with missing values
Diffstat (limited to 'model-integration/src/test/java/ai/vespa/rankingexpression/importer')
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java86
1 files changed, 86 insertions, 0 deletions
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
new file mode 100644
index 00000000000..78137352512
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportEvaluationTestCase.java
@@ -0,0 +1,86 @@
+package ai.vespa.rankingexpression.importer.xgboost;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestNode;
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class XGBoostImportEvaluationTestCase {
+
+ @Test
+ public void testXGBoostEvaluation() {
+ RankingExpression expression = new XGBoostImporter()
+ .importModel("xgb", "src/test/models/xgboost/xgboost.test.if_inversion.json")
+ .expressions().get("xgb");
+
+ ArrayContext context = new ArrayContext(expression, DoubleValue.NaN);
+
+ assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+ assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+ assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+ assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ assertTrue(expression.getRoot() instanceof GBDTForestNode);
+
+// assertXGBoostEvaluation(1.0, expression, features(context, "f1", 0.0, "f2", 0.0));
+// assertXGBoostEvaluation(2.0, expression, features(context, "f1", 0.0, "f2", 1.0));
+// assertXGBoostEvaluation(3.0, expression, features(context, "f1", 1.0, "f2", 0.0));
+// assertXGBoostEvaluation(4.0, expression, features(context, "f1", 1.0, "f2", 1.0));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", 0.0));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", 1.0));
+ assertXGBoostEvaluation(7.0, expression, features(context, "f2", 0.0));
+ assertXGBoostEvaluation(9.0, expression, features(context, "f2", 1.0));
+ assertXGBoostEvaluation(11.0, expression, features(context));
+ assertXGBoostEvaluation(5.0, expression, features(context, "f1", Tensor.from(0.0)));
+ assertXGBoostEvaluation(6.0, expression, features(context, "f1", Tensor.from(1.0)));
+ }
+
+ private ArrayContext features(ArrayContext context) {
+ return context.clone();
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1) {
+ context = context.clone();
+ context.put(f1, v1);
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, Tensor v1) {
+ context = context.clone();
+ context.put(f1, new TensorValue(v1));
+ return context;
+ }
+
+ private ArrayContext features(ArrayContext context, String f1, double v1, String f2, double v2) {
+ context = context.clone();
+ context.put(f1, v1);
+ context.put(f2, v2);
+ return context;
+ }
+
+ private void assertXGBoostEvaluation(double expected, RankingExpression expr, Context context) {
+ assertEquals(expected, expr.evaluate(context).asDouble(), 1e-9);
+ }
+
+}