diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java | 106 |
1 files changed, 49 insertions, 57 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java index 8d24acdf141..2e9edd6de3a 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.search.Result; @@ -14,80 +14,72 @@ class ResultReranker { private static final Logger logger = Logger.getLogger(ResultReranker.class.getName()); - // scale and adjust the score according to the range - // of the original and final score values to avoid that - // a score from the backend is larger than finalScores_low - static class Ranges { - private double initialScores_high = -Double.MAX_VALUE; - private double initialScores_low = Double.MAX_VALUE; - private double finalScores_high = -Double.MAX_VALUE; - private double finalScores_low = Double.MAX_VALUE; + private final HitRescorer hitRescorer; + private final int rerankCount; + private final List<WrappedHit> hitsToRescore = new ArrayList<>(); + private final RangeAdjuster ranges = new RangeAdjuster(); - boolean rescaleNeeded() { - return (initialScores_low > finalScores_low - && - initialScores_high >= initialScores_low - && - finalScores_high >= finalScores_low); - } - void withInitialScore(double score) { - if (score < initialScores_low) initialScores_low = score; - if (score > initialScores_high) initialScores_high = score; - } - void withFinalScore(double score) { - if (score < finalScores_low) finalScores_low = score; - if (score > finalScores_high) finalScores_high = score; - } - private double initialRange() { - double r = initialScores_high - initialScores_low; - if (r < 1.0) r = 1.0; - return r; - } - private double finalRange() { - double r = finalScores_high - finalScores_low; - if (r < 1.0) r = 1.0; - return r; - } - double scale() { return finalRange() / initialRange(); } - double bias() { return finalScores_low - initialScores_low * scale(); } + ResultReranker(HitRescorer hitRescorer, int rerankCount) { + this.hitRescorer = hitRescorer; + this.rerankCount = rerankCount; } - static void rerankHits(Result result, HitRescorer hitRescorer, int rerankCount) { - List<Hit> hitsToRescore = new ArrayList<>(); - // consider doing recursive iteration explicitly instead of using deepIterator? + void rerankHits(Result result) { + gatherHits(result); + runPreProcessing(); + hitRescorer.runNormalizers(); + runProcessing(); + runPostProcessing(); + result.hits().sort(); + } + + private void gatherHits(Result result) { for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { Hit hit = iterator.next(); if (hit.isMeta() || hit instanceof HitGroup) { continue; } // what about hits inside grouping results? - // they are inside GroupingListHit, we won't recurse into it; so we won't see them. - hitsToRescore.add(hit); + // they did not show up here during manual testing. + var wrapped = WrappedHit.from(hit); + if (wrapped != null) hitsToRescore.add(wrapped); } + } + + private void runPreProcessing() { // we can't be 100% certain that hits were sorted according to relevance: hitsToRescore.sort(Comparator.naturalOrder()); - var ranges = new Ranges(); - for (var iterator = hitsToRescore.iterator(); rerankCount > 0 && iterator.hasNext(); ) { - Hit hit = iterator.next(); - double oldScore = hit.getRelevance().getScore(); - boolean didRerank = hitRescorer.rescoreHit(hit); - if (didRerank) { - ranges.withInitialScore(oldScore); - ranges.withFinalScore(hit.getRelevance().getScore()); - --rerankCount; - iterator.remove(); - } + int count = 0; + for (WrappedHit hit : hitsToRescore) { + if (count == rerankCount) break; + hitRescorer.preprocess(hit); + ++count; } + } + + private void runProcessing() { + int count = 0; + for (var iterator = hitsToRescore.iterator(); count < rerankCount && iterator.hasNext(); ) { + WrappedHit wrapped = iterator.next(); + double oldScore = wrapped.getScore(); + double newScore = hitRescorer.rescoreHit(wrapped); + ranges.withInitialScore(oldScore); + ranges.withFinalScore(newScore); + ++count; + iterator.remove(); + } + } + + private void runPostProcessing() { // if any hits are left in the list, they may need rescaling: - if (ranges.rescaleNeeded()) { + if (ranges.rescaleNeeded() && ! hitsToRescore.isEmpty()) { double scale = ranges.scale(); double bias = ranges.bias(); - for (Hit hit : hitsToRescore) { - double oldScore = hit.getRelevance().getScore(); - hit.setRelevance(oldScore * scale + bias); + for (WrappedHit wrapped : hitsToRescore) { + double oldScore = wrapped.getScore(); + wrapped.setScore(oldScore * scale + bias); } } - result.hits().sort(); } } |