aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java
blob: 31a676e4c8e5245fe13ffc941250e4bf7efae032 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;

import ai.vespa.models.evaluation.FunctionEvaluator;

import com.yahoo.vespa.config.search.RankProfilesConfig;

import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Map;
import java.util.HashMap;
import java.util.function.Supplier;

class GlobalPhaseSetup {

    final FunEvalSpec globalPhaseEvalSpec;
    final int rerankCount;
    final Collection<String> matchFeaturesToHide;
    final List<NormalizerSetup> normalizers;

    GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec,
                     final int rerankCount,
                     Collection<String> matchFeaturesToHide,
                     List<NormalizerSetup> normalizers)
    {
        this.globalPhaseEvalSpec = globalPhaseEvalSpec;
        this.rerankCount = rerankCount;
        this.matchFeaturesToHide = matchFeaturesToHide;
        this.normalizers = normalizers;
    }

    static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) {
        var model = modelEvaluator.modelForRankProfile(rp.name());
        Map<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<>();
        for (var n : rp.normalizer()) {
            availableNormalizers.put(n.name(), n);
        }
        Supplier<FunctionEvaluator> functionEvaluatorSource = null;
        int rerankCount = -1;
        Set<String> namesToHide = new HashSet<>();
        Set<String> matchFeatures = new HashSet<>();
        Map<String, String> renameFeatures = new HashMap<>();
        String toRename = null;
        for (var prop : rp.fef().property()) {
            if (prop.name().equals("vespa.globalphase.rerankcount")) {
                rerankCount = Integer.valueOf(prop.value());
            }
            if (prop.name().equals("vespa.rank.globalphase")) {
                functionEvaluatorSource = () -> model.evaluatorOf("globalphase");
            }
            if (prop.name().equals("vespa.hidden.matchfeature")) {
                namesToHide.add(prop.value());
            }
            if (prop.name().equals("vespa.match.feature")) {
                matchFeatures.add(prop.value());
            }
            if (prop.name().equals("vespa.feature.rename")) {
                if (toRename == null) {
                    toRename = prop.value();
                } else {
                    renameFeatures.put(toRename, prop.value());
                    toRename = null;
                }
            }
        }
        for (var entry : renameFeatures.entrySet()) {
            String old = entry.getKey();
            if (matchFeatures.contains(old)) {
                matchFeatures.remove(old);
                matchFeatures.add(entry.getValue());
            }
        }
        if (rerankCount < 0) {
            rerankCount = 100;
        }
        if (functionEvaluatorSource != null) {
            var evaluator = functionEvaluatorSource.get();
            var allInputs = List.copyOf(evaluator.function().arguments());
            List<String> fromMF = new ArrayList<>();
            List<String> fromQuery = new ArrayList<>();
            List<NormalizerSetup> normalizers = new ArrayList<>();
            for (var input : allInputs) {
                String queryFeatureName = asQueryFeature(input);
                if (queryFeatureName != null) {
                    fromQuery.add(queryFeatureName);
                } else if (availableNormalizers.containsKey(input)) {
                    var cfg = availableNormalizers.get(input);
                    String normInput = cfg.input();
                    if (matchFeatures.contains(normInput)) {
                        Supplier<Evaluator> normSource = () -> new DummyEvaluator(normInput);
                        normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSource, List.of(normInput), rerankCount));
                    } else {
                        Supplier<FunctionEvaluator> normSource = () -> model.evaluatorOf(normInput);
                        var normInputs = List.copyOf(normSource.get().function().arguments());
                        var normSupplier = SimpleEvaluator.wrap(normSource);
                        normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSupplier, normInputs, rerankCount));
                    }
                } else if (matchFeatures.contains(input)) {
                    fromMF.add(input);
                } else {
                    throw new IllegalArgumentException("Bad config, missing global-phase input: " + input);
                }
            }
            Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource);
            var gfun = new FunEvalSpec(supplier, fromQuery, fromMF);
            return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers);
        }
        return null;
    }

    private static NormalizerSetup makeNormalizerSetup(RankProfilesConfig.Rankprofile.Normalizer cfg,
                                                       Set<String> matchFeatures,
                                                       Supplier<Evaluator> evalSupplier,
                                                       List<String> normInputs,
                                                       int rerankCount)
    {
        List<String> fromQuery = new ArrayList<>();
        List<String> fromMF = new ArrayList<>();
        for (var input : normInputs) {
            String queryFeatureName = asQueryFeature(input);
            if (queryFeatureName != null) {
                fromQuery.add(queryFeatureName);
            } else if (matchFeatures.contains(input)) {
                fromMF.add(input);
            } else {
                throw new IllegalArgumentException("Bad config, missing normalizer input: " + input);
            }
        }
        var fun = new FunEvalSpec(evalSupplier, fromQuery, fromMF);
        return new NormalizerSetup(cfg.name(), makeNormalizerSupplier(cfg, rerankCount), fun);
    }

    private static Supplier<Normalizer> makeNormalizerSupplier(RankProfilesConfig.Rankprofile.Normalizer cfg, int rerankCount) {
        return switch (cfg.algo()) {
            case LINEAR -> () -> new LinearNormalizer(rerankCount);
            case RRANK -> () -> new ReciprocalRankNormalizer(rerankCount, cfg.kparam());
        };
    }

    static String asQueryFeature(String input) {
        var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(input);
        if (optRef.isPresent()) {
            var ref = optRef.get();
            if (ref.isSimple() && ref.name().equals("query")) {
                return ref.simpleArgument().get();
            }
        }
        return null;
    }
}