From d07b20d7655d74f0460abc5dfceb7039bb8ec371 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 15 Apr 2020 13:32:56 +0000 Subject: Add query control of top-k-probability. --- .../java/com/yahoo/search/dispatch/Dispatcher.java | 5 ++++ .../search/dispatch/InterleavedSearchInvoker.java | 10 +++++++- .../com/yahoo/search/dispatch/TopKEstimator.java | 23 +++++++++++++----- .../dispatch/searchcluster/SearchCluster.java | 3 +++ .../dispatch/InterleavedSearchInvokerTest.java | 27 ++++++++++++++++++++++ .../com/yahoo/search/dispatch/MockInvoker.java | 2 ++ .../yahoo/search/dispatch/TopKEstimatorTest.java | 9 ++++++++ 7 files changed, 72 insertions(+), 7 deletions(-) (limited to 'container-search/src') diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java index 9c46d194fb3..3c26612e8e1 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/Dispatcher.java @@ -50,6 +50,7 @@ public class Dispatcher extends AbstractComponent { public static final String DISPATCH = "dispatch"; private static final String INTERNAL = "internal"; private static final String PROTOBUF = "protobuf"; + private static final String TOP_K_PROBABILITY = "top-k-probability"; private static final String INTERNAL_METRIC = "dispatch_internal"; @@ -58,6 +59,9 @@ public class Dispatcher extends AbstractComponent { /** If enabled, search queries will use protobuf rpc */ public static final CompoundName dispatchProtobuf = CompoundName.fromComponents(DISPATCH, PROTOBUF); + /** If set will control computation of how many hits will be fetched from each partition.*/ + public static final CompoundName topKProbability = CompoundName.fromComponents(DISPATCH, TOP_K_PROBABILITY); + /** A model of the search cluster this dispatches to */ private final SearchCluster searchCluster; private final ClusterMonitor clusterMonitor; @@ -79,6 +83,7 @@ public class Dispatcher extends AbstractComponent { argumentType.setBuiltin(true); argumentType.addField(new FieldDescription(INTERNAL, FieldType.booleanType)); argumentType.addField(new FieldDescription(PROTOBUF, FieldType.booleanType)); + argumentType.addField(new FieldDescription(TOP_K_PROBABILITY, FieldType.doubleType)); argumentType.freeze(); } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index bae1eb03e5f..e62848a7f9e 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -81,7 +81,12 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM int originalHits = query.getHits(); int originalOffset = query.getOffset(); - query.setHits(searchCluster.estimateHitsToFetch(query.getHits() + query.getOffset(), invokers.size())); + int neededHits = originalHits + originalOffset; + Double topkProbabilityOverrride = query.properties().getDouble(Dispatcher.topKProbability); + int q = (topkProbabilityOverrride != null) + ? searchCluster.estimateHitsToFetch(neededHits, invokers.size(), topkProbabilityOverrride) + : searchCluster.estimateHitsToFetch(neededHits, invokers.size()); + query.setHits(q); query.setOffset(0); for (SearchInvoker invoker : invokers) { @@ -321,4 +326,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM protected LinkedBlockingQueue newQueue() { return new LinkedBlockingQueue<>(); } + + // For testing + Collection invokers() { return invokers; } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java b/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java index 374f919e2bb..2a84481fdf3 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/TopKEstimator.java @@ -9,22 +9,33 @@ import org.apache.commons.math3.distribution.TDistribution; */ public class TopKEstimator { private final TDistribution studentT; - private final double p; + private final double defaultP; private final boolean estimate; - public TopKEstimator(double freedom, double wantedprobability) { + private static boolean needEstimate(double p) { + return (0.0 < p) && (p < 1.0); + } + public TopKEstimator(double freedom, double defaultProbability) { this.studentT = new TDistribution(null, freedom); - p = wantedprobability; - estimate = (0.0 < p) && (p < 1.0); + defaultP = defaultProbability; + estimate = needEstimate(defaultP); } - double estimateExactK(double k, double n) { + double estimateExactK(double k, double n, double p) { double variance = k * 1/n * (1 - 1/n); double p_inverse = 1 - (1 - p)/n; return k/n + studentT.inverseCumulativeProbability(p_inverse) * Math.sqrt(variance); } + double estimateExactK(double k, double n) { + return estimateExactK(k, n, defaultP); + } public int estimateK(int k, int n) { return (estimate && n > 1) - ? (int)Math.ceil(estimateExactK(k, n)) + ? (int)Math.ceil(estimateExactK(k, n, defaultP)) + : k; + } + public int estimateK(int k, int n, double p) { + return (needEstimate(p) && (n > 1)) + ? (int)Math.ceil(estimateExactK(k, n, p)) : k; } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java index 5acafb9e0a5..f31fd666ae9 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java @@ -246,6 +246,9 @@ public class SearchCluster implements NodeManager { public int estimateHitsToFetch(int wantedHits, int numPartitions) { return hitEstimator.estimateK(wantedHits, numPartitions); } + public int estimateHitsToFetch(int wantedHits, int numPartitions, double topKProbability) { + return hitEstimator.estimateK(wantedHits, numPartitions, topKProbability); + } public boolean hasInformationAboutAllNodes() { return nodesByHost.values().stream().allMatch(node -> node.isWorking() != null); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java index 27685426cf8..e16f09a58ab 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -204,6 +204,33 @@ public class InterleavedSearchInvokerTest { private static final List A5Aux = Arrays.asList(-1.0,11.0,8.5,7.5,-7.0,3.0,2.0); private static final List B5Aux = Arrays.asList(9.0,8.0,-3.0,7.0,6.0,1.0, -1.0); + private void validateThatTopKProbabilityOverrideTakesEffect(Double topKProbability, int expectedK) throws IOException { + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); + query.setHits(8); + query.properties().set(Dispatcher.topKProbability, topKProbability); + SearchInvoker [] invokers = invoker.invokers().toArray(new SearchInvoker[0]); + Result result = invoker.search(query, null); + assertEquals(2, invokers.length); + assertEquals(expectedK, ((MockInvoker)invokers[0]).hitsRequested); + assertEquals(8, result.hits().size()); + assertEquals(11.0, result.hits().get(0).getRelevance().getScore(), DELTA); + assertEquals(9.0, result.hits().get(1).getRelevance().getScore(), DELTA); + assertEquals(8.5, result.hits().get(2).getRelevance().getScore(), DELTA); + assertEquals(8.0, result.hits().get(3).getRelevance().getScore(), DELTA); + assertEquals(7.5, result.hits().get(4).getRelevance().getScore(), DELTA); + assertEquals(7.0, result.hits().get(5).getRelevance().getScore(), DELTA); + assertEquals(6.0, result.hits().get(6).getRelevance().getScore(), DELTA); + assertEquals(3.0, result.hits().get(7).getRelevance().getScore(), DELTA); + assertEquals(0, result.getQuery().getOffset()); + assertEquals(8, result.getQuery().getHits()); + } + + @Test + public void requireThatTopKProbabilityOverrideTakesEffect() throws IOException { + validateThatTopKProbabilityOverrideTakesEffect(null, 8); + validateThatTopKProbabilityOverrideTakesEffect(0.8, 6); + } + @Test public void requireThatMergeOfConcreteHitsObeySorting() throws IOException { InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java b/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java index c5fbda7c2f5..c159293d7d9 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/MockInvoker.java @@ -17,6 +17,7 @@ class MockInvoker extends SearchInvoker { private final Coverage coverage; private Query query; private List hits; + int hitsRequested; protected MockInvoker(int key, Coverage coverage) { super(Optional.of(new Node(key, "?", 0))); @@ -35,6 +36,7 @@ class MockInvoker extends SearchInvoker { @Override protected void sendSearchRequest(Query query) throws IOException { this.query = query; + hitsRequested = query.getHits(); } @Override diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java index 6ef28119c23..0d742f8c739 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/TopKEstimatorTest.java @@ -14,5 +14,14 @@ public class TopKEstimatorTest { assertEquals(38, estimator.estimateK(200, 10)); assertEquals(23.815737601023095, estimator.estimateExactK(200, 20), 0.0); assertEquals(24, estimator.estimateK(200, 20)); + + assertEquals(37.96328109101396, estimator.estimateExactK(200, 10, 0.999), 0.0); + assertEquals(38, estimator.estimateK(200, 10, 0.999)); + assertEquals(34.36212304875885, estimator.estimateExactK(200, 10, 0.99), 0.0); + assertEquals(35, estimator.estimateK(200, 10, 0.99)); + assertEquals(41.44244358524574, estimator.estimateExactK(200, 10, 0.9999), 0.0); + assertEquals(42, estimator.estimateK(200, 10, 0.9999)); + assertEquals(44.909040374464155, estimator.estimateExactK(200, 10, 0.99999), 0.0); + assertEquals(45, estimator.estimateK(200, 10, 0.99999)); } } -- cgit v1.2.3