diff options
Diffstat (limited to 'container-search/src/main/java')
3 files changed, 20 insertions, 22 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 829d0c268e5..91acc883803 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -13,10 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; @@ -52,12 +49,12 @@ public class GlobalPhaseRanker { static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) { var mainSpec = setup.globalPhaseEvalSpec; - var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), query); + var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query); int rerankCount = resolveRerankCount(setup, query); var normalizers = new ArrayList<NormalizerContext>(); for (var nSetup : setup.normalizers) { var normSpec = nSetup.inputEvalSpec(); - var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), query); + var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query); normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF())); } var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers); @@ -73,8 +70,8 @@ public class GlobalPhaseRanker { } } - static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Query query) { - var prepared = PreparedInput.findFromQuery(query, queryFeatures); + static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Map<String, Tensor> defaultValues, Query query) { + var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues); Supplier<Evaluator> supplier = () -> { var evaluator = evalSource.get(); for (var entry : prepared) { diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java index 31a676e4c8e..e5cd09d3a18 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -3,15 +3,10 @@ package com.yahoo.search.ranking; import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.tensor.Tensor; import com.yahoo.vespa.config.search.RankProfilesConfig; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import java.util.Map; -import java.util.HashMap; +import java.util.*; import java.util.function.Supplier; class GlobalPhaseSetup { @@ -20,16 +15,19 @@ class GlobalPhaseSetup { final int rerankCount; final Collection<String> matchFeaturesToHide; final List<NormalizerSetup> normalizers; + final Map<String, Tensor> defaultValues; GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec, final int rerankCount, Collection<String> matchFeaturesToHide, - List<NormalizerSetup> normalizers) + List<NormalizerSetup> normalizers, + Map<String, Tensor> defaultValues) { this.globalPhaseEvalSpec = globalPhaseEvalSpec; this.rerankCount = rerankCount; this.matchFeaturesToHide = matchFeaturesToHide; this.normalizers = normalizers; + this.defaultValues = defaultValues; } static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { @@ -106,7 +104,7 @@ class GlobalPhaseSetup { } Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource); var gfun = new FunEvalSpec(supplier, fromQuery, fromMF); - return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers); + return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, Collections.emptyMap()); } return null; } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java index 5491724cc08..914635fef59 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -13,16 +13,13 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; record PreparedInput(String name, Tensor value) { - static List<PreparedInput> findFromQuery(Query query, Collection<String> queryFeatures) { + static List<PreparedInput> findFromQuery(Query query, Collection<String> queryFeatures, Map<String, Tensor> defaultValues) { List<PreparedInput> result = new ArrayList<>(); var ranking = query.getRanking(); var rankFeatures = ranking.getFeatures(); @@ -36,6 +33,12 @@ record PreparedInput(String name, Tensor value) { feature = rankFeatures.getTensor(needed); } if (feature.isEmpty()) { + var t = defaultValues.get(needed); + if (t != null) { + feature = Optional.of(t); + } + } + if (feature.isEmpty()) { throw new IllegalArgumentException("missing query feature: " + queryFeatureName); } result.add(new PreparedInput(needed, feature.get())); |