aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
blob: fee4f5b4160fc0d89ba6b9a5de9c30cbe53bb3e7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;

import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.tensor.Tensor;

import java.util.List;
import java.util.function.Supplier;
import java.util.logging.Logger;

class HitRescorer {

    private static final Logger logger = Logger.getLogger(HitRescorer.class.getName());

    private final Supplier<Evaluator> mainEvalSrc;
    private final List<String> mainFromMF;
    private final List<NormalizerContext> normalizers;

    public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) {
        this.mainEvalSrc = mainEvalSrc;
        this.mainFromMF = mainFromMF;
        this.normalizers = normalizers;
    }

    void preprocess(WrappedHit wrapped) {
        for (var n : normalizers) {
            var scorer = n.evalSource().get();
            double val = evalScorer(wrapped, scorer, n.fromMF());
            wrapped.setIdx(n.normalizer().addInput(val));
        }
    }

    void runNormalizers() {
        for (var n : normalizers) {
            n.normalizer().normalize();
        }
    }

    double rescoreHit(WrappedHit wrapped) {
        var scorer = mainEvalSrc.get();
        for (var n : normalizers) {
            double normalizedValue = n.normalizer().getOutput(wrapped.getIdx());
            scorer.bind(n.name(), Tensor.from(normalizedValue));
        }
        double newScore = evalScorer(wrapped, scorer, mainFromMF);
        wrapped.setScore(newScore);
        return newScore;
    }

    private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) {
        for (String argName : fromMF) {
            var asTensor = wrapped.getTensor(argName);
            if (asTensor != null) {
                scorer.bind(argName, asTensor);
            } else {
                logger.warning("Missing match-feature for Evaluator argument: " + argName);
                return 0.0;
            }
        }
        return scorer.evaluateScore();
    }
}