aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-27 14:34:36 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-27 14:42:58 +0000
commitf723c2ca9252e6aab036a3ee21a24603f6fffb85 (patch)
tree41a8951434d24a61a7c45fcaddf260cbd4ff1791 /container-search
parentc095c1061088bae7e6f5f26ec595325fc79be43c (diff)
split out rerankHits and add adjust of scores just like second-phase adjustment
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java36
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java91
2 files changed, 92 insertions, 35 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java
index 9810f612e5c..be4ba5444fe 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java
@@ -48,7 +48,7 @@ public class GlobalPhaseHelper {
};
// TODO need to get rerank-count somehow
int rerank = 7;
- rerankHits(query, result, new HitRescorer(supplier), rerank);
+ ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank);
}
record NameAndValue(String name, Tensor value) { }
@@ -86,40 +86,6 @@ public class GlobalPhaseHelper {
return result;
}
- void rerankHits(Query query, Result result, HitRescorer hitRescorer, int rerank) {
- double worstRerankedScore = Double.MAX_VALUE;
- double worstRerankedOldScore = Double.MAX_VALUE;
- // TODO consider doing recursive iteration instead of deepIterator
- for (var iterator = result.hits().deepIterator(); iterator.hasNext();) {
- Hit hit = iterator.next();
- if (hit.isMeta() || hit instanceof HitGroup) {
- continue;
- }
- // what about hits inside grouping results?
- if (rerank > 0) {
- double oldScore = hit.getRelevance().getScore();
- boolean didRerank = hitRescorer.rescoreHit(hit);
- if (didRerank) {
- double newScore = hit.getRelevance().getScore();
- if (oldScore < worstRerankedOldScore) worstRerankedOldScore = oldScore;
- if (newScore < worstRerankedScore) worstRerankedScore = newScore;
- --rerank;
- } else {
- // failed to rescore this hit, what should we do?
- hit.setRelevance(-Double.MAX_VALUE);
- }
- } else {
- // too low quality
- if (worstRerankedOldScore > worstRerankedScore) {
- double penalty = worstRerankedOldScore - worstRerankedScore;
- double oldScore = hit.getRelevance().getScore();
- hit.setRelevance(oldScore - penalty);
- }
- }
- }
- result.hits().sort();
- }
-
private Supplier<FunctionEvaluator> underlying(Query query, String schema) {
String rankProfile = query.getRanking().getProfile();
String key = schema + " with rank profile " + rankProfile;
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java
new file mode 100644
index 00000000000..e700276b151
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java
@@ -0,0 +1,91 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.ranking;
+
+import com.yahoo.search.Result;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.result.HitGroup;
+
+import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.List;
+import java.util.logging.Logger;
+
+public class ResultReranker {
+
+ private static final Logger logger = Logger.getLogger(ResultReranker.class.getName());
+
+ // scale and adjust the score according to the range
+ // of the original and final score values to avoid that
+ // a score from the backend is larger than finalScores_low
+ static class Ranges {
+ private double initialScores_high = -Double.MAX_VALUE;
+ private double initialScores_low = Double.MAX_VALUE;
+ private double finalScores_high = -Double.MAX_VALUE;
+ private double finalScores_low = Double.MAX_VALUE;
+
+ boolean valid() {
+ return (initialScores_high >= initialScores_low
+ &&
+ finalScores_high >= finalScores_low);
+ }
+ void withInitialScore(double score) {
+ if (score < initialScores_low) initialScores_low = score;
+ if (score > initialScores_high) initialScores_high = score;
+ }
+ void withFinalScore(double score) {
+ if (score < finalScores_low) finalScores_low = score;
+ if (score > finalScores_high) finalScores_high = score;
+ }
+ private double initialRange() {
+ double r = initialScores_high - initialScores_low;
+ if (r < 1.0) r = 1.0;
+ return r;
+ }
+ private double finalRange() {
+ double r = finalScores_high - finalScores_low;
+ if (r < 1.0) r = 1.0;
+ return r;
+ }
+ double scale() { return finalRange() / initialRange(); }
+ double bias() { return finalScores_low - initialScores_low * scale(); }
+ }
+
+ static void rerankHits(Result result, HitRescorer hitRescorer, int rerankCount) {
+ List<Hit> hitsToRescore = new ArrayList<>();
+ // consider doing recursive iteration explicitly instead of using deepIterator?
+ for (var iterator = result.hits().deepIterator(); iterator.hasNext();) {
+ Hit hit = iterator.next();
+ if (hit.isMeta() || hit instanceof HitGroup) {
+ continue;
+ }
+ // what about hits inside grouping results?
+ // they are inside GroupingListHit, we won't recurse into it; so we won't see them.
+ hitsToRescore.add(hit);
+ }
+ // we can't be 100% certain that hits were sorted according to relevance:
+ hitsToRescore.sort(Comparator.naturalOrder());
+ var ranges = new Ranges();
+ for (var iterator = hitsToRescore.iterator(); rerankCount > 0 && iterator.hasNext(); ) {
+ Hit hit = iterator.next();
+ double oldScore = hit.getRelevance().getScore();
+ boolean didRerank = hitRescorer.rescoreHit(hit);
+ if (didRerank) {
+ ranges.withInitialScore(oldScore);
+ ranges.withFinalScore(hit.getRelevance().getScore());
+ --rerankCount;
+ iterator.remove();
+ }
+ }
+ // if any hits are left in the list, they need rescaling:
+ if (ranges.valid()) {
+ double scale = ranges.scale();
+ double bias = ranges.bias();
+ for (Hit hit : hitsToRescore) {
+ double oldScore = hit.getRelevance().getScore();
+ hit.setRelevance(oldScore * scale + bias);
+ }
+ }
+ result.hits().sort();
+ }
+
+}