diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java | 50 | ||||
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java | 44 |
2 files changed, 56 insertions, 38 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 87213362acd..b72f81f1439 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -6,6 +6,7 @@ import ai.vespa.models.evaluation.Model; import com.yahoo.component.annotation.Inject; import com.yahoo.search.Query; import com.yahoo.search.Result; +import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData; import com.yahoo.search.result.Hit; import com.yahoo.search.result.HitGroup; import com.yahoo.tensor.Tensor; @@ -23,8 +24,6 @@ public class GlobalPhaseRanker { private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName()); private final RankProfilesEvaluatorFactory factory; - private final Set<String> skipProcessing = new HashSet<>(); - private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>(); @Inject public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) { @@ -33,11 +32,14 @@ public class GlobalPhaseRanker { } public void process(Query query, Result result, String schema) { - var functionEvaluatorSource = underlying(query, schema); - if (functionEvaluatorSource == null) { + var proxy = factory.proxyForSchema(schema); + String rankProfile = query.getRanking().getProfile(); + var optData = proxy.getGlobalPhaseData(rankProfile); + if (optData.isEmpty()) return; - } - var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments()); + GlobalPhaseData data = optData.get(); + var functionEvaluatorSource = data.functionEvaluatorSource(); + var prepared = findFromQuery(query, data.needInputs()); Supplier<Evaluator> supplier = () -> { var evaluator = functionEvaluatorSource.get(); var simple = new SimpleEvaluator(evaluator); @@ -46,9 +48,10 @@ public class GlobalPhaseRanker { } return simple; }; - // TODO need to get rerank-count somehow - int rerank = 7; - ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank); + int rerankCount = data.rerankCount(); + if (rerankCount < 0) + rerankCount = 100; + ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount); } record NameAndValue(String name, Tensor value) { } @@ -86,33 +89,4 @@ public class GlobalPhaseRanker { return result; } - private Supplier<FunctionEvaluator> underlying(Query query, String schema) { - String rankProfile = query.getRanking().getProfile(); - String key = schema + " with rank profile " + rankProfile; - if (skipProcessing.contains(key)) { - return null; - } - Supplier<FunctionEvaluator> supplier = scorers.get(key); - if (supplier != null) { - return supplier; - } - try { - var proxy = factory.proxyForSchema(schema); - var model = proxy.modelForRankProfile(rankProfile); - supplier = () -> model.evaluatorOf("globalphase"); - if (supplier.get() == null) { - supplier = null; - } - } catch (IllegalArgumentException e) { - logger.info("no global-phase for " + key + " because: " + e.getMessage()); - supplier = null; - } - if (supplier == null) { - skipProcessing.add(key); - } else { - scorers.put(key, supplier); - } - return supplier; - } - } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java index ccb9b9837fe..2ca91a3ea91 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java @@ -14,6 +14,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; +import java.util.logging.Logger; + /** * proxy for model-evaluation components * @author arnej @@ -22,6 +29,7 @@ import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; public class RankProfilesEvaluator extends AbstractComponent { private final ModelsEvaluator evaluator; + private static final Logger logger = Logger.getLogger(RankProfilesEvaluator.class.getName()); @Inject public RankProfilesEvaluator( @@ -37,6 +45,7 @@ public class RankProfilesEvaluator extends AbstractComponent { expressionsConfig, onnxModelsConfig, fileAcquirer); + extractGlobalPhaseData(rankProfilesConfig); } public Model modelForRankProfile(String rankProfile) { @@ -50,4 +59,39 @@ public class RankProfilesEvaluator extends AbstractComponent { public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) { return modelForRankProfile(rankProfile).evaluatorOf(functionName); } + + static record GlobalPhaseData(Supplier<FunctionEvaluator> functionEvaluatorSource, + int rerankCount, + List<String> needInputs) {} + + private Map<String, GlobalPhaseData> profilesWithGlobalPhase = new HashMap<>(); + + Optional<GlobalPhaseData> getGlobalPhaseData(String rankProfile) { + return Optional.ofNullable(profilesWithGlobalPhase.get(rankProfile)); + } + + private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) { + for (var rp : rankProfilesConfig.rankprofile()) { + String name = rp.name(); + Supplier<FunctionEvaluator> functionEvaluatorSource = null; + int rerankCount = -1; + List<String> needInputs = null; + + for (var prop : rp.fef().property()) { + if (prop.name().equals("vespa.globalphase.rerankcount")) { + rerankCount = Integer.valueOf(prop.value()); + } + if (prop.name().equals("vespa.rank.globalphase")) { + var model = modelForRankProfile(name); + functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); + var evaluator = functionEvaluatorSource.get(); + needInputs = List.copyOf(evaluator.function().arguments()); + } + } + if (functionEvaluatorSource != null && needInputs != null) { + profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, rerankCount, needInputs)); + } + } + } + } |