diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/ranking')
19 files changed, 700 insertions, 198 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java new file mode 100644 index 00000000000..e83a308d99c --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/DummyEvaluator.java @@ -0,0 +1,38 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import java.util.Collection; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +class DummyEvaluator implements Evaluator { + + private final String input; + private Tensor result = null; + + DummyEvaluator(String input) { + this.input = input; + } + + @Override + public Evaluator bind(String name, Tensor value) { + result = value; + return this; + } + + @Override + public double evaluateScore() { + return result.asDouble(); + } + + @Override + public String toString() { + return "DummyEvaluator(" + input + ")"; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java index d2edb776c92..83f9d0e2704 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/Evaluator.java @@ -1,13 +1,11 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.tensor.Tensor; -import java.util.Collection; +import java.util.List; interface Evaluator { - Collection<String> needInputs(); - Evaluator bind(String name, Tensor value); double evaluateScore(); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java new file mode 100644 index 00000000000..df9c509dd82 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/FunEvalSpec.java @@ -0,0 +1,7 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.List; +import java.util.function.Supplier; + +record FunEvalSpec(Supplier<Evaluator> evalSource, List<String> fromQuery, List<String> fromMF) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java index 2aa9fd32795..91acc883803 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseRanker.java @@ -1,11 +1,10 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.component.annotation.Inject; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.query.Sorting; -import com.yahoo.search.ranking.RankProfilesEvaluator.GlobalPhaseData; import com.yahoo.search.result.ErrorMessage; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; @@ -14,10 +13,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.data.access.helpers.MatchFeatureData; import com.yahoo.data.access.helpers.MatchFeatureFilter; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; -import java.util.Optional; +import java.util.*; import java.util.function.Supplier; import java.util.logging.Logger; @@ -32,9 +28,14 @@ public class GlobalPhaseRanker { logger.fine(() -> "Using factory: " + factory); } + public int getRerankCount(Query query, String schema) { + var setup = globalPhaseSetupFor(query, schema).orElse(null); + return resolveRerankCount(setup, query); + } + public Optional<ErrorMessage> validateNoSorting(Query query, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return Optional.empty(); + var setup = globalPhaseSetupFor(query, schema).orElse(null); + if (setup == null) return Optional.empty(); var sorting = query.getRanking().getSorting(); if (sorting == null || sorting.fieldOrders() == null) return Optional.empty(); for (var fieldOrder : sorting.fieldOrders()) { @@ -46,27 +47,42 @@ public class GlobalPhaseRanker { return Optional.empty(); } + static void rerankHitsImpl(GlobalPhaseSetup setup, Query query, Result result) { + var mainSpec = setup.globalPhaseEvalSpec; + var mainSrc = withQueryPrep(mainSpec.evalSource(), mainSpec.fromQuery(), setup.defaultValues, query); + int rerankCount = resolveRerankCount(setup, query); + var normalizers = new ArrayList<NormalizerContext>(); + for (var nSetup : setup.normalizers) { + var normSpec = nSetup.inputEvalSpec(); + var normEvalSrc = withQueryPrep(normSpec.evalSource(), normSpec.fromQuery(), setup.defaultValues, query); + normalizers.add(new NormalizerContext(nSetup.name(), nSetup.supplier().get(), normEvalSrc, normSpec.fromMF())); + } + var rescorer = new HitRescorer(mainSrc, mainSpec.fromMF(), normalizers); + var reranker = new ResultReranker(rescorer, rerankCount); + reranker.rerankHits(result); + hideImplicitMatchFeatures(result, setup.matchFeaturesToHide); + } + public void rerankHits(Query query, Result result, String schema) { - var data = globalPhaseDataFor(query, schema).orElse(null); - if (data == null) return; - var functionEvaluatorSource = data.functionEvaluatorSource(); - var prepared = findFromQuery(query, data.needInputs()); + var setup = globalPhaseSetupFor(query, schema); + if (setup.isPresent()) { + rerankHitsImpl(setup.get(), query, result); + } + } + + static Supplier<Evaluator> withQueryPrep(Supplier<Evaluator> evalSource, List<String> queryFeatures, Map<String, Tensor> defaultValues, Query query) { + var prepared = PreparedInput.findFromQuery(query, queryFeatures, defaultValues); Supplier<Evaluator> supplier = () -> { - var evaluator = functionEvaluatorSource.get(); - var simple = new SimpleEvaluator(evaluator); + var evaluator = evalSource.get(); for (var entry : prepared) { - simple.bind(entry.name(), entry.value()); + evaluator.bind(entry.name(), entry.value()); } - return simple; + return evaluator; }; - int rerankCount = data.rerankCount(); - if (rerankCount < 0) - rerankCount = 100; - ResultReranker.rerankHits(result, new HitRescorer(supplier), rerankCount); - hideImplicitMatchFeatures(result, data.matchFeaturesToHide()); + return supplier; } - private void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) { + private static void hideImplicitMatchFeatures(Result result, Collection<String> namesToHide) { if (namesToHide.size() == 0) return; var filter = new MatchFeatureFilter(namesToHide); for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { @@ -80,51 +96,27 @@ public class GlobalPhaseRanker { if (newValue.fieldCount() == 0) { hit.removeField("matchfeatures"); } else { - hit.setField("matchfeatures", newValue); + hit.setField("matchfeatures", new FeatureData(newValue)); } } } } } - private Optional<GlobalPhaseData> globalPhaseDataFor(Query query, String schema) { + private Optional<GlobalPhaseSetup> globalPhaseSetupFor(Query query, String schema) { return factory.evaluatorForSchema(schema) - .flatMap(evaluator -> evaluator.getGlobalPhaseData(query.getRanking().getProfile())); + .flatMap(evaluator -> evaluator.getGlobalPhaseSetup(query.getRanking().getProfile())); } - record NameAndValue(String name, Tensor value) { } - - /* do this only once per query: */ - List<NameAndValue> findFromQuery(Query query, List<String> needInputs) { - List<NameAndValue> result = new ArrayList<>(); - var ranking = query.getRanking(); - var rankFeatures = ranking.getFeatures(); - var rankProps = ranking.getProperties().asMap(); - for (String needed : needInputs) { - var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(needed); - if (optRef.isEmpty()) continue; - var ref = optRef.get(); - if (ref.name().equals("constant")) { - // XXX in theory, we should be able to avoid this - result.add(new NameAndValue(needed, null)); - continue; - } - if (ref.isSimple() && ref.name().equals("query")) { - String queryFeatureName = ref.simpleArgument().get(); - // searchers are recommended to place query features here: - var feature = rankFeatures.getTensor(queryFeatureName); - if (feature.isPresent()) { - result.add(new NameAndValue(needed, feature.get())); - } else { - // but other ways of setting query features end up in the properties: - var objList = rankProps.get(queryFeatureName); - if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { - result.add(new NameAndValue(needed, t)); - } - } - } + private static int resolveRerankCount(GlobalPhaseSetup setup, Query query) { + if (setup == null) { + // there is no global-phase at all (ignore override) + return 0; } - return result; + Integer override = query.getRanking().getGlobalPhase().getRerankCount(); + if (override != null) { + return override; + } + return setup.rerankCount; } - } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java new file mode 100644 index 00000000000..7340e9e2a5e --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/GlobalPhaseSetup.java @@ -0,0 +1,218 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import ai.vespa.models.evaluation.FunctionEvaluator; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; + +import java.util.*; +import java.util.function.Supplier; + +class GlobalPhaseSetup { + + final FunEvalSpec globalPhaseEvalSpec; + final int rerankCount; + final Collection<String> matchFeaturesToHide; + final List<NormalizerSetup> normalizers; + final Map<String, Tensor> defaultValues; + + GlobalPhaseSetup(FunEvalSpec globalPhaseEvalSpec, + final int rerankCount, + Collection<String> matchFeaturesToHide, + List<NormalizerSetup> normalizers, + Map<String, Tensor> defaultValues) + { + this.globalPhaseEvalSpec = globalPhaseEvalSpec; + this.rerankCount = rerankCount; + this.matchFeaturesToHide = matchFeaturesToHide; + this.normalizers = normalizers; + this.defaultValues = defaultValues; + } + + static class DefaultQueryFeatureExtractor { + final String baseName; + final String qfName; + TensorType type = null; + Tensor value = null; + DefaultQueryFeatureExtractor(String unwrappedQueryFeature) { + baseName = unwrappedQueryFeature; + qfName = "query(" + baseName + ")"; + } + List<String> lookingFor() { + return List.of(qfName, "vespa.type.query." + baseName); + } + void accept(String key, String propValue) { + if (key.equals(qfName)) { + this.value = Tensor.from(propValue); + } else { + this.type = TensorType.fromSpec(propValue); + } + } + Tensor extract() { + if (value != null) { + return value; + } + if (type != null) { + return Tensor.Builder.of(type).build(); + } + return Tensor.from(0.0); + } + } + + static private Map<String, Tensor> extraDefaultQueryFeatureValues(RankProfilesConfig.Rankprofile rp, + List<String> fromQuery, + List<NormalizerSetup> normalizers) + { + Map<String, DefaultQueryFeatureExtractor> extractors = new HashMap<>(); + for (String fn : fromQuery) { + extractors.put(fn, new DefaultQueryFeatureExtractor(fn)); + } + for (var n : normalizers) { + for (String fn : n.inputEvalSpec().fromQuery()) { + extractors.put(fn, new DefaultQueryFeatureExtractor(fn)); + } + } + Map<String, DefaultQueryFeatureExtractor> targets = new HashMap<>(); + for (var extractor : extractors.values()) { + for (String key : extractor.lookingFor()) { + var old = targets.put(key, extractor); + if (old != null) { + throw new IllegalStateException("Multiple targets for key: " + key); + } + } + } + for (var prop : rp.fef().property()) { + var extractor = targets.get(prop.name()); + if (extractor != null) { + extractor.accept(prop.name(), prop.value()); + } + } + Map<String, Tensor> defaultValues = new HashMap<>(); + for (var extractor : extractors.values()) { + defaultValues.put(extractor.qfName, extractor.extract()); + } + return defaultValues; + } + + static GlobalPhaseSetup maybeMakeSetup(RankProfilesConfig.Rankprofile rp, RankProfilesEvaluator modelEvaluator) { + var model = modelEvaluator.modelForRankProfile(rp.name()); + Map<String, RankProfilesConfig.Rankprofile.Normalizer> availableNormalizers = new HashMap<>(); + for (var n : rp.normalizer()) { + availableNormalizers.put(n.name(), n); + } + Supplier<FunctionEvaluator> functionEvaluatorSource = null; + int rerankCount = -1; + Set<String> namesToHide = new HashSet<>(); + Set<String> matchFeatures = new HashSet<>(); + Map<String, String> renameFeatures = new HashMap<>(); + String toRename = null; + for (var prop : rp.fef().property()) { + if (prop.name().equals("vespa.globalphase.rerankcount")) { + rerankCount = Integer.valueOf(prop.value()); + } + if (prop.name().equals("vespa.rank.globalphase")) { + functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); + } + if (prop.name().equals("vespa.hidden.matchfeature")) { + namesToHide.add(prop.value()); + } + if (prop.name().equals("vespa.match.feature")) { + matchFeatures.add(prop.value()); + } + if (prop.name().equals("vespa.feature.rename")) { + if (toRename == null) { + toRename = prop.value(); + } else { + renameFeatures.put(toRename, prop.value()); + toRename = null; + } + } + } + for (var entry : renameFeatures.entrySet()) { + String old = entry.getKey(); + if (matchFeatures.contains(old)) { + matchFeatures.remove(old); + matchFeatures.add(entry.getValue()); + } + } + if (rerankCount < 0) { + rerankCount = 100; + } + if (functionEvaluatorSource != null) { + var evaluator = functionEvaluatorSource.get(); + var allInputs = List.copyOf(evaluator.function().arguments()); + List<String> fromMF = new ArrayList<>(); + List<String> fromQuery = new ArrayList<>(); + List<NormalizerSetup> normalizers = new ArrayList<>(); + for (var input : allInputs) { + String queryFeatureName = asQueryFeature(input); + if (queryFeatureName != null) { + fromQuery.add(queryFeatureName); + } else if (availableNormalizers.containsKey(input)) { + var cfg = availableNormalizers.get(input); + String normInput = cfg.input(); + if (matchFeatures.contains(normInput)) { + Supplier<Evaluator> normSource = () -> new DummyEvaluator(normInput); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSource, List.of(normInput), rerankCount)); + } else { + Supplier<FunctionEvaluator> normSource = () -> model.evaluatorOf(normInput); + var normInputs = List.copyOf(normSource.get().function().arguments()); + var normSupplier = SimpleEvaluator.wrap(normSource); + normalizers.add(makeNormalizerSetup(cfg, matchFeatures, normSupplier, normInputs, rerankCount)); + } + } else if (matchFeatures.contains(input) || matchFeatures.contains(WrappedHit.alternate(input))) { + fromMF.add(input); + } else { + throw new IllegalArgumentException("Bad config, missing global-phase input: " + input); + } + } + Supplier<Evaluator> supplier = SimpleEvaluator.wrap(functionEvaluatorSource); + var gfun = new FunEvalSpec(supplier, fromQuery, fromMF); + var defaultValues = extraDefaultQueryFeatureValues(rp, fromQuery, normalizers); + return new GlobalPhaseSetup(gfun, rerankCount, namesToHide, normalizers, defaultValues); + } + return null; + } + + private static NormalizerSetup makeNormalizerSetup(RankProfilesConfig.Rankprofile.Normalizer cfg, + Set<String> matchFeatures, + Supplier<Evaluator> evalSupplier, + List<String> normInputs, + int rerankCount) + { + List<String> fromQuery = new ArrayList<>(); + List<String> fromMF = new ArrayList<>(); + for (var input : normInputs) { + String queryFeatureName = asQueryFeature(input); + if (queryFeatureName != null) { + fromQuery.add(queryFeatureName); + } else if (matchFeatures.contains(input) || matchFeatures.contains(WrappedHit.alternate(input))) { + fromMF.add(input); + } else { + throw new IllegalArgumentException("Bad config, missing normalizer input: " + input); + } + } + var fun = new FunEvalSpec(evalSupplier, fromQuery, fromMF); + return new NormalizerSetup(cfg.name(), makeNormalizerSupplier(cfg, rerankCount), fun); + } + + private static Supplier<Normalizer> makeNormalizerSupplier(RankProfilesConfig.Rankprofile.Normalizer cfg, int rerankCount) { + return switch (cfg.algo()) { + case LINEAR -> () -> new LinearNormalizer(rerankCount); + case RRANK -> () -> new ReciprocalRankNormalizer(rerankCount, cfg.kparam()); + }; + } + + static String asQueryFeature(String input) { + var optRef = com.yahoo.searchlib.rankingexpression.Reference.simple(input); + if (optRef.isPresent()) { + var ref = optRef.get(); + if (ref.isSimple() && ref.name().equals("query")) { + return ref.simpleArgument().get(); + } + } + return null; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java index cce6b42d323..fee4f5b4160 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/HitRescorer.java @@ -1,57 +1,63 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.search.result.FeatureData; import com.yahoo.search.result.Hit; -import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; +import com.yahoo.tensor.Tensor; +import java.util.List; import java.util.function.Supplier; import java.util.logging.Logger; class HitRescorer { private static final Logger logger = Logger.getLogger(HitRescorer.class.getName()); - - private final Supplier<Evaluator> evaluatorSource; - public HitRescorer(Supplier<Evaluator> evaluatorSource) { - this.evaluatorSource = evaluatorSource; + private final Supplier<Evaluator> mainEvalSrc; + private final List<String> mainFromMF; + private final List<NormalizerContext> normalizers; + + public HitRescorer(Supplier<Evaluator> mainEvalSrc, List<String> mainFromMF, List<NormalizerContext> normalizers) { + this.mainEvalSrc = mainEvalSrc; + this.mainFromMF = mainFromMF; + this.normalizers = normalizers; } - boolean rescoreHit(Hit hit) { - var features = hit.getField("matchfeatures"); - if (features instanceof FeatureData matchFeatures) { - var scorer = evaluatorSource.get(); - for (String argName : scorer.needInputs()) { - var asTensor = matchFeatures.getTensor(argName); - if (asTensor == null) { - asTensor = matchFeatures.getTensor(alternate(argName)); - } - if (asTensor != null) { - scorer.bind(argName, asTensor); - } else { - logger.warning("Missing match-feature for Evaluator argument: " + argName); - return false; - } - } - double newScore = scorer.evaluateScore(); - hit.setRelevance(newScore); - return true; - } else { - logger.warning("Hit without match-features: " + hit); - return false; + void preprocess(WrappedHit wrapped) { + for (var n : normalizers) { + var scorer = n.evalSource().get(); + double val = evalScorer(wrapped, scorer, n.fromMF()); + wrapped.setIdx(n.normalizer().addInput(val)); + } + } + + void runNormalizers() { + for (var n : normalizers) { + n.normalizer().normalize(); } } - private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; - private static final String RE_SUFFIX = ")"; - private static final int RE_PRE_LEN = RE_PREFIX.length(); - private static final int RE_SUF_LEN = RE_SUFFIX.length(); + double rescoreHit(WrappedHit wrapped) { + var scorer = mainEvalSrc.get(); + for (var n : normalizers) { + double normalizedValue = n.normalizer().getOutput(wrapped.getIdx()); + scorer.bind(n.name(), Tensor.from(normalizedValue)); + } + double newScore = evalScorer(wrapped, scorer, mainFromMF); + wrapped.setScore(newScore); + return newScore; + } - static String alternate(String argName) { - if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { - return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + private static double evalScorer(WrappedHit wrapped, Evaluator scorer, List<String> fromMF) { + for (String argName : fromMF) { + var asTensor = wrapped.getTensor(argName); + if (asTensor != null) { + scorer.bind(argName, asTensor); + } else { + logger.warning("Missing match-feature for Evaluator argument: " + argName); + return 0.0; + } } - return argName; + return scorer.evaluateScore(); } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java new file mode 100644 index 00000000000..2cdba9d6361 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/LinearNormalizer.java @@ -0,0 +1,33 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +class LinearNormalizer extends Normalizer { + + LinearNormalizer(int maxSize) { + super(maxSize); + } + + void normalize() { + double min = Float.MAX_VALUE; + double max = -Float.MAX_VALUE; + for (int i = 0; i < size; i++) { + double val = data[i]; + if (val < Float.MAX_VALUE && val > -Float.MAX_VALUE) { + min = Math.min(min, data[i]); + max = Math.max(max, data[i]); + } + } + double scale = 0.0; + double midpoint = 0.0; + if (max > min) { + scale = 1.0 / (max - min); + midpoint = (min + max) * 0.5; + } + for (int i = 0; i < size; i++) { + double old = data[i]; + data[i] = 0.5 + scale * (old - midpoint); + } + } + + String normalizing() { return "linear"; } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java new file mode 100644 index 00000000000..eb81d0555b3 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java @@ -0,0 +1,23 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +abstract class Normalizer { + + protected final double[] data; + protected int size = 0; + + Normalizer(int maxSize) { + this.data = new double[maxSize]; + } + + int addInput(double value) { + data[size] = value; + return size++; + } + + double getOutput(int index) { return data[index]; } + + abstract void normalize(); + + abstract String normalizing(); +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java new file mode 100644 index 00000000000..9438b5ea824 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerContext.java @@ -0,0 +1,7 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.List; +import java.util.function.Supplier; + +record NormalizerContext(String name, Normalizer normalizer, Supplier<Evaluator> evalSource, List<String> fromMF) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java new file mode 100644 index 00000000000..32fbb3190fc --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/NormalizerSetup.java @@ -0,0 +1,6 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.function.Supplier; + +record NormalizerSetup(String name, Supplier<Normalizer> supplier, FunEvalSpec inputEvalSpec) {} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java new file mode 100644 index 00000000000..914635fef59 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -0,0 +1,49 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.query.Sorting; +import com.yahoo.search.result.ErrorMessage; +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.search.result.HitGroup; +import com.yahoo.tensor.Tensor; +import com.yahoo.data.access.helpers.MatchFeatureData; +import com.yahoo.data.access.helpers.MatchFeatureFilter; + +import java.util.*; +import java.util.function.Supplier; +import java.util.logging.Logger; + +record PreparedInput(String name, Tensor value) { + + static List<PreparedInput> findFromQuery(Query query, Collection<String> queryFeatures, Map<String, Tensor> defaultValues) { + List<PreparedInput> result = new ArrayList<>(); + var ranking = query.getRanking(); + var rankFeatures = ranking.getFeatures(); + var rankProps = ranking.getProperties(); + for (String queryFeatureName : queryFeatures) { + String needed = "query(" + queryFeatureName + ")"; + // after prepare() the query tensor ends up here: + var feature = rankProps.getAsTensor(queryFeatureName); + if (feature.isEmpty()) { + // searchers are recommended to place query features here: + feature = rankFeatures.getTensor(needed); + } + if (feature.isEmpty()) { + var t = defaultValues.get(needed); + if (t != null) { + feature = Optional.of(t); + } + } + if (feature.isEmpty()) { + throw new IllegalArgumentException("missing query feature: " + queryFeatureName); + } + result.add(new PreparedInput(needed, feature.get())); + } + return result; + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java b/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java new file mode 100644 index 00000000000..6881eece620 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/RangeAdjuster.java @@ -0,0 +1,40 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +// scale and adjust the score according to the range +// of the original and final score values to avoid that +// a score from the backend is larger than finalScores_low +class RangeAdjuster { + private double initialScores_high = -Double.MAX_VALUE; + private double initialScores_low = Double.MAX_VALUE; + private double finalScores_high = -Double.MAX_VALUE; + private double finalScores_low = Double.MAX_VALUE; + + boolean rescaleNeeded() { + return (initialScores_low > finalScores_low + && + initialScores_high >= initialScores_low + && + finalScores_high >= finalScores_low); + } + void withInitialScore(double score) { + if (score < initialScores_low) initialScores_low = score; + if (score > initialScores_high) initialScores_high = score; + } + void withFinalScore(double score) { + if (score < finalScores_low) finalScores_low = score; + if (score > finalScores_high) finalScores_high = score; + } + private double initialRange() { + double r = initialScores_high - initialScores_low; + if (r < 1.0) r = 1.0; + return r; + } + private double finalRange() { + double r = finalScores_high - finalScores_low; + if (r < 1.0) r = 1.0; + return r; + } + double scale() { return finalRange() / initialRange(); } + double bias() { return finalScores_low - initialScores_low * scale(); } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java index a89f0a5c3ea..0ebb98af60e 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluator.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; @@ -63,40 +63,17 @@ public class RankProfilesEvaluator extends AbstractComponent { return modelForRankProfile(rankProfile).evaluatorOf(functionName); } - static record GlobalPhaseData(Supplier<FunctionEvaluator> functionEvaluatorSource, - Collection<String> matchFeaturesToHide, - int rerankCount, - List<String> needInputs) {} + private Map<String, GlobalPhaseSetup> profilesWithGlobalPhase = new HashMap<>(); - private Map<String, GlobalPhaseData> profilesWithGlobalPhase = new HashMap<>(); - - Optional<GlobalPhaseData> getGlobalPhaseData(String rankProfile) { + Optional<GlobalPhaseSetup> getGlobalPhaseSetup(String rankProfile) { return Optional.ofNullable(profilesWithGlobalPhase.get(rankProfile)); } private void extractGlobalPhaseData(RankProfilesConfig rankProfilesConfig) { for (var rp : rankProfilesConfig.rankprofile()) { - String name = rp.name(); - Supplier<FunctionEvaluator> functionEvaluatorSource = null; - int rerankCount = -1; - List<String> needInputs = null; - Set<String> namesToHide = new HashSet<>(); - for (var prop : rp.fef().property()) { - if (prop.name().equals("vespa.globalphase.rerankcount")) { - rerankCount = Integer.valueOf(prop.value()); - } - if (prop.name().equals("vespa.rank.globalphase")) { - var model = modelForRankProfile(name); - functionEvaluatorSource = () -> model.evaluatorOf("globalphase"); - var evaluator = functionEvaluatorSource.get(); - needInputs = List.copyOf(evaluator.function().arguments()); - } - if (prop.name().equals("vespa.hidden.matchfeature")) { - namesToHide.add(prop.value()); - } - } - if (functionEvaluatorSource != null && needInputs != null) { - profilesWithGlobalPhase.put(name, new GlobalPhaseData(functionEvaluatorSource, namesToHide, rerankCount, needInputs)); + var setup = GlobalPhaseSetup.maybeMakeSetup(rp, this); + if (setup != null) { + profilesWithGlobalPhase.put(rp.name(), setup); } } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java index 33f2fb74da5..ee866d6c67a 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/RankProfilesEvaluatorFactory.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java new file mode 100644 index 00000000000..fca920a7a65 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/ReciprocalRankNormalizer.java @@ -0,0 +1,34 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import java.util.Arrays; + +class ReciprocalRankNormalizer extends Normalizer { + + private final double k; + + ReciprocalRankNormalizer(int maxSize, double k) { + super(maxSize); + this.k = k; + } + + static record IdxScore(int index, double score) {} + + void normalize() { + if (size < 1) return; + IdxScore[] temp = new IdxScore[size]; + for (int i = 0; i < size; i++) { + double val = data[i]; + if (Double.isNaN(val)) val = Double.NEGATIVE_INFINITY; + temp[i] = new IdxScore(i, val); + } + Arrays.sort(temp, (a, b) -> Double.compare(b.score, a.score)); + for (int i = 0; i < size; i++) { + int idx = temp[i].index; + double old = data[idx]; + data[idx] = 1.0 / (k + 1.0 + i); + } + } + + String normalizing() { return "reciprocal-rank{k:" + k + "}"; } +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java index 8d24acdf141..2e9edd6de3a 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/ResultReranker.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import com.yahoo.search.Result; @@ -14,80 +14,72 @@ class ResultReranker { private static final Logger logger = Logger.getLogger(ResultReranker.class.getName()); - // scale and adjust the score according to the range - // of the original and final score values to avoid that - // a score from the backend is larger than finalScores_low - static class Ranges { - private double initialScores_high = -Double.MAX_VALUE; - private double initialScores_low = Double.MAX_VALUE; - private double finalScores_high = -Double.MAX_VALUE; - private double finalScores_low = Double.MAX_VALUE; + private final HitRescorer hitRescorer; + private final int rerankCount; + private final List<WrappedHit> hitsToRescore = new ArrayList<>(); + private final RangeAdjuster ranges = new RangeAdjuster(); - boolean rescaleNeeded() { - return (initialScores_low > finalScores_low - && - initialScores_high >= initialScores_low - && - finalScores_high >= finalScores_low); - } - void withInitialScore(double score) { - if (score < initialScores_low) initialScores_low = score; - if (score > initialScores_high) initialScores_high = score; - } - void withFinalScore(double score) { - if (score < finalScores_low) finalScores_low = score; - if (score > finalScores_high) finalScores_high = score; - } - private double initialRange() { - double r = initialScores_high - initialScores_low; - if (r < 1.0) r = 1.0; - return r; - } - private double finalRange() { - double r = finalScores_high - finalScores_low; - if (r < 1.0) r = 1.0; - return r; - } - double scale() { return finalRange() / initialRange(); } - double bias() { return finalScores_low - initialScores_low * scale(); } + ResultReranker(HitRescorer hitRescorer, int rerankCount) { + this.hitRescorer = hitRescorer; + this.rerankCount = rerankCount; } - static void rerankHits(Result result, HitRescorer hitRescorer, int rerankCount) { - List<Hit> hitsToRescore = new ArrayList<>(); - // consider doing recursive iteration explicitly instead of using deepIterator? + void rerankHits(Result result) { + gatherHits(result); + runPreProcessing(); + hitRescorer.runNormalizers(); + runProcessing(); + runPostProcessing(); + result.hits().sort(); + } + + private void gatherHits(Result result) { for (var iterator = result.hits().deepIterator(); iterator.hasNext();) { Hit hit = iterator.next(); if (hit.isMeta() || hit instanceof HitGroup) { continue; } // what about hits inside grouping results? - // they are inside GroupingListHit, we won't recurse into it; so we won't see them. - hitsToRescore.add(hit); + // they did not show up here during manual testing. + var wrapped = WrappedHit.from(hit); + if (wrapped != null) hitsToRescore.add(wrapped); } + } + + private void runPreProcessing() { // we can't be 100% certain that hits were sorted according to relevance: hitsToRescore.sort(Comparator.naturalOrder()); - var ranges = new Ranges(); - for (var iterator = hitsToRescore.iterator(); rerankCount > 0 && iterator.hasNext(); ) { - Hit hit = iterator.next(); - double oldScore = hit.getRelevance().getScore(); - boolean didRerank = hitRescorer.rescoreHit(hit); - if (didRerank) { - ranges.withInitialScore(oldScore); - ranges.withFinalScore(hit.getRelevance().getScore()); - --rerankCount; - iterator.remove(); - } + int count = 0; + for (WrappedHit hit : hitsToRescore) { + if (count == rerankCount) break; + hitRescorer.preprocess(hit); + ++count; } + } + + private void runProcessing() { + int count = 0; + for (var iterator = hitsToRescore.iterator(); count < rerankCount && iterator.hasNext(); ) { + WrappedHit wrapped = iterator.next(); + double oldScore = wrapped.getScore(); + double newScore = hitRescorer.rescoreHit(wrapped); + ranges.withInitialScore(oldScore); + ranges.withFinalScore(newScore); + ++count; + iterator.remove(); + } + } + + private void runPostProcessing() { // if any hits are left in the list, they may need rescaling: - if (ranges.rescaleNeeded()) { + if (ranges.rescaleNeeded() && ! hitsToRescore.isEmpty()) { double scale = ranges.scale(); double bias = ranges.bias(); - for (Hit hit : hitsToRescore) { - double oldScore = hit.getRelevance().getScore(); - hit.setRelevance(oldScore * scale + bias); + for (WrappedHit wrapped : hitsToRescore) { + double oldScore = wrapped.getScore(); + wrapped.setScore(oldScore * scale + bias); } } - result.hits().sort(); } } diff --git a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java index f247eab1649..548576e3a15 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/SimpleEvaluator.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.ranking; import ai.vespa.models.evaluation.FunctionEvaluator; @@ -10,24 +10,23 @@ import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Set; +import java.util.function.Supplier; class SimpleEvaluator implements Evaluator { private final FunctionEvaluator evaluator; - private final Set<String> neededInputs; - - public SimpleEvaluator(FunctionEvaluator prototype) { - this.evaluator = prototype; - this.neededInputs = new HashSet<String>(prototype.function().arguments()); + + static Supplier<Evaluator> wrap(Supplier<FunctionEvaluator> supplier) { + return () -> new SimpleEvaluator(supplier.get()); } - @Override - public Collection<String> needInputs() { return List.copyOf(neededInputs); } + SimpleEvaluator(FunctionEvaluator prototype) { + this.evaluator = prototype; + } @Override - public SimpleEvaluator bind(String name, Tensor value) { - if (value != null) evaluator.bind(name, value); - neededInputs.remove(name); + public Evaluator bind(String name, Tensor value) { + evaluator.bind(name, value); return this; } @@ -42,7 +41,7 @@ class SimpleEvaluator implements Evaluator { buf.append("SimpleEvaluator("); buf.append(evaluator.function().toString()); buf.append(")["); - for (String arg : neededInputs) { + for (String arg : evaluator.function().arguments()) { buf.append("{").append(arg).append("}"); } buf.append("]"); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java b/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java new file mode 100644 index 00000000000..7c33b836e33 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/ranking/WrappedHit.java @@ -0,0 +1,83 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.ranking; + +import com.yahoo.search.result.FeatureData; +import com.yahoo.search.result.Hit; +import com.yahoo.tensor.Tensor; + +import static com.yahoo.searchlib.rankingexpression.Reference.RANKING_EXPRESSION_WRAPPER; + +import java.util.logging.Logger; + +class WrappedHit implements Comparable<WrappedHit> { + + private static final Logger logger = Logger.getLogger(WrappedHit.class.getName()); + private final Hit hit; + private final FeatureData matchFeatures; + private int idx = -1; + + private WrappedHit(Hit hit, FeatureData matchFeatures) { + this.hit = hit; + this.matchFeatures = matchFeatures; + } + + static WrappedHit from(Hit hit) { + if (hit.getField("matchfeatures") instanceof FeatureData mf) { + return new WrappedHit(hit, mf); + } else { + return null; + } + } + + double getScore() { + return hit.getRelevance().getScore(); + } + + void setScore(double value) { + hit.setRelevance(value); + } + + int getIdx() { + if (idx < 0) { + throw new IllegalStateException("Missing index"); + } + return idx; + } + + void setIdx(int value) { + if (idx == value) { + return; + } else if (idx < 0) { + idx = value; + } else { + throw new IllegalArgumentException("Cannot re-assign index " + idx + " -> " + value); + } + } + + public int compareTo(WrappedHit other) { + return hit.compareTo(other.hit); + } + + Tensor getTensor(String argName) { + var asTensor = matchFeatures.getTensor(argName); + if (asTensor == null) { + asTensor = matchFeatures.getTensor(alternate(argName)); + } + return asTensor; + } + + private static final String RE_PREFIX = RANKING_EXPRESSION_WRAPPER + "("; + private static final String RE_SUFFIX = ")"; + private static final int RE_PRE_LEN = RE_PREFIX.length(); + private static final int RE_SUF_LEN = RE_SUFFIX.length(); + + // rankingExpression(foo) <-> foo + static String alternate(String argName) { + if (argName.startsWith(RE_PREFIX) && argName.endsWith(RE_SUFFIX)) { + return argName.substring(RE_PRE_LEN, argName.length() - RE_SUF_LEN); + } else { + return RE_PREFIX + argName + RE_SUFFIX; + } + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/ranking/package-info.java b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java index a86a5c1e52f..ab2acb6dd95 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/package-info.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/package-info.java @@ -1,4 +1,4 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. @ExportPackage package com.yahoo.search.ranking; |