aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-24 14:41:45 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-24 14:41:45 +0000
commit7d2993aad755d93ebba49c967fb51962884486df (patch)
treeee125fb63e03eedef4766a7263b66d3a6be3dded
parent36c0d79e750e1c7b32dcbc6e294a462a8d691bae (diff)
add new components for global-phase handling
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java31
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java152
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java56
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java40
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java53
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java48
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/package-info.java6
-rw-r--r--container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java1
9 files changed, 396 insertions, 5 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java
index 729aebf2fc2..a18250fbcfe 100644
--- a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java
@@ -20,6 +20,7 @@ import com.yahoo.search.Searcher;
import com.yahoo.search.config.ClusterConfig;
import com.yahoo.search.dispatch.Dispatcher;
import com.yahoo.search.query.ParameterParser;
+import com.yahoo.search.ranking.GlobalPhaseHelper;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.schema.SchemaInfo;
import com.yahoo.search.searchchain.Execution;
@@ -64,6 +65,7 @@ public class ClusterSearcher extends Searcher {
private final VespaBackEndSearcher server;
private final Executor executor;
+ private final GlobalPhaseHelper globalPhaseHelper;
@Inject
public ClusterSearcher(ComponentId id,
@@ -73,10 +75,12 @@ public class ClusterSearcher extends Searcher {
DocumentdbInfoConfig documentDbConfig,
SchemaInfo schemaInfo,
ComponentRegistry<Dispatcher> dispatchers,
+ GlobalPhaseHelper globalPhaseHelper,
VipStatus vipStatus,
VespaDocumentAccess access) {
super(id);
this.executor = executor;
+ this.globalPhaseHelper = globalPhaseHelper;
int searchClusterIndex = clusterConfig.clusterId();
searchClusterName = clusterConfig.clusterName();
QrSearchersConfig.Searchcluster searchClusterConfig = getSearchClusterConfigFromClusterName(qrsConfig, searchClusterName);
@@ -159,7 +163,9 @@ public class ClusterSearcher extends Searcher {
maxQueryCacheTimeout = DEFAULT_MAX_QUERY_CACHE_TIMEOUT;
server = searcher;
this.executor = executor;
+ this.globalPhaseHelper = null;
}
+
/** Do not use, for internal testing purposes only. **/
ClusterSearcher(Set<String> schemas) {
this(schemas, null, null);
@@ -169,7 +175,7 @@ public class ClusterSearcher extends Searcher {
public void fill(com.yahoo.search.Result result, String summaryClass, Execution execution) {
Query query = result.getQuery();
- VespaBackEndSearcher searcher = server;
+ Searcher searcher = server;
if (searcher != null) {
if (query.getTimeLeft() > 0) {
searcher.fill(result, summaryClass, execution);
@@ -190,7 +196,7 @@ public class ClusterSearcher extends Searcher {
public Result search(Query query, Execution execution) {
validateQueryTimeout(query);
validateQueryCache(query);
- VespaBackEndSearcher searcher = server;
+ Searcher searcher = server;
if (searcher == null) {
return new Result(query, ErrorMessage.createNoBackendsInService("Could not search"));
}
@@ -228,8 +234,23 @@ public class ClusterSearcher extends Searcher {
} else {
String docType = schemas.iterator().next();
query.getModel().setRestrict(docType);
- return searcher.search(query, execution);
+ return perSchemaSearch(searcher, query, execution);
+ }
+ }
+
+ private Result perSchemaSearch(Searcher searcher, Query query, Execution execution) {
+ Set<String> restrict = query.getModel().getRestrict();
+ if (restrict.size() != 1) {
+ throw new IllegalStateException("perSchemaSearch must always be called with 1 schema, got: " + restrict.size());
+ }
+ for (String schema : restrict) {
+ Result result = searcher.search(query, execution);
+ if (globalPhaseHelper != null) {
+ globalPhaseHelper.process(query, result, schema);
+ }
+ return result;
}
+ return null;
}
private static void processResult(Query query, FutureTask<Result> task, Result mergedResult) {
@@ -248,12 +269,12 @@ public class ClusterSearcher extends Searcher {
Set<String> schemas = resolveSchemas(query, execution.context().getIndexFacts());
List<Query> queries = createQueries(query, schemas);
if (queries.size() == 1) {
- return searcher.search(queries.get(0), execution);
+ return perSchemaSearch(searcher, queries.get(0), execution);
} else {
Result mergedResult = new Result(query);
List<FutureTask<Result>> pending = new ArrayList<>(queries.size());
for (Query q : queries) {
- FutureTask<Result> task = new FutureTask<>(() -> searcher.search(q, execution));
+ FutureTask<Result> task = new FutureTask<>(() -> perSchemaSearch(searcher, q, execution));
try {
executor.execute(task);
pending.add(task);
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java
new file mode 100644
index 00000000000..d2edb776c92
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java
@@ -0,0 +1,14 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.ranking;
+
+import com.yahoo.tensor.Tensor;
+
+import java.util.Collection;
+
+interface Evaluator {
+ Collection<String> needInputs();
+
+ Evaluator bind(String name, Tensor value);
+
+ double evaluateScore();
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java
new file mode 100644
index 00000000000..9810f612e5c
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java
@@ -0,0 +1,152 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.ranking;
+
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.Model;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.search.Query;
+import com.yahoo.search.Result;
+import com.yahoo.search.result.Hit;
+import com.yahoo.search.result.HitGroup;
+import com.yahoo.tensor.Tensor;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.logging.Logger;
+import java.util.function.Supplier;
+
+public class GlobalPhaseHelper {
+
+ private static final Logger logger = Logger.getLogger(GlobalPhaseHelper.class.getName());
+ private final RankingExpressionEvaluatorFactory factory;
+ private final Set<String> skipProcessing = new HashSet<>();
+ private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>();
+
+ @Inject
+ public GlobalPhaseHelper(RankingExpressionEvaluatorFactory factory) {
+ this.factory = factory;
+ logger.info("using factory: " + factory);
+ }
+
+ public void process(Query query, Result result, String schema) {
+ var functionEvaluatorSource = underlying(query, schema);
+ if (functionEvaluatorSource == null) {
+ return;
+ }
+ var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments());
+ Supplier<Evaluator> supplier = () -> {
+ var evaluator = functionEvaluatorSource.get();
+ var simple = new SimpleEvaluator(evaluator);
+ for (var entry : prepared) {
+ simple.bind(entry.name(), entry.value());
+ }
+ return simple;
+ };
+ // TODO need to get rerank-count somehow
+ int rerank = 7;
+ rerankHits(query, result, new HitRescorer(supplier), rerank);
+ }
+
+ record NameAndValue(String name, Tensor value) { }
+
+ /* do this only once per query: */
+ List<NameAndValue> findFromQuery(Query query, List<String> needInputs) {
+ List<NameAndValue> result = new ArrayList<>();
+ var ranking = query.getRanking();
+ var rankFeatures = ranking.getFeatures();
+ var rankProps = ranking.getProperties().asMap();
+ for (String needed : needInputs) {
+ var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed);
+ if (optRef.isEmpty()) continue;
+ var ref = optRef.get();
+ if (ref.name().equals("constant")) {
+ // XXX in theory, we should be able to avoid this
+ result.add(new NameAndValue(needed, null));
+ continue;
+ }
+ if (ref.isSimple() && ref.name().equals("query")) {
+ String queryFeatureName = ref.simpleArgument().get();
+ // searchers are recommended to place query features here:
+ var feature = rankFeatures.getTensor(queryFeatureName);
+ if (feature.isPresent()) {
+ result.add(new NameAndValue(needed, feature.get()));
+ } else {
+ // but other ways of setting query features end up in the properties:
+ var objList = rankProps.get(queryFeatureName);
+ if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) {
+ result.add(new NameAndValue(needed, t));
+ }
+ }
+ }
+ }
+ return result;
+ }
+
+ void rerankHits(Query query, Result result, HitRescorer hitRescorer, int rerank) {
+ double worstRerankedScore = Double.MAX_VALUE;
+ double worstRerankedOldScore = Double.MAX_VALUE;
+ // TODO consider doing recursive iteration instead of deepIterator
+ for (var iterator = result.hits().deepIterator(); iterator.hasNext();) {
+ Hit hit = iterator.next();
+ if (hit.isMeta() || hit instanceof HitGroup) {
+ continue;
+ }
+ // what about hits inside grouping results?
+ if (rerank > 0) {
+ double oldScore = hit.getRelevance().getScore();
+ boolean didRerank = hitRescorer.rescoreHit(hit);
+ if (didRerank) {
+ double newScore = hit.getRelevance().getScore();
+ if (oldScore < worstRerankedOldScore) worstRerankedOldScore = oldScore;
+ if (newScore < worstRerankedScore) worstRerankedScore = newScore;
+ --rerank;
+ } else {
+ // failed to rescore this hit, what should we do?
+ hit.setRelevance(-Double.MAX_VALUE);
+ }
+ } else {
+ // too low quality
+ if (worstRerankedOldScore > worstRerankedScore) {
+ double penalty = worstRerankedOldScore - worstRerankedScore;
+ double oldScore = hit.getRelevance().getScore();
+ hit.setRelevance(oldScore - penalty);
+ }
+ }
+ }
+ result.hits().sort();
+ }
+
+ private Supplier<FunctionEvaluator> underlying(Query query, String schema) {
+ String rankProfile = query.getRanking().getProfile();
+ String key = schema + " with rank profile " + rankProfile;
+ if (skipProcessing.contains(key)) {
+ return null;
+ }
+ Supplier<FunctionEvaluator> supplier = scorers.get(key);
+ if (supplier != null) {
+ return supplier;
+ }
+ try {
+ var proxy = factory.proxyForSchema(schema);
+ var model = proxy.modelForRankProfile(rankProfile);
+ supplier = () -> model.evaluatorOf("globalphase");
+ if (supplier.get() == null) {
+ supplier = null;
+ }
+ } catch (IllegalArgumentException e) {
+ logger.info("no global-phase for " + key + " because: " + e.getMessage());
+ supplier = null;
+ }
+ if (supplier == null) {
+ skipProcessing.add(key);
+ } else {
+ scorers.put(key, supplier);
+ }
+ return supplier;
+ }
+
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
new file mode 100644
index 00000000000..b92f9da2e2a
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java
@@ -0,0 +1,56 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.ranking;
+
+import com.yahoo.search.result.FeatureData;
+import com.yahoo.search.result.Hit;
+
+import java.util.function.Supplier;
+import java.util.logging.Logger;
+
+public class HitRescorer {
+
+ private static final Logger logger = Logger.getLogger(HitRescorer.class.getName());
+
+ private final Supplier<Evaluator> evaluatorSource;
+
+ public HitRescorer(Supplier<Evaluator> evaluatorSource) {
+ this.evaluatorSource = evaluatorSource;
+ }
+
+ boolean rescoreHit(Hit hit) {
+ var features = hit.getField("matchfeatures");
+ if (features instanceof FeatureData matchFeatures) {
+ var scorer = evaluatorSource.get();
+ for (String argName : scorer.needInputs()) {
+ var asTensor = matchFeatures.getTensor(argName);
+ if (asTensor == null) {
+ asTensor = matchFeatures.getTensor(alternate(argName));
+ }
+ if (asTensor != null) {
+ scorer.bind(argName, asTensor);
+ } else {
+ logger.warning("Missing match-feature for Evaluator argument: " + argName);
+ return false;
+ }
+ }
+ double newScore = scorer.evaluateScore();
+ hit.setRelevance(newScore);
+ return true;
+ } else {
+ logger.warning("Hit without match-features: " + hit);
+ return false;
+ }
+ }
+
+ private static final String RE_PREFIX = "rankingExpression(";
+ private static final String RE_SUFFIX = ")";
+ private static final int RE_PRE_LEN = RE_PREFIX.length();
+ private static final int RE_SUF_LEN = RE_SUFFIX.length();
+
+ static String alternate(String argName) {
+ if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) {
+ return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN);
+ }
+ return argName;
+ }
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java
new file mode 100644
index 00000000000..8ec3fc919db
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.search.ranking;
+
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.component.provider.ComponentRegistry;
+
+/**
+ * factory for model-evaluation proxies
+ * @author arnej
+ */
+@Beta
+public class RankingExpressionEvaluatorFactory {
+
+ private final ComponentRegistry<RankingExpressionEvaluatorProxy> registry;
+
+ @Inject
+ public RankingExpressionEvaluatorFactory(ComponentRegistry<RankingExpressionEvaluatorProxy> registry) {
+ this.registry = registry;
+ }
+
+ public RankingExpressionEvaluatorProxy proxyForSchema(String schemaName) {
+ var component = registry.getComponent("ranking-expression-evaluator." + schemaName);
+ if (component == null) {
+ throw new IllegalArgumentException("ranking expression evaluator for schema '" + schemaName + "' not found");
+ }
+ return component;
+ }
+
+ public String toString() {
+ var buf = new StringBuilder();
+ buf.append(this.getClass().getName()).append(" containing: [");
+ for (var id : registry.allComponentsById().keySet()) {
+ buf.append(" ").append(id.toString());
+ }
+ buf.append(" ]");
+ return buf.toString();
+ }
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java
new file mode 100644
index 00000000000..b4ee33263f1
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java
@@ -0,0 +1,53 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package com.yahoo.search.ranking;
+
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import ai.vespa.models.evaluation.Model;
+import ai.vespa.models.evaluation.ModelsEvaluator;
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.component.AbstractComponent;
+import com.yahoo.component.annotation.Inject;
+import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
+
+/**
+ * proxy for model-evaluation components
+ * @author arnej
+ */
+@Beta
+public class RankingExpressionEvaluatorProxy extends AbstractComponent {
+
+ private final ModelsEvaluator evaluator;
+
+ @Inject
+ public RankingExpressionEvaluatorProxy(
+ RankProfilesConfig rankProfilesConfig,
+ RankingConstantsConfig constantsConfig,
+ RankingExpressionsConfig expressionsConfig,
+ OnnxModelsConfig onnxModelsConfig,
+ FileAcquirer fileAcquirer)
+ {
+ this.evaluator = new ModelsEvaluator(
+ rankProfilesConfig,
+ constantsConfig,
+ expressionsConfig,
+ onnxModelsConfig,
+ fileAcquirer);
+ }
+
+ public Model modelForRankProfile(String rankProfile) {
+ var m = evaluator.models().get(rankProfile);
+ if (m == null) {
+ throw new IllegalArgumentException("unknown rankprofile: " + rankProfile);
+ }
+ return m;
+ }
+
+ public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) {
+ return modelForRankProfile(rankProfile).evaluatorOf(functionName);
+ }
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java
new file mode 100644
index 00000000000..39abacc17f1
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.ranking;
+
+import ai.vespa.models.evaluation.FunctionEvaluator;
+import com.yahoo.search.result.FeatureData;
+import com.yahoo.search.result.Hit;
+import com.yahoo.tensor.Tensor;
+
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+public class SimpleEvaluator implements Evaluator {
+
+ private final FunctionEvaluator evaluator;
+ private final Set<String> neededInputs;
+
+
+ public SimpleEvaluator(FunctionEvaluator prototype) {
+ this.evaluator = prototype;
+ this.neededInputs = new HashSet<String>(prototype.function().arguments());
+ }
+
+ public Collection<String> needInputs() { return List.copyOf(neededInputs); }
+
+ public SimpleEvaluator bind(String name, Tensor value) {
+ if (value != null) evaluator.bind(name, value);
+ neededInputs.remove(name);
+ return this;
+ }
+
+ public double evaluateScore() {
+ return evaluator.evaluate().asDouble();
+ }
+
+ public String toString() {
+ var buf = new StringBuilder();
+ buf.append("SimpleEvaluator(");
+ buf.append(evaluator.function().toString());
+ buf.append(")[");
+ for (String arg : neededInputs) {
+ buf.append("{").append(arg).append("}");
+ }
+ buf.append("]");
+ return buf.toString();
+ }
+}
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/package-info.java b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java
new file mode 100644
index 00000000000..a86a5c1e52f
--- /dev/null
+++ b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java
@@ -0,0 +1,6 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+@ExportPackage
+package com.yahoo.search.ranking;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java
index 5df8d2e5444..06ae9923dae 100644
--- a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java
+++ b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java
@@ -464,6 +464,7 @@ public class ClusterSearcherTestCase {
documentDbConfig.build(),
new SchemaInfo(List.of(schema.build()), Map.of()),
dispatchers,
+ null,
vipStatus,
null);
}