aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/gbdt/GbdtModelTestCase.java
blob: cb6dc6247c9aae0020d92cadbc03740f9c92da60 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.gbdt;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;

import static org.junit.Assert.*;

/**
 * @author Simon Thoresen Hult
 */
public class GbdtModelTestCase {

    @Test
    public void requireThatFactoryMethodWorks() throws Exception {
        GbdtModel model = GbdtModel.fromXmlFile("src/test/files/gbdt.xml");
        assertEquals(10, model.trees().size());
        String exp = model.toRankingExpression();
        assertEquals(readFile("src/test/files/gbdt.expression").trim(), exp.trim());
        assertNotNull(new RankingExpression(exp));
    }

    @Test
    public void requireThatIllegalXmlThrowsException() throws Exception {
        assertIllegalXml("<Unknown />");
        assertIllegalXml("<DecisionTree />");
        assertIllegalXml("<DecisionTree>" +
                         "    <Unknown />" +
                         "</DecisionTree>");
        assertIllegalXml("<DecisionTree>" +
                         "    <Forest />" +
                         "</DecisionTree>");
        assertIllegalXml("<DecisionTree>" +
                         "    <Forest>" +
                         "        <Unknown />" +
                         "    </Forest>" +
                         "</DecisionTree>");
    }

    private static void assertIllegalXml(String xml) throws Exception {
        try {
            GbdtModel.fromXml(xml);
            fail();
        } catch (IllegalArgumentException e) {

        }
    }

    private static String readFile(String file) throws IOException {
        StringBuilder ret = new StringBuilder();
        BufferedReader in = new BufferedReader(new FileReader(file));
        while (true) {
            String str = in.readLine();
            if (str == null) {
                break;
            }
            ret.append(str).append("\n");
        }
        return ret.toString();
    }
}