diff options
Diffstat (limited to 'include/llvm/ADT/SmallPtrSet.h')
-rw-r--r-- | include/llvm/ADT/SmallPtrSet.h | 57 |
1 files changed, 30 insertions, 27 deletions
diff --git a/include/llvm/ADT/SmallPtrSet.h b/include/llvm/ADT/SmallPtrSet.h index 49feb9da897a2..196ab6338047c 100644 --- a/include/llvm/ADT/SmallPtrSet.h +++ b/include/llvm/ADT/SmallPtrSet.h @@ -18,6 +18,7 @@ #include "llvm/Config/abi-breaking.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/PointerLikeTypeTraits.h" +#include "llvm/Support/type_traits.h" #include <cassert> #include <cstddef> #include <cstring> @@ -166,8 +167,8 @@ protected: const void *const *P = find_imp(Ptr); if (P == EndPointer()) return false; - - const void ** Loc = const_cast<const void **>(P); + + const void **Loc = const_cast<const void **>(P); assert(*Loc == Ptr && "broken find!"); *Loc = getTombstoneMarker(); NumTombstones++; @@ -193,7 +194,7 @@ protected: return Bucket; return EndPointer(); } - + private: bool isSmall() const { return CurArray == SmallArray; } @@ -259,11 +260,10 @@ protected: } #if LLVM_ENABLE_ABI_BREAKING_CHECKS void RetreatIfNotValid() { - --Bucket; - assert(Bucket <= End); + assert(Bucket >= End); while (Bucket != End && - (*Bucket == SmallPtrSetImplBase::getEmptyMarker() || - *Bucket == SmallPtrSetImplBase::getTombstoneMarker())) { + (Bucket[-1] == SmallPtrSetImplBase::getEmptyMarker() || + Bucket[-1] == SmallPtrSetImplBase::getTombstoneMarker())) { --Bucket; } } @@ -288,6 +288,12 @@ public: // Most methods provided by baseclass. const PtrTy operator*() const { +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + if (ReverseIterate<bool>::value) { + assert(Bucket > End); + return PtrTraits::getFromVoidPointer(const_cast<void *>(Bucket[-1])); + } +#endif assert(Bucket < End); return PtrTraits::getFromVoidPointer(const_cast<void*>(*Bucket)); } @@ -295,6 +301,7 @@ public: inline SmallPtrSetIterator& operator++() { // Preincrement #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) { + --Bucket; RetreatIfNotValid(); return *this; } @@ -343,7 +350,9 @@ struct RoundUpToPowerOfTwo { /// to avoid encoding a particular small size in the interface boundary. template <typename PtrType> class SmallPtrSetImpl : public SmallPtrSetImplBase { + using ConstPtrType = typename add_const_past_pointer<PtrType>::type; typedef PointerLikeTypeTraits<PtrType> PtrTraits; + typedef PointerLikeTypeTraits<ConstPtrType> ConstPtrTraits; protected: // Constructors that forward to the base. @@ -367,7 +376,7 @@ public: /// the element equal to Ptr. std::pair<iterator, bool> insert(PtrType Ptr) { auto p = insert_imp(PtrTraits::getAsVoidPointer(Ptr)); - return std::make_pair(iterator(p.first, EndPointer()), p.second); + return std::make_pair(makeIterator(p.first), p.second); } /// erase - If the set contains the specified pointer, remove it and return @@ -375,14 +384,10 @@ public: bool erase(PtrType Ptr) { return erase_imp(PtrTraits::getAsVoidPointer(Ptr)); } - /// count - Return 1 if the specified pointer is in the set, 0 otherwise. - size_type count(PtrType Ptr) const { - return find(Ptr) != endPtr() ? 1 : 0; - } - iterator find(PtrType Ptr) const { - auto *P = find_imp(PtrTraits::getAsVoidPointer(Ptr)); - return iterator(P, EndPointer()); + size_type count(ConstPtrType Ptr) const { return find(Ptr) != end() ? 1 : 0; } + iterator find(ConstPtrType Ptr) const { + return makeIterator(find_imp(ConstPtrTraits::getAsVoidPointer(Ptr))); } template <typename IterT> @@ -395,25 +400,23 @@ public: insert(IL.begin(), IL.end()); } - inline iterator begin() const { + iterator begin() const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) - return endPtr(); + return makeIterator(EndPointer() - 1); #endif - return iterator(CurArray, EndPointer()); + return makeIterator(CurArray); } - inline iterator end() const { + iterator end() const { return makeIterator(EndPointer()); } + +private: + /// Create an iterator that dereferences to same place as the given pointer. + iterator makeIterator(const void *const *P) const { #if LLVM_ENABLE_ABI_BREAKING_CHECKS if (ReverseIterate<bool>::value) - return iterator(CurArray, CurArray); + return iterator(P == EndPointer() ? CurArray : P + 1, CurArray); #endif - return endPtr(); - } - -private: - inline iterator endPtr() const { - const void *const *End = EndPointer(); - return iterator(End, End); + return iterator(P, EndPointer()); } }; |