summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-28 11:26:44 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-28 14:32:44 +0000
commitd15365e34d7dbe85d166ba2d9029ade52c722a00 (patch)
tree06eb658eca6cfad5548e754e6c88422efe2695f3 /container-search
parent03a8cc6920bab8ea27ad4b5ca749618951495cd6 (diff)
use GlobalPhaseData from proxy
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java48
1 files changed, 10 insertions, 38 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
index 87213362acd..638bc13fb29 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
@@ -23,8 +23,6 @@ public class GlobalPhaseRanker {
private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName());
private final RankProfilesEvaluatorFactory factory;
- private final Set<String> skipProcessing = new HashSet<>();
- private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>();
@Inject
public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) {
@@ -33,11 +31,13 @@ public class GlobalPhaseRanker {
}
public void process(Query query, Result result, String schema) {
- var functionEvaluatorSource = underlying(query, schema);
- if (functionEvaluatorSource == null) {
+ var proxy = factory.proxyForSchema(schema);
+ String rankProfile = query.getRanking().getProfile();
+ var data = proxy.getGlobalPhaseData(rankProfile);
+ if (data == null)
return;
- }
- var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments());
+ var functionEvaluatorSource = data.functionEvaluatorSource();
+ var prepared = findFromQuery(query, data.needInputs());
Supplier<Evaluator> supplier = () -> {
var evaluator = functionEvaluatorSource.get();
var simple = new SimpleEvaluator(evaluator);
@@ -46,9 +46,10 @@ public class GlobalPhaseRanker {
}
return simple;
};
- // TODO need to get rerank-count somehow
- int rerank = 7;
- ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank);
+ int rerankCount = data.rerankCount();
+ if (rerankCount < 0)
+ rerankCount = 100;
+ ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount);
}
record NameAndValue(String name, Tensor value) { }
@@ -86,33 +87,4 @@ public class GlobalPhaseRanker {
return result;
}
- private Supplier<FunctionEvaluator> underlying(Query query, String schema) {
- String rankProfile = query.getRanking().getProfile();
- String key = schema + " with rank profile " + rankProfile;
- if (skipProcessing.contains(key)) {
- return null;
- }
- Supplier<FunctionEvaluator> supplier = scorers.get(key);
- if (supplier != null) {
- return supplier;
- }
- try {
- var proxy = factory.proxyForSchema(schema);
- var model = proxy.modelForRankProfile(rankProfile);
- supplier = () -> model.evaluatorOf("globalphase");
- if (supplier.get() == null) {
- supplier = null;
- }
- } catch (IllegalArgumentException e) {
- logger.info("no global-phase for " + key + " because: " + e.getMessage());
- supplier = null;
- }
- if (supplier == null) {
- skipProcessing.add(key);
- } else {
- scorers.put(key, supplier);
- }
- return supplier;
- }
-
}