aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/gbdt/GbdtModel.java
blob: 038ce3d4bb77c4bda4039734d8ca88b58b106fd2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.gbdt;

import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import org.w3c.dom.Element;
import org.w3c.dom.Node;
import org.xml.sax.SAXException;

import javax.xml.parsers.ParserConfigurationException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;

/**
 * @author Simon Thoresen Hult
 */
public class GbdtModel {

    private final List<TreeNode> trees;

    public GbdtModel(List<TreeNode> trees) {
        this.trees = asForest(trees);
    }

    public List<TreeNode> trees() {
        return trees;
    }

    public String toRankingExpression() {
        if ( ! hasSampleInformation())
            System.err.println("The model nodes does not have the 'nSamples' attribute. " +
                               "For optimal runtime performance use an 'ext' model which has this information.");
        StringBuilder ret = new StringBuilder();
        for (TreeNode tree : trees) {
            if (ret.length() > 0) {
                ret.append(" +\n");
            }
            ret.append(tree.toRankingExpression());
        }
        ret.append("\n");
        return ret.toString();
    }

    /**
     * Return whether this model has sample information.
     * Don't bother to check every node as files either has this for all nodes or for none.
     */
    private boolean hasSampleInformation() {
        if (trees.size() == 0) return true; // no matter
        return trees.get(0).samples() !=null;
    }

    public static GbdtModel fromXml(String xml) throws ParserConfigurationException, IOException, SAXException {
        return fromDom(XmlHelper.parseXml(xml));
    }

    public static GbdtModel fromXmlFile(String fileName) throws ParserConfigurationException, IOException, SAXException {
        return fromDom(XmlHelper.parseXmlFile(fileName));
    }

    public static GbdtModel fromDom(Node doc) {
        Element dtree = XmlHelper.getSingleElement(doc, "DecisionTree");
        Element forest = XmlHelper.getSingleElement(dtree, "Forest");
        List<Element> trees = XmlHelper.getChildElements(forest, "Tree");
        if (trees.isEmpty()) {
            throw new IllegalArgumentException("Forest has no trees.");
        }
        List<TreeNode> model = new ArrayList<>();
        for (Node tree : trees) {
            if (XmlHelper.getChildElements(tree, null).isEmpty()) continue; // ignore
            model.add(TreeNode.fromDom(XmlHelper.getSingleElement(tree, null)));
        }
        return new GbdtModel(model);
    }

    private static List<TreeNode> asForest(List<TreeNode> in) {
        List<TreeNode> out = new ArrayList<>(in.size());
        for (TreeNode node : in) {
            if (node instanceof FeatureNode) {
                out.add(node);
            } else if (node instanceof ResponseNode) { // TODO): We should stop this sillyness ...
                out.add(new NumericFeatureNode("value(0)", new DoubleValue(1), node.samples(), node,
                                                           new ResponseNode(0, Optional.of(0))));
            } else {
                throw new UnsupportedOperationException(node.getClass().getName());
            }
        }
        return Collections.unmodifiableList(out);
    }
}