diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java | 108 |
1 files changed, 50 insertions, 58 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 2aa9fd32795..91acc883803 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -1,11 +1,10 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.component.annotation.Inject; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.query.Sorting; -import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; @@ -14,10 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; @@ -32,9 +28,14 @@ public class GlobalPhaseRanker { logger.fine(() -> "Using factory: " + factory); } + public int getRerankCount(Query query, String schema) { + var setup = globalPhaseSetupFor(query, schema).orElse(null); + return resolveRerankCount(setup, query); + } + public Optional<ErrorMessage> validateNoSorting(Query query, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return Optional.empty(); + var setup = globalPhaseSetupFor(query, schema).orElse(null); + if (setup == null) return Optional.empty(); var sorting = query.getRanking().getSorting(); if (sorting == null || sorting.fieldOrders() == null) return Optional.empty(); for (var fieldOrder : sorting.fieldOrders()) { @@ -46,27 +47,42 @@ public class GlobalPhaseRanker { return Optional.empty(); } + static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) { + var mainSpec = setup.globalPhaseEvalSpec; + var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query); + int rerankCount = resolveRerankCount(setup, query); + var normalizers = new ArrayList<NormalizerContext>(); + for (var nSetup : setup.normalizers) { + var normSpec = nSetup.inputEvalSpec(); + var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query); + normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF())); + } + var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers); + var reranker = new ResultReranker(rescorer, rerankCount); + reranker.rerankHits(result); + hideImplicitMatchFeatures(result, setup.matchFeaturesToHide); + } + public void rerankHits(Query query, Result result, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return; - var functionEvaluatorSource = data.functionEvaluatorSource(); - var prepared = findFromQuery(query, data.needInputs()); + var setup = globalPhaseSetupFor(query, schema); + if (setup.isPresent()) { + rerankHitsImpl(setup.get(), query, result); + } + } + + static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Map<String, Tensor> defaultValues, Query query) { + var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues); Supplier<Evaluator> supplier = () -> { - var evaluator = functionEvaluatorSource.get(); - var simple = new SimpleEvaluator(evaluator); + var evaluator = evalSource.get(); for (var entry : prepared) { - simple.bind(entry.name(), entry.value()); + evaluator.bind(entry.name(), entry.value()); } - return simple; + return evaluator; }; - int rerankCount = data.rerankCount(); - if (rerankCount < 0) - rerankCount = 100; - ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount); - hideImplicitMatchFeatures(result, data.matchFeaturesToHide()); + return supplier; } - private void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) { + private static void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) { if (namesToHide.size() == 0) return; var filter = new MatchFeatureFilter(namesToHide); for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { @@ -80,51 +96,27 @@ public class GlobalPhaseRanker { if (newValue.fieldCount() == 0) { hit.removeField("matchfeatures"); } else { - hit.setField("matchfeatures", newValue); + hit.setField("matchfeatures", new FeatureData(newValue)); } } } } } - private Optional<GlobalPhaseData> globalPhaseDataFor(Query query, String schema) { + private Optional<GlobalPhaseSetup> globalPhaseSetupFor(Query query, String schema) { return factory.evaluatorForSchema(schema) - .flatMap(evaluator -> evaluator.getGlobalPhaseData(query.getRanking().getProfile())); + .flatMap(evaluator -> evaluator.getGlobalPhaseSetup(query.getRanking().getProfile())); } - record NameAndValue(String name, Tensor value) { } - - /* do this only once per query: */ - List<NameAndValue> findFromQuery(Query query, List<String> needInputs) { - List<NameAndValue> result = new ArrayList<>(); - var ranking = query.getRanking(); - var rankFeatures = ranking.getFeatures(); - var rankProps = ranking.getProperties().asMap(); - for (String needed : needInputs) { - var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed); - if (optRef.isEmpty()) continue; - var ref = optRef.get(); - if (ref.name().equals("constant")) { - // XXX in theory, we should be able to avoid this - result.add(new NameAndValue(needed, null)); - continue; - } - if (ref.isSimple() && ref.name().equals("query")) { - String queryFeatureName = ref.simpleArgument().get(); - // searchers are recommended to place query features here: - var feature = rankFeatures.getTensor(queryFeatureName); - if (feature.isPresent()) { - result.add(new NameAndValue(needed, feature.get())); - } else { - // but other ways of setting query features end up in the properties: - var objList = rankProps.get(queryFeatureName); - if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { - result.add(new NameAndValue(needed, t)); - } - } - } + private static int resolveRerankCount(GlobalPhaseSetup setup, Query query) { + if (setup == null) { + // there is no global-phase at all (ignore override) + return 0; } - return result; + Integer override = query.getRanking().getGlobalPhase().getRerankCount(); + if (override != null) { + return override; + } + return setup.rerankCount; } - } |