summaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/ranking/GlobalPhaseSetupTest.java
blob: 082531a97dd89c3df6c18e222436591a9bb785d0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;

import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;

import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.*;

public class GlobalPhaseSetupTest {
    private static final String CONFIG_DIR = "src/test/resources/config/";

    @SuppressWarnings("deprecation")
    RankProfilesConfig readConfig(String subDir) {
        String cfgId = "file:" + CONFIG_DIR + subDir + "/rank-profiles.cfg";
        return ConfigGetter.getConfig(RankProfilesConfig.class, cfgId);
    }

    @Test void mediumAdvancedSetup() {
        RankProfilesConfig rpCfg = readConfig("medium");
        assertEquals(1, rpCfg.rankprofile().size());
        RankProfilesEvaluator rpEvaluator = createEvaluator(rpCfg);
        var setup = GlobalPhaseSetup.maybeMakeSetup(rpCfg.rankprofile().get(0), rpEvaluator);
        assertNotNull(setup);
        assertEquals(42, setup.rerankCount);
        assertEquals(0, setup.normalizers.size());
        assertEquals(9, setup.matchFeaturesToHide.size());
        assertEquals(1, setup.globalPhaseEvalSpec.fromQuery().size());
        assertEquals(9, setup.globalPhaseEvalSpec.fromMF().size());
    }

    @Test void queryFeaturesWithDefaults() {
        RankProfilesConfig rpCfg = readConfig("qf_defaults");
        assertEquals(1, rpCfg.rankprofile().size());
        RankProfilesEvaluator rpEvaluator = createEvaluator(rpCfg);
        var setup = GlobalPhaseSetup.maybeMakeSetup(rpCfg.rankprofile().get(0), rpEvaluator);
        assertNotNull(setup);
        assertEquals(0, setup.normalizers.size());
        assertEquals(0, setup.matchFeaturesToHide.size());
        assertEquals(5, setup.globalPhaseEvalSpec.fromQuery().size());
        assertEquals(2, setup.globalPhaseEvalSpec.fromMF().size());
        assertEquals(5, setup.defaultValues.size());
        assertEquals(Tensor.from(0.0), setup.defaultValues.get("query(w_no_def)"));
        assertEquals(Tensor.from(1.0), setup.defaultValues.get("query(w_has_def)"));
        assertEquals(Tensor.from("tensor(m{}):{}"), setup.defaultValues.get("query(m_no_def)"));
        assertEquals(Tensor.from("tensor(v[3]):[0,0,0]"), setup.defaultValues.get("query(v_no_def)"));
        assertEquals(Tensor.from("tensor(v[3]):[2,0.25,1.5]"), setup.defaultValues.get("query(v_has_def)"));
    }

    @Test void withNormalizers() {
        RankProfilesConfig rpCfg = readConfig("with_normalizers");
        assertEquals(1, rpCfg.rankprofile().size());
        RankProfilesEvaluator rpEvaluator = createEvaluator(rpCfg);
        var setup = GlobalPhaseSetup.maybeMakeSetup(rpCfg.rankprofile().get(0), rpEvaluator);
        assertNotNull(setup);
        var nList = setup.normalizers;
        assertEquals(7, nList.size());
        nList.sort((a,b) -> a.name().compareTo(b.name()));

        var n = nList.get(0);
        assertEquals("normalize@2974853441@linear", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("funmf", n.inputEvalSpec().fromMF().get(0));
        assertEquals("linear", n.supplier().get().normalizing());

        n = nList.get(1);
        assertEquals("normalize@3414032797@rrank", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("attribute(year)", n.inputEvalSpec().fromMF().get(0));
        assertEquals("reciprocal-rank{k:60.0}", n.supplier().get().normalizing());

        n = nList.get(2);
        assertEquals("normalize@3551296680@linear", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("nativeRank", n.inputEvalSpec().fromMF().get(0));
        assertEquals("linear", n.supplier().get().normalizing());

        n = nList.get(3);
        assertEquals("normalize@4280591309@rrank", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("bm25(myabstract)", n.inputEvalSpec().fromMF().get(0));
        assertEquals("reciprocal-rank{k:42.0}", n.supplier().get().normalizing());

        n = nList.get(4);
        assertEquals("normalize@4370385022@linear", n.name());
        assertEquals(1, n.inputEvalSpec().fromQuery().size());
        assertEquals("myweight", n.inputEvalSpec().fromQuery().get(0));
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0));
        assertEquals("linear", n.supplier().get().normalizing());

        n = nList.get(5);
        assertEquals("normalize@4640646880@linear", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("attribute(foo1)", n.inputEvalSpec().fromMF().get(0));
        assertEquals("linear", n.supplier().get().normalizing());

        n = nList.get(6);
        assertEquals("normalize@6283155534@linear", n.name());
        assertEquals(0, n.inputEvalSpec().fromQuery().size());
        assertEquals(1, n.inputEvalSpec().fromMF().size());
        assertEquals("bm25(mytitle)", n.inputEvalSpec().fromMF().get(0));
        assertEquals("linear", n.supplier().get().normalizing());
    }

    private RankProfilesEvaluator createEvaluator(RankProfilesConfig config) {
        RankingConstantsConfig constantsConfig = new RankingConstantsConfig.Builder().build();
        RankingExpressionsConfig expressionsConfig = new RankingExpressionsConfig.Builder().build();
        OnnxModelsConfig onnxModelsConfig = new OnnxModelsConfig.Builder().build();
        return new RankProfilesEvaluator(config, constantsConfig, expressionsConfig, onnxModelsConfig, MockFileAcquirer.returnFile(null));
    }
}