diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-02-24 14:41:45 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-02-24 14:41:45 +0000 |
commit | 7d2993aad755d93ebba49c967fb51962884486df (patch) | |
tree | ee125fb63e03eedef4766a7263b66d3a6be3dded | |
parent | 36c0d79e750e1c7b32dcbc6e294a462a8d691bae (diff) |
add new components for global-phase handling
9 files changed, 396 insertions, 5 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java index 729aebf2fc2..a18250fbcfe 100644 --- a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java @@ -20,6 +20,7 @@ import com.yahoo.search.Searcher; import com.yahoo.search.config.ClusterConfig; import com.yahoo.search.dispatch.Dispatcher; import com.yahoo.search.query.ParameterParser; +import com.yahoo.search.ranking.GlobalPhaseHelper; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.searchchain.Execution; @@ -64,6 +65,7 @@ public class ClusterSearcher extends Searcher { private final VespaBackEndSearcher server; private final Executor executor; + private final GlobalPhaseHelper globalPhaseHelper; @Inject public ClusterSearcher(ComponentId id, @@ -73,10 +75,12 @@ public class ClusterSearcher extends Searcher { DocumentdbInfoConfig documentDbConfig, SchemaInfo schemaInfo, ComponentRegistry<Dispatcher> dispatchers, + GlobalPhaseHelper globalPhaseHelper, VipStatus vipStatus, VespaDocumentAccess access) { super(id); this.executor = executor; + this.globalPhaseHelper = globalPhaseHelper; int searchClusterIndex = clusterConfig.clusterId(); searchClusterName = clusterConfig.clusterName(); QrSearchersConfig.Searchcluster searchClusterConfig = getSearchClusterConfigFromClusterName(qrsConfig, searchClusterName); @@ -159,7 +163,9 @@ public class ClusterSearcher extends Searcher { maxQueryCacheTimeout = DEFAULT_MAX_QUERY_CACHE_TIMEOUT; server = searcher; this.executor = executor; + this.globalPhaseHelper = null; } + /** Do not use, for internal testing purposes only. **/ ClusterSearcher(Set<String> schemas) { this(schemas, null, null); @@ -169,7 +175,7 @@ public class ClusterSearcher extends Searcher { public void fill(com.yahoo.search.Result result, String summaryClass, Execution execution) { Query query = result.getQuery(); - VespaBackEndSearcher searcher = server; + Searcher searcher = server; if (searcher != null) { if (query.getTimeLeft() > 0) { searcher.fill(result, summaryClass, execution); @@ -190,7 +196,7 @@ public class ClusterSearcher extends Searcher { public Result search(Query query, Execution execution) { validateQueryTimeout(query); validateQueryCache(query); - VespaBackEndSearcher searcher = server; + Searcher searcher = server; if (searcher == null) { return new Result(query, ErrorMessage.createNoBackendsInService("Could not search")); } @@ -228,8 +234,23 @@ public class ClusterSearcher extends Searcher { } else { String docType = schemas.iterator().next(); query.getModel().setRestrict(docType); - return searcher.search(query, execution); + return perSchemaSearch(searcher, query, execution); + } + } + + private Result perSchemaSearch(Searcher searcher, Query query, Execution execution) { + Set<String> restrict = query.getModel().getRestrict(); + if (restrict.size() != 1) { + throw new IllegalStateException("perSchemaSearch must always be called with 1 schema, got: " + restrict.size()); + } + for (String schema : restrict) { + Result result = searcher.search(query, execution); + if (globalPhaseHelper != null) { + globalPhaseHelper.process(query, result, schema); + } + return result; } + return null; } private static void processResult(Query query, FutureTask<Result> task, Result mergedResult) { @@ -248,12 +269,12 @@ public class ClusterSearcher extends Searcher { Set<String> schemas = resolveSchemas(query, execution.context().getIndexFacts()); List<Query> queries = createQueries(query, schemas); if (queries.size() == 1) { - return searcher.search(queries.get(0), execution); + return perSchemaSearch(searcher, queries.get(0), execution); } else { Result mergedResult = new Result(query); List<FutureTask<Result>> pending = new ArrayList<>(queries.size()); for (Query q : queries) { - FutureTask<Result> task = new FutureTask<>(() -> searcher.search(q, execution)); + FutureTask<Result> task = new FutureTask<>(() -> perSchemaSearch(searcher, q, execution)); try { executor.execute(task); pending.add(task); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java new file mode 100644 index 00000000000..d2edb776c92 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java @@ -0,0 +1,14 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.tensor.Tensor; + +import java.util.Collection; + +interface Evaluator { + Collection<String> needInputs(); + + Evaluator bind(String name, Tensor value); + + double evaluateScore(); +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java new file mode 100644 index 00000000000..9810f612e5c --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseHelper.java @@ -0,0 +1,152 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; +import com.yahoo.component.annotation.Inject; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; +import com.yahoo.tensor.Tensor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import java.util.function.Supplier; + +public class GlobalPhaseHelper { + + private static final Logger logger = Logger.getLogger(GlobalPhaseHelper.class.getName()); + private final RankingExpressionEvaluatorFactory factory; + private final Set<String> skipProcessing = new HashSet<>(); + private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>(); + + @Inject + public GlobalPhaseHelper(RankingExpressionEvaluatorFactory factory) { + this.factory = factory; + logger.info("using factory: " + factory); + } + + public void process(Query query, Result result, String schema) { + var functionEvaluatorSource = underlying(query, schema); + if (functionEvaluatorSource == null) { + return; + } + var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments()); + Supplier<Evaluator> supplier = () -> { + var evaluator = functionEvaluatorSource.get(); + var simple = new SimpleEvaluator(evaluator); + for (var entry : prepared) { + simple.bind(entry.name(), entry.value()); + } + return simple; + }; + // TODO need to get rerank-count somehow + int rerank = 7; + rerankHits(query, result, new HitRescorer(supplier), rerank); + } + + record NameAndValue(String name, Tensor value) { } + + /* do this only once per query: */ + List<NameAndValue> findFromQuery(Query query, List<String> needInputs) { + List<NameAndValue> result = new ArrayList<>(); + var ranking = query.getRanking(); + var rankFeatures = ranking.getFeatures(); + var rankProps = ranking.getProperties().asMap(); + for (String needed : needInputs) { + var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed); + if (optRef.isEmpty()) continue; + var ref = optRef.get(); + if (ref.name().equals("constant")) { + // XXX in theory, we should be able to avoid this + result.add(new NameAndValue(needed, null)); + continue; + } + if (ref.isSimple() && ref.name().equals("query")) { + String queryFeatureName = ref.simpleArgument().get(); + // searchers are recommended to place query features here: + var feature = rankFeatures.getTensor(queryFeatureName); + if (feature.isPresent()) { + result.add(new NameAndValue(needed, feature.get())); + } else { + // but other ways of setting query features end up in the properties: + var objList = rankProps.get(queryFeatureName); + if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { + result.add(new NameAndValue(needed, t)); + } + } + } + } + return result; + } + + void rerankHits(Query query, Result result, HitRescorer hitRescorer, int rerank) { + double worstRerankedScore = Double.MAX_VALUE; + double worstRerankedOldScore = Double.MAX_VALUE; + // TODO consider doing recursive iteration instead of deepIterator + for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { + Hit hit = iterator.next(); + if (hit.isMeta() || hit instanceof HitGroup) { + continue; + } + // what about hits inside grouping results? + if (rerank > 0) { + double oldScore = hit.getRelevance().getScore(); + boolean didRerank = hitRescorer.rescoreHit(hit); + if (didRerank) { + double newScore = hit.getRelevance().getScore(); + if (oldScore < worstRerankedOldScore) worstRerankedOldScore = oldScore; + if (newScore < worstRerankedScore) worstRerankedScore = newScore; + --rerank; + } else { + // failed to rescore this hit, what should we do? + hit.setRelevance(-Double.MAX_VALUE); + } + } else { + // too low quality + if (worstRerankedOldScore > worstRerankedScore) { + double penalty = worstRerankedOldScore - worstRerankedScore; + double oldScore = hit.getRelevance().getScore(); + hit.setRelevance(oldScore - penalty); + } + } + } + result.hits().sort(); + } + + private Supplier<FunctionEvaluator> underlying(Query query, String schema) { + String rankProfile = query.getRanking().getProfile(); + String key = schema + " with rank profile " + rankProfile; + if (skipProcessing.contains(key)) { + return null; + } + Supplier<FunctionEvaluator> supplier = scorers.get(key); + if (supplier != null) { + return supplier; + } + try { + var proxy = factory.proxyForSchema(schema); + var model = proxy.modelForRankProfile(rankProfile); + supplier = () -> model.evaluatorOf("globalphase"); + if (supplier.get() == null) { + supplier = null; + } + } catch (IllegalArgumentException e) { + logger.info("no global-phase for " + key + " because: " + e.getMessage()); + supplier = null; + } + if (supplier == null) { + skipProcessing.add(key); + } else { + scorers.put(key, supplier); + } + return supplier; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java new file mode 100644 index 00000000000..b92f9da2e2a --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -0,0 +1,56 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; + +import java.util.function.Supplier; +import java.util.logging.Logger; + +public class HitRescorer { + + private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); + + private final Supplier<Evaluator> evaluatorSource; + + public HitRescorer(Supplier<Evaluator> evaluatorSource) { + this.evaluatorSource = evaluatorSource; + } + + boolean rescoreHit(Hit hit) { + var features = hit.getField("matchfeatures"); + if (features instanceof FeatureData matchFeatures) { + var scorer = evaluatorSource.get(); + for (String argName : scorer.needInputs()) { + var asTensor = matchFeatures.getTensor(argName); + if (asTensor == null) { + asTensor = matchFeatures.getTensor(alternate(argName)); + } + if (asTensor != null) { + scorer.bind(argName, asTensor); + } else { + logger.warning("Missing match-feature for Evaluator argument: " + argName); + return false; + } + } + double newScore = scorer.evaluateScore(); + hit.setRelevance(newScore); + return true; + } else { + logger.warning("Hit without match-features: " + hit); + return false; + } + } + + private static final String RE_PREFIX = "rankingExpression("; + private static final String RE_SUFFIX = ")"; + private static final int RE_PRE_LEN = RE_PREFIX.length(); + private static final int RE_SUF_LEN = RE_SUFFIX.length(); + + static String alternate(String argName) { + if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { + return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + } + return argName; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java new file mode 100644 index 00000000000..8ec3fc919db --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorFactory.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.ranking; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; + +/** + * factory for model-evaluation proxies + * @author arnej + */ +@Beta +public class RankingExpressionEvaluatorFactory { + + private final ComponentRegistry<RankingExpressionEvaluatorProxy> registry; + + @Inject + public RankingExpressionEvaluatorFactory(ComponentRegistry<RankingExpressionEvaluatorProxy> registry) { + this.registry = registry; + } + + public RankingExpressionEvaluatorProxy proxyForSchema(String schemaName) { + var component = registry.getComponent("ranking-expression-evaluator." + schemaName); + if (component == null) { + throw new IllegalArgumentException("ranking expression evaluator for schema '" + schemaName + "' not found"); + } + return component; + } + + public String toString() { + var buf = new StringBuilder(); + buf.append(this.getClass().getName()).append(" containing: ["); + for (var id : registry.allComponentsById().keySet()) { + buf.append(" ").append(id.toString()); + } + buf.append(" ]"); + return buf.toString(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java new file mode 100644 index 00000000000..b4ee33263f1 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankingExpressionEvaluatorProxy.java @@ -0,0 +1,53 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; + +/** + * proxy for model-evaluation components + * @author arnej + */ +@Beta +public class RankingExpressionEvaluatorProxy extends AbstractComponent { + + private final ModelsEvaluator evaluator; + + @Inject + public RankingExpressionEvaluatorProxy( + RankProfilesConfig rankProfilesConfig, + RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, + OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer) + { + this.evaluator = new ModelsEvaluator( + rankProfilesConfig, + constantsConfig, + expressionsConfig, + onnxModelsConfig, + fileAcquirer); + } + + public Model modelForRankProfile(String rankProfile) { + var m = evaluator.models().get(rankProfile); + if (m == null) { + throw new IllegalArgumentException("unknown rankprofile: " + rankProfile); + } + return m; + } + + public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) { + return modelForRankProfile(rankProfile).evaluatorOf(functionName); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java new file mode 100644 index 00000000000..39abacc17f1 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +public class SimpleEvaluator implements Evaluator { + + private final FunctionEvaluator evaluator; + private final Set<String> neededInputs; + + + public SimpleEvaluator(FunctionEvaluator prototype) { + this.evaluator = prototype; + this.neededInputs = new HashSet<String>(prototype.function().arguments()); + } + + public Collection<String> needInputs() { return List.copyOf(neededInputs); } + + public SimpleEvaluator bind(String name, Tensor value) { + if (value != null) evaluator.bind(name, value); + neededInputs.remove(name); + return this; + } + + public double evaluateScore() { + return evaluator.evaluate().asDouble(); + } + + public String toString() { + var buf = new StringBuilder(); + buf.append("SimpleEvaluator("); + buf.append(evaluator.function().toString()); + buf.append(")["); + for (String arg : neededInputs) { + buf.append("{").append(arg).append("}"); + } + buf.append("]"); + return buf.toString(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/package-info.java b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java new file mode 100644 index 00000000000..a86a5c1e52f --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java @@ -0,0 +1,6 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.search.ranking; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java index 5df8d2e5444..06ae9923dae 100644 --- a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java @@ -464,6 +464,7 @@ public class ClusterSearcherTestCase { documentDbConfig.build(), new SchemaInfo(List.of(schema.build()), Map.of()), dispatchers, + null, vipStatus, null); } |