aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-28 11:20:12 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-28 14:32:44 +0000
commit03a8cc6920bab8ea27ad4b5ca749618951495cd6 (patch)
tree42f5b796e8e1cc27f261b8130a2f3195c00420fc /container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
parent98167b26a57e9413bf2d4ea71e99228a825432b9 (diff)
extract GlobalPhaseData from rank-profiles config
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java53
1 files changed, 53 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
index ccb9b9837fe..2057b50f0aa 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
@@ -14,6 +14,12 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.List;
+import java.util.function.Supplier;
+import java.util.logging.Logger;
+
/**
* proxy for model-evaluation components
* @author arnej
@@ -22,6 +28,7 @@ import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
public class RankProfilesEvaluator extends AbstractComponent {
private final ModelsEvaluator evaluator;
+ private static final Logger logger = Logger.getLogger(RankProfilesEvaluator.class.getName());
@Inject
public RankProfilesEvaluator(
@@ -37,6 +44,7 @@ public class RankProfilesEvaluator extends AbstractComponent {
expressionsConfig,
onnxModelsConfig,
fileAcquirer);
+ extractGlobalPhaseData(rankProfilesConfig);
}
public Model modelForRankProfile(String rankProfile) {
@@ -50,4 +58,49 @@ public class RankProfilesEvaluator extends AbstractComponent {
public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) {
return modelForRankProfile(rankProfile).evaluatorOf(functionName);
}
+
+ static record GlobalPhaseData(Supplier<FunctionEvaluator> functionEvaluatorSource,
+ int rerankCount,
+ List<String> needInputs) {}
+
+ private Map<String, GlobalPhaseData> profilesWithGlobalPhase = new HashMap<>();
+
+ GlobalPhaseData getGlobalPhaseData(String rankProfile) {
+ return profilesWithGlobalPhase.get(rankProfile);
+ }
+
+ private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) {
+ for (var rp : rankProfilesConfig.rankprofile()) {
+ String name = rp.name();
+ Supplier<FunctionEvaluator> functionEvaluatorSource = null;
+ int rerankCount = -1;
+ List<String> needInputs = null;
+
+ for (var prop : rp.fef().property()) {
+ if (prop.name().equals("vespa.globalphase.rerankcount")) {
+ try {
+ rerankCount = Integer.valueOf(prop.value());
+ } catch (NumberFormatException e) {
+ logger.warning("bad vespa.globalphase.rerankcount '" + prop.value() +
+ "' for rank profile " + name + ": " + e.getMessage());
+ }
+ }
+ if (prop.name().equals("vespa.rank.globalphase")) {
+ try {
+ var model = modelForRankProfile(name);
+ functionEvaluatorSource = () -> model.evaluatorOf("globalphase");
+ var evaluator = functionEvaluatorSource.get();
+ needInputs = List.copyOf(evaluator.function().arguments());
+ } catch (IllegalArgumentException e) {
+ logger.warning("failed setting up global-phase for " + name + " because: " + e.getMessage());
+ functionEvaluatorSource = null;
+ }
+ }
+ }
+ if (functionEvaluatorSource != null && needInputs != null) {
+ profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, rerankCount, needInputs));
+ }
+ }
+ }
+
}