aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostParser.java
blob: cfb7c5b76bd447f69e780b8e1673432e673464d1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.xgboost;

import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;

/**
 * @author grace-lam
 */
class XGBoostParser {

    private final List<XGBoostTree> xgboostTrees;

    /**
     * Constructor stores parsed JSON trees.
     *
     * @param filePath XGBoost JSON intput file.
     * @throws JsonProcessingException Fails JSON parsing.
     * @throws IOException             Fails file reading.
     */
    XGBoostParser(String filePath) throws JsonProcessingException, IOException {
        this.xgboostTrees = new ArrayList<>();
        ObjectMapper mapper = new ObjectMapper();
        JsonNode forestNode = mapper.readTree(new File(filePath));
        for (JsonNode treeNode : forestNode) {
            this.xgboostTrees.add(mapper.treeToValue(treeNode, XGBoostTree.class));
        }
    }

    /**
     * Converts parsed JSON trees to Vespa ranking expressions.
     *
     * @return Vespa ranking expressions.
     */
    String toRankingExpression() {
        StringBuilder ret = new StringBuilder();
        for (int i = 0; i < xgboostTrees.size(); i++) {
            ret.append(treeToRankExp(xgboostTrees.get(i)));
            if (i != xgboostTrees.size() - 1) {
                ret.append(" + \n");
            }
        }
        return ret.toString();
    }

    /**
     * Recursive helper function for toRankingExpression().
     *
     * @param node XGBoost tree node to convert.
     * @return Vespa ranking expression for input node.
     */
    private String treeToRankExp(XGBoostTree node) {
        if (node.isLeaf()) {
            return Double.toString(node.getLeaf());
        } else {
            assert node.getChildren().size() == 2;
            String trueExp;
            String falseExp;
            if (node.getYes() == node.getChildren().get(0).getNodeid()) {
                trueExp = treeToRankExp(node.getChildren().get(0));
                falseExp = treeToRankExp(node.getChildren().get(1));
            } else {
                trueExp = treeToRankExp(node.getChildren().get(1));
                falseExp = treeToRankExp(node.getChildren().get(0));
            }
            String condition;
            if (node.getMissing() == node.getYes()) {
                // Note: this is for handling missing features, as the backend handles comparison with NaN as false.
                condition = "!(" + node.getSplit() + " >= " + node.getSplit_condition() + ")";
            } else {
                condition = node.getSplit() + " < " + node.getSplit_condition();
            }
            return "if (" + condition + ", " + trueExp + ", " + falseExp + ")";
        }
    }

}