diff options
-rw-r--r-- | searchlib/src/vespa/searchlib/attribute/attributeiterators.hpp | 87 |
1 files changed, 77 insertions, 10 deletions
diff --git a/searchlib/src/vespa/searchlib/attribute/attributeiterators.hpp b/searchlib/src/vespa/searchlib/attribute/attributeiterators.hpp index 17df4628606..d1226a07703 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributeiterators.hpp +++ b/searchlib/src/vespa/searchlib/attribute/attributeiterators.hpp @@ -110,12 +110,67 @@ AttributePostingListIteratorT<PL>::doSeek(uint32_t docId) } } +namespace { + +template <typename> struct is_tree_iterator; + +template <typename P> +struct is_tree_iterator<DocIdIterator<P>> { + static constexpr bool value = false; +}; + +template <typename P> +struct is_tree_iterator<DocIdMinMaxIterator<P>> { + static constexpr bool value = false; +}; + +template <typename KeyT, typename DataT, typename AggrT, typename CompareT, typename TraitsT> +struct is_tree_iterator<vespalib::btree::BTreeConstIterator<KeyT, DataT, AggrT, CompareT, TraitsT>> { + static constexpr bool value = true; +}; + +template <typename PL> +inline constexpr bool is_tree_iterator_v = is_tree_iterator<PL>::value; + +template <typename PL> +void get_hits_helper(BitVector& result, PL& iterator, uint32_t end_id) +{ + auto end_itr = iterator; + if (end_itr.valid() && end_itr.getKey() < end_id) { + end_itr.seek(end_id); + } + iterator.foreach_key_range(end_itr, [&](uint32_t key) { result.setBit(key); }); + iterator = end_itr; +} + +template <typename PL> +void or_hits_helper(BitVector& result, PL& iterator, uint32_t end_id) +{ + auto end_itr = iterator; + if (end_itr.valid() && end_itr.getKey() < end_id) { + end_itr.seek(end_id); + } + iterator.foreach_key_range(end_itr, [&](uint32_t key) + { + if (!result.testBit(key)) { + result.setBit(key); + } + }); + iterator = end_itr; +} + +} + template <typename PL> std::unique_ptr<BitVector> AttributePostingListIteratorT<PL>::get_hits(uint32_t begin_id) { BitVector::UP result(BitVector::create(begin_id, getEndId())); - for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { - result->setBit(_iterator.getKey()); + if constexpr (is_tree_iterator_v<PL>) { + get_hits_helper(*result, _iterator, getEndId()); + } else { + for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { + result->setBit(_iterator.getKey()); + } } result->invalidateCachedCount(); return result; @@ -125,9 +180,13 @@ template <typename PL> void AttributePostingListIteratorT<PL>::or_hits_into(BitVector & result, uint32_t begin_id) { (void) begin_id; - for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { - if ( ! result.testBit(_iterator.getKey()) ) { - result.setBit(_iterator.getKey()); + if constexpr (is_tree_iterator_v<PL>) { + or_hits_helper(result, _iterator, getEndId()); + } else { + for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { + if ( ! result.testBit(_iterator.getKey()) ) { + result.setBit(_iterator.getKey()); + } } } result.invalidateCachedCount(); @@ -143,8 +202,12 @@ template <typename PL> std::unique_ptr<BitVector> FilterAttributePostingListIteratorT<PL>::get_hits(uint32_t begin_id) { BitVector::UP result(BitVector::create(begin_id, getEndId())); - for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { - result->setBit(_iterator.getKey()); + if constexpr (is_tree_iterator_v<PL>) { + get_hits_helper(*result, _iterator, getEndId()); + } else { + for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { + result->setBit(_iterator.getKey()); + } } result->invalidateCachedCount(); return result; @@ -154,9 +217,13 @@ template <typename PL> void FilterAttributePostingListIteratorT<PL>::or_hits_into(BitVector & result, uint32_t begin_id) { (void) begin_id; - for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { - if ( ! result.testBit(_iterator.getKey()) ) { - result.setBit(_iterator.getKey()); + if constexpr (is_tree_iterator_v<PL>) { + or_hits_helper(result, _iterator, getEndId()); + } else { + for (; _iterator.valid() && _iterator.getKey() < getEndId(); ++_iterator) { + if ( ! result.testBit(_iterator.getKey()) ) { + result.setBit(_iterator.getKey()); + } } } result.invalidateCachedCount(); |