aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
blob: cce6b42d32340e1400af05765f7a6720684d0478 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;

import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER;

import java.util.function.Supplier;
import java.util.logging.Logger;

class HitRescorer {

    private static final Logger logger = Logger.getLogger(HitRescorer.class.getName());
    
    private final Supplier<Evaluator> evaluatorSource;

    public HitRescorer(Supplier<Evaluator> evaluatorSource) {
        this.evaluatorSource = evaluatorSource;
    }

    boolean rescoreHit(Hit hit) {
        var features = hit.getField("matchfeatures");
        if (features instanceof FeatureData matchFeatures) {
            var scorer = evaluatorSource.get();
            for (String argName : scorer.needInputs()) {
                var asTensor = matchFeatures.getTensor(argName);
                if (asTensor == null) {
                    asTensor = matchFeatures.getTensor(alternate(argName));
                }
                if (asTensor != null) {
                    scorer.bind(argName, asTensor);
                } else {
                    logger.warning("Missing match-feature for Evaluator argument: " + argName);
                    return false;
                }
            }
            double newScore = scorer.evaluateScore();
            hit.setRelevance(newScore);
            return true;
        } else {
            logger.warning("Hit without match-features: " + hit);
            return false;
        }
    }

    private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "(";
    private static final String RE_SUFFIX = ")";
    private static final int RE_PRE_LEN = RE_PREFIX.length();
    private static final int RE_SUF_LEN = RE_SUFFIX.length();

    static String alternate(String argName) {
        if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) {
            return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN);
        }
        return argName;
    }
}