blob: 40b199b34531d21a7d3855bd51d802cc6237a2c0 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.gbdt;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import static org.junit.Assert.*;
/**
* @author Simon Thoresen Hult
*/
public class GbdtModelTestCase {
@Test
public void requireThatFactoryMethodWorks() throws Exception {
GbdtModel model = GbdtModel.fromXmlFile("src/test/files/gbdt.xml");
assertEquals(10, model.trees().size());
String exp = model.toRankingExpression();
assertEquals(readFile("src/test/files/gbdt.expression").trim(), exp.trim());
assertNotNull(new RankingExpression(exp));
}
@Test
public void requireThatIllegalXmlThrowsException() throws Exception {
assertIllegalXml("<Unknown />");
assertIllegalXml("<DecisionTree />");
assertIllegalXml("<DecisionTree>" +
" <Unknown />" +
"</DecisionTree>");
assertIllegalXml("<DecisionTree>" +
" <Forest />" +
"</DecisionTree>");
assertIllegalXml("<DecisionTree>" +
" <Forest>" +
" <Unknown />" +
" </Forest>" +
"</DecisionTree>");
}
private static void assertIllegalXml(String xml) throws Exception {
try {
GbdtModel.fromXml(xml);
fail();
} catch (IllegalArgumentException e) {
}
}
private static String readFile(String file) throws IOException {
StringBuilder ret = new StringBuilder();
BufferedReader in = new BufferedReader(new FileReader(file));
while (true) {
String str = in.readLine();
if (str == null) {
break;
}
ret.append(str).append("\n");
}
return ret.toString();
}
}
|