diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java | 78 |
1 files changed, 42 insertions, 36 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java index cce6b42d323..fee4f5b4160 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -1,57 +1,63 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; -import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; +import com.yahoo.tensor.Tensor; +import java.util.List; import java.util.function.Supplier; import java.util.logging.Logger; class HitRescorer { private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); - - private final Supplier<Evaluator> evaluatorSource; - public HitRescorer(Supplier<Evaluator> evaluatorSource) { - this.evaluatorSource = evaluatorSource; + private final Supplier<Evaluator> mainEvalSrc; + private final List<String> mainFromMF; + private final List<NormalizerContext> normalizers; + + public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) { + this.mainEvalSrc = mainEvalSrc; + this.mainFromMF = mainFromMF; + this.normalizers = normalizers; } - boolean rescoreHit(Hit hit) { - var features = hit.getField("matchfeatures"); - if (features instanceof FeatureData matchFeatures) { - var scorer = evaluatorSource.get(); - for (String argName : scorer.needInputs()) { - var asTensor = matchFeatures.getTensor(argName); - if (asTensor == null) { - asTensor = matchFeatures.getTensor(alternate(argName)); - } - if (asTensor != null) { - scorer.bind(argName, asTensor); - } else { - logger.warning("Missing match-feature for Evaluator argument: " + argName); - return false; - } - } - double newScore = scorer.evaluateScore(); - hit.setRelevance(newScore); - return true; - } else { - logger.warning("Hit without match-features: " + hit); - return false; + void preprocess(WrappedHit wrapped) { + for (var n : normalizers) { + var scorer = n.evalSource().get(); + double val = evalScorer(wrapped, scorer, n.fromMF()); + wrapped.setIdx(n.normalizer().addInput(val)); + } + } + + void runNormalizers() { + for (var n : normalizers) { + n.normalizer().normalize(); } } - private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; - private static final String RE_SUFFIX = ")"; - private static final int RE_PRE_LEN = RE_PREFIX.length(); - private static final int RE_SUF_LEN = RE_SUFFIX.length(); + double rescoreHit(WrappedHit wrapped) { + var scorer = mainEvalSrc.get(); + for (var n : normalizers) { + double normalizedValue = n.normalizer().getOutput(wrapped.getIdx()); + scorer.bind(n.name(), Tensor.from(normalizedValue)); + } + double newScore = evalScorer(wrapped, scorer, mainFromMF); + wrapped.setScore(newScore); + return newScore; + } - static String alternate(String argName) { - if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { - return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) { + for (String argName : fromMF) { + var asTensor = wrapped.getTensor(argName); + if (asTensor != null) { + scorer.bind(argName, asTensor); + } else { + logger.warning("Missing match-feature for Evaluator argument: " + argName); + return 0.0; + } } - return argName; + return scorer.evaluateScore(); } } |