diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-02-28 14:20:53 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-28 14:20:53 +0100 |
commit | 00de2a92c9cdc1056f718674b055a9639fe64b3b (patch) | |
tree | 66ad9d07b5292d5961353d6454c78ebda6ad1922 /container-search | |
parent | 05a7e7a75feae44c6ca9ed27d4d3570873603702 (diff) | |
parent | 18aa303fe0959786838aa63ec8f0aca092be2d99 (diff) |
Merge pull request #26179 from vespa-engine/arnej/add-new-components-4
add new components for global-phase handling
Diffstat (limited to 'container-search')
10 files changed, 454 insertions, 5 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java index 729aebf2fc2..7787d7d7702 100644 --- a/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/cluster/ClusterSearcher.java @@ -20,6 +20,7 @@ import com.yahoo.search.Searcher; import com.yahoo.search.config.ClusterConfig; import com.yahoo.search.dispatch.Dispatcher; import com.yahoo.search.query.ParameterParser; +import com.yahoo.search.ranking.GlobalPhaseRanker; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.searchchain.Execution; @@ -64,6 +65,7 @@ public class ClusterSearcher extends Searcher { private final VespaBackEndSearcher server; private final Executor executor; + private final GlobalPhaseRanker globalPhaseHelper; @Inject public ClusterSearcher(ComponentId id, @@ -73,10 +75,12 @@ public class ClusterSearcher extends Searcher { DocumentdbInfoConfig documentDbConfig, SchemaInfo schemaInfo, ComponentRegistry<Dispatcher> dispatchers, + GlobalPhaseRanker globalPhaseHelper, VipStatus vipStatus, VespaDocumentAccess access) { super(id); this.executor = executor; + this.globalPhaseHelper = globalPhaseHelper; int searchClusterIndex = clusterConfig.clusterId(); searchClusterName = clusterConfig.clusterName(); QrSearchersConfig.Searchcluster searchClusterConfig = getSearchClusterConfigFromClusterName(qrsConfig, searchClusterName); @@ -159,7 +163,9 @@ public class ClusterSearcher extends Searcher { maxQueryCacheTimeout = DEFAULT_MAX_QUERY_CACHE_TIMEOUT; server = searcher; this.executor = executor; + this.globalPhaseHelper = null; } + /** Do not use, for internal testing purposes only. **/ ClusterSearcher(Set<String> schemas) { this(schemas, null, null); @@ -169,7 +175,7 @@ public class ClusterSearcher extends Searcher { public void fill(com.yahoo.search.Result result, String summaryClass, Execution execution) { Query query = result.getQuery(); - VespaBackEndSearcher searcher = server; + Searcher searcher = server; if (searcher != null) { if (query.getTimeLeft() > 0) { searcher.fill(result, summaryClass, execution); @@ -190,7 +196,7 @@ public class ClusterSearcher extends Searcher { public Result search(Query query, Execution execution) { validateQueryTimeout(query); validateQueryCache(query); - VespaBackEndSearcher searcher = server; + Searcher searcher = server; if (searcher == null) { return new Result(query, ErrorMessage.createNoBackendsInService("Could not search")); } @@ -228,8 +234,21 @@ public class ClusterSearcher extends Searcher { } else { String docType = schemas.iterator().next(); query.getModel().setRestrict(docType); - return searcher.search(query, execution); + return perSchemaSearch(searcher, query, execution); + } + } + + private Result perSchemaSearch(Searcher searcher, Query query, Execution execution) { + Set<String> restrict = query.getModel().getRestrict(); + if (restrict.size() != 1) { + throw new IllegalStateException("perSchemaSearch must always be called with 1 schema, got: " + restrict.size()); + } + String schema = restrict.iterator().next(); + Result result = searcher.search(query, execution); + if (globalPhaseHelper != null) { + globalPhaseHelper.process(query, result, schema); } + return result; } private static void processResult(Query query, FutureTask<Result> task, Result mergedResult) { @@ -248,12 +267,12 @@ public class ClusterSearcher extends Searcher { Set<String> schemas = resolveSchemas(query, execution.context().getIndexFacts()); List<Query> queries = createQueries(query, schemas); if (queries.size() == 1) { - return searcher.search(queries.get(0), execution); + return perSchemaSearch(searcher, queries.get(0), execution); } else { Result mergedResult = new Result(query); List<FutureTask<Result>> pending = new ArrayList<>(queries.size()); for (Query q : queries) { - FutureTask<Result> task = new FutureTask<>(() -> searcher.search(q, execution)); + FutureTask<Result> task = new FutureTask<>(() -> perSchemaSearch(searcher, q, execution)); try { executor.execute(task); pending.add(task); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java new file mode 100644 index 00000000000..d2edb776c92 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java @@ -0,0 +1,14 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.tensor.Tensor; + +import java.util.Collection; + +interface Evaluator { + Collection<String> needInputs(); + + Evaluator bind(String name, Tensor value); + + double evaluateScore(); +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java new file mode 100644 index 00000000000..87213362acd --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -0,0 +1,118 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; +import com.yahoo.component.annotation.Inject; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; +import com.yahoo.tensor.Tensor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.logging.Logger; +import java.util.function.Supplier; + +public class GlobalPhaseRanker { + + private static final Logger logger = Logger.getLogger(GlobalPhaseRanker.class.getName()); + private final RankProfilesEvaluatorFactory factory; + private final Set<String> skipProcessing = new HashSet<>(); + private final Map<String, Supplier<FunctionEvaluator>> scorers = new HashMap<>(); + + @Inject + public GlobalPhaseRanker(RankProfilesEvaluatorFactory factory) { + this.factory = factory; + logger.info("using factory: " + factory); + } + + public void process(Query query, Result result, String schema) { + var functionEvaluatorSource = underlying(query, schema); + if (functionEvaluatorSource == null) { + return; + } + var prepared = findFromQuery(query, functionEvaluatorSource.get().function().arguments()); + Supplier<Evaluator> supplier = () -> { + var evaluator = functionEvaluatorSource.get(); + var simple = new SimpleEvaluator(evaluator); + for (var entry : prepared) { + simple.bind(entry.name(), entry.value()); + } + return simple; + }; + // TODO need to get rerank-count somehow + int rerank = 7; + ResultReranker.rerankHits(result, new HitRescorer(supplier), rerank); + } + + record NameAndValue(String name, Tensor value) { } + + /* do this only once per query: */ + List<NameAndValue> findFromQuery(Query query, List<String> needInputs) { + List<NameAndValue> result = new ArrayList<>(); + var ranking = query.getRanking(); + var rankFeatures = ranking.getFeatures(); + var rankProps = ranking.getProperties().asMap(); + for (String needed : needInputs) { + var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed); + if (optRef.isEmpty()) continue; + var ref = optRef.get(); + if (ref.name().equals("constant")) { + // XXX in theory, we should be able to avoid this + result.add(new NameAndValue(needed, null)); + continue; + } + if (ref.isSimple() && ref.name().equals("query")) { + String queryFeatureName = ref.simpleArgument().get(); + // searchers are recommended to place query features here: + var feature = rankFeatures.getTensor(queryFeatureName); + if (feature.isPresent()) { + result.add(new NameAndValue(needed, feature.get())); + } else { + // but other ways of setting query features end up in the properties: + var objList = rankProps.get(queryFeatureName); + if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { + result.add(new NameAndValue(needed, t)); + } + } + } + } + return result; + } + + private Supplier<FunctionEvaluator> underlying(Query query, String schema) { + String rankProfile = query.getRanking().getProfile(); + String key = schema + " with rank profile " + rankProfile; + if (skipProcessing.contains(key)) { + return null; + } + Supplier<FunctionEvaluator> supplier = scorers.get(key); + if (supplier != null) { + return supplier; + } + try { + var proxy = factory.proxyForSchema(schema); + var model = proxy.modelForRankProfile(rankProfile); + supplier = () -> model.evaluatorOf("globalphase"); + if (supplier.get() == null) { + supplier = null; + } + } catch (IllegalArgumentException e) { + logger.info("no global-phase for " + key + " because: " + e.getMessage()); + supplier = null; + } + if (supplier == null) { + skipProcessing.add(key); + } else { + scorers.put(key, supplier); + } + return supplier; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java new file mode 100644 index 00000000000..ebdbbb693f1 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -0,0 +1,56 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; + +import java.util.function.Supplier; +import java.util.logging.Logger; + +class HitRescorer { + + private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); + + private final Supplier<Evaluator> evaluatorSource; + + public HitRescorer(Supplier<Evaluator> evaluatorSource) { + this.evaluatorSource = evaluatorSource; + } + + boolean rescoreHit(Hit hit) { + var features = hit.getField("matchfeatures"); + if (features instanceof FeatureData matchFeatures) { + var scorer = evaluatorSource.get(); + for (String argName : scorer.needInputs()) { + var asTensor = matchFeatures.getTensor(argName); + if (asTensor == null) { + asTensor = matchFeatures.getTensor(alternate(argName)); + } + if (asTensor != null) { + scorer.bind(argName, asTensor); + } else { + logger.warning("Missing match-feature for Evaluator argument: " + argName); + return false; + } + } + double newScore = scorer.evaluateScore(); + hit.setRelevance(newScore); + return true; + } else { + logger.warning("Hit without match-features: " + hit); + return false; + } + } + + private static final String RE_PREFIX = "rankingExpression("; + private static final String RE_SUFFIX = ")"; + private static final int RE_PRE_LEN = RE_PREFIX.length(); + private static final int RE_SUF_LEN = RE_SUFFIX.length(); + + static String alternate(String argName) { + if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { + return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + } + return argName; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java new file mode 100644 index 00000000000..ccb9b9837fe --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java @@ -0,0 +1,53 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import ai.vespa.models.evaluation.Model; +import ai.vespa.models.evaluation.ModelsEvaluator; +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.OnnxModelsConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.config.search.core.RankingExpressionsConfig; + +/** + * proxy for model-evaluation components + * @author arnej + */ +@Beta +public class RankProfilesEvaluator extends AbstractComponent { + + private final ModelsEvaluator evaluator; + + @Inject + public RankProfilesEvaluator( + RankProfilesConfig rankProfilesConfig, + RankingConstantsConfig constantsConfig, + RankingExpressionsConfig expressionsConfig, + OnnxModelsConfig onnxModelsConfig, + FileAcquirer fileAcquirer) + { + this.evaluator = new ModelsEvaluator( + rankProfilesConfig, + constantsConfig, + expressionsConfig, + onnxModelsConfig, + fileAcquirer); + } + + public Model modelForRankProfile(String rankProfile) { + var m = evaluator.models().get(rankProfile); + if (m == null) { + throw new IllegalArgumentException("unknown rankprofile: " + rankProfile); + } + return m; + } + + public FunctionEvaluator evaluatorForFunction(String rankProfile, String functionName) { + return modelForRankProfile(rankProfile).evaluatorOf(functionName); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java new file mode 100644 index 00000000000..edb05ed9788 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.ranking; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; + +/** + * factory for model-evaluation proxies + * @author arnej + */ +@Beta +public class RankProfilesEvaluatorFactory { + + private final ComponentRegistry<RankProfilesEvaluator> registry; + + @Inject + public RankProfilesEvaluatorFactory(ComponentRegistry<RankProfilesEvaluator> registry) { + this.registry = registry; + } + + public RankProfilesEvaluator proxyForSchema(String schemaName) { + var component = registry.getComponent("ranking-expression-evaluator." + schemaName); + if (component == null) { + throw new IllegalArgumentException("ranking expression evaluator for schema '" + schemaName + "' not found"); + } + return component; + } + + public String toString() { + var buf = new StringBuilder(); + buf.append(this.getClass().getName()).append(" containing: ["); + for (var id : registry.allComponentsById().keySet()) { + buf.append(" ").append(id.toString()); + } + buf.append(" ]"); + return buf.toString(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java new file mode 100644 index 00000000000..11b3fa7390a --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java @@ -0,0 +1,91 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.search.Result; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.logging.Logger; + +class ResultReranker { + + private static final Logger logger = Logger.getLogger(ResultReranker.class.getName()); + + // scale and adjust the score according to the range + // of the original and final score values to avoid that + // a score from the backend is larger than finalScores_low + static class Ranges { + private double initialScores_high = -Double.MAX_VALUE; + private double initialScores_low = Double.MAX_VALUE; + private double finalScores_high = -Double.MAX_VALUE; + private double finalScores_low = Double.MAX_VALUE; + + boolean valid() { + return (initialScores_high >= initialScores_low + && + finalScores_high >= finalScores_low); + } + void withInitialScore(double score) { + if (score < initialScores_low) initialScores_low = score; + if (score > initialScores_high) initialScores_high = score; + } + void withFinalScore(double score) { + if (score < finalScores_low) finalScores_low = score; + if (score > finalScores_high) finalScores_high = score; + } + private double initialRange() { + double r = initialScores_high - initialScores_low; + if (r < 1.0) r = 1.0; + return r; + } + private double finalRange() { + double r = finalScores_high - finalScores_low; + if (r < 1.0) r = 1.0; + return r; + } + double scale() { return finalRange() / initialRange(); } + double bias() { return finalScores_low - initialScores_low * scale(); } + } + + static void rerankHits(Result result, HitRescorer hitRescorer, int rerankCount) { + List<Hit> hitsToRescore = new ArrayList<>(); + // consider doing recursive iteration explicitly instead of using deepIterator? + for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { + Hit hit = iterator.next(); + if (hit.isMeta() || hit instanceof HitGroup) { + continue; + } + // what about hits inside grouping results? + // they are inside GroupingListHit, we won't recurse into it; so we won't see them. + hitsToRescore.add(hit); + } + // we can't be 100% certain that hits were sorted according to relevance: + hitsToRescore.sort(Comparator.naturalOrder()); + var ranges = new Ranges(); + for (var iterator = hitsToRescore.iterator(); rerankCount > 0 && iterator.hasNext(); ) { + Hit hit = iterator.next(); + double oldScore = hit.getRelevance().getScore(); + boolean didRerank = hitRescorer.rescoreHit(hit); + if (didRerank) { + ranges.withInitialScore(oldScore); + ranges.withFinalScore(hit.getRelevance().getScore()); + --rerankCount; + iterator.remove(); + } + } + // if any hits are left in the list, they need rescaling: + if (ranges.valid()) { + double scale = ranges.scale(); + double bias = ranges.bias(); + for (Hit hit : hitsToRescore) { + double oldScore = hit.getRelevance().getScore(); + hit.setRelevance(oldScore * scale + bias); + } + } + result.hits().sort(); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java new file mode 100644 index 00000000000..f247eab1649 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java @@ -0,0 +1,51 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +class SimpleEvaluator implements Evaluator { + + private final FunctionEvaluator evaluator; + private final Set<String> neededInputs; + + public SimpleEvaluator(FunctionEvaluator prototype) { + this.evaluator = prototype; + this.neededInputs = new HashSet<String>(prototype.function().arguments()); + } + + @Override + public Collection<String> needInputs() { return List.copyOf(neededInputs); } + + @Override + public SimpleEvaluator bind(String name, Tensor value) { + if (value != null) evaluator.bind(name, value); + neededInputs.remove(name); + return this; + } + + @Override + public double evaluateScore() { + return evaluator.evaluate().asDouble(); + } + + @Override + public String toString() { + var buf = new StringBuilder(); + buf.append("SimpleEvaluator("); + buf.append(evaluator.function().toString()); + buf.append(")["); + for (String arg : neededInputs) { + buf.append("{").append(arg).append("}"); + } + buf.append("]"); + return buf.toString(); + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/package-info.java b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java new file mode 100644 index 00000000000..a86a5c1e52f --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java @@ -0,0 +1,6 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.search.ranking; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java index 5df8d2e5444..06ae9923dae 100644 --- a/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/cluster/ClusterSearcherTestCase.java @@ -464,6 +464,7 @@ public class ClusterSearcherTestCase { documentDbConfig.build(), new SchemaInfo(List.of(schema.build()), Map.of()), dispatchers, + null, vipStatus, null); } |