diff options
Diffstat (limited to 'predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java')
-rw-r--r-- | predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java | 134 |
1 files changed, 134 insertions, 0 deletions
diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java new file mode 100644 index 00000000000..d19357cd8ab --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java @@ -0,0 +1,134 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.google.common.collect.MinMaxPriorityQueue; +import com.gs.collections.api.tuple.primitive.ObjectLongPair; +import com.gs.collections.impl.map.mutable.primitive.ObjectIntHashMap; +import com.gs.collections.impl.map.mutable.primitive.ObjectLongHashMap; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Counts the number of posting lists per document id. + * Caches the most expensive posting list in a bit vector. + * + * @author bjorncs + */ +public class CachedPostingListCounter { + // Only use bit vector for counting if the documents covered is more than the threshold (relative to nDocuments) + private static final double THRESHOLD_USE_BIT_VECTOR = 1; + + private final int nDocuments; + private final ObjectLongHashMap<int[]> frequency = new ObjectLongHashMap<>(); + private final ObjectIntHashMap<int[]> postingListMapping; + private final int[] bitVector; + + public CachedPostingListCounter(int nDocuments) { + this.nDocuments = nDocuments; + this.postingListMapping = new ObjectIntHashMap<>(); + this.bitVector = new int[0]; + } + + private CachedPostingListCounter(ObjectIntHashMap<int[]> postingListMapping, int[] bitVector) { + this.nDocuments = bitVector.length; + this.postingListMapping = postingListMapping; + this.bitVector = bitVector; + } + + public synchronized void registerUsage(List<PostingList> postingLists) { + for (PostingList postingList : postingLists) { + frequency.updateValue(postingList.getDocIds(), 0, v -> v + 1); + } + } + + public void countPostingListsPerDocument(List<PostingList> postingLists, byte[] nPostingListsForDocument) { + Arrays.fill(nPostingListsForDocument, (byte) 0); + List<int[]> nonCachedPostingLists = new ArrayList<>(postingLists.size()); + List<int[]> cachedPostingLists = new ArrayList<>(postingLists.size()); + long nDocumentsCachedPostingLists = 0; + int postingListBitmap = 0; + for (PostingList postingList : postingLists) { + int[] docIds = postingList.getDocIds(); + int index = postingListMapping.getIfAbsent(docIds, -1); + if (index >= 0) { + cachedPostingLists.add(docIds); + postingListBitmap |= (1 << index); + nDocumentsCachedPostingLists += docIds.length; + } else { + nonCachedPostingLists.add(docIds); + } + } + if (postingListBitmap != 0) { + if (nDocumentsCachedPostingLists > nDocuments * THRESHOLD_USE_BIT_VECTOR) { + countUsingBitVector(nPostingListsForDocument, postingListBitmap); + } else { + nonCachedPostingLists.addAll(cachedPostingLists); + } + } + if (!nonCachedPostingLists.isEmpty()) { + countUsingDocIdIteration(nPostingListsForDocument, nonCachedPostingLists); + } + } + + private void countUsingBitVector(byte[] nPostingListsForDocument, int postingListBitmap) { + for (int docId = 0; docId < nDocuments; docId++) { + nPostingListsForDocument[docId] += Integer.bitCount(bitVector[docId] & postingListBitmap); + } + } + + private static void countUsingDocIdIteration(byte[] nPostingListsForDocument, List<int[]> nonCachedPostingLists) { + for (int[] docIds : nonCachedPostingLists) { + for (int docId : docIds) { + ++nPostingListsForDocument[docId]; + } + } + } + + public CachedPostingListCounter rebuildCache() { + MinMaxPriorityQueue<Entry> mostExpensive = MinMaxPriorityQueue + .maximumSize(32).expectedSize(32).create(); + synchronized (this) { + for (ObjectLongPair<int[]> p : frequency.keyValuesView()) { + mostExpensive.add(new Entry(p.getOne(), p.getTwo())); + } + } + ObjectIntHashMap<int[]> postingListMapping = new ObjectIntHashMap<>(); + int[] bitVector = new int[nDocuments]; + int length = mostExpensive.size(); + for (int i = 0; i < length; i++) { + Entry e = mostExpensive.removeFirst(); + int[] docIds = e.docIds; + postingListMapping.put(docIds, i); + for (int docId : docIds) { + bitVector[docId] |= (1 << i); + } + } + return new CachedPostingListCounter(postingListMapping, bitVector); + } + + int[] getBitVector() { + return bitVector; + } + + ObjectIntHashMap<int[]> getPostingListMapping() { + return postingListMapping; + } + + private static class Entry implements Comparable<Entry> { + public final int[] docIds; + public final double cost; + + private Entry(int[] docIds, long frequency) { + this.docIds = docIds; + this.cost = docIds.length * (double) frequency; + assert cost > 0; + } + + @Override + public int compareTo(Entry o) { + return -Double.compare(cost, o.cost); + } + } +} |