blob: fee4f5b4160fc0d89ba6b9a5de9c30cbe53bb3e7 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
import com.yahoo.tensor.Tensor;
import java.util.List;
import java.util.function.Supplier;
import java.util.logging.Logger;
class HitRescorer {
private static final Logger logger = Logger.getLogger(HitRescorer.class.getName());
private final Supplier<Evaluator> mainEvalSrc;
private final List<String> mainFromMF;
private final List<NormalizerContext> normalizers;
public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) {
this.mainEvalSrc = mainEvalSrc;
this.mainFromMF = mainFromMF;
this.normalizers = normalizers;
}
void preprocess(WrappedHit wrapped) {
for (var n : normalizers) {
var scorer = n.evalSource().get();
double val = evalScorer(wrapped, scorer, n.fromMF());
wrapped.setIdx(n.normalizer().addInput(val));
}
}
void runNormalizers() {
for (var n : normalizers) {
n.normalizer().normalize();
}
}
double rescoreHit(WrappedHit wrapped) {
var scorer = mainEvalSrc.get();
for (var n : normalizers) {
double normalizedValue = n.normalizer().getOutput(wrapped.getIdx());
scorer.bind(n.name(), Tensor.from(normalizedValue));
}
double newScore = evalScorer(wrapped, scorer, mainFromMF);
wrapped.setScore(newScore);
return newScore;
}
private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) {
for (String argName : fromMF) {
var asTensor = wrapped.getTensor(argName);
if (asTensor != null) {
scorer.bind(argName, asTensor);
} else {
logger.warning("Missing match-feature for Evaluator argument: " + argName);
return 0.0;
}
}
return scorer.evaluateScore();
}
}
|