aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/lightgbm/LightGBMTestBase.java
blob: 05bddd2ffaf295a42819e982a146e6894d463049 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.lightgbm;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;

import static org.junit.Assert.assertEquals;

/**
 * @author lesters
 */
class LightGBMTestBase {

    RankingExpression importModel(String path) {
        return new LightGBMImporter().importModel("lightgbm", path).expressions().get("lightgbm");
    }

    void assertEvaluation(double expected, RankingExpression expr, TestFeatures features) {
        assertEquals(expected, expr.evaluate(features.context).asDouble(), 1e-6);
    }

    TestFeatures features(ArrayContext context) {
        return new TestFeatures(context.clone());
    }

    static class TestFeatures {
        private final ArrayContext context;
        TestFeatures(ArrayContext context) {
            this.context = context;
        }
        TestFeatures add(String name, double value) {
            context.put(name, value);
            return this;
        }
        TestFeatures add(String name, String value) {
            context.put(name, new StringValue(value));
            return this;
        }
    }

}