diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-06-06 16:11:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-06-06 16:11:43 +0200 |
commit | 8246c9b3761fe5c346963cc01d6669f6bd4197d9 (patch) | |
tree | 21b818d7c47e5560d277b389ebe9cfbfcb5bf8e1 | |
parent | 88be7917fae655f1be7e06b382446ae481064915 (diff) | |
parent | cdc0f19249a33457168348e0e5911690bf063a44 (diff) |
Merge pull request #31467 from vespa-engine/balder/support-diversity-at-rankprofile-level
Balder/support diversity at rankprofile level
8 files changed, 191 insertions, 161 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 5cfb99cc76a..ed1a4e98b49 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -73,7 +73,8 @@ public class RankProfile implements Cloneable { /** The resolved inherited profiles, or null when not resolved. */ private List<RankProfile> inherited; - private MatchPhaseSettings matchPhaseSettings = null; + private MatchPhaseSettings matchPhase = null; + private DiversitySettings diversity = null; protected Set<RankSetting> rankSettings = new java.util.LinkedHashSet<>(); @@ -225,7 +226,7 @@ public class RankProfile implements Cloneable { public boolean useSignificanceModel() { if (useSignificanceModel != null) return useSignificanceModel; - return uniquelyInherited(p -> p.useSignificanceModel(), "use-model") + return uniquelyInherited(RankProfile::useSignificanceModel, "use-model") .orElse(false); // Disabled by default } @@ -307,20 +308,28 @@ public class RankProfile implements Cloneable { return false; } - public void setMatchPhaseSettings(MatchPhaseSettings settings) { + public void setMatchPhase(MatchPhaseSettings settings) { settings.checkValid(); - this.matchPhaseSettings = settings; + this.matchPhase = settings; } - public MatchPhaseSettings getMatchPhaseSettings() { - if (matchPhaseSettings != null) return matchPhaseSettings; - return uniquelyInherited(p -> p.getMatchPhaseSettings(), "match phase settings").orElse(null); + public MatchPhaseSettings getMatchPhase() { + if (matchPhase != null) return matchPhase; + return uniquelyInherited(RankProfile::getMatchPhase, "match phase settings").orElse(null); + } + public void setDiversity(DiversitySettings value) { + value.checkValid(); + diversity = value; + } + public DiversitySettings getDiversity() { + if (diversity != null) return diversity; + return uniquelyInherited(RankProfile::getDiversity, "diversity settings").orElse(null); } /** Returns the uniquely determined property, where non-empty is defined as non-null */ private <T> Optional<T> uniquelyInherited(Function<RankProfile, T> propertyRetriever, String propertyDescription) { - return uniquelyInherited(propertyRetriever, p -> p != null, propertyDescription); + return uniquelyInherited(propertyRetriever, Objects::nonNull, propertyDescription); } /** @@ -335,8 +344,8 @@ public class RankProfile implements Cloneable { Predicate<T> nonEmptyValueFilter, String propertyDescription) { Set<T> uniqueProperties = inherited().stream() - .map(p -> propertyRetriever.apply(p)) - .filter(p -> nonEmptyValueFilter.test(p)) + .map(propertyRetriever) + .filter(nonEmptyValueFilter) .collect(Collectors.toSet()); if (uniqueProperties.isEmpty()) return Optional.empty(); if (uniqueProperties.size() == 1) return uniqueProperties.stream().findAny(); @@ -495,7 +504,7 @@ public class RankProfile implements Cloneable { public RankingExpressionFunction getFirstPhase() { if (firstPhaseRanking != null) return firstPhaseRanking; - return uniquelyInherited(p -> p.getFirstPhase(), "first-phase expression").orElse(null); + return uniquelyInherited(RankProfile::getFirstPhase, "first-phase expression").orElse(null); } void setFirstPhaseRanking(RankingExpression rankingExpression) { @@ -522,7 +531,7 @@ public class RankProfile implements Cloneable { public RankingExpressionFunction getSecondPhase() { if (secondPhaseRanking != null) return secondPhaseRanking; - return uniquelyInherited(p -> p.getSecondPhase(), "second-phase expression").orElse(null); + return uniquelyInherited(RankProfile::getSecondPhase, "second-phase expression").orElse(null); } public void setSecondPhaseRanking(String expression) { @@ -542,7 +551,7 @@ public class RankProfile implements Cloneable { public RankingExpressionFunction getGlobalPhase() { if (globalPhaseRanking != null) return globalPhaseRanking; - return uniquelyInherited(p -> p.getGlobalPhase(), "global-phase expression").orElse(null); + return uniquelyInherited(RankProfile::getGlobalPhase, "global-phase expression").orElse(null); } public void setGlobalPhaseRanking(String expression) { @@ -601,7 +610,7 @@ public class RankProfile implements Cloneable { return Collections.unmodifiableSet(combined); } if (summaryFeatures != null) return Collections.unmodifiableSet(summaryFeatures); - return uniquelyInherited(p -> p.getSummaryFeatures(), f -> ! f.isEmpty(), "summary features") + return uniquelyInherited(RankProfile::getSummaryFeatures, f -> ! f.isEmpty(), "summary features") .orElse(Set.of()); } @@ -618,13 +627,13 @@ public class RankProfile implements Cloneable { return Collections.unmodifiableSet(combined); } if (matchFeatures != null) return Collections.unmodifiableSet(matchFeatures); - return uniquelyInherited(p -> p.getMatchFeatures(), f -> ! f.isEmpty(), "match features") + return uniquelyInherited(RankProfile::getMatchFeatures, f -> ! f.isEmpty(), "match features") .orElse(Set.of()); } public Set<ReferenceNode> getHiddenMatchFeatures() { if (hiddenMatchFeatures != null) return Collections.unmodifiableSet(hiddenMatchFeatures); - return uniquelyInherited(p -> p.getHiddenMatchFeatures(), f -> ! f.isEmpty(), "hidden match features") + return uniquelyInherited(RankProfile::getHiddenMatchFeatures, f -> ! f.isEmpty(), "hidden match features") .orElse(Set.of()); } @@ -662,7 +671,7 @@ public class RankProfile implements Cloneable { /** Returns a read-only view of the rank features to use in this profile. This is never null */ public Set<ReferenceNode> getRankFeatures() { if (rankFeatures != null) return Collections.unmodifiableSet(rankFeatures); - return uniquelyInherited(p -> p.getRankFeatures(), f -> ! f.isEmpty(), "summary-features") + return uniquelyInherited(RankProfile::getRankFeatures, f -> ! f.isEmpty(), "summary-features") .orElse(Set.of()); } @@ -693,7 +702,7 @@ public class RankProfile implements Cloneable { if (rankProperties.isEmpty() && inherited().isEmpty()) return Map.of(); if (inherited().isEmpty()) return Collections.unmodifiableMap(rankProperties); - var inheritedProperties = uniquelyInherited(p -> p.getRankPropertyMap(), m -> ! m.isEmpty(), "rank-properties") + var inheritedProperties = uniquelyInherited(RankProfile::getRankPropertyMap, m -> ! m.isEmpty(), "rank-properties") .orElse(Map.of()); if (rankProperties.isEmpty()) return inheritedProperties; @@ -735,21 +744,21 @@ public class RankProfile implements Cloneable { public int getRerankCount() { if (rerankCount >= 0) return rerankCount; - return uniquelyInherited(p -> p.getRerankCount(), c -> c >= 0, "rerank-count").orElse(-1); + return uniquelyInherited(RankProfile::getRerankCount, c -> c >= 0, "rerank-count").orElse(-1); } public void setGlobalPhaseRerankCount(int count) { this.globalPhaseRerankCount = count; } public int getGlobalPhaseRerankCount() { if (globalPhaseRerankCount >= 0) return globalPhaseRerankCount; - return uniquelyInherited(p -> p.getGlobalPhaseRerankCount(), c -> c >= 0, "global-phase rerank-count").orElse(-1); + return uniquelyInherited(RankProfile::getGlobalPhaseRerankCount, c -> c >= 0, "global-phase rerank-count").orElse(-1); } public void setNumThreadsPerSearch(int numThreads) { this.numThreadsPerSearch = numThreads; } public int getNumThreadsPerSearch() { if (numThreadsPerSearch >= 0) return numThreadsPerSearch; - return uniquelyInherited(p -> p.getNumThreadsPerSearch(), n -> n >= 0, "num-threads-per-search") + return uniquelyInherited(RankProfile::getNumThreadsPerSearch, n -> n >= 0, "num-threads-per-search") .orElse(-1); } @@ -757,14 +766,14 @@ public class RankProfile implements Cloneable { public int getMinHitsPerThread() { if (minHitsPerThread >= 0) return minHitsPerThread; - return uniquelyInherited(p -> p.getMinHitsPerThread(), n -> n >= 0, "min-hits-per-search").orElse(-1); + return uniquelyInherited(RankProfile::getMinHitsPerThread, n -> n >= 0, "min-hits-per-search").orElse(-1); } public void setNumSearchPartitions(int numSearchPartitions) { this.numSearchPartitions = numSearchPartitions; } public int getNumSearchPartitions() { if (numSearchPartitions >= 0) return numSearchPartitions; - return uniquelyInherited(p -> p.getNumSearchPartitions(), n -> n >= 0, "num-search-partitions").orElse(-1); + return uniquelyInherited(RankProfile::getNumSearchPartitions, n -> n >= 0, "num-search-partitions").orElse(-1); } public void setTermwiseLimit(double termwiseLimit) { this.termwiseLimit = termwiseLimit; } @@ -774,7 +783,7 @@ public class RankProfile implements Cloneable { public OptionalDouble getTermwiseLimit() { if (termwiseLimit != null) return OptionalDouble.of(termwiseLimit); - return uniquelyInherited(p -> p.getTermwiseLimit(), l -> l.isPresent(), "termwise-limit") + return uniquelyInherited(RankProfile::getTermwiseLimit, OptionalDouble::isPresent, "termwise-limit") .orElse(OptionalDouble.empty()); } @@ -782,21 +791,21 @@ public class RankProfile implements Cloneable { if (postFilterThreshold != null) { return OptionalDouble.of(postFilterThreshold); } - return uniquelyInherited(p -> p.getPostFilterThreshold(), l -> l.isPresent(), "post-filter-threshold").orElse(OptionalDouble.empty()); + return uniquelyInherited(RankProfile::getPostFilterThreshold, OptionalDouble::isPresent, "post-filter-threshold").orElse(OptionalDouble.empty()); } public OptionalDouble getApproximateThreshold() { if (approximateThreshold != null) { return OptionalDouble.of(approximateThreshold); } - return uniquelyInherited(p -> p.getApproximateThreshold(), l -> l.isPresent(), "approximate-threshold").orElse(OptionalDouble.empty()); + return uniquelyInherited(RankProfile::getApproximateThreshold, OptionalDouble::isPresent, "approximate-threshold").orElse(OptionalDouble.empty()); } public OptionalDouble getTargetHitsMaxAdjustmentFactor() { if (targetHitsMaxAdjustmentFactor != null) { return OptionalDouble.of(targetHitsMaxAdjustmentFactor); } - return uniquelyInherited(p -> p.getTargetHitsMaxAdjustmentFactor(), l -> l.isPresent(), "target-hits-max-adjustment-factor").orElse(OptionalDouble.empty()); + return uniquelyInherited(RankProfile::getTargetHitsMaxAdjustmentFactor, OptionalDouble::isPresent, "target-hits-max-adjustment-factor").orElse(OptionalDouble.empty()); } /** Whether we should ignore the default rank features. Set to null to use inherited */ @@ -806,21 +815,21 @@ public class RankProfile implements Cloneable { public Boolean getIgnoreDefaultRankFeatures() { if (ignoreDefaultRankFeatures != null) return ignoreDefaultRankFeatures; - return uniquelyInherited(p -> p.getIgnoreDefaultRankFeatures(), "ignore-default-rank-features").orElse(false); + return uniquelyInherited(RankProfile::getIgnoreDefaultRankFeatures, "ignore-default-rank-features").orElse(false); } public void setKeepRankCount(int rerankArraySize) { this.keepRankCount = rerankArraySize; } public int getKeepRankCount() { if (keepRankCount >= 0) return keepRankCount; - return uniquelyInherited(p -> p.getKeepRankCount(), c -> c >= 0, "keep-rank-count").orElse(-1); + return uniquelyInherited(RankProfile::getKeepRankCount, c -> c >= 0, "keep-rank-count").orElse(-1); } public void setRankScoreDropLimit(double rankScoreDropLimit) { this.rankScoreDropLimit = rankScoreDropLimit; } public double getRankScoreDropLimit() { if (rankScoreDropLimit > -Double.MAX_VALUE) return rankScoreDropLimit; - return uniquelyInherited(p -> p.getRankScoreDropLimit(), c -> c > -Double.MAX_VALUE, "rank.score-drop-limit") + return uniquelyInherited(RankProfile::getRankScoreDropLimit, c -> c > -Double.MAX_VALUE, "rank.score-drop-limit") .orElse(rankScoreDropLimit); } @@ -830,7 +839,7 @@ public class RankProfile implements Cloneable { if (secondPhaseRankScoreDropLimit > -Double.MAX_VALUE) { return secondPhaseRankScoreDropLimit; } - return uniquelyInherited(p -> p.getSecondPhaseRankScoreDropLimit(), c -> c > -Double.MAX_VALUE, "second-phase rank-score-drop-limit") + return uniquelyInherited(RankProfile::getSecondPhaseRankScoreDropLimit, c -> c > -Double.MAX_VALUE, "second-phase rank-score-drop-limit") .orElse(secondPhaseRankScoreDropLimit); } @@ -958,7 +967,7 @@ public class RankProfile implements Cloneable { } private boolean needToUpdateFunctionCache() { - if (inherited().stream().anyMatch(profile -> profile.needToUpdateFunctionCache())) return true; + if (inherited().stream().anyMatch(RankProfile::needToUpdateFunctionCache)) return true; return allFunctionsCached == null; } @@ -966,7 +975,7 @@ public class RankProfile implements Cloneable { /** Returns all filter fields in this profile and any profile it inherits. */ public Set<String> allFilterFields() { - Set<String> inheritedFilterFields = uniquelyInherited(p -> p.allFilterFields(), fields -> ! fields.isEmpty(), + Set<String> inheritedFilterFields = uniquelyInherited(RankProfile::allFilterFields, fields -> ! fields.isEmpty(), "filter fields").orElse(Set.of()); if (inheritedFilterFields.isEmpty()) return Collections.unmodifiableSet(filterFields); @@ -977,7 +986,7 @@ public class RankProfile implements Cloneable { } private ExpressionFunction parseRankingExpression(String name, List<String> arguments, String expression) throws ParseException { - if (expression.trim().length() == 0) + if (expression.trim().isEmpty()) throw new ParseException("Encountered an empty ranking expression in " + name() + ", " + name + "."); try (Reader rankingExpressionReader = openRankingExpressionReader(name, expression.trim())) { @@ -1019,7 +1028,8 @@ public class RankProfile implements Cloneable { try { RankProfile clone = (RankProfile)super.clone(); clone.rankSettings = new LinkedHashSet<>(this.rankSettings); - clone.matchPhaseSettings = this.matchPhaseSettings; // hmm? + clone.matchPhase = this.matchPhase; // hmm? + clone.diversity = this.diversity; clone.summaryFeatures = summaryFeatures != null ? new LinkedHashSet<>(this.summaryFeatures) : null; clone.matchFeatures = matchFeatures != null ? new LinkedHashSet<>(this.matchFeatures) : null; clone.rankFeatures = rankFeatures != null ? new LinkedHashSet<>(this.rankFeatures) : null; @@ -1200,7 +1210,7 @@ public class RankProfile implements Cloneable { private Map<Reference, TensorType> featureTypes() { Map<Reference, TensorType> featureTypes = inputs().values().stream() - .collect(Collectors.toMap(input -> input.name(), + .collect(Collectors.toMap(Input::name, input -> input.type().tensorType())); allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes)); @@ -1517,15 +1527,9 @@ public class RankProfile implements Cloneable { private boolean ascending = false; private int maxHits = 0; // try to get this many hits before degrading the match phase private double maxFilterCoverage = 0.2; // Max coverage of original corpus that will trigger the filter. - private DiversitySettings diversity = null; private double evaluationPoint = 0.20; private double prePostFilterTippingPoint = 1.0; - public void setDiversity(DiversitySettings value) { - value.checkValid(); - diversity = value; - } - public void setAscending(boolean value) { ascending = value; } public void setAttribute(String value) { attribute = value; } public void setMaxHits(int value) { maxHits = value; } @@ -1537,7 +1541,6 @@ public class RankProfile implements Cloneable { public String getAttribute() { return attribute; } public int getMaxHits() { return maxHits; } public double getMaxFilterCoverage() { return maxFilterCoverage; } - public DiversitySettings getDiversity() { return diversity; } public double getEvaluationPoint() { return evaluationPoint; } public double getPrePostFilterTippingPoint() { return prePostFilterTippingPoint; } @@ -1701,7 +1704,7 @@ public class RankProfile implements Cloneable { } - public static record RankFeatureNormalizer(Reference original, String name, String input, String algo, double kparam) { + public record RankFeatureNormalizer(Reference original, String name, String input, String algo, double kparam) { @Override public String toString() { return "normalizer{name=" + name + ",input=" + input + ",algo=" + algo + ",k=" + kparam + "}"; @@ -1722,7 +1725,7 @@ public class RankProfile implements Cloneable { } } - private List<RankFeatureNormalizer> featureNormalizers = new ArrayList<>(); + private final List<RankFeatureNormalizer> featureNormalizers = new ArrayList<>(); public Map<String, RankFeatureNormalizer> getFeatureNormalizers() { Map<String, RankFeatureNormalizer> all = new LinkedHashMap<>(); diff --git a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java index e05225226b6..b540add4be2 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/RawRankProfile.java @@ -159,6 +159,7 @@ public class RawRankProfile { private final boolean ignoreDefaultRankFeatures; private final RankProfile.MatchPhaseSettings matchPhaseSettings; + private final RankProfile.DiversitySettings diversitySettings; private final int rerankCount; private final int keepRankCount; private final int numThreadsPerSearch; @@ -208,7 +209,8 @@ public class RawRankProfile { rankFeatures = compiled.getRankFeatures(); rerankCount = compiled.getRerankCount(); globalPhaseRerankCount = compiled.getGlobalPhaseRerankCount(); - matchPhaseSettings = compiled.getMatchPhaseSettings(); + matchPhaseSettings = compiled.getMatchPhase(); + diversitySettings = compiled.getDiversity(); numThreadsPerSearch = compiled.getNumThreadsPerSearch(); minHitsPerThread = compiled.getMinHitsPerThread(); numSearchPartitions = compiled.getNumSearchPartitions(); @@ -488,13 +490,12 @@ public class RawRankProfile { properties.add(new Pair<>("vespa.matchphase.degradation.maxfiltercoverage", matchPhaseSettings.getMaxFilterCoverage() + "")); properties.add(new Pair<>("vespa.matchphase.degradation.samplepercentage", matchPhaseSettings.getEvaluationPoint() + "")); properties.add(new Pair<>("vespa.matchphase.degradation.postfiltermultiplier", matchPhaseSettings.getPrePostFilterTippingPoint() + "")); - RankProfile.DiversitySettings diversitySettings = matchPhaseSettings.getDiversity(); - if (diversitySettings != null) { - properties.add(new Pair<>("vespa.matchphase.diversity.attribute", diversitySettings.getAttribute())); - properties.add(new Pair<>("vespa.matchphase.diversity.mingroups", String.valueOf(diversitySettings.getMinGroups()))); - properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.factor", String.valueOf(diversitySettings.getCutoffFactor()))); - properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.strategy", String.valueOf(diversitySettings.getCutoffStrategy()))); - } + } + if (diversitySettings != null) { + properties.add(new Pair<>("vespa.matchphase.diversity.attribute", diversitySettings.getAttribute())); + properties.add(new Pair<>("vespa.matchphase.diversity.mingroups", String.valueOf(diversitySettings.getMinGroups()))); + properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.factor", String.valueOf(diversitySettings.getCutoffFactor()))); + properties.add(new Pair<>("vespa.matchphase.diversity.cutoff.strategy", String.valueOf(diversitySettings.getCutoffStrategy()))); } if (rerankCount > -1) { properties.add(new Pair<>("vespa.hitcollector.heapsize", rerankCount + "")); diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java b/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java index 78f2c8a85ef..ff78a4a3b60 100644 --- a/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java +++ b/config-model/src/main/java/com/yahoo/schema/parser/ConvertParsedRanking.java @@ -38,8 +38,8 @@ public class ConvertParsedRanking { for (String name : parsed.getInherited()) profile.inherit(name); - parsed.isStrict().ifPresent(value -> profile.setStrict(value)); - parsed.isUseSignificanceModel().ifPresent(value -> profile.setUseSignificanceModel(value)); + parsed.isStrict().ifPresent(profile::setStrict); + parsed.isUseSignificanceModel().ifPresent(profile::setUseSignificanceModel); for (var constant : parsed.getConstants().values()) profile.add(constant); @@ -58,41 +58,26 @@ public class ConvertParsedRanking { profile.addFunction(name, parameters, expression, inline); } - parsed.getRankScoreDropLimit().ifPresent - (value -> profile.setRankScoreDropLimit(value)); - parsed.getSecondPhaseRankScoreDropLimit().ifPresent - (value -> profile.setSecondPhaseRankScoreDropLimit(value)); - parsed.getTermwiseLimit().ifPresent - (value -> profile.setTermwiseLimit(value)); - parsed.getPostFilterThreshold().ifPresent - (value -> profile.setPostFilterThreshold(value)); - parsed.getApproximateThreshold().ifPresent - (value -> profile.setApproximateThreshold(value)); - parsed.getTargetHitsMaxAdjustmentFactor().ifPresent - (value -> profile.setTargetHitsMaxAdjustmentFactor(value)); - parsed.getKeepRankCount().ifPresent - (value -> profile.setKeepRankCount(value)); - parsed.getMinHitsPerThread().ifPresent - (value -> profile.setMinHitsPerThread(value)); - parsed.getNumSearchPartitions().ifPresent - (value -> profile.setNumSearchPartitions(value)); - parsed.getNumThreadsPerSearch().ifPresent - (value -> profile.setNumThreadsPerSearch(value)); - parsed.getReRankCount().ifPresent - (value -> profile.setRerankCount(value)); - - parsed.getMatchPhaseSettings().ifPresent - (value -> profile.setMatchPhaseSettings(value)); - - parsed.getFirstPhaseExpression().ifPresent - (value -> profile.setFirstPhaseRanking(value)); - parsed.getSecondPhaseExpression().ifPresent - (value -> profile.setSecondPhaseRanking(value)); - - parsed.getGlobalPhaseExpression().ifPresent - (value -> profile.setGlobalPhaseRanking(value)); - parsed.getGlobalPhaseRerankCount().ifPresent - (value -> profile.setGlobalPhaseRerankCount(value)); + parsed.getRankScoreDropLimit().ifPresent(profile::setRankScoreDropLimit); + parsed.getSecondPhaseRankScoreDropLimit().ifPresent(profile::setSecondPhaseRankScoreDropLimit); + parsed.getTermwiseLimit().ifPresent(profile::setTermwiseLimit); + parsed.getPostFilterThreshold().ifPresent(profile::setPostFilterThreshold); + parsed.getApproximateThreshold().ifPresent(profile::setApproximateThreshold); + parsed.getTargetHitsMaxAdjustmentFactor().ifPresent(profile::setTargetHitsMaxAdjustmentFactor); + parsed.getKeepRankCount().ifPresent(profile::setKeepRankCount); + parsed.getMinHitsPerThread().ifPresent(profile::setMinHitsPerThread); + parsed.getNumSearchPartitions().ifPresent(profile::setNumSearchPartitions); + parsed.getNumThreadsPerSearch().ifPresent(profile::setNumThreadsPerSearch); + parsed.getReRankCount().ifPresent(profile::setRerankCount); + + parsed.getMatchPhase().ifPresent(profile::setMatchPhase); + parsed.getDiversity().ifPresent(profile::setDiversity); + + parsed.getFirstPhaseExpression().ifPresent(profile::setFirstPhaseRanking); + parsed.getSecondPhaseExpression().ifPresent(profile::setSecondPhaseRanking); + + parsed.getGlobalPhaseExpression().ifPresent(profile::setGlobalPhaseRanking); + parsed.getGlobalPhaseRerankCount().ifPresent(profile::setGlobalPhaseRerankCount); for (var value : parsed.getMatchFeatures()) { profile.addMatchFeatures(value); @@ -104,10 +89,8 @@ public class ConvertParsedRanking { profile.addSummaryFeatures(value); } - parsed.getInheritedMatchFeatures().ifPresent - (value -> profile.setInheritedMatchFeatures(value)); - parsed.getInheritedSummaryFeatures().ifPresent - (value -> profile.setInheritedSummaryFeatures(value)); + parsed.getInheritedMatchFeatures().ifPresent(profile::setInheritedMatchFeatures); + parsed.getInheritedSummaryFeatures().ifPresent(profile::setInheritedSummaryFeatures); if (parsed.getIgnoreDefaultRankFeatures()) { profile.setIgnoreDefaultRankFeatures(true); } diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java b/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java index 6a800bf354f..2a117a4af4b 100644 --- a/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/parser/ParsedRankProfile.java @@ -4,6 +4,7 @@ package com.yahoo.schema.parser; import com.yahoo.schema.OnnxModel; import com.yahoo.schema.RankProfile; import com.yahoo.schema.RankProfile.MatchPhaseSettings; +import com.yahoo.schema.RankProfile.DiversitySettings; import com.yahoo.schema.RankProfile.MutateOperation; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.Reference; @@ -39,7 +40,8 @@ class ParsedRankProfile extends ParsedBlock { private Integer numSearchPartitions = null; private Integer numThreadsPerSearch = null; private Integer reRankCount = null; - private MatchPhaseSettings matchPhaseSettings = null; + private MatchPhaseSettings matchPhase = null; + private DiversitySettings diversity = null; private String firstPhaseExpression = null; private String inheritedSummaryFeatures = null; private String inheritedMatchFeatures = null; @@ -78,7 +80,8 @@ class ParsedRankProfile extends ParsedBlock { Optional<Integer> getNumSearchPartitions() { return Optional.ofNullable(this.numSearchPartitions); } Optional<Integer> getNumThreadsPerSearch() { return Optional.ofNullable(this.numThreadsPerSearch); } Optional<Integer> getReRankCount() { return Optional.ofNullable(this.reRankCount); } - Optional<MatchPhaseSettings> getMatchPhaseSettings() { return Optional.ofNullable(this.matchPhaseSettings); } + Optional<MatchPhaseSettings> getMatchPhase() { return Optional.ofNullable(this.matchPhase); } + Optional<DiversitySettings> getDiversity() { return Optional.ofNullable(this.diversity); } Optional<String> getFirstPhaseExpression() { return Optional.ofNullable(this.firstPhaseExpression); } Optional<String> getInheritedMatchFeatures() { return Optional.ofNullable(this.inheritedMatchFeatures); } List<ParsedRankFunction> getFunctions() { return List.copyOf(functions.values()); } @@ -173,9 +176,13 @@ class ParsedRankProfile extends ParsedBlock { this.keepRankCount = count; } - void setMatchPhaseSettings(MatchPhaseSettings settings) { - verifyThat(matchPhaseSettings == null, "already has match-phase"); - this.matchPhaseSettings = settings; + void setMatchPhase(MatchPhaseSettings settings) { + verifyThat(matchPhase == null, "already has match-phase"); + this.matchPhase = settings; + } + void setDiversity(DiversitySettings settings) { + verifyThat(diversity == null, "already has diversity"); + this.diversity = settings; } void setMinHitsPerThread(int minHits) { diff --git a/config-model/src/main/java/com/yahoo/schema/processing/DiversitySettingsValidator.java b/config-model/src/main/java/com/yahoo/schema/processing/DiversitySettingsValidator.java index 5c06ce25184..24972167732 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/DiversitySettingsValidator.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/DiversitySettingsValidator.java @@ -23,8 +23,8 @@ public class DiversitySettingsValidator extends Processor { if (documentsOnly) return; for (RankProfile rankProfile : rankProfileRegistry.rankProfilesOf(schema)) { - if (rankProfile.getMatchPhaseSettings() != null && rankProfile.getMatchPhaseSettings().getDiversity() != null) { - validate(rankProfile, rankProfile.getMatchPhaseSettings().getDiversity()); + if (rankProfile.getDiversity() != null) { + validate(rankProfile, rankProfile.getDiversity()); } } } diff --git a/config-model/src/main/java/com/yahoo/schema/processing/MatchPhaseSettingsValidator.java b/config-model/src/main/java/com/yahoo/schema/processing/MatchPhaseSettingsValidator.java index f3a8f7cee18..d29820e0d51 100644 --- a/config-model/src/main/java/com/yahoo/schema/processing/MatchPhaseSettingsValidator.java +++ b/config-model/src/main/java/com/yahoo/schema/processing/MatchPhaseSettingsValidator.java @@ -25,7 +25,7 @@ public class MatchPhaseSettingsValidator extends Processor { if (documentsOnly) return; for (RankProfile rankProfile : rankProfileRegistry.rankProfilesOf(schema)) { - RankProfile.MatchPhaseSettings settings = rankProfile.getMatchPhaseSettings(); + RankProfile.MatchPhaseSettings settings = rankProfile.getMatchPhase(); if (settings != null) { validateMatchPhaseSettings(rankProfile, settings); } diff --git a/config-model/src/main/javacc/SchemaParser.jj b/config-model/src/main/javacc/SchemaParser.jj index d87c57574d3..c9eff88764f 100644 --- a/config-model/src/main/javacc/SchemaParser.jj +++ b/config-model/src/main/javacc/SchemaParser.jj @@ -1756,6 +1756,7 @@ void rankProfileItem(ParsedSchema schema, ParsedRankProfile profile) : { } | fieldRankFilter(profile) | firstPhase(profile) | matchPhase(profile) + | diversity(profile) | function(profile) | mutate(profile) | ignoreRankFeatures(profile) @@ -1875,14 +1876,14 @@ void matchPhase(ParsedRankProfile profile) : MatchPhaseSettings settings = new MatchPhaseSettings(); } { - <MATCH_PHASE> lbrace() (matchPhaseItem(settings) (<NL>)*)* <RBRACE> + <MATCH_PHASE> lbrace() (matchPhaseItem(profile, settings) (<NL>)*)* <RBRACE> { settings.checkValid(); - profile.setMatchPhaseSettings(settings); + profile.setMatchPhase(settings); } } -void matchPhaseItem(MatchPhaseSettings settings) : +void matchPhaseItem(ParsedRankProfile profile, MatchPhaseSettings settings) : { String str; int num; @@ -1891,7 +1892,7 @@ void matchPhaseItem(MatchPhaseSettings settings) : } { ( <ATTRIBUTE> <COLON> str = identifier() { settings.setAttribute(str); } - | diversity(settings) + | diversityDeprecated(profile) | <ORDER> <COLON> ( <ASCENDING> { settings.setAscending(true); } | <DESCENDING> { settings.setAscending(false); } ) | <MAX_HITS> <COLON> num = integer() { settings.setMaxHits(num); } @@ -1906,7 +1907,7 @@ void matchPhaseItem(MatchPhaseSettings settings) : * * @param profile The rank profile to modify. */ -void diversity(MatchPhaseSettings profile) : +void diversity(ParsedRankProfile profile) : { DiversitySettings settings = new DiversitySettings(); } @@ -1917,6 +1918,18 @@ void diversity(MatchPhaseSettings profile) : } } +void diversityDeprecated(ParsedRankProfile profile) : +{ + DiversitySettings settings = new DiversitySettings(); +} +{ + <DIVERSITY> lbrace() (diversityItem(settings) (<NL>)*)* <RBRACE> + { + profile.setDiversity(settings); + deployLogger.logApplicationPackage(Level.WARNING, "'diversity is deprecated inside 'match-phase'. Specify it at 'rank-profile' level."); + } +} + void diversityItem(DiversitySettings settings) : { String str; diff --git a/config-model/src/test/java/com/yahoo/schema/DiversityTestCase.java b/config-model/src/test/java/com/yahoo/schema/DiversityTestCase.java index df71e30c7d9..4026341464f 100644 --- a/config-model/src/test/java/com/yahoo/schema/DiversityTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/DiversityTestCase.java @@ -3,49 +3,58 @@ package com.yahoo.schema; import com.yahoo.search.query.ranking.Diversity; import com.yahoo.schema.parser.ParseException; +import com.yahoo.vespa.model.test.utils.DeployLoggerStub; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; /** * @author baldersheim */ public class DiversityTestCase { - @Test - void testDiversity() throws ParseException { + private static void verifyDiversity(DeployLoggerStub logger, boolean atRankProfile, boolean atMatchPhase) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); + ApplicationBuilder builder = new ApplicationBuilder(logger, rankProfileRegistry); + String diversitySpec = """ + diversity { + attribute: b + min-groups: 74 + cutoff-factor: 17.3 + cutoff-strategy: strict + } + """; builder.addSchema( - "search test {\n" + - " document test { \n" + - " field a type int { \n" + - " indexing: attribute \n" + - " attribute: fast-search\n" + - " }\n" + - " field b type int {\n" + - " indexing: attribute \n" + - " }\n" + - " }\n" + - " \n" + - " rank-profile parent {\n" + - " match-phase {\n" + - " diversity {\n" + - " attribute: b\n" + - " min-groups: 74\n" + - " cutoff-factor: 17.3\n" + - " cutoff-strategy: strict" + - " }\n" + - " attribute: a\n" + - " max-hits: 120\n" + - " max-filter-coverage: 0.065" + - " }\n" + - " }\n" + - "}\n"); + """ + search test { + document test { + field a type int { + indexing: attribute + attribute: fast-search + } + field b type int { + indexing: attribute + } + } + rank-profile parent { + match-phase {""" + + (atMatchPhase ? diversitySpec : "") + + """ + attribute: a + max-hits: 120 + max-filter-coverage: 0.065 + }""" + + (atRankProfile ? diversitySpec : "") + + """ + } + } + """); builder.build(true); Schema s = builder.getSchema(); - RankProfile.MatchPhaseSettings matchPhase = rankProfileRegistry.get(s, "parent").getMatchPhaseSettings(); - RankProfile.DiversitySettings diversity = matchPhase.getDiversity(); + RankProfile parent = rankProfileRegistry.get(s, "parent"); + RankProfile.MatchPhaseSettings matchPhase = parent.getMatchPhase(); + RankProfile.DiversitySettings diversity = parent.getDiversity(); assertEquals("b", diversity.getAttribute()); assertEquals(74, diversity.getMinGroups()); assertEquals(17.3, diversity.getCutoffFactor(), 1e-16); @@ -54,6 +63,21 @@ public class DiversityTestCase { assertEquals("a", matchPhase.getAttribute()); assertEquals(0.065, matchPhase.getMaxFilterCoverage(), 1e-16); } + @Test + void testDiversity() throws ParseException { + DeployLoggerStub logger = new DeployLoggerStub(); + verifyDiversity(logger, true, false); + assertTrue(logger.entries.isEmpty()); + verifyDiversity(logger, false, true); + assertEquals(1, logger.entries.size()); + assertEquals("'diversity is deprecated inside 'match-phase'. Specify it at 'rank-profile' level.", logger.entries.get(0).message); + try { + verifyDiversity(logger, true, true); + fail("Should throw."); + } catch (Exception e) { + assertEquals("rank-profile 'parent' error: already has diversity", e.getMessage()); + } + } private static String getMessagePrefix() { return "In search definition 'test', rank-profile 'parent': diversity attribute 'b' "; @@ -82,30 +106,29 @@ public class DiversityTestCase { assertEquals(getMessagePrefix() + "must be single value numeric, or enumerated attribute, but it is 'Array<int>'", e.getMessage()); } } - private ApplicationBuilder getSearchBuilder(String diversity) throws ParseException { - RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); - builder.addSchema( - "search test {\n" + - " document test { \n" + - " field a type int { \n" + - " indexing: attribute \n" + - " attribute: fast-search\n" + - " }\n" + - diversity + - " }\n" + - " \n" + - " rank-profile parent {\n" + - " match-phase {\n" + - " diversity {\n" + - " attribute: b\n" + - " min-groups: 74\n" + - " }\n" + - " attribute: a\n" + - " max-hits: 120\n" + - " }\n" + - " }\n" + - "}\n"); + private ApplicationBuilder getSearchBuilder(String diversityField) throws ParseException { + ApplicationBuilder builder = new ApplicationBuilder(new RankProfileRegistry()); + builder.addSchema(""" + search test { + document test { + field a type int { + indexing: attribute + attribute: fast-search + }""" + + diversityField + + """ + } + rank-profile parent { + match-phase { + attribute: a + max-hits: 120 + } + diversity { + attribute: b + min-groups: 74 + } + } + }"""); return builder; } } |