summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-20 12:25:57 +0100
committerLester Solbakken <lesters@oath.com>2019-11-20 12:25:57 +0100
commit414cc25cdfc47500377b7e9d5717889107be325f (patch)
tree4538cdb2b39c8f4e451bd2619318f8a6a15c90ee /model-integration/src/main
parent063dcfef8881dbb10775f7f2983a86ccc9b7b9da (diff)
Add XGBoost if-inversion for missing features
Diffstat (limited to 'model-integration/src/main')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java9
1 files changed, 5 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
index 9de07eed475..c41a114a970 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
@@ -16,7 +16,6 @@ import com.fasterxml.jackson.databind.ObjectMapper;
class XGBoostParser {
private List<XGBoostTree> xgboostTrees;
- private boolean doIfInversion = false;
/**
* Constructor stores parsed JSON trees.
@@ -32,7 +31,6 @@ class XGBoostParser {
for (JsonNode treeNode : forestNode) {
this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
}
- doIfInversion = filePath.endsWith("if_inversion.json");
}
/**
@@ -71,9 +69,12 @@ class XGBoostParser {
trueExp = treeToRankExp(node.getChildren().get(1));
falseExp = treeToRankExp(node.getChildren().get(0));
}
- String condition = node.getSplit() + " < " + node.getSplit_condition();
- if (doIfInversion && node.getMissing() == node.getYes()) {
+ String condition;
+ if (node.getMissing() == node.getYes()) {
+ // Note: this is for handling missing features, as the backend handles comparison with NaN as false.
condition = "!(" + node.getSplit() + " >= " + node.getSplit_condition() + ")";
+ } else {
+ condition = node.getSplit() + " < " + node.getSplit_condition();
}
return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
}