diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2020-05-09 00:40:36 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2020-05-09 00:40:36 +0000 |
commit | 9fa23dc9f5530702d9f281b62e04ef8f4c3793c5 (patch) | |
tree | 570f8b46bccdb6584f4c9bedde939fe72ed743e0 /searchlib | |
parent | 1435d93b3d22f5c3819761246c564d6669cac54b (diff) |
Add protection to avoid going out of bounds when handling an empty bitvector.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/common/bitvector/bitvector_test.cpp | 16 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/common/bitvector.cpp | 51 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/common/bitvector.h | 4 |
3 files changed, 48 insertions, 23 deletions
diff --git a/searchlib/src/tests/common/bitvector/bitvector_test.cpp b/searchlib/src/tests/common/bitvector/bitvector_test.cpp index 499db526dbe..f53d15d6dc4 100644 --- a/searchlib/src/tests/common/bitvector/bitvector_test.cpp +++ b/searchlib/src/tests/common/bitvector/bitvector_test.cpp @@ -325,14 +325,19 @@ verifyThatLongerWithShorterWorksAsZeroPadded(uint32_t offset, uint32_t sz1, uint BitVector::UP bSmall = createEveryNthBitSet(3, 0, offset + sz1); BitVector::UP bLarger = createEveryNthBitSet(3, 0, offset + sz2); + BitVector::UP bEmpty = createEveryNthBitSet(3, 0, 0); bLarger->clearInterval(offset + sz1, offset + sz2); EXPECT_EQUAL(bSmall->countTrueBits(), bLarger->countTrueBits()); BitVector::UP aLarger2 = BitVector::create(*aLarger, aLarger->getStartIndex(), aLarger->size()); + BitVector::UP aLarger3 = BitVector::create(*aLarger, aLarger->getStartIndex(), aLarger->size()); EXPECT_TRUE(*aLarger == *aLarger2); + EXPECT_TRUE(*aLarger == *aLarger3); func(*aLarger, *bLarger); func(*aLarger2, *bSmall); + func(*aLarger3, *bEmpty); EXPECT_TRUE(*aLarger == *aLarger2); + //EXPECT_TRUE(*aLarger == *aLarger3); } TEST("requireThatAndWorks") { @@ -360,6 +365,17 @@ TEST("requireThatAndNotWorks") { } } +TEST("test that empty bitvectors does not crash") { + BitVector::UP empty = BitVector::create(0); + EXPECT_EQUAL(0u, empty->countTrueBits()); + EXPECT_EQUAL(0u, empty->countInterval(0, 100)); + empty->setInterval(0,17); + EXPECT_EQUAL(0u, empty->countInterval(0, 100)); + empty->clearInterval(0,17); + EXPECT_EQUAL(0u, empty->countInterval(0, 100)); + empty->notSelf(); + EXPECT_EQUAL(0u, empty->countInterval(0, 100)); +} TEST("requireThatNotWorks") { for (uint32_t offset(0); offset < 100; offset++) { diff --git a/searchlib/src/vespa/searchlib/common/bitvector.cpp b/searchlib/src/vespa/searchlib/common/bitvector.cpp index fc11f9ba032..2ffdefc81cd 100644 --- a/searchlib/src/vespa/searchlib/common/bitvector.cpp +++ b/searchlib/src/vespa/searchlib/common/bitvector.cpp @@ -91,9 +91,10 @@ BitVector::clearInterval(Index start, Index end) } void -BitVector::clearIntervalNoInvalidation(Index start, Index end) +BitVector::clearIntervalNoInvalidation(Index start_in, Index end) { - if (start >= end) { return; } + Index start = std::max(start_in, getStartIndex()); + if (start >= end || end == 0 || size() == 0) { return; } Index last = std::min(end, size()) - 1; Index startw = wordNum(start); @@ -109,9 +110,10 @@ BitVector::clearIntervalNoInvalidation(Index start, Index end) } void -BitVector::setInterval(Index start, Index end) +BitVector::setInterval(Index start_in, Index end) { - if (start >= end) { return; } + Index start = std::max(start_in, getStartIndex()); + if (start >= end || end == 0 || size() == 0) { return; } Index last = std::min(end, size()) - 1; Index startw = wordNum(start); @@ -131,14 +133,14 @@ BitVector::setInterval(Index start, Index end) BitVector::Index BitVector::count() const { - // Subtract by one to compensate for guard bit return countInterval(getStartIndex(), size()); } BitVector::Index -BitVector::countInterval(Index start, Index end) const +BitVector::countInterval(Index start_in, Index end) const { - if (start >= end) return 0; + Index start = std::max(start_in, getStartIndex()); + if (start >= end || end == 0 || size() == 0) { return 0; } Index last = std::min(end, size()) - 1; // Count bits in range [start..end> @@ -180,12 +182,14 @@ BitVector::orWith(const BitVector & right) verifyInclusiveStart(*this, right); if (right.size() < size()) { - ssize_t commonBytes = numActiveBytes(getStartIndex(), right.size()) - sizeof(Word); - if (commonBytes > 0) { - IAccelrated::getAccelrator().orBit(getActiveStart(), right.getWordIndex(getStartIndex()), commonBytes); + if (right.size() > 0) { + ssize_t commonBytes = numActiveBytes(getStartIndex(), right.size()) - sizeof(Word); + if (commonBytes > 0) { + IAccelrated::getAccelrator().orBit(getActiveStart(), right.getWordIndex(getStartIndex()), commonBytes); + } + Index last(right.size() - 1); + getWordIndex(last)[0] |= (right.getWordIndex(last)[0] & ~endBits(last)); } - Index last(right.size() - 1); - getWordIndex(last)[0] |= (right.getWordIndex(last)[0] & ~endBits(last)); } else { IAccelrated::getAccelrator().orBit(getActiveStart(), right.getWordIndex(getStartIndex()), getActiveBytes()); } @@ -196,11 +200,12 @@ BitVector::orWith(const BitVector & right) void BitVector::repairEnds() { - if (size() == 0) return; - Index start(getStartIndex()); - Index last(size() - 1); - getWordIndex(start)[0] &= ~startBits(start); - getWordIndex(last)[0] &= ~endBits(last); + if (size() != 0) { + Index start(getStartIndex()); + Index last(size() - 1); + getWordIndex(start)[0] &= ~startBits(start); + getWordIndex(last)[0] &= ~endBits(last); + } setGuardBit(); } @@ -227,12 +232,14 @@ BitVector::andNotWith(const BitVector& right) verifyInclusiveStart(*this, right); if (right.size() < size()) { - ssize_t commonBytes = numActiveBytes(getStartIndex(), right.size()) - sizeof(Word); - if (commonBytes > 0) { - IAccelrated::getAccelrator().andNotBit(getActiveStart(), right.getWordIndex(getStartIndex()), commonBytes); + if (right.size() > 0) { + ssize_t commonBytes = numActiveBytes(getStartIndex(), right.size()) - sizeof(Word); + if (commonBytes > 0) { + IAccelrated::getAccelrator().andNotBit(getActiveStart(), right.getWordIndex(getStartIndex()), commonBytes); + } + Index last(right.size() - 1); + getWordIndex(last)[0] &= ~(right.getWordIndex(last)[0] & ~endBits(last)); } - Index last(right.size() - 1); - getWordIndex(last)[0] &= ~(right.getWordIndex(last)[0] & ~endBits(last)); } else { IAccelrated::getAccelrator().andNotBit(getActiveStart(), right.getWordIndex(getStartIndex()), getActiveBytes()); } diff --git a/searchlib/src/vespa/searchlib/common/bitvector.h b/searchlib/src/vespa/searchlib/common/bitvector.h index 0ddd9d001aa..d842779d4f9 100644 --- a/searchlib/src/vespa/searchlib/common/bitvector.h +++ b/searchlib/src/vespa/searchlib/common/bitvector.h @@ -282,7 +282,9 @@ private: Index getActiveSize() const { return size() - getStartIndex(); } size_t getActiveBytes() const { return numActiveBytes(getStartIndex(), size()); } size_t numActiveWords() const { return numActiveWords(getStartIndex(), size()); } - static size_t numActiveWords(Index start, Index end) { return (numWords(end) - wordNum(start)); } + static size_t numActiveWords(Index start, Index end) { + return (end >= start) ? (numWords(end) - wordNum(start)) : 0; + } static Index invalidCount() { return std::numeric_limits<Index>::max(); } void setGuardBit() { setBit(size()); } void incNumBits() { |