aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java108
1 files changed, 50 insertions, 58 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
index 2aa9fd32795..91acc883803 100644
--- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
+++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java
@@ -1,11 +1,10 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.ranking;
import com.yahoo.component.annotation.Inject;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.query.Sorting;
-import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.result.FeatureData;
import com.yahoo.search.result.Hit;
@@ -14,10 +13,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.data.access.helpers.MatchFeatureData;
import com.yahoo.data.access.helpers.MatchFeatureFilter;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
import java.util.function.Supplier;
import java.util.logging.Logger;
@@ -32,9 +28,14 @@ public class GlobalPhaseRanker {
logger.fine(() -> "Using factory: " + factory);
}
+ public int getRerankCount(Query query, String schema) {
+ var setup = globalPhaseSetupFor(query, schema).orElse(null);
+ return resolveRerankCount(setup, query);
+ }
+
public Optional<ErrorMessage> validateNoSorting(Query query, String schema) {
- var data = globalPhaseDataFor(query, schema).orElse(null);
- if (data == null) return Optional.empty();
+ var setup = globalPhaseSetupFor(query, schema).orElse(null);
+ if (setup == null) return Optional.empty();
var sorting = query.getRanking().getSorting();
if (sorting == null || sorting.fieldOrders() == null) return Optional.empty();
for (var fieldOrder : sorting.fieldOrders()) {
@@ -46,27 +47,42 @@ public class GlobalPhaseRanker {
return Optional.empty();
}
+ static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) {
+ var mainSpec = setup.globalPhaseEvalSpec;
+ var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query);
+ int rerankCount = resolveRerankCount(setup, query);
+ var normalizers = new ArrayList<NormalizerContext>();
+ for (var nSetup : setup.normalizers) {
+ var normSpec = nSetup.inputEvalSpec();
+ var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query);
+ normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF()));
+ }
+ var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers);
+ var reranker = new ResultReranker(rescorer, rerankCount);
+ reranker.rerankHits(result);
+ hideImplicitMatchFeatures(result, setup.matchFeaturesToHide);
+ }
+
public void rerankHits(Query query, Result result, String schema) {
- var data = globalPhaseDataFor(query, schema).orElse(null);
- if (data == null) return;
- var functionEvaluatorSource = data.functionEvaluatorSource();
- var prepared = findFromQuery(query, data.needInputs());
+ var setup = globalPhaseSetupFor(query, schema);
+ if (setup.isPresent()) {
+ rerankHitsImpl(setup.get(), query, result);
+ }
+ }
+
+ static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Map<String, Tensor> defaultValues, Query query) {
+ var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues);
Supplier<Evaluator> supplier = () -> {
- var evaluator = functionEvaluatorSource.get();
- var simple = new SimpleEvaluator(evaluator);
+ var evaluator = evalSource.get();
for (var entry : prepared) {
- simple.bind(entry.name(), entry.value());
+ evaluator.bind(entry.name(), entry.value());
}
- return simple;
+ return evaluator;
};
- int rerankCount = data.rerankCount();
- if (rerankCount < 0)
- rerankCount = 100;
- ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount);
- hideImplicitMatchFeatures(result, data.matchFeaturesToHide());
+ return supplier;
}
- private void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) {
+ private static void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) {
if (namesToHide.size() == 0) return;
var filter = new MatchFeatureFilter(namesToHide);
for (var iterator = result.hits().deepIterator(); iterator.hasNext();) {
@@ -80,51 +96,27 @@ public class GlobalPhaseRanker {
if (newValue.fieldCount() == 0) {
hit.removeField("matchfeatures");
} else {
- hit.setField("matchfeatures", newValue);
+ hit.setField("matchfeatures", new FeatureData(newValue));
}
}
}
}
}
- private Optional<GlobalPhaseData> globalPhaseDataFor(Query query, String schema) {
+ private Optional<GlobalPhaseSetup> globalPhaseSetupFor(Query query, String schema) {
return factory.evaluatorForSchema(schema)
- .flatMap(evaluator -> evaluator.getGlobalPhaseData(query.getRanking().getProfile()));
+ .flatMap(evaluator -> evaluator.getGlobalPhaseSetup(query.getRanking().getProfile()));
}
- record NameAndValue(String name, Tensor value) { }
-
- /* do this only once per query: */
- List<NameAndValue> findFromQuery(Query query, List<String> needInputs) {
- List<NameAndValue> result = new ArrayList<>();
- var ranking = query.getRanking();
- var rankFeatures = ranking.getFeatures();
- var rankProps = ranking.getProperties().asMap();
- for (String needed : needInputs) {
- var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed);
- if (optRef.isEmpty()) continue;
- var ref = optRef.get();
- if (ref.name().equals("constant")) {
- // XXX in theory, we should be able to avoid this
- result.add(new NameAndValue(needed, null));
- continue;
- }
- if (ref.isSimple() && ref.name().equals("query")) {
- String queryFeatureName = ref.simpleArgument().get();
- // searchers are recommended to place query features here:
- var feature = rankFeatures.getTensor(queryFeatureName);
- if (feature.isPresent()) {
- result.add(new NameAndValue(needed, feature.get()));
- } else {
- // but other ways of setting query features end up in the properties:
- var objList = rankProps.get(queryFeatureName);
- if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) {
- result.add(new NameAndValue(needed, t));
- }
- }
- }
+ private static int resolveRerankCount(GlobalPhaseSetup setup, Query query) {
+ if (setup == null) {
+ // there is no global-phase at all (ignore override)
+ return 0;
}
- return result;
+ Integer override = query.getRanking().getGlobalPhase().getRerankCount();
+ if (override != null) {
+ return override;
+ }
+ return setup.rerankCount;
}
-
}