summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-03-01 14:25:18 +0100
committerGitHub <noreply@github.com>2023-03-01 14:25:18 +0100
commit0c766a6506aa82299c2cd1d166fcc3522a962f2f (patch)
treea38e3958faaf731385037e04abbca3c3e92dee17 /container-search
parentcdc86f437afbb34cabc9f05db951bef6ad206121 (diff)
parentb354096f220e76d28776f40bb3b89d3a6c3911cf (diff)
Merge pull request #26229 from vespa-engine/arnej/global-phase-data-from-proxy
Arnej/global phase data from proxy
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java44
2 files changed, 56 insertions, 38 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
index 87213362acd..b72f81f1439 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
@@ -6,6 +6,7 @@ import ai.vespa.models.evaluation.Model;
import com.yahoo.component.annotation.Inject;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
+import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData;
import com.yahoo.search.result.Hit;
import com.yahoo.search.result.HitGroup;
import com.yahoo.tensor.Tensor;
@@ -23,8 +24,6 @@ public class GlobalPhaseRanker {
private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName());
private final RankProfilesEvaluatorFactory factory;
- private final Set<String> skipProcessing = new HashSet<>();
- private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>();
@Inject
public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) {
@@ -33,11 +32,14 @@ public class GlobalPhaseRanker {
}
public void process(Query query, Result result, String schema) {
- var functionEvaluatorSource = underlying(query, schema);
- if (functionEvaluatorSource == null) {
+ var proxy = factory.proxyForSchema(schema);
+ String rankProfile = query.getRanking().getProfile();
+ var optData = proxy.getGlobalPhaseData(rankProfile);
+ if (optData.isEmpty())
return;
- }
- var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments());
+ GlobalPhaseData data = optData.get();
+ var functionEvaluatorSource = data.functionEvaluatorSource();
+ var prepared = findFromQuery(query, data.needInputs());
Supplier<Evaluator> supplier = () -> {
var evaluator = functionEvaluatorSource.get();
var simple = new SimpleEvaluator(evaluator);
@@ -46,9 +48,10 @@ public class GlobalPhaseRanker {
}
return simple;
};
- // TODO need to get rerank-count somehow
- int rerank = 7;
- ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank);
+ int rerankCount = data.rerankCount();
+ if (rerankCount < 0)
+ rerankCount = 100;
+ ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount);
}
record NameAndValue(String name, Tensor value) { }
@@ -86,33 +89,4 @@ public class GlobalPhaseRanker {
return result;
}
- private Supplier<FunctionEvaluator> underlying(Query query, String schema) {
- String rankProfile = query.getRanking().getProfile();
- String key = schema + " with rank profile " + rankProfile;
- if (skipProcessing.contains(key)) {
- return null;
- }
- Supplier<FunctionEvaluator> supplier = scorers.get(key);
- if (supplier != null) {
- return supplier;
- }
- try {
- var proxy = factory.proxyForSchema(schema);
- var model = proxy.modelForRankProfile(rankProfile);
- supplier = () -> model.evaluatorOf("globalphase");
- if (supplier.get() == null) {
- supplier = null;
- }
- } catch (IllegalArgumentException e) {
- logger.info("no global-phase for " + key + " because: " + e.getMessage());
- supplier = null;
- }
- if (supplier == null) {
- skipProcessing.add(key);
- } else {
- scorers.put(key, supplier);
- }
- return supplier;
- }
-
}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
index ccb9b9837fe..2ca91a3ea91 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java
@@ -14,6 +14,13 @@ import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.function.Supplier;
+import java.util.logging.Logger;
+
/**
* proxy for model-evaluation components
* @author arnej
@@ -22,6 +29,7 @@ import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
public class RankProfilesEvaluator extends AbstractComponent {
private final ModelsEvaluator evaluator;
+ private static final Logger logger = Logger.getLogger(RankProfilesEvaluator.class.getName());
@Inject
public RankProfilesEvaluator(
@@ -37,6 +45,7 @@ public class RankProfilesEvaluator extends AbstractComponent {
expressionsConfig,
onnxModelsConfig,
fileAcquirer);
+ extractGlobalPhaseData(rankProfilesConfig);
}
public Model modelForRankProfile(String rankProfile) {
@@ -50,4 +59,39 @@ public class RankProfilesEvaluator extends AbstractComponent {
public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) {
return modelForRankProfile(rankProfile).evaluatorOf(functionName);
}
+
+ static record GlobalPhaseData(Supplier<FunctionEvaluator> functionEvaluatorSource,
+ int rerankCount,
+ List<String> needInputs) {}
+
+ private Map<String, GlobalPhaseData> profilesWithGlobalPhase = new HashMap<>();
+
+ Optional<GlobalPhaseData> getGlobalPhaseData(String rankProfile) {
+ return Optional.ofNullable(profilesWithGlobalPhase.get(rankProfile));
+ }
+
+ private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) {
+ for (var rp : rankProfilesConfig.rankprofile()) {
+ String name = rp.name();
+ Supplier<FunctionEvaluator> functionEvaluatorSource = null;
+ int rerankCount = -1;
+ List<String> needInputs = null;
+
+ for (var prop : rp.fef().property()) {
+ if (prop.name().equals("vespa.globalphase.rerankcount")) {
+ rerankCount = Integer.valueOf(prop.value());
+ }
+ if (prop.name().equals("vespa.rank.globalphase")) {
+ var model = modelForRankProfile(name);
+ functionEvaluatorSource = () -> model.evaluatorOf("globalphase");
+ var evaluator = functionEvaluatorSource.get();
+ needInputs = List.copyOf(evaluator.function().arguments());
+ }
+ }
+ if (functionEvaluatorSource != null && needInputs != null) {
+ profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, rerankCount, needInputs));
+ }
+ }
+ }
+
}