diff options
author | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2022-03-24 17:43:27 +0100 |
---|---|---|
committer | Bjørn Christian Seime <bjorncs@verizonmedia.com> | 2022-03-24 17:43:27 +0100 |
commit | f892050296c4ab4c231e20cec23285608cb4163c (patch) | |
tree | 1260edf7a4b4e37844c9fef5697ee355e1dd23d0 /container-search | |
parent | 51346c0606f7d859baf5b6aee1395e1926821d0a (diff) |
Add 'grouping.globalMaxGroups' query parameter
Diffstat (limited to 'container-search')
7 files changed, 172 insertions, 10 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 6249988a5ee..94baaf4ef52 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -2369,6 +2369,8 @@ "public void setDefaultMaxHits(int)", "public java.util.OptionalInt defaultMaxGroups()", "public void setDefaultMaxGroups(int)", + "public java.util.OptionalLong globalMaxGroups()", + "public void setGlobalMaxGroups(long)", "public static com.yahoo.search.grouping.GroupingRequest newInstance(com.yahoo.search.Query)", "public java.lang.String toString()" ], @@ -6624,6 +6626,7 @@ "fields": [ "public static final com.yahoo.processing.request.CompoundName MAX_OFFSET", "public static final com.yahoo.processing.request.CompoundName MAX_HITS", + "public static final com.yahoo.processing.request.CompoundName GROUPING_GLOBAL_MAX_GROUPS", "public static final com.yahoo.search.query.profile.types.QueryProfileType argumentType" ] }, diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java index 0de4e36eae5..ee78e41d0d8 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java @@ -5,15 +5,16 @@ import com.yahoo.api.annotations.Beta; import com.yahoo.component.chain.dependencies.After; import com.yahoo.component.chain.dependencies.Before; import com.yahoo.component.chain.dependencies.Provides; +import com.yahoo.processing.IllegalInputException; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.Result; import com.yahoo.search.Searcher; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.query.Select; +import com.yahoo.search.query.properties.DefaultProperties; import com.yahoo.search.searchchain.Execution; import com.yahoo.search.searchchain.PhaseNames; -import com.yahoo.processing.IllegalInputException; import java.util.Collections; import java.util.LinkedHashMap; @@ -45,6 +46,10 @@ public class GroupingQueryParser extends Searcher { @Override public Result search(Query query, Execution execution) { try { + if (query.getHttpRequest().getProperty(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS.toString()) != null) { + throw new IllegalInputException(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS + " must be specified in a query profile."); + } + String reqParam = query.properties().getString(PARAM_REQUEST); if (reqParam == null) return execution.search(query); @@ -57,6 +62,7 @@ public class GroupingQueryParser extends Searcher { grpRequest.continuations().addAll(continuations); grpRequest.setDefaultMaxGroups(query.properties().getInteger(PARAM_DEFAULT_MAX_GROUPS, -1)); grpRequest.setDefaultMaxHits(query.properties().getInteger(PARAM_DEFAULT_MAX_HITS, -1)); + grpRequest.setGlobalMaxGroups(query.properties().getLong(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS)); } return execution.search(query); } diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java index 0c163aaacae..9f5deb482db 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java @@ -15,6 +15,7 @@ import com.yahoo.search.result.Hit; import java.util.ArrayList; import java.util.List; import java.util.OptionalInt; +import java.util.OptionalLong; import java.util.TimeZone; /** @@ -34,6 +35,7 @@ public class GroupingRequest { private TimeZone timeZone; private int defaultMaxHits = -1; private int defaultMaxGroups = -1; + private long globalMaxGroups = -1; private GroupingRequest(Select parent) { this.parent = parent; @@ -44,18 +46,20 @@ public class GroupingRequest { GroupingOperation root, TimeZone timeZone, int defaultMaxHits, - int defaultMaxGroups) { + int defaultMaxGroups, + long globalMaxGroups) { this.parent = parent; continuations.forEach(item -> this.continuations.add(item.copy())); this.root = root != null ? root.copy(null) : null; this.timeZone = timeZone; this.defaultMaxHits = defaultMaxHits; this.defaultMaxGroups = defaultMaxGroups; + this.globalMaxGroups = globalMaxGroups; } /** Returns a deep copy of this */ public GroupingRequest copy(Select parentOfCopy) { - return new GroupingRequest(parentOfCopy, continuations, root, timeZone, defaultMaxHits, defaultMaxGroups); + return new GroupingRequest(parentOfCopy, continuations, root, timeZone, defaultMaxHits, defaultMaxGroups, globalMaxGroups); } /** @@ -153,6 +157,13 @@ public class GroupingRequest { @Beta public void setDefaultMaxGroups(int v) { this.defaultMaxGroups = v; } + @Beta + public OptionalLong globalMaxGroups() { + return globalMaxGroups >= 0 ? OptionalLong.of(globalMaxGroups) : OptionalLong.empty(); + } + + @Beta public void setGlobalMaxGroups(long v) { this.globalMaxGroups = v; } + /** * Creates a new grouping request and adds it to the query.getSelect().getGrouping() list * diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java index c09502110b1..c7000c0cdcf 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java @@ -152,6 +152,7 @@ public class GroupingExecutor extends Searcher { builder.addContinuations(req.continuations()); req.defaultMaxGroups().ifPresent(builder::setDefaultMaxGroups); req.defaultMaxHits().ifPresent(builder::setDefaultMaxHits); + req.globalMaxGroups().ifPresent(builder::setGlobalMaxGroups); builder.build(); RequestContext ctx = new RequestContext(req, builder.getTransform()); diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java index 78452b36cd5..b013e87fb24 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.grouping.vespa; +import com.yahoo.processing.IllegalInputException; import com.yahoo.search.grouping.Continuation; import com.yahoo.search.grouping.GroupingRequest; import com.yahoo.search.grouping.request.AllOperation; @@ -8,10 +9,22 @@ import com.yahoo.search.grouping.request.EachOperation; import com.yahoo.search.grouping.request.GroupingExpression; import com.yahoo.search.grouping.request.GroupingOperation; import com.yahoo.search.grouping.request.NegFunction; -import com.yahoo.searchlib.aggregation.*; +import com.yahoo.searchlib.aggregation.AggregationResult; +import com.yahoo.searchlib.aggregation.ExpressionCountAggregationResult; +import com.yahoo.searchlib.aggregation.Group; +import com.yahoo.searchlib.aggregation.Grouping; +import com.yahoo.searchlib.aggregation.GroupingLevel; +import com.yahoo.searchlib.aggregation.HitsAggregationResult; import com.yahoo.searchlib.expression.ExpressionNode; -import java.util.*; +import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; +import java.util.OptionalLong; +import java.util.Stack; +import java.util.TimeZone; + +import static java.util.stream.Collectors.toList; /** * This class implements the necessary logic to build a list of {@link Grouping} objects from an instance of {@link @@ -29,6 +42,8 @@ class RequestBuilder { private int tag = 0; private int defaultMaxHits = -1; private int defaultMaxGroups = -1; + private long globalMaxGroups = -1; + private long totalGroupsAndSummaries = -1; /** * Constructs a new instance of this class. @@ -128,6 +143,7 @@ class RequestBuilder { } } pruneRequests(); + validateGlobalMax(); } public RequestBuilder addContinuations(Iterable<Continuation> continuations) { @@ -144,6 +160,12 @@ class RequestBuilder { public RequestBuilder setDefaultMaxHits(int v) { this.defaultMaxHits = v; return this; } + public RequestBuilder setGlobalMaxGroups(long v) { this.globalMaxGroups = v; return this; } + + OptionalLong totalGroupsAndSummaries() { + return totalGroupsAndSummaries != -1 ? OptionalLong.of(totalGroupsAndSummaries) : OptionalLong.empty(); + } + private void processRequestNode(BuildFrame frame) { int level = frame.astNode.getLevel(); if (level > 2) { @@ -377,6 +399,55 @@ class RequestBuilder { } } + private void validateGlobalMax() { + this.totalGroupsAndSummaries = -1; + if (globalMaxGroups < 0) return; + int totalGroupsAndSummaries = 0; + for (Grouping grp : requestList) { + int levelMultiplier = 1; + for (GroupingLevel lvl : grp.getLevels()) { + totalGroupsAndSummaries += (levelMultiplier *= validateSummaryMax(lvl)); + var hars = hitsAggregationResult(lvl); + for (HitsAggregationResult har : hars) { + totalGroupsAndSummaries += (levelMultiplier * validateSummaryMax(har)); + } + } + } + if (totalGroupsAndSummaries > globalMaxGroups) + throw new IllegalInputException(String.format( + "The theoretical total number of groups and summaries in grouping query exceeds " + + "'grouping.globalMaxGroups' ( %d > %d ). " + + "Either restrict group/summary counts with max() or disable 'grouping.globalMaxGroups'. " + + "See https://docs.vespa.ai/en/grouping.html for details.", + totalGroupsAndSummaries, globalMaxGroups)); + this.totalGroupsAndSummaries = totalGroupsAndSummaries; + } + + private int validateSummaryMax(GroupingLevel lvl) { + int max = transform.getMax(lvl.getGroupPrototype().getTag()); + if (max <= 0) throw new IllegalInputException( + "Cannot return unbounded number of groups when 'grouping.globalMaxGroups' is enabled. " + + "Either restrict group count with max() or disable 'grouping.globalMaxGroups'. " + + "See https://docs.vespa.ai/en/grouping.html for details."); + return max; + } + + private int validateSummaryMax(HitsAggregationResult res) { + int max = transform.getMax(res.getTag()); + if (max <= 0) throw new IllegalInputException( + "Cannot return unbounded number of summaries when 'grouping.globalMaxGroups' is enabled. " + + "Either restrict summary count with max() or disable 'grouping.globalMaxGroups'. " + + "See https://docs.vespa.ai/en/grouping.html for details."); + return max; + } + + private List<HitsAggregationResult> hitsAggregationResult(GroupingLevel lvl) { + return lvl.getGroupPrototype().getAggregationResults().stream() + .filter(ar -> ar instanceof HitsAggregationResult) + .map(ar -> (HitsAggregationResult) ar) + .collect(toList()); + } + private static class BuildFrame { final Grouping grouping; diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java index b38e3070e28..f25b61c3265 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java @@ -1,10 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.properties; +import com.yahoo.api.annotations.Beta; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.query.Properties; import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import java.util.Map; @@ -17,7 +19,7 @@ public final class DefaultProperties extends Properties { public static final CompoundName MAX_OFFSET = new CompoundName("maxOffset"); public static final CompoundName MAX_HITS = new CompoundName("maxHits"); - + @Beta public static final CompoundName GROUPING_GLOBAL_MAX_GROUPS = new CompoundName("grouping.globalMaxGroups"); public static final QueryProfileType argumentType = new QueryProfileType("DefaultProperties"); @@ -26,6 +28,7 @@ public final class DefaultProperties extends Properties { argumentType.addField(new FieldDescription(MAX_OFFSET.toString(), "integer")); argumentType.addField(new FieldDescription(MAX_HITS.toString(), "integer")); + argumentType.addField(new FieldDescription(GROUPING_GLOBAL_MAX_GROUPS.toString(), "long"), new QueryProfileTypeRegistry()); argumentType.freeze(); } @@ -36,6 +39,8 @@ public final class DefaultProperties extends Properties { return 1000; } else if (MAX_HITS.equals(name)) { return 400; + } else if (GROUPING_GLOBAL_MAX_GROUPS.equals(name)) { + return -1; // TODO Vespa 8: use default from Vespa 8 release notes } else { return super.get(name, context, substitution); } diff --git a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java index c1b9e74757b..ccf11d82541 100644 --- a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java @@ -1,13 +1,38 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.grouping.vespa; +import com.yahoo.processing.IllegalInputException; import com.yahoo.search.grouping.Continuation; -import com.yahoo.search.grouping.request.*; -import com.yahoo.searchlib.aggregation.*; -import com.yahoo.searchlib.expression.*; +import com.yahoo.search.grouping.request.AllOperation; +import com.yahoo.search.grouping.request.AttributeValue; +import com.yahoo.search.grouping.request.EachOperation; +import com.yahoo.search.grouping.request.GroupingOperation; +import com.yahoo.search.grouping.request.SummaryValue; +import com.yahoo.searchlib.aggregation.AggregationResult; +import com.yahoo.searchlib.aggregation.CountAggregationResult; +import com.yahoo.searchlib.aggregation.ExpressionCountAggregationResult; +import com.yahoo.searchlib.aggregation.Group; +import com.yahoo.searchlib.aggregation.Grouping; +import com.yahoo.searchlib.aggregation.GroupingLevel; +import com.yahoo.searchlib.aggregation.HitsAggregationResult; +import com.yahoo.searchlib.aggregation.SumAggregationResult; +import com.yahoo.searchlib.expression.AddFunctionNode; +import com.yahoo.searchlib.expression.AttributeMapLookupNode; +import com.yahoo.searchlib.expression.AttributeNode; +import com.yahoo.searchlib.expression.ConstantNode; +import com.yahoo.searchlib.expression.ExpressionNode; +import com.yahoo.searchlib.expression.StrCatFunctionNode; +import com.yahoo.searchlib.expression.StringResultNode; +import com.yahoo.searchlib.expression.TimeStampFunctionNode; +import com.yahoo.searchlib.expression.ToStringFunctionNode; +import org.assertj.core.api.Assertions; import org.junit.Test; -import java.util.*; +import java.util.Arrays; +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; +import java.util.TimeZone; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -742,6 +767,46 @@ public class RequestBuilderTestCase { } } + @Test + public void require_that_total_groups_and_summaries_calculation_is_correct() { + assertTotalGroupsAndSummaries(5, "all(group(a) max(5) each(output(count())))"); + assertTotalGroupsAndSummaries(5+5*7, "all(group(a) max(5) each(max(7) each(output(summary()))))"); + assertTotalGroupsAndSummaries(3+3*5+3*5*7+3*5*7*11, + "all( group(a) max(3) each(output(count())" + + " all(group(b) max(5) each(output(count())" + + " all(group(c) max(7) each(max(11) output(count())" + + " each(output(summary()))))))))"); + assertTotalGroupsAndSummaries(2*(3+3*5), + "all(" + + " all(group(a) max(3) each(output(count()) max(5) each(output(summary())))) " + + " all(group(b) max(3) each(output(count()) max(5) each(output(summary())))))"); + } + + @Test + public void require_that_unbounded_queries_fails_when_global_max_is_enabled() { + assertQueryFailsOnGlobalMax(4, "all(group(a) max(5) each(output(count())))", "5 > 4"); + assertQueryFailsOnGlobalMax(Long.MAX_VALUE, "all(group(a) each(output(count())))", "unbounded number of groups"); + assertQueryFailsOnGlobalMax(Long.MAX_VALUE, "all(group(a) max(5) each(each(output(summary()))))", "unbounded number of summaries"); + } + + private static void assertTotalGroupsAndSummaries(long expected, String query) { + RequestBuilder builder = new RequestBuilder(0) + .setRootOperation(GroupingOperation.fromString(query)).setGlobalMaxGroups(Long.MAX_VALUE); + builder.build(); + assertEquals(expected, builder.totalGroupsAndSummaries().orElseThrow()); + } + + private static void assertQueryFailsOnGlobalMax(long globalMax, String query, String errorSubstring) { + RequestBuilder builder = new RequestBuilder(0) + .setRootOperation(GroupingOperation.fromString(query)).setGlobalMaxGroups(globalMax); + try { + builder.build(); + fail(); + } catch (IllegalInputException e) { + Assertions.assertThat(e.getMessage()).contains(errorSubstring); + } + } + private static CompositeContinuation newComposite(EncodableContinuation... conts) { CompositeContinuation ret = new CompositeContinuation(); for (EncodableContinuation cont : conts) { |