aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@yahooinc.com>2023-03-08 09:20:44 +0100
committerGitHub <noreply@github.com>2023-03-08 09:20:44 +0100
commit4de52ed5557f5d16d05e39296a1405223ccd8e54 (patch)
tree95a6ccb34361a04ac5498660f21c5a66540d6831 /container-search
parent34b28580764915adb3fe0e8d5539727eb6f59706 (diff)
parent4b55e59b0bdda6559a40addf0ede434ab955dc07 (diff)
Merge pull request #26347 from vespa-engine/bjorncs/global-phase
Only setup `RankProfilesEvaluator` for schemas with 'global-phase'
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java19
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java11
2 files changed, 10 insertions, 20 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
index b72f81f1439..2c6ab9e9367 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
@@ -1,24 +1,16 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;
-import ai.vespa.models.evaluation.FunctionEvaluator;
-import ai.vespa.models.evaluation.Model;
import com.yahoo.component.annotation.Inject;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData;
-import com.yahoo.search.result.Hit;
-import com.yahoo.search.result.HitGroup;
import com.yahoo.tensor.Tensor;
import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
-import java.util.Map;
-import java.util.Set;
-import java.util.logging.Logger;
import java.util.function.Supplier;
+import java.util.logging.Logger;
public class GlobalPhaseRanker {
@@ -32,12 +24,11 @@ public class GlobalPhaseRanker {
}
public void process(Query query, Result result, String schema) {
- var proxy = factory.proxyForSchema(schema);
String rankProfile = query.getRanking().getProfile();
- var optData = proxy.getGlobalPhaseData(rankProfile);
- if (optData.isEmpty())
- return;
- GlobalPhaseData data = optData.get();
+ GlobalPhaseData data = factory.evaluatorForSchema(schema)
+ .flatMap(evaluator -> evaluator.getGlobalPhaseData(rankProfile))
+ .orElse(null);
+ if (data == null) return;
var functionEvaluatorSource = data.functionEvaluatorSource();
var prepared = findFromQuery(query, data.needInputs());
Supplier<Evaluator> supplier = () -> {
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java
index edb05ed9788..33f2fb74da5 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java
@@ -6,6 +6,8 @@ import com.yahoo.api.annotations.Beta;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.provider.ComponentRegistry;
+import java.util.Optional;
+
/**
* factory for model-evaluation proxies
* @author arnej
@@ -20,14 +22,11 @@ public class RankProfilesEvaluatorFactory {
this.registry = registry;
}
- public RankProfilesEvaluator proxyForSchema(String schemaName) {
- var component = registry.getComponent("ranking-expression-evaluator." + schemaName);
- if (component == null) {
- throw new IllegalArgumentException("ranking expression evaluator for schema '" + schemaName + "' not found");
- }
- return component;
+ public Optional<RankProfilesEvaluator> evaluatorForSchema(String schemaName) {
+ return Optional.ofNullable(registry.getComponent("ranking-expression-evaluator." + schemaName));
}
+ @Override
public String toString() {
var buf = new StringBuilder();
buf.append(this.getClass().getName()).append(" containing: [");