summaryrefslogtreecommitdiffstats
path: root/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
diff options
context:
space:
mode:
Diffstat (limited to 'predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java')
-rw-r--r--predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java134
1 files changed, 134 insertions, 0 deletions
diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
new file mode 100644
index 00000000000..d19357cd8ab
--- /dev/null
+++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java
@@ -0,0 +1,134 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.search.predicate.index;
+
+import com.google.common.collect.MinMaxPriorityQueue;
+import com.gs.collections.api.tuple.primitive.ObjectLongPair;
+import com.gs.collections.impl.map.mutable.primitive.ObjectIntHashMap;
+import com.gs.collections.impl.map.mutable.primitive.ObjectLongHashMap;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Counts the number of posting lists per document id.
+ * Caches the most expensive posting list in a bit vector.
+ *
+ * @author bjorncs
+ */
+public class CachedPostingListCounter {
+ // Only use bit vector for counting if the documents covered is more than the threshold (relative to nDocuments)
+ private static final double THRESHOLD_USE_BIT_VECTOR = 1;
+
+ private final int nDocuments;
+ private final ObjectLongHashMap<int[]> frequency = new ObjectLongHashMap<>();
+ private final ObjectIntHashMap<int[]> postingListMapping;
+ private final int[] bitVector;
+
+ public CachedPostingListCounter(int nDocuments) {
+ this.nDocuments = nDocuments;
+ this.postingListMapping = new ObjectIntHashMap<>();
+ this.bitVector = new int[0];
+ }
+
+ private CachedPostingListCounter(ObjectIntHashMap<int[]> postingListMapping, int[] bitVector) {
+ this.nDocuments = bitVector.length;
+ this.postingListMapping = postingListMapping;
+ this.bitVector = bitVector;
+ }
+
+ public synchronized void registerUsage(List<PostingList> postingLists) {
+ for (PostingList postingList : postingLists) {
+ frequency.updateValue(postingList.getDocIds(), 0, v -> v + 1);
+ }
+ }
+
+ public void countPostingListsPerDocument(List<PostingList> postingLists, byte[] nPostingListsForDocument) {
+ Arrays.fill(nPostingListsForDocument, (byte) 0);
+ List<int[]> nonCachedPostingLists = new ArrayList<>(postingLists.size());
+ List<int[]> cachedPostingLists = new ArrayList<>(postingLists.size());
+ long nDocumentsCachedPostingLists = 0;
+ int postingListBitmap = 0;
+ for (PostingList postingList : postingLists) {
+ int[] docIds = postingList.getDocIds();
+ int index = postingListMapping.getIfAbsent(docIds, -1);
+ if (index >= 0) {
+ cachedPostingLists.add(docIds);
+ postingListBitmap |= (1 << index);
+ nDocumentsCachedPostingLists += docIds.length;
+ } else {
+ nonCachedPostingLists.add(docIds);
+ }
+ }
+ if (postingListBitmap != 0) {
+ if (nDocumentsCachedPostingLists > nDocuments * THRESHOLD_USE_BIT_VECTOR) {
+ countUsingBitVector(nPostingListsForDocument, postingListBitmap);
+ } else {
+ nonCachedPostingLists.addAll(cachedPostingLists);
+ }
+ }
+ if (!nonCachedPostingLists.isEmpty()) {
+ countUsingDocIdIteration(nPostingListsForDocument, nonCachedPostingLists);
+ }
+ }
+
+ private void countUsingBitVector(byte[] nPostingListsForDocument, int postingListBitmap) {
+ for (int docId = 0; docId < nDocuments; docId++) {
+ nPostingListsForDocument[docId] += Integer.bitCount(bitVector[docId] & postingListBitmap);
+ }
+ }
+
+ private static void countUsingDocIdIteration(byte[] nPostingListsForDocument, List<int[]> nonCachedPostingLists) {
+ for (int[] docIds : nonCachedPostingLists) {
+ for (int docId : docIds) {
+ ++nPostingListsForDocument[docId];
+ }
+ }
+ }
+
+ public CachedPostingListCounter rebuildCache() {
+ MinMaxPriorityQueue<Entry> mostExpensive = MinMaxPriorityQueue
+ .maximumSize(32).expectedSize(32).create();
+ synchronized (this) {
+ for (ObjectLongPair<int[]> p : frequency.keyValuesView()) {
+ mostExpensive.add(new Entry(p.getOne(), p.getTwo()));
+ }
+ }
+ ObjectIntHashMap<int[]> postingListMapping = new ObjectIntHashMap<>();
+ int[] bitVector = new int[nDocuments];
+ int length = mostExpensive.size();
+ for (int i = 0; i < length; i++) {
+ Entry e = mostExpensive.removeFirst();
+ int[] docIds = e.docIds;
+ postingListMapping.put(docIds, i);
+ for (int docId : docIds) {
+ bitVector[docId] |= (1 << i);
+ }
+ }
+ return new CachedPostingListCounter(postingListMapping, bitVector);
+ }
+
+ int[] getBitVector() {
+ return bitVector;
+ }
+
+ ObjectIntHashMap<int[]> getPostingListMapping() {
+ return postingListMapping;
+ }
+
+ private static class Entry implements Comparable<Entry> {
+ public final int[] docIds;
+ public final double cost;
+
+ private Entry(int[] docIds, long frequency) {
+ this.docIds = docIds;
+ this.cost = docIds.length * (double) frequency;
+ assert cost > 0;
+ }
+
+ @Override
+ public int compareTo(Entry o) {
+ return -Double.compare(cost, o.cost);
+ }
+ }
+}