From 60a9df99b067ddbfa72bc0a6d0cccbcd986cf45a Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 10 Oct 2023 13:03:46 +0000 Subject: restructure with normalizers in global-phase --- .../com/yahoo/search/ranking/DummyEvaluator.java | 38 +++++ .../java/com/yahoo/search/ranking/Evaluator.java | 4 +- .../java/com/yahoo/search/ranking/FunEvalCtx.java | 7 + .../yahoo/search/ranking/GlobalPhaseRanker.java | 79 ++++------- .../com/yahoo/search/ranking/GlobalPhaseSetup.java | 153 +++++++++++++++++++++ .../java/com/yahoo/search/ranking/HitRescorer.java | 76 +++++----- .../yahoo/search/ranking/NormalizerContext.java | 7 + .../com/yahoo/search/ranking/NormalizerSetup.java | 6 + .../com/yahoo/search/ranking/PreparedInput.java | 49 +++++++ .../com/yahoo/search/ranking/RangeAdjuster.java | 40 ++++++ .../search/ranking/RankProfilesEvaluator.java | 35 ++--- .../com/yahoo/search/ranking/ResultReranker.java | 98 +++++++------ .../com/yahoo/search/ranking/SimpleEvaluator.java | 19 ++- .../java/com/yahoo/search/ranking/WrappedHit.java | 91 ++++++++++++ 14 files changed, 522 insertions(+), 180 deletions(-) create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/FunEvalCtx.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java create mode 100644 container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java (limited to 'container-search') diff --git a/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java new file mode 100644 index 00000000000..e83a308d99c --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +class DummyEvaluator implements Evaluator { + + private final String input; + private Tensor result = null; + + DummyEvaluator(String input) { + this.input = input; + } + + @Override + public Evaluator bind(String name, Tensor value) { + result = value; + return this; + } + + @Override + public double evaluateScore() { + return result.asDouble(); + } + + @Override + public String toString() { + return "DummyEvaluator(" + input + ")"; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java index 9dab252e9a4..83f9d0e2704 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java @@ -3,11 +3,9 @@ package com.yahoo.search.ranking; import com.yahoo.tensor.Tensor; -import java.util.Collection; +import java.util.List; interface Evaluator { - Collection needInputs(); - Evaluator bind(String name, Tensor value); double evaluateScore(); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/FunEvalCtx.java b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalCtx.java new file mode 100644 index 00000000000..c8b8810f368 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalCtx.java @@ -0,0 +1,7 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.List; +import java.util.function.Supplier; + +record FunEvalCtx(Supplier evalSrc, List fromQuery, List fromMF) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 01ea5e3ebd5..130cf720684 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -5,7 +5,6 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.query.Sorting; -import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; @@ -33,8 +32,8 @@ public class GlobalPhaseRanker { } public Optional validateNoSorting(Query query, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return Optional.empty(); + var setup = globalPhaseSetupFor(query, schema).orElse(null); + if (setup == null) return Optional.empty(); var sorting = query.getRanking().getSorting(); if (sorting == null || sorting.fieldOrders() == null) return Optional.empty(); for (var fieldOrder : sorting.fieldOrders()) { @@ -47,23 +46,32 @@ public class GlobalPhaseRanker { } public void rerankHits(Query query, Result result, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return; - var functionEvaluatorSource = data.functionEvaluatorSource(); - var prepared = findFromQuery(query, data.needInputs()); + var setup = globalPhaseSetupFor(query, schema).orElse(null); + if (setup == null) return; + var mainSrc = withQueryPrep(setup.globalPhaseEvalCtx, query); + var mainMF = setup.globalPhaseEvalCtx.fromMF(); + int rerankCount = setup.rerankCount; + var normalizers = new ArrayList(); + for (var nSetup : setup.normalizers) { + var normEvalSrc = withQueryPrep(nSetup.evalCtx(), query); + normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, nSetup.evalCtx().fromMF())); + } + var rescorer = new HitRescorer(mainSrc, mainMF, normalizers); + var reranker = new ResultReranker(rescorer, rerankCount); + reranker.rerankHits(result); + hideImplicitMatchFeatures(result, setup.matchFeaturesToHide); + } + + static Supplier withQueryPrep(FunEvalCtx evalCtx, Query query) { + var prepared = PreparedInput.findFromQuery(query, evalCtx.fromQuery()); Supplier supplier = () -> { - var evaluator = functionEvaluatorSource.get(); - var simple = new SimpleEvaluator(evaluator); + var result = evalCtx.evalSrc().get(); for (var entry : prepared) { - simple.bind(entry.name(), entry.value()); + result.bind(entry.name(), entry.value()); } - return simple; + return result; }; - int rerankCount = data.rerankCount(); - if (rerankCount < 0) - rerankCount = 100; - ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount); - hideImplicitMatchFeatures(result, data.matchFeaturesToHide()); + return supplier; } private void hideImplicitMatchFeatures(Result result, Collection namesToHide) { @@ -87,44 +95,9 @@ public class GlobalPhaseRanker { } } - private Optional globalPhaseDataFor(Query query, String schema) { + private Optional globalPhaseSetupFor(Query query, String schema) { return factory.evaluatorForSchema(schema) - .flatMap(evaluator -> evaluator.getGlobalPhaseData(query.getRanking().getProfile())); - } - - record NameAndValue(String name, Tensor value) { } - - /* do this only once per query: */ - List findFromQuery(Query query, List needInputs) { - List result = new ArrayList<>(); - var ranking = query.getRanking(); - var rankFeatures = ranking.getFeatures(); - var rankProps = ranking.getProperties().asMap(); - for (String needed : needInputs) { - var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed); - if (optRef.isEmpty()) continue; - var ref = optRef.get(); - if (ref.name().equals("constant")) { - // XXX in theory, we should be able to avoid this - result.add(new NameAndValue(needed, null)); - continue; - } - if (ref.isSimple() && ref.name().equals("query")) { - String queryFeatureName = ref.simpleArgument().get(); - // searchers are recommended to place query features here: - var feature = rankFeatures.getTensor(queryFeatureName); - if (feature.isPresent()) { - result.add(new NameAndValue(needed, feature.get())); - } else { - // but other ways of setting query features end up in the properties: - var objList = rankProps.get(queryFeatureName); - if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { - result.add(new NameAndValue(needed, t)); - } - } - } - } - return result; + .flatMap(evaluator -> evaluator.getGlobalPhaseSetup(query.getRanking().getProfile())); } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java new file mode 100644 index 00000000000..e9335ee8123 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -0,0 +1,153 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; + +import com.yahoo.vespa.config.search.RankProfilesConfig; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.Map; +import java.util.HashMap; +import java.util.function.Supplier; + +class GlobalPhaseSetup { + + final FunEvalCtx globalPhaseEvalCtx; + final int rerankCount; + final Collection matchFeaturesToHide; + final List normalizers; + + GlobalPhaseSetup(FunEvalCtx globalPhase, + final int rerankCount, + Collection matchFeaturesToHide, + List normalizers) + { + this.globalPhaseEvalCtx = globalPhase; + this.rerankCount = rerankCount; + this.matchFeaturesToHide = matchFeaturesToHide; + this.normalizers = normalizers; + } + + static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { + var model = modelEvaluator.modelForRankProfile(rp.name()); + Map availableNormalizers = new HashMap<>(); + for (var n : rp.normalizer()) { + availableNormalizers.put(n.name(), n); + } + Supplier functionEvaluatorSource = null; + int rerankCount = -1; + Set namesToHide = new HashSet<>(); + Set matchFeatures = new HashSet<>(); + Map renameFeatures = new HashMap<>(); + String toRename = null; + for (var prop : rp.fef().property()) { + if (prop.name().equals("vespa.globalphase.rerankcount")) { + rerankCount = Integer.valueOf(prop.value()); + } + if (prop.name().equals("vespa.rank.globalphase")) { + functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); + } + if (prop.name().equals("vespa.hidden.matchfeature")) { + namesToHide.add(prop.value()); + } + if (prop.name().equals("vespa.match.feature")) { + matchFeatures.add(prop.value()); + } + if (prop.name().equals("vespa.feature.rename")) { + if (toRename == null) { + toRename = prop.value(); + } else { + renameFeatures.put(toRename, prop.value()); + toRename = null; + } + } + } + for (var entry : renameFeatures.entrySet()) { + String old = entry.getKey(); + if (matchFeatures.contains(old)) { + matchFeatures.remove(old); + matchFeatures.add(entry.getValue()); + } + } + if (rerankCount < 0) { + rerankCount = 100; + } + if (functionEvaluatorSource != null) { + var evaluator = functionEvaluatorSource.get(); + var allInputs = List.copyOf(evaluator.function().arguments()); + List fromMF = new ArrayList<>(); + List fromQuery = new ArrayList<>(); + List normalizers = new ArrayList<>(); + for (var input : allInputs) { + String queryFeatureName = asQueryFeature(input); + if (queryFeatureName != null) { + fromQuery.add(queryFeatureName); + } else if (availableNormalizers.containsKey(input)) { + var cfg = availableNormalizers.get(input); + String normInput = cfg.input(); + if (matchFeatures.contains(normInput)) { + Supplier normSource = () -> new DummyEvaluator(normInput); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSource, List.of(normInput), rerankCount)); + } else { + Supplier normSource = () -> model.evaluatorOf(normInput); + var normInputs = List.copyOf(normSource.get().function().arguments()); + var normSupplier = SimpleEvaluator.wrap(normSource); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSupplier, normInputs, rerankCount)); + } + } else if (matchFeatures.contains(input)) { + fromMF.add(input); + } else { + throw new IllegalArgumentException("Bad config, missing global-phase input: " + input); + } + } + Supplier supplier = SimpleEvaluator.wrap(functionEvaluatorSource); + var gfun = new FunEvalCtx(supplier, fromQuery, fromMF); + return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers); + } + return null; + } + + private static NormalizerSetup makeNormalizerSetup(RankProfilesConfig.Rankprofile.Normalizer cfg, + Set matchFeatures, + Supplier evalSupplier, + List normInputs, + int rerankCount) + { + List fromQuery = new ArrayList<>(); + List fromMF = new ArrayList<>(); + for (var input : normInputs) { + String queryFeatureName = asQueryFeature(input); + if (queryFeatureName != null) { + fromQuery.add(queryFeatureName); + } else if (matchFeatures.contains(input)) { + fromMF.add(input); + } else { + throw new IllegalArgumentException("Bad config, missing normalizer input: " + input); + } + } + var fun = new FunEvalCtx(evalSupplier, fromQuery, fromMF); + return new NormalizerSetup(cfg.name(), makeNormalizerSupplier(cfg, rerankCount), fun); + } + + private static Supplier makeNormalizerSupplier(RankProfilesConfig.Rankprofile.Normalizer cfg, int rerankCount) { + return switch (cfg.algo()) { + case LINEAR -> () -> new LinearNormalizer(rerankCount); + case RRANK -> () -> new ReciprocalRankNormalizer(rerankCount, cfg.kparam()); + }; + } + + static String asQueryFeature(String input) { + var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(input); + if (optRef.isPresent()) { + var ref = optRef.get(); + if (ref.isSimple() && ref.name().equals("query")) { + return ref.simpleArgument().get(); + } + } + return null; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java index f6519158e88..7336713da64 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -3,55 +3,61 @@ package com.yahoo.search.ranking; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; -import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; +import com.yahoo.tensor.Tensor; +import java.util.List; import java.util.function.Supplier; import java.util.logging.Logger; class HitRescorer { private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); - - private final Supplier evaluatorSource; - public HitRescorer(Supplier evaluatorSource) { - this.evaluatorSource = evaluatorSource; + private final Supplier mainEvalSrc; + private final List mainFromMF; + private final List normalizers; + + public HitRescorer(Supplier mainEvalSrc, List mainFromMF, List normalizers) { + this.mainEvalSrc = mainEvalSrc; + this.mainFromMF = mainFromMF; + this.normalizers = normalizers; } - boolean rescoreHit(Hit hit) { - var features = hit.getField("matchfeatures"); - if (features instanceof FeatureData matchFeatures) { - var scorer = evaluatorSource.get(); - for (String argName : scorer.needInputs()) { - var asTensor = matchFeatures.getTensor(argName); - if (asTensor == null) { - asTensor = matchFeatures.getTensor(alternate(argName)); - } - if (asTensor != null) { - scorer.bind(argName, asTensor); - } else { - logger.warning("Missing match-feature for Evaluator argument: " + argName); - return false; - } - } - double newScore = scorer.evaluateScore(); - hit.setRelevance(newScore); - return true; - } else { - logger.warning("Hit without match-features: " + hit); - return false; + void preprocess(WrappedHit wrapped) { + for (var n : normalizers) { + var scorer = n.evalSource().get(); + double val = evalScorer(wrapped, scorer, n.fromMF()); + wrapped.setIdx(n.normalizer().addInput(val)); + } + } + + void runNormalizers() { + for (var n : normalizers) { + n.normalizer().normalize(); } } - private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; - private static final String RE_SUFFIX = ")"; - private static final int RE_PRE_LEN = RE_PREFIX.length(); - private static final int RE_SUF_LEN = RE_SUFFIX.length(); + boolean rescoreHit(WrappedHit wrapped) { + var scorer = mainEvalSrc.get(); + for (var n : normalizers) { + double normalizedValue = n.normalizer().getOutput(wrapped.getIdx()); + scorer.bind(n.name(), Tensor.from(normalizedValue)); + } + double newScore = evalScorer(wrapped, scorer, mainFromMF); + wrapped.setScore(newScore); + return true; + } - static String alternate(String argName) { - if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { - return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List fromMF) { + for (String argName : fromMF) { + var asTensor = wrapped.getTensor(argName); + if (asTensor != null) { + scorer.bind(argName, asTensor); + } else { + logger.warning("Missing match-feature for Evaluator argument: " + argName); + return 0.0; + } } - return argName; + return scorer.evaluateScore(); } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java new file mode 100644 index 00000000000..9438b5ea824 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java @@ -0,0 +1,7 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.List; +import java.util.function.Supplier; + +record NormalizerContext(String name, Normalizer normalizer, Supplier evalSource, List fromMF) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java new file mode 100644 index 00000000000..6e93a73b6ed --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java @@ -0,0 +1,6 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.function.Supplier; + +record NormalizerSetup(String name, Supplier supplier, FunEvalCtx evalCtx) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java new file mode 100644 index 00000000000..5ab2d7160f9 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.query.Sorting; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; +import com.yahoo.tensor.Tensor; +import com.yahoo.data.access.helpers.MatchFeatureData; +import com.yahoo.data.access.helpers.MatchFeatureFilter; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.logging.Logger; + +record PreparedInput(String name, Tensor value) { + + static List findFromQuery(Query query, Collection queryFeatures) { + List result = new ArrayList<>(); + var ranking = query.getRanking(); + var rankFeatures = ranking.getFeatures(); + var rankProps = ranking.getProperties().asMap(); + for (String queryFeatureName : queryFeatures) { + String needed = "query(" + queryFeatureName + ")"; + // searchers are recommended to place query features here: + var feature = rankFeatures.getTensor(queryFeatureName); + if (feature.isPresent()) { + result.add(new PreparedInput(needed, feature.get())); + } else { + // but other ways of setting query features end up in the properties: + var objList = rankProps.get(queryFeatureName); + if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { + result.add(new PreparedInput(needed, t)); + } else { + throw new IllegalArgumentException("missing query feature: " + queryFeatureName); + } + } + } + return result; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java b/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java new file mode 100644 index 00000000000..6881eece620 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +// scale and adjust the score according to the range +// of the original and final score values to avoid that +// a score from the backend is larger than finalScores_low +class RangeAdjuster { + private double initialScores_high = -Double.MAX_VALUE; + private double initialScores_low = Double.MAX_VALUE; + private double finalScores_high = -Double.MAX_VALUE; + private double finalScores_low = Double.MAX_VALUE; + + boolean rescaleNeeded() { + return (initialScores_low > finalScores_low + && + initialScores_high >= initialScores_low + && + finalScores_high >= finalScores_low); + } + void withInitialScore(double score) { + if (score < initialScores_low) initialScores_low = score; + if (score > initialScores_high) initialScores_high = score; + } + void withFinalScore(double score) { + if (score < finalScores_low) finalScores_low = score; + if (score > finalScores_high) finalScores_high = score; + } + private double initialRange() { + double r = initialScores_high - initialScores_low; + if (r < 1.0) r = 1.0; + return r; + } + private double finalRange() { + double r = finalScores_high - finalScores_low; + if (r < 1.0) r = 1.0; + return r; + } + double scale() { return finalRange() / initialRange(); } + double bias() { return finalScores_low - initialScores_low * scale(); } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java index 353c88d374a..b6b4c4080b4 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java @@ -63,40 +63,21 @@ public class RankProfilesEvaluator extends AbstractComponent { return modelForRankProfile(rankProfile).evaluatorOf(functionName); } - static record GlobalPhaseData(Supplier functionEvaluatorSource, - Collection matchFeaturesToHide, - int rerankCount, - List needInputs) {} + Supplier getSupplier(String rankProfile, String functionName) { + return () -> new SimpleEvaluator(evaluatorForFunction(rankProfile, functionName)); + } - private Map profilesWithGlobalPhase = new HashMap<>(); + private Map profilesWithGlobalPhase = new HashMap<>(); - Optional getGlobalPhaseData(String rankProfile) { + Optional getGlobalPhaseSetup(String rankProfile) { return Optional.ofNullable(profilesWithGlobalPhase.get(rankProfile)); } private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) { for (var rp : rankProfilesConfig.rankprofile()) { - String name = rp.name(); - Supplier functionEvaluatorSource = null; - int rerankCount = -1; - List needInputs = null; - Set namesToHide = new HashSet<>(); - for (var prop : rp.fef().property()) { - if (prop.name().equals("vespa.globalphase.rerankcount")) { - rerankCount = Integer.valueOf(prop.value()); - } - if (prop.name().equals("vespa.rank.globalphase")) { - var model = modelForRankProfile(name); - functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); - var evaluator = functionEvaluatorSource.get(); - needInputs = List.copyOf(evaluator.function().arguments()); - } - if (prop.name().equals("vespa.hidden.matchfeature")) { - namesToHide.add(prop.value()); - } - } - if (functionEvaluatorSource != null && needInputs != null) { - profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, namesToHide, rerankCount, needInputs)); + var setup = GlobalPhaseSetup.maybeMakeSetup(rp, this); + if (setup != null) { + profilesWithGlobalPhase.put(rp.name(), setup); } } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java index d92068cd8d9..2d54e58daa2 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java @@ -14,80 +14,74 @@ class ResultReranker { private static final Logger logger = Logger.getLogger(ResultReranker.class.getName()); - // scale and adjust the score according to the range - // of the original and final score values to avoid that - // a score from the backend is larger than finalScores_low - static class Ranges { - private double initialScores_high = -Double.MAX_VALUE; - private double initialScores_low = Double.MAX_VALUE; - private double finalScores_high = -Double.MAX_VALUE; - private double finalScores_low = Double.MAX_VALUE; + private final HitRescorer hitRescorer; + private final int rerankCount; + private final List hitsToRescore = new ArrayList<>(); + private final RangeAdjuster ranges = new RangeAdjuster(); - boolean rescaleNeeded() { - return (initialScores_low > finalScores_low - && - initialScores_high >= initialScores_low - && - finalScores_high >= finalScores_low); - } - void withInitialScore(double score) { - if (score < initialScores_low) initialScores_low = score; - if (score > initialScores_high) initialScores_high = score; - } - void withFinalScore(double score) { - if (score < finalScores_low) finalScores_low = score; - if (score > finalScores_high) finalScores_high = score; - } - private double initialRange() { - double r = initialScores_high - initialScores_low; - if (r < 1.0) r = 1.0; - return r; - } - private double finalRange() { - double r = finalScores_high - finalScores_low; - if (r < 1.0) r = 1.0; - return r; - } - double scale() { return finalRange() / initialRange(); } - double bias() { return finalScores_low - initialScores_low * scale(); } + ResultReranker(HitRescorer hitRescorer, int rerankCount) { + this.hitRescorer = hitRescorer; + this.rerankCount = rerankCount; } - static void rerankHits(Result result, HitRescorer hitRescorer, int rerankCount) { - List hitsToRescore = new ArrayList<>(); - // consider doing recursive iteration explicitly instead of using deepIterator? + void rerankHits(Result result) { + gatherHits(result); + runPreProcessing(); + hitRescorer.runNormalizers(); + runProcessing(); + runPostProcessing(); + result.hits().sort(); + } + + private void gatherHits(Result result) { for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { Hit hit = iterator.next(); if (hit.isMeta() || hit instanceof HitGroup) { continue; } // what about hits inside grouping results? - // they are inside GroupingListHit, we won't recurse into it; so we won't see them. - hitsToRescore.add(hit); + // they did not show up here during manual testing. + var wrapped = WrappedHit.from(hit); + if (wrapped != null) hitsToRescore.add(wrapped); } + } + + private void runPreProcessing() { // we can't be 100% certain that hits were sorted according to relevance: hitsToRescore.sort(Comparator.naturalOrder()); - var ranges = new Ranges(); - for (var iterator = hitsToRescore.iterator(); rerankCount > 0 && iterator.hasNext(); ) { - Hit hit = iterator.next(); - double oldScore = hit.getRelevance().getScore(); - boolean didRerank = hitRescorer.rescoreHit(hit); + int count = 0; + for (WrappedHit hit : hitsToRescore) { + if (count == rerankCount) break; + hitRescorer.preprocess(hit); + ++count; + } + } + + private void runProcessing() { + int count = 0; + for (var iterator = hitsToRescore.iterator(); count < rerankCount && iterator.hasNext(); ) { + WrappedHit wrapped = iterator.next(); + double oldScore = wrapped.getScore(); + boolean didRerank = hitRescorer.rescoreHit(wrapped); if (didRerank) { ranges.withInitialScore(oldScore); - ranges.withFinalScore(hit.getRelevance().getScore()); - --rerankCount; + ranges.withFinalScore(wrapped.getScore()); + ++count; iterator.remove(); } } + } + + private void runPostProcessing() { // if any hits are left in the list, they may need rescaling: - if (ranges.rescaleNeeded()) { + if (ranges.rescaleNeeded() && ! hitsToRescore.isEmpty()) { double scale = ranges.scale(); double bias = ranges.bias(); - for (Hit hit : hitsToRescore) { - double oldScore = hit.getRelevance().getScore(); - hit.setRelevance(oldScore * scale + bias); + for (WrappedHit wrapped : hitsToRescore) { + double oldScore = wrapped.getScore(); + wrapped.setScore(oldScore * scale + bias); } } - result.hits().sort(); } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java index a42024c80a1..f2943c18960 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java @@ -10,24 +10,23 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.Supplier; class SimpleEvaluator implements Evaluator { private final FunctionEvaluator evaluator; - private final Set neededInputs; - - public SimpleEvaluator(FunctionEvaluator prototype) { - this.evaluator = prototype; - this.neededInputs = new HashSet(prototype.function().arguments()); + + static Supplier wrap(Supplier supplier) { + return () -> new SimpleEvaluator(supplier.get()); } - @Override - public Collection needInputs() { return List.copyOf(neededInputs); } + SimpleEvaluator(FunctionEvaluator prototype) { + this.evaluator = prototype; + } @Override - public SimpleEvaluator bind(String name, Tensor value) { + public Evaluator bind(String name, Tensor value) { if (value != null) evaluator.bind(name, value); - neededInputs.remove(name); return this; } @@ -42,7 +41,7 @@ class SimpleEvaluator implements Evaluator { buf.append("SimpleEvaluator("); buf.append(evaluator.function().toString()); buf.append(")["); - for (String arg : neededInputs) { + for (String arg : evaluator.function().arguments()) { buf.append("{").append(arg).append("}"); } buf.append("]"); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java b/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java new file mode 100644 index 00000000000..3b564eec3aa --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java @@ -0,0 +1,91 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; + +import java.util.logging.Logger; + +class WrappedHit implements Comparable { + + private static final Logger logger = Logger.getLogger(WrappedHit.class.getName()); + private final Hit hit; + private final FeatureData matchFeatures; + private int idx = -1; + + private WrappedHit(Hit hit, FeatureData matchFeatures) { + this.hit = hit; + this.matchFeatures = matchFeatures; + } + + static WrappedHit from(Hit hit) { + if (hit.getField("matchfeatures") instanceof FeatureData mf) { + return new WrappedHit(hit, mf); + } else { + return null; + } + } + + double getScore() { + return hit.getRelevance().getScore(); + } + + void setScore(double value) { + hit.setRelevance(value); + } + + int getIdx() { + if (idx < 0) { + throw new IllegalStateException("Missing index"); + } + return idx; + } + + void setIdx(int value) { + if (idx == value) { + return; + } else if (idx < 0) { + idx = value; + } else { + throw new IllegalArgumentException("Cannot re-assign index " + idx + " -> " + value); + } + } + + public int compareTo(WrappedHit other) { + return hit.compareTo(other.hit); + } + + Tensor getTensor(String argName) { + var asTensor = matchFeatures.getTensor(argName); + if (asTensor == null) { + asTensor = matchFeatures.getTensor(alternate(argName)); + } + return asTensor; + } + + Double getDouble(String argName) { + Double arg = matchFeatures.getDouble(argName); + if (arg == null) { + arg = matchFeatures.getDouble(alternate(argName)); + } + return arg; + } + + private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; + private static final String RE_SUFFIX = ")"; + private static final int RE_PRE_LEN = RE_PREFIX.length(); + private static final int RE_SUF_LEN = RE_SUFFIX.length(); + + // rankingExpression(foo) <-> foo + static String alternate(String argName) { + if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { + return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + } else { + return RE_PREFIX + argName + RE_SUFFIX; + } + } + +} -- cgit v1.2.3