diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-20 12:25:57 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-20 12:25:57 +0100 |
commit | 414cc25cdfc47500377b7e9d5717889107be325f (patch) | |
tree | 4538cdb2b39c8f4e451bd2619318f8a6a15c90ee /model-integration/src/main | |
parent | 063dcfef8881dbb10775f7f2983a86ccc9b7b9da (diff) |
Add XGBoost if-inversion for missing features
Diffstat (limited to 'model-integration/src/main')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java index 9de07eed475..c41a114a970 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java @@ -16,7 +16,6 @@ import com.fasterxml.jackson.databind.ObjectMapper; class XGBoostParser { private List<XGBoostTree> xgboostTrees; - private boolean doIfInversion = false; /** * Constructor stores parsed JSON trees. @@ -32,7 +31,6 @@ class XGBoostParser { for (JsonNode treeNode : forestNode) { this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class)); } - doIfInversion = filePath.endsWith("if_inversion.json"); } /** @@ -71,9 +69,12 @@ class XGBoostParser { trueExp = treeToRankExp(node.getChildren().get(1)); falseExp = treeToRankExp(node.getChildren().get(0)); } - String condition = node.getSplit() + " < " + node.getSplit_condition(); - if (doIfInversion && node.getMissing() == node.getYes()) { + String condition; + if (node.getMissing() == node.getYes()) { + // Note: this is for handling missing features, as the backend handles comparison with NaN as false. condition = "!(" + node.getSplit() + " >= " + node.getSplit_condition() + ")"; + } else { + condition = node.getSplit() + " < " + node.getSplit_condition(); } return "if (" + condition + ", " + trueExp + ", " + falseExp + ")"; } |