diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-28 11:20:12 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-28 14:32:44 +0000 |
commit | 03a8cc6920bab8ea27ad4b5ca749618951495cd6 (patch) | |
tree | 42f5b796e8e1cc27f261b8130a2f3195c00420fc /container-search/src/main | |
parent | 98167b26a57e9413bf2d4ea71e99228a825432b9 (diff) |
extract GlobalPhaseData from rank-profiles config
Diffstat (limited to 'container-search/src/main')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java index ccb9b9837fe..2057b50f0aa 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java @@ -14,6 +14,12 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; +import java.util.HashMap; +import java.util.Map; +import java.util.List; +import java.util.function.Supplier; +import java.util.logging.Logger; + /** * proxy for model-evaluation components * @author arnej @@ -22,6 +28,7 @@ import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; public class RankProfilesEvaluator extends AbstractComponent { private final ModelsEvaluator evaluator; + private static final Logger logger = Logger.getLogger(RankProfilesEvaluator.class.getName()); @Inject public RankProfilesEvaluator( @@ -37,6 +44,7 @@ public class RankProfilesEvaluator extends AbstractComponent { expressionsConfig, onnxModelsConfig, fileAcquirer); + extractGlobalPhaseData(rankProfilesConfig); } public Model modelForRankProfile(String rankProfile) { @@ -50,4 +58,49 @@ public class RankProfilesEvaluator extends AbstractComponent { public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) { return modelForRankProfile(rankProfile).evaluatorOf(functionName); } + + static record GlobalPhaseData(Supplier<FunctionEvaluator> functionEvaluatorSource, + int rerankCount, + List<String> needInputs) {} + + private Map<String, GlobalPhaseData> profilesWithGlobalPhase = new HashMap<>(); + + GlobalPhaseData getGlobalPhaseData(String rankProfile) { + return profilesWithGlobalPhase.get(rankProfile); + } + + private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) { + for (var rp : rankProfilesConfig.rankprofile()) { + String name = rp.name(); + Supplier<FunctionEvaluator> functionEvaluatorSource = null; + int rerankCount = -1; + List<String> needInputs = null; + + for (var prop : rp.fef().property()) { + if (prop.name().equals("vespa.globalphase.rerankcount")) { + try { + rerankCount = Integer.valueOf(prop.value()); + } catch (NumberFormatException e) { + logger.warning("bad vespa.globalphase.rerankcount '" + prop.value() + + "' for rank profile " + name + ": " + e.getMessage()); + } + } + if (prop.name().equals("vespa.rank.globalphase")) { + try { + var model = modelForRankProfile(name); + functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); + var evaluator = functionEvaluatorSource.get(); + needInputs = List.copyOf(evaluator.function().arguments()); + } catch (IllegalArgumentException e) { + logger.warning("failed setting up global-phase for " + name + " because: " + e.getMessage()); + functionEvaluatorSource = null; + } + } + } + if (functionEvaluatorSource != null && needInputs != null) { + profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, rerankCount, needInputs)); + } + } + } + } |