blob: cfb7c5b76bd447f69e780b8e1673432e673464d1 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
/**
* @author grace-lam
*/
class XGBoostParser {
private final List<XGBoostTree> xgboostTrees;
/**
* Constructor stores parsed JSON trees.
*
* @param filePath XGBoost JSON intput file.
* @throws JsonProcessingException Fails JSON parsing.
* @throws IOException Fails file reading.
*/
XGBoostParser(String filePath) throws JsonProcessingException, IOException {
this.xgboostTrees = new ArrayList<>();
ObjectMapper mapper = new ObjectMapper();
JsonNode forestNode = mapper.readTree(new File(filePath));
for (JsonNode treeNode : forestNode) {
this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
}
}
/**
* Converts parsed JSON trees to Vespa ranking expressions.
*
* @return Vespa ranking expressions.
*/
String toRankingExpression() {
StringBuilder ret = new StringBuilder();
for (int i = 0; i < xgboostTrees.size(); i++) {
ret.append(treeToRankExp(xgboostTrees.get(i)));
if (i != xgboostTrees.size() - 1) {
ret.append(" + \n");
}
}
return ret.toString();
}
/**
* Recursive helper function for toRankingExpression().
*
* @param node XGBoost tree node to convert.
* @return Vespa ranking expression for input node.
*/
private String treeToRankExp(XGBoostTree node) {
if (node.isLeaf()) {
return Double.toString(node.getLeaf());
} else {
assert node.getChildren().size() == 2;
String trueExp;
String falseExp;
if (node.getYes() == node.getChildren().get(0).getNodeid()) {
trueExp = treeToRankExp(node.getChildren().get(0));
falseExp = treeToRankExp(node.getChildren().get(1));
} else {
trueExp = treeToRankExp(node.getChildren().get(1));
falseExp = treeToRankExp(node.getChildren().get(0));
}
String condition;
if (node.getMissing() == node.getYes()) {
// Note: this is for handling missing features, as the backend handles comparison with NaN as false.
condition = "!(" + node.getSplit() + " >= " + node.getSplit_condition() + ")";
} else {
condition = node.getSplit() + " < " + node.getSplit_condition();
}
return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
}
}
}
|