summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorBjørn Christian Seime <bjorncs@verizonmedia.com>2022-03-24 17:43:27 +0100
committerBjørn Christian Seime <bjorncs@verizonmedia.com>2022-03-24 17:43:27 +0100
commitf892050296c4ab4c231e20cec23285608cb4163c (patch)
tree1260edf7a4b4e37844c9fef5697ee355e1dd23d0 /container-search
parent51346c0606f7d859baf5b6aee1395e1926821d0a (diff)
Add 'grouping.globalMaxGroups' query parameter
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json3
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java8
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java15
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java1
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java75
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java7
-rw-r--r--container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java73
7 files changed, 172 insertions, 10 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 6249988a5ee..94baaf4ef52 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -2369,6 +2369,8 @@
"public void setDefaultMaxHits(int)",
"public java.util.OptionalInt defaultMaxGroups()",
"public void setDefaultMaxGroups(int)",
+ "public java.util.OptionalLong globalMaxGroups()",
+ "public void setGlobalMaxGroups(long)",
"public static com.yahoo.search.grouping.GroupingRequest newInstance(com.yahoo.search.Query)",
"public java.lang.String toString()"
],
@@ -6624,6 +6626,7 @@
"fields": [
"public static final com.yahoo.processing.request.CompoundName MAX_OFFSET",
"public static final com.yahoo.processing.request.CompoundName MAX_HITS",
+ "public static final com.yahoo.processing.request.CompoundName GROUPING_GLOBAL_MAX_GROUPS",
"public static final com.yahoo.search.query.profile.types.QueryProfileType argumentType"
]
},
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java
index 0de4e36eae5..ee78e41d0d8 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingQueryParser.java
@@ -5,15 +5,16 @@ import com.yahoo.api.annotations.Beta;
import com.yahoo.component.chain.dependencies.After;
import com.yahoo.component.chain.dependencies.Before;
import com.yahoo.component.chain.dependencies.Provides;
+import com.yahoo.processing.IllegalInputException;
import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.grouping.request.GroupingOperation;
import com.yahoo.search.query.Select;
+import com.yahoo.search.query.properties.DefaultProperties;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.search.searchchain.PhaseNames;
-import com.yahoo.processing.IllegalInputException;
import java.util.Collections;
import java.util.LinkedHashMap;
@@ -45,6 +46,10 @@ public class GroupingQueryParser extends Searcher {
@Override
public Result search(Query query, Execution execution) {
try {
+ if (query.getHttpRequest().getProperty(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS.toString()) != null) {
+ throw new IllegalInputException(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS + " must be specified in a query profile.");
+ }
+
String reqParam = query.properties().getString(PARAM_REQUEST);
if (reqParam == null) return execution.search(query);
@@ -57,6 +62,7 @@ public class GroupingQueryParser extends Searcher {
grpRequest.continuations().addAll(continuations);
grpRequest.setDefaultMaxGroups(query.properties().getInteger(PARAM_DEFAULT_MAX_GROUPS, -1));
grpRequest.setDefaultMaxHits(query.properties().getInteger(PARAM_DEFAULT_MAX_HITS, -1));
+ grpRequest.setGlobalMaxGroups(query.properties().getLong(DefaultProperties.GROUPING_GLOBAL_MAX_GROUPS));
}
return execution.search(query);
}
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java b/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java
index 0c163aaacae..9f5deb482db 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/GroupingRequest.java
@@ -15,6 +15,7 @@ import com.yahoo.search.result.Hit;
import java.util.ArrayList;
import java.util.List;
import java.util.OptionalInt;
+import java.util.OptionalLong;
import java.util.TimeZone;
/**
@@ -34,6 +35,7 @@ public class GroupingRequest {
private TimeZone timeZone;
private int defaultMaxHits = -1;
private int defaultMaxGroups = -1;
+ private long globalMaxGroups = -1;
private GroupingRequest(Select parent) {
this.parent = parent;
@@ -44,18 +46,20 @@ public class GroupingRequest {
GroupingOperation root,
TimeZone timeZone,
int defaultMaxHits,
- int defaultMaxGroups) {
+ int defaultMaxGroups,
+ long globalMaxGroups) {
this.parent = parent;
continuations.forEach(item -> this.continuations.add(item.copy()));
this.root = root != null ? root.copy(null) : null;
this.timeZone = timeZone;
this.defaultMaxHits = defaultMaxHits;
this.defaultMaxGroups = defaultMaxGroups;
+ this.globalMaxGroups = globalMaxGroups;
}
/** Returns a deep copy of this */
public GroupingRequest copy(Select parentOfCopy) {
- return new GroupingRequest(parentOfCopy, continuations, root, timeZone, defaultMaxHits, defaultMaxGroups);
+ return new GroupingRequest(parentOfCopy, continuations, root, timeZone, defaultMaxHits, defaultMaxGroups, globalMaxGroups);
}
/**
@@ -153,6 +157,13 @@ public class GroupingRequest {
@Beta public void setDefaultMaxGroups(int v) { this.defaultMaxGroups = v; }
+ @Beta
+ public OptionalLong globalMaxGroups() {
+ return globalMaxGroups >= 0 ? OptionalLong.of(globalMaxGroups) : OptionalLong.empty();
+ }
+
+ @Beta public void setGlobalMaxGroups(long v) { this.globalMaxGroups = v; }
+
/**
* Creates a new grouping request and adds it to the query.getSelect().getGrouping() list
*
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
index c09502110b1..c7000c0cdcf 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
@@ -152,6 +152,7 @@ public class GroupingExecutor extends Searcher {
builder.addContinuations(req.continuations());
req.defaultMaxGroups().ifPresent(builder::setDefaultMaxGroups);
req.defaultMaxHits().ifPresent(builder::setDefaultMaxHits);
+ req.globalMaxGroups().ifPresent(builder::setGlobalMaxGroups);
builder.build();
RequestContext ctx = new RequestContext(req, builder.getTransform());
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java
index 78452b36cd5..b013e87fb24 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/RequestBuilder.java
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.vespa;
+import com.yahoo.processing.IllegalInputException;
import com.yahoo.search.grouping.Continuation;
import com.yahoo.search.grouping.GroupingRequest;
import com.yahoo.search.grouping.request.AllOperation;
@@ -8,10 +9,22 @@ import com.yahoo.search.grouping.request.EachOperation;
import com.yahoo.search.grouping.request.GroupingExpression;
import com.yahoo.search.grouping.request.GroupingOperation;
import com.yahoo.search.grouping.request.NegFunction;
-import com.yahoo.searchlib.aggregation.*;
+import com.yahoo.searchlib.aggregation.AggregationResult;
+import com.yahoo.searchlib.aggregation.ExpressionCountAggregationResult;
+import com.yahoo.searchlib.aggregation.Group;
+import com.yahoo.searchlib.aggregation.Grouping;
+import com.yahoo.searchlib.aggregation.GroupingLevel;
+import com.yahoo.searchlib.aggregation.HitsAggregationResult;
import com.yahoo.searchlib.expression.ExpressionNode;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.OptionalLong;
+import java.util.Stack;
+import java.util.TimeZone;
+
+import static java.util.stream.Collectors.toList;
/**
* This class implements the necessary logic to build a list of {@link Grouping} objects from an instance of {@link
@@ -29,6 +42,8 @@ class RequestBuilder {
private int tag = 0;
private int defaultMaxHits = -1;
private int defaultMaxGroups = -1;
+ private long globalMaxGroups = -1;
+ private long totalGroupsAndSummaries = -1;
/**
* Constructs a new instance of this class.
@@ -128,6 +143,7 @@ class RequestBuilder {
}
}
pruneRequests();
+ validateGlobalMax();
}
public RequestBuilder addContinuations(Iterable<Continuation> continuations) {
@@ -144,6 +160,12 @@ class RequestBuilder {
public RequestBuilder setDefaultMaxHits(int v) { this.defaultMaxHits = v; return this; }
+ public RequestBuilder setGlobalMaxGroups(long v) { this.globalMaxGroups = v; return this; }
+
+ OptionalLong totalGroupsAndSummaries() {
+ return totalGroupsAndSummaries != -1 ? OptionalLong.of(totalGroupsAndSummaries) : OptionalLong.empty();
+ }
+
private void processRequestNode(BuildFrame frame) {
int level = frame.astNode.getLevel();
if (level > 2) {
@@ -377,6 +399,55 @@ class RequestBuilder {
}
}
+ private void validateGlobalMax() {
+ this.totalGroupsAndSummaries = -1;
+ if (globalMaxGroups < 0) return;
+ int totalGroupsAndSummaries = 0;
+ for (Grouping grp : requestList) {
+ int levelMultiplier = 1;
+ for (GroupingLevel lvl : grp.getLevels()) {
+ totalGroupsAndSummaries += (levelMultiplier *= validateSummaryMax(lvl));
+ var hars = hitsAggregationResult(lvl);
+ for (HitsAggregationResult har : hars) {
+ totalGroupsAndSummaries += (levelMultiplier * validateSummaryMax(har));
+ }
+ }
+ }
+ if (totalGroupsAndSummaries > globalMaxGroups)
+ throw new IllegalInputException(String.format(
+ "The theoretical total number of groups and summaries in grouping query exceeds " +
+ "'grouping.globalMaxGroups' ( %d > %d ). " +
+ "Either restrict group/summary counts with max() or disable 'grouping.globalMaxGroups'. " +
+ "See https://docs.vespa.ai/en/grouping.html for details.",
+ totalGroupsAndSummaries, globalMaxGroups));
+ this.totalGroupsAndSummaries = totalGroupsAndSummaries;
+ }
+
+ private int validateSummaryMax(GroupingLevel lvl) {
+ int max = transform.getMax(lvl.getGroupPrototype().getTag());
+ if (max <= 0) throw new IllegalInputException(
+ "Cannot return unbounded number of groups when 'grouping.globalMaxGroups' is enabled. " +
+ "Either restrict group count with max() or disable 'grouping.globalMaxGroups'. " +
+ "See https://docs.vespa.ai/en/grouping.html for details.");
+ return max;
+ }
+
+ private int validateSummaryMax(HitsAggregationResult res) {
+ int max = transform.getMax(res.getTag());
+ if (max <= 0) throw new IllegalInputException(
+ "Cannot return unbounded number of summaries when 'grouping.globalMaxGroups' is enabled. " +
+ "Either restrict summary count with max() or disable 'grouping.globalMaxGroups'. " +
+ "See https://docs.vespa.ai/en/grouping.html for details.");
+ return max;
+ }
+
+ private List<HitsAggregationResult> hitsAggregationResult(GroupingLevel lvl) {
+ return lvl.getGroupPrototype().getAggregationResults().stream()
+ .filter(ar -> ar instanceof HitsAggregationResult)
+ .map(ar -> (HitsAggregationResult) ar)
+ .collect(toList());
+ }
+
private static class BuildFrame {
final Grouping grouping;
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java
index b38e3070e28..f25b61c3265 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/DefaultProperties.java
@@ -1,10 +1,12 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query.properties;
+import com.yahoo.api.annotations.Beta;
import com.yahoo.processing.request.CompoundName;
import com.yahoo.search.query.Properties;
import com.yahoo.search.query.profile.types.FieldDescription;
import com.yahoo.search.query.profile.types.QueryProfileType;
+import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry;
import java.util.Map;
@@ -17,7 +19,7 @@ public final class DefaultProperties extends Properties {
public static final CompoundName MAX_OFFSET = new CompoundName("maxOffset");
public static final CompoundName MAX_HITS = new CompoundName("maxHits");
-
+ @Beta public static final CompoundName GROUPING_GLOBAL_MAX_GROUPS = new CompoundName("grouping.globalMaxGroups");
public static final QueryProfileType argumentType = new QueryProfileType("DefaultProperties");
@@ -26,6 +28,7 @@ public final class DefaultProperties extends Properties {
argumentType.addField(new FieldDescription(MAX_OFFSET.toString(), "integer"));
argumentType.addField(new FieldDescription(MAX_HITS.toString(), "integer"));
+ argumentType.addField(new FieldDescription(GROUPING_GLOBAL_MAX_GROUPS.toString(), "long"), new QueryProfileTypeRegistry());
argumentType.freeze();
}
@@ -36,6 +39,8 @@ public final class DefaultProperties extends Properties {
return 1000;
} else if (MAX_HITS.equals(name)) {
return 400;
+ } else if (GROUPING_GLOBAL_MAX_GROUPS.equals(name)) {
+ return -1; // TODO Vespa 8: use default from Vespa 8 release notes
} else {
return super.get(name, context, substitution);
}
diff --git a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java
index c1b9e74757b..ccf11d82541 100644
--- a/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/grouping/vespa/RequestBuilderTestCase.java
@@ -1,13 +1,38 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.grouping.vespa;
+import com.yahoo.processing.IllegalInputException;
import com.yahoo.search.grouping.Continuation;
-import com.yahoo.search.grouping.request.*;
-import com.yahoo.searchlib.aggregation.*;
-import com.yahoo.searchlib.expression.*;
+import com.yahoo.search.grouping.request.AllOperation;
+import com.yahoo.search.grouping.request.AttributeValue;
+import com.yahoo.search.grouping.request.EachOperation;
+import com.yahoo.search.grouping.request.GroupingOperation;
+import com.yahoo.search.grouping.request.SummaryValue;
+import com.yahoo.searchlib.aggregation.AggregationResult;
+import com.yahoo.searchlib.aggregation.CountAggregationResult;
+import com.yahoo.searchlib.aggregation.ExpressionCountAggregationResult;
+import com.yahoo.searchlib.aggregation.Group;
+import com.yahoo.searchlib.aggregation.Grouping;
+import com.yahoo.searchlib.aggregation.GroupingLevel;
+import com.yahoo.searchlib.aggregation.HitsAggregationResult;
+import com.yahoo.searchlib.aggregation.SumAggregationResult;
+import com.yahoo.searchlib.expression.AddFunctionNode;
+import com.yahoo.searchlib.expression.AttributeMapLookupNode;
+import com.yahoo.searchlib.expression.AttributeNode;
+import com.yahoo.searchlib.expression.ConstantNode;
+import com.yahoo.searchlib.expression.ExpressionNode;
+import com.yahoo.searchlib.expression.StrCatFunctionNode;
+import com.yahoo.searchlib.expression.StringResultNode;
+import com.yahoo.searchlib.expression.TimeStampFunctionNode;
+import com.yahoo.searchlib.expression.ToStringFunctionNode;
+import org.assertj.core.api.Assertions;
import org.junit.Test;
-import java.util.*;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.TimeZone;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -742,6 +767,46 @@ public class RequestBuilderTestCase {
}
}
+ @Test
+ public void require_that_total_groups_and_summaries_calculation_is_correct() {
+ assertTotalGroupsAndSummaries(5, "all(group(a) max(5) each(output(count())))");
+ assertTotalGroupsAndSummaries(5+5*7, "all(group(a) max(5) each(max(7) each(output(summary()))))");
+ assertTotalGroupsAndSummaries(3+3*5+3*5*7+3*5*7*11,
+ "all( group(a) max(3) each(output(count())" +
+ " all(group(b) max(5) each(output(count())" +
+ " all(group(c) max(7) each(max(11) output(count())" +
+ " each(output(summary()))))))))");
+ assertTotalGroupsAndSummaries(2*(3+3*5),
+ "all(" +
+ " all(group(a) max(3) each(output(count()) max(5) each(output(summary())))) " +
+ " all(group(b) max(3) each(output(count()) max(5) each(output(summary())))))");
+ }
+
+ @Test
+ public void require_that_unbounded_queries_fails_when_global_max_is_enabled() {
+ assertQueryFailsOnGlobalMax(4, "all(group(a) max(5) each(output(count())))", "5 > 4");
+ assertQueryFailsOnGlobalMax(Long.MAX_VALUE, "all(group(a) each(output(count())))", "unbounded number of groups");
+ assertQueryFailsOnGlobalMax(Long.MAX_VALUE, "all(group(a) max(5) each(each(output(summary()))))", "unbounded number of summaries");
+ }
+
+ private static void assertTotalGroupsAndSummaries(long expected, String query) {
+ RequestBuilder builder = new RequestBuilder(0)
+ .setRootOperation(GroupingOperation.fromString(query)).setGlobalMaxGroups(Long.MAX_VALUE);
+ builder.build();
+ assertEquals(expected, builder.totalGroupsAndSummaries().orElseThrow());
+ }
+
+ private static void assertQueryFailsOnGlobalMax(long globalMax, String query, String errorSubstring) {
+ RequestBuilder builder = new RequestBuilder(0)
+ .setRootOperation(GroupingOperation.fromString(query)).setGlobalMaxGroups(globalMax);
+ try {
+ builder.build();
+ fail();
+ } catch (IllegalInputException e) {
+ Assertions.assertThat(e.getMessage()).contains(errorSubstring);
+ }
+ }
+
private static CompositeContinuation newComposite(EncodableContinuation... conts) {
CompositeContinuation ret = new CompositeContinuation();
for (EncodableContinuation cont : conts) {