aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java78
1 files changed, 42 insertions, 36 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
index cce6b42d323..fee4f5b4160 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
@@ -1,57 +1,63 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
-import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER;
+import com.yahoo.tensor.Tensor;
+import java.util.List;
import java.util.function.Supplier;
import java.util.logging.Logger;
class HitRescorer {
private static final Logger logger = Logger.getLogger(HitRescorer.class.getName());
-
- private final Supplier<Evaluator> evaluatorSource;
- public HitRescorer(Supplier<Evaluator> evaluatorSource) {
- this.evaluatorSource = evaluatorSource;
+ private final Supplier<Evaluator> mainEvalSrc;
+ private final List<String> mainFromMF;
+ private final List<NormalizerContext> normalizers;
+
+ public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) {
+ this.mainEvalSrc = mainEvalSrc;
+ this.mainFromMF = mainFromMF;
+ this.normalizers = normalizers;
}
- boolean rescoreHit(Hit hit) {
- var features = hit.getField("matchfeatures");
- if (features instanceof FeatureData matchFeatures) {
- var scorer = evaluatorSource.get();
- for (String argName : scorer.needInputs()) {
- var asTensor = matchFeatures.getTensor(argName);
- if (asTensor == null) {
- asTensor = matchFeatures.getTensor(alternate(argName));
- }
- if (asTensor != null) {
- scorer.bind(argName, asTensor);
- } else {
- logger.warning("Missing match-feature for Evaluator argument: " + argName);
- return false;
- }
- }
- double newScore = scorer.evaluateScore();
- hit.setRelevance(newScore);
- return true;
- } else {
- logger.warning("Hit without match-features: " + hit);
- return false;
+ void preprocess(WrappedHit wrapped) {
+ for (var n : normalizers) {
+ var scorer = n.evalSource().get();
+ double val = evalScorer(wrapped, scorer, n.fromMF());
+ wrapped.setIdx(n.normalizer().addInput(val));
+ }
+ }
+
+ void runNormalizers() {
+ for (var n : normalizers) {
+ n.normalizer().normalize();
}
}
- private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "(";
- private static final String RE_SUFFIX = ")";
- private static final int RE_PRE_LEN = RE_PREFIX.length();
- private static final int RE_SUF_LEN = RE_SUFFIX.length();
+ double rescoreHit(WrappedHit wrapped) {
+ var scorer = mainEvalSrc.get();
+ for (var n : normalizers) {
+ double normalizedValue = n.normalizer().getOutput(wrapped.getIdx());
+ scorer.bind(n.name(), Tensor.from(normalizedValue));
+ }
+ double newScore = evalScorer(wrapped, scorer, mainFromMF);
+ wrapped.setScore(newScore);
+ return newScore;
+ }
- static String alternate(String argName) {
- if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) {
- return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN);
+ private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) {
+ for (String argName : fromMF) {
+ var asTensor = wrapped.getTensor(argName);
+ if (asTensor != null) {
+ scorer.bind(argName, asTensor);
+ } else {
+ logger.warning("Missing match-feature for Evaluator argument: " + argName);
+ return 0.0;
+ }
}
- return argName;
+ return scorer.evaluateScore();
}
}