From d15365e34d7dbe85d166ba2d9029ade52c722a00 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 28 Feb 2023 11:26:44 +0000 Subject: use GlobalPhaseData from proxy --- .../yahoo/search/ranking/GlobalPhaseRanker.java | 48 +++++----------------- 1 file changed, 10 insertions(+), 38 deletions(-) diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 87213362acd..638bc13fb29 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -23,8 +23,6 @@ public class GlobalPhaseRanker { private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName()); private final RankProfilesEvaluatorFactory factory; - private final Set skipProcessing = new HashSet<>(); - private final Map> scorers = new HashMap<>(); @Inject public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) { @@ -33,11 +31,13 @@ public class GlobalPhaseRanker { } public void process(Query query, Result result, String schema) { - var functionEvaluatorSource = underlying(query, schema); - if (functionEvaluatorSource == null) { + var proxy = factory.proxyForSchema(schema); + String rankProfile = query.getRanking().getProfile(); + var data = proxy.getGlobalPhaseData(rankProfile); + if (data == null) return; - } - var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments()); + var functionEvaluatorSource = data.functionEvaluatorSource(); + var prepared = findFromQuery(query, data.needInputs()); Supplier supplier = () -> { var evaluator = functionEvaluatorSource.get(); var simple = new SimpleEvaluator(evaluator); @@ -46,9 +46,10 @@ public class GlobalPhaseRanker { } return simple; }; - // TODO need to get rerank-count somehow - int rerank = 7; - ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank); + int rerankCount = data.rerankCount(); + if (rerankCount < 0) + rerankCount = 100; + ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount); } record NameAndValue(String name, Tensor value) { } @@ -86,33 +87,4 @@ public class GlobalPhaseRanker { return result; } - private Supplier underlying(Query query, String schema) { - String rankProfile = query.getRanking().getProfile(); - String key = schema + " with rank profile " + rankProfile; - if (skipProcessing.contains(key)) { - return null; - } - Supplier supplier = scorers.get(key); - if (supplier != null) { - return supplier; - } - try { - var proxy = factory.proxyForSchema(schema); - var model = proxy.modelForRankProfile(rankProfile); - supplier = () -> model.evaluatorOf("globalphase"); - if (supplier.get() == null) { - supplier = null; - } - } catch (IllegalArgumentException e) { - logger.info("no global-phase for " + key + " because: " + e.getMessage()); - supplier = null; - } - if (supplier == null) { - skipProcessing.add(key); - } else { - scorers.put(key, supplier); - } - return supplier; - } - } -- cgit v1.2.3