diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2020-04-16 02:28:52 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-16 02:28:52 +0200 |
commit | b2519d490b2bb1d4e28e9c6c5e3ed72ee16b5469 (patch) | |
tree | c3a4fa47a30163d46bc01b9c97c539da9593cca8 /container-search | |
parent | bdb570a9e21410108bbb56f183bad1603c45c1fc (diff) | |
parent | cc436f402118300a5ffba223480cd63da2345008 (diff) |
Merge pull request #12918 from vespa-engine/balder/top-k-probability
Introduce top-k-probability and use it to fetch correct proper amount…
Diffstat (limited to 'container-search')
9 files changed, 129 insertions, 1 deletions
diff --git a/container-search/pom.xml b/container-search/pom.xml index 84ee5b2bc65..6fa32947869 100644 --- a/container-search/pom.xml +++ b/container-search/pom.xml @@ -132,6 +132,11 @@ <scope>compile</scope> </dependency> <dependency> + <groupId>org.apache.commons</groupId> + <artifactId>commons-math3</artifactId> + <scope>compile</scope> + </dependency> + <dependency> <groupId>javax.xml.bind</groupId> <artifactId>jaxb-api</artifactId> <scope>test</scope> diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 9b42ce03e6d..626cf087aca 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -51,6 +51,7 @@ public class Dispatcher extends AbstractComponent { public static final String DISPATCH = "dispatch"; private static final String INTERNAL = "internal"; private static final String PROTOBUF = "protobuf"; + private static final String TOP_K_PROBABILITY = "topKProbability"; private static final String INTERNAL_METRIC = "dispatch_internal"; @@ -59,6 +60,9 @@ public class Dispatcher extends AbstractComponent { /** If enabled, search queries will use protobuf rpc */ public static final CompoundName dispatchProtobuf = CompoundName.fromComponents(DISPATCH, PROTOBUF); + /** If set will control computation of how many hits will be fetched from each partition.*/ + public static final CompoundName topKProbability = CompoundName.fromComponents(DISPATCH, TOP_K_PROBABILITY); + /** A model of the search cluster this dispatches to */ private final SearchCluster searchCluster; private final ClusterMonitor clusterMonitor; @@ -80,6 +84,7 @@ public class Dispatcher extends AbstractComponent { argumentType.setBuiltin(true); argumentType.addField(new FieldDescription(INTERNAL, FieldType.booleanType)); argumentType.addField(new FieldDescription(PROTOBUF, FieldType.booleanType)); + argumentType.addField(new FieldDescription(TOP_K_PROBABILITY, FieldType.doubleType)); argumentType.freeze(); } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index cec3e94d551..e62848a7f9e 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -81,7 +81,12 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM int originalHits = query.getHits(); int originalOffset = query.getOffset(); - query.setHits(query.getHits() + query.getOffset()); + int neededHits = originalHits + originalOffset; + Double topkProbabilityOverrride = query.properties().getDouble(Dispatcher.topKProbability); + int q = (topkProbabilityOverrride != null) + ? searchCluster.estimateHitsToFetch(neededHits, invokers.size(), topkProbabilityOverrride) + : searchCluster.estimateHitsToFetch(neededHits, invokers.size()); + query.setHits(q); query.setOffset(0); for (SearchInvoker invoker : invokers) { @@ -321,4 +326,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM protected LinkedBlockingQueue<SearchInvoker> newQueue() { return new LinkedBlockingQueue<>(); } + + // For testing + Collection<SearchInvoker> invokers() { return invokers; } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java b/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java new file mode 100644 index 00000000000..8003d9c6744 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java @@ -0,0 +1,42 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import org.apache.commons.math3.distribution.TDistribution; + +/** + * Use StudentT distribution and estimate how many hits you need from each partition + * to to get the globally top-k documents with the desired probability + * @author baldersheim + */ +public class TopKEstimator { + private final TDistribution studentT; + private final double defaultP; + private final boolean estimate; + + private static boolean needEstimate(double p) { + return (0.0 < p) && (p < 1.0); + } + public TopKEstimator(double freedom, double defaultProbability) { + this.studentT = new TDistribution(null, freedom); + defaultP = defaultProbability; + estimate = needEstimate(defaultP); + } + double estimateExactK(double k, double n, double p) { + double variance = k * 1/n * (1 - 1/n); + double p_inverse = 1 - (1 - p)/n; + return k/n + studentT.inverseCumulativeProbability(p_inverse) * Math.sqrt(variance); + } + double estimateExactK(double k, double n) { + return estimateExactK(k, n, defaultP); + } + public int estimateK(int k, int n) { + return (estimate && n > 1) + ? (int)Math.ceil(estimateExactK(k, n, defaultP)) + : k; + } + public int estimateK(int k, int n, double p) { + return (needEstimate(p) && (n > 1)) + ? (int)Math.ceil(estimateExactK(k, n, p)) + : k; + } +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java index 27b4472e324..7dfc03fd2d7 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java @@ -10,6 +10,7 @@ import com.yahoo.net.HostName; import com.yahoo.prelude.Pong; import com.yahoo.search.cluster.ClusterMonitor; import com.yahoo.search.cluster.NodeManager; +import com.yahoo.search.dispatch.TopKEstimator; import com.yahoo.vespa.config.search.DispatchConfig; import java.util.LinkedHashMap; @@ -38,6 +39,7 @@ public class SearchCluster implements NodeManager<Node> { private final ImmutableList<Group> orderedGroups; private final VipStatus vipStatus; private final PingFactory pingFactory; + private final TopKEstimator hitEstimator; private long nextLogTime = 0; /** @@ -76,6 +78,7 @@ public class SearchCluster implements NodeManager<Node> { for (Node node : nodes) nodesByHostBuilder.put(node.hostname(), node); this.nodesByHost = nodesByHostBuilder.build(); + hitEstimator = new TopKEstimator(30.0, dispatchConfig.topKProbability()); this.localCorpusDispatchTarget = findLocalCorpusDispatchTarget(HostName.getLocalhost(), size, @@ -240,6 +243,13 @@ public class SearchCluster implements NodeManager<Node> { vipStatus.removeFromRotation(clusterId); } + public int estimateHitsToFetch(int wantedHits, int numPartitions) { + return hitEstimator.estimateK(wantedHits, numPartitions); + } + public int estimateHitsToFetch(int wantedHits, int numPartitions, double topKProbability) { + return hitEstimator.estimateK(wantedHits, numPartitions, topKProbability); + } + public boolean hasInformationAboutAllNodes() { return nodesByHost.values().stream().allMatch(node -> node.isWorking() != null); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java index 27685426cf8..e16f09a58ab 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -204,6 +204,33 @@ public class InterleavedSearchInvokerTest { private static final List<Double> A5Aux = Arrays.asList(-1.0,11.0,8.5,7.5,-7.0,3.0,2.0); private static final List<Double> B5Aux = Arrays.asList(9.0,8.0,-3.0,7.0,6.0,1.0, -1.0); + private void validateThatTopKProbabilityOverrideTakesEffect(Double topKProbability, int expectedK) throws IOException { + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); + query.setHits(8); + query.properties().set(Dispatcher.topKProbability, topKProbability); + SearchInvoker [] invokers = invoker.invokers().toArray(new SearchInvoker[0]); + Result result = invoker.search(query, null); + assertEquals(2, invokers.length); + assertEquals(expectedK, ((MockInvoker)invokers[0]).hitsRequested); + assertEquals(8, result.hits().size()); + assertEquals(11.0, result.hits().get(0).getRelevance().getScore(), DELTA); + assertEquals(9.0, result.hits().get(1).getRelevance().getScore(), DELTA); + assertEquals(8.5, result.hits().get(2).getRelevance().getScore(), DELTA); + assertEquals(8.0, result.hits().get(3).getRelevance().getScore(), DELTA); + assertEquals(7.5, result.hits().get(4).getRelevance().getScore(), DELTA); + assertEquals(7.0, result.hits().get(5).getRelevance().getScore(), DELTA); + assertEquals(6.0, result.hits().get(6).getRelevance().getScore(), DELTA); + assertEquals(3.0, result.hits().get(7).getRelevance().getScore(), DELTA); + assertEquals(0, result.getQuery().getOffset()); + assertEquals(8, result.getQuery().getHits()); + } + + @Test + public void requireThatTopKProbabilityOverrideTakesEffect() throws IOException { + validateThatTopKProbabilityOverrideTakesEffect(null, 8); + validateThatTopKProbabilityOverrideTakesEffect(0.8, 6); + } + @Test public void requireThatMergeOfConcreteHitsObeySorting() throws IOException { InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java index c5fbda7c2f5..c159293d7d9 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java @@ -17,6 +17,7 @@ class MockInvoker extends SearchInvoker { private final Coverage coverage; private Query query; private List<Hit> hits; + int hitsRequested; protected MockInvoker(int key, Coverage coverage) { super(Optional.of(new Node(key, "?", 0))); @@ -35,6 +36,7 @@ class MockInvoker extends SearchInvoker { @Override protected void sendSearchRequest(Query query) throws IOException { this.query = query; + hitsRequested = query.getHits(); } @Override diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java new file mode 100644 index 00000000000..c14e4f984f1 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java @@ -0,0 +1,28 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TopKEstimatorTest { + @Test + public void requireHitsAreEstimatedAccordingToPartitionsAndProbability() { + TopKEstimator estimator = new TopKEstimator(30, 0.999); + assertEquals(91.97368471911312, estimator.estimateExactK(200, 3), 0.0); + assertEquals(92, estimator.estimateK(200, 3)); + assertEquals(37.96328109101396, estimator.estimateExactK(200, 10), 0.0); + assertEquals(38, estimator.estimateK(200, 10)); + assertEquals(23.815737601023095, estimator.estimateExactK(200, 20), 0.0); + assertEquals(24, estimator.estimateK(200, 20)); + + assertEquals(37.96328109101396, estimator.estimateExactK(200, 10, 0.999), 0.0); + assertEquals(38, estimator.estimateK(200, 10, 0.999)); + assertEquals(34.36212304875885, estimator.estimateExactK(200, 10, 0.99), 0.0); + assertEquals(35, estimator.estimateK(200, 10, 0.99)); + assertEquals(41.44244358524574, estimator.estimateExactK(200, 10, 0.9999), 0.0); + assertEquals(42, estimator.estimateK(200, 10, 0.9999)); + assertEquals(44.909040374464155, estimator.estimateExactK(200, 10, 0.99999), 0.0); + assertEquals(45, estimator.estimateK(200, 10, 0.99999)); + } +} diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java index ad281aeda7d..09024150a9a 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java @@ -8,6 +8,7 @@ import com.yahoo.net.HostName; import com.yahoo.prelude.Pong; import com.yahoo.search.cluster.ClusterMonitor; import com.yahoo.search.dispatch.MockSearchCluster; +import com.yahoo.search.dispatch.TopKEstimator; import com.yahoo.search.result.ErrorMessage; import org.junit.Test; |