diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2015-01-18 16:17:27 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2015-01-18 16:17:27 +0000 |
commit | 67c32a98315f785a9ec9d531c1f571a0196c7463 (patch) | |
tree | 4abb9cbeecc7901726dd0b4a37369596c852e9ef /lib/Analysis | |
parent | 9f61947910e6ab40de38e6b4034751ef1513200f (diff) |
Diffstat (limited to 'lib/Analysis')
51 files changed, 5854 insertions, 1841 deletions
diff --git a/lib/Analysis/AliasAnalysis.cpp b/lib/Analysis/AliasAnalysis.cpp index 8b8106b950a0..6eea8177418a 100644 --- a/lib/Analysis/AliasAnalysis.cpp +++ b/lib/Analysis/AliasAnalysis.cpp @@ -196,12 +196,11 @@ AliasAnalysis::getModRefInfo(ImmutableCallSite CS1, ImmutableCallSite CS2) { if (!Arg->getType()->isPointerTy()) continue; ModRefResult ArgMask; - Location CS1Loc = - getArgLocation(CS1, (unsigned) std::distance(CS1.arg_begin(), I), - ArgMask); - // ArgMask indicates what CS1 might do to CS1Loc; if CS1 might Mod - // CS1Loc, then we care about either a Mod or a Ref by CS2. If CS1 - // might Ref, then we care only about a Mod by CS2. + Location CS1Loc = getArgLocation( + CS1, (unsigned)std::distance(CS1.arg_begin(), I), ArgMask); + // ArgMask indicates what CS1 might do to CS1Loc; if CS1 might Mod + // CS1Loc, then we care about either a Mod or a Ref by CS2. If CS1 + // might Ref, then we care only about a Mod by CS2. ModRefResult ArgR = getModRefInfo(CS2, CS1Loc); if (((ArgMask & Mod) != NoModRef && (ArgR & ModRef) != NoModRef) || ((ArgMask & Ref) != NoModRef && (ArgR & Mod) != NoModRef)) @@ -252,61 +251,73 @@ AliasAnalysis::getModRefBehavior(const Function *F) { //===----------------------------------------------------------------------===// AliasAnalysis::Location AliasAnalysis::getLocation(const LoadInst *LI) { + AAMDNodes AATags; + LI->getAAMetadata(AATags); + return Location(LI->getPointerOperand(), - getTypeStoreSize(LI->getType()), - LI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(LI->getType()), AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const StoreInst *SI) { + AAMDNodes AATags; + SI->getAAMetadata(AATags); + return Location(SI->getPointerOperand(), - getTypeStoreSize(SI->getValueOperand()->getType()), - SI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(SI->getValueOperand()->getType()), AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const VAArgInst *VI) { - return Location(VI->getPointerOperand(), - UnknownSize, - VI->getMetadata(LLVMContext::MD_tbaa)); + AAMDNodes AATags; + VI->getAAMetadata(AATags); + + return Location(VI->getPointerOperand(), UnknownSize, AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const AtomicCmpXchgInst *CXI) { + AAMDNodes AATags; + CXI->getAAMetadata(AATags); + return Location(CXI->getPointerOperand(), getTypeStoreSize(CXI->getCompareOperand()->getType()), - CXI->getMetadata(LLVMContext::MD_tbaa)); + AATags); } AliasAnalysis::Location AliasAnalysis::getLocation(const AtomicRMWInst *RMWI) { + AAMDNodes AATags; + RMWI->getAAMetadata(AATags); + return Location(RMWI->getPointerOperand(), - getTypeStoreSize(RMWI->getValOperand()->getType()), - RMWI->getMetadata(LLVMContext::MD_tbaa)); + getTypeStoreSize(RMWI->getValOperand()->getType()), AATags); } -AliasAnalysis::Location +AliasAnalysis::Location AliasAnalysis::getLocationForSource(const MemTransferInst *MTI) { uint64_t Size = UnknownSize; if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) Size = C->getValue().getZExtValue(); - // memcpy/memmove can have TBAA tags. For memcpy, they apply + // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. - MDNode *TBAATag = MTI->getMetadata(LLVMContext::MD_tbaa); + AAMDNodes AATags; + MTI->getAAMetadata(AATags); - return Location(MTI->getRawSource(), Size, TBAATag); + return Location(MTI->getRawSource(), Size, AATags); } -AliasAnalysis::Location +AliasAnalysis::Location AliasAnalysis::getLocationForDest(const MemIntrinsic *MTI) { uint64_t Size = UnknownSize; if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) Size = C->getValue().getZExtValue(); - // memcpy/memmove can have TBAA tags. For memcpy, they apply + // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. - MDNode *TBAATag = MTI->getMetadata(LLVMContext::MD_tbaa); - - return Location(MTI->getRawDest(), Size, TBAATag); + AAMDNodes AATags; + MTI->getAAMetadata(AATags); + + return Location(MTI->getRawDest(), Size, AATags); } @@ -428,7 +439,7 @@ AliasAnalysis::callCapturesBefore(const Instruction *I, // assume that the call could touch the pointer, even though it doesn't // escape. if (isNoAlias(AliasAnalysis::Location(*CI), - AliasAnalysis::Location(Object))) + AliasAnalysis::Location(Object))) continue; if (CS.doesNotAccessMemory(ArgNo)) continue; @@ -472,21 +483,22 @@ uint64_t AliasAnalysis::getTypeStoreSize(Type *Ty) { } /// canBasicBlockModify - Return true if it is possible for execution of the -/// specified basic block to modify the value pointed to by Ptr. +/// specified basic block to modify the location Loc. /// bool AliasAnalysis::canBasicBlockModify(const BasicBlock &BB, const Location &Loc) { - return canInstructionRangeModify(BB.front(), BB.back(), Loc); + return canInstructionRangeModRef(BB.front(), BB.back(), Loc, Mod); } -/// canInstructionRangeModify - Return true if it is possible for the execution -/// of the specified instructions to modify the value pointed to by Ptr. The -/// instructions to consider are all of the instructions in the range of [I1,I2] -/// INCLUSIVE. I1 and I2 must be in the same basic block. -/// -bool AliasAnalysis::canInstructionRangeModify(const Instruction &I1, +/// canInstructionRangeModRef - Return true if it is possible for the +/// execution of the specified instructions to mod\ref (according to the +/// mode) the location Loc. The instructions to consider are all +/// of the instructions in the range of [I1,I2] INCLUSIVE. +/// I1 and I2 must be in the same basic block. +bool AliasAnalysis::canInstructionRangeModRef(const Instruction &I1, const Instruction &I2, - const Location &Loc) { + const Location &Loc, + const ModRefResult Mode) { assert(I1.getParent() == I2.getParent() && "Instructions not in same basic block!"); BasicBlock::const_iterator I = &I1; @@ -494,7 +506,7 @@ bool AliasAnalysis::canInstructionRangeModify(const Instruction &I1, ++E; // Convert from inclusive to exclusive range. for (; I != E; ++I) // Check every instruction in range - if (getModRefInfo(I, Loc) & Mod) + if (getModRefInfo(I, Loc) & Mode) return true; return false; } @@ -545,4 +557,3 @@ bool llvm::isIdentifiedFunctionLocal(const Value *V) { return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); } - diff --git a/lib/Analysis/AliasAnalysisEvaluator.cpp b/lib/Analysis/AliasAnalysisEvaluator.cpp index d9fa5a500f36..fe4bd4cc2109 100644 --- a/lib/Analysis/AliasAnalysisEvaluator.cpp +++ b/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -43,7 +43,7 @@ static cl::opt<bool> PrintMod("print-mod", cl::ReallyHidden); static cl::opt<bool> PrintRef("print-ref", cl::ReallyHidden); static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); -static cl::opt<bool> EvalTBAA("evaluate-tbaa", cl::ReallyHidden); +static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); namespace { class AAEval : public FunctionPass { @@ -153,9 +153,9 @@ bool AAEval::runOnFunction(Function &F) { for (inst_iterator I = inst_begin(F), E = inst_end(F); I != E; ++I) { if (I->getType()->isPointerTy()) // Add all pointer instructions. Pointers.insert(&*I); - if (EvalTBAA && isa<LoadInst>(&*I)) + if (EvalAAMD && isa<LoadInst>(&*I)) Loads.insert(&*I); - if (EvalTBAA && isa<StoreInst>(&*I)) + if (EvalAAMD && isa<StoreInst>(&*I)) Stores.insert(&*I); Instruction &Inst = *I; if (CallSite CS = cast<Value>(&Inst)) { @@ -213,7 +213,7 @@ bool AAEval::runOnFunction(Function &F) { } } - if (EvalTBAA) { + if (EvalAAMD) { // iterate over all pairs of load, store for (SetVector<Value *>::iterator I1 = Loads.begin(), E = Loads.end(); I1 != E; ++I1) { diff --git a/lib/Analysis/AliasSetTracker.cpp b/lib/Analysis/AliasSetTracker.cpp index 75b6be06aa8e..45442b025629 100644 --- a/lib/Analysis/AliasSetTracker.cpp +++ b/lib/Analysis/AliasSetTracker.cpp @@ -47,10 +47,10 @@ void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { // If the pointers are not a must-alias pair, this set becomes a may alias. if (AA.alias(AliasAnalysis::Location(L->getValue(), L->getSize(), - L->getTBAAInfo()), + L->getAAInfo()), AliasAnalysis::Location(R->getValue(), R->getSize(), - R->getTBAAInfo())) + R->getAAInfo())) != AliasAnalysis::MustAlias) AliasTy = MayAlias; } @@ -97,7 +97,7 @@ void AliasSet::removeFromTracker(AliasSetTracker &AST) { } void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, - uint64_t Size, const MDNode *TBAAInfo, + uint64_t Size, const AAMDNodes &AAInfo, bool KnownMustAlias) { assert(!Entry.hasAliasSet() && "Entry already in set!"); @@ -107,17 +107,17 @@ void AliasSet::addPointer(AliasSetTracker &AST, PointerRec &Entry, AliasAnalysis &AA = AST.getAliasAnalysis(); AliasAnalysis::AliasResult Result = AA.alias(AliasAnalysis::Location(P->getValue(), P->getSize(), - P->getTBAAInfo()), - AliasAnalysis::Location(Entry.getValue(), Size, TBAAInfo)); + P->getAAInfo()), + AliasAnalysis::Location(Entry.getValue(), Size, AAInfo)); if (Result != AliasAnalysis::MustAlias) AliasTy = MayAlias; else // First entry of must alias must have maximum size! - P->updateSizeAndTBAAInfo(Size, TBAAInfo); + P->updateSizeAndAAInfo(Size, AAInfo); assert(Result != AliasAnalysis::NoAlias && "Cannot be part of must set!"); } Entry.setAliasSet(this); - Entry.updateSizeAndTBAAInfo(Size, TBAAInfo); + Entry.updateSizeAndAAInfo(Size, AAInfo); // Add it to the end of the list... assert(*PtrListEnd == nullptr && "End of list is not null?"); @@ -147,7 +147,7 @@ void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { /// alias one of the members in the set. /// bool AliasSet::aliasesPointer(const Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo, + const AAMDNodes &AAInfo, AliasAnalysis &AA) const { if (AliasTy == MustAlias) { assert(UnknownInsts.empty() && "Illegal must alias set!"); @@ -158,23 +158,23 @@ bool AliasSet::aliasesPointer(const Value *Ptr, uint64_t Size, assert(SomePtr && "Empty must-alias set??"); return AA.alias(AliasAnalysis::Location(SomePtr->getValue(), SomePtr->getSize(), - SomePtr->getTBAAInfo()), - AliasAnalysis::Location(Ptr, Size, TBAAInfo)); + SomePtr->getAAInfo()), + AliasAnalysis::Location(Ptr, Size, AAInfo)); } // If this is a may-alias set, we have to check all of the pointers in the set // to be sure it doesn't alias the set... for (iterator I = begin(), E = end(); I != E; ++I) - if (AA.alias(AliasAnalysis::Location(Ptr, Size, TBAAInfo), + if (AA.alias(AliasAnalysis::Location(Ptr, Size, AAInfo), AliasAnalysis::Location(I.getPointer(), I.getSize(), - I.getTBAAInfo()))) + I.getAAInfo()))) return true; // Check the unknown instructions... if (!UnknownInsts.empty()) { for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) if (AA.getModRefInfo(UnknownInsts[i], - AliasAnalysis::Location(Ptr, Size, TBAAInfo)) != + AliasAnalysis::Location(Ptr, Size, AAInfo)) != AliasAnalysis::NoModRef) return true; } @@ -197,7 +197,7 @@ bool AliasSet::aliasesUnknownInst(Instruction *Inst, AliasAnalysis &AA) const { for (iterator I = begin(), E = end(); I != E; ++I) if (AA.getModRefInfo(Inst, AliasAnalysis::Location(I.getPointer(), I.getSize(), - I.getTBAAInfo())) != + I.getAAInfo())) != AliasAnalysis::NoModRef) return true; @@ -223,11 +223,11 @@ void AliasSetTracker::clear() { /// AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo) { + const AAMDNodes &AAInfo) { AliasSet *FoundSet = nullptr; for (iterator I = begin(), E = end(); I != E;) { iterator Cur = I++; - if (Cur->Forward || !Cur->aliasesPointer(Ptr, Size, TBAAInfo, AA)) continue; + if (Cur->Forward || !Cur->aliasesPointer(Ptr, Size, AAInfo, AA)) continue; if (!FoundSet) { // If this is the first alias set ptr can go into. FoundSet = Cur; // Remember it. @@ -243,14 +243,19 @@ AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, /// this alias set, false otherwise. This does not modify the AST object or /// alias sets. bool AliasSetTracker::containsPointer(Value *Ptr, uint64_t Size, - const MDNode *TBAAInfo) const { + const AAMDNodes &AAInfo) const { for (const_iterator I = begin(), E = end(); I != E; ++I) - if (!I->Forward && I->aliasesPointer(Ptr, Size, TBAAInfo, AA)) + if (!I->Forward && I->aliasesPointer(Ptr, Size, AAInfo, AA)) return true; return false; } - +bool AliasSetTracker::containsUnknown(Instruction *Inst) const { + for (const_iterator I = begin(), E = end(); I != E; ++I) + if (!I->Forward && I->aliasesUnknownInst(Inst, AA)) + return true; + return false; +} AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { AliasSet *FoundSet = nullptr; @@ -272,67 +277,75 @@ AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { /// getAliasSetForPointer - Return the alias set that the specified pointer /// lives in. AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, uint64_t Size, - const MDNode *TBAAInfo, + const AAMDNodes &AAInfo, bool *New) { AliasSet::PointerRec &Entry = getEntryFor(Pointer); // Check to see if the pointer is already known. if (Entry.hasAliasSet()) { - Entry.updateSizeAndTBAAInfo(Size, TBAAInfo); + Entry.updateSizeAndAAInfo(Size, AAInfo); // Return the set! return *Entry.getAliasSet(*this)->getForwardedTarget(*this); } - if (AliasSet *AS = findAliasSetForPointer(Pointer, Size, TBAAInfo)) { + if (AliasSet *AS = findAliasSetForPointer(Pointer, Size, AAInfo)) { // Add it to the alias set it aliases. - AS->addPointer(*this, Entry, Size, TBAAInfo); + AS->addPointer(*this, Entry, Size, AAInfo); return *AS; } if (New) *New = true; // Otherwise create a new alias set to hold the loaded pointer. AliasSets.push_back(new AliasSet()); - AliasSets.back().addPointer(*this, Entry, Size, TBAAInfo); + AliasSets.back().addPointer(*this, Entry, Size, AAInfo); return AliasSets.back(); } -bool AliasSetTracker::add(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { +bool AliasSetTracker::add(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { bool NewPtr; - addPointer(Ptr, Size, TBAAInfo, AliasSet::NoModRef, NewPtr); + addPointer(Ptr, Size, AAInfo, AliasSet::NoModRef, NewPtr); return NewPtr; } bool AliasSetTracker::add(LoadInst *LI) { if (LI->getOrdering() > Monotonic) return addUnknown(LI); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + AliasSet::AccessType ATy = AliasSet::Refs; bool NewPtr; AliasSet &AS = addPointer(LI->getOperand(0), AA.getTypeStoreSize(LI->getType()), - LI->getMetadata(LLVMContext::MD_tbaa), - ATy, NewPtr); + AAInfo, ATy, NewPtr); if (LI->isVolatile()) AS.setVolatile(); return NewPtr; } bool AliasSetTracker::add(StoreInst *SI) { if (SI->getOrdering() > Monotonic) return addUnknown(SI); + + AAMDNodes AAInfo; + SI->getAAMetadata(AAInfo); + AliasSet::AccessType ATy = AliasSet::Mods; bool NewPtr; Value *Val = SI->getOperand(0); AliasSet &AS = addPointer(SI->getOperand(1), AA.getTypeStoreSize(Val->getType()), - SI->getMetadata(LLVMContext::MD_tbaa), - ATy, NewPtr); + AAInfo, ATy, NewPtr); if (SI->isVolatile()) AS.setVolatile(); return NewPtr; } bool AliasSetTracker::add(VAArgInst *VAAI) { + AAMDNodes AAInfo; + VAAI->getAAMetadata(AAInfo); + bool NewPtr; addPointer(VAAI->getOperand(0), AliasAnalysis::UnknownSize, - VAAI->getMetadata(LLVMContext::MD_tbaa), - AliasSet::ModRef, NewPtr); + AAInfo, AliasSet::ModRef, NewPtr); return NewPtr; } @@ -390,7 +403,7 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { bool X; for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { AliasSet &NewAS = addPointer(ASI.getPointer(), ASI.getSize(), - ASI.getTBAAInfo(), + ASI.getAAInfo(), (AliasSet::AccessType)AS.AccessTy, X); if (AS.isVolatile()) NewAS.setVolatile(); } @@ -429,8 +442,8 @@ void AliasSetTracker::remove(AliasSet &AS) { } bool -AliasSetTracker::remove(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { - AliasSet *AS = findAliasSetForPointer(Ptr, Size, TBAAInfo); +AliasSetTracker::remove(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { + AliasSet *AS = findAliasSetForPointer(Ptr, Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -438,8 +451,11 @@ AliasSetTracker::remove(Value *Ptr, uint64_t Size, const MDNode *TBAAInfo) { bool AliasSetTracker::remove(LoadInst *LI) { uint64_t Size = AA.getTypeStoreSize(LI->getType()); - const MDNode *TBAAInfo = LI->getMetadata(LLVMContext::MD_tbaa); - AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size, TBAAInfo); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + + AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -447,17 +463,22 @@ bool AliasSetTracker::remove(LoadInst *LI) { bool AliasSetTracker::remove(StoreInst *SI) { uint64_t Size = AA.getTypeStoreSize(SI->getOperand(0)->getType()); - const MDNode *TBAAInfo = SI->getMetadata(LLVMContext::MD_tbaa); - AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size, TBAAInfo); + + AAMDNodes AAInfo; + SI->getAAMetadata(AAInfo); + + AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size, AAInfo); if (!AS) return false; remove(*AS); return true; } bool AliasSetTracker::remove(VAArgInst *VAAI) { + AAMDNodes AAInfo; + VAAI->getAAMetadata(AAInfo); + AliasSet *AS = findAliasSetForPointer(VAAI->getOperand(0), - AliasAnalysis::UnknownSize, - VAAI->getMetadata(LLVMContext::MD_tbaa)); + AliasAnalysis::UnknownSize, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -546,7 +567,7 @@ void AliasSetTracker::copyValue(Value *From, Value *To) { I = PointerMap.find_as(From); AliasSet *AS = I->second->getAliasSet(*this); AS->addPointer(*this, Entry, I->second->getSize(), - I->second->getTBAAInfo(), + I->second->getAAInfo(), true); } diff --git a/lib/Analysis/Analysis.cpp b/lib/Analysis/Analysis.cpp index c58012f3a8c3..f64bf0ec2c47 100644 --- a/lib/Analysis/Analysis.cpp +++ b/lib/Analysis/Analysis.cpp @@ -34,6 +34,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeCFGPrinterPass(Registry); initializeCFGOnlyViewerPass(Registry); initializeCFGOnlyPrinterPass(Registry); + initializeCFLAliasAnalysisPass(Registry); initializeDependenceAnalysisPass(Registry); initializeDelinearizationPass(Registry); initializeDominanceFrontierPass(Registry); @@ -66,6 +67,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeScalarEvolutionAliasAnalysisPass(Registry); initializeTargetTransformInfoAnalysisGroup(Registry); initializeTypeBasedAliasAnalysisPass(Registry); + initializeScopedNoAliasAAPass(Registry); } void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { diff --git a/lib/Analysis/AssumptionCache.cpp b/lib/Analysis/AssumptionCache.cpp new file mode 100644 index 000000000000..da5ba18fc43b --- /dev/null +++ b/lib/Analysis/AssumptionCache.cpp @@ -0,0 +1,125 @@ +//===- AssumptionCache.cpp - Cache finding @llvm.assume calls -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that keeps track of @llvm.assume intrinsics in +// the functions of a module. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +using namespace llvm; +using namespace llvm::PatternMatch; + +void AssumptionCache::scanFunction() { + assert(!Scanned && "Tried to scan the function twice!"); + assert(AssumeHandles.empty() && "Already have assumes when scanning!"); + + // Go through all instructions in all blocks, add all calls to @llvm.assume + // to this cache. + for (BasicBlock &B : F) + for (Instruction &II : B) + if (match(&II, m_Intrinsic<Intrinsic::assume>())) + AssumeHandles.push_back(&II); + + // Mark the scan as complete. + Scanned = true; +} + +void AssumptionCache::registerAssumption(CallInst *CI) { + assert(match(CI, m_Intrinsic<Intrinsic::assume>()) && + "Registered call does not call @llvm.assume"); + + // If we haven't scanned the function yet, just drop this assumption. It will + // be found when we scan later. + if (!Scanned) + return; + + AssumeHandles.push_back(CI); + +#ifndef NDEBUG + assert(CI->getParent() && + "Cannot register @llvm.assume call not in a basic block"); + assert(&F == CI->getParent()->getParent() && + "Cannot register @llvm.assume call not in this function"); + + // We expect the number of assumptions to be small, so in an asserts build + // check that we don't accumulate duplicates and that all assumptions point + // to the same function. + SmallPtrSet<Value *, 16> AssumptionSet; + for (auto &VH : AssumeHandles) { + if (!VH) + continue; + + assert(&F == cast<Instruction>(VH)->getParent()->getParent() && + "Cached assumption not inside this function!"); + assert(match(cast<CallInst>(VH), m_Intrinsic<Intrinsic::assume>()) && + "Cached something other than a call to @llvm.assume!"); + assert(AssumptionSet.insert(VH).second && + "Cache contains multiple copies of a call!"); + } +#endif +} + +void AssumptionCacheTracker::FunctionCallbackVH::deleted() { + auto I = ACT->AssumptionCaches.find_as(cast<Function>(getValPtr())); + if (I != ACT->AssumptionCaches.end()) + ACT->AssumptionCaches.erase(I); + // 'this' now dangles! +} + +AssumptionCache &AssumptionCacheTracker::getAssumptionCache(Function &F) { + // We probe the function map twice to try and avoid creating a value handle + // around the function in common cases. This makes insertion a bit slower, + // but if we have to insert we're going to scan the whole function so that + // shouldn't matter. + auto I = AssumptionCaches.find_as(&F); + if (I != AssumptionCaches.end()) + return *I->second; + + // Ok, build a new cache by scanning the function, insert it and the value + // handle into our map, and return the newly populated cache. + auto IP = AssumptionCaches.insert(std::make_pair( + FunctionCallbackVH(&F, this), llvm::make_unique<AssumptionCache>(F))); + assert(IP.second && "Scanning function already in the map?"); + return *IP.first->second; +} + +void AssumptionCacheTracker::verifyAnalysis() const { +#ifndef NDEBUG + SmallPtrSet<const CallInst *, 4> AssumptionSet; + for (const auto &I : AssumptionCaches) { + for (auto &VH : I.second->assumptions()) + if (VH) + AssumptionSet.insert(cast<CallInst>(VH)); + + for (const BasicBlock &B : cast<Function>(*I.first)) + for (const Instruction &II : B) + if (match(&II, m_Intrinsic<Intrinsic::assume>())) + assert(AssumptionSet.count(cast<CallInst>(&II)) && + "Assumption in scanned function not in cache"); + } +#endif +} + +AssumptionCacheTracker::AssumptionCacheTracker() : ImmutablePass(ID) { + initializeAssumptionCacheTrackerPass(*PassRegistry::getPassRegistry()); +} + +AssumptionCacheTracker::~AssumptionCacheTracker() {} + +INITIALIZE_PASS(AssumptionCacheTracker, "assumption-cache-tracker", + "Assumption Cache Tracker", false, true) +char AssumptionCacheTracker::ID = 0; diff --git a/lib/Analysis/BasicAliasAnalysis.cpp b/lib/Analysis/BasicAliasAnalysis.cpp index 38ec52d6b9ac..a9efc5a9f734 100644 --- a/lib/Analysis/BasicAliasAnalysis.cpp +++ b/lib/Analysis/BasicAliasAnalysis.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -194,7 +195,8 @@ namespace { /// represented in the result. static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, ExtensionKind &Extension, - const DataLayout &DL, unsigned Depth) { + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, DominatorTree *DT) { assert(V->getType()->isIntegerTy() && "Not an integer value"); // Limit our recursion depth. @@ -204,6 +206,14 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, return V; } + if (ConstantInt *Const = dyn_cast<ConstantInt>(V)) { + // if it's a constant, just convert it to an offset + // and remove the variable. + Offset += Const->getValue(); + assert(Scale == 0 && "Constant values don't have a scale"); + return V; + } + if (BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) { if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) { switch (BOp->getOpcode()) { @@ -211,23 +221,24 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, case Instruction::Or: // X|C == X+C if all the bits in C are unset in X. Otherwise we can't // analyze it. - if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL)) + if (!MaskedValueIsZero(BOp->getOperand(0), RHSC->getValue(), &DL, 0, AC, + BOp, DT)) break; // FALL THROUGH. case Instruction::Add: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth + 1, AC, DT); Offset += RHSC->getValue(); return V; case Instruction::Mul: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth + 1, AC, DT); Offset *= RHSC->getValue(); Scale *= RHSC->getValue(); return V; case Instruction::Shl: V = GetLinearExpression(BOp->getOperand(0), Scale, Offset, Extension, - DL, Depth+1); + DL, Depth + 1, AC, DT); Offset <<= RHSC->getValue().getLimitedValue(); Scale <<= RHSC->getValue().getLimitedValue(); return V; @@ -247,10 +258,13 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, Offset = Offset.trunc(SmallWidth); Extension = isa<SExtInst>(V) ? EK_SignExt : EK_ZeroExt; - Value *Result = GetLinearExpression(CastOp, Scale, Offset, Extension, - DL, Depth+1); + Value *Result = GetLinearExpression(CastOp, Scale, Offset, Extension, DL, + Depth + 1, AC, DT); Scale = Scale.zext(OldWidth); - Offset = Offset.zext(OldWidth); + + // We have to sign-extend even if Extension == EK_ZeroExt as we can't + // decompose a sign extension (i.e. zext(x - 1) != zext(x) - zext(-1)). + Offset = Offset.sext(OldWidth); return Result; } @@ -278,7 +292,8 @@ static Value *GetLinearExpression(Value *V, APInt &Scale, APInt &Offset, static const Value * DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, SmallVectorImpl<VariableGEPIndex> &VarIndices, - bool &MaxLookupReached, const DataLayout *DL) { + bool &MaxLookupReached, const DataLayout *DL, + AssumptionCache *AC, DominatorTree *DT) { // Limit recursion depth to limit compile time in crazy cases. unsigned MaxLookup = MaxLookupSearchDepth; MaxLookupReached = false; @@ -309,7 +324,10 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // If it's not a GEP, hand it off to SimplifyInstruction to see if it // can come up with something. This matches what GetUnderlyingObject does. if (const Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Get a DominatorTree and use it here. + // TODO: Get a DominatorTree and AssumptionCache and use them here + // (these are both now available in this function, but this should be + // updated when GetUnderlyingObject is updated). TLI should be + // provided also. if (const Value *Simplified = SimplifyInstruction(const_cast<Instruction *>(I), DL)) { V = Simplified; @@ -368,7 +386,7 @@ DecomposeGEPExpression(const Value *V, int64_t &BaseOffs, // Use GetLinearExpression to decompose the index into a C1*V+C2 form. APInt IndexScale(Width, 0), IndexOffset(Width, 0); Index = GetLinearExpression(Index, IndexScale, IndexOffset, Extension, - *DL, 0); + *DL, 0, AC, DT); // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. @@ -449,6 +467,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); } @@ -456,8 +475,8 @@ namespace { assert(AliasCache.empty() && "AliasCache must be cleared after use!"); assert(notDifferentParent(LocA.Ptr, LocB.Ptr) && "BasicAliasAnalysis doesn't support interprocedural queries."); - AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.TBAATag, - LocB.Ptr, LocB.Size, LocB.TBAATag); + AliasResult Alias = aliasCheck(LocA.Ptr, LocA.Size, LocA.AATags, + LocB.Ptr, LocB.Size, LocB.AATags); // AliasCache rarely has more than 1 or 2 elements, always use // shrink_and_clear so it quickly returns to the inline capacity of the // SmallDenseMap if it ever grows larger. @@ -471,10 +490,7 @@ namespace { const Location &Loc) override; ModRefResult getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) override { - // The AliasAnalysis base class has some smarts, lets use them. - return AliasAnalysis::getModRefInfo(CS1, CS2); - } + ImmutableCallSite CS2) override; /// pointsToConstantMemory - Chase pointers until we find a (constant /// global) or not. @@ -544,28 +560,28 @@ namespace { // aliasGEP - Provide a bunch of ad-hoc rules to disambiguate a GEP // instruction against another. AliasResult aliasGEP(const GEPOperator *V1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + const AAMDNodes &V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo, + const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2); // aliasPHI - Provide a bunch of ad-hoc rules to disambiguate a PHI // instruction against another. AliasResult aliasPHI(const PHINode *PN, uint64_t PNSize, - const MDNode *PNTBAAInfo, + const AAMDNodes &PNAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo); + const AAMDNodes &V2AAInfo); /// aliasSelect - Disambiguate a Select instruction against another value. AliasResult aliasSelect(const SelectInst *SI, uint64_t SISize, - const MDNode *SITBAAInfo, + const AAMDNodes &SIAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo); + const AAMDNodes &V2AAInfo); AliasResult aliasCheck(const Value *V1, uint64_t V1Size, - const MDNode *V1TBAATag, + AAMDNodes V1AATag, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAATag); + AAMDNodes V2AATag); }; } // End of anonymous namespace @@ -574,6 +590,7 @@ char BasicAliasAnalysis::ID = 0; INITIALIZE_AG_PASS_BEGIN(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", false, true, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_AG_PASS_END(BasicAliasAnalysis, AliasAnalysis, "basicaa", "Basic Alias Analysis (stateless AA impl)", @@ -596,7 +613,7 @@ BasicAliasAnalysis::pointsToConstantMemory(const Location &Loc, bool OrLocal) { Worklist.push_back(Loc.Ptr); do { const Value *V = GetUnderlyingObject(Worklist.pop_back_val(), DL); - if (!Visited.insert(V)) { + if (!Visited.insert(V).second) { Visited.clear(); return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); } @@ -788,6 +805,14 @@ BasicAliasAnalysis::getArgLocation(ImmutableCallSite CS, unsigned ArgIdx, return Loc; } +static bool isAssumeIntrinsic(ImmutableCallSite CS) { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction()); + if (II && II->getIntrinsicID() == Intrinsic::assume) + return true; + + return false; +} + /// getModRefInfo - Check to see if the specified callsite can clobber the /// specified memory object. Since we only look at local properties of this /// function, we really can't say much about this query. We do, however, use @@ -840,10 +865,29 @@ BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS, return NoModRef; } + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isAssumeIntrinsic(CS)) + return NoModRef; + // The AliasAnalysis base class has some smarts, lets use them. return AliasAnalysis::getModRefInfo(CS, Loc); } +AliasAnalysis::ModRefResult +BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS1, + ImmutableCallSite CS2) { + // While the assume intrinsic is marked as arbitrarily writing so that + // proper control dependencies will be maintained, it never aliases any + // particular memory location. + if (isAssumeIntrinsic(CS1) || isAssumeIntrinsic(CS2)) + return NoModRef; + + // The AliasAnalysis base class has some smarts, lets use them. + return AliasAnalysis::getModRefInfo(CS1, CS2); +} + /// aliasGEP - Provide a bunch of ad-hoc rules to disambiguate a GEP instruction /// against another pointer. We know that V1 is a GEP, but we don't know /// anything about V2. UnderlyingV1 is GetUnderlyingObject(GEP1, DL), @@ -851,30 +895,50 @@ BasicAliasAnalysis::getModRefInfo(ImmutableCallSite CS, /// AliasAnalysis::AliasResult BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + const AAMDNodes &V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo, + const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2) { int64_t GEP1BaseOffset; bool GEP1MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP1VariableIndices; + // We have to get two AssumptionCaches here because GEP1 and V2 may be from + // different functions. + // FIXME: This really doesn't make any sense. We get a dominator tree below + // that can only refer to a single function. But this function (aliasGEP) is + // a method on an immutable pass that can be called when there *isn't* + // a single function. The old pass management layer makes this "work", but + // this isn't really a clean solution. + AssumptionCacheTracker &ACT = getAnalysis<AssumptionCacheTracker>(); + AssumptionCache *AC1 = nullptr, *AC2 = nullptr; + if (auto *GEP1I = dyn_cast<Instruction>(GEP1)) + AC1 = &ACT.getAssumptionCache( + const_cast<Function &>(*GEP1I->getParent()->getParent())); + if (auto *I2 = dyn_cast<Instruction>(V2)) + AC2 = &ACT.getAssumptionCache( + const_cast<Function &>(*I2->getParent()->getParent())); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + // If we have two gep instructions with must-alias or not-alias'ing base // pointers, figure out if the indexes to the GEP tell us anything about the // derived pointer. if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) { // Do the base pointers alias? - AliasResult BaseAlias = aliasCheck(UnderlyingV1, UnknownSize, nullptr, - UnderlyingV2, UnknownSize, nullptr); + AliasResult BaseAlias = aliasCheck(UnderlyingV1, UnknownSize, AAMDNodes(), + UnderlyingV2, UnknownSize, AAMDNodes()); // Check for geps of non-aliasing underlying pointers where the offsets are // identical. if ((BaseAlias == MayAlias) && V1Size == V2Size) { // Do the base pointers alias assuming type and size. AliasResult PreciseBaseAlias = aliasCheck(UnderlyingV1, V1Size, - V1TBAAInfo, UnderlyingV2, - V2Size, V2TBAAInfo); + V1AAInfo, UnderlyingV2, + V2Size, V2AAInfo); if (PreciseBaseAlias == NoAlias) { // See if the computed offset from the common pointer tells us about the // relation of the resulting pointer. @@ -882,11 +946,11 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, bool GEP2MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = - DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, + GEP2MaxLookupReached, DL, AC2, DT); const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, + GEP1MaxLookupReached, DL, AC1, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. if (GEP1BasePtr != UnderlyingV1 || GEP2BasePtr != UnderlyingV2) { @@ -914,15 +978,15 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // exactly, see if the computed offset from the common pointer tells us // about the relation of the resulting pointer. const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, + GEP1MaxLookupReached, DL, AC1, DT); int64_t GEP2BaseOffset; bool GEP2MaxLookupReached; SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; const Value *GEP2BasePtr = - DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL); + DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, + GEP2MaxLookupReached, DL, AC2, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. @@ -949,8 +1013,8 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (V1Size == UnknownSize && V2Size == UnknownSize) return MayAlias; - AliasResult R = aliasCheck(UnderlyingV1, UnknownSize, nullptr, - V2, V2Size, V2TBAAInfo); + AliasResult R = aliasCheck(UnderlyingV1, UnknownSize, AAMDNodes(), + V2, V2Size, V2AAInfo); if (R != MustAlias) // If V2 may alias GEP base pointer, conservatively returns MayAlias. // If V2 is known not to alias GEP base pointer, then the two values @@ -960,8 +1024,8 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, return R; const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL); + DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, + GEP1MaxLookupReached, DL, AC1, DT); // DecomposeGEPExpression and GetUnderlyingObject should return the // same result except when DecomposeGEPExpression has no DataLayout. @@ -1012,12 +1076,43 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, } } - // Try to distinguish something like &A[i][1] against &A[42][0]. - // Grab the least significant bit set in any of the scales. if (!GEP1VariableIndices.empty()) { uint64_t Modulo = 0; - for (unsigned i = 0, e = GEP1VariableIndices.size(); i != e; ++i) - Modulo |= (uint64_t)GEP1VariableIndices[i].Scale; + bool AllPositive = true; + for (unsigned i = 0, e = GEP1VariableIndices.size(); i != e; ++i) { + + // Try to distinguish something like &A[i][1] against &A[42][0]. + // Grab the least significant bit set in any of the scales. We + // don't need std::abs here (even if the scale's negative) as we'll + // be ^'ing Modulo with itself later. + Modulo |= (uint64_t) GEP1VariableIndices[i].Scale; + + if (AllPositive) { + // If the Value could change between cycles, then any reasoning about + // the Value this cycle may not hold in the next cycle. We'll just + // give up if we can't determine conditions that hold for every cycle: + const Value *V = GEP1VariableIndices[i].V; + + bool SignKnownZero, SignKnownOne; + ComputeSignBit(const_cast<Value *>(V), SignKnownZero, SignKnownOne, DL, + 0, AC1, nullptr, DT); + + // Zero-extension widens the variable, and so forces the sign + // bit to zero. + bool IsZExt = GEP1VariableIndices[i].Extension == EK_ZeroExt; + SignKnownZero |= IsZExt; + SignKnownOne &= !IsZExt; + + // If the variable begins with a zero then we know it's + // positive, regardless of whether the value is signed or + // unsigned. + int64_t Scale = GEP1VariableIndices[i].Scale; + AllPositive = + (SignKnownZero && Scale >= 0) || + (SignKnownOne && Scale < 0); + } + } + Modulo = Modulo ^ (Modulo & (Modulo - 1)); // We can compute the difference between the two addresses @@ -1027,6 +1122,12 @@ BasicAliasAnalysis::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (V1Size != UnknownSize && V2Size != UnknownSize && ModOffset >= V2Size && V1Size <= Modulo - ModOffset) return NoAlias; + + // If we know all the variables are positive, then GEP1 >= GEP1BasePtr. + // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers + // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr. + if (AllPositive && GEP1BaseOffset > 0 && V2Size <= (uint64_t) GEP1BaseOffset) + return NoAlias; } // Statically, we can see that the base objects are the same, but the @@ -1056,33 +1157,33 @@ MergeAliasResults(AliasAnalysis::AliasResult A, AliasAnalysis::AliasResult B) { /// instruction against another. AliasAnalysis::AliasResult BasicAliasAnalysis::aliasSelect(const SelectInst *SI, uint64_t SISize, - const MDNode *SITBAAInfo, + const AAMDNodes &SIAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + const AAMDNodes &V2AAInfo) { // If the values are Selects with the same condition, we can do a more precise // check: just check for aliases between the values on corresponding arms. if (const SelectInst *SI2 = dyn_cast<SelectInst>(V2)) if (SI->getCondition() == SI2->getCondition()) { AliasResult Alias = - aliasCheck(SI->getTrueValue(), SISize, SITBAAInfo, - SI2->getTrueValue(), V2Size, V2TBAAInfo); + aliasCheck(SI->getTrueValue(), SISize, SIAAInfo, + SI2->getTrueValue(), V2Size, V2AAInfo); if (Alias == MayAlias) return MayAlias; AliasResult ThisAlias = - aliasCheck(SI->getFalseValue(), SISize, SITBAAInfo, - SI2->getFalseValue(), V2Size, V2TBAAInfo); + aliasCheck(SI->getFalseValue(), SISize, SIAAInfo, + SI2->getFalseValue(), V2Size, V2AAInfo); return MergeAliasResults(ThisAlias, Alias); } // If both arms of the Select node NoAlias or MustAlias V2, then returns // NoAlias / MustAlias. Otherwise, returns MayAlias. AliasResult Alias = - aliasCheck(V2, V2Size, V2TBAAInfo, SI->getTrueValue(), SISize, SITBAAInfo); + aliasCheck(V2, V2Size, V2AAInfo, SI->getTrueValue(), SISize, SIAAInfo); if (Alias == MayAlias) return MayAlias; AliasResult ThisAlias = - aliasCheck(V2, V2Size, V2TBAAInfo, SI->getFalseValue(), SISize, SITBAAInfo); + aliasCheck(V2, V2Size, V2AAInfo, SI->getFalseValue(), SISize, SIAAInfo); return MergeAliasResults(ThisAlias, Alias); } @@ -1090,9 +1191,9 @@ BasicAliasAnalysis::aliasSelect(const SelectInst *SI, uint64_t SISize, // against another. AliasAnalysis::AliasResult BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, - const MDNode *PNTBAAInfo, + const AAMDNodes &PNAAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + const AAMDNodes &V2AAInfo) { // Track phi nodes we have visited. We use this information when we determine // value equivalence. VisitedPhiBBs.insert(PN->getParent()); @@ -1102,8 +1203,8 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // on corresponding edges. if (const PHINode *PN2 = dyn_cast<PHINode>(V2)) if (PN2->getParent() == PN->getParent()) { - LocPair Locs(Location(PN, PNSize, PNTBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + LocPair Locs(Location(PN, PNSize, PNAAInfo), + Location(V2, V2Size, V2AAInfo)); if (PN > V2) std::swap(Locs.first, Locs.second); // Analyse the PHIs' inputs under the assumption that the PHIs are @@ -1121,9 +1222,9 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { AliasResult ThisAlias = - aliasCheck(PN->getIncomingValue(i), PNSize, PNTBAAInfo, + aliasCheck(PN->getIncomingValue(i), PNSize, PNAAInfo, PN2->getIncomingValueForBlock(PN->getIncomingBlock(i)), - V2Size, V2TBAAInfo); + V2Size, V2AAInfo); Alias = MergeAliasResults(ThisAlias, Alias); if (Alias == MayAlias) break; @@ -1146,12 +1247,12 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // sides are PHI nodes. In which case, this is O(m x n) time where 'm' // and 'n' are the number of PHI sources. return MayAlias; - if (UniqueSrc.insert(PV1)) + if (UniqueSrc.insert(PV1).second) V1Srcs.push_back(PV1); } - AliasResult Alias = aliasCheck(V2, V2Size, V2TBAAInfo, - V1Srcs[0], PNSize, PNTBAAInfo); + AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, + V1Srcs[0], PNSize, PNAAInfo); // Early exit if the check of the first PHI source against V2 is MayAlias. // Other results are not possible. if (Alias == MayAlias) @@ -1162,8 +1263,8 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, for (unsigned i = 1, e = V1Srcs.size(); i != e; ++i) { Value *V = V1Srcs[i]; - AliasResult ThisAlias = aliasCheck(V2, V2Size, V2TBAAInfo, - V, PNSize, PNTBAAInfo); + AliasResult ThisAlias = aliasCheck(V2, V2Size, V2AAInfo, + V, PNSize, PNAAInfo); Alias = MergeAliasResults(ThisAlias, Alias); if (Alias == MayAlias) break; @@ -1177,9 +1278,9 @@ BasicAliasAnalysis::aliasPHI(const PHINode *PN, uint64_t PNSize, // AliasAnalysis::AliasResult BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, - const MDNode *V1TBAAInfo, + AAMDNodes V1AAInfo, const Value *V2, uint64_t V2Size, - const MDNode *V2TBAAInfo) { + AAMDNodes V2AAInfo) { // If either of the memory references is empty, it doesn't matter what the // pointer values are. if (V1Size == 0 || V2Size == 0) @@ -1259,8 +1360,8 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, // Check the cache before climbing up use-def chains. This also terminates // otherwise infinitely recursive queries. - LocPair Locs(Location(V1, V1Size, V1TBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + LocPair Locs(Location(V1, V1Size, V1AAInfo), + Location(V2, V2Size, V2AAInfo)); if (V1 > V2) std::swap(Locs.first, Locs.second); std::pair<AliasCacheTy::iterator, bool> Pair = @@ -1274,32 +1375,32 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, std::swap(V1, V2); std::swap(V1Size, V2Size); std::swap(O1, O2); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const GEPOperator *GV1 = dyn_cast<GEPOperator>(V1)) { - AliasResult Result = aliasGEP(GV1, V1Size, V1TBAAInfo, V2, V2Size, V2TBAAInfo, O1, O2); + AliasResult Result = aliasGEP(GV1, V1Size, V1AAInfo, V2, V2Size, V2AAInfo, O1, O2); if (Result != MayAlias) return AliasCache[Locs] = Result; } if (isa<PHINode>(V2) && !isa<PHINode>(V1)) { std::swap(V1, V2); std::swap(V1Size, V2Size); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const PHINode *PN = dyn_cast<PHINode>(V1)) { - AliasResult Result = aliasPHI(PN, V1Size, V1TBAAInfo, - V2, V2Size, V2TBAAInfo); + AliasResult Result = aliasPHI(PN, V1Size, V1AAInfo, + V2, V2Size, V2AAInfo); if (Result != MayAlias) return AliasCache[Locs] = Result; } if (isa<SelectInst>(V2) && !isa<SelectInst>(V1)) { std::swap(V1, V2); std::swap(V1Size, V2Size); - std::swap(V1TBAAInfo, V2TBAAInfo); + std::swap(V1AAInfo, V2AAInfo); } if (const SelectInst *S1 = dyn_cast<SelectInst>(V1)) { - AliasResult Result = aliasSelect(S1, V1Size, V1TBAAInfo, - V2, V2Size, V2TBAAInfo); + AliasResult Result = aliasSelect(S1, V1Size, V1AAInfo, + V2, V2Size, V2AAInfo); if (Result != MayAlias) return AliasCache[Locs] = Result; } @@ -1312,8 +1413,8 @@ BasicAliasAnalysis::aliasCheck(const Value *V1, uint64_t V1Size, return AliasCache[Locs] = PartialAlias; AliasResult Result = - AliasAnalysis::alias(Location(V1, V1Size, V1TBAAInfo), - Location(V2, V2Size, V2TBAAInfo)); + AliasAnalysis::alias(Location(V1, V1Size, V1AAInfo), + Location(V2, V2Size, V2AAInfo)); return AliasCache[Locs] = Result; } @@ -1338,10 +1439,8 @@ bool BasicAliasAnalysis::isValueEqualInPotentialCycles(const Value *V, // Make sure that the visited phis cannot reach the Value. This ensures that // the Values cannot come from different iterations of a potential cycle the // phi nodes could be involved in. - for (SmallPtrSet<const BasicBlock *, 8>::iterator PI = VisitedPhiBBs.begin(), - PE = VisitedPhiBBs.end(); - PI != PE; ++PI) - if (isPotentiallyReachable((*PI)->begin(), Inst, DT, LI)) + for (auto *P : VisitedPhiBBs) + if (isPotentiallyReachable(P->begin(), Inst, DT, LI)) return false; return true; diff --git a/lib/Analysis/BlockFrequencyInfoImpl.cpp b/lib/Analysis/BlockFrequencyInfoImpl.cpp index 723332d480f4..278073cd6c52 100644 --- a/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -614,7 +614,8 @@ static void findIrreducibleHeaders( break; } } - assert(Headers.size() >= 2 && "Should be irreducible"); + assert(Headers.size() >= 2 && + "Expected irreducible CFG; -loop-info is likely invalid"); if (Headers.size() == InSCC.size()) { // Every block is a header. std::sort(Headers.begin(), Headers.end()); diff --git a/lib/Analysis/BranchProbabilityInfo.cpp b/lib/Analysis/BranchProbabilityInfo.cpp index bbd875059522..2b39d47e5faf 100644 --- a/lib/Analysis/BranchProbabilityInfo.cpp +++ b/lib/Analysis/BranchProbabilityInfo.cpp @@ -196,7 +196,8 @@ bool BranchProbabilityInfo::calcMetadataWeights(BasicBlock *BB) { SmallVector<uint32_t, 2> Weights; Weights.reserve(TI->getNumSuccessors()); for (unsigned i = 1, e = WeightsNode->getNumOperands(); i != e; ++i) { - ConstantInt *Weight = dyn_cast<ConstantInt>(WeightsNode->getOperand(i)); + ConstantInt *Weight = + mdconst::dyn_extract<ConstantInt>(WeightsNode->getOperand(i)); if (!Weight) return false; Weights.push_back( diff --git a/lib/Analysis/CFG.cpp b/lib/Analysis/CFG.cpp index 8ef5302717f8..8ecd70b5d716 100644 --- a/lib/Analysis/CFG.cpp +++ b/lib/Analysis/CFG.cpp @@ -27,7 +27,7 @@ using namespace llvm; void llvm::FindFunctionBackedges(const Function &F, SmallVectorImpl<std::pair<const BasicBlock*,const BasicBlock*> > &Result) { const BasicBlock *BB = &F.getEntryBlock(); - if (succ_begin(BB) == succ_end(BB)) + if (succ_empty(BB)) return; SmallPtrSet<const BasicBlock*, 8> Visited; @@ -45,7 +45,7 @@ void llvm::FindFunctionBackedges(const Function &F, bool FoundNew = false; while (I != succ_end(ParentBB)) { BB = *I++; - if (Visited.insert(BB)) { + if (Visited.insert(BB).second) { FoundNew = true; break; } @@ -141,7 +141,7 @@ static bool isPotentiallyReachableInner(SmallVectorImpl<BasicBlock *> &Worklist, SmallSet<const BasicBlock*, 64> Visited; do { BasicBlock *BB = Worklist.pop_back_val(); - if (!Visited.insert(BB)) + if (!Visited.insert(BB).second) continue; if (BB == StopBB) return true; diff --git a/lib/Analysis/CFGPrinter.cpp b/lib/Analysis/CFGPrinter.cpp index c2c19d6a69ef..89787f826bec 100644 --- a/lib/Analysis/CFGPrinter.cpp +++ b/lib/Analysis/CFGPrinter.cpp @@ -79,11 +79,11 @@ namespace { bool runOnFunction(Function &F) override { std::string Filename = "cfg." + F.getName().str() + ".dot"; errs() << "Writing '" << Filename << "'..."; - - std::string ErrorInfo; - raw_fd_ostream File(Filename.c_str(), ErrorInfo, sys::fs::F_Text); - if (ErrorInfo.empty()) + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::F_Text); + + if (!EC) WriteGraph(File, (const Function*)&F); else errs() << " error opening file for writing!"; @@ -114,10 +114,10 @@ namespace { std::string Filename = "cfg." + F.getName().str() + ".dot"; errs() << "Writing '" << Filename << "'..."; - std::string ErrorInfo; - raw_fd_ostream File(Filename.c_str(), ErrorInfo, sys::fs::F_Text); - - if (ErrorInfo.empty()) + std::error_code EC; + raw_fd_ostream File(Filename, EC, sys::fs::F_Text); + + if (!EC) WriteGraph(File, (const Function*)&F, true); else errs() << " error opening file for writing!"; diff --git a/lib/Analysis/CFLAliasAnalysis.cpp b/lib/Analysis/CFLAliasAnalysis.cpp new file mode 100644 index 000000000000..88bb84a80020 --- /dev/null +++ b/lib/Analysis/CFLAliasAnalysis.cpp @@ -0,0 +1,1013 @@ +//===- CFLAliasAnalysis.cpp - CFL-Based Alias Analysis Implementation ------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-based context-insensitive alias analysis +// algorithm. It does not depend on types. The algorithm is a mixture of the one +// described in "Demand-driven alias analysis for C" by Xin Zheng and Radu +// Rugina, and "Fast algorithms for Dyck-CFL-reachability with applications to +// Alias Analysis" by Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the +// papers, we build a graph of the uses of a variable, where each node is a +// memory location, and each edge is an action that happened on that memory +// location. The "actions" can be one of Dereference, Reference, Assign, or +// Assign. +// +// Two variables are considered as aliasing iff you can reach one value's node +// from the other value's node and the language formed by concatenating all of +// the edge labels (actions) conforms to a context-free grammar. +// +// Because this algorithm requires a graph search on each query, we execute the +// algorithm outlined in "Fast algorithms..." (mentioned above) +// in order to transform the graph into sets of variables that may alias in +// ~nlogn time (n = number of variables.), which makes queries take constant +// time. +//===----------------------------------------------------------------------===// + +#include "StratifiedSets.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" +#include <algorithm> +#include <cassert> +#include <forward_list> +#include <tuple> + +using namespace llvm; + +// Try to go from a Value* to a Function*. Never returns nullptr. +static Optional<Function *> parentFunctionOfValue(Value *); + +// Returns possible functions called by the Inst* into the given +// SmallVectorImpl. Returns true if targets found, false otherwise. +// This is templated because InvokeInst/CallInst give us the same +// set of functions that we care about, and I don't like repeating +// myself. +template <typename Inst> +static bool getPossibleTargets(Inst *, SmallVectorImpl<Function *> &); + +// Some instructions need to have their users tracked. Instructions like +// `add` require you to get the users of the Instruction* itself, other +// instructions like `store` require you to get the users of the first +// operand. This function gets the "proper" value to track for each +// type of instruction we support. +static Optional<Value *> getTargetValue(Instruction *); + +// There are certain instructions (i.e. FenceInst, etc.) that we ignore. +// This notes that we should ignore those. +static bool hasUsefulEdges(Instruction *); + +const StratifiedIndex StratifiedLink::SetSentinel = + std::numeric_limits<StratifiedIndex>::max(); + +namespace { +// StratifiedInfo Attribute things. +typedef unsigned StratifiedAttr; +LLVM_CONSTEXPR unsigned MaxStratifiedAttrIndex = NumStratifiedAttrs; +LLVM_CONSTEXPR unsigned AttrAllIndex = 0; +LLVM_CONSTEXPR unsigned AttrGlobalIndex = 1; +LLVM_CONSTEXPR unsigned AttrFirstArgIndex = 2; +LLVM_CONSTEXPR unsigned AttrLastArgIndex = MaxStratifiedAttrIndex; +LLVM_CONSTEXPR unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; + +LLVM_CONSTEXPR StratifiedAttr AttrNone = 0; +LLVM_CONSTEXPR StratifiedAttr AttrAll = ~AttrNone; + +// \brief StratifiedSets call for knowledge of "direction", so this is how we +// represent that locally. +enum class Level { Same, Above, Below }; + +// \brief Edges can be one of four "weights" -- each weight must have an inverse +// weight (Assign has Assign; Reference has Dereference). +enum class EdgeType { + // The weight assigned when assigning from or to a value. For example, in: + // %b = getelementptr %a, 0 + // ...The relationships are %b assign %a, and %a assign %b. This used to be + // two edges, but having a distinction bought us nothing. + Assign, + + // The edge used when we have an edge going from some handle to a Value. + // Examples of this include: + // %b = load %a (%b Dereference %a) + // %b = extractelement %a, 0 (%a Dereference %b) + Dereference, + + // The edge used when our edge goes from a value to a handle that may have + // contained it at some point. Examples: + // %b = load %a (%a Reference %b) + // %b = extractelement %a, 0 (%b Reference %a) + Reference +}; + +// \brief Encodes the notion of a "use" +struct Edge { + // \brief Which value the edge is coming from + Value *From; + + // \brief Which value the edge is pointing to + Value *To; + + // \brief Edge weight + EdgeType Weight; + + // \brief Whether we aliased any external values along the way that may be + // invisible to the analysis (i.e. landingpad for exceptions, calls for + // interprocedural analysis, etc.) + StratifiedAttrs AdditionalAttrs; + + Edge(Value *From, Value *To, EdgeType W, StratifiedAttrs A) + : From(From), To(To), Weight(W), AdditionalAttrs(A) {} +}; + +// \brief Information we have about a function and would like to keep around +struct FunctionInfo { + StratifiedSets<Value *> Sets; + // Lots of functions have < 4 returns. Adjust as necessary. + SmallVector<Value *, 4> ReturnedValues; + + FunctionInfo(StratifiedSets<Value *> &&S, + SmallVector<Value *, 4> &&RV) + : Sets(std::move(S)), ReturnedValues(std::move(RV)) {} +}; + +struct CFLAliasAnalysis; + +struct FunctionHandle : public CallbackVH { + FunctionHandle(Function *Fn, CFLAliasAnalysis *CFLAA) + : CallbackVH(Fn), CFLAA(CFLAA) { + assert(Fn != nullptr); + assert(CFLAA != nullptr); + } + + virtual ~FunctionHandle() {} + + void deleted() override { removeSelfFromCache(); } + void allUsesReplacedWith(Value *) override { removeSelfFromCache(); } + +private: + CFLAliasAnalysis *CFLAA; + + void removeSelfFromCache(); +}; + +struct CFLAliasAnalysis : public ImmutablePass, public AliasAnalysis { +private: + /// \brief Cached mapping of Functions to their StratifiedSets. + /// If a function's sets are currently being built, it is marked + /// in the cache as an Optional without a value. This way, if we + /// have any kind of recursion, it is discernable from a function + /// that simply has empty sets. + DenseMap<Function *, Optional<FunctionInfo>> Cache; + std::forward_list<FunctionHandle> Handles; + +public: + static char ID; + + CFLAliasAnalysis() : ImmutablePass(ID) { + initializeCFLAliasAnalysisPass(*PassRegistry::getPassRegistry()); + } + + virtual ~CFLAliasAnalysis() {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AliasAnalysis::getAnalysisUsage(AU); + } + + void *getAdjustedAnalysisPointer(const void *ID) override { + if (ID == &AliasAnalysis::ID) + return (AliasAnalysis *)this; + return this; + } + + /// \brief Inserts the given Function into the cache. + void scan(Function *Fn); + + void evict(Function *Fn) { Cache.erase(Fn); } + + /// \brief Ensures that the given function is available in the cache. + /// Returns the appropriate entry from the cache. + const Optional<FunctionInfo> &ensureCached(Function *Fn) { + auto Iter = Cache.find(Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; + } + + AliasResult query(const Location &LocA, const Location &LocB); + + AliasResult alias(const Location &LocA, const Location &LocB) override { + if (LocA.Ptr == LocB.Ptr) { + if (LocA.Size == LocB.Size) { + return MustAlias; + } else { + return PartialAlias; + } + } + + // Comparisons between global variables and other constants should be + // handled by BasicAA. + if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr)) { + return MayAlias; + } + + return query(LocA, LocB); + } + + void initializePass() override { InitializeAliasAnalysis(this); } +}; + +void FunctionHandle::removeSelfFromCache() { + assert(CFLAA != nullptr); + auto *Val = getValPtr(); + CFLAA->evict(cast<Function>(Val)); + setValPtr(nullptr); +} + +// \brief Gets the edges our graph should have, based on an Instruction* +class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { + CFLAliasAnalysis &AA; + SmallVectorImpl<Edge> &Output; + +public: + GetEdgesVisitor(CFLAliasAnalysis &AA, SmallVectorImpl<Edge> &Output) + : AA(AA), Output(Output) {} + + void visitInstruction(Instruction &) { + llvm_unreachable("Unsupported instruction encountered"); + } + + void visitCastInst(CastInst &Inst) { + Output.push_back(Edge(&Inst, Inst.getOperand(0), EdgeType::Assign, + AttrNone)); + } + + void visitBinaryOperator(BinaryOperator &Inst) { + auto *Op1 = Inst.getOperand(0); + auto *Op2 = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Op1, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Op2, EdgeType::Assign, AttrNone)); + } + + void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getNewValOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitAtomicRMWInst(AtomicRMWInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitPHINode(PHINode &Inst) { + for (unsigned I = 0, E = Inst.getNumIncomingValues(); I != E; ++I) { + Value *Val = Inst.getIncomingValue(I); + Output.push_back(Edge(&Inst, Val, EdgeType::Assign, AttrNone)); + } + } + + void visitGetElementPtrInst(GetElementPtrInst &Inst) { + auto *Op = Inst.getPointerOperand(); + Output.push_back(Edge(&Inst, Op, EdgeType::Assign, AttrNone)); + for (auto I = Inst.idx_begin(), E = Inst.idx_end(); I != E; ++I) + Output.push_back(Edge(&Inst, *I, EdgeType::Assign, AttrNone)); + } + + void visitSelectInst(SelectInst &Inst) { + auto *Condition = Inst.getCondition(); + Output.push_back(Edge(&Inst, Condition, EdgeType::Assign, AttrNone)); + auto *TrueVal = Inst.getTrueValue(); + Output.push_back(Edge(&Inst, TrueVal, EdgeType::Assign, AttrNone)); + auto *FalseVal = Inst.getFalseValue(); + Output.push_back(Edge(&Inst, FalseVal, EdgeType::Assign, AttrNone)); + } + + void visitAllocaInst(AllocaInst &) {} + + void visitLoadInst(LoadInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = &Inst; + Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitStoreInst(StoreInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValueOperand(); + Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); + } + + void visitVAArgInst(VAArgInst &Inst) { + // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it does + // two things: + // 1. Loads a value from *((T*)*Ptr). + // 2. Increments (stores to) *Ptr by some target-specific amount. + // For now, we'll handle this like a landingpad instruction (by placing the + // result in its own group, and having that group alias externals). + auto *Val = &Inst; + Output.push_back(Edge(Val, Val, EdgeType::Assign, AttrAll)); + } + + static bool isFunctionExternal(Function *Fn) { + return Fn->isDeclaration() || !Fn->hasLocalLinkage(); + } + + // Gets whether the sets at Index1 above, below, or equal to the sets at + // Index2. Returns None if they are not in the same set chain. + static Optional<Level> getIndexRelation(const StratifiedSets<Value *> &Sets, + StratifiedIndex Index1, + StratifiedIndex Index2) { + if (Index1 == Index2) + return Level::Same; + + const auto *Current = &Sets.getLink(Index1); + while (Current->hasBelow()) { + if (Current->Below == Index2) + return Level::Below; + Current = &Sets.getLink(Current->Below); + } + + Current = &Sets.getLink(Index1); + while (Current->hasAbove()) { + if (Current->Above == Index2) + return Level::Above; + Current = &Sets.getLink(Current->Above); + } + + return NoneType(); + } + + bool + tryInterproceduralAnalysis(const SmallVectorImpl<Function *> &Fns, + Value *FuncValue, + const iterator_range<User::op_iterator> &Args) { + const unsigned ExpectedMaxArgs = 8; + const unsigned MaxSupportedArgs = 50; + assert(Fns.size() > 0); + + // I put this here to give us an upper bound on time taken by IPA. Is it + // really (realistically) needed? Keep in mind that we do have an n^2 algo. + if (std::distance(Args.begin(), Args.end()) > (int) MaxSupportedArgs) + return false; + + // Exit early if we'll fail anyway + for (auto *Fn : Fns) { + if (isFunctionExternal(Fn) || Fn->isVarArg()) + return false; + auto &MaybeInfo = AA.ensureCached(Fn); + if (!MaybeInfo.hasValue()) + return false; + } + + SmallVector<Value *, ExpectedMaxArgs> Arguments(Args.begin(), Args.end()); + SmallVector<StratifiedInfo, ExpectedMaxArgs> Parameters; + for (auto *Fn : Fns) { + auto &Info = *AA.ensureCached(Fn); + auto &Sets = Info.Sets; + auto &RetVals = Info.ReturnedValues; + + Parameters.clear(); + for (auto &Param : Fn->args()) { + auto MaybeInfo = Sets.find(&Param); + // Did a new parameter somehow get added to the function/slip by? + if (!MaybeInfo.hasValue()) + return false; + Parameters.push_back(*MaybeInfo); + } + + // Adding an edge from argument -> return value for each parameter that + // may alias the return value + for (unsigned I = 0, E = Parameters.size(); I != E; ++I) { + auto &ParamInfo = Parameters[I]; + auto &ArgVal = Arguments[I]; + bool AddEdge = false; + StratifiedAttrs Externals; + for (unsigned X = 0, XE = RetVals.size(); X != XE; ++X) { + auto MaybeInfo = Sets.find(RetVals[X]); + if (!MaybeInfo.hasValue()) + return false; + + auto &RetInfo = *MaybeInfo; + auto RetAttrs = Sets.getLink(RetInfo.Index).Attrs; + auto ParamAttrs = Sets.getLink(ParamInfo.Index).Attrs; + auto MaybeRelation = + getIndexRelation(Sets, ParamInfo.Index, RetInfo.Index); + if (MaybeRelation.hasValue()) { + AddEdge = true; + Externals |= RetAttrs | ParamAttrs; + } + } + if (AddEdge) + Output.push_back(Edge(FuncValue, ArgVal, EdgeType::Assign, + StratifiedAttrs().flip())); + } + + if (Parameters.size() != Arguments.size()) + return false; + + // Adding edges between arguments for arguments that may end up aliasing + // each other. This is necessary for functions such as + // void foo(int** a, int** b) { *a = *b; } + // (Technically, the proper sets for this would be those below + // Arguments[I] and Arguments[X], but our algorithm will produce + // extremely similar, and equally correct, results either way) + for (unsigned I = 0, E = Arguments.size(); I != E; ++I) { + auto &MainVal = Arguments[I]; + auto &MainInfo = Parameters[I]; + auto &MainAttrs = Sets.getLink(MainInfo.Index).Attrs; + for (unsigned X = I + 1; X != E; ++X) { + auto &SubInfo = Parameters[X]; + auto &SubVal = Arguments[X]; + auto &SubAttrs = Sets.getLink(SubInfo.Index).Attrs; + auto MaybeRelation = + getIndexRelation(Sets, MainInfo.Index, SubInfo.Index); + + if (!MaybeRelation.hasValue()) + continue; + + auto NewAttrs = SubAttrs | MainAttrs; + Output.push_back(Edge(MainVal, SubVal, EdgeType::Assign, NewAttrs)); + } + } + } + return true; + } + + template <typename InstT> void visitCallLikeInst(InstT &Inst) { + SmallVector<Function *, 4> Targets; + if (getPossibleTargets(&Inst, Targets)) { + if (tryInterproceduralAnalysis(Targets, &Inst, Inst.arg_operands())) + return; + // Cleanup from interprocedural analysis + Output.clear(); + } + + for (Value *V : Inst.arg_operands()) + Output.push_back(Edge(&Inst, V, EdgeType::Assign, AttrAll)); + } + + void visitCallInst(CallInst &Inst) { visitCallLikeInst(Inst); } + + void visitInvokeInst(InvokeInst &Inst) { visitCallLikeInst(Inst); } + + // Because vectors/aggregates are immutable and unaddressable, + // there's nothing we can do to coax a value out of them, other + // than calling Extract{Element,Value}. We can effectively treat + // them as pointers to arbitrary memory locations we can store in + // and load from. + void visitExtractElementInst(ExtractElementInst &Inst) { + auto *Ptr = Inst.getVectorOperand(); + auto *Val = &Inst; + Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitInsertElementInst(InsertElementInst &Inst) { + auto *Vec = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Vec, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); + } + + void visitLandingPadInst(LandingPadInst &Inst) { + // Exceptions come from "nowhere", from our analysis' perspective. + // So we place the instruction its own group, noting that said group may + // alias externals + Output.push_back(Edge(&Inst, &Inst, EdgeType::Assign, AttrAll)); + } + + void visitInsertValueInst(InsertValueInst &Inst) { + auto *Agg = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + Output.push_back(Edge(&Inst, Agg, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); + } + + void visitExtractValueInst(ExtractValueInst &Inst) { + auto *Ptr = Inst.getAggregateOperand(); + Output.push_back(Edge(&Inst, Ptr, EdgeType::Reference, AttrNone)); + } + + void visitShuffleVectorInst(ShuffleVectorInst &Inst) { + auto *From1 = Inst.getOperand(0); + auto *From2 = Inst.getOperand(1); + Output.push_back(Edge(&Inst, From1, EdgeType::Assign, AttrNone)); + Output.push_back(Edge(&Inst, From2, EdgeType::Assign, AttrNone)); + } +}; + +// For a given instruction, we need to know which Value* to get the +// users of in order to build our graph. In some cases (i.e. add), +// we simply need the Instruction*. In other cases (i.e. store), +// finding the users of the Instruction* is useless; we need to find +// the users of the first operand. This handles determining which +// value to follow for us. +// +// Note: we *need* to keep this in sync with GetEdgesVisitor. Add +// something to GetEdgesVisitor, add it here -- remove something from +// GetEdgesVisitor, remove it here. +class GetTargetValueVisitor + : public InstVisitor<GetTargetValueVisitor, Value *> { +public: + Value *visitInstruction(Instruction &Inst) { return &Inst; } + + Value *visitStoreInst(StoreInst &Inst) { return Inst.getPointerOperand(); } + + Value *visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + return Inst.getPointerOperand(); + } + + Value *visitAtomicRMWInst(AtomicRMWInst &Inst) { + return Inst.getPointerOperand(); + } + + Value *visitInsertElementInst(InsertElementInst &Inst) { + return Inst.getOperand(0); + } + + Value *visitInsertValueInst(InsertValueInst &Inst) { + return Inst.getAggregateOperand(); + } +}; + +// Set building requires a weighted bidirectional graph. +template <typename EdgeTypeT> class WeightedBidirectionalGraph { +public: + typedef std::size_t Node; + +private: + const static Node StartNode = Node(0); + + struct Edge { + EdgeTypeT Weight; + Node Other; + + Edge(const EdgeTypeT &W, const Node &N) + : Weight(W), Other(N) {} + + bool operator==(const Edge &E) const { + return Weight == E.Weight && Other == E.Other; + } + + bool operator!=(const Edge &E) const { return !operator==(E); } + }; + + struct NodeImpl { + std::vector<Edge> Edges; + }; + + std::vector<NodeImpl> NodeImpls; + + bool inbounds(Node NodeIndex) const { return NodeIndex < NodeImpls.size(); } + + const NodeImpl &getNode(Node N) const { return NodeImpls[N]; } + NodeImpl &getNode(Node N) { return NodeImpls[N]; } + +public: + // ----- Various Edge iterators for the graph ----- // + + // \brief Iterator for edges. Because this graph is bidirected, we don't + // allow modificaiton of the edges using this iterator. Additionally, the + // iterator becomes invalid if you add edges to or from the node you're + // getting the edges of. + struct EdgeIterator : public std::iterator<std::forward_iterator_tag, + std::tuple<EdgeTypeT, Node *>> { + EdgeIterator(const typename std::vector<Edge>::const_iterator &Iter) + : Current(Iter) {} + + EdgeIterator(NodeImpl &Impl) : Current(Impl.begin()) {} + + EdgeIterator &operator++() { + ++Current; + return *this; + } + + EdgeIterator operator++(int) { + EdgeIterator Copy(Current); + operator++(); + return Copy; + } + + std::tuple<EdgeTypeT, Node> &operator*() { + Store = std::make_tuple(Current->Weight, Current->Other); + return Store; + } + + bool operator==(const EdgeIterator &Other) const { + return Current == Other.Current; + } + + bool operator!=(const EdgeIterator &Other) const { + return !operator==(Other); + } + + private: + typename std::vector<Edge>::const_iterator Current; + std::tuple<EdgeTypeT, Node> Store; + }; + + // Wrapper for EdgeIterator with begin()/end() calls. + struct EdgeIterable { + EdgeIterable(const std::vector<Edge> &Edges) + : BeginIter(Edges.begin()), EndIter(Edges.end()) {} + + EdgeIterator begin() { return EdgeIterator(BeginIter); } + + EdgeIterator end() { return EdgeIterator(EndIter); } + + private: + typename std::vector<Edge>::const_iterator BeginIter; + typename std::vector<Edge>::const_iterator EndIter; + }; + + // ----- Actual graph-related things ----- // + + WeightedBidirectionalGraph() {} + + WeightedBidirectionalGraph(WeightedBidirectionalGraph<EdgeTypeT> &&Other) + : NodeImpls(std::move(Other.NodeImpls)) {} + + WeightedBidirectionalGraph<EdgeTypeT> & + operator=(WeightedBidirectionalGraph<EdgeTypeT> &&Other) { + NodeImpls = std::move(Other.NodeImpls); + return *this; + } + + Node addNode() { + auto Index = NodeImpls.size(); + auto NewNode = Node(Index); + NodeImpls.push_back(NodeImpl()); + return NewNode; + } + + void addEdge(Node From, Node To, const EdgeTypeT &Weight, + const EdgeTypeT &ReverseWeight) { + assert(inbounds(From)); + assert(inbounds(To)); + auto &FromNode = getNode(From); + auto &ToNode = getNode(To); + FromNode.Edges.push_back(Edge(Weight, To)); + ToNode.Edges.push_back(Edge(ReverseWeight, From)); + } + + EdgeIterable edgesFor(const Node &N) const { + const auto &Node = getNode(N); + return EdgeIterable(Node.Edges); + } + + bool empty() const { return NodeImpls.empty(); } + std::size_t size() const { return NodeImpls.size(); } + + // \brief Gets an arbitrary node in the graph as a starting point for + // traversal. + Node getEntryNode() { + assert(inbounds(StartNode)); + return StartNode; + } +}; + +typedef WeightedBidirectionalGraph<std::pair<EdgeType, StratifiedAttrs>> GraphT; +typedef DenseMap<Value *, GraphT::Node> NodeMapT; +} + +// -- Setting up/registering CFLAA pass -- // +char CFLAliasAnalysis::ID = 0; + +INITIALIZE_AG_PASS(CFLAliasAnalysis, AliasAnalysis, "cfl-aa", + "CFL-Based AA implementation", false, true, false) + +ImmutablePass *llvm::createCFLAliasAnalysisPass() { + return new CFLAliasAnalysis(); +} + +//===----------------------------------------------------------------------===// +// Function declarations that require types defined in the namespace above +//===----------------------------------------------------------------------===// + +// Given an argument number, returns the appropriate Attr index to set. +static StratifiedAttr argNumberToAttrIndex(StratifiedAttr); + +// Given a Value, potentially return which AttrIndex it maps to. +static Optional<StratifiedAttr> valueToAttrIndex(Value *Val); + +// Gets the inverse of a given EdgeType. +static EdgeType flipWeight(EdgeType); + +// Gets edges of the given Instruction*, writing them to the SmallVector*. +static void argsToEdges(CFLAliasAnalysis &, Instruction *, + SmallVectorImpl<Edge> &); + +// Gets the "Level" that one should travel in StratifiedSets +// given an EdgeType. +static Level directionOfEdgeType(EdgeType); + +// Builds the graph needed for constructing the StratifiedSets for the +// given function +static void buildGraphFrom(CFLAliasAnalysis &, Function *, + SmallVectorImpl<Value *> &, NodeMapT &, GraphT &); + +// Builds the graph + StratifiedSets for a function. +static FunctionInfo buildSetsFrom(CFLAliasAnalysis &, Function *); + +static Optional<Function *> parentFunctionOfValue(Value *Val) { + if (auto *Inst = dyn_cast<Instruction>(Val)) { + auto *Bb = Inst->getParent(); + return Bb->getParent(); + } + + if (auto *Arg = dyn_cast<Argument>(Val)) + return Arg->getParent(); + return NoneType(); +} + +template <typename Inst> +static bool getPossibleTargets(Inst *Call, + SmallVectorImpl<Function *> &Output) { + if (auto *Fn = Call->getCalledFunction()) { + Output.push_back(Fn); + return true; + } + + // TODO: If the call is indirect, we might be able to enumerate all potential + // targets of the call and return them, rather than just failing. + return false; +} + +static Optional<Value *> getTargetValue(Instruction *Inst) { + GetTargetValueVisitor V; + return V.visit(Inst); +} + +static bool hasUsefulEdges(Instruction *Inst) { + bool IsNonInvokeTerminator = + isa<TerminatorInst>(Inst) && !isa<InvokeInst>(Inst); + return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && !IsNonInvokeTerminator; +} + +static Optional<StratifiedAttr> valueToAttrIndex(Value *Val) { + if (isa<GlobalValue>(Val)) + return AttrGlobalIndex; + + if (auto *Arg = dyn_cast<Argument>(Val)) + if (!Arg->hasNoAliasAttr()) + return argNumberToAttrIndex(Arg->getArgNo()); + return NoneType(); +} + +static StratifiedAttr argNumberToAttrIndex(unsigned ArgNum) { + if (ArgNum > AttrMaxNumArgs) + return AttrAllIndex; + return ArgNum + AttrFirstArgIndex; +} + +static EdgeType flipWeight(EdgeType Initial) { + switch (Initial) { + case EdgeType::Assign: + return EdgeType::Assign; + case EdgeType::Dereference: + return EdgeType::Reference; + case EdgeType::Reference: + return EdgeType::Dereference; + } + llvm_unreachable("Incomplete coverage of EdgeType enum"); +} + +static void argsToEdges(CFLAliasAnalysis &Analysis, Instruction *Inst, + SmallVectorImpl<Edge> &Output) { + GetEdgesVisitor v(Analysis, Output); + v.visit(Inst); +} + +static Level directionOfEdgeType(EdgeType Weight) { + switch (Weight) { + case EdgeType::Reference: + return Level::Above; + case EdgeType::Dereference: + return Level::Below; + case EdgeType::Assign: + return Level::Same; + } + llvm_unreachable("Incomplete switch coverage"); +} + +// Aside: We may remove graph construction entirely, because it doesn't really +// buy us much that we don't already have. I'd like to add interprocedural +// analysis prior to this however, in case that somehow requires the graph +// produced by this for efficient execution +static void buildGraphFrom(CFLAliasAnalysis &Analysis, Function *Fn, + SmallVectorImpl<Value *> &ReturnedValues, + NodeMapT &Map, GraphT &Graph) { + const auto findOrInsertNode = [&Map, &Graph](Value *Val) { + auto Pair = Map.insert(std::make_pair(Val, GraphT::Node())); + auto &Iter = Pair.first; + if (Pair.second) { + auto NewNode = Graph.addNode(); + Iter->second = NewNode; + } + return Iter->second; + }; + + SmallVector<Edge, 8> Edges; + for (auto &Bb : Fn->getBasicBlockList()) { + for (auto &Inst : Bb.getInstList()) { + // We don't want the edges of most "return" instructions, but we *do* want + // to know what can be returned. + if (auto *Ret = dyn_cast<ReturnInst>(&Inst)) + ReturnedValues.push_back(Ret); + + if (!hasUsefulEdges(&Inst)) + continue; + + Edges.clear(); + argsToEdges(Analysis, &Inst, Edges); + + // In the case of an unused alloca (or similar), edges may be empty. Note + // that it exists so we can potentially answer NoAlias. + if (Edges.empty()) { + auto MaybeVal = getTargetValue(&Inst); + assert(MaybeVal.hasValue()); + auto *Target = *MaybeVal; + findOrInsertNode(Target); + continue; + } + + for (const Edge &E : Edges) { + auto To = findOrInsertNode(E.To); + auto From = findOrInsertNode(E.From); + auto FlippedWeight = flipWeight(E.Weight); + auto Attrs = E.AdditionalAttrs; + Graph.addEdge(From, To, std::make_pair(E.Weight, Attrs), + std::make_pair(FlippedWeight, Attrs)); + } + } + } +} + +static FunctionInfo buildSetsFrom(CFLAliasAnalysis &Analysis, Function *Fn) { + NodeMapT Map; + GraphT Graph; + SmallVector<Value *, 4> ReturnedValues; + + buildGraphFrom(Analysis, Fn, ReturnedValues, Map, Graph); + + DenseMap<GraphT::Node, Value *> NodeValueMap; + NodeValueMap.resize(Map.size()); + for (const auto &Pair : Map) + NodeValueMap.insert(std::make_pair(Pair.second, Pair.first)); + + const auto findValueOrDie = [&NodeValueMap](GraphT::Node Node) { + auto ValIter = NodeValueMap.find(Node); + assert(ValIter != NodeValueMap.end()); + return ValIter->second; + }; + + StratifiedSetsBuilder<Value *> Builder; + + SmallVector<GraphT::Node, 16> Worklist; + for (auto &Pair : Map) { + Worklist.clear(); + + auto *Value = Pair.first; + Builder.add(Value); + auto InitialNode = Pair.second; + Worklist.push_back(InitialNode); + while (!Worklist.empty()) { + auto Node = Worklist.pop_back_val(); + auto *CurValue = findValueOrDie(Node); + if (isa<Constant>(CurValue) && !isa<GlobalValue>(CurValue)) + continue; + + for (const auto &EdgeTuple : Graph.edgesFor(Node)) { + auto Weight = std::get<0>(EdgeTuple); + auto Label = Weight.first; + auto &OtherNode = std::get<1>(EdgeTuple); + auto *OtherValue = findValueOrDie(OtherNode); + + if (isa<Constant>(OtherValue) && !isa<GlobalValue>(OtherValue)) + continue; + + bool Added; + switch (directionOfEdgeType(Label)) { + case Level::Above: + Added = Builder.addAbove(CurValue, OtherValue); + break; + case Level::Below: + Added = Builder.addBelow(CurValue, OtherValue); + break; + case Level::Same: + Added = Builder.addWith(CurValue, OtherValue); + break; + } + + if (Added) { + auto Aliasing = Weight.second; + if (auto MaybeCurIndex = valueToAttrIndex(CurValue)) + Aliasing.set(*MaybeCurIndex); + if (auto MaybeOtherIndex = valueToAttrIndex(OtherValue)) + Aliasing.set(*MaybeOtherIndex); + Builder.noteAttributes(CurValue, Aliasing); + Builder.noteAttributes(OtherValue, Aliasing); + Worklist.push_back(OtherNode); + } + } + } + } + + // There are times when we end up with parameters not in our graph (i.e. if + // it's only used as the condition of a branch). Other bits of code depend on + // things that were present during construction being present in the graph. + // So, we add all present arguments here. + for (auto &Arg : Fn->args()) { + Builder.add(&Arg); + } + + return FunctionInfo(Builder.build(), std::move(ReturnedValues)); +} + +void CFLAliasAnalysis::scan(Function *Fn) { + auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + FunctionInfo Info(buildSetsFrom(*this, Fn)); + Cache[Fn] = std::move(Info); + Handles.push_front(FunctionHandle(Fn, this)); +} + +AliasAnalysis::AliasResult +CFLAliasAnalysis::query(const AliasAnalysis::Location &LocA, + const AliasAnalysis::Location &LocB) { + auto *ValA = const_cast<Value *>(LocA.Ptr); + auto *ValB = const_cast<Value *>(LocB.Ptr); + + Function *Fn = nullptr; + auto MaybeFnA = parentFunctionOfValue(ValA); + auto MaybeFnB = parentFunctionOfValue(ValB); + if (!MaybeFnA.hasValue() && !MaybeFnB.hasValue()) { + llvm_unreachable("Don't know how to extract the parent function " + "from values A or B"); + } + + if (MaybeFnA.hasValue()) { + Fn = *MaybeFnA; + assert((!MaybeFnB.hasValue() || *MaybeFnB == *MaybeFnA) && + "Interprocedural queries not supported"); + } else { + Fn = *MaybeFnB; + } + + assert(Fn != nullptr); + auto &MaybeInfo = ensureCached(Fn); + assert(MaybeInfo.hasValue()); + + auto &Sets = MaybeInfo->Sets; + auto MaybeA = Sets.find(ValA); + if (!MaybeA.hasValue()) + return AliasAnalysis::MayAlias; + + auto MaybeB = Sets.find(ValB); + if (!MaybeB.hasValue()) + return AliasAnalysis::MayAlias; + + auto SetA = *MaybeA; + auto SetB = *MaybeB; + + if (SetA.Index == SetB.Index) + return AliasAnalysis::PartialAlias; + + auto AttrsA = Sets.getLink(SetA.Index).Attrs; + auto AttrsB = Sets.getLink(SetB.Index).Attrs; + // Stratified set attributes are used as markets to signify whether a member + // of a StratifiedSet (or a member of a set above the current set) has + // interacted with either arguments or globals. "Interacted with" meaning + // its value may be different depending on the value of an argument or + // global. The thought behind this is that, because arguments and globals + // may alias each other, if AttrsA and AttrsB have touched args/globals, + // we must conservatively say that they alias. However, if at least one of + // the sets has no values that could legally be altered by changing the value + // of an argument or global, then we don't have to be as conservative. + if (AttrsA.any() && AttrsB.any()) + return AliasAnalysis::MayAlias; + + return AliasAnalysis::NoAlias; +} diff --git a/lib/Analysis/CGSCCPassManager.cpp b/lib/Analysis/CGSCCPassManager.cpp index 5d1d8a9c6e06..4a03002e510b 100644 --- a/lib/Analysis/CGSCCPassManager.cpp +++ b/lib/Analysis/CGSCCPassManager.cpp @@ -13,105 +13,10 @@ using namespace llvm; -static cl::opt<bool> -DebugPM("debug-cgscc-pass-manager", cl::Hidden, - cl::desc("Print CGSCC pass management debugging information")); - -PreservedAnalyses CGSCCPassManager::run(LazyCallGraph::SCC *C, - CGSCCAnalysisManager *AM) { - PreservedAnalyses PA = PreservedAnalyses::all(); - - if (DebugPM) - dbgs() << "Starting CGSCC pass manager run.\n"; - - for (unsigned Idx = 0, Size = Passes.size(); Idx != Size; ++Idx) { - if (DebugPM) - dbgs() << "Running CGSCC pass: " << Passes[Idx]->name() << "\n"; - - PreservedAnalyses PassPA = Passes[Idx]->run(C, AM); - if (AM) - AM->invalidate(C, PassPA); - PA.intersect(std::move(PassPA)); - } - - if (DebugPM) - dbgs() << "Finished CGSCC pass manager run.\n"; - - return PA; -} - -bool CGSCCAnalysisManager::empty() const { - assert(CGSCCAnalysisResults.empty() == CGSCCAnalysisResultLists.empty() && - "The storage and index of analysis results disagree on how many there " - "are!"); - return CGSCCAnalysisResults.empty(); -} - -void CGSCCAnalysisManager::clear() { - CGSCCAnalysisResults.clear(); - CGSCCAnalysisResultLists.clear(); -} - -CGSCCAnalysisManager::ResultConceptT & -CGSCCAnalysisManager::getResultImpl(void *PassID, LazyCallGraph::SCC *C) { - CGSCCAnalysisResultMapT::iterator RI; - bool Inserted; - std::tie(RI, Inserted) = CGSCCAnalysisResults.insert(std::make_pair( - std::make_pair(PassID, C), CGSCCAnalysisResultListT::iterator())); - - // If we don't have a cached result for this function, look up the pass and - // run it to produce a result, which we then add to the cache. - if (Inserted) { - CGSCCAnalysisResultListT &ResultList = CGSCCAnalysisResultLists[C]; - ResultList.emplace_back(PassID, lookupPass(PassID).run(C, this)); - RI->second = std::prev(ResultList.end()); - } - - return *RI->second->second; -} - -CGSCCAnalysisManager::ResultConceptT * -CGSCCAnalysisManager::getCachedResultImpl(void *PassID, - LazyCallGraph::SCC *C) const { - CGSCCAnalysisResultMapT::const_iterator RI = - CGSCCAnalysisResults.find(std::make_pair(PassID, C)); - return RI == CGSCCAnalysisResults.end() ? nullptr : &*RI->second->second; -} - -void CGSCCAnalysisManager::invalidateImpl(void *PassID, LazyCallGraph::SCC *C) { - CGSCCAnalysisResultMapT::iterator RI = - CGSCCAnalysisResults.find(std::make_pair(PassID, C)); - if (RI == CGSCCAnalysisResults.end()) - return; - - CGSCCAnalysisResultLists[C].erase(RI->second); -} - -void CGSCCAnalysisManager::invalidateImpl(LazyCallGraph::SCC *C, - const PreservedAnalyses &PA) { - // Clear all the invalidated results associated specifically with this - // function. - SmallVector<void *, 8> InvalidatedPassIDs; - CGSCCAnalysisResultListT &ResultsList = CGSCCAnalysisResultLists[C]; - for (CGSCCAnalysisResultListT::iterator I = ResultsList.begin(), - E = ResultsList.end(); - I != E;) - if (I->second->invalidate(C, PA)) { - InvalidatedPassIDs.push_back(I->first); - I = ResultsList.erase(I); - } else { - ++I; - } - while (!InvalidatedPassIDs.empty()) - CGSCCAnalysisResults.erase( - std::make_pair(InvalidatedPassIDs.pop_back_val(), C)); - CGSCCAnalysisResultLists.erase(C); -} - char CGSCCAnalysisManagerModuleProxy::PassID; CGSCCAnalysisManagerModuleProxy::Result -CGSCCAnalysisManagerModuleProxy::run(Module *M) { +CGSCCAnalysisManagerModuleProxy::run(Module &M) { assert(CGAM->empty() && "CGSCC analyses ran prior to the module proxy!"); return Result(*CGAM); } @@ -123,7 +28,7 @@ CGSCCAnalysisManagerModuleProxy::Result::~Result() { } bool CGSCCAnalysisManagerModuleProxy::Result::invalidate( - Module *M, const PreservedAnalyses &PA) { + Module &M, const PreservedAnalyses &PA) { // If this proxy isn't marked as preserved, then we can't even invalidate // individual CGSCC analyses, there may be an invalid set of SCC objects in // the cache making it impossible to incrementally preserve them. @@ -140,7 +45,7 @@ char ModuleAnalysisManagerCGSCCProxy::PassID; char FunctionAnalysisManagerCGSCCProxy::PassID; FunctionAnalysisManagerCGSCCProxy::Result -FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC *C) { +FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC &C) { assert(FAM->empty() && "Function analyses ran prior to the CGSCC proxy!"); return Result(*FAM); } @@ -152,7 +57,7 @@ FunctionAnalysisManagerCGSCCProxy::Result::~Result() { } bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate( - LazyCallGraph::SCC *C, const PreservedAnalyses &PA) { + LazyCallGraph::SCC &C, const PreservedAnalyses &PA) { // If this proxy isn't marked as preserved, then we can't even invalidate // individual function analyses, there may be an invalid set of Function // objects in the cache making it impossible to incrementally preserve them. diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index d1632fd26ace..fbd7cf59d04b 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -5,12 +5,14 @@ add_llvm_library(LLVMAnalysis AliasDebugger.cpp AliasSetTracker.cpp Analysis.cpp + AssumptionCache.cpp BasicAliasAnalysis.cpp BlockFrequencyInfo.cpp BlockFrequencyInfoImpl.cpp BranchProbabilityInfo.cpp CFG.cpp CFGPrinter.cpp + CFLAliasAnalysis.cpp CGSCCPassManager.cpp CaptureTracking.cpp CostModel.cpp @@ -20,6 +22,7 @@ add_llvm_library(LLVMAnalysis DependenceAnalysis.cpp DomPrinter.cpp DominanceFrontier.cpp + FunctionTargetTransformInfo.cpp IVUsers.cpp InstCount.cpp InstructionSimplify.cpp @@ -53,6 +56,7 @@ add_llvm_library(LLVMAnalysis TargetTransformInfo.cpp Trace.cpp TypeBasedAliasAnalysis.cpp + ScopedNoAliasAA.cpp ValueTracking.cpp ) diff --git a/lib/Analysis/CaptureTracking.cpp b/lib/Analysis/CaptureTracking.cpp index f2f8877af1a2..5a5475417951 100644 --- a/lib/Analysis/CaptureTracking.cpp +++ b/lib/Analysis/CaptureTracking.cpp @@ -19,8 +19,8 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" @@ -239,7 +239,7 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { if (Count++ >= Threshold) return Tracker->tooManyUses(); - if (Visited.insert(&UU)) + if (Visited.insert(&UU).second) if (Tracker->shouldExplore(&UU)) Worklist.push_back(&UU); } diff --git a/lib/Analysis/CodeMetrics.cpp b/lib/Analysis/CodeMetrics.cpp index 4c8a093684f6..fa5683c6820d 100644 --- a/lib/Analysis/CodeMetrics.cpp +++ b/lib/Analysis/CodeMetrics.cpp @@ -11,23 +11,113 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "code-metrics" using namespace llvm; +static void completeEphemeralValues(SmallVector<const Value *, 16> &WorkSet, + SmallPtrSetImpl<const Value*> &EphValues) { + SmallPtrSet<const Value *, 32> Visited; + + // Make sure that all of the items in WorkSet are in our EphValues set. + EphValues.insert(WorkSet.begin(), WorkSet.end()); + + // Note: We don't speculate PHIs here, so we'll miss instruction chains kept + // alive only by ephemeral values. + + while (!WorkSet.empty()) { + const Value *V = WorkSet.front(); + WorkSet.erase(WorkSet.begin()); + + if (!Visited.insert(V).second) + continue; + + // If all uses of this value are ephemeral, then so is this value. + bool FoundNEUse = false; + for (const User *I : V->users()) + if (!EphValues.count(I)) { + FoundNEUse = true; + break; + } + + if (FoundNEUse) + continue; + + EphValues.insert(V); + DEBUG(dbgs() << "Ephemeral Value: " << *V << "\n"); + + if (const User *U = dyn_cast<User>(V)) + for (const Value *J : U->operands()) { + if (isSafeToSpeculativelyExecute(J)) + WorkSet.push_back(J); + } + } +} + +// Find all ephemeral values. +void CodeMetrics::collectEphemeralValues( + const Loop *L, AssumptionCache *AC, + SmallPtrSetImpl<const Value *> &EphValues) { + SmallVector<const Value *, 16> WorkSet; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + Instruction *I = cast<Instruction>(AssumeVH); + + // Filter out call sites outside of the loop so we don't to a function's + // worth of work for each of its loops (and, in the common case, ephemeral + // values in the loop are likely due to @llvm.assume calls in the loop). + if (!L->contains(I->getParent())) + continue; + + WorkSet.push_back(I); + } + + completeEphemeralValues(WorkSet, EphValues); +} + +void CodeMetrics::collectEphemeralValues( + const Function *F, AssumptionCache *AC, + SmallPtrSetImpl<const Value *> &EphValues) { + SmallVector<const Value *, 16> WorkSet; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + Instruction *I = cast<Instruction>(AssumeVH); + assert(I->getParent()->getParent() == F && + "Found assumption for the wrong function!"); + WorkSet.push_back(I); + } + + completeEphemeralValues(WorkSet, EphValues); +} + /// analyzeBasicBlock - Fill in the current structure with information gleaned /// from the specified block. void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, - const TargetTransformInfo &TTI) { + const TargetTransformInfo &TTI, + SmallPtrSetImpl<const Value*> &EphValues) { ++NumBlocks; unsigned NumInstsBeforeThisBB = NumInsts; for (BasicBlock::const_iterator II = BB->begin(), E = BB->end(); II != E; ++II) { + // Skip ephemeral values. + if (EphValues.count(II)) + continue; + // Special handling for calls. if (isa<CallInst>(II) || isa<InvokeInst>(II)) { ImmutableCallSite CS(cast<Instruction>(II)); diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index 8dc94219027f..fd8f2aee2db5 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -47,15 +47,16 @@ using namespace llvm; // Constant Folding internal helper functions //===----------------------------------------------------------------------===// -/// FoldBitCast - Constant fold bitcast, symbolically evaluating it with -/// DataLayout. This always returns a non-null constant, but it may be a +/// Constant fold bitcast, symbolically evaluating it with DataLayout. +/// This always returns a non-null constant, but it may be a /// ConstantExpr if unfoldable. static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &TD) { // Catch the obvious splat cases. if (C->isNullValue() && !DestTy->isX86_MMXTy()) return Constant::getNullValue(DestTy); - if (C->isAllOnesValue() && !DestTy->isX86_MMXTy()) + if (C->isAllOnesValue() && !DestTy->isX86_MMXTy() && + !DestTy->isPtrOrPtrVectorTy()) // Don't get ones for ptr types! return Constant::getAllOnesValue(DestTy); // Handle a vector->integer cast. @@ -197,7 +198,7 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, // Handle: bitcast (<2 x i64> <i64 0, i64 1> to <4 x i32>) unsigned Ratio = NumDstElt/NumSrcElt; - unsigned DstBitSize = DstEltTy->getPrimitiveSizeInBits(); + unsigned DstBitSize = TD.getTypeSizeInBits(DstEltTy); // Loop over each source value, expanding into multiple results. for (unsigned i = 0; i != NumSrcElt; ++i) { @@ -213,6 +214,15 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, ConstantInt::get(Src->getType(), ShiftAmt)); ShiftAmt += isLittleEndian ? DstBitSize : -DstBitSize; + // Truncate the element to an integer with the same pointer size and + // convert the element back to a pointer using a inttoptr. + if (DstEltTy->isPointerTy()) { + IntegerType *DstIntTy = Type::getIntNTy(C->getContext(), DstBitSize); + Constant *CE = ConstantExpr::getTrunc(Elt, DstIntTy); + Result.push_back(ConstantExpr::getIntToPtr(CE, DstEltTy)); + continue; + } + // Truncate and remember this piece. Result.push_back(ConstantExpr::getTrunc(Elt, DstEltTy)); } @@ -222,9 +232,8 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, } -/// IsConstantOffsetFromGlobal - If this constant is actually a constant offset -/// from a global, return the global and the constant. Because of -/// constantexprs, this function is recursive. +/// If this constant is a constant offset from a global, return the global and +/// the constant. Because of constantexprs, this function is recursive. static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, APInt &Offset, const DataLayout &TD) { // Trivial case, constant is the global. @@ -264,10 +273,10 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, return true; } -/// ReadDataFromGlobal - Recursive helper to read bits out of global. C is the -/// constant being copied out of. ByteOffset is an offset into C. CurPtr is the -/// pointer to copy results into and BytesLeft is the number of bytes left in -/// the CurPtr buffer. TD is the target data. +/// Recursive helper to read bits out of global. C is the constant being copied +/// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy +/// results into and BytesLeft is the number of bytes left in +/// the CurPtr buffer. TD is the target data. static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, unsigned BytesLeft, const DataLayout &TD) { @@ -518,9 +527,8 @@ static Constant *ConstantFoldLoadThroughBitcast(ConstantExpr *CE, return nullptr; } -/// ConstantFoldLoadFromConstPtr - Return the value that a load from C would -/// produce if it is constant and determinable. If this is not determinable, -/// return null. +/// Return the value that a load from C would produce if it is constant and +/// determinable. If this is not determinable, return null. Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, const DataLayout *TD) { // First, try the easy cases: @@ -610,7 +618,7 @@ static Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout *TD){ return nullptr; } -/// SymbolicallyEvaluateBinop - One of Op0/Op1 is a constant expression. +/// One of Op0/Op1 is a constant expression. /// Attempt to symbolically evaluate the result of a binary operator merging /// these together. If target data info is available, it is provided as DL, /// otherwise DL is null. @@ -667,9 +675,8 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, return nullptr; } -/// CastGEPIndices - If array indices are not pointer-sized integers, -/// explicitly cast them so that they aren't implicitly casted by the -/// getelementptr. +/// If array indices are not pointer-sized integers, explicitly cast them so +/// that they aren't implicitly casted by the getelementptr. static Constant *CastGEPIndices(ArrayRef<Constant *> Ops, Type *ResultTy, const DataLayout *TD, const TargetLibraryInfo *TLI) { @@ -724,8 +731,7 @@ static Constant* StripPtrCastKeepAS(Constant* Ptr) { return Ptr; } -/// SymbolicallyEvaluateGEP - If we can symbolically evaluate the specified GEP -/// constant expression, do so. +/// If we can symbolically evaluate the GEP constant expression, do so. static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops, Type *ResultTy, const DataLayout *TD, const TargetLibraryInfo *TLI) { @@ -887,7 +893,7 @@ static Constant *SymbolicallyEvaluateGEP(ArrayRef<Constant *> Ops, // Constant Folding public APIs //===----------------------------------------------------------------------===// -/// ConstantFoldInstruction - Try to constant fold the specified instruction. +/// Try to constant fold the specified instruction. /// If successful, the constant result is returned, if not, null is returned. /// Note that this fails if not all of the operands are constant. Otherwise, /// this function can only fail when attempting to fold instructions like loads @@ -967,7 +973,7 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, static Constant * ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, const TargetLibraryInfo *TLI, - SmallPtrSet<ConstantExpr *, 4> &FoldedOps) { + SmallPtrSetImpl<ConstantExpr *> &FoldedOps) { SmallVector<Constant *, 8> Ops; for (User::const_op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; ++i) { @@ -975,7 +981,7 @@ ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, // Recursively fold the ConstantExpr's operands. If we have already folded // a ConstantExpr, we don't have to process it again. if (ConstantExpr *NewCE = dyn_cast<ConstantExpr>(NewC)) { - if (FoldedOps.insert(NewCE)) + if (FoldedOps.insert(NewCE).second) NewC = ConstantFoldConstantExpressionImpl(NewCE, TD, TLI, FoldedOps); } Ops.push_back(NewC); @@ -987,7 +993,7 @@ ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout *TD, return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(), Ops, TD, TLI); } -/// ConstantFoldConstantExpression - Attempt to fold the constant expression +/// Attempt to fold the constant expression /// using the specified DataLayout. If successful, the constant result is /// result is returned, if not, null is returned. Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, @@ -997,7 +1003,7 @@ Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, return ConstantFoldConstantExpressionImpl(CE, TD, TLI, FoldedOps); } -/// ConstantFoldInstOperands - Attempt to constant fold an instruction with the +/// Attempt to constant fold an instruction with the /// specified opcode and operands. If successful, the constant result is /// returned, if not, null is returned. Note that this function can fail when /// attempting to fold instructions like loads and stores, which have no @@ -1102,10 +1108,9 @@ Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy, } } -/// ConstantFoldCompareInstOperands - Attempt to constant fold a compare +/// Attempt to constant fold a compare /// instruction (icmp/fcmp) with the specified operands. If it fails, it /// returns a constant expression of the specified operands. -/// Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Constant *Ops0, Constant *Ops1, const DataLayout *TD, @@ -1192,9 +1197,9 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, } -/// ConstantFoldLoadThroughGEPConstantExpr - Given a constant and a -/// getelementptr constantexpr, return the constant value being addressed by the -/// constant expression, or null if something is funny and we can't decide. +/// Given a constant and a getelementptr constantexpr, return the constant value +/// being addressed by the constant expression, or null if something is funny +/// and we can't decide. Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, ConstantExpr *CE) { if (!CE->getOperand(1)->isNullValue()) @@ -1210,10 +1215,9 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, return C; } -/// ConstantFoldLoadThroughGEPIndices - Given a constant and getelementptr -/// indices (with an *implied* zero pointer index that is not in the list), -/// return the constant value being addressed by a virtual load, or null if -/// something is funny and we can't decide. +/// Given a constant and getelementptr indices (with an *implied* zero pointer +/// index that is not in the list), return the constant value being addressed by +/// a virtual load, or null if something is funny and we can't decide. Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, ArrayRef<Constant*> Indices) { // Loop over all of the operands, tracking down which value we are @@ -1231,11 +1235,12 @@ Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, // Constant Folding for Calls // -/// canConstantFoldCallTo - Return true if its even possible to fold a call to -/// the specified function. +/// Return true if it's even possible to fold a call to the specified function. bool llvm::canConstantFoldCallTo(const Function *F) { switch (F->getIntrinsicID()) { case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: case Intrinsic::log: case Intrinsic::log2: case Intrinsic::log10: @@ -1321,7 +1326,7 @@ static Constant *GetConstantFoldFPValue(double V, Type *Ty) { } namespace { -/// llvm_fenv_clearexcept - Clear the floating-point exception state. +/// Clear the floating-point exception state. static inline void llvm_fenv_clearexcept() { #if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT feclearexcept(FE_ALL_EXCEPT); @@ -1329,7 +1334,7 @@ static inline void llvm_fenv_clearexcept() { errno = 0; } -/// llvm_fenv_testexcept - Test if a floating-point exception was raised. +/// Test if a floating-point exception was raised. static inline bool llvm_fenv_testexcept() { int errno_val = errno; if (errno_val == ERANGE || errno_val == EDOM) @@ -1366,14 +1371,13 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), return GetConstantFoldFPValue(V, Ty); } -/// ConstantFoldConvertToInt - Attempt to an SSE floating point to integer -/// conversion of a constant floating point. If roundTowardZero is false, the -/// default IEEE rounding is used (toward nearest, ties to even). This matches -/// the behavior of the non-truncating SSE instructions in the default rounding -/// mode. The desired integer type Ty is used to select how many bits are -/// available for the result. Returns null if the conversion cannot be -/// performed, otherwise returns the Constant value resulting from the -/// conversion. +/// Attempt to fold an SSE floating point to integer conversion of a constant +/// floating point. If roundTowardZero is false, the default IEEE rounding is +/// used (toward nearest, ties to even). This matches the behavior of the +/// non-truncating SSE instructions in the default rounding mode. The desired +/// integer type Ty is used to select how many bits are available for the +/// result. Returns null if the conversion cannot be performed, otherwise +/// returns the Constant value resulting from the conversion. static Constant *ConstantFoldConvertToInt(const APFloat &Val, bool roundTowardZero, Type *Ty) { // All of these conversion intrinsics form an integer of at most 64bits. @@ -1520,8 +1524,14 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, (Ty->isHalfTy() || Ty->isFloatTy() || Ty->isDoubleTy())) { if (V >= -0.0) return ConstantFoldFP(sqrt, V, Ty); - else // Undefined - return Constant::getNullValue(Ty); + else { + // Unlike the sqrt definitions in C/C++, POSIX, and IEEE-754 - which + // all guarantee or favor returning NaN - the square root of a + // negative number is not defined for the LLVM sqrt intrinsic. + // This is because the intrinsic should only be emitted in place of + // libm's sqrt function when using "no-nans-fp-math". + return UndefValue::get(Ty); + } } break; case 's': @@ -1627,6 +1637,19 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, V1.copySign(V2); return ConstantFP::get(Ty->getContext(), V1); } + + if (IntrinsicID == Intrinsic::minnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), minnum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::maxnum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), maxnum(C1, C2)); + } + if (!TLI) return nullptr; if (Name == "pow" && TLI->has(LibFunc::pow)) @@ -1762,7 +1785,7 @@ static Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, return ConstantVector::get(Result); } -/// ConstantFoldCall - Attempt to constant fold a call to the specified function +/// Attempt to constant fold a call to the specified function /// with the specified arguments, returning null if unsuccessful. Constant * llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands, diff --git a/lib/Analysis/DependenceAnalysis.cpp b/lib/Analysis/DependenceAnalysis.cpp index d0784f1e678d..092df5c15cde 100644 --- a/lib/Analysis/DependenceAnalysis.cpp +++ b/lib/Analysis/DependenceAnalysis.cpp @@ -163,16 +163,15 @@ void dumpExampleDependence(raw_ostream &OS, Function *F, DstI != DstE; ++DstI) { if (isa<StoreInst>(*DstI) || isa<LoadInst>(*DstI)) { OS << "da analyze - "; - if (Dependence *D = DA->depends(&*SrcI, &*DstI, true)) { + if (auto D = DA->depends(&*SrcI, &*DstI, true)) { D->dump(OS); for (unsigned Level = 1; Level <= D->getLevels(); Level++) { if (D->isSplitable(Level)) { OS << "da analyze - split level = " << Level; - OS << ", iteration = " << *DA->getSplitIteration(D, Level); + OS << ", iteration = " << *DA->getSplitIteration(*D, Level); OS << "!\n"; } } - delete D; } else OS << "none!\n"; @@ -782,6 +781,25 @@ void DependenceAnalysis::collectCommonLoops(const SCEV *Expression, } } +void DependenceAnalysis::unifySubscriptType(Subscript *Pair) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; + IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); + IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); + if (SrcTy == nullptr || DstTy == nullptr) { + assert(SrcTy == DstTy && "This function only unify integer types and " + "expect Src and Dst share the same type " + "otherwise."); + return; + } + if (SrcTy->getBitWidth() > DstTy->getBitWidth()) { + // Sign-extend Dst to typeof(Src) if typeof(Src) is wider than typeof(Dst). + Pair->Dst = SE->getSignExtendExpr(Dst, SrcTy); + } else if (SrcTy->getBitWidth() < DstTy->getBitWidth()) { + // Sign-extend Src to typeof(Dst) if typeof(Dst) is wider than typeof(Src). + Pair->Src = SE->getSignExtendExpr(Src, DstTy); + } +} // removeMatchingExtensions - Examines a subscript pair. // If the source and destination are identically sign (or zero) @@ -794,9 +812,11 @@ void DependenceAnalysis::removeMatchingExtensions(Subscript *Pair) { (isa<SCEVSignExtendExpr>(Src) && isa<SCEVSignExtendExpr>(Dst))) { const SCEVCastExpr *SrcCast = cast<SCEVCastExpr>(Src); const SCEVCastExpr *DstCast = cast<SCEVCastExpr>(Dst); - if (SrcCast->getType() == DstCast->getType()) { - Pair->Src = SrcCast->getOperand(); - Pair->Dst = DstCast->getOperand(); + const SCEV *SrcCastOp = SrcCast->getOperand(); + const SCEV *DstCastOp = DstCast->getOperand(); + if (SrcCastOp->getType() == DstCastOp->getType()) { + Pair->Src = SrcCastOp; + Pair->Dst = DstCastOp; } } } @@ -2957,15 +2977,11 @@ const SCEV *DependenceAnalysis::addToCoefficient(const SCEV *Expr, AddRec->getNoWrapFlags()); } if (SE->isLoopInvariant(AddRec, TargetLoop)) - return SE->getAddRecExpr(AddRec, - Value, - TargetLoop, - SCEV::FlagAnyWrap); - return SE->getAddRecExpr(addToCoefficient(AddRec->getStart(), - TargetLoop, Value), - AddRec->getStepRecurrence(*SE), - AddRec->getLoop(), - AddRec->getNoWrapFlags()); + return SE->getAddRecExpr(AddRec, Value, TargetLoop, SCEV::FlagAnyWrap); + return SE->getAddRecExpr( + addToCoefficient(AddRec->getStart(), TargetLoop, Value), + AddRec->getStepRecurrence(*SE), AddRec->getLoop(), + AddRec->getNoWrapFlags()); } @@ -3183,7 +3199,7 @@ void DependenceAnalysis::updateDirection(Dependence::DVEntry &Level, bool DependenceAnalysis::tryDelinearize(const SCEV *SrcSCEV, const SCEV *DstSCEV, SmallVectorImpl<Subscript> &Pair, - const SCEV *ElementSize) const { + const SCEV *ElementSize) { const SCEVUnknown *SrcBase = dyn_cast<SCEVUnknown>(SE->getPointerBase(SrcSCEV)); const SCEVUnknown *DstBase = @@ -3238,6 +3254,7 @@ bool DependenceAnalysis::tryDelinearize(const SCEV *SrcSCEV, for (int i = 0; i < size; ++i) { Pair[i].Src = SrcSubscripts[i]; Pair[i].Dst = DstSubscripts[i]; + unifySubscriptType(&Pair[i]); // FIXME: we should record the bounds SrcSizes[i] and DstSizes[i] that the // delinearization has found, and add these constraints to the dependence @@ -3277,9 +3294,9 @@ static void dumpSmallBitVector(SmallBitVector &BV) { // // Care is required to keep the routine below, getSplitIteration(), // up to date with respect to this routine. -Dependence *DependenceAnalysis::depends(Instruction *Src, - Instruction *Dst, - bool PossiblyLoopIndependent) { +std::unique_ptr<Dependence> +DependenceAnalysis::depends(Instruction *Src, Instruction *Dst, + bool PossiblyLoopIndependent) { if (Src == Dst) PossiblyLoopIndependent = false; @@ -3291,7 +3308,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, if (!isLoadOrStore(Src) || !isLoadOrStore(Dst)) { // can only analyze simple loads and stores, i.e., no calls, invokes, etc. DEBUG(dbgs() << "can only handle simple loads and stores\n"); - return new Dependence(Src, Dst); + return make_unique<Dependence>(Src, Dst); } Value *SrcPtr = getPointerOperand(Src); @@ -3302,7 +3319,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, case AliasAnalysis::PartialAlias: // cannot analyse objects if we don't understand their aliasing. DEBUG(dbgs() << "can't analyze may or partial alias\n"); - return new Dependence(Src, Dst); + return make_unique<Dependence>(Src, Dst); case AliasAnalysis::NoAlias: // If the objects noalias, they are distinct, accesses are independent. DEBUG(dbgs() << "no alias\n"); @@ -3346,6 +3363,7 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, ++SrcIdx, ++DstIdx, ++P) { Pair[P].Src = SE->getSCEV(*SrcIdx); Pair[P].Dst = SE->getSCEV(*DstIdx); + unifySubscriptType(&Pair[P]); } } else { @@ -3675,9 +3693,9 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, return nullptr; } - FullDependence *Final = new FullDependence(Result); + auto Final = make_unique<FullDependence>(Result); Result.DV = nullptr; - return Final; + return std::move(Final); } @@ -3729,13 +3747,12 @@ Dependence *DependenceAnalysis::depends(Instruction *Src, // // breaks the dependence and allows us to vectorize/parallelize // both loops. -const SCEV *DependenceAnalysis::getSplitIteration(const Dependence *Dep, +const SCEV *DependenceAnalysis::getSplitIteration(const Dependence &Dep, unsigned SplitLevel) { - assert(Dep && "expected a pointer to a Dependence"); - assert(Dep->isSplitable(SplitLevel) && + assert(Dep.isSplitable(SplitLevel) && "Dep should be splitable at SplitLevel"); - Instruction *Src = Dep->getSrc(); - Instruction *Dst = Dep->getDst(); + Instruction *Src = Dep.getSrc(); + Instruction *Dst = Dep.getDst(); assert(Src->mayReadFromMemory() || Src->mayWriteToMemory()); assert(Dst->mayReadFromMemory() || Dst->mayWriteToMemory()); assert(isLoadOrStore(Src)); diff --git a/lib/Analysis/FunctionTargetTransformInfo.cpp b/lib/Analysis/FunctionTargetTransformInfo.cpp new file mode 100644 index 000000000000..a686bec4e73f --- /dev/null +++ b/lib/Analysis/FunctionTargetTransformInfo.cpp @@ -0,0 +1,50 @@ +//===- llvm/Analysis/FunctionTargetTransformInfo.h --------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass wraps a TargetTransformInfo in a FunctionPass so that it can +// forward along the current Function so that we can make target specific +// decisions based on the particular subtarget specified for each Function. +// +//===----------------------------------------------------------------------===// + +#include "llvm/InitializePasses.h" +#include "llvm/Analysis/FunctionTargetTransformInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "function-tti" +static const char ftti_name[] = "Function TargetTransformInfo"; +INITIALIZE_PASS_BEGIN(FunctionTargetTransformInfo, "function_tti", ftti_name, false, true) +INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_END(FunctionTargetTransformInfo, "function_tti", ftti_name, false, true) +char FunctionTargetTransformInfo::ID = 0; + +namespace llvm { +FunctionPass *createFunctionTargetTransformInfoPass() { + return new FunctionTargetTransformInfo(); +} +} + +FunctionTargetTransformInfo::FunctionTargetTransformInfo() + : FunctionPass(ID), Fn(nullptr), TTI(nullptr) { + initializeFunctionTargetTransformInfoPass(*PassRegistry::getPassRegistry()); +} + +void FunctionTargetTransformInfo::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetTransformInfo>(); +} + +void FunctionTargetTransformInfo::releaseMemory() {} + +bool FunctionTargetTransformInfo::runOnFunction(Function &F) { + Fn = &F; + TTI = &getAnalysis<TargetTransformInfo>(); + return false; +} diff --git a/lib/Analysis/IPA/CMakeLists.txt b/lib/Analysis/IPA/CMakeLists.txt index 67b413577980..6095136d60a1 100644 --- a/lib/Analysis/IPA/CMakeLists.txt +++ b/lib/Analysis/IPA/CMakeLists.txt @@ -2,7 +2,6 @@ add_llvm_library(LLVMipa CallGraph.cpp CallGraphSCCPass.cpp CallPrinter.cpp - FindUsedTypes.cpp GlobalsModRef.cpp IPA.cpp InlineCost.cpp diff --git a/lib/Analysis/IPA/CallGraph.cpp b/lib/Analysis/IPA/CallGraph.cpp index dfabb0ad063d..67cf7f86e072 100644 --- a/lib/Analysis/IPA/CallGraph.cpp +++ b/lib/Analysis/IPA/CallGraph.cpp @@ -282,6 +282,3 @@ void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } #endif - -// Enuse that users of CallGraph.h also link with this file -DEFINING_FILE_FOR(CallGraph) diff --git a/lib/Analysis/IPA/CallGraphSCCPass.cpp b/lib/Analysis/IPA/CallGraphSCCPass.cpp index c27edbfa2ff5..ded1de79a124 100644 --- a/lib/Analysis/IPA/CallGraphSCCPass.cpp +++ b/lib/Analysis/IPA/CallGraphSCCPass.cpp @@ -21,8 +21,8 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LegacyPassManagers.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManagers.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" @@ -243,7 +243,14 @@ bool CGPassManager::RefreshCallGraph(CallGraphSCC &CurSCC, assert(!CallSites.count(I->first) && "Call site occurs in node multiple times"); - CallSites.insert(std::make_pair(I->first, I->second)); + + CallSite CS(I->first); + if (CS) { + Function *Callee = CS.getCalledFunction(); + // Ignore intrinsics because they're not really function calls. + if (!Callee || !(Callee->isIntrinsic())) + CallSites.insert(std::make_pair(I->first, I->second)); + } ++I; } diff --git a/lib/Analysis/IPA/FindUsedTypes.cpp b/lib/Analysis/IPA/FindUsedTypes.cpp deleted file mode 100644 index b37344b044ce..000000000000 --- a/lib/Analysis/IPA/FindUsedTypes.cpp +++ /dev/null @@ -1,100 +0,0 @@ -//===- FindUsedTypes.cpp - Find all Types used by a module ----------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This pass is used to seek out all of the types in use by the program. Note -// that this analysis explicitly does not include types only used by the symbol -// table. -// -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/FindUsedTypes.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/Module.h" -#include "llvm/Support/raw_ostream.h" -using namespace llvm; - -char FindUsedTypes::ID = 0; -INITIALIZE_PASS(FindUsedTypes, "print-used-types", - "Find Used Types", false, true) - -// IncorporateType - Incorporate one type and all of its subtypes into the -// collection of used types. -// -void FindUsedTypes::IncorporateType(Type *Ty) { - // If ty doesn't already exist in the used types map, add it now, otherwise - // return. - if (!UsedTypes.insert(Ty)) return; // Already contain Ty. - - // Make sure to add any types this type references now. - // - for (Type::subtype_iterator I = Ty->subtype_begin(), E = Ty->subtype_end(); - I != E; ++I) - IncorporateType(*I); -} - -void FindUsedTypes::IncorporateValue(const Value *V) { - IncorporateType(V->getType()); - - // If this is a constant, it could be using other types... - if (const Constant *C = dyn_cast<Constant>(V)) { - if (!isa<GlobalValue>(C)) - for (User::const_op_iterator OI = C->op_begin(), OE = C->op_end(); - OI != OE; ++OI) - IncorporateValue(*OI); - } -} - - -// run - This incorporates all types used by the specified module -// -bool FindUsedTypes::runOnModule(Module &m) { - UsedTypes.clear(); // reset if run multiple times... - - // Loop over global variables, incorporating their types - for (Module::const_global_iterator I = m.global_begin(), E = m.global_end(); - I != E; ++I) { - IncorporateType(I->getType()); - if (I->hasInitializer()) - IncorporateValue(I->getInitializer()); - } - - for (Module::iterator MI = m.begin(), ME = m.end(); MI != ME; ++MI) { - IncorporateType(MI->getType()); - const Function &F = *MI; - - // Loop over all of the instructions in the function, adding their return - // type as well as the types of their operands. - // - for (const_inst_iterator II = inst_begin(F), IE = inst_end(F); - II != IE; ++II) { - const Instruction &I = *II; - - IncorporateType(I.getType()); // Incorporate the type of the instruction - for (User::const_op_iterator OI = I.op_begin(), OE = I.op_end(); - OI != OE; ++OI) - IncorporateValue(*OI); // Insert inst operand types as well - } - } - - return false; -} - -// Print the types found in the module. If the optional Module parameter is -// passed in, then the types are printed symbolically if possible, using the -// symbol table from the module. -// -void FindUsedTypes::print(raw_ostream &OS, const Module *M) const { - OS << "Types in use by this module:\n"; - for (SetVector<Type *>::const_iterator I = UsedTypes.begin(), - E = UsedTypes.end(); I != E; ++I) { - OS << " " << **I << '\n'; - } -} diff --git a/lib/Analysis/IPA/IPA.cpp b/lib/Analysis/IPA/IPA.cpp index b26c052de67e..806bfb81b6d5 100644 --- a/lib/Analysis/IPA/IPA.cpp +++ b/lib/Analysis/IPA/IPA.cpp @@ -22,7 +22,6 @@ void llvm::initializeIPA(PassRegistry &Registry) { initializeCallGraphWrapperPassPass(Registry); initializeCallGraphPrinterPass(Registry); initializeCallGraphViewerPass(Registry); - initializeFindUsedTypesPass(Registry); initializeGlobalsModRefPass(Registry); } diff --git a/lib/Analysis/IPA/InlineCost.cpp b/lib/Analysis/IPA/InlineCost.cpp index 8807529cabac..58ac5d33409e 100644 --- a/lib/Analysis/IPA/InlineCost.cpp +++ b/lib/Analysis/IPA/InlineCost.cpp @@ -17,6 +17,8 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -49,6 +51,9 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { /// The TargetTransformInfo available for this compilation. const TargetTransformInfo &TTI; + /// The cache of @llvm.assume intrinsics. + AssumptionCache &AC; + // The called function. Function &F; @@ -104,7 +109,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { ConstantInt *stripAndComputeInBoundsConstantOffsets(Value *&V); // Custom analysis routines. - bool analyzeBlock(BasicBlock *BB); + bool analyzeBlock(BasicBlock *BB, SmallPtrSetImpl<const Value *> &EphValues); // Disable several entry points to the visitor so we don't accidentally use // them by declaring but not defining them here. @@ -141,8 +146,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { public: CallAnalyzer(const DataLayout *DL, const TargetTransformInfo &TTI, - Function &Callee, int Threshold) - : DL(DL), TTI(TTI), F(Callee), Threshold(Threshold), Cost(0), + AssumptionCache &AC, Function &Callee, int Threshold) + : DL(DL), TTI(TTI), AC(AC), F(Callee), Threshold(Threshold), Cost(0), IsCallerRecursive(false), IsRecursiveCall(false), ExposesReturnsTwice(false), HasDynamicAlloca(false), ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), @@ -778,7 +783,7 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { // during devirtualization and so we want to give it a hefty bonus for // inlining, but cap that bonus in the event that inlining wouldn't pan // out. Pretend to inline the function, with a custom threshold. - CallAnalyzer CA(DL, TTI, *F, InlineConstants::IndirectCallThreshold); + CallAnalyzer CA(DL, TTI, AC, *F, InlineConstants::IndirectCallThreshold); if (CA.analyzeCall(CS)) { // We were able to inline the indirect call! Subtract the cost from the // bonus we want to apply, but don't go below zero. @@ -881,7 +886,8 @@ bool CallAnalyzer::visitInstruction(Instruction &I) { /// aborts early if the threshold has been exceeded or an impossible to inline /// construct has been detected. It returns false if inlining is no longer /// viable, and true if inlining remains viable. -bool CallAnalyzer::analyzeBlock(BasicBlock *BB) { +bool CallAnalyzer::analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { // FIXME: Currently, the number of instructions in a function regardless of // our ability to simplify them during inline to constants or dead code, @@ -893,6 +899,10 @@ bool CallAnalyzer::analyzeBlock(BasicBlock *BB) { if (isa<DbgInfoIntrinsic>(I)) continue; + // Skip ephemeral values. + if (EphValues.count(I)) + continue; + ++NumInstructions; if (isa<ExtractElementInst>(I) || I->getType()->isVectorTy()) ++NumVectorInstructions; @@ -967,7 +977,7 @@ ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { break; } assert(V->getType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(V)); + } while (Visited.insert(V).second); Type *IntPtrTy = DL->getIntPtrType(V->getContext()); return cast<ConstantInt>(ConstantInt::get(IntPtrTy, Offset)); @@ -1096,6 +1106,12 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { NumConstantOffsetPtrArgs = ConstantOffsetPtrs.size(); NumAllocaArgs = SROAArgValues.size(); + // FIXME: If a caller has multiple calls to a callee, we end up recomputing + // the ephemeral values multiple times (and they're completely determined by + // the callee, so this is purely duplicate work). + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(&F, &AC, EphValues); + // The worklist of live basic blocks in the callee *after* inlining. We avoid // adding basic blocks of the callee which can be proven to be dead for this // particular call site in order to get more accurate cost estimates. This @@ -1129,7 +1145,7 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // Analyze the cost of this block. If we blow through the threshold, this // returns false, and we can bail on out. - if (!analyzeBlock(BB)) { + if (!analyzeBlock(BB, EphValues)) { if (IsRecursiveCall || ExposesReturnsTwice || HasDynamicAlloca || HasIndirectBr) return false; @@ -1217,6 +1233,7 @@ void CallAnalyzer::dump() { INITIALIZE_PASS_BEGIN(InlineCostAnalysis, "inline-cost", "Inline Cost Analysis", true, true) INITIALIZE_AG_DEPENDENCY(TargetTransformInfo) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(InlineCostAnalysis, "inline-cost", "Inline Cost Analysis", true, true) @@ -1228,12 +1245,14 @@ InlineCostAnalysis::~InlineCostAnalysis() {} void InlineCostAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfo>(); CallGraphSCCPass::getAnalysisUsage(AU); } bool InlineCostAnalysis::runOnSCC(CallGraphSCC &SCC) { TTI = &getAnalysis<TargetTransformInfo>(); + ACT = &getAnalysis<AssumptionCacheTracker>(); return false; } @@ -1290,7 +1309,8 @@ InlineCost InlineCostAnalysis::getInlineCost(CallSite CS, Function *Callee, DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() << "...\n"); - CallAnalyzer CA(Callee->getDataLayout(), *TTI, *Callee, Threshold); + CallAnalyzer CA(Callee->getDataLayout(), *TTI, + ACT->getAssumptionCache(*Callee), *Callee, Threshold); bool ShouldInline = CA.analyzeCall(CS); DEBUG(CA.dump()); diff --git a/lib/Analysis/IVUsers.cpp b/lib/Analysis/IVUsers.cpp index 24655aa002c8..6b5f37002b80 100644 --- a/lib/Analysis/IVUsers.cpp +++ b/lib/Analysis/IVUsers.cpp @@ -84,7 +84,7 @@ static bool isInteresting(const SCEV *S, const Instruction *I, const Loop *L, /// form. static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, const LoopInfo *LI, - SmallPtrSet<Loop*,16> &SimpleLoopNests) { + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { Loop *NearestLoop = nullptr; for (DomTreeNode *Rung = DT->getNode(BB); Rung; Rung = Rung->getIDom()) { @@ -112,10 +112,10 @@ static bool isSimplifiedLoopNest(BasicBlock *BB, const DominatorTree *DT, /// reducible SCEV, recursively add its users to the IVUsesByStride set and /// return true. Otherwise, return false. bool IVUsers::AddUsersImpl(Instruction *I, - SmallPtrSet<Loop*,16> &SimpleLoopNests) { + SmallPtrSetImpl<Loop*> &SimpleLoopNests) { // Add this IV user to the Processed set before returning false to ensure that // all IV users are members of the set. See IVUsers::isIVUserOrOperand. - if (!Processed.insert(I)) + if (!Processed.insert(I).second) return true; // Instruction already handled. if (!SE->isSCEVable(I->getType())) @@ -145,7 +145,7 @@ bool IVUsers::AddUsersImpl(Instruction *I, SmallPtrSet<Instruction *, 4> UniqueUsers; for (Use &U : I->uses()) { Instruction *User = cast<Instruction>(U.getUser()); - if (!UniqueUsers.insert(User)) + if (!UniqueUsers.insert(User).second) continue; // Do not infinitely recurse on PHI nodes. diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 7a820a58f6f0..3fbbd7cbd22a 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -20,6 +20,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" @@ -31,6 +32,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" +#include <algorithm> using namespace llvm; using namespace llvm::PatternMatch; @@ -41,14 +43,20 @@ enum { RecursionLimit = 3 }; STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumReassoc, "Number of reassociations"); +namespace { struct Query { const DataLayout *DL; const TargetLibraryInfo *TLI; const DominatorTree *DT; + AssumptionCache *AC; + const Instruction *CxtI; Query(const DataLayout *DL, const TargetLibraryInfo *tli, - const DominatorTree *dt) : DL(DL), TLI(tli), DT(dt) {} + const DominatorTree *dt, AssumptionCache *ac = nullptr, + const Instruction *cxti = nullptr) + : DL(DL), TLI(tli), DT(dt), AC(ac), CxtI(cxti) {} }; +} // end anonymous namespace static Value *SimplifyAndInst(Value *, Value *, const Query &, unsigned); static Value *SimplifyBinOp(unsigned, Value *, Value *, const Query &, @@ -575,8 +583,9 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -624,7 +633,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout *DL, } assert(V->getType()->getScalarType()->isPointerTy() && "Unexpected operand type!"); - } while (Visited.insert(V)); + } while (Visited.insert(V).second); Constant *OffsetIntPtr = ConstantInt::get(IntPtrTy, Offset); if (V->getType()->isVectorTy()) @@ -676,6 +685,10 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); + // 0 - X -> 0 if the sub is NUW. + if (isNUW && match(Op0, m_Zero())) + return Op0; + // (X + Y) - Z -> X + (Y - Z) or Y + (X - Z) if everything simplifies. // For example, (X + Y) - Y -> X; (Y + X) - Y -> X Value *X = nullptr, *Y = nullptr, *Z = Op1; @@ -769,8 +782,9 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifySubInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -946,29 +960,38 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, } Value *llvm::SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFAddInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DataLayout *DL, + const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFAddInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } Value *llvm::SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, - const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFSubInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DataLayout *DL, + const TargetLibraryInfo *TLI, + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFSubInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } -Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, - FastMathFlags FMF, +Value *llvm::SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFMulInst(Op0, Op1, FMF, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFMulInst(Op0, Op1, FMF, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyMulInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyMulInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyDiv - Given operands for an SDiv or UDiv, see if we can @@ -988,6 +1011,10 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (match(Op1, m_Undef())) return Op1; + // X / 0 -> undef, we don't need to preserve faults! + if (match(Op1, m_Zero())) + return UndefValue::get(Op1->getType()); + // undef / X -> 0 if (match(Op0, m_Undef())) return Constant::getNullValue(Op0->getType()); @@ -1028,6 +1055,16 @@ static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, (!isSigned && match(Op0, m_URem(m_Value(), m_Specific(Op1))))) return Constant::getNullValue(Op0->getType()); + // (X /u C1) /u C2 -> 0 if C1 * C2 overflow + ConstantInt *C1, *C2; + if (!isSigned && match(Op0, m_UDiv(m_Value(X), m_ConstantInt(C1))) && + match(Op1, m_ConstantInt(C2))) { + bool Overflow; + C1->getValue().umul_ov(C2->getValue(), Overflow); + if (Overflow) + return Constant::getNullValue(Op0->getType()); + } + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1055,8 +1092,10 @@ static Value *SimplifySDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifySDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyUDivInst - Given operands for a UDiv, see if we can @@ -1071,8 +1110,10 @@ static Value *SimplifyUDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyUDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyUDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, @@ -1090,8 +1131,10 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFDivInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFDivInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyRem - Given operands for an SRem or URem, see if we can @@ -1133,6 +1176,13 @@ static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Op0 == Op1) return Constant::getNullValue(Op0->getType()); + // (X % Y) % Y -> X % Y + if ((Opcode == Instruction::SRem && + match(Op0, m_SRem(m_Value(), m_Specific(Op1)))) || + (Opcode == Instruction::URem && + match(Op0, m_URem(m_Value(), m_Specific(Op1))))) + return Op0; + // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1160,8 +1210,10 @@ static Value *SimplifySRemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifySRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifySRemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyURemInst - Given operands for a URem, see if we can @@ -1176,8 +1228,10 @@ static Value *SimplifyURemInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyURemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyURemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, @@ -1195,8 +1249,10 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, const Query &, Value *llvm::SimplifyFRemInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFRemInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFRemInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// isUndefShift - Returns true if a shift by \c Amount always yields undef. @@ -1264,6 +1320,37 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, return nullptr; } +/// \brief Given operands for an Shl, LShr or AShr, see if we can +/// fold the result. If not, this returns null. +static Value *SimplifyRightShift(unsigned Opcode, Value *Op0, Value *Op1, + bool isExact, const Query &Q, + unsigned MaxRecurse) { + if (Value *V = SimplifyShift(Opcode, Op0, Op1, Q, MaxRecurse)) + return V; + + // X >> X -> 0 + if (Op0 == Op1) + return Constant::getNullValue(Op0->getType()); + + // undef >> X -> 0 + // undef >> X -> undef (if it's exact) + if (match(Op0, m_Undef())) + return isExact ? Op0 : Constant::getNullValue(Op0->getType()); + + // The low bit cannot be shifted out of an exact shift if it is set. + if (isExact) { + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + APInt Op0KnownZero(BitWidth, 0); + APInt Op0KnownOne(BitWidth, 0); + computeKnownBits(Op0, Op0KnownZero, Op0KnownOne, Q.DL, /*Depth=*/0, Q.AC, + Q.CxtI, Q.DT); + if (Op0KnownOne[0]) + return Op0; + } + + return nullptr; +} + /// SimplifyShlInst - Given operands for an Shl, see if we can /// fold the result. If not, this returns null. static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, @@ -1272,8 +1359,9 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, return V; // undef << X -> 0 + // undef << X -> undef if (if it's NSW/NUW) if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); + return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType()); // (X >> A) << A -> X Value *X; @@ -1284,8 +1372,9 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyShlInst(Op0, Op1, isNSW, isNUW, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -1293,21 +1382,13 @@ Value *llvm::SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, /// fold the result. If not, this returns null. static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::LShr, Op0, Op1, Q, MaxRecurse)) - return V; - - // X >> X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); - - // undef >>l X -> 0 - if (match(Op0, m_Undef())) - return Constant::getNullValue(Op0->getType()); + if (Value *V = SimplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q, + MaxRecurse)) + return V; // (X << A) >> A -> X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && - cast<OverflowingBinaryOperator>(Op0)->hasNoUnsignedWrap()) + if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) return X; return nullptr; @@ -1316,8 +1397,9 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyLShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyLShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -1325,29 +1407,21 @@ Value *llvm::SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, /// fold the result. If not, this returns null. static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const Query &Q, unsigned MaxRecurse) { - if (Value *V = SimplifyShift(Instruction::AShr, Op0, Op1, Q, MaxRecurse)) + if (Value *V = SimplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q, + MaxRecurse)) return V; - // X >> X -> 0 - if (Op0 == Op1) - return Constant::getNullValue(Op0->getType()); - // all ones >>a X -> all ones if (match(Op0, m_AllOnes())) return Op0; - // undef >>a X -> all ones - if (match(Op0, m_Undef())) - return Constant::getAllOnesValue(Op0->getType()); - // (X << A) >> A -> X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1))) && - cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap()) + if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) return X; // Arithmetic shifting an all-sign-bit value is a no-op. - unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL); + unsigned NumSignBits = ComputeNumSignBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); if (NumSignBits == Op0->getType()->getScalarSizeInBits()) return Op0; @@ -1357,11 +1431,105 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, Value *llvm::SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAShrInst(Op0, Op1, isExact, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyAShrInst(Op0, Op1, isExact, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } +static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, + ICmpInst *UnsignedICmp, bool IsAnd) { + Value *X, *Y; + + ICmpInst::Predicate EqPred; + if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(Y), m_Zero())) || + !ICmpInst::isEquality(EqPred)) + return nullptr; + + ICmpInst::Predicate UnsignedPred; + if (match(UnsignedICmp, m_ICmp(UnsignedPred, m_Value(X), m_Specific(Y))) && + ICmpInst::isUnsigned(UnsignedPred)) + ; + else if (match(UnsignedICmp, + m_ICmp(UnsignedPred, m_Value(Y), m_Specific(X))) && + ICmpInst::isUnsigned(UnsignedPred)) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + else + return nullptr; + + // X < Y && Y != 0 --> X < Y + // X < Y || Y != 0 --> Y != 0 + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE) + return IsAnd ? UnsignedICmp : ZeroICmp; + + // X >= Y || Y != 0 --> true + // X >= Y || Y == 0 --> X >= Y + if (UnsignedPred == ICmpInst::ICMP_UGE && !IsAnd) { + if (EqPred == ICmpInst::ICMP_NE) + return getTrue(UnsignedICmp->getType()); + return UnsignedICmp; + } + + // X < Y && Y == 0 --> false + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_EQ && + IsAnd) + return getFalse(UnsignedICmp->getType()); + + return nullptr; +} + +// Simplify (and (icmp ...) (icmp ...)) to true when we can tell that the range +// of possible values cannot be satisfied. +static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + ConstantInt *CI1, *CI2; + Value *V; + + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) + return X; + + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), + m_ConstantInt(CI2)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + return nullptr; + + Type *ITy = Op0->getType(); + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt &CI1V = CI1->getValue(); + const APInt &CI2V = CI2->getValue(); + const APInt Delta = CI2V - CI1V; + if (CI1V.isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT) + return getFalse(ITy); + if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW) + return getFalse(ITy); + } + } + if (CI1V.getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGT) + return getFalse(ITy); + } + + return nullptr; +} + /// SimplifyAndInst - Given operands for an And, see if we can /// fold the result. If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, @@ -1412,12 +1580,21 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, // A & (-A) = A if A is a power of two or zero. if (match(Op0, m_Neg(m_Specific(Op1))) || match(Op1, m_Neg(m_Specific(Op0)))) { - if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op0, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) return Op0; - if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/true)) + if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) return Op1; } + if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { + if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { + if (Value *V = SimplifyAndOfICmps(ICILHS, ICIRHS)) + return V; + if (Value *V = SimplifyAndOfICmps(ICIRHS, ICILHS)) + return V; + } + } + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, MaxRecurse)) @@ -1452,8 +1629,62 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyAndInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyAndInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyAndInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); +} + +// Simplify (or (icmp ...) (icmp ...)) to true when we can tell that the union +// contains all possible values. +static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + ICmpInst::Predicate Pred0, Pred1; + ConstantInt *CI1, *CI2; + Value *V; + + if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) + return X; + + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), + m_ConstantInt(CI2)))) + return nullptr; + + if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) + return nullptr; + + Type *ITy = Op0->getType(); + + auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + bool isNSW = AddInst->hasNoSignedWrap(); + bool isNUW = AddInst->hasNoUnsignedWrap(); + + const APInt &CI1V = CI1->getValue(); + const APInt &CI2V = CI2->getValue(); + const APInt Delta = CI2V - CI1V; + if (CI1V.isStrictlyPositive()) { + if (Delta == 2) { + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + if (Delta == 1) { + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE) + return getTrue(ITy); + if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW) + return getTrue(ITy); + } + } + if (CI1V.getBoolValue() && isNUW) { + if (Delta == 2) + if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Delta == 1) + if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; } /// SimplifyOrInst - Given operands for an Or, see if we can @@ -1513,6 +1744,15 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, (A == Op0 || B == Op0)) return Constant::getAllOnesValue(Op0->getType()); + if (auto *ICILHS = dyn_cast<ICmpInst>(Op0)) { + if (auto *ICIRHS = dyn_cast<ICmpInst>(Op1)) { + if (Value *V = SimplifyOrOfICmps(ICILHS, ICIRHS)) + return V; + if (Value *V = SimplifyOrOfICmps(ICIRHS, ICILHS)) + return V; + } + } + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Or, Op0, Op1, Q, MaxRecurse)) @@ -1545,18 +1785,22 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, if ((C2->getValue() & (C2->getValue() + 1)) == 0 && // C2 == 0+1+ match(A, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == B && MaskedValueIsZero(V2, C2->getValue())) + if (V1 == B && + MaskedValueIsZero(V2, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return A; - if (V2 == B && MaskedValueIsZero(V1, C2->getValue())) + if (V2 == B && + MaskedValueIsZero(V1, C2->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return A; } // Or commutes, try both ways. if ((C1->getValue() & (C1->getValue() + 1)) == 0 && match(B, m_Add(m_Value(V1), m_Value(V2)))) { // Add commutes, try both ways. - if (V1 == A && MaskedValueIsZero(V2, C1->getValue())) + if (V1 == A && + MaskedValueIsZero(V2, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; - if (V2 == A && MaskedValueIsZero(V1, C1->getValue())) + if (V2 == A && + MaskedValueIsZero(V1, C1->getValue(), Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return B; } } @@ -1573,8 +1817,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyOrInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyOrInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyXorInst - Given operands for a Xor, see if we can @@ -1628,8 +1874,10 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, Value *llvm::SimplifyXorInst(Value *Op0, Value *Op1, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyXorInst(Op0, Op1, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyXorInst(Op0, Op1, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } static Type *GetCompareTy(Value *Op) { @@ -1804,6 +2052,50 @@ static Constant *computePointerICmp(const DataLayout *DL, return ConstantExpr::getICmp(Pred, ConstantExpr::getAdd(LHSOffset, LHSNoBound), ConstantExpr::getAdd(RHSOffset, RHSNoBound)); + + // If one side of the equality comparison must come from a noalias call + // (meaning a system memory allocation function), and the other side must + // come from a pointer that cannot overlap with dynamically-allocated + // memory within the lifetime of the current function (allocas, byval + // arguments, globals), then determine the comparison result here. + SmallVector<Value *, 8> LHSUObjs, RHSUObjs; + GetUnderlyingObjects(LHS, LHSUObjs, DL); + GetUnderlyingObjects(RHS, RHSUObjs, DL); + + // Is the set of underlying objects all noalias calls? + auto IsNAC = [](SmallVectorImpl<Value *> &Objects) { + return std::all_of(Objects.begin(), Objects.end(), + [](Value *V){ return isNoAliasCall(V); }); + }; + + // Is the set of underlying objects all things which must be disjoint from + // noalias calls. For allocas, we consider only static ones (dynamic + // allocas might be transformed into calls to malloc not simultaneously + // live with the compared-to allocation). For globals, we exclude symbols + // that might be resolve lazily to symbols in another dynamically-loaded + // library (and, thus, could be malloc'ed by the implementation). + auto IsAllocDisjoint = [](SmallVectorImpl<Value *> &Objects) { + return std::all_of(Objects.begin(), Objects.end(), + [](Value *V){ + if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) + return AI->getParent() && AI->getParent()->getParent() && + AI->isStaticAlloca(); + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) + return (GV->hasLocalLinkage() || + GV->hasHiddenVisibility() || + GV->hasProtectedVisibility() || + GV->hasUnnamedAddr()) && + !GV->isThreadLocal(); + if (const Argument *A = dyn_cast<Argument>(V)) + return A->hasByValAttr(); + return false; + }); + }; + + if ((IsNAC(LHSUObjs) && IsAllocDisjoint(RHSUObjs)) || + (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) + return ConstantInt::get(GetCompareTy(LHS), + !CmpInst::isTrueWhenEqual(Pred)); } // Otherwise, fail. @@ -1883,40 +2175,46 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); if (LHSKnownNonNegative) return getFalse(ITy); break; case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (LHSKnownNegative) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(ITy); break; case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); if (LHSKnownNonNegative) return getTrue(ITy); break; case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL); + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (LHSKnownNegative) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL)) + if (LHSKnownNonNegative && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; } @@ -1981,6 +2279,22 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Upper = Upper + 1; assert(Upper != Lower && "Upper part of range has wrapped!"); } + } else if (match(LHS, m_NUWShl(m_ConstantInt(CI2), m_Value()))) { + // 'shl nuw CI2, x' produces [CI2, CI2 << CLZ(CI2)] + Lower = CI2->getValue(); + Upper = Lower.shl(Lower.countLeadingZeros()) + 1; + } else if (match(LHS, m_NSWShl(m_ConstantInt(CI2), m_Value()))) { + if (CI2->isNegative()) { + // 'shl nsw CI2, x' produces [CI2 << CLO(CI2)-1, CI2] + unsigned ShiftAmount = CI2->getValue().countLeadingOnes() - 1; + Lower = CI2->getValue().shl(ShiftAmount); + Upper = CI2->getValue() + 1; + } else { + // 'shl nsw CI2, x' produces [CI2, CI2 << CLZ(CI2)-1] + unsigned ShiftAmount = CI2->getValue().countLeadingZeros() - 1; + Lower = CI2->getValue(); + Upper = CI2->getValue().shl(ShiftAmount) + 1; + } } else if (match(LHS, m_LShr(m_Value(), m_ConstantInt(CI2)))) { // 'lshr x, CI2' produces [0, UINT_MAX >> CI2]. APInt NegOne = APInt::getAllOnesValue(Width); @@ -2189,25 +2503,6 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // If a bit is known to be zero for A and known to be one for B, - // then A and B cannot be equal. - if (ICmpInst::isEquality(Pred)) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - uint32_t BitWidth = CI->getBitWidth(); - APInt LHSKnownZero(BitWidth, 0); - APInt LHSKnownOne(BitWidth, 0); - computeKnownBits(LHS, LHSKnownZero, LHSKnownOne); - APInt RHSKnownZero(BitWidth, 0); - APInt RHSKnownOne(BitWidth, 0); - computeKnownBits(RHS, RHSKnownZero, RHSKnownOne); - if (((LHSKnownOne & RHSKnownZero) != 0) || - ((LHSKnownZero & RHSKnownOne) != 0)) - return (Pred == ICmpInst::ICMP_EQ) - ? ConstantInt::getFalse(CI->getContext()) - : ConstantInt::getTrue(CI->getContext()); - } - } - // Special logic for binary operators. BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); @@ -2271,6 +2566,40 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } + // icmp pred (or X, Y), X + if (LBO && match(LBO, m_CombineOr(m_Or(m_Value(), m_Specific(RHS)), + m_Or(m_Specific(RHS), m_Value())))) { + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + } + // icmp pred X, (or X, Y) + if (RBO && match(RBO, m_CombineOr(m_Or(m_Value(), m_Specific(LHS)), + m_Or(m_Specific(LHS), m_Value())))) { + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + } + + // icmp pred (and X, Y), X + if (LBO && match(LBO, m_CombineOr(m_And(m_Value(), m_Specific(RHS)), + m_And(m_Specific(RHS), m_Value())))) { + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + // icmp pred X, (and X, Y) + if (RBO && match(RBO, m_CombineOr(m_And(m_Value(), m_Specific(LHS)), + m_And(m_Specific(LHS), m_Value())))) { + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + } + // 0 - (zext X) pred C if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) { if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { @@ -2301,7 +2630,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2311,7 +2641,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2330,7 +2661,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2340,7 +2672,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL); + ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); if (!KnownNonNegative) break; // fall-through @@ -2360,6 +2693,41 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getTrue(ITy); } + // handle: + // CI2 << X == CI + // CI2 << X != CI + // + // where CI2 is a power of 2 and CI isn't + if (auto *CI = dyn_cast<ConstantInt>(RHS)) { + const APInt *CI2Val, *CIVal = &CI->getValue(); + if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && + CI2Val->isPowerOf2()) { + if (!CIVal->isPowerOf2()) { + // CI2 << X can equal zero in some circumstances, + // this simplification is unsafe if CI is zero. + // + // We know it is safe if: + // - The shift is nsw, we can't shift out the one bit. + // - The shift is nuw, we can't shift out the one bit. + // - CI2 is one + // - CI isn't zero + if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() || + *CI2Val == 1 || !CI->isZero()) { + if (Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + if (CIVal->isSignBit() && *CI2Val == 1) { + if (Pred == ICmpInst::ICMP_UGT) + return ConstantInt::getFalse(RHS->getContext()); + if (Pred == ICmpInst::ICMP_ULE) + return ConstantInt::getTrue(RHS->getContext()); + } + } + } + if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && LBO->getOperand(1) == RBO->getOperand(1)) { switch (LBO->getOpcode()) { @@ -2607,6 +2975,23 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } + // If a bit is known to be zero for A and known to be one for B, + // then A and B cannot be equal. + if (ICmpInst::isEquality(Pred)) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { + uint32_t BitWidth = CI->getBitWidth(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, Q.DL, /*Depth=*/0, Q.AC, + Q.CxtI, Q.DT); + const APInt &RHSVal = CI->getValue(); + if (((LHSKnownZero & RHSVal) != 0) || ((LHSKnownOne & ~RHSVal) != 0)) + return Pred == ICmpInst::ICMP_EQ + ? ConstantInt::getFalse(CI->getContext()) + : ConstantInt::getTrue(CI->getContext()); + } + } + // If the comparison is with the result of a select instruction, check whether // comparing with either branch of the select always yields the same value. if (isa<SelectInst>(LHS) || isa<SelectInst>(RHS)) @@ -2625,8 +3010,9 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyICmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + Instruction *CxtI) { + return ::SimplifyICmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -2722,8 +3108,9 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyFCmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -2755,15 +3142,71 @@ static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, if (isa<UndefValue>(FalseVal)) // select C, X, undef -> X return TrueVal; + const auto *ICI = dyn_cast<ICmpInst>(CondVal); + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (ICI && BitWidth) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + APInt MinSignedValue = APInt::getSignBit(BitWidth); + Value *X; + const APInt *Y; + bool TrueWhenUnset; + bool IsBitTest = false; + if (ICmpInst::isEquality(Pred) && + match(ICI->getOperand(0), m_And(m_Value(X), m_APInt(Y))) && + match(ICI->getOperand(1), m_Zero())) { + IsBitTest = true; + TrueWhenUnset = Pred == ICmpInst::ICMP_EQ; + } else if (Pred == ICmpInst::ICMP_SLT && + match(ICI->getOperand(1), m_Zero())) { + X = ICI->getOperand(0); + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = false; + } else if (Pred == ICmpInst::ICMP_SGT && + match(ICI->getOperand(1), m_AllOnes())) { + X = ICI->getOperand(0); + Y = &MinSignedValue; + IsBitTest = true; + TrueWhenUnset = true; + } + if (IsBitTest) { + const APInt *C; + // (X & Y) == 0 ? X & ~Y : X --> X + // (X & Y) != 0 ? X & ~Y : X --> X & ~Y + if (FalseVal == X && match(TrueVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + // (X & Y) == 0 ? X : X & ~Y --> X & ~Y + // (X & Y) != 0 ? X : X & ~Y --> X + if (TrueVal == X && match(FalseVal, m_And(m_Specific(X), m_APInt(C))) && + *Y == ~*C) + return TrueWhenUnset ? FalseVal : TrueVal; + + if (Y->isPowerOf2()) { + // (X & Y) == 0 ? X | Y : X --> X | Y + // (X & Y) != 0 ? X | Y : X --> X + if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + // (X & Y) == 0 ? X : X | Y --> X + // (X & Y) != 0 ? X : X | Y --> X | Y + if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && + *Y == *C) + return TrueWhenUnset ? TrueVal : FalseVal; + } + } + } + return nullptr; } Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifySelectInst(Cond, TrueVal, FalseVal, Query (DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifySelectInst(Cond, TrueVal, FalseVal, + Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } /// SimplifyGEPInst - Given operands for an GetElementPtrInst, see if we can @@ -2771,29 +3214,72 @@ Value *llvm::SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, static Value *SimplifyGEPInst(ArrayRef<Value *> Ops, const Query &Q, unsigned) { // The type of the GEP pointer operand. PointerType *PtrTy = cast<PointerType>(Ops[0]->getType()->getScalarType()); + unsigned AS = PtrTy->getAddressSpace(); // getelementptr P -> P. if (Ops.size() == 1) return Ops[0]; - if (isa<UndefValue>(Ops[0])) { - // Compute the (pointer) type returned by the GEP instruction. - Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, Ops.slice(1)); - Type *GEPTy = PointerType::get(LastType, PtrTy->getAddressSpace()); - if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) - GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + // Compute the (pointer) type returned by the GEP instruction. + Type *LastType = GetElementPtrInst::getIndexedType(PtrTy, Ops.slice(1)); + Type *GEPTy = PointerType::get(LastType, AS); + if (VectorType *VT = dyn_cast<VectorType>(Ops[0]->getType())) + GEPTy = VectorType::get(GEPTy, VT->getNumElements()); + + if (isa<UndefValue>(Ops[0])) return UndefValue::get(GEPTy); - } if (Ops.size() == 2) { // getelementptr P, 0 -> P. if (match(Ops[1], m_Zero())) return Ops[0]; - // getelementptr P, N -> P if P points to a type of zero size. - if (Q.DL) { - Type *Ty = PtrTy->getElementType(); - if (Ty->isSized() && Q.DL->getTypeAllocSize(Ty) == 0) + + Type *Ty = PtrTy->getElementType(); + if (Q.DL && Ty->isSized()) { + Value *P; + uint64_t C; + uint64_t TyAllocSize = Q.DL->getTypeAllocSize(Ty); + // getelementptr P, N -> P if P points to a type of zero size. + if (TyAllocSize == 0) return Ops[0]; + + // The following transforms are only safe if the ptrtoint cast + // doesn't truncate the pointers. + if (Ops[1]->getType()->getScalarSizeInBits() == + Q.DL->getPointerSizeInBits(AS)) { + auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { + if (match(P, m_Zero())) + return Constant::getNullValue(GEPTy); + Value *Temp; + if (match(P, m_PtrToInt(m_Value(Temp)))) + if (Temp->getType() == GEPTy) + return Temp; + return nullptr; + }; + + // getelementptr V, (sub P, V) -> P if P points to a type of size 1. + if (TyAllocSize == 1 && + match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (ashr (sub P, V), C) -> Q + // if P points to a type of size 1 << C. + if (match(Ops[1], + m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_ConstantInt(C))) && + TyAllocSize == 1ULL << C) + if (Value *R = PtrToIntOrZero(P)) + return R; + + // getelementptr V, (sdiv (sub P, V), C) -> Q + // if P points to a type of size C. + if (match(Ops[1], + m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), + m_SpecificInt(TyAllocSize)))) + if (Value *R = PtrToIntOrZero(P)) + return R; + } } } @@ -2807,8 +3293,9 @@ static Value *SimplifyGEPInst(ArrayRef<Value *> Ops, const Query &Q, unsigned) { Value *llvm::SimplifyGEPInst(ArrayRef<Value *> Ops, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyGEPInst(Ops, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyGEPInst(Ops, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } /// SimplifyInsertValueInst - Given operands for an InsertValueInst, see if we @@ -2840,12 +3327,11 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return nullptr; } -Value *llvm::SimplifyInsertValueInst(Value *Agg, Value *Val, - ArrayRef<unsigned> Idxs, - const DataLayout *DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query (DL, TLI, DT), +Value *llvm::SimplifyInsertValueInst( + Value *Agg, Value *Val, ArrayRef<unsigned> Idxs, const DataLayout *DL, + const TargetLibraryInfo *TLI, const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyInsertValueInst(Agg, Val, Idxs, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -2892,8 +3378,10 @@ static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { Value *llvm::SimplifyTruncInst(Value *Op, Type *Ty, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyTruncInst(Op, Ty, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyTruncInst(Op, Ty, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } //=== Helper functions for higher up the class hierarchy. @@ -2965,8 +3453,10 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, Value *llvm::SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyBinOp(Opcode, LHS, RHS, Query (DL, TLI, DT), RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyBinOp(Opcode, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), + RecursionLimit); } /// SimplifyCmpInst - Given operands for a CmpInst, see if we can @@ -2980,8 +3470,9 @@ static Value *SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, Value *llvm::SimplifyCmpInst(unsigned Predicate, Value *LHS, Value *RHS, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCmpInst(Predicate, LHS, RHS, Query (DL, TLI, DT), + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyCmpInst(Predicate, LHS, RHS, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } @@ -3055,24 +3546,25 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, Value *llvm::SimplifyCall(Value *V, User::op_iterator ArgBegin, User::op_iterator ArgEnd, const DataLayout *DL, - const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT), + const TargetLibraryInfo *TLI, const DominatorTree *DT, + AssumptionCache *AC, const Instruction *CxtI) { + return ::SimplifyCall(V, ArgBegin, ArgEnd, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } Value *llvm::SimplifyCall(Value *V, ArrayRef<Value *> Args, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return ::SimplifyCall(V, Args.begin(), Args.end(), Query(DL, TLI, DT), - RecursionLimit); + const DominatorTree *DT, AssumptionCache *AC, + const Instruction *CxtI) { + return ::SimplifyCall(V, Args.begin(), Args.end(), + Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } /// SimplifyInstruction - See if we can compute a simplified version of this /// instruction. If not, this returns null. Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, AssumptionCache *AC) { Value *Result; switch (I->getOpcode()) { @@ -3081,109 +3573,122 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, break; case Instruction::FAdd: Result = SimplifyFAddInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AC, I); break; case Instruction::Add: Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, + TLI, DT, AC, I); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AC, I); break; case Instruction::Sub: Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, + TLI, DT, AC, I); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), - I->getFastMathFlags(), DL, TLI, DT); + I->getFastMathFlags(), DL, TLI, DT, AC, I); break; case Instruction::Mul: - Result = SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyMulInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::SDiv: - Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::UDiv: - Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyUDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::FDiv: - Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFDivInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::SRem: - Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifySRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::URem: - Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyURemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::FRem: - Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = SimplifyFRemInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, + AC, I); break; case Instruction::Shl: Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), - DL, TLI, DT); + cast<BinaryOperator>(I)->hasNoUnsignedWrap(), DL, + TLI, DT, AC, I); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, + AC, I); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), - DL, TLI, DT); + cast<BinaryOperator>(I)->isExact(), DL, TLI, DT, + AC, I); break; case Instruction::And: - Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyAndInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::Or: - Result = SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyOrInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::Xor: - Result = SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyXorInst(I->getOperand(0), I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::ICmp: - Result = SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), I->getOperand(0), + I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::FCmp: - Result = SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), - I->getOperand(0), I->getOperand(1), DL, TLI, DT); + Result = + SimplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), I->getOperand(0), + I->getOperand(1), DL, TLI, DT, AC, I); break; case Instruction::Select: Result = SimplifySelectInst(I->getOperand(0), I->getOperand(1), - I->getOperand(2), DL, TLI, DT); + I->getOperand(2), DL, TLI, DT, AC, I); break; case Instruction::GetElementPtr: { SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); - Result = SimplifyGEPInst(Ops, DL, TLI, DT); + Result = SimplifyGEPInst(Ops, DL, TLI, DT, AC, I); break; } case Instruction::InsertValue: { InsertValueInst *IV = cast<InsertValueInst>(I); Result = SimplifyInsertValueInst(IV->getAggregateOperand(), IV->getInsertedValueOperand(), - IV->getIndices(), DL, TLI, DT); + IV->getIndices(), DL, TLI, DT, AC, I); break; } case Instruction::PHI: - Result = SimplifyPHINode(cast<PHINode>(I), Query (DL, TLI, DT)); + Result = SimplifyPHINode(cast<PHINode>(I), Query(DL, TLI, DT, AC, I)); break; case Instruction::Call: { CallSite CS(cast<CallInst>(I)); - Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), - DL, TLI, DT); + Result = SimplifyCall(CS.getCalledValue(), CS.arg_begin(), CS.arg_end(), DL, + TLI, DT, AC, I); break; } case Instruction::Trunc: - Result = SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT); + Result = + SimplifyTruncInst(I->getOperand(0), I->getType(), DL, TLI, DT, AC, I); break; } @@ -3207,7 +3712,8 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout *DL, static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionCache *AC) { bool Simplified = false; SmallSetVector<Instruction *, 8> Worklist; @@ -3234,7 +3740,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, I = Worklist[Idx]; // See if this instruction simplifies. - SimpleV = SimplifyInstruction(I, DL, TLI, DT); + SimpleV = SimplifyInstruction(I, DL, TLI, DT, AC); if (!SimpleV) continue; @@ -3257,18 +3763,19 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, return Simplified; } -bool llvm::recursivelySimplifyInstruction(Instruction *I, - const DataLayout *DL, +bool llvm::recursivelySimplifyInstruction(Instruction *I, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { - return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT); + const DominatorTree *DT, + AssumptionCache *AC) { + return replaceAndRecursivelySimplifyImpl(I, nullptr, DL, TLI, DT, AC); } bool llvm::replaceAndRecursivelySimplify(Instruction *I, Value *SimpleV, const DataLayout *DL, const TargetLibraryInfo *TLI, - const DominatorTree *DT) { + const DominatorTree *DT, + AssumptionCache *AC) { assert(I != SimpleV && "replaceAndRecursivelySimplify(X,X) is not valid!"); assert(SimpleV && "Must provide a simplified value."); - return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT); + return replaceAndRecursivelySimplifyImpl(I, SimpleV, DL, TLI, DT, AC); } diff --git a/lib/Analysis/JumpInstrTableInfo.cpp b/lib/Analysis/JumpInstrTableInfo.cpp index b5b426533ff3..7aae2a5592e9 100644 --- a/lib/Analysis/JumpInstrTableInfo.cpp +++ b/lib/Analysis/JumpInstrTableInfo.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/Passes.h" #include "llvm/IR/Function.h" #include "llvm/IR/Type.h" +#include "llvm/Support/MathExtras.h" using namespace llvm; @@ -28,7 +29,21 @@ ImmutablePass *llvm::createJumpInstrTableInfoPass() { return new JumpInstrTableInfo(); } -JumpInstrTableInfo::JumpInstrTableInfo() : ImmutablePass(ID), Tables() { +ModulePass *llvm::createJumpInstrTableInfoPass(unsigned Bound) { + // This cast is always safe, since Bound is always in a subset of uint64_t. + uint64_t B = static_cast<uint64_t>(Bound); + return new JumpInstrTableInfo(B); +} + +JumpInstrTableInfo::JumpInstrTableInfo(uint64_t ByteAlign) + : ImmutablePass(ID), Tables(), ByteAlignment(ByteAlign) { + if (!llvm::isPowerOf2_64(ByteAlign)) { + // Note that we don't explicitly handle overflow here, since we handle the 0 + // case explicitly when a caller actually tries to create jumptable entries, + // and this is the return value on overflow. + ByteAlignment = llvm::NextPowerOf2(ByteAlign); + } + initializeJumpInstrTableInfoPass(*PassRegistry::getPassRegistry()); } diff --git a/lib/Analysis/LazyCallGraph.cpp b/lib/Analysis/LazyCallGraph.cpp index e0736162a77a..c8d0410c1e0f 100644 --- a/lib/Analysis/LazyCallGraph.cpp +++ b/lib/Analysis/LazyCallGraph.cpp @@ -48,7 +48,7 @@ static void findCallees( } for (Value *Op : C->operand_values()) - if (Visited.insert(cast<Constant>(Op))) + if (Visited.insert(cast<Constant>(Op)).second) Worklist.push_back(cast<Constant>(Op)); } } @@ -66,7 +66,7 @@ LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) for (Instruction &I : BB) for (Value *Op : I.operand_values()) if (Constant *C = dyn_cast<Constant>(Op)) - if (Visited.insert(C)) + if (Visited.insert(C).second) Worklist.push_back(C); // We've collected all the constant (and thus potentially function or @@ -113,7 +113,7 @@ LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { SmallPtrSet<Constant *, 16> Visited; for (GlobalVariable &GV : M.globals()) if (GV.hasInitializer()) - if (Visited.insert(GV.getInitializer())) + if (Visited.insert(GV.getInitializer()).second) Worklist.push_back(GV.getInitializer()); DEBUG(dbgs() << " Adding functions referenced by global initializers to the " @@ -688,7 +688,7 @@ static void printNodes(raw_ostream &OS, LazyCallGraph::Node &N, SmallPtrSetImpl<LazyCallGraph::Node *> &Printed) { // Recurse depth first through the nodes. for (LazyCallGraph::Node &ChildN : N) - if (Printed.insert(&ChildN)) + if (Printed.insert(&ChildN).second) printNodes(OS, ChildN, Printed); OS << " Call edges in function: " << N.getFunction().getName() << "\n"; @@ -708,21 +708,20 @@ static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &SCC) { OS << "\n"; } -PreservedAnalyses LazyCallGraphPrinterPass::run(Module *M, +PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M, ModuleAnalysisManager *AM) { LazyCallGraph &G = AM->getResult<LazyCallGraphAnalysis>(M); - OS << "Printing the call graph for module: " << M->getModuleIdentifier() + OS << "Printing the call graph for module: " << M.getModuleIdentifier() << "\n\n"; SmallPtrSet<LazyCallGraph::Node *, 16> Printed; for (LazyCallGraph::Node &N : G) - if (Printed.insert(&N)) + if (Printed.insert(&N).second) printNodes(OS, N, Printed); for (LazyCallGraph::SCC &SCC : G.postorder_sccs()) printSCC(OS, SCC); return PreservedAnalyses::all(); - } diff --git a/lib/Analysis/LazyValueInfo.cpp b/lib/Analysis/LazyValueInfo.cpp index 9f919f7644a3..4de56f195cb4 100644 --- a/lib/Analysis/LazyValueInfo.cpp +++ b/lib/Analysis/LazyValueInfo.cpp @@ -1,4 +1,4 @@ -//===- LazyValueInfo.cpp - Value constraint analysis ----------------------===// +//===- LazyValueInfo.cpp - Value constraint analysis ------------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -15,12 +15,14 @@ #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" @@ -38,6 +40,7 @@ using namespace PatternMatch; char LazyValueInfo::ID = 0; INITIALIZE_PASS_BEGIN(LazyValueInfo, "lazy-value-info", "Lazy Value Information Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_END(LazyValueInfo, "lazy-value-info", "Lazy Value Information Analysis", false, true) @@ -51,8 +54,7 @@ namespace llvm { // LVILatticeVal //===----------------------------------------------------------------------===// -/// LVILatticeVal - This is the information tracked by LazyValueInfo for each -/// value. +/// This is the information tracked by LazyValueInfo for each value. /// /// FIXME: This is basically just for bringup, this can be made a lot more rich /// in the future. @@ -60,19 +62,19 @@ namespace llvm { namespace { class LVILatticeVal { enum LatticeValueTy { - /// undefined - This Value has no known value yet. + /// This Value has no known value yet. undefined, - /// constant - This Value has a specific constant value. + /// This Value has a specific constant value. constant, - /// notconstant - This Value is known to not have the specified value. + + /// This Value is known to not have the specified value. notconstant, - /// constantrange - The Value falls within this range. + /// The Value falls within this range. constantrange, - /// overdefined - This value is not known to be constant, and we know that - /// it has a value. + /// This value is not known to be constant, and we know that it has a value. overdefined }; @@ -125,7 +127,7 @@ public: return Range; } - /// markOverdefined - Return true if this is a change in status. + /// Return true if this is a change in status. bool markOverdefined() { if (isOverdefined()) return false; @@ -133,7 +135,7 @@ public: return true; } - /// markConstant - Return true if this is a change in status. + /// Return true if this is a change in status. bool markConstant(Constant *V) { assert(V && "Marking constant with NULL"); if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) @@ -149,7 +151,7 @@ public: return true; } - /// markNotConstant - Return true if this is a change in status. + /// Return true if this is a change in status. bool markNotConstant(Constant *V) { assert(V && "Marking constant with NULL"); if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) @@ -167,7 +169,7 @@ public: return true; } - /// markConstantRange - Return true if this is a change in status. + /// Return true if this is a change in status. bool markConstantRange(const ConstantRange NewR) { if (isConstantRange()) { if (NewR.isEmptySet()) @@ -187,7 +189,7 @@ public: return true; } - /// mergeIn - Merge the specified lattice value into this one, updating this + /// Merge the specified lattice value into this one, updating this /// one and returning true if anything changed. bool mergeIn(const LVILatticeVal &RHS) { if (RHS.isUndefined() || isOverdefined()) return false; @@ -295,8 +297,7 @@ raw_ostream &operator<<(raw_ostream &OS, const LVILatticeVal &Val) { //===----------------------------------------------------------------------===// namespace { - /// LVIValueHandle - A callback value handle updates the cache when - /// values are erased. + /// A callback value handle updates the cache when values are erased. class LazyValueInfoCache; struct LVIValueHandle : public CallbackVH { LazyValueInfoCache *Parent; @@ -312,59 +313,67 @@ namespace { } namespace { - /// LazyValueInfoCache - This is the cache kept by LazyValueInfo which + /// This is the cache kept by LazyValueInfo which /// maintains information about queries across the clients' queries. class LazyValueInfoCache { - /// ValueCacheEntryTy - This is all of the cached block information for - /// exactly one Value*. The entries are sorted by the BasicBlock* of the + /// This is all of the cached block information for exactly one Value*. + /// The entries are sorted by the BasicBlock* of the /// entries, allowing us to do a lookup with a binary search. typedef std::map<AssertingVH<BasicBlock>, LVILatticeVal> ValueCacheEntryTy; - /// ValueCache - This is all of the cached information for all values, + /// This is all of the cached information for all values, /// mapped from Value* to key information. std::map<LVIValueHandle, ValueCacheEntryTy> ValueCache; - /// OverDefinedCache - This tracks, on a per-block basis, the set of - /// values that are over-defined at the end of that block. This is required + /// This tracks, on a per-block basis, the set of values that are + /// over-defined at the end of that block. This is required /// for cache updating. typedef std::pair<AssertingVH<BasicBlock>, Value*> OverDefinedPairTy; DenseSet<OverDefinedPairTy> OverDefinedCache; - /// SeenBlocks - Keep track of all blocks that we have ever seen, so we + /// Keep track of all blocks that we have ever seen, so we /// don't spend time removing unused blocks from our caches. DenseSet<AssertingVH<BasicBlock> > SeenBlocks; - /// BlockValueStack - This stack holds the state of the value solver - /// during a query. It basically emulates the callstack of the naive + /// This stack holds the state of the value solver during a query. + /// It basically emulates the callstack of the naive /// recursive value lookup process. std::stack<std::pair<BasicBlock*, Value*> > BlockValueStack; + + /// Keeps track of which block-value pairs are in BlockValueStack. + DenseSet<std::pair<BasicBlock*, Value*> > BlockValueSet; + + /// Push BV onto BlockValueStack unless it's already in there. + /// Returns true on success. + bool pushBlockValue(const std::pair<BasicBlock *, Value *> &BV) { + if (BlockValueSet.count(BV)) + return false; // It's already in the stack. + + BlockValueStack.push(BV); + BlockValueSet.insert(BV); + return true; + } + + /// A pointer to the cache of @llvm.assume calls. + AssumptionCache *AC; + /// An optional DL pointer. + const DataLayout *DL; + /// An optional DT pointer. + DominatorTree *DT; friend struct LVIValueHandle; - - /// OverDefinedCacheUpdater - A helper object that ensures that the - /// OverDefinedCache is updated whenever solveBlockValue returns. - struct OverDefinedCacheUpdater { - LazyValueInfoCache *Parent; - Value *Val; - BasicBlock *BB; - LVILatticeVal &BBLV; - - OverDefinedCacheUpdater(Value *V, BasicBlock *B, LVILatticeVal &LV, - LazyValueInfoCache *P) - : Parent(P), Val(V), BB(B), BBLV(LV) { } - - bool markResult(bool changed) { - if (changed && BBLV.isOverdefined()) - Parent->OverDefinedCache.insert(std::make_pair(BB, Val)); - return changed; - } - }; - + void insertResult(Value *Val, BasicBlock *BB, const LVILatticeVal &Result) { + SeenBlocks.insert(BB); + lookup(Val)[BB] = Result; + if (Result.isOverdefined()) + OverDefinedCache.insert(std::make_pair(BB, Val)); + } LVILatticeVal getBlockValue(Value *Val, BasicBlock *BB); bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, - LVILatticeVal &Result); + LVILatticeVal &Result, + Instruction *CxtI = nullptr); bool hasBlockValue(Value *Val, BasicBlock *BB); // These methods process one work item and may add more. A false value @@ -377,6 +386,8 @@ namespace { PHINode *PN, BasicBlock *BB); bool solveBlockValueConstantRange(LVILatticeVal &BBLV, Instruction *BBI, BasicBlock *BB); + void mergeAssumeBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, + Instruction *BBI); void solve(); @@ -385,20 +396,26 @@ namespace { } public: - /// getValueInBlock - This is the query interface to determine the lattice + /// This is the query interface to determine the lattice /// value for the specified Value* at the end of the specified block. - LVILatticeVal getValueInBlock(Value *V, BasicBlock *BB); + LVILatticeVal getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI = nullptr); + + /// This is the query interface to determine the lattice + /// value for the specified Value* at the specified instruction (generally + /// from an assume intrinsic). + LVILatticeVal getValueAt(Value *V, Instruction *CxtI); - /// getValueOnEdge - This is the query interface to determine the lattice + /// This is the query interface to determine the lattice /// value for the specified Value* that is true on the specified edge. - LVILatticeVal getValueOnEdge(Value *V, BasicBlock *FromBB,BasicBlock *ToBB); + LVILatticeVal getValueOnEdge(Value *V, BasicBlock *FromBB,BasicBlock *ToBB, + Instruction *CxtI = nullptr); - /// threadEdge - This is the update interface to inform the cache that an - /// edge from PredBB to OldSucc has been threaded to be from PredBB to - /// NewSucc. + /// This is the update interface to inform the cache that an edge from + /// PredBB to OldSucc has been threaded to be from PredBB to NewSucc. void threadEdge(BasicBlock *PredBB,BasicBlock *OldSucc,BasicBlock *NewSucc); - /// eraseBlock - This is part of the update interface to inform the cache + /// This is part of the update interface to inform the cache /// that a block has been deleted. void eraseBlock(BasicBlock *BB); @@ -408,6 +425,10 @@ namespace { ValueCache.clear(); OverDefinedCache.clear(); } + + LazyValueInfoCache(AssumptionCache *AC, const DataLayout *DL = nullptr, + DominatorTree *DT = nullptr) + : AC(AC), DL(DL), DT(DT) {} }; } // end anonymous namespace @@ -415,17 +436,11 @@ void LVIValueHandle::deleted() { typedef std::pair<AssertingVH<BasicBlock>, Value*> OverDefinedPairTy; SmallVector<OverDefinedPairTy, 4> ToErase; - for (DenseSet<OverDefinedPairTy>::iterator - I = Parent->OverDefinedCache.begin(), - E = Parent->OverDefinedCache.end(); - I != E; ++I) { - if (I->second == getValPtr()) - ToErase.push_back(*I); - } - - for (SmallVectorImpl<OverDefinedPairTy>::iterator I = ToErase.begin(), - E = ToErase.end(); I != E; ++I) - Parent->OverDefinedCache.erase(*I); + for (const OverDefinedPairTy &P : Parent->OverDefinedCache) + if (P.second == getValPtr()) + ToErase.push_back(P); + for (const OverDefinedPairTy &P : ToErase) + Parent->OverDefinedCache.erase(P); // This erasure deallocates *this, so it MUST happen after we're done // using any and all members of *this. @@ -440,15 +455,11 @@ void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { SeenBlocks.erase(I); SmallVector<OverDefinedPairTy, 4> ToErase; - for (DenseSet<OverDefinedPairTy>::iterator I = OverDefinedCache.begin(), - E = OverDefinedCache.end(); I != E; ++I) { - if (I->first == BB) - ToErase.push_back(*I); - } - - for (SmallVectorImpl<OverDefinedPairTy>::iterator I = ToErase.begin(), - E = ToErase.end(); I != E; ++I) - OverDefinedCache.erase(*I); + for (const OverDefinedPairTy& P : OverDefinedCache) + if (P.first == BB) + ToErase.push_back(P); + for (const OverDefinedPairTy &P : ToErase) + OverDefinedCache.erase(P); for (std::map<LVIValueHandle, ValueCacheEntryTy>::iterator I = ValueCache.begin(), E = ValueCache.end(); I != E; ++I) @@ -458,9 +469,18 @@ void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { void LazyValueInfoCache::solve() { while (!BlockValueStack.empty()) { std::pair<BasicBlock*, Value*> &e = BlockValueStack.top(); + assert(BlockValueSet.count(e) && "Stack value should be in BlockValueSet!"); + if (solveBlockValue(e.second, e.first)) { - assert(BlockValueStack.top() == e); + // The work item was completely processed. + assert(BlockValueStack.top() == e && "Nothing should have been pushed!"); + assert(lookup(e.second).count(e.first) && "Result should be in cache!"); + BlockValueStack.pop(); + BlockValueSet.erase(e); + } else { + // More work needs to be done before revisiting. + assert(BlockValueStack.top() != e && "Stack should have been pushed!"); } } } @@ -490,44 +510,40 @@ bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { if (isa<Constant>(Val)) return true; - ValueCacheEntryTy &Cache = lookup(Val); - SeenBlocks.insert(BB); - LVILatticeVal &BBLV = Cache[BB]; - - // OverDefinedCacheUpdater is a helper object that will update - // the OverDefinedCache for us when this method exits. Make sure to - // call markResult on it as we exist, passing a bool to indicate if the - // cache needs updating, i.e. if we have solve a new value or not. - OverDefinedCacheUpdater ODCacheUpdater(Val, BB, BBLV, this); - - // If we've already computed this block's value, return it. - if (!BBLV.isUndefined()) { - DEBUG(dbgs() << " reuse BB '" << BB->getName() << "' val=" << BBLV <<'\n'); - - // Since we're reusing a cached value here, we don't need to update the - // OverDefinedCahce. The cache will have been properly updated - // whenever the cached value was inserted. - ODCacheUpdater.markResult(false); + if (lookup(Val).count(BB)) { + // If we have a cached value, use that. + DEBUG(dbgs() << " reuse BB '" << BB->getName() + << "' val=" << lookup(Val)[BB] << '\n'); + + // Since we're reusing a cached value, we don't need to update the + // OverDefinedCache. The cache will have been properly updated whenever the + // cached value was inserted. return true; } - // Otherwise, this is the first time we're seeing this block. Reset the - // lattice value to overdefined, so that cycles will terminate and be - // conservatively correct. - BBLV.markOverdefined(); + // Hold off inserting this value into the Cache in case we have to return + // false and come back later. + LVILatticeVal Res; Instruction *BBI = dyn_cast<Instruction>(Val); if (!BBI || BBI->getParent() != BB) { - return ODCacheUpdater.markResult(solveBlockValueNonLocal(BBLV, Val, BB)); + if (!solveBlockValueNonLocal(Res, Val, BB)) + return false; + insertResult(Val, BB, Res); + return true; } if (PHINode *PN = dyn_cast<PHINode>(BBI)) { - return ODCacheUpdater.markResult(solveBlockValuePHINode(BBLV, PN, BB)); + if (!solveBlockValuePHINode(Res, PN, BB)) + return false; + insertResult(Val, BB, Res); + return true; } if (AllocaInst *AI = dyn_cast<AllocaInst>(BBI)) { - BBLV = LVILatticeVal::getNot(ConstantPointerNull::get(AI->getType())); - return ODCacheUpdater.markResult(true); + Res = LVILatticeVal::getNot(ConstantPointerNull::get(AI->getType())); + insertResult(Val, BB, Res); + return true; } // We can only analyze the definitions of certain classes of instructions @@ -537,8 +553,9 @@ bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { !BBI->getType()->isIntegerTy()) { DEBUG(dbgs() << " compute BB '" << BB->getName() << "' - overdefined because inst def found.\n"); - BBLV.markOverdefined(); - return ODCacheUpdater.markResult(true); + Res.markOverdefined(); + insertResult(Val, BB, Res); + return true; } // FIXME: We're currently limited to binops with a constant RHS. This should @@ -548,11 +565,15 @@ bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { DEBUG(dbgs() << " compute BB '" << BB->getName() << "' - overdefined because inst def found.\n"); - BBLV.markOverdefined(); - return ODCacheUpdater.markResult(true); + Res.markOverdefined(); + insertResult(Val, BB, Res); + return true; } - return ODCacheUpdater.markResult(solveBlockValueConstantRange(BBLV, BBI, BB)); + if (!solveBlockValueConstantRange(Res, BBI, BB)) + return false; + insertResult(Val, BB, Res); + return true; } static bool InstructionDereferencesPointer(Instruction *I, Value *Ptr) { @@ -597,9 +618,8 @@ bool LazyValueInfoCache::solveBlockValueNonLocal(LVILatticeVal &BBLV, // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge // inside InstructionDereferencesPointer either. if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, nullptr, 1)) { - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); - BI != BE; ++BI) { - if (InstructionDereferencesPointer(BI, UnderlyingVal)) { + for (Instruction &I : *BB) { + if (InstructionDereferencesPointer(&I, UnderlyingVal)) { NotNull = true; break; } @@ -669,7 +689,10 @@ bool LazyValueInfoCache::solveBlockValuePHINode(LVILatticeVal &BBLV, BasicBlock *PhiBB = PN->getIncomingBlock(i); Value *PhiVal = PN->getIncomingValue(i); LVILatticeVal EdgeResult; - EdgesMissing |= !getEdgeValue(PhiVal, PhiBB, BB, EdgeResult); + // Note that we can provide PN as the context value to getEdgeValue, even + // though the results will be cached, because PN is the value being used as + // the cache key in the caller. + EdgesMissing |= !getEdgeValue(PhiVal, PhiBB, BB, EdgeResult, PN); if (EdgesMissing) continue; @@ -694,16 +717,53 @@ bool LazyValueInfoCache::solveBlockValuePHINode(LVILatticeVal &BBLV, return true; } +static bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, + LVILatticeVal &Result, + bool isTrueDest = true); + +// If we can determine a constant range for the value Val in the context +// provided by the instruction BBI, then merge it into BBLV. If we did find a +// constant range, return true. +void LazyValueInfoCache::mergeAssumeBlockValueConstantRange(Value *Val, + LVILatticeVal &BBLV, + Instruction *BBI) { + BBI = BBI ? BBI : dyn_cast<Instruction>(Val); + if (!BBI) + return; + + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *I = cast<CallInst>(AssumeVH); + if (!isValidAssumeForContext(I, BBI, DL, DT)) + continue; + + Value *C = I->getArgOperand(0); + if (ICmpInst *ICI = dyn_cast<ICmpInst>(C)) { + LVILatticeVal Result; + if (getValueFromFromCondition(Val, ICI, Result)) { + if (BBLV.isOverdefined()) + BBLV = Result; + else + BBLV.mergeIn(Result); + } + } + } +} + bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, Instruction *BBI, BasicBlock *BB) { // Figure out the range of the LHS. If that fails, bail. if (!hasBlockValue(BBI->getOperand(0), BB)) { - BlockValueStack.push(std::make_pair(BB, BBI->getOperand(0))); - return false; + if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) + return false; + BBLV.markOverdefined(); + return true; } LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); + mergeAssumeBlockValueConstantRange(BBI->getOperand(0), LHSVal, BBI); if (!LHSVal.isConstantRange()) { BBLV.markOverdefined(); return true; @@ -775,6 +835,47 @@ bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, return true; } +bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, + LVILatticeVal &Result, bool isTrueDest) { + if (ICI && isa<Constant>(ICI->getOperand(1))) { + if (ICI->isEquality() && ICI->getOperand(0) == Val) { + // We know that V has the RHS constant if this is a true SETEQ or + // false SETNE. + if (isTrueDest == (ICI->getPredicate() == ICmpInst::ICMP_EQ)) + Result = LVILatticeVal::get(cast<Constant>(ICI->getOperand(1))); + else + Result = LVILatticeVal::getNot(cast<Constant>(ICI->getOperand(1))); + return true; + } + + // Recognize the range checking idiom that InstCombine produces. + // (X-C1) u< C2 --> [C1, C1+C2) + ConstantInt *NegOffset = nullptr; + if (ICI->getPredicate() == ICmpInst::ICMP_ULT) + match(ICI->getOperand(0), m_Add(m_Specific(Val), + m_ConstantInt(NegOffset))); + + ConstantInt *CI = dyn_cast<ConstantInt>(ICI->getOperand(1)); + if (CI && (ICI->getOperand(0) == Val || NegOffset)) { + // Calculate the range of values that would satisfy the comparison. + ConstantRange CmpRange(CI->getValue()); + ConstantRange TrueValues = + ConstantRange::makeICmpRegion(ICI->getPredicate(), CmpRange); + + if (NegOffset) // Apply the offset from above. + TrueValues = TrueValues.subtract(NegOffset->getValue()); + + // If we're interested in the false dest, invert the condition. + if (!isTrueDest) TrueValues = TrueValues.inverse(); + + Result = LVILatticeVal::getRange(TrueValues); + return true; + } + } + + return false; +} + /// \brief Compute the value of Val on the edge BBFrom -> BBTo. Returns false if /// Val is not constrained on the edge. static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, @@ -783,7 +884,7 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, // know that v != 0. if (BranchInst *BI = dyn_cast<BranchInst>(BBFrom->getTerminator())) { // If this is a conditional branch and only one successor goes to BBTo, then - // we maybe able to infer something from the condition. + // we may be able to infer something from the condition. if (BI->isConditional() && BI->getSuccessor(0) != BI->getSuccessor(1)) { bool isTrueDest = BI->getSuccessor(0) == BBTo; @@ -800,42 +901,9 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, // If the condition of the branch is an equality comparison, we may be // able to infer the value. - ICmpInst *ICI = dyn_cast<ICmpInst>(BI->getCondition()); - if (ICI && isa<Constant>(ICI->getOperand(1))) { - if (ICI->isEquality() && ICI->getOperand(0) == Val) { - // We know that V has the RHS constant if this is a true SETEQ or - // false SETNE. - if (isTrueDest == (ICI->getPredicate() == ICmpInst::ICMP_EQ)) - Result = LVILatticeVal::get(cast<Constant>(ICI->getOperand(1))); - else - Result = LVILatticeVal::getNot(cast<Constant>(ICI->getOperand(1))); - return true; - } - - // Recognize the range checking idiom that InstCombine produces. - // (X-C1) u< C2 --> [C1, C1+C2) - ConstantInt *NegOffset = nullptr; - if (ICI->getPredicate() == ICmpInst::ICMP_ULT) - match(ICI->getOperand(0), m_Add(m_Specific(Val), - m_ConstantInt(NegOffset))); - - ConstantInt *CI = dyn_cast<ConstantInt>(ICI->getOperand(1)); - if (CI && (ICI->getOperand(0) == Val || NegOffset)) { - // Calculate the range of values that would satisfy the comparison. - ConstantRange CmpRange(CI->getValue()); - ConstantRange TrueValues = - ConstantRange::makeICmpRegion(ICI->getPredicate(), CmpRange); - - if (NegOffset) // Apply the offset from above. - TrueValues = TrueValues.subtract(NegOffset->getValue()); - - // If we're interested in the false dest, invert the condition. - if (!isTrueDest) TrueValues = TrueValues.inverse(); - - Result = LVILatticeVal::getRange(TrueValues); + if (ICmpInst *ICI = dyn_cast<ICmpInst>(BI->getCondition())) + if (getValueFromFromCondition(Val, ICI, Result, isTrueDest)) return true; - } - } } } @@ -849,8 +917,7 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, unsigned BitWidth = Val->getType()->getIntegerBitWidth(); ConstantRange EdgesVals(BitWidth, DefaultCase/*isFullSet*/); - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) { + for (SwitchInst::CaseIt i : SI->cases()) { ConstantRange EdgeVal(i.getCaseValue()->getValue()); if (DefaultCase) { // It is possible that the default destination is the destination of @@ -866,10 +933,11 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, return false; } -/// \brief Compute the value of Val on the edge BBFrom -> BBTo, or the value at -/// the basic block if the edge does not constraint Val. +/// \brief Compute the value of Val on the edge BBFrom -> BBTo or the value at +/// the basic block if the edge does not constrain Val. bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, - BasicBlock *BBTo, LVILatticeVal &Result) { + BasicBlock *BBTo, LVILatticeVal &Result, + Instruction *CxtI) { // If already a constant, there is nothing to compute. if (Constant *VC = dyn_cast<Constant>(Val)) { Result = LVILatticeVal::get(VC); @@ -878,19 +946,25 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, if (getEdgeValueLocal(Val, BBFrom, BBTo, Result)) { if (!Result.isConstantRange() || - Result.getConstantRange().getSingleElement()) + Result.getConstantRange().getSingleElement()) return true; // FIXME: this check should be moved to the beginning of the function when // LVI better supports recursive values. Even for the single value case, we // can intersect to detect dead code (an empty range). if (!hasBlockValue(Val, BBFrom)) { - BlockValueStack.push(std::make_pair(BBFrom, Val)); - return false; + if (pushBlockValue(std::make_pair(BBFrom, Val))) + return false; + Result.markOverdefined(); + return true; } // Try to intersect ranges of the BB and the constraint on the edge. LVILatticeVal InBlock = getBlockValue(Val, BBFrom); + mergeAssumeBlockValueConstantRange(Val, InBlock, BBFrom->getTerminator()); + // See note on the use of the CxtI with mergeAssumeBlockValueConstantRange, + // and caching, below. + mergeAssumeBlockValueConstantRange(Val, InBlock, CxtI); if (!InBlock.isConstantRange()) return true; @@ -901,36 +975,64 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, } if (!hasBlockValue(Val, BBFrom)) { - BlockValueStack.push(std::make_pair(BBFrom, Val)); - return false; + if (pushBlockValue(std::make_pair(BBFrom, Val))) + return false; + Result.markOverdefined(); + return true; } - // if we couldn't compute the value on the edge, use the value from the BB + // If we couldn't compute the value on the edge, use the value from the BB. Result = getBlockValue(Val, BBFrom); + mergeAssumeBlockValueConstantRange(Val, Result, BBFrom->getTerminator()); + // We can use the context instruction (generically the ultimate instruction + // the calling pass is trying to simplify) here, even though the result of + // this function is generally cached when called from the solve* functions + // (and that cached result might be used with queries using a different + // context instruction), because when this function is called from the solve* + // functions, the context instruction is not provided. When called from + // LazyValueInfoCache::getValueOnEdge, the context instruction is provided, + // but then the result is not cached. + mergeAssumeBlockValueConstantRange(Val, Result, CxtI); return true; } -LVILatticeVal LazyValueInfoCache::getValueInBlock(Value *V, BasicBlock *BB) { +LVILatticeVal LazyValueInfoCache::getValueInBlock(Value *V, BasicBlock *BB, + Instruction *CxtI) { DEBUG(dbgs() << "LVI Getting block end value " << *V << " at '" << BB->getName() << "'\n"); - BlockValueStack.push(std::make_pair(BB, V)); + assert(BlockValueStack.empty() && BlockValueSet.empty()); + pushBlockValue(std::make_pair(BB, V)); + solve(); LVILatticeVal Result = getBlockValue(V, BB); + mergeAssumeBlockValueConstantRange(V, Result, CxtI); + + DEBUG(dbgs() << " Result = " << Result << "\n"); + return Result; +} + +LVILatticeVal LazyValueInfoCache::getValueAt(Value *V, Instruction *CxtI) { + DEBUG(dbgs() << "LVI Getting value " << *V << " at '" + << CxtI->getName() << "'\n"); + + LVILatticeVal Result; + mergeAssumeBlockValueConstantRange(V, Result, CxtI); DEBUG(dbgs() << " Result = " << Result << "\n"); return Result; } LVILatticeVal LazyValueInfoCache:: -getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB) { +getValueOnEdge(Value *V, BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { DEBUG(dbgs() << "LVI Getting edge value " << *V << " from '" << FromBB->getName() << "' to '" << ToBB->getName() << "'\n"); LVILatticeVal Result; - if (!getEdgeValue(V, FromBB, ToBB, Result)) { + if (!getEdgeValue(V, FromBB, ToBB, Result, CxtI)) { solve(); - bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result); + bool WasFastQuery = getEdgeValue(V, FromBB, ToBB, Result, CxtI); (void)WasFastQuery; assert(WasFastQuery && "More work to do after problem solved?"); } @@ -947,7 +1049,7 @@ void LazyValueInfoCache::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, // we clear their entries from the cache, and allow lazy updating to recompute // them when needed. - // The updating process is fairly simple: we need to dropped cached info + // The updating process is fairly simple: we need to drop cached info // for all values that were marked overdefined in OldSucc, and for those same // values in any successor of OldSucc (except NewSucc) in which they were // also marked overdefined. @@ -955,11 +1057,9 @@ void LazyValueInfoCache::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, worklist.push_back(OldSucc); DenseSet<Value*> ClearSet; - for (DenseSet<OverDefinedPairTy>::iterator I = OverDefinedCache.begin(), - E = OverDefinedCache.end(); I != E; ++I) { - if (I->first == OldSucc) - ClearSet.insert(I->second); - } + for (OverDefinedPairTy &P : OverDefinedCache) + if (P.first == OldSucc) + ClearSet.insert(P.second); // Use a worklist to perform a depth-first search of OldSucc's successors. // NOTE: We do not need a visited list since any blocks we have already @@ -973,15 +1073,14 @@ void LazyValueInfoCache::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, if (ToUpdate == NewSucc) continue; bool changed = false; - for (DenseSet<Value*>::iterator I = ClearSet.begin(), E = ClearSet.end(); - I != E; ++I) { + for (Value *V : ClearSet) { // If a value was marked overdefined in OldSucc, and is here too... DenseSet<OverDefinedPairTy>::iterator OI = - OverDefinedCache.find(std::make_pair(ToUpdate, *I)); + OverDefinedCache.find(std::make_pair(ToUpdate, V)); if (OI == OverDefinedCache.end()) continue; // Remove it from the caches. - ValueCacheEntryTy &Entry = ValueCache[LVIValueHandle(*I, this)]; + ValueCacheEntryTy &Entry = ValueCache[LVIValueHandle(V, this)]; ValueCacheEntryTy::iterator CI = Entry.find(ToUpdate); assert(CI != Entry.end() && "Couldn't find entry to update?"); @@ -1003,41 +1102,53 @@ void LazyValueInfoCache::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, // LazyValueInfo Impl //===----------------------------------------------------------------------===// -/// getCache - This lazily constructs the LazyValueInfoCache. -static LazyValueInfoCache &getCache(void *&PImpl) { +/// This lazily constructs the LazyValueInfoCache. +static LazyValueInfoCache &getCache(void *&PImpl, AssumptionCache *AC, + const DataLayout *DL = nullptr, + DominatorTree *DT = nullptr) { if (!PImpl) - PImpl = new LazyValueInfoCache(); + PImpl = new LazyValueInfoCache(AC, DL, DT); return *static_cast<LazyValueInfoCache*>(PImpl); } bool LazyValueInfo::runOnFunction(Function &F) { - if (PImpl) - getCache(PImpl).clear(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DT = DTWP ? &DTWP->getDomTree() : nullptr; DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; + TLI = &getAnalysis<TargetLibraryInfo>(); + if (PImpl) + getCache(PImpl, AC, DL, DT).clear(); + // Fully lazy. return false; } void LazyValueInfo::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); } void LazyValueInfo::releaseMemory() { // If the cache was allocated, free it. if (PImpl) { - delete &getCache(PImpl); + delete &getCache(PImpl, AC); PImpl = nullptr; } } -Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB) { - LVILatticeVal Result = getCache(PImpl).getValueInBlock(V, BB); - +Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AC, DL, DT).getValueInBlock(V, BB, CxtI); + if (Result.isConstant()) return Result.getConstant(); if (Result.isConstantRange()) { @@ -1048,12 +1159,14 @@ Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB) { return nullptr; } -/// getConstantOnEdge - Determine whether the specified value is known to be a +/// Determine whether the specified value is known to be a /// constant on the specified edge. Return null if not. Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, - BasicBlock *ToBB) { - LVILatticeVal Result = getCache(PImpl).getValueOnEdge(V, FromBB, ToBB); - + BasicBlock *ToBB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AC, DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + if (Result.isConstant()) return Result.getConstant(); if (Result.isConstantRange()) { @@ -1064,51 +1177,47 @@ Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, return nullptr; } -/// getPredicateOnEdge - Determine whether the specified value comparison -/// with a constant is known to be true or false on the specified CFG edge. -/// Pred is a CmpInst predicate. -LazyValueInfo::Tristate -LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, - BasicBlock *FromBB, BasicBlock *ToBB) { - LVILatticeVal Result = getCache(PImpl).getValueOnEdge(V, FromBB, ToBB); - +static LazyValueInfo::Tristate +getPredicateResult(unsigned Pred, Constant *C, LVILatticeVal &Result, + const DataLayout *DL, TargetLibraryInfo *TLI) { + // If we know the value is a constant, evaluate the conditional. Constant *Res = nullptr; if (Result.isConstant()) { Res = ConstantFoldCompareInstOperands(Pred, Result.getConstant(), C, DL, TLI); if (ConstantInt *ResCI = dyn_cast<ConstantInt>(Res)) - return ResCI->isZero() ? False : True; - return Unknown; + return ResCI->isZero() ? LazyValueInfo::False : LazyValueInfo::True; + return LazyValueInfo::Unknown; } if (Result.isConstantRange()) { ConstantInt *CI = dyn_cast<ConstantInt>(C); - if (!CI) return Unknown; + if (!CI) return LazyValueInfo::Unknown; ConstantRange CR = Result.getConstantRange(); if (Pred == ICmpInst::ICMP_EQ) { if (!CR.contains(CI->getValue())) - return False; + return LazyValueInfo::False; if (CR.isSingleElement() && CR.contains(CI->getValue())) - return True; + return LazyValueInfo::True; } else if (Pred == ICmpInst::ICMP_NE) { if (!CR.contains(CI->getValue())) - return True; + return LazyValueInfo::True; if (CR.isSingleElement() && CR.contains(CI->getValue())) - return False; + return LazyValueInfo::False; } // Handle more complex predicates. ConstantRange TrueValues = ICmpInst::makeConstantRange((ICmpInst::Predicate)Pred, CI->getValue()); if (TrueValues.contains(CR)) - return True; + return LazyValueInfo::True; if (TrueValues.inverse().contains(CR)) - return False; - return Unknown; + return LazyValueInfo::False; + return LazyValueInfo::Unknown; } if (Result.isNotConstant()) { @@ -1120,26 +1229,48 @@ LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, Result.getNotConstant(), C, DL, TLI); if (Res->isNullValue()) - return False; + return LazyValueInfo::False; } else if (Pred == ICmpInst::ICMP_NE) { // !C1 != C -> true iff C1 == C. Res = ConstantFoldCompareInstOperands(ICmpInst::ICMP_NE, Result.getNotConstant(), C, DL, TLI); if (Res->isNullValue()) - return True; + return LazyValueInfo::True; } - return Unknown; + return LazyValueInfo::Unknown; } - return Unknown; + return LazyValueInfo::Unknown; +} + +/// Determine whether the specified value comparison with a constant is known to +/// be true or false on the specified CFG edge. Pred is a CmpInst predicate. +LazyValueInfo::Tristate +LazyValueInfo::getPredicateOnEdge(unsigned Pred, Value *V, Constant *C, + BasicBlock *FromBB, BasicBlock *ToBB, + Instruction *CxtI) { + LVILatticeVal Result = + getCache(PImpl, AC, DL, DT).getValueOnEdge(V, FromBB, ToBB, CxtI); + + return getPredicateResult(Pred, C, Result, DL, TLI); +} + +LazyValueInfo::Tristate +LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, + Instruction *CxtI) { + LVILatticeVal Result = getCache(PImpl, AC, DL, DT).getValueAt(V, CxtI); + + return getPredicateResult(Pred, C, Result, DL, TLI); } void LazyValueInfo::threadEdge(BasicBlock *PredBB, BasicBlock *OldSucc, BasicBlock *NewSucc) { - if (PImpl) getCache(PImpl).threadEdge(PredBB, OldSucc, NewSucc); + if (PImpl) + getCache(PImpl, AC, DL, DT).threadEdge(PredBB, OldSucc, NewSucc); } void LazyValueInfo::eraseBlock(BasicBlock *BB) { - if (PImpl) getCache(PImpl).eraseBlock(BB); + if (PImpl) + getCache(PImpl, AC, DL, DT).eraseBlock(BB); } diff --git a/lib/Analysis/LibCallSemantics.cpp b/lib/Analysis/LibCallSemantics.cpp index 7d4e254a111d..23639e71232d 100644 --- a/lib/Analysis/LibCallSemantics.cpp +++ b/lib/Analysis/LibCallSemantics.cpp @@ -18,7 +18,7 @@ #include "llvm/IR/Function.h" using namespace llvm; -/// getMap - This impl pointer in ~LibCallInfo is actually a StringMap. This +/// This impl pointer in ~LibCallInfo is actually a StringMap. This /// helper does the cast. static StringMap<const LibCallFunctionInfo*> *getMap(void *Ptr) { return static_cast<StringMap<const LibCallFunctionInfo*> *>(Ptr); @@ -38,7 +38,7 @@ const LibCallLocationInfo &LibCallInfo::getLocationInfo(unsigned LocID) const { } -/// getFunctionInfo - Return the LibCallFunctionInfo object corresponding to +/// Return the LibCallFunctionInfo object corresponding to /// the specified function if we have it. If not, return null. const LibCallFunctionInfo * LibCallInfo::getFunctionInfo(const Function *F) const { diff --git a/lib/Analysis/Lint.cpp b/lib/Analysis/Lint.cpp index b14f3292e900..b5c7245030fc 100644 --- a/lib/Analysis/Lint.cpp +++ b/lib/Analysis/Lint.cpp @@ -37,6 +37,7 @@ #include "llvm/Analysis/Lint.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -96,11 +97,12 @@ namespace { Value *findValue(Value *V, bool OffsetOk) const; Value *findValueImpl(Value *V, bool OffsetOk, - SmallPtrSet<Value *, 4> &Visited) const; + SmallPtrSetImpl<Value *> &Visited) const; public: Module *Mod; AliasAnalysis *AA; + AssumptionCache *AC; DominatorTree *DT; const DataLayout *DL; TargetLibraryInfo *TLI; @@ -118,6 +120,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesAll(); AU.addRequired<AliasAnalysis>(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfo>(); AU.addRequired<DominatorTreeWrapperPass>(); } @@ -151,6 +154,7 @@ namespace { char Lint::ID = 0; INITIALIZE_PASS_BEGIN(Lint, "lint", "Statically lint-checks LLVM IR", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) @@ -175,6 +179,7 @@ INITIALIZE_PASS_END(Lint, "lint", "Statically lint-checks LLVM IR", bool Lint::runOnFunction(Function &F) { Mod = F.getParent(); AA = &getAnalysis<AliasAnalysis>(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -504,7 +509,8 @@ void Lint::visitShl(BinaryOperator &I) { "Undefined result: Shift count out of range", &I); } -static bool isZero(Value *V, const DataLayout *DL) { +static bool isZero(Value *V, const DataLayout *DL, DominatorTree *DT, + AssumptionCache *AC) { // Assume undef could be zero. if (isa<UndefValue>(V)) return true; @@ -513,7 +519,8 @@ static bool isZero(Value *V, const DataLayout *DL) { if (!VecTy) { unsigned BitWidth = V->getType()->getIntegerBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL); + computeKnownBits(V, KnownZero, KnownOne, DL, 0, AC, + dyn_cast<Instruction>(V), DT); return KnownZero.isAllOnesValue(); } @@ -543,22 +550,22 @@ static bool isZero(Value *V, const DataLayout *DL) { } void Lint::visitSDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AC), "Undefined behavior: Division by zero", &I); } void Lint::visitUDiv(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AC), "Undefined behavior: Division by zero", &I); } void Lint::visitSRem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AC), "Undefined behavior: Division by zero", &I); } void Lint::visitURem(BinaryOperator &I) { - Assert1(!isZero(I.getOperand(1), DL), + Assert1(!isZero(I.getOperand(1), DL, DT, AC), "Undefined behavior: Division by zero", &I); } @@ -622,9 +629,9 @@ Value *Lint::findValue(Value *V, bool OffsetOk) const { /// findValueImpl - Implementation helper for findValue. Value *Lint::findValueImpl(Value *V, bool OffsetOk, - SmallPtrSet<Value *, 4> &Visited) const { + SmallPtrSetImpl<Value *> &Visited) const { // Detect self-referential values. - if (!Visited.insert(V)) + if (!Visited.insert(V).second) return UndefValue::get(V->getType()); // TODO: Look through sext or zext cast, when the result is known to @@ -638,7 +645,8 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, BasicBlock *BB = L->getParent(); SmallPtrSet<BasicBlock *, 4> VisitedBlocks; for (;;) { - if (!VisitedBlocks.insert(BB)) break; + if (!VisitedBlocks.insert(BB).second) + break; if (Value *U = FindAvailableLoadedValue(L->getPointerOperand(), BB, BBI, 6, AA)) return findValueImpl(U, OffsetOk, Visited); @@ -678,7 +686,7 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, // As a last resort, try SimplifyInstruction or constant folding. if (Instruction *Inst = dyn_cast<Instruction>(V)) { - if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT)) + if (Value *W = SimplifyInstruction(Inst, DL, TLI, DT, AC)) return findValueImpl(W, OffsetOk, Visited); } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { if (Value *W = ConstantFoldConstantExpression(CE, DL, TLI)) diff --git a/lib/Analysis/Loads.cpp b/lib/Analysis/Loads.cpp index 005d309894c1..5042eb9777ac 100644 --- a/lib/Analysis/Loads.cpp +++ b/lib/Analysis/Loads.cpp @@ -22,25 +22,29 @@ #include "llvm/IR/Operator.h" using namespace llvm; -/// AreEquivalentAddressValues - Test if A and B will obviously have the same -/// value. This includes recognizing that %t0 and %t1 will have the same +/// \brief Test if A and B will obviously have the same value. +/// +/// This includes recognizing that %t0 and %t1 will have the same /// value in code like this: +/// \code /// %t0 = getelementptr \@a, 0, 3 /// store i32 0, i32* %t0 /// %t1 = getelementptr \@a, 0, 3 /// %t2 = load i32* %t1 +/// \endcode /// static bool AreEquivalentAddressValues(const Value *A, const Value *B) { // Test if the values are trivially equivalent. - if (A == B) return true; + if (A == B) + return true; // Test if the values come from identical arithmetic instructions. // Use isIdenticalToWhenDefined instead of isIdenticalTo because // this function is only used when one address use dominates the // other, which means that they'll always either have the same // value or one of them will have an undefined value. - if (isa<BinaryOperator>(A) || isa<CastInst>(A) || - isa<PHINode>(A) || isa<GetElementPtrInst>(A)) + if (isa<BinaryOperator>(A) || isa<CastInst>(A) || isa<PHINode>(A) || + isa<GetElementPtrInst>(A)) if (const Instruction *BI = dyn_cast<Instruction>(B)) if (cast<Instruction>(A)->isIdenticalToWhenDefined(BI)) return true; @@ -49,15 +53,19 @@ static bool AreEquivalentAddressValues(const Value *A, const Value *B) { return false; } -/// isSafeToLoadUnconditionally - Return true if we know that executing a load -/// from this value cannot trap. If it is not obviously safe to load from the -/// specified pointer, we do a quick local scan of the basic block containing -/// ScanFrom, to determine if the address is already accessed. +/// \brief Check if executing a load of this pointer value cannot trap. +/// +/// If it is not obviously safe to load from the specified pointer, we do +/// a quick local scan of the basic block containing \c ScanFrom, to determine +/// if the address is already accessed. +/// +/// This uses the pointee type to determine how many bytes need to be safe to +/// load from the pointer. bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, - unsigned Align, const DataLayout *TD) { + unsigned Align, const DataLayout *DL) { int64_t ByteOffset = 0; Value *Base = V; - Base = GetPointerBaseWithConstantOffset(V, ByteOffset, TD); + Base = GetPointerBaseWithConstantOffset(V, ByteOffset, DL); if (ByteOffset < 0) // out of bounds return false; @@ -69,26 +77,29 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, BaseType = AI->getAllocatedType(); BaseAlign = AI->getAlignment(); } else if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { - // Global variables are safe to load from but their size cannot be - // guaranteed if they are overridden. + // Global variables are not necessarily safe to load from if they are + // overridden. Their size may change or they may be weak and require a test + // to determine if they were in fact provided. if (!GV->mayBeOverridden()) { BaseType = GV->getType()->getElementType(); BaseAlign = GV->getAlignment(); } } - if (BaseType && BaseType->isSized()) { - if (TD && BaseAlign == 0) - BaseAlign = TD->getPrefTypeAlignment(BaseType); + PointerType *AddrTy = cast<PointerType>(V->getType()); + uint64_t LoadSize = DL ? DL->getTypeStoreSize(AddrTy->getElementType()) : 0; - if (Align <= BaseAlign) { - if (!TD) - return true; // Loading directly from an alloca or global is OK. + // If we found a base allocated type from either an alloca or global variable, + // try to see if we are definitively within the allocated region. We need to + // know the size of the base type and the loaded type to do anything in this + // case, so only try this when we have the DataLayout available. + if (BaseType && BaseType->isSized() && DL) { + if (BaseAlign == 0) + BaseAlign = DL->getPrefTypeAlignment(BaseType); + if (Align <= BaseAlign) { // Check if the load is within the bounds of the underlying object. - PointerType *AddrTy = cast<PointerType>(V->getType()); - uint64_t LoadSize = TD->getTypeStoreSize(AddrTy->getElementType()); - if (ByteOffset + LoadSize <= TD->getTypeAllocSize(BaseType) && + if (ByteOffset + LoadSize <= DL->getTypeAllocSize(BaseType) && (Align == 0 || (ByteOffset % Align) == 0)) return true; } @@ -101,6 +112,10 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, // the load entirely). BasicBlock::iterator BBI = ScanFrom, E = ScanFrom->getParent()->begin(); + // We can at least always strip pointer casts even though we can't use the + // base here. + V = V->stripPointerCasts(); + while (BBI != E) { --BBI; @@ -110,46 +125,67 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, !isa<DbgInfoIntrinsic>(BBI)) return false; - if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) { - if (AreEquivalentAddressValues(LI->getOperand(0), V)) return true; - } else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) { - if (AreEquivalentAddressValues(SI->getOperand(1), V)) return true; - } + Value *AccessedPtr; + if (LoadInst *LI = dyn_cast<LoadInst>(BBI)) + AccessedPtr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast<StoreInst>(BBI)) + AccessedPtr = SI->getPointerOperand(); + else + continue; + + // Handle trivial cases even w/o DataLayout or other work. + if (AccessedPtr == V) + return true; + + if (!DL) + continue; + + auto *AccessedTy = cast<PointerType>(AccessedPtr->getType()); + if (AreEquivalentAddressValues(AccessedPtr->stripPointerCasts(), V) && + LoadSize <= DL->getTypeStoreSize(AccessedTy->getElementType())) + return true; } return false; } -/// FindAvailableLoadedValue - Scan the ScanBB block backwards (starting at the -/// instruction before ScanFrom) checking to see if we have the value at the +/// \brief Scan the ScanBB block backwards to see if we have the value at the /// memory address *Ptr locally available within a small number of instructions. -/// If the value is available, return it. /// -/// If not, return the iterator for the last validated instruction that the -/// value would be live through. If we scanned the entire block and didn't find -/// something that invalidates *Ptr or provides it, ScanFrom would be left at -/// begin() and this returns null. ScanFrom could also be left +/// The scan starts from \c ScanFrom. \c MaxInstsToScan specifies the maximum +/// instructions to scan in the block. If it is set to \c 0, it will scan the whole +/// block. /// -/// MaxInstsToScan specifies the maximum instructions to scan in the block. If -/// it is set to 0, it will scan the whole block. You can also optionally -/// specify an alias analysis implementation, which makes this more precise. +/// If the value is available, this function returns it. If not, it returns the +/// iterator for the last validated instruction that the value would be live +/// through. If we scanned the entire block and didn't find something that +/// invalidates \c *Ptr or provides it, \c ScanFrom is left at the last +/// instruction processed and this returns null. /// -/// If TBAATag is non-null and a load or store is found, the TBAA tag from the -/// load or store is recorded there. If there is no TBAA tag or if no access -/// is found, it is left unmodified. +/// You can also optionally specify an alias analysis implementation, which +/// makes this more precise. +/// +/// If \c AATags is non-null and a load or store is found, the AA tags from the +/// load or store are recorded there. If there are no AA tags or if no access is +/// found, it is left unmodified. Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, - AliasAnalysis *AA, - MDNode **TBAATag) { - if (MaxInstsToScan == 0) MaxInstsToScan = ~0U; - - // If we're using alias analysis to disambiguate get the size of *Ptr. - uint64_t AccessSize = 0; - if (AA) { - Type *AccessTy = cast<PointerType>(Ptr->getType())->getElementType(); - AccessSize = AA->getTypeStoreSize(AccessTy); - } - + AliasAnalysis *AA, AAMDNodes *AATags) { + if (MaxInstsToScan == 0) + MaxInstsToScan = ~0U; + + Type *AccessTy = cast<PointerType>(Ptr->getType())->getElementType(); + + // Try to get the DataLayout for this module. This may be null, in which case + // the optimizations will be limited. + const DataLayout *DL = ScanBB->getDataLayout(); + + // Try to get the store size for the type. + uint64_t AccessSize = DL ? DL->getTypeStoreSize(AccessTy) + : AA ? AA->getTypeStoreSize(AccessTy) : 0; + + Value *StrippedPtr = Ptr->stripPointerCasts(); + while (ScanFrom != ScanBB->begin()) { // We must ignore debug info directives when counting (otherwise they // would affect codegen). @@ -159,62 +195,72 @@ Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, // Restore ScanFrom to expected value in case next test succeeds ScanFrom++; - + // Don't scan huge blocks. - if (MaxInstsToScan-- == 0) return nullptr; - + if (MaxInstsToScan-- == 0) + return nullptr; + --ScanFrom; // If this is a load of Ptr, the loaded value is available. // (This is true even if the load is volatile or atomic, although // those cases are unlikely.) if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) - if (AreEquivalentAddressValues(LI->getOperand(0), Ptr)) { - if (TBAATag) *TBAATag = LI->getMetadata(LLVMContext::MD_tbaa); + if (AreEquivalentAddressValues( + LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) { + if (AATags) + LI->getAAMetadata(*AATags); return LI; } - + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + Value *StorePtr = SI->getPointerOperand()->stripPointerCasts(); // If this is a store through Ptr, the value is available! // (This is true even if the store is volatile or atomic, although // those cases are unlikely.) - if (AreEquivalentAddressValues(SI->getOperand(1), Ptr)) { - if (TBAATag) *TBAATag = SI->getMetadata(LLVMContext::MD_tbaa); + if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && + CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(), + AccessTy, DL)) { + if (AATags) + SI->getAAMetadata(*AATags); return SI->getOperand(0); } - - // If Ptr is an alloca and this is a store to a different alloca, ignore - // the store. This is a trivial form of alias analysis that is important - // for reg2mem'd code. - if ((isa<AllocaInst>(Ptr) || isa<GlobalVariable>(Ptr)) && - (isa<AllocaInst>(SI->getOperand(1)) || - isa<GlobalVariable>(SI->getOperand(1)))) + + // If both StrippedPtr and StorePtr reach all the way to an alloca or + // global and they are different, ignore the store. This is a trivial form + // of alias analysis that is important for reg2mem'd code. + if ((isa<AllocaInst>(StrippedPtr) || isa<GlobalVariable>(StrippedPtr)) && + (isa<AllocaInst>(StorePtr) || isa<GlobalVariable>(StorePtr)) && + StrippedPtr != StorePtr) continue; - + // If we have alias analysis and it says the store won't modify the loaded // value, ignore the store. if (AA && - (AA->getModRefInfo(SI, Ptr, AccessSize) & AliasAnalysis::Mod) == 0) + (AA->getModRefInfo(SI, StrippedPtr, AccessSize) & + AliasAnalysis::Mod) == 0) continue; - + // Otherwise the store that may or may not alias the pointer, bail out. ++ScanFrom; return nullptr; } - + // If this is some other instruction that may clobber Ptr, bail out. if (Inst->mayWriteToMemory()) { // If alias analysis claims that it really won't modify the load, // ignore it. if (AA && - (AA->getModRefInfo(Inst, Ptr, AccessSize) & AliasAnalysis::Mod) == 0) + (AA->getModRefInfo(Inst, StrippedPtr, AccessSize) & + AliasAnalysis::Mod) == 0) continue; - + // May modify the pointer, bail out. ++ScanFrom; return nullptr; } } - + // Got to the start of the block, we didn't find it, but are done for this // block. return nullptr; diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index 46c0eaabe1a3..b1f62c437326 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -24,6 +24,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -307,7 +308,8 @@ bool Loop::isAnnotatedParallel() const { // directly or indirectly through another list metadata (in case of // nested parallel loops). The loop identifier metadata refers to // itself so we can check both cases with the same routine. - MDNode *loopIdMD = II->getMetadata("llvm.mem.parallel_loop_access"); + MDNode *loopIdMD = + II->getMetadata(LLVMContext::MD_mem_parallel_loop_access); if (!loopIdMD) return false; diff --git a/lib/Analysis/LoopPass.cpp b/lib/Analysis/LoopPass.cpp index 7bd866e73e10..190abc73fbb4 100644 --- a/lib/Analysis/LoopPass.cpp +++ b/lib/Analysis/LoopPass.cpp @@ -76,6 +76,9 @@ void LPPassManager::deleteLoopFromQueue(Loop *L) { LI->updateUnloop(L); + // Notify passes that the loop is being deleted. + deleteSimpleAnalysisLoop(L); + // If L is current loop then skip rest of the passes and let // runOnFunction remove L from LQ. Otherwise, remove L from LQ now // and continue applying other passes on CurrentLoop. @@ -164,6 +167,14 @@ void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { } } +/// Invoke deleteAnalysisLoop hook for all passes. +void LPPassManager::deleteSimpleAnalysisLoop(Loop *L) { + for (unsigned Index = 0; Index < getNumContainedPasses(); ++Index) { + LoopPass *LP = getContainedPass(Index); + LP->deleteAnalysisLoop(L); + } +} + // Recurse through all subloops and all loops into LQ. static void addLoopIntoQueue(Loop *L, std::deque<Loop *> &LQ) { diff --git a/lib/Analysis/MemDepPrinter.cpp b/lib/Analysis/MemDepPrinter.cpp index 10da3d5d6187..ffc9fe64cc52 100644 --- a/lib/Analysis/MemDepPrinter.cpp +++ b/lib/Analysis/MemDepPrinter.cpp @@ -92,7 +92,6 @@ const char *const MemDepPrinter::DepTypeStr[] bool MemDepPrinter::runOnFunction(Function &F) { this->F = &F; - AliasAnalysis &AA = getAnalysis<AliasAnalysis>(); MemoryDependenceAnalysis &MDA = getAnalysis<MemoryDependenceAnalysis>(); // All this code uses non-const interfaces because MemDep is not @@ -119,30 +118,9 @@ bool MemDepPrinter::runOnFunction(Function &F) { } } else { SmallVector<NonLocalDepResult, 4> NLDI; - if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - if (!LI->isUnordered()) { - // FIXME: Handle atomic/volatile loads. - Deps[Inst].insert(std::make_pair(getInstTypePair(nullptr, Unknown), - static_cast<BasicBlock *>(nullptr))); - continue; - } - AliasAnalysis::Location Loc = AA.getLocation(LI); - MDA.getNonLocalPointerDependency(Loc, true, LI->getParent(), NLDI); - } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - if (!SI->isUnordered()) { - // FIXME: Handle atomic/volatile stores. - Deps[Inst].insert(std::make_pair(getInstTypePair(nullptr, Unknown), - static_cast<BasicBlock *>(nullptr))); - continue; - } - AliasAnalysis::Location Loc = AA.getLocation(SI); - MDA.getNonLocalPointerDependency(Loc, false, SI->getParent(), NLDI); - } else if (VAArgInst *VI = dyn_cast<VAArgInst>(Inst)) { - AliasAnalysis::Location Loc = AA.getLocation(VI); - MDA.getNonLocalPointerDependency(Loc, false, VI->getParent(), NLDI); - } else { - llvm_unreachable("Unknown memory instruction!"); - } + assert( (isa<LoadInst>(Inst) || isa<StoreInst>(Inst) || + isa<VAArgInst>(Inst)) && "Unknown memory instruction!"); + MDA.getNonLocalPointerDependency(Inst, NLDI); DepSet &InstDeps = Deps[Inst]; for (SmallVectorImpl<NonLocalDepResult>::const_iterator diff --git a/lib/Analysis/MemoryBuiltins.cpp b/lib/Analysis/MemoryBuiltins.cpp index 64d339f564d7..08b41fee4459 100644 --- a/lib/Analysis/MemoryBuiltins.cpp +++ b/lib/Analysis/MemoryBuiltins.cpp @@ -332,7 +332,11 @@ const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { TLIFn == LibFunc::ZdlPv || // operator delete(void*) TLIFn == LibFunc::ZdaPv) // operator delete[](void*) ExpectedNumParams = 1; - else if (TLIFn == LibFunc::ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + else if (TLIFn == LibFunc::ZdlPvj || // delete(void*, uint) + TLIFn == LibFunc::ZdlPvm || // delete(void*, ulong) + TLIFn == LibFunc::ZdlPvRKSt9nothrow_t || // delete(void*, nothrow) + TLIFn == LibFunc::ZdaPvj || // delete[](void*, uint) + TLIFn == LibFunc::ZdaPvm || // delete[](void*, ulong) TLIFn == LibFunc::ZdaPvRKSt9nothrow_t) // delete[](void*, nothrow) ExpectedNumParams = 2; else @@ -412,7 +416,7 @@ SizeOffsetType ObjectSizeOffsetVisitor::compute(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) { // If we have already seen this instruction, bail out. Cycles can happen in // unreachable code after constant propagation. - if (!SeenInsts.insert(I)) + if (!SeenInsts.insert(I).second) return unknown(); if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) @@ -648,7 +652,7 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { // Record the pointers that were handled in this run, so that they can be // cleaned later if something fails. We also use this set to break cycles that // can occur in dead code. - if (!SeenVals.insert(V)) { + if (!SeenVals.insert(V).second) { Result = unknown(); } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { Result = visitGEPOperator(*GEP); diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp index 9eaf10958694..c505aa488170 100644 --- a/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -18,6 +18,7 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/PHITransAddr.h" @@ -48,13 +49,17 @@ STATISTIC(NumCacheCompleteNonLocalPtr, "Number of block queries that were completely cached"); // Limit for the number of instructions to scan in a block. -static const int BlockScanLimit = 100; +static const unsigned int BlockScanLimit = 100; + +// Limit on the number of memdep results to process. +static const unsigned int NumResultsLimit = 100; char MemoryDependenceAnalysis::ID = 0; // Register this pass... INITIALIZE_PASS_BEGIN(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_AG_DEPENDENCY(AliasAnalysis) INITIALIZE_PASS_END(MemoryDependenceAnalysis, "memdep", "Memory Dependence Analysis", false, true) @@ -83,11 +88,13 @@ void MemoryDependenceAnalysis::releaseMemory() { /// void MemoryDependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequiredTransitive<AliasAnalysis>(); } -bool MemoryDependenceAnalysis::runOnFunction(Function &) { +bool MemoryDependenceAnalysis::runOnFunction(Function &F) { AA = &getAnalysis<AliasAnalysis>(); + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; DominatorTreeWrapperPass *DTWP = @@ -158,29 +165,32 @@ AliasAnalysis::ModRefResult GetLocation(const Instruction *Inst, return AliasAnalysis::Mod; } - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + AAMDNodes AAInfo; + switch (II->getIntrinsicID()) { case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: case Intrinsic::invariant_start: + II->getAAMetadata(AAInfo); Loc = AliasAnalysis::Location(II->getArgOperand(1), cast<ConstantInt>(II->getArgOperand(0)) - ->getZExtValue(), - II->getMetadata(LLVMContext::MD_tbaa)); + ->getZExtValue(), AAInfo); // These intrinsics don't really modify the memory, but returning Mod // will allow them to be handled conservatively. return AliasAnalysis::Mod; case Intrinsic::invariant_end: + II->getAAMetadata(AAInfo); Loc = AliasAnalysis::Location(II->getArgOperand(2), cast<ConstantInt>(II->getArgOperand(1)) - ->getZExtValue(), - II->getMetadata(LLVMContext::MD_tbaa)); + ->getZExtValue(), AAInfo); // These intrinsics don't really modify the memory, but returning Mod // will allow them to be handled conservatively. return AliasAnalysis::Mod; default: break; } + } // Otherwise, just do the coarse-grained thing that always works. if (Inst->mayWriteToMemory()) @@ -367,6 +377,36 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, int64_t MemLocOffset = 0; unsigned Limit = BlockScanLimit; bool isInvariantLoad = false; + + // We must be careful with atomic accesses, as they may allow another thread + // to touch this location, cloberring it. We are conservative: if the + // QueryInst is not a simple (non-atomic) memory access, we automatically + // return getClobber. + // If it is simple, we know based on the results of + // "Compiler testing via a theory of sound optimisations in the C11/C++11 + // memory model" in PLDI 2013, that a non-atomic location can only be + // clobbered between a pair of a release and an acquire action, with no + // access to the location in between. + // Here is an example for giving the general intuition behind this rule. + // In the following code: + // store x 0; + // release action; [1] + // acquire action; [4] + // %val = load x; + // It is unsafe to replace %val by 0 because another thread may be running: + // acquire action; [2] + // store x 42; + // release action; [3] + // with synchronization from 1 to 2 and from 3 to 4, resulting in %val + // being 42. A key property of this program however is that if either + // 1 or 4 were missing, there would be a race between the store of 42 + // either the store of 0 or the load (making the whole progam racy). + // The paper mentionned above shows that the same property is respected + // by every program that can detect any optimisation of that kind: either + // it is racy (undefined) or there is a release followed by an acquire + // between the pair of accesses under consideration. + bool HasSeenAcquire = false; + if (isLoad && QueryInst) { LoadInst *LI = dyn_cast<LoadInst>(QueryInst); if (LI && LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr) @@ -404,10 +444,37 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, // Values depend on loads if the pointers are must aliased. This means that // a load depends on another must aliased load from the same value. + // One exception is atomic loads: a value can depend on an atomic load that it + // does not alias with when this atomic load indicates that another thread may + // be accessing the location. if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { // Atomic loads have complications involved. + // A Monotonic (or higher) load is OK if the query inst is itself not atomic. + // An Acquire (or higher) load sets the HasSeenAcquire flag, so that any + // release store will know to return getClobber. // FIXME: This is overly conservative. - if (!LI->isUnordered()) + if (!LI->isUnordered()) { + if (!QueryInst) + return MemDepResult::getClobber(LI); + if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { + if (!QueryLI->isSimple()) + return MemDepResult::getClobber(LI); + } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { + if (!QuerySI->isSimple()) + return MemDepResult::getClobber(LI); + } else if (QueryInst->mayReadOrWriteMemory()) { + return MemDepResult::getClobber(LI); + } + + if (isAtLeastAcquire(LI->getOrdering())) + HasSeenAcquire = true; + } + + // FIXME: this is overly conservative. + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses can for example be reordered + // with volatile accesses. + if (LI->isVolatile()) return MemDepResult::getClobber(LI); AliasAnalysis::Location LoadLoc = AA->getLocation(LI); @@ -466,8 +533,32 @@ getPointerDependencyFrom(const AliasAnalysis::Location &MemLoc, bool isLoad, if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { // Atomic stores have complications involved. + // A Monotonic store is OK if the query inst is itself not atomic. + // A Release (or higher) store further requires that no acquire load + // has been seen. // FIXME: This is overly conservative. - if (!SI->isUnordered()) + if (!SI->isUnordered()) { + if (!QueryInst) + return MemDepResult::getClobber(SI); + if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { + if (!QueryLI->isSimple()) + return MemDepResult::getClobber(SI); + } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { + if (!QuerySI->isSimple()) + return MemDepResult::getClobber(SI); + } else if (QueryInst->mayReadOrWriteMemory()) { + return MemDepResult::getClobber(SI); + } + + if (HasSeenAcquire && isAtLeastRelease(SI->getOrdering())) + return MemDepResult::getClobber(SI); + } + + // FIXME: this is overly conservative. + // While volatile access cannot be eliminated, they do not have to clobber + // non-aliasing locations, as normal accesses can for example be reordered + // with volatile accesses. + if (SI->isVolatile()) return MemDepResult::getClobber(SI); // If alias analysis can tell that this store is guaranteed to not modify @@ -685,7 +776,7 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { DirtyBlocks.pop_back(); // Already processed this block? - if (!Visited.insert(DirtyBB)) + if (!Visited.insert(DirtyBB).second) continue; // Do a binary search to see if we already have an entry for this block in @@ -768,14 +859,65 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { /// own block. /// void MemoryDependenceAnalysis:: -getNonLocalPointerDependency(const AliasAnalysis::Location &Loc, bool isLoad, - BasicBlock *FromBB, +getNonLocalPointerDependency(Instruction *QueryInst, SmallVectorImpl<NonLocalDepResult> &Result) { + + auto getLocation = [](AliasAnalysis *AA, Instruction *Inst) { + if (auto *I = dyn_cast<LoadInst>(Inst)) + return AA->getLocation(I); + else if (auto *I = dyn_cast<StoreInst>(Inst)) + return AA->getLocation(I); + else if (auto *I = dyn_cast<VAArgInst>(Inst)) + return AA->getLocation(I); + else if (auto *I = dyn_cast<AtomicCmpXchgInst>(Inst)) + return AA->getLocation(I); + else if (auto *I = dyn_cast<AtomicRMWInst>(Inst)) + return AA->getLocation(I); + else + llvm_unreachable("unsupported memory instruction"); + }; + + const AliasAnalysis::Location Loc = getLocation(AA, QueryInst); + bool isLoad = isa<LoadInst>(QueryInst); + BasicBlock *FromBB = QueryInst->getParent(); + assert(FromBB); + assert(Loc.Ptr->getType()->isPointerTy() && "Can't get pointer deps of a non-pointer!"); Result.clear(); + + // This routine does not expect to deal with volatile instructions. + // Doing so would require piping through the QueryInst all the way through. + // TODO: volatiles can't be elided, but they can be reordered with other + // non-volatile accesses. + auto isVolatile = [](Instruction *Inst) { + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return LI->isVolatile(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return SI->isVolatile(); + } + return false; + }; + // We currently give up on any instruction which is ordered, but we do handle + // atomic instructions which are unordered. + // TODO: Handle ordered instructions + auto isOrdered = [](Instruction *Inst) { + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return !LI->isUnordered(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return !SI->isUnordered(); + } + return false; + }; + if (isVolatile(QueryInst) || isOrdered(QueryInst)) { + Result.push_back(NonLocalDepResult(FromBB, + MemDepResult::getUnknown(), + const_cast<Value *>(Loc.Ptr))); + return; + } - PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL); + + PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, AC); // This is the set of blocks we've inspected, and the pointer we consider in // each block. Because of critical edges, we currently bail out if querying @@ -861,7 +1003,7 @@ GetNonLocalInfoForBlock(const AliasAnalysis::Location &Loc, return Dep; } -/// SortNonLocalDepInfoCache - Sort the a NonLocalDepInfo cache, given a certain +/// SortNonLocalDepInfoCache - Sort the NonLocalDepInfo cache, given a certain /// number of elements in the array that are already properly ordered. This is /// optimized for the case when only a few entries are added. static void @@ -922,10 +1064,10 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, // Set up a temporary NLPI value. If the map doesn't yet have an entry for // CacheKey, this value will be inserted as the associated value. Otherwise, // it'll be ignored, and we'll have to check to see if the cached size and - // tbaa tag are consistent with the current query. + // aa tags are consistent with the current query. NonLocalPointerInfo InitialNLPI; InitialNLPI.Size = Loc.Size; - InitialNLPI.TBAATag = Loc.TBAATag; + InitialNLPI.AATags = Loc.AATags; // Get the NLPI for CacheKey, inserting one into the map if it doesn't // already have one. @@ -955,21 +1097,21 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, SkipFirstBlock); } - // If the query's TBAATag is inconsistent with the cached one, + // If the query's AATags are inconsistent with the cached one, // conservatively throw out the cached data and restart the query with // no tag if needed. - if (CacheInfo->TBAATag != Loc.TBAATag) { - if (CacheInfo->TBAATag) { + if (CacheInfo->AATags != Loc.AATags) { + if (CacheInfo->AATags) { CacheInfo->Pair = BBSkipFirstBlockPair(); - CacheInfo->TBAATag = nullptr; + CacheInfo->AATags = AAMDNodes(); for (NonLocalDepInfo::iterator DI = CacheInfo->NonLocalDeps.begin(), DE = CacheInfo->NonLocalDeps.end(); DI != DE; ++DI) if (Instruction *Inst = DI->getResult().getInst()) RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); CacheInfo->NonLocalDeps.clear(); } - if (Loc.TBAATag) - return getNonLocalPointerDepFromBB(Pointer, Loc.getWithoutTBAATag(), + if (Loc.AATags) + return getNonLocalPointerDepFromBB(Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result, Visited, SkipFirstBlock); } @@ -1045,6 +1187,24 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, while (!Worklist.empty()) { BasicBlock *BB = Worklist.pop_back_val(); + // If we do process a large number of blocks it becomes very expensive and + // likely it isn't worth worrying about + if (Result.size() > NumResultsLimit) { + Worklist.clear(); + // Sort it now (if needed) so that recursive invocations of + // getNonLocalPointerDepFromBB and other routines that could reuse the + // cache value will only see properly sorted cache arrays. + if (Cache && NumSortedEntries != Cache->size()) { + SortNonLocalDepInfoCache(*Cache, NumSortedEntries); + } + // Since we bail out, the "Cache" set won't contain all of the + // results for the query. This is ok (we can still use it to accelerate + // specific block queries) but we can't do the fastpath "return all + // results from the set". Clear out the indicator for this. + CacheInfo->Pair = BBSkipFirstBlockPair(); + return true; + } + // Skip the first block if we have it. if (!SkipFirstBlock) { // Analyze the dependency of *Pointer in FromBB. See if we already have @@ -1251,7 +1411,7 @@ getNonLocalPointerDepFromBB(const PHITransAddr &Pointer, if (I->getBB() != BB) continue; - assert(I->getResult().isNonLocal() && + assert((I->getResult().isNonLocal() || !DT->isReachableFromEntry(BB)) && "Should only be here with transparent block"); I->setResult(MemDepResult::getUnknown()); Result.push_back(NonLocalDepResult(I->getBB(), I->getResult(), @@ -1369,14 +1529,11 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); if (ReverseDepIt != ReverseLocalDeps.end()) { - SmallPtrSet<Instruction*, 4> &ReverseDeps = ReverseDepIt->second; // RemInst can't be the terminator if it has local stuff depending on it. - assert(!ReverseDeps.empty() && !isa<TerminatorInst>(RemInst) && + assert(!ReverseDepIt->second.empty() && !isa<TerminatorInst>(RemInst) && "Nothing can locally depend on a terminator"); - for (SmallPtrSet<Instruction*, 4>::iterator I = ReverseDeps.begin(), - E = ReverseDeps.end(); I != E; ++I) { - Instruction *InstDependingOnRemInst = *I; + for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) { assert(InstDependingOnRemInst != RemInst && "Already removed our local dep info"); @@ -1402,12 +1559,10 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseDepIt = ReverseNonLocalDeps.find(RemInst); if (ReverseDepIt != ReverseNonLocalDeps.end()) { - SmallPtrSet<Instruction*, 4> &Set = ReverseDepIt->second; - for (SmallPtrSet<Instruction*, 4>::iterator I = Set.begin(), E = Set.end(); - I != E; ++I) { - assert(*I != RemInst && "Already removed NonLocalDep info for RemInst"); + for (Instruction *I : ReverseDepIt->second) { + assert(I != RemInst && "Already removed NonLocalDep info for RemInst"); - PerInstNLInfo &INLD = NonLocalDeps[*I]; + PerInstNLInfo &INLD = NonLocalDeps[I]; // The information is now dirty! INLD.second = true; @@ -1419,7 +1574,7 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { DI->setResult(NewDirtyVal); if (Instruction *NextI = NewDirtyVal.getInst()) - ReverseDepsToAdd.push_back(std::make_pair(NextI, *I)); + ReverseDepsToAdd.push_back(std::make_pair(NextI, I)); } } @@ -1438,12 +1593,9 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = ReverseNonLocalPtrDeps.find(RemInst); if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { - SmallPtrSet<ValueIsLoadPair, 4> &Set = ReversePtrDepIt->second; SmallVector<std::pair<Instruction*, ValueIsLoadPair>,8> ReversePtrDepsToAdd; - for (SmallPtrSet<ValueIsLoadPair, 4>::iterator I = Set.begin(), - E = Set.end(); I != E; ++I) { - ValueIsLoadPair P = *I; + for (ValueIsLoadPair P : ReversePtrDepIt->second) { assert(P.getPointer() != RemInst && "Already removed NonLocalPointerDeps info for RemInst"); @@ -1484,8 +1636,10 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { DEBUG(verifyRemoved(RemInst)); } /// verifyRemoved - Verify that the specified instruction does not occur -/// in our internal data structures. +/// in our internal data structures. This function verifies by asserting in +/// debug builds. void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { +#ifndef NDEBUG for (LocalDepMapType::const_iterator I = LocalDeps.begin(), E = LocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); @@ -1514,18 +1668,16 @@ void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { for (ReverseDepMapType::const_iterator I = ReverseLocalDeps.begin(), E = ReverseLocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); - for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), - EE = I->second.end(); II != EE; ++II) - assert(*II != D && "Inst occurs in data structures"); + for (Instruction *Inst : I->second) + assert(Inst != D && "Inst occurs in data structures"); } for (ReverseDepMapType::const_iterator I = ReverseNonLocalDeps.begin(), E = ReverseNonLocalDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in data structures"); - for (SmallPtrSet<Instruction*, 4>::const_iterator II = I->second.begin(), - EE = I->second.end(); II != EE; ++II) - assert(*II != D && "Inst occurs in data structures"); + for (Instruction *Inst : I->second) + assert(Inst != D && "Inst occurs in data structures"); } for (ReverseNonLocalPtrDepTy::const_iterator @@ -1533,11 +1685,10 @@ void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { E = ReverseNonLocalPtrDeps.end(); I != E; ++I) { assert(I->first != D && "Inst occurs in rev NLPD map"); - for (SmallPtrSet<ValueIsLoadPair, 4>::const_iterator II = I->second.begin(), - E = I->second.end(); II != E; ++II) - assert(*II != ValueIsLoadPair(D, false) && - *II != ValueIsLoadPair(D, true) && + for (ValueIsLoadPair P : I->second) + assert(P != ValueIsLoadPair(D, false) && + P != ValueIsLoadPair(D, true) && "Inst occurs in ReverseNonLocalPtrDeps map"); } - +#endif } diff --git a/lib/Analysis/NoAliasAnalysis.cpp b/lib/Analysis/NoAliasAnalysis.cpp index 139fa38b8a94..c214d3cdf17a 100644 --- a/lib/Analysis/NoAliasAnalysis.cpp +++ b/lib/Analysis/NoAliasAnalysis.cpp @@ -57,8 +57,9 @@ namespace { Location getArgLocation(ImmutableCallSite CS, unsigned ArgIdx, ModRefResult &Mask) override { Mask = ModRef; - return Location(CS.getArgument(ArgIdx), UnknownSize, - CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)); + AAMDNodes AATags; + CS->getAAMetadata(AATags); + return Location(CS.getArgument(ArgIdx), UnknownSize, AATags); } ModRefResult getModRefInfo(ImmutableCallSite CS, diff --git a/lib/Analysis/PHITransAddr.cpp b/lib/Analysis/PHITransAddr.cpp index bfe86425119b..a534418d199e 100644 --- a/lib/Analysis/PHITransAddr.cpp +++ b/lib/Analysis/PHITransAddr.cpp @@ -228,7 +228,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return GEP; // Simplify the GEP to handle 'gep x, 0' -> x etc. - if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT)) { + if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT, AC)) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); @@ -283,7 +283,7 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, } // See if the add simplifies away. - if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT)) { + if (Value *Res = SimplifyAddInst(LHS, RHS, isNSW, isNUW, DL, TLI, DT, AC)) { // If we simplified the operands, the LHS is no longer an input, but Res // is. RemoveInstInputs(LHS, InstInputs); @@ -369,7 +369,7 @@ InsertPHITranslatedSubExpr(Value *InVal, BasicBlock *CurBB, SmallVectorImpl<Instruction*> &NewInsts) { // See if we have a version of this value already available and dominating // PredBB. If so, there is no need to insert a new instance of it. - PHITransAddr Tmp(InVal, DL); + PHITransAddr Tmp(InVal, DL, AC); if (!Tmp.PHITranslateValue(CurBB, PredBB, &DT)) return Tmp.getAddr(); diff --git a/lib/Analysis/PtrUseVisitor.cpp b/lib/Analysis/PtrUseVisitor.cpp index 1b0f359e6366..68c7535ea594 100644 --- a/lib/Analysis/PtrUseVisitor.cpp +++ b/lib/Analysis/PtrUseVisitor.cpp @@ -17,7 +17,7 @@ using namespace llvm; void detail::PtrUseVisitorBase::enqueueUsers(Instruction &I) { for (Use &U : I.uses()) { - if (VisitedUses.insert(&U)) { + if (VisitedUses.insert(&U).second) { UseToVisit NewU = { UseToVisit::UseAndIsOffsetKnownPair(&U, IsOffsetKnown), Offset diff --git a/lib/Analysis/RegionInfo.cpp b/lib/Analysis/RegionInfo.cpp index 08ebf0d85723..8cd85348fdcc 100644 --- a/lib/Analysis/RegionInfo.cpp +++ b/lib/Analysis/RegionInfo.cpp @@ -10,10 +10,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/RegionInfo.h" -#include "llvm/Analysis/RegionInfoImpl.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/RegionInfoImpl.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/lib/Analysis/RegionPass.cpp b/lib/Analysis/RegionPass.cpp index de34b727a5a0..6fa7b2e0e30c 100644 --- a/lib/Analysis/RegionPass.cpp +++ b/lib/Analysis/RegionPass.cpp @@ -15,9 +15,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/RegionPass.h" #include "llvm/Analysis/RegionIterator.h" -#include "llvm/Support/Timer.h" - #include "llvm/Support/Debug.h" +#include "llvm/Support/Timer.h" using namespace llvm; #define DEBUG_TYPE "regionpassmgr" diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 06dbde58c108..86852d634ff2 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolution.cpp - Scalar Evolution Analysis ----------*- C++ -*-===// +//===- ScalarEvolution.cpp - Scalar Evolution Analysis --------------------===// // // The LLVM Compiler Infrastructure // @@ -59,9 +59,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" @@ -78,6 +80,7 @@ #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -113,6 +116,7 @@ VerifySCEV("verify-scev", INITIALIZE_PASS_BEGIN(ScalarEvolution, "scalar-evolution", "Scalar Evolution Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfo) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfo) @@ -671,7 +675,254 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, } } +namespace { +struct FindSCEVSize { + int Size; + FindSCEVSize() : Size(0) {} + + bool follow(const SCEV *S) { + ++Size; + // Keep looking at all operands of S. + return true; + } + bool isDone() const { + return false; + } +}; +} + +// Returns the size of the SCEV S. +static inline int sizeOfSCEV(const SCEV *S) { + FindSCEVSize F; + SCEVTraversal<FindSCEVSize> ST(F); + ST.visitAll(S); + return F.Size; +} + +namespace { + +struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { +public: + // Computes the Quotient and Remainder of the division of Numerator by + // Denominator. + static void divide(ScalarEvolution &SE, const SCEV *Numerator, + const SCEV *Denominator, const SCEV **Quotient, + const SCEV **Remainder) { + assert(Numerator && Denominator && "Uninitialized SCEV"); + + SCEVDivision D(SE, Numerator, Denominator); + + // Check for the trivial case here to avoid having to check for it in the + // rest of the code. + if (Numerator == Denominator) { + *Quotient = D.One; + *Remainder = D.Zero; + return; + } + + if (Numerator->isZero()) { + *Quotient = D.Zero; + *Remainder = D.Zero; + return; + } + + // Split the Denominator when it is a product. + if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { + const SCEV *Q, *R; + *Quotient = Numerator; + for (const SCEV *Op : T->operands()) { + divide(SE, *Quotient, Op, &Q, &R); + *Quotient = Q; + + // Bail out when the Numerator is not divisible by one of the terms of + // the Denominator. + if (!R->isZero()) { + *Quotient = D.Zero; + *Remainder = Numerator; + return; + } + } + *Remainder = D.Zero; + return; + } + + D.visit(Numerator); + *Quotient = D.Quotient; + *Remainder = D.Remainder; + } + + // Except in the trivial case described above, we do not know how to divide + // Expr by Denominator for the following functions with empty implementation. + void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} + void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} + void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} + void visitUDivExpr(const SCEVUDivExpr *Numerator) {} + void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} + void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} + void visitUnknown(const SCEVUnknown *Numerator) {} + void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} + + void visitConstant(const SCEVConstant *Numerator) { + if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { + APInt NumeratorVal = Numerator->getValue()->getValue(); + APInt DenominatorVal = D->getValue()->getValue(); + uint32_t NumeratorBW = NumeratorVal.getBitWidth(); + uint32_t DenominatorBW = DenominatorVal.getBitWidth(); + + if (NumeratorBW > DenominatorBW) + DenominatorVal = DenominatorVal.sext(NumeratorBW); + else if (NumeratorBW < DenominatorBW) + NumeratorVal = NumeratorVal.sext(DenominatorBW); + + APInt QuotientVal(NumeratorVal.getBitWidth(), 0); + APInt RemainderVal(NumeratorVal.getBitWidth(), 0); + APInt::sdivrem(NumeratorVal, DenominatorVal, QuotientVal, RemainderVal); + Quotient = SE.getConstant(QuotientVal); + Remainder = SE.getConstant(RemainderVal); + return; + } + } + + void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { + const SCEV *StartQ, *StartR, *StepQ, *StepR; + assert(Numerator->isAffine() && "Numerator should be affine"); + divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); + divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); + Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), + Numerator->getNoWrapFlags()); + } + + void visitAddExpr(const SCEVAddExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs, Rs; + Type *Ty = Denominator->getType(); + + for (const SCEV *Op : Numerator->operands()) { + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + + // Bail out if types do not match. + if (Ty != Q->getType() || Ty != R->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + Qs.push_back(Q); + Rs.push_back(R); + } + + if (Qs.size() == 1) { + Quotient = Qs[0]; + Remainder = Rs[0]; + return; + } + + Quotient = SE.getAddExpr(Qs); + Remainder = SE.getAddExpr(Rs); + } + + void visitMulExpr(const SCEVMulExpr *Numerator) { + SmallVector<const SCEV *, 2> Qs; + Type *Ty = Denominator->getType(); + + bool FoundDenominatorTerm = false; + for (const SCEV *Op : Numerator->operands()) { + // Bail out if types do not match. + if (Ty != Op->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + if (FoundDenominatorTerm) { + Qs.push_back(Op); + continue; + } + + // Check whether Denominator divides one of the product operands. + const SCEV *Q, *R; + divide(SE, Op, Denominator, &Q, &R); + if (!R->isZero()) { + Qs.push_back(Op); + continue; + } + + // Bail out if types do not match. + if (Ty != Q->getType()) { + Quotient = Zero; + Remainder = Numerator; + return; + } + + FoundDenominatorTerm = true; + Qs.push_back(Q); + } + + if (FoundDenominatorTerm) { + Remainder = Zero; + if (Qs.size() == 1) + Quotient = Qs[0]; + else + Quotient = SE.getMulExpr(Qs); + return; + } + + if (!isa<SCEVUnknown>(Denominator)) { + Quotient = Zero; + Remainder = Numerator; + return; + } + // The Remainder is obtained by replacing Denominator by 0 in Numerator. + ValueToValueMap RewriteMap; + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(Zero)->getValue(); + Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + + if (Remainder->isZero()) { + // The Quotient is obtained by replacing Denominator by 1 in Numerator. + RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = + cast<SCEVConstant>(One)->getValue(); + Quotient = + SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); + return; + } + + // Quotient is (Numerator - Remainder) divided by Denominator. + const SCEV *Q, *R; + const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); + if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { + // This SCEV does not seem to simplify: fail the division here. + Quotient = Zero; + Remainder = Numerator; + return; + } + divide(SE, Diff, Denominator, &Q, &R); + assert(R == Zero && + "(Numerator - Remainder) should evenly divide Denominator"); + Quotient = Q; + } + +private: + SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, + const SCEV *Denominator) + : SE(S), Denominator(Denominator) { + Zero = SE.getConstant(Denominator->getType(), 0); + One = SE.getConstant(Denominator->getType(), 1); + + // By default, we don't know how to divide Expr by Denominator. + // Providing the default here simplifies the rest of the code. + Quotient = Zero; + Remainder = Numerator; + } + + ScalarEvolution &SE; + const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; +}; + +} //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -1486,6 +1737,36 @@ namespace { }; } +// We're trying to construct a SCEV of type `Type' with `Ops' as operands and +// `OldFlags' as can't-wrap behavior. Infer a more aggressive set of +// can't-overflow flags for the operation if possible. +static SCEV::NoWrapFlags +StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, + const SmallVectorImpl<const SCEV *> &Ops, + SCEV::NoWrapFlags OldFlags) { + using namespace std::placeholders; + + bool CanAnalyze = + Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; + (void)CanAnalyze; + assert(CanAnalyze && "don't call from other places!"); + + int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; + SCEV::NoWrapFlags SignOrUnsignWrap = + ScalarEvolution::maskFlags(OldFlags, SignOrUnsignMask); + + // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. + auto IsKnownNonNegative = + std::bind(std::mem_fn(&ScalarEvolution::isKnownNonNegative), SE, _1); + + if (SignOrUnsignWrap == SCEV::FlagNSW && + std::all_of(Ops.begin(), Ops.end(), IsKnownNonNegative)) + return ScalarEvolution::setFlags(OldFlags, + (SCEV::NoWrapFlags)SignOrUnsignMask); + + return OldFlags; +} + /// getAddExpr - Get a canonical add expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, @@ -1501,20 +1782,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, "SCEVAddExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1889,6 +2157,25 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { return r; } +/// Determine if any of the operands in this SCEV are a constant or if +/// any of the add or multiply expressions in this SCEV contain a constant. +static bool containsConstantSomewhere(const SCEV *StartExpr) { + SmallVector<const SCEV *, 4> Ops; + Ops.push_back(StartExpr); + while (!Ops.empty()) { + const SCEV *CurrentExpr = Ops.pop_back_val(); + if (isa<SCEVConstant>(*CurrentExpr)) + return true; + + if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { + const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); + for (const SCEV *Operand : CurrentNAry->operands()) + Ops.push_back(Operand); + } + } + return false; +} + /// getMulExpr - Get a canonical multiply expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, @@ -1904,20 +2191,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, "SCEVMulExpr operand types don't match!"); #endif - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl<const SCEV *>::const_iterator I = Ops.begin(), - E = Ops.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags); // Sort by complexity, this groups all similar expression types together. GroupByComplexity(Ops, LI); @@ -1928,11 +2202,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) - if (Add->getNumOperands() == 2 && - isa<SCEVConstant>(Add->getOperand(0))) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), - getMulExpr(LHSC, Add->getOperand(1))); + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, + // apply this transformation as well. + if (Add->getNumOperands() == 2) + if (containsConstantSomewhere(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0)), + getMulExpr(LHSC, Add->getOperand(1))); ++Idx; while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { @@ -2061,71 +2337,66 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // Okay, if there weren't any loop invariants to be folded, check to see if // there are multiple AddRec's with the same loop induction variable being // multiplied together. If so, we can fold them. + + // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> + // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ + // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z + // ]]],+,...up to x=2n}. + // Note that the arguments to choose() are always integers with values + // known at compile time, never SCEV objects. + // + // The implementation avoids pointless extra computations when the two + // addrec's are of different length (mathematically, it's equivalent to + // an infinite stream of zeros on the right). + bool OpsModified = false; for (unsigned OtherIdx = Idx+1; - OtherIdx < Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); + OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); ++OtherIdx) { - if (AddRecLoop != cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) + const SCEVAddRecExpr *OtherAddRec = + dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); + if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) continue; - // {A1,+,A2,+,...,+,An}<L> * {B1,+,B2,+,...,+,Bn}<L> - // = {x=1 in [ sum y=x..2x [ sum z=max(y-x, y-n)..min(x,n) [ - // choose(x, 2x)*choose(2x-y, x-z)*A_{y-z}*B_z - // ]]],+,...up to x=2n}. - // Note that the arguments to choose() are always integers with values - // known at compile time, never SCEV objects. - // - // The implementation avoids pointless extra computations when the two - // addrec's are of different length (mathematically, it's equivalent to - // an infinite stream of zeros on the right). - bool OpsModified = false; - for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]); - ++OtherIdx) { - const SCEVAddRecExpr *OtherAddRec = - dyn_cast<SCEVAddRecExpr>(Ops[OtherIdx]); - if (!OtherAddRec || OtherAddRec->getLoop() != AddRecLoop) - continue; - - bool Overflow = false; - Type *Ty = AddRec->getType(); - bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; - SmallVector<const SCEV*, 7> AddRecOps; - for (int x = 0, xe = AddRec->getNumOperands() + - OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getConstant(Ty, 0); - for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { - uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); - for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), - ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); - z < ze && !Overflow; ++z) { - uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); - uint64_t Coeff; - if (LargerThan64Bits) - Coeff = umul_ov(Coeff1, Coeff2, Overflow); - else - Coeff = Coeff1*Coeff2; - const SCEV *CoeffTerm = getConstant(Ty, Coeff); - const SCEV *Term1 = AddRec->getOperand(y-z); - const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); - } + bool Overflow = false; + Type *Ty = AddRec->getType(); + bool LargerThan64Bits = getTypeSizeInBits(Ty) > 64; + SmallVector<const SCEV*, 7> AddRecOps; + for (int x = 0, xe = AddRec->getNumOperands() + + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { + const SCEV *Term = getConstant(Ty, 0); + for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { + uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); + for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), + ze = std::min(x+1, (int)OtherAddRec->getNumOperands()); + z < ze && !Overflow; ++z) { + uint64_t Coeff2 = Choose(2*x - y, x-z, Overflow); + uint64_t Coeff; + if (LargerThan64Bits) + Coeff = umul_ov(Coeff1, Coeff2, Overflow); + else + Coeff = Coeff1*Coeff2; + const SCEV *CoeffTerm = getConstant(Ty, Coeff); + const SCEV *Term1 = AddRec->getOperand(y-z); + const SCEV *Term2 = OtherAddRec->getOperand(z); + Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1,Term2)); } - AddRecOps.push_back(Term); - } - if (!Overflow) { - const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), - SCEV::FlagAnyWrap); - if (Ops.size() == 2) return NewAddRec; - Ops[Idx] = NewAddRec; - Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; - OpsModified = true; - AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); - if (!AddRec) - break; } + AddRecOps.push_back(Term); + } + if (!Overflow) { + const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), + SCEV::FlagAnyWrap); + if (Ops.size() == 2) return NewAddRec; + Ops[Idx] = NewAddRec; + Ops.erase(Ops.begin() + OtherIdx); --OtherIdx; + OpsModified = true; + AddRec = dyn_cast<SCEVAddRecExpr>(NewAddRec); + if (!AddRec) + break; } - if (OpsModified) - return getMulExpr(Ops); } + if (OpsModified) + return getMulExpr(Ops); // Otherwise couldn't fold anything into this recurrence. Move onto the // next one. @@ -2386,20 +2657,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, // meaningful BE count at this point (and if we don't, we'd be stuck // with a SCEVCouldNotCompute as the cached BE count). - // If FlagNSW is true and all the operands are non-negative, infer FlagNUW. - // And vice-versa. - int SignOrUnsignMask = SCEV::FlagNUW | SCEV::FlagNSW; - SCEV::NoWrapFlags SignOrUnsignWrap = maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap && (SignOrUnsignWrap != SignOrUnsignMask)) { - bool All = true; - for (SmallVectorImpl<const SCEV *>::const_iterator I = Operands.begin(), - E = Operands.end(); I != E; ++I) - if (!isKnownNonNegative(*I)) { - All = false; - break; - } - if (All) Flags = setFlags(Flags, (SCEV::NoWrapFlags)SignOrUnsignMask); - } + Flags = StrengthenNoWrapFlags(this, scAddRecExpr, Operands, Flags); // Canonicalize nested AddRecs in by nesting them in order of loop depth. if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) { @@ -3082,7 +3340,8 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { Visited.insert(PN); while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -3263,7 +3522,7 @@ const SCEV *ScalarEvolution::createNodeForPHI(PHINode *PN) { // PHI's incoming blocks are in a different loop, in which case doing so // risks breaking LCSSA form. Instcombine would normally zap these, but // it doesn't have DominatorTree information, so it may miss cases. - if (Value *V = SimplifyInstruction(PN, DL, TLI, DT)) + if (Value *V = SimplifyInstruction(PN, DL, TLI, DT, AC)) if (LI->replacementPreservesLCSSAForm(PN, V)) return getSCEV(V); @@ -3395,7 +3654,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { // For a SCEVUnknown, ask ValueTracking. unsigned BitWidth = getTypeSizeInBits(U->getType()); APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); return Zeros.countTrailingOnes(); } @@ -3403,6 +3662,33 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } +/// GetRangeFromMetadata - Helper method to assign a range to V from +/// metadata present in the IR. +static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { + if (Instruction *I = dyn_cast<Instruction>(V)) { + if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) { + ConstantRange TotalRange( + cast<IntegerType>(I->getType())->getBitWidth(), false); + + unsigned NumRanges = MD->getNumOperands() / 2; + assert(NumRanges >= 1); + + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract<ConstantInt>(MD->getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + TotalRange = TotalRange.unionWith(Range); + } + + return TotalRange; + } + } + + return None; +} + /// getUnsignedRange - Determine the unsigned range for a particular SCEV. /// ConstantRange @@ -3532,9 +3818,14 @@ ScalarEvolution::getUnsignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // Check if the IR explicitly contains !range metadata. + Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. APInt Zeros(BitWidth, 0), Ones(BitWidth, 0); - computeKnownBits(U->getValue(), Zeros, Ones, DL); + computeKnownBits(U->getValue(), Zeros, Ones, DL, 0, AC, nullptr, DT); if (Ones == ~Zeros + 1) return setUnsignedRange(U, ConservativeResult); return setUnsignedRange(U, @@ -3683,10 +3974,15 @@ ScalarEvolution::getSignedRange(const SCEV *S) { } if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + // Check if the IR explicitly contains !range metadata. + Optional<ConstantRange> MDRange = GetRangeFromMetadata(U->getValue()); + if (MDRange.hasValue()) + ConservativeResult = ConservativeResult.intersectWith(MDRange.getValue()); + // For a SCEVUnknown, ask ValueTracking. if (!U->getValue()->getType()->isIntegerTy() && !DL) return setSignedRange(U, ConservativeResult); - unsigned NS = ComputeNumSignBits(U->getValue(), DL); + unsigned NS = ComputeNumSignBits(U->getValue(), DL, 0, AC, nullptr, DT); if (NS <= 1) return setSignedRange(U, ConservativeResult); return setSignedRange(U, ConservativeResult.intersectWith( @@ -3793,7 +4089,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { unsigned TZ = A.countTrailingZeros(); unsigned BitWidth = A.getBitWidth(); APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, 0, AC, + nullptr, DT); APInt EffectiveMask = APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); @@ -4070,6 +4367,14 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // Iteration Count Computation Code // +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripCount(L, ExitingBB); + + // No trip count information for multiple exits. + return 0; +} + /// getSmallConstantTripCount - Returns the maximum trip count of this loop as a /// normal unsigned value. Returns 0 if the trip count is unknown or not /// constant. Will also return 0 if the maximum trip count is very large (>= @@ -4080,19 +4385,13 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { /// before taking the branch. For loops with multiple exits, it may not be the /// number times that the loop header executes because the loop may exit /// prematurely via another branch. -/// -/// FIXME: We conservatively call getBackedgeTakenCount(L) instead of -/// getExitCount(L, ExitingBlock) to compute a safe trip count considering all -/// loop exits. getExitCount() may return an exact count for this branch -/// assuming no-signed-wrap. The number of well-defined iterations may actually -/// be higher than this trip count if this exit test is skipped and the loop -/// exits via a different branch. Ideally, getExitCount() would know whether it -/// depends on a NSW assumption, and we would only fall back to a conservative -/// trip count in that case. -unsigned ScalarEvolution:: -getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { +unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); const SCEVConstant *ExitCount = - dyn_cast<SCEVConstant>(getBackedgeTakenCount(L)); + dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); if (!ExitCount) return 0; @@ -4106,6 +4405,14 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { return ((unsigned)ExitConst->getZExtValue()) + 1; } +unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { + if (BasicBlock *ExitingBB = L->getExitingBlock()) + return getSmallConstantTripMultiple(L, ExitingBB); + + // No trip multiple information for multiple exits. + return 0; +} + /// getSmallConstantTripMultiple - Returns the largest constant divisor of the /// trip count of this loop as a normal unsigned value, if possible. This /// means that the actual trip count is always a multiple of the returned @@ -4118,9 +4425,13 @@ getSmallConstantTripCount(Loop *L, BasicBlock * /*ExitingBlock*/) { /// /// As explained in the comments for getSmallConstantTripCount, this assumes /// that control exits the loop via ExitingBlock. -unsigned ScalarEvolution:: -getSmallConstantTripMultiple(Loop *L, BasicBlock * /*ExitingBlock*/) { - const SCEV *ExitCount = getBackedgeTakenCount(L); +unsigned +ScalarEvolution::getSmallConstantTripMultiple(Loop *L, + BasicBlock *ExitingBlock) { + assert(ExitingBlock && "Must pass a non-null exiting block!"); + assert(L->isLoopExiting(ExitingBlock) && + "Exiting block must actually branch out of the loop!"); + const SCEV *ExitCount = getExitCount(L, ExitingBlock); if (ExitCount == getCouldNotCompute()) return 1; @@ -4230,7 +4541,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4282,7 +4594,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4316,7 +4629,8 @@ void ScalarEvolution::forgetValue(Value *V) { SmallPtrSet<Instruction *, 8> Visited; while (!Worklist.empty()) { I = Worklist.pop_back_val(); - if (!Visited.insert(I)) continue; + if (!Visited.insert(I).second) + continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -4467,20 +4781,12 @@ ScalarEvolution::ComputeBackedgeTakenCount(const Loop *L) { // non-exiting iterations. Partition the loop exits into two kinds: // LoopMustExits and LoopMayExits. // - // A LoopMustExit meets two requirements: - // - // (a) Its ExitLimit.MustExit flag must be set which indicates that the exit - // test condition cannot be skipped (the tested variable has unit stride or - // the test is less-than or greater-than, rather than a strict inequality). - // - // (b) It must dominate the loop latch, hence must be tested on every loop - // iteration. - // - // If any computable LoopMustExit is found, then MaxBECount is the minimum - // EL.Max of computable LoopMustExits. Otherwise, MaxBECount is - // conservatively the maximum EL.Max, where CouldNotCompute is considered - // greater than any computable EL.Max. - if (EL.MustExit && EL.Max != getCouldNotCompute() && Latch && + // If the exit dominates the loop latch, it is a LoopMustExit otherwise it + // is a LoopMayExit. If any computable LoopMustExit is found, then + // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, + // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is + // considered greater than any computable EL.Max. + if (EL.Max != getCouldNotCompute() && Latch && DT->dominates(ExitBB, Latch)) { if (!MustExitMaxBECount) MustExitMaxBECount = EL.Max; @@ -4567,18 +4873,19 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } + bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. return ComputeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) return ComputeExitLimitFromSingleExitSwitch(L, SI, Exit, - /*IsSubExpr=*/false); + /*ControlsExit=*/IsOnlyExit); return getCouldNotCompute(); } @@ -4587,28 +4894,27 @@ ScalarEvolution::ComputeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { /// backedge of the specified loop will execute if its exit condition /// were a conditional branch of ExitCond, TBB, and FBB. /// -/// @param IsSubExpr is true if ExitCond does not directly control the exit -/// branch. In this case, we cannot assume that the loop only exits when the -/// condition is true and cannot infer that failing to meet the condition prior -/// to integer wraparound results in undefined behavior. +/// @param ControlsExit is true if ExitCond directly controls the exit +/// branch. In this case, we can assume that the loop exits only if the +/// condition is true and can infer that failing to meet the condition prior to +/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be true for the loop to continue executing. // Choose the less conservative count. @@ -4623,7 +4929,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. @@ -4632,21 +4937,19 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = ComputeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); ExitLimit EL1 = ComputeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - IsSubExpr || EitherMayExit); + ControlsExit && !EitherMayExit); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); - bool MustExit = false; if (EitherMayExit) { // Both conditions must be false for the loop to continue executing. // Choose the less conservative count. @@ -4661,7 +4964,6 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; else MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); - MustExit = EL0.MustExit || EL1.MustExit; } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. @@ -4670,17 +4972,16 @@ ScalarEvolution::ComputeExitLimitFromCond(const Loop *L, MaxBECount = EL0.Max; if (EL0.Exact == EL1.Exact) BECount = EL0.Exact; - MustExit = EL0.MustExit && EL1.MustExit; } - return ExitLimit(BECount, MaxBECount, MustExit); + return ExitLimit(BECount, MaxBECount); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) - return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, IsSubExpr); + return ComputeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -4707,7 +5008,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool IsSubExpr) { + bool ControlsExit) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -4759,7 +5060,7 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4772,14 +5073,14 @@ ScalarEvolution::ComputeExitLimitFromICmp(const Loop *L, case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, IsSubExpr); + ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); if (EL.hasAnyInfo()) return EL; break; } @@ -4801,7 +5102,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, SwitchInst *Switch, BasicBlock *ExitingBlock, - bool IsSubExpr) { + bool ControlsExit) { assert(!L->contains(ExitingBlock) && "Not an exiting block!"); // Give up if the exit is the default dest of a switch. @@ -4814,7 +5115,7 @@ ScalarEvolution::ComputeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, IsSubExpr); + ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5687,7 +5988,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { /// effectively V != 0. We know and take advantage of the fact that this /// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { +ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -5781,38 +6082,34 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool IsSubExpr) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, /*MustExit=*/true); - } - - // If the recurrence is known not to wraparound, unsigned divide computes the - // back edge count. (Ideally we would have an "isexact" bit for udiv). We know - // that the value will either become zero (and thus the loop terminates), that - // the loop will terminate through some other exit condition first, or that - // the loop has undefined behavior. This means we can't "miss" the exit - // value, even with nonunit stride, and exit later via the same branch. Note - // that we can skip this exit if loop later exits via a different - // branch. Hence MustExit=false. - // - // This is only valid for expressions that directly compute the loop exit. It - // is invalid for subexpressions in which the loop may exit through this - // branch even if this subexpression is false. In that case, the trip count - // computed by this udiv could be smaller than the number of well-defined - // iterations. - if (!IsSubExpr && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + return ExitLimit(Distance, MaxBECount); + } + + // As a special case, handle the instance where Step is a positive power of + // two. In this case, determining whether Step divides Distance evenly can be + // done by counting and comparing the number of trailing zeros of Step and + // Distance. + if (!CountDown) { + const APInt &StepV = StepC->getValue()->getValue(); + // StepV.isPowerOf2() returns true if StepV is an positive power of two. It + // also returns true if StepV is maximally negative (eg, INT_MIN), but that + // case is not handled as this code is guarded by !CountDown. + if (StepV.isPowerOf2() && + GetMinTrailingZeros(Distance) >= StepV.countTrailingZeros()) + return getUDivExactExpr(Distance, Step); + } + + // If the condition controls loop exit (the loop exits only if the expression + // is true) and the addition is no-wrap we can use unsigned divide to + // compute the backedge count. In this case, the step may not divide the + // distance, but we don't care because if the condition is "missed" the loop + // will have undefined behavior due to wrapping. + if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { const SCEV *Exact = - getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact, /*MustExit=*/false); + getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); + return ExitLimit(Exact, Exact); } - // If Step is a power of two that evenly divides Start we know that the loop - // will always terminate. Start may not be a constant so we just have the - // number of trailing zeros available. This is safe even in presence of - // overflow as the recurrence will overflow to exactly 0. - const APInt &StepV = StepC->getValue()->getValue(); - if (StepV.isPowerOf2() && - GetMinTrailingZeros(getNegativeSCEV(Start)) >= StepV.countTrailingZeros()) - return getUDivExactExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - // Then, try to solve the above equation provided that Start is constant. if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) return SolveLinEquationWithOverflow(StepC->getValue()->getValue(), @@ -6309,19 +6606,33 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + BasicBlock *Latch = L->getLoopLatch(); if (!Latch) return false; BranchInst *LoopContinuePredicate = dyn_cast<BranchInst>(Latch->getTerminator()); - if (!LoopContinuePredicate || - LoopContinuePredicate->isUnconditional()) - return false; + if (LoopContinuePredicate && LoopContinuePredicate->isConditional() && + isImpliedCond(Pred, LHS, RHS, + LoopContinuePredicate->getCondition(), + LoopContinuePredicate->getSuccessor(0) != L->getHeader())) + return true; + + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast<CallInst>(AssumeVH); + if (!DT->dominates(CI, Latch->getTerminator())) + continue; - return isImpliedCond(Pred, LHS, RHS, - LoopContinuePredicate->getCondition(), - LoopContinuePredicate->getSuccessor(0) != L->getHeader()); + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + + return false; } /// isLoopEntryGuardedByCond - Test whether entry to the loop is protected @@ -6335,6 +6646,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -6355,6 +6668,18 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, return true; } + // Check conditions due to any @llvm.assume intrinsics. + for (auto &AssumeVH : AC->assumptions()) { + if (!AssumeVH) + continue; + auto *CI = cast<CallInst>(AssumeVH); + if (!DT->dominates(CI, L->getHeader())) + continue; + + if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + return true; + } + return false; } @@ -6469,6 +6794,66 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, RHS, LHS, FoundLHS, FoundRHS); } + // Check if we can make progress by sharpening ranges. + if (FoundPred == ICmpInst::ICMP_NE && + (isa<SCEVConstant>(FoundLHS) || isa<SCEVConstant>(FoundRHS))) { + + const SCEVConstant *C = nullptr; + const SCEV *V = nullptr; + + if (isa<SCEVConstant>(FoundLHS)) { + C = cast<SCEVConstant>(FoundLHS); + V = FoundRHS; + } else { + C = cast<SCEVConstant>(FoundRHS); + V = FoundLHS; + } + + // The guarding predicate tells us that C != V. If the known range + // of V is [C, t), we can sharpen the range to [C + 1, t). The + // range we consider has to correspond to same signedness as the + // predicate we're interested in folding. + + APInt Min = ICmpInst::isSigned(Pred) ? + getSignedRange(V).getSignedMin() : getUnsignedRange(V).getUnsignedMin(); + + if (Min == C->getValue()->getValue()) { + // Given (V >= Min && V != Min) we conclude V >= (Min + 1). + // This is true even if (Min + 1) wraps around -- in case of + // wraparound, (Min + 1) < Min, so (V >= Min => V >= (Min + 1)). + + APInt SharperMin = Min + 1; + + switch (Pred) { + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: + // We know V `Pred` SharperMin. If this implies LHS `Pred` + // RHS, we're done. + if (isImpliedCondOperands(Pred, LHS, RHS, V, + getConstant(SharperMin))) + return true; + + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + // We know from the range information that (V `Pred` Min || + // V == Min). We know from the guarding condition that !(V + // == Min). This gives us + // + // V `Pred` Min || V == Min && !(V == Min) + // => V `Pred` Min + // + // If V `Pred` Min implies LHS `Pred` RHS, we're done. + + if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min))) + return true; + + default: + // No change + break; + } + } + } + // Check whether the actual condition is beyond sufficient. if (FoundPred == ICmpInst::ICMP_EQ) if (ICmpInst::isTrueWhenEqual(Pred)) @@ -6498,6 +6883,85 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, getNotSCEV(FoundLHS)); } + +/// If Expr computes ~A, return A else return nullptr +static const SCEV *MatchNotExpr(const SCEV *Expr) { + const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); + if (!Add || Add->getNumOperands() != 2) return nullptr; + + const SCEVConstant *AddLHS = dyn_cast<SCEVConstant>(Add->getOperand(0)); + if (!(AddLHS && AddLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + const SCEVMulExpr *AddRHS = dyn_cast<SCEVMulExpr>(Add->getOperand(1)); + if (!AddRHS || AddRHS->getNumOperands() != 2) return nullptr; + + const SCEVConstant *MulLHS = dyn_cast<SCEVConstant>(AddRHS->getOperand(0)); + if (!(MulLHS && MulLHS->getValue()->getValue().isAllOnesValue())) + return nullptr; + + return AddRHS->getOperand(1); +} + + +/// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? +template<typename MaxExprType> +static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, + const SCEV *Candidate) { + const MaxExprType *MaxExpr = dyn_cast<MaxExprType>(MaybeMaxExpr); + if (!MaxExpr) return false; + + auto It = std::find(MaxExpr->op_begin(), MaxExpr->op_end(), Candidate); + return It != MaxExpr->op_end(); +} + + +/// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? +template<typename MaxExprType> +static bool IsMinConsistingOf(ScalarEvolution &SE, + const SCEV *MaybeMinExpr, + const SCEV *Candidate) { + const SCEV *MaybeMaxExpr = MatchNotExpr(MaybeMinExpr); + if (!MaybeMaxExpr) + return false; + + return IsMaxConsistingOf<MaxExprType>(MaybeMaxExpr, SE.getNotSCEV(Candidate)); +} + + +/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max +/// expression? +static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + switch (Pred) { + default: + return false; + + case ICmpInst::ICMP_SGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_SLE: + return + // min(A, ...) <= A + IsMinConsistingOf<SCEVSMaxExpr>(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf<SCEVSMaxExpr>(RHS, LHS); + + case ICmpInst::ICMP_UGE: + std::swap(LHS, RHS); + // fall through + case ICmpInst::ICMP_ULE: + return + // min(A, ...) <= A + IsMinConsistingOf<SCEVUMaxExpr>(SE, LHS, RHS) || + // A <= max(A, ...) + IsMaxConsistingOf<SCEVUMaxExpr>(RHS, LHS); + } + + llvm_unreachable("covered switch fell through?!"); +} + /// isImpliedCondOperandsHelper - Test whether the condition described by /// Pred, LHS, and RHS is true whenever the condition described by Pred, /// FoundLHS, and FoundRHS is true. @@ -6506,6 +6970,12 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, const SCEV *FoundRHS) { + auto IsKnownPredicateFull = + [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { + return isKnownPredicateWithRanges(Pred, LHS, RHS) || + IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS); + }; + switch (Pred) { default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); case ICmpInst::ICMP_EQ: @@ -6515,26 +6985,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (isKnownPredicateWithRanges(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - isKnownPredicateWithRanges(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (IsKnownPredicateFull(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + IsKnownPredicateFull(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } @@ -6614,13 +7084,13 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, /// specified less-than comparison will execute. If not computable, return /// CouldNotCompute. /// -/// @param IsSubExpr is true when the LHS < RHS condition does not directly -/// control the branch. In this case, we can only compute an iteration count for -/// a subexpression that cannot overflow before evaluating true. +/// @param ControlsExit is true when the LHS < RHS condition directly controls +/// the branch (loops exits only if condition is true). In this case, we can use +/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6631,7 +7101,7 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = IV->getStepRecurrence(*this); @@ -6651,9 +7121,19 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa<SCEVConstant>(Diff)) { + APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); + if (D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMaxExpr(RHS, Start) + : getUMaxExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -6684,13 +7164,13 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } ScalarEvolution::ExitLimit ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool IsSubExpr) { + bool ControlsExit) { // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); @@ -6701,7 +7181,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (!IV || IV->getLoop() != L || !IV->isAffine()) return getCouldNotCompute(); - bool NoWrap = !IsSubExpr && + bool NoWrap = ControlsExit && IV->getNoWrapFlags(IsSigned ? SCEV::FlagNSW : SCEV::FlagNUW); const SCEV *Stride = getNegativeSCEV(IV->getStepRecurrence(*this)); @@ -6722,9 +7202,19 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { + const SCEV *Diff = getMinusSCEV(RHS, Start); + // If we have NoWrap set, then we can assume that the increment won't + // overflow, in which case if RHS - Start is a constant, we don't need to + // do a max operation since we can just figure it out statically + if (NoWrap && isa<SCEVConstant>(Diff)) { + APInt D = dyn_cast<const SCEVConstant>(Diff)->getValue()->getValue(); + if (!D.isNegative()) + End = Start; + } else + End = IsSigned ? getSMinExpr(RHS, Start) + : getUMinExpr(RHS, Start); + } const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -6756,7 +7246,7 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount, /*MustExit=*/true); + return ExitLimit(BECount, MaxBECount); } /// getNumIterationsInRange - Return the number of iterations of this loop that @@ -6984,268 +7474,6 @@ void SCEVAddRecExpr::collectParametricTerms( }); } -static const APInt srem(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::srem(A, B); -} - -static const APInt sdiv(const SCEVConstant *C1, const SCEVConstant *C2) { - APInt A = C1->getValue()->getValue(); - APInt B = C2->getValue()->getValue(); - uint32_t ABW = A.getBitWidth(); - uint32_t BBW = B.getBitWidth(); - - if (ABW > BBW) - B = B.sext(ABW); - else if (ABW < BBW) - A = A.sext(BBW); - - return APIntOps::sdiv(A, B); -} - -namespace { -struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} - - bool follow(const SCEV *S) { - ++Size; - // Keep looking at all operands of S. - return true; - } - bool isDone() const { - return false; - } -}; -} - -// Returns the size of the SCEV S. -static inline int sizeOfSCEV(const SCEV *S) { - FindSCEVSize F; - SCEVTraversal<FindSCEVSize> ST(F); - ST.visitAll(S); - return F.Size; -} - -namespace { - -struct SCEVDivision : public SCEVVisitor<SCEVDivision, void> { -public: - // Computes the Quotient and Remainder of the division of Numerator by - // Denominator. - static void divide(ScalarEvolution &SE, const SCEV *Numerator, - const SCEV *Denominator, const SCEV **Quotient, - const SCEV **Remainder) { - assert(Numerator && Denominator && "Uninitialized SCEV"); - - SCEVDivision D(SE, Numerator, Denominator); - - // Check for the trivial case here to avoid having to check for it in the - // rest of the code. - if (Numerator == Denominator) { - *Quotient = D.One; - *Remainder = D.Zero; - return; - } - - if (Numerator->isZero()) { - *Quotient = D.Zero; - *Remainder = D.Zero; - return; - } - - // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { - const SCEV *Q, *R; - *Quotient = Numerator; - for (const SCEV *Op : T->operands()) { - divide(SE, *Quotient, Op, &Q, &R); - *Quotient = Q; - - // Bail out when the Numerator is not divisible by one of the terms of - // the Denominator. - if (!R->isZero()) { - *Quotient = D.Zero; - *Remainder = Numerator; - return; - } - } - *Remainder = D.Zero; - return; - } - - D.visit(Numerator); - *Quotient = D.Quotient; - *Remainder = D.Remainder; - } - - SCEVDivision(ScalarEvolution &S, const SCEV *Numerator, const SCEV *Denominator) - : SE(S), Denominator(Denominator) { - Zero = SE.getConstant(Denominator->getType(), 0); - One = SE.getConstant(Denominator->getType(), 1); - - // By default, we don't know how to divide Expr by Denominator. - // Providing the default here simplifies the rest of the code. - Quotient = Zero; - Remainder = Numerator; - } - - // Except in the trivial case described above, we do not know how to divide - // Expr by Denominator for the following functions with empty implementation. - void visitTruncateExpr(const SCEVTruncateExpr *Numerator) {} - void visitZeroExtendExpr(const SCEVZeroExtendExpr *Numerator) {} - void visitSignExtendExpr(const SCEVSignExtendExpr *Numerator) {} - void visitUDivExpr(const SCEVUDivExpr *Numerator) {} - void visitSMaxExpr(const SCEVSMaxExpr *Numerator) {} - void visitUMaxExpr(const SCEVUMaxExpr *Numerator) {} - void visitUnknown(const SCEVUnknown *Numerator) {} - void visitCouldNotCompute(const SCEVCouldNotCompute *Numerator) {} - - void visitConstant(const SCEVConstant *Numerator) { - if (const SCEVConstant *D = dyn_cast<SCEVConstant>(Denominator)) { - Quotient = SE.getConstant(sdiv(Numerator, D)); - Remainder = SE.getConstant(srem(Numerator, D)); - return; - } - } - - void visitAddRecExpr(const SCEVAddRecExpr *Numerator) { - const SCEV *StartQ, *StartR, *StepQ, *StepR; - assert(Numerator->isAffine() && "Numerator should be affine"); - divide(SE, Numerator->getStart(), Denominator, &StartQ, &StartR); - divide(SE, Numerator->getStepRecurrence(SE), Denominator, &StepQ, &StepR); - Quotient = SE.getAddRecExpr(StartQ, StepQ, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - Remainder = SE.getAddRecExpr(StartR, StepR, Numerator->getLoop(), - Numerator->getNoWrapFlags()); - } - - void visitAddExpr(const SCEVAddExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs, Rs; - Type *Ty = Denominator->getType(); - - for (const SCEV *Op : Numerator->operands()) { - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - - // Bail out if types do not match. - if (Ty != Q->getType() || Ty != R->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - Qs.push_back(Q); - Rs.push_back(R); - } - - if (Qs.size() == 1) { - Quotient = Qs[0]; - Remainder = Rs[0]; - return; - } - - Quotient = SE.getAddExpr(Qs); - Remainder = SE.getAddExpr(Rs); - } - - void visitMulExpr(const SCEVMulExpr *Numerator) { - SmallVector<const SCEV *, 2> Qs; - Type *Ty = Denominator->getType(); - - bool FoundDenominatorTerm = false; - for (const SCEV *Op : Numerator->operands()) { - // Bail out if types do not match. - if (Ty != Op->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - if (FoundDenominatorTerm) { - Qs.push_back(Op); - continue; - } - - // Check whether Denominator divides one of the product operands. - const SCEV *Q, *R; - divide(SE, Op, Denominator, &Q, &R); - if (!R->isZero()) { - Qs.push_back(Op); - continue; - } - - // Bail out if types do not match. - if (Ty != Q->getType()) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - FoundDenominatorTerm = true; - Qs.push_back(Q); - } - - if (FoundDenominatorTerm) { - Remainder = Zero; - if (Qs.size() == 1) - Quotient = Qs[0]; - else - Quotient = SE.getMulExpr(Qs); - return; - } - - if (!isa<SCEVUnknown>(Denominator)) { - Quotient = Zero; - Remainder = Numerator; - return; - } - - // The Remainder is obtained by replacing Denominator by 0 in Numerator. - ValueToValueMap RewriteMap; - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(Zero)->getValue(); - Remainder = SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - - if (Remainder->isZero()) { - // The Quotient is obtained by replacing Denominator by 1 in Numerator. - RewriteMap[cast<SCEVUnknown>(Denominator)->getValue()] = - cast<SCEVConstant>(One)->getValue(); - Quotient = - SCEVParameterRewriter::rewrite(Numerator, SE, RewriteMap, true); - return; - } - - // Quotient is (Numerator - Remainder) divided by Denominator. - const SCEV *Q, *R; - const SCEV *Diff = SE.getMinusSCEV(Numerator, Remainder); - if (sizeOfSCEV(Diff) > sizeOfSCEV(Numerator)) { - // This SCEV does not seem to simplify: fail the division here. - Quotient = Zero; - Remainder = Numerator; - return; - } - divide(SE, Diff, Denominator, &Q, &R); - assert(R == Zero && - "(Numerator - Remainder) should evenly divide Denominator"); - Quotient = Q; - } - -private: - ScalarEvolution &SE; - const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; -}; -} - static bool findArrayDimensionsRec(ScalarEvolution &SE, SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes) { @@ -7609,7 +7837,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { // that until everything else is done. if (U == Old) continue; - if (!Visited.insert(U)) + if (!Visited.insert(U).second) continue; if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); @@ -7638,6 +7866,7 @@ ScalarEvolution::ScalarEvolution() bool ScalarEvolution::runOnFunction(Function &F) { this->F = &F; + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); LI = &getAnalysis<LoopInfo>(); DataLayoutPass *DLP = getAnalysisIfAvailable<DataLayoutPass>(); DL = DLP ? &DLP->getDataLayout() : nullptr; @@ -7678,6 +7907,7 @@ void ScalarEvolution::releaseMemory() { void ScalarEvolution::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); AU.addRequiredTransitive<LoopInfo>(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfo>(); diff --git a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp index 6933f74150fd..5c339eecdd62 100644 --- a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp +++ b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -162,10 +162,10 @@ ScalarEvolutionAliasAnalysis::alias(const Location &LocA, if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr)) if (alias(Location(AO ? AO : LocA.Ptr, AO ? +UnknownSize : LocA.Size, - AO ? nullptr : LocA.TBAATag), + AO ? AAMDNodes() : LocA.AATags), Location(BO ? BO : LocB.Ptr, BO ? +UnknownSize : LocB.Size, - BO ? nullptr : LocB.TBAATag)) == NoAlias) + BO ? AAMDNodes() : LocB.AATags)) == NoAlias) return NoAlias; // Forward the query to the next analysis. diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 968c619a48dd..59f19a002ecc 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1443,7 +1443,7 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { Constant *One = ConstantInt::get(Ty, 1); for (pred_iterator HPI = HPB; HPI != HPE; ++HPI) { BasicBlock *HP = *HPI; - if (!PredSeen.insert(HP)) { + if (!PredSeen.insert(HP).second) { // There must be an incoming value for each predecessor, even the // duplicates! CanonicalIV->addIncoming(CanonicalIV->getIncomingValueForBlock(HP), HP); @@ -1711,7 +1711,7 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, // Fold constant phis. They may be congruent to other constant phis and // would confuse the logic below that expects proper IVs. - if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT)) { + if (Value *V = SimplifyInstruction(Phi, SE.DL, SE.TLI, SE.DT, SE.AC)) { Phi->replaceAllUsesWith(V); DeadInsts.push_back(Phi); ++NumElim; diff --git a/lib/Analysis/ScalarEvolutionNormalization.cpp b/lib/Analysis/ScalarEvolutionNormalization.cpp index 3ccefb01101d..b238fe43cc60 100644 --- a/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -126,7 +126,7 @@ TransformImpl(const SCEV *S, Instruction *User, Value *OperandValToReplace) { // Normalized form: {-2,+,1,+,2} // Denormalized form: {1,+,3,+,2} // - // However, denormalization would use the a different step expression than + // However, denormalization would use a different step expression than // normalization (see getPostIncExpr), generating the wrong final // expression: {-2,+,1,+,2} + {1,+,2} => {-1,+,3,+,2} if (AR->isAffine() && diff --git a/lib/Analysis/ScopedNoAliasAA.cpp b/lib/Analysis/ScopedNoAliasAA.cpp new file mode 100644 index 000000000000..c6ea3af3d0c3 --- /dev/null +++ b/lib/Analysis/ScopedNoAliasAA.cpp @@ -0,0 +1,245 @@ +//===- ScopedNoAliasAA.cpp - Scoped No-Alias Alias Analysis ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines the ScopedNoAlias alias-analysis pass, which implements +// metadata-based scoped no-alias support. +// +// Alias-analysis scopes are defined by an id (which can be a string or some +// other metadata node), a domain node, and an optional descriptive string. +// A domain is defined by an id (which can be a string or some other metadata +// node), and an optional descriptive string. +// +// !dom0 = metadata !{ metadata !"domain of foo()" } +// !scope1 = metadata !{ metadata !scope1, metadata !dom0, metadata !"scope 1" } +// !scope2 = metadata !{ metadata !scope2, metadata !dom0, metadata !"scope 2" } +// +// Loads and stores can be tagged with an alias-analysis scope, and also, with +// a noalias tag for a specific scope: +// +// ... = load %ptr1, !alias.scope !{ !scope1 } +// ... = load %ptr2, !alias.scope !{ !scope1, !scope2 }, !noalias !{ !scope1 } +// +// When evaluating an aliasing query, if one of the instructions is associated +// has a set of noalias scopes in some domain that is superset of the alias +// scopes in that domain of some other instruction, then the two memory +// accesses are assumed not to alias. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +using namespace llvm; + +// A handy option for disabling scoped no-alias functionality. The same effect +// can also be achieved by stripping the associated metadata tags from IR, but +// this option is sometimes more convenient. +static cl::opt<bool> +EnableScopedNoAlias("enable-scoped-noalias", cl::init(true)); + +namespace { +/// AliasScopeNode - This is a simple wrapper around an MDNode which provides +/// a higher-level interface by hiding the details of how alias analysis +/// information is encoded in its operands. +class AliasScopeNode { + const MDNode *Node; + +public: + AliasScopeNode() : Node(0) {} + explicit AliasScopeNode(const MDNode *N) : Node(N) {} + + /// getNode - Get the MDNode for this AliasScopeNode. + const MDNode *getNode() const { return Node; } + + /// getDomain - Get the MDNode for this AliasScopeNode's domain. + const MDNode *getDomain() const { + if (Node->getNumOperands() < 2) + return nullptr; + return dyn_cast_or_null<MDNode>(Node->getOperand(1)); + } +}; + +/// ScopedNoAliasAA - This is a simple alias analysis +/// implementation that uses scoped-noalias metadata to answer queries. +class ScopedNoAliasAA : public ImmutablePass, public AliasAnalysis { +public: + static char ID; // Class identification, replacement for typeinfo + ScopedNoAliasAA() : ImmutablePass(ID) { + initializeScopedNoAliasAAPass(*PassRegistry::getPassRegistry()); + } + + void initializePass() override { InitializeAliasAnalysis(this); } + + /// getAdjustedAnalysisPointer - This method is used when a pass implements + /// an analysis interface through multiple inheritance. If needed, it + /// should override this to adjust the this pointer as needed for the + /// specified pass info. + void *getAdjustedAnalysisPointer(const void *PI) override { + if (PI == &AliasAnalysis::ID) + return (AliasAnalysis*)this; + return this; + } + +protected: + bool mayAliasInScopes(const MDNode *Scopes, const MDNode *NoAlias) const; + void collectMDInDomain(const MDNode *List, const MDNode *Domain, + SmallPtrSetImpl<const MDNode *> &Nodes) const; + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override; + AliasResult alias(const Location &LocA, const Location &LocB) override; + bool pointsToConstantMemory(const Location &Loc, bool OrLocal) override; + ModRefBehavior getModRefBehavior(ImmutableCallSite CS) override; + ModRefBehavior getModRefBehavior(const Function *F) override; + ModRefResult getModRefInfo(ImmutableCallSite CS, + const Location &Loc) override; + ModRefResult getModRefInfo(ImmutableCallSite CS1, + ImmutableCallSite CS2) override; +}; +} // End of anonymous namespace + +// Register this pass... +char ScopedNoAliasAA::ID = 0; +INITIALIZE_AG_PASS(ScopedNoAliasAA, AliasAnalysis, "scoped-noalias", + "Scoped NoAlias Alias Analysis", false, true, false) + +ImmutablePass *llvm::createScopedNoAliasAAPass() { + return new ScopedNoAliasAA(); +} + +void +ScopedNoAliasAA::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AliasAnalysis::getAnalysisUsage(AU); +} + +void +ScopedNoAliasAA::collectMDInDomain(const MDNode *List, const MDNode *Domain, + SmallPtrSetImpl<const MDNode *> &Nodes) const { + for (unsigned i = 0, ie = List->getNumOperands(); i != ie; ++i) + if (const MDNode *MD = dyn_cast<MDNode>(List->getOperand(i))) + if (AliasScopeNode(MD).getDomain() == Domain) + Nodes.insert(MD); +} + +bool +ScopedNoAliasAA::mayAliasInScopes(const MDNode *Scopes, + const MDNode *NoAlias) const { + if (!Scopes || !NoAlias) + return true; + + // Collect the set of scope domains relevant to the noalias scopes. + SmallPtrSet<const MDNode *, 16> Domains; + for (unsigned i = 0, ie = NoAlias->getNumOperands(); i != ie; ++i) + if (const MDNode *NAMD = dyn_cast<MDNode>(NoAlias->getOperand(i))) + if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain()) + Domains.insert(Domain); + + // We alias unless, for some domain, the set of noalias scopes in that domain + // is a superset of the set of alias scopes in that domain. + for (const MDNode *Domain : Domains) { + SmallPtrSet<const MDNode *, 16> NANodes, ScopeNodes; + collectMDInDomain(NoAlias, Domain, NANodes); + collectMDInDomain(Scopes, Domain, ScopeNodes); + if (!ScopeNodes.size()) + continue; + + // To not alias, all of the nodes in ScopeNodes must be in NANodes. + bool FoundAll = true; + for (const MDNode *SMD : ScopeNodes) + if (!NANodes.count(SMD)) { + FoundAll = false; + break; + } + + if (FoundAll) + return false; + } + + return true; +} + +AliasAnalysis::AliasResult +ScopedNoAliasAA::alias(const Location &LocA, const Location &LocB) { + if (!EnableScopedNoAlias) + return AliasAnalysis::alias(LocA, LocB); + + // Get the attached MDNodes. + const MDNode *AScopes = LocA.AATags.Scope, + *BScopes = LocB.AATags.Scope; + + const MDNode *ANoAlias = LocA.AATags.NoAlias, + *BNoAlias = LocB.AATags.NoAlias; + + if (!mayAliasInScopes(AScopes, BNoAlias)) + return NoAlias; + + if (!mayAliasInScopes(BScopes, ANoAlias)) + return NoAlias; + + // If they may alias, chain to the next AliasAnalysis. + return AliasAnalysis::alias(LocA, LocB); +} + +bool ScopedNoAliasAA::pointsToConstantMemory(const Location &Loc, + bool OrLocal) { + return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); +} + +AliasAnalysis::ModRefBehavior +ScopedNoAliasAA::getModRefBehavior(ImmutableCallSite CS) { + return AliasAnalysis::getModRefBehavior(CS); +} + +AliasAnalysis::ModRefBehavior +ScopedNoAliasAA::getModRefBehavior(const Function *F) { + return AliasAnalysis::getModRefBehavior(F); +} + +AliasAnalysis::ModRefResult +ScopedNoAliasAA::getModRefInfo(ImmutableCallSite CS, const Location &Loc) { + if (!EnableScopedNoAlias) + return AliasAnalysis::getModRefInfo(CS, Loc); + + if (!mayAliasInScopes(Loc.AATags.Scope, CS.getInstruction()->getMetadata( + LLVMContext::MD_noalias))) + return NoModRef; + + if (!mayAliasInScopes( + CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + Loc.AATags.NoAlias)) + return NoModRef; + + return AliasAnalysis::getModRefInfo(CS, Loc); +} + +AliasAnalysis::ModRefResult +ScopedNoAliasAA::getModRefInfo(ImmutableCallSite CS1, ImmutableCallSite CS2) { + if (!EnableScopedNoAlias) + return AliasAnalysis::getModRefInfo(CS1, CS2); + + if (!mayAliasInScopes( + CS1.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + CS2.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + return NoModRef; + + if (!mayAliasInScopes( + CS2.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), + CS1.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + return NoModRef; + + return AliasAnalysis::getModRefInfo(CS1, CS2); +} + diff --git a/lib/Analysis/StratifiedSets.h b/lib/Analysis/StratifiedSets.h new file mode 100644 index 000000000000..fd3fbc0d86ad --- /dev/null +++ b/lib/Analysis/StratifiedSets.h @@ -0,0 +1,692 @@ +//===- StratifiedSets.h - Abstract stratified sets implementation. --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ADT_STRATIFIEDSETS_H +#define LLVM_ADT_STRATIFIEDSETS_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Compiler.h" +#include <bitset> +#include <cassert> +#include <cmath> +#include <limits> +#include <type_traits> +#include <utility> +#include <vector> + +namespace llvm { +// \brief An index into Stratified Sets. +typedef unsigned StratifiedIndex; +// NOTE: ^ This can't be a short -- bootstrapping clang has a case where +// ~1M sets exist. + +// \brief Container of information related to a value in a StratifiedSet. +struct StratifiedInfo { + StratifiedIndex Index; + // For field sensitivity, etc. we can tack attributes on to this struct. +}; + +// The number of attributes that StratifiedAttrs should contain. Attributes are +// described below, and 32 was an arbitrary choice because it fits nicely in 32 +// bits (because we use a bitset for StratifiedAttrs). +static const unsigned NumStratifiedAttrs = 32; + +// These are attributes that the users of StratifiedSets/StratifiedSetBuilders +// may use for various purposes. These also have the special property of that +// they are merged down. So, if set A is above set B, and one decides to set an +// attribute in set A, then the attribute will automatically be set in set B. +typedef std::bitset<NumStratifiedAttrs> StratifiedAttrs; + +// \brief A "link" between two StratifiedSets. +struct StratifiedLink { + // \brief This is a value used to signify "does not exist" where + // the StratifiedIndex type is used. This is used instead of + // Optional<StratifiedIndex> because Optional<StratifiedIndex> would + // eat up a considerable amount of extra memory, after struct + // padding/alignment is taken into account. + static const StratifiedIndex SetSentinel; + + // \brief The index for the set "above" current + StratifiedIndex Above; + + // \brief The link for the set "below" current + StratifiedIndex Below; + + // \brief Attributes for these StratifiedSets. + StratifiedAttrs Attrs; + + StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} + + bool hasBelow() const { return Below != SetSentinel; } + bool hasAbove() const { return Above != SetSentinel; } + + void clearBelow() { Below = SetSentinel; } + void clearAbove() { Above = SetSentinel; } +}; + +// \brief These are stratified sets, as described in "Fast algorithms for +// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M +// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets +// of Value*s. If two Value*s are in the same set, or if both sets have +// overlapping attributes, then the Value*s are said to alias. +// +// Sets may be related by position, meaning that one set may be considered as +// above or below another. In CFL Alias Analysis, this gives us an indication +// of how two variables are related; if the set of variable A is below a set +// containing variable B, then at some point, a variable that has interacted +// with B (or B itself) was either used in order to extract the variable A, or +// was used as storage of variable A. +// +// Sets may also have attributes (as noted above). These attributes are +// generally used for noting whether a variable in the set has interacted with +// a variable whose origins we don't quite know (i.e. globals/arguments), or if +// the variable may have had operations performed on it (modified in a function +// call). All attributes that exist in a set A must exist in all sets marked as +// below set A. +template <typename T> class StratifiedSets { +public: + StratifiedSets() {} + + StratifiedSets(DenseMap<T, StratifiedInfo> Map, + std::vector<StratifiedLink> Links) + : Values(std::move(Map)), Links(std::move(Links)) {} + + StratifiedSets(StratifiedSets<T> &&Other) { *this = std::move(Other); } + + StratifiedSets &operator=(StratifiedSets<T> &&Other) { + Values = std::move(Other.Values); + Links = std::move(Other.Links); + return *this; + } + + Optional<StratifiedInfo> find(const T &Elem) const { + auto Iter = Values.find(Elem); + if (Iter == Values.end()) { + return NoneType(); + } + return Iter->second; + } + + const StratifiedLink &getLink(StratifiedIndex Index) const { + assert(inbounds(Index)); + return Links[Index]; + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<StratifiedLink> Links; + + bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } +}; + +// \brief Generic Builder class that produces StratifiedSets instances. +// +// The goal of this builder is to efficiently produce correct StratifiedSets +// instances. To this end, we use a few tricks: +// > Set chains (A method for linking sets together) +// > Set remaps (A method for marking a set as an alias [irony?] of another) +// +// ==== Set chains ==== +// This builder has a notion of some value A being above, below, or with some +// other value B: +// > The `A above B` relationship implies that there is a reference edge going +// from A to B. Namely, it notes that A can store anything in B's set. +// > The `A below B` relationship is the opposite of `A above B`. It implies +// that there's a dereference edge going from A to B. +// > The `A with B` relationship states that there's an assignment edge going +// from A to B, and that A and B should be treated as equals. +// +// As an example, take the following code snippet: +// +// %a = alloca i32, align 4 +// %ap = alloca i32*, align 8 +// %app = alloca i32**, align 8 +// store %a, %ap +// store %ap, %app +// %aw = getelementptr %ap, 0 +// +// Given this, the follow relations exist: +// - %a below %ap & %ap above %a +// - %ap below %app & %app above %ap +// - %aw with %ap & %ap with %aw +// +// These relations produce the following sets: +// [{%a}, {%ap, %aw}, {%app}] +// +// ...Which states that the only MayAlias relationship in the above program is +// between %ap and %aw. +// +// Life gets more complicated when we actually have logic in our programs. So, +// we either must remove this logic from our programs, or make consessions for +// it in our AA algorithms. In this case, we have decided to select the latter +// option. +// +// First complication: Conditionals +// Motivation: +// %ad = alloca int, align 4 +// %a = alloca int*, align 8 +// %b = alloca int*, align 8 +// %bp = alloca int**, align 8 +// %c = call i1 @SomeFunc() +// %k = select %c, %ad, %bp +// store %ad, %a +// store %b, %bp +// +// %k has 'with' edges to both %a and %b, which ordinarily would not be linked +// together. So, we merge the set that contains %a with the set that contains +// %b. We then recursively merge the set above %a with the set above %b, and +// the set below %a with the set below %b, etc. Ultimately, the sets for this +// program would end up like: {%ad}, {%a, %b, %k}, {%bp}, where {%ad} is below +// {%a, %b, %c} is below {%ad}. +// +// Second complication: Arbitrary casts +// Motivation: +// %ip = alloca int*, align 8 +// %ipp = alloca int**, align 8 +// %i = bitcast ipp to int +// store %ip, %ipp +// store %i, %ip +// +// This is impossible to construct with any of the rules above, because a set +// containing both {%i, %ipp} is supposed to exist, the set with %i is supposed +// to be below the set with %ip, and the set with %ip is supposed to be below +// the set with %ipp. Because we don't allow circular relationships like this, +// we merge all concerned sets into one. So, the above code would generate a +// single StratifiedSet: {%ip, %ipp, %i}. +// +// ==== Set remaps ==== +// More of an implementation detail than anything -- when merging sets, we need +// to update the numbers of all of the elements mapped to those sets. Rather +// than doing this at each merge, we note in the BuilderLink structure that a +// remap has occurred, and use this information so we can defer renumbering set +// elements until build time. +template <typename T> class StratifiedSetsBuilder { + // \brief Represents a Stratified Set, with information about the Stratified + // Set above it, the set below it, and whether the current set has been + // remapped to another. + struct BuilderLink { + const StratifiedIndex Number; + + BuilderLink(StratifiedIndex N) : Number(N) { + Remap = StratifiedLink::SetSentinel; + } + + bool hasAbove() const { + assert(!isRemapped()); + return Link.hasAbove(); + } + + bool hasBelow() const { + assert(!isRemapped()); + return Link.hasBelow(); + } + + void setBelow(StratifiedIndex I) { + assert(!isRemapped()); + Link.Below = I; + } + + void setAbove(StratifiedIndex I) { + assert(!isRemapped()); + Link.Above = I; + } + + void clearBelow() { + assert(!isRemapped()); + Link.clearBelow(); + } + + void clearAbove() { + assert(!isRemapped()); + Link.clearAbove(); + } + + StratifiedIndex getBelow() const { + assert(!isRemapped()); + assert(hasBelow()); + return Link.Below; + } + + StratifiedIndex getAbove() const { + assert(!isRemapped()); + assert(hasAbove()); + return Link.Above; + } + + StratifiedAttrs &getAttrs() { + assert(!isRemapped()); + return Link.Attrs; + } + + void setAttr(unsigned index) { + assert(!isRemapped()); + assert(index < NumStratifiedAttrs); + Link.Attrs.set(index); + } + + void setAttrs(const StratifiedAttrs &other) { + assert(!isRemapped()); + Link.Attrs |= other; + } + + bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } + + // \brief For initial remapping to another set + void remapTo(StratifiedIndex Other) { + assert(!isRemapped()); + Remap = Other; + } + + StratifiedIndex getRemapIndex() const { + assert(isRemapped()); + return Remap; + } + + // \brief Should only be called when we're already remapped. + void updateRemap(StratifiedIndex Other) { + assert(isRemapped()); + Remap = Other; + } + + // \brief Prefer the above functions to calling things directly on what's + // returned from this -- they guard against unexpected calls when the + // current BuilderLink is remapped. + const StratifiedLink &getLink() const { return Link; } + + private: + StratifiedLink Link; + StratifiedIndex Remap; + }; + + // \brief This function performs all of the set unioning/value renumbering + // that we've been putting off, and generates a vector<StratifiedLink> that + // may be placed in a StratifiedSets instance. + void finalizeSets(std::vector<StratifiedLink> &StratLinks) { + DenseMap<StratifiedIndex, StratifiedIndex> Remaps; + for (auto &Link : Links) { + if (Link.isRemapped()) { + continue; + } + + StratifiedIndex Number = StratLinks.size(); + Remaps.insert(std::make_pair(Link.Number, Number)); + StratLinks.push_back(Link.getLink()); + } + + for (auto &Link : StratLinks) { + if (Link.hasAbove()) { + auto &Above = linksAt(Link.Above); + auto Iter = Remaps.find(Above.Number); + assert(Iter != Remaps.end()); + Link.Above = Iter->second; + } + + if (Link.hasBelow()) { + auto &Below = linksAt(Link.Below); + auto Iter = Remaps.find(Below.Number); + assert(Iter != Remaps.end()); + Link.Below = Iter->second; + } + } + + for (auto &Pair : Values) { + auto &Info = Pair.second; + auto &Link = linksAt(Info.Index); + auto Iter = Remaps.find(Link.Number); + assert(Iter != Remaps.end()); + Info.Index = Iter->second; + } + } + + // \brief There's a guarantee in StratifiedLink where all bits set in a + // Link.externals will be set in all Link.externals "below" it. + static void propagateAttrs(std::vector<StratifiedLink> &Links) { + const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { + const auto *Link = &Links[Idx]; + while (Link->hasAbove()) { + Idx = Link->Above; + Link = &Links[Idx]; + } + return Idx; + }; + + SmallSet<StratifiedIndex, 16> Visited; + for (unsigned I = 0, E = Links.size(); I < E; ++I) { + auto CurrentIndex = getHighestParentAbove(I); + if (!Visited.insert(CurrentIndex).second) { + continue; + } + + while (Links[CurrentIndex].hasBelow()) { + auto &CurrentBits = Links[CurrentIndex].Attrs; + auto NextIndex = Links[CurrentIndex].Below; + auto &NextBits = Links[NextIndex].Attrs; + NextBits |= CurrentBits; + CurrentIndex = NextIndex; + } + } + } + +public: + // \brief Builds a StratifiedSet from the information we've been given since + // either construction or the prior build() call. + StratifiedSets<T> build() { + std::vector<StratifiedLink> StratLinks; + finalizeSets(StratLinks); + propagateAttrs(StratLinks); + Links.clear(); + return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); + } + + std::size_t size() const { return Values.size(); } + std::size_t numSets() const { return Links.size(); } + + bool has(const T &Elem) const { return get(Elem).hasValue(); } + + bool add(const T &Main) { + if (get(Main).hasValue()) + return false; + + auto NewIndex = getNewUnlinkedIndex(); + return addAtMerging(Main, NewIndex); + } + + // \brief Restructures the stratified sets as necessary to make "ToAdd" in a + // set above "Main". There are some cases where this is not possible (see + // above), so we merge them such that ToAdd and Main are in the same set. + bool addAbove(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasAbove()) + addLinkAbove(Index); + + auto Above = linksAt(Index).getAbove(); + return addAtMerging(ToAdd, Above); + } + + // \brief Restructures the stratified sets as necessary to make "ToAdd" in a + // set below "Main". There are some cases where this is not possible (see + // above), so we merge them such that ToAdd and Main are in the same set. + bool addBelow(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto Index = *indexOf(Main); + if (!linksAt(Index).hasBelow()) + addLinkBelow(Index); + + auto Below = linksAt(Index).getBelow(); + return addAtMerging(ToAdd, Below); + } + + bool addWith(const T &Main, const T &ToAdd) { + assert(has(Main)); + auto MainIndex = *indexOf(Main); + return addAtMerging(ToAdd, MainIndex); + } + + void noteAttribute(const T &Main, unsigned AttrNum) { + assert(has(Main)); + assert(AttrNum < StratifiedLink::SetSentinel); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + Link.setAttr(AttrNum); + } + + void noteAttributes(const T &Main, const StratifiedAttrs &NewAttrs) { + assert(has(Main)); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + Link.setAttrs(NewAttrs); + } + + StratifiedAttrs getAttributes(const T &Main) { + assert(has(Main)); + auto *Info = *get(Main); + auto *Link = &linksAt(Info->Index); + auto Attrs = Link->getAttrs(); + while (Link->hasAbove()) { + Link = &linksAt(Link->getAbove()); + Attrs |= Link->getAttrs(); + } + + return Attrs; + } + + bool getAttribute(const T &Main, unsigned AttrNum) { + assert(AttrNum < StratifiedLink::SetSentinel); + auto Attrs = getAttributes(Main); + return Attrs[AttrNum]; + } + + // \brief Gets the attributes that have been applied to the set that Main + // belongs to. It ignores attributes in any sets above the one that Main + // resides in. + StratifiedAttrs getRawAttributes(const T &Main) { + assert(has(Main)); + auto *Info = *get(Main); + auto &Link = linksAt(Info->Index); + return Link.getAttrs(); + } + + // \brief Gets an attribute from the attributes that have been applied to the + // set that Main belongs to. It ignores attributes in any sets above the one + // that Main resides in. + bool getRawAttribute(const T &Main, unsigned AttrNum) { + assert(AttrNum < StratifiedLink::SetSentinel); + auto Attrs = getRawAttributes(Main); + return Attrs[AttrNum]; + } + +private: + DenseMap<T, StratifiedInfo> Values; + std::vector<BuilderLink> Links; + + // \brief Adds the given element at the given index, merging sets if + // necessary. + bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { + StratifiedInfo Info = {Index}; + auto Pair = Values.insert(std::make_pair(ToAdd, Info)); + if (Pair.second) + return true; + + auto &Iter = Pair.first; + auto &IterSet = linksAt(Iter->second.Index); + auto &ReqSet = linksAt(Index); + + // Failed to add where we wanted to. Merge the sets. + if (&IterSet != &ReqSet) + merge(IterSet.Number, ReqSet.Number); + + return false; + } + + // \brief Gets the BuilderLink at the given index, taking set remapping into + // account. + BuilderLink &linksAt(StratifiedIndex Index) { + auto *Start = &Links[Index]; + if (!Start->isRemapped()) + return *Start; + + auto *Current = Start; + while (Current->isRemapped()) + Current = &Links[Current->getRemapIndex()]; + + auto NewRemap = Current->Number; + + // Run through everything that has yet to be updated, and update them to + // remap to NewRemap + Current = Start; + while (Current->isRemapped()) { + auto *Next = &Links[Current->getRemapIndex()]; + Current->updateRemap(NewRemap); + Current = Next; + } + + return *Current; + } + + // \brief Merges two sets into one another. Assumes that these sets are not + // already one in the same + void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + assert(&linksAt(Idx1) != &linksAt(Idx2) && + "Merging a set into itself is not allowed"); + + // CASE 1: If the set at `Idx1` is above or below `Idx2`, we need to merge + // both the + // given sets, and all sets between them, into one. + if (tryMergeUpwards(Idx1, Idx2)) + return; + + if (tryMergeUpwards(Idx2, Idx1)) + return; + + // CASE 2: The set at `Idx1` is not in the same chain as the set at `Idx2`. + // We therefore need to merge the two chains together. + mergeDirect(Idx1, Idx2); + } + + // \brief Merges two sets assuming that the set at `Idx1` is unreachable from + // traversing above or below the set at `Idx2`. + void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { + assert(inbounds(Idx1) && inbounds(Idx2)); + + auto *LinksInto = &linksAt(Idx1); + auto *LinksFrom = &linksAt(Idx2); + // Merging everything above LinksInto then proceeding to merge everything + // below LinksInto becomes problematic, so we go as far "up" as possible! + while (LinksInto->hasAbove() && LinksFrom->hasAbove()) { + LinksInto = &linksAt(LinksInto->getAbove()); + LinksFrom = &linksAt(LinksFrom->getAbove()); + } + + if (LinksFrom->hasAbove()) { + LinksInto->setAbove(LinksFrom->getAbove()); + auto &NewAbove = linksAt(LinksInto->getAbove()); + NewAbove.setBelow(LinksInto->Number); + } + + // Merging strategy: + // > If neither has links below, stop. + // > If only `LinksInto` has links below, stop. + // > If only `LinksFrom` has links below, reset `LinksInto.Below` to + // match `LinksFrom.Below` + // > If both have links above, deal with those next. + while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { + auto &FromAttrs = LinksFrom->getAttrs(); + LinksInto->setAttrs(FromAttrs); + + // Remap needs to happen after getBelow(), but before + // assignment of LinksFrom + auto *NewLinksFrom = &linksAt(LinksFrom->getBelow()); + LinksFrom->remapTo(LinksInto->Number); + LinksFrom = NewLinksFrom; + LinksInto = &linksAt(LinksInto->getBelow()); + } + + if (LinksFrom->hasBelow()) { + LinksInto->setBelow(LinksFrom->getBelow()); + auto &NewBelow = linksAt(LinksInto->getBelow()); + NewBelow.setAbove(LinksInto->Number); + } + + LinksFrom->remapTo(LinksInto->Number); + } + + // \brief Checks to see if lowerIndex is at a level lower than upperIndex. + // If so, it will merge lowerIndex with upperIndex (and all of the sets + // between) and return true. Otherwise, it will return false. + bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { + assert(inbounds(LowerIndex) && inbounds(UpperIndex)); + auto *Lower = &linksAt(LowerIndex); + auto *Upper = &linksAt(UpperIndex); + if (Lower == Upper) + return true; + + SmallVector<BuilderLink *, 8> Found; + auto *Current = Lower; + auto Attrs = Current->getAttrs(); + while (Current->hasAbove() && Current != Upper) { + Found.push_back(Current); + Attrs |= Current->getAttrs(); + Current = &linksAt(Current->getAbove()); + } + + if (Current != Upper) + return false; + + Upper->setAttrs(Attrs); + + if (Lower->hasBelow()) { + auto NewBelowIndex = Lower->getBelow(); + Upper->setBelow(NewBelowIndex); + auto &NewBelow = linksAt(NewBelowIndex); + NewBelow.setAbove(UpperIndex); + } else { + Upper->clearBelow(); + } + + for (const auto &Ptr : Found) + Ptr->remapTo(Upper->Number); + + return true; + } + + Optional<const StratifiedInfo *> get(const T &Val) const { + auto Result = Values.find(Val); + if (Result == Values.end()) + return NoneType(); + return &Result->second; + } + + Optional<StratifiedInfo *> get(const T &Val) { + auto Result = Values.find(Val); + if (Result == Values.end()) + return NoneType(); + return &Result->second; + } + + Optional<StratifiedIndex> indexOf(const T &Val) { + auto MaybeVal = get(Val); + if (!MaybeVal.hasValue()) + return NoneType(); + auto *Info = *MaybeVal; + auto &Link = linksAt(Info->Index); + return Link.Number; + } + + StratifiedIndex addLinkBelow(StratifiedIndex Set) { + auto At = addLinks(); + Links[Set].setBelow(At); + Links[At].setAbove(Set); + return At; + } + + StratifiedIndex addLinkAbove(StratifiedIndex Set) { + auto At = addLinks(); + Links[At].setBelow(Set); + Links[Set].setAbove(At); + return At; + } + + StratifiedIndex getNewUnlinkedIndex() { return addLinks(); } + + StratifiedIndex addLinks() { + auto Link = Links.size(); + Links.push_back(BuilderLink(Link)); + return Link; + } + + bool inbounds(StratifiedIndex N) const { return N < Links.size(); } +}; +} +#endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index cdb0b79fd7ee..ef3909b326f9 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -87,9 +87,10 @@ bool TargetTransformInfo::isLoweredToCall(const Function *F) const { return PrevTTI->isLoweredToCall(F); } -void TargetTransformInfo::getUnrollingPreferences(Loop *L, - UnrollingPreferences &UP) const { - PrevTTI->getUnrollingPreferences(L, UP); +void +TargetTransformInfo::getUnrollingPreferences(const Function *F, Loop *L, + UnrollingPreferences &UP) const { + PrevTTI->getUnrollingPreferences(F, L, UP); } bool TargetTransformInfo::isLegalAddImmediate(int64_t Imm) const { @@ -100,6 +101,17 @@ bool TargetTransformInfo::isLegalICmpImmediate(int64_t Imm) const { return PrevTTI->isLegalICmpImmediate(Imm); } +bool TargetTransformInfo::isLegalMaskedLoad(Type *DataType, + int Consecutive) const { + return false; +} + +bool TargetTransformInfo::isLegalMaskedStore(Type *DataType, + int Consecutive) const { + return false; +} + + bool TargetTransformInfo::isLegalAddressingMode(Type *Ty, GlobalValue *BaseGV, int64_t BaseOffset, bool HasBaseReg, @@ -167,15 +179,16 @@ unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { return PrevTTI->getRegisterBitWidth(Vector); } -unsigned TargetTransformInfo::getMaximumUnrollFactor() const { - return PrevTTI->getMaximumUnrollFactor(); +unsigned TargetTransformInfo::getMaxInterleaveFactor() const { + return PrevTTI->getMaxInterleaveFactor(); } -unsigned TargetTransformInfo::getArithmeticInstrCost(unsigned Opcode, - Type *Ty, - OperandValueKind Op1Info, - OperandValueKind Op2Info) const { - return PrevTTI->getArithmeticInstrCost(Opcode, Ty, Op1Info, Op2Info); +unsigned TargetTransformInfo::getArithmeticInstrCost( + unsigned Opcode, Type *Ty, OperandValueKind Op1Info, + OperandValueKind Op2Info, OperandValueProperties Opd1PropInfo, + OperandValueProperties Opd2PropInfo) const { + return PrevTTI->getArithmeticInstrCost(Opcode, Ty, Op1Info, Op2Info, + Opd1PropInfo, Opd2PropInfo); } unsigned TargetTransformInfo::getShuffleCost(ShuffleKind Kind, Type *Tp, @@ -206,7 +219,6 @@ unsigned TargetTransformInfo::getMemoryOpCost(unsigned Opcode, Type *Src, unsigned Alignment, unsigned AddressSpace) const { return PrevTTI->getMemoryOpCost(Opcode, Src, Alignment, AddressSpace); - ; } unsigned @@ -230,6 +242,11 @@ unsigned TargetTransformInfo::getReductionCost(unsigned Opcode, Type *Ty, return PrevTTI->getReductionCost(Opcode, Ty, IsPairwise); } +unsigned TargetTransformInfo::getCostOfKeepingLiveOverCall(ArrayRef<Type*> Tys) + const { + return PrevTTI->getCostOfKeepingLiveOverCall(Tys); +} + namespace { struct NoTTI final : ImmutablePass, TargetTransformInfo { @@ -239,7 +256,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { initializeNoTTIPass(*PassRegistry::getPassRegistry()); } - virtual void initializePass() override { + void initializePass() override { // Note that this subclass is special, and must *not* call initializeTTI as // it does not chain. TopTTI = this; @@ -248,7 +265,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { DL = DLP ? &DLP->getDataLayout() : nullptr; } - virtual void getAnalysisUsage(AnalysisUsage &AU) const override { + void getAnalysisUsage(AnalysisUsage &AU) const override { // Note that this subclass is special, and must *not* call // TTI::getAnalysisUsage as it breaks the recursion. } @@ -257,7 +274,7 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { static char ID; /// Provide necessary pointer adjustments for the two base classes. - virtual void *getAdjustedAnalysisPointer(const void *ID) override { + void *getAdjustedAnalysisPointer(const void *ID) override { if (ID == &TargetTransformInfo::ID) return (TargetTransformInfo*)this; return this; @@ -385,6 +402,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { // FIXME: This is wrong for libc intrinsics. return TCC_Basic; + case Intrinsic::annotation: + case Intrinsic::assume: case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::invariant_start: @@ -394,6 +413,10 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { case Intrinsic::objectsize: case Intrinsic::ptr_annotation: case Intrinsic::var_annotation: + case Intrinsic::experimental_gc_result_int: + case Intrinsic::experimental_gc_result_float: + case Intrinsic::experimental_gc_result_ptr: + case Intrinsic::experimental_gc_relocate: // These intrinsics don't actually represent code after lowering. return TCC_Free; } @@ -466,6 +489,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { // These will all likely lower to a single selection DAG node. if (Name == "copysign" || Name == "copysignf" || Name == "copysignl" || Name == "fabs" || Name == "fabsf" || Name == "fabsl" || Name == "sin" || + Name == "fmin" || Name == "fminf" || Name == "fminl" || + Name == "fmax" || Name == "fmaxf" || Name == "fmaxl" || Name == "sinf" || Name == "sinl" || Name == "cos" || Name == "cosf" || Name == "cosl" || Name == "sqrt" || Name == "sqrtf" || Name == "sqrtl") return false; @@ -480,8 +505,8 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { return true; } - void getUnrollingPreferences(Loop *, UnrollingPreferences &) const override { - } + void getUnrollingPreferences(const Function *, Loop *, + UnrollingPreferences &) const override {} bool isLegalAddImmediate(int64_t Imm) const override { return false; @@ -558,12 +583,13 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { return 32; } - unsigned getMaximumUnrollFactor() const override { + unsigned getMaxInterleaveFactor() const override { return 1; } unsigned getArithmeticInstrCost(unsigned Opcode, Type *Ty, OperandValueKind, - OperandValueKind) const override { + OperandValueKind, OperandValueProperties, + OperandValueProperties) const override { return 1; } @@ -612,6 +638,11 @@ struct NoTTI final : ImmutablePass, TargetTransformInfo { unsigned getReductionCost(unsigned, Type *, bool) const override { return 1; } + + unsigned getCostOfKeepingLiveOverCall(ArrayRef<Type*> Tys) const override { + return 0; + } + }; } // end anonymous namespace diff --git a/lib/Analysis/TypeBasedAliasAnalysis.cpp b/lib/Analysis/TypeBasedAliasAnalysis.cpp index f36f6f8a8768..085ce920139b 100644 --- a/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -167,7 +167,7 @@ namespace { bool TypeIsImmutable() const { if (Node->getNumOperands() < 3) return false; - ConstantInt *CI = dyn_cast<ConstantInt>(Node->getOperand(2)); + ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(2)); if (!CI) return false; return CI->getValue()[0]; @@ -194,7 +194,7 @@ namespace { return dyn_cast_or_null<MDNode>(Node->getOperand(1)); } uint64_t getOffset() const { - return cast<ConstantInt>(Node->getOperand(2))->getZExtValue(); + return mdconst::extract<ConstantInt>(Node->getOperand(2))->getZExtValue(); } /// TypeIsImmutable - Test if this TBAAStructTagNode represents a type for /// objects which are not modified (by any means) in the context where this @@ -202,7 +202,7 @@ namespace { bool TypeIsImmutable() const { if (Node->getNumOperands() < 4) return false; - ConstantInt *CI = dyn_cast<ConstantInt>(Node->getOperand(3)); + ConstantInt *CI = mdconst::dyn_extract<ConstantInt>(Node->getOperand(3)); if (!CI) return false; return CI->getValue()[0]; @@ -233,8 +233,10 @@ namespace { // Fast path for a scalar type node and a struct type node with a single // field. if (Node->getNumOperands() <= 3) { - uint64_t Cur = Node->getNumOperands() == 2 ? 0 : - cast<ConstantInt>(Node->getOperand(2))->getZExtValue(); + uint64_t Cur = Node->getNumOperands() == 2 + ? 0 + : mdconst::extract<ConstantInt>(Node->getOperand(2)) + ->getZExtValue(); Offset -= Cur; MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(1)); if (!P) @@ -246,8 +248,8 @@ namespace { // the current offset is bigger than the given offset. unsigned TheIdx = 0; for (unsigned Idx = 1; Idx < Node->getNumOperands(); Idx += 2) { - uint64_t Cur = cast<ConstantInt>(Node->getOperand(Idx + 1))-> - getZExtValue(); + uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(Idx + 1)) + ->getZExtValue(); if (Cur > Offset) { assert(Idx >= 3 && "TBAAStructTypeNode::getParent should have an offset match!"); @@ -258,8 +260,8 @@ namespace { // Move along the last field. if (TheIdx == 0) TheIdx = Node->getNumOperands() - 2; - uint64_t Cur = cast<ConstantInt>(Node->getOperand(TheIdx + 1))-> - getZExtValue(); + uint64_t Cur = mdconst::extract<ConstantInt>(Node->getOperand(TheIdx + 1)) + ->getZExtValue(); Offset -= Cur; MDNode *P = dyn_cast_or_null<MDNode>(Node->getOperand(TheIdx)); if (!P) @@ -454,9 +456,9 @@ TypeBasedAliasAnalysis::alias(const Location &LocA, // Get the attached MDNodes. If either value lacks a tbaa MDNode, we must // be conservative. - const MDNode *AM = LocA.TBAATag; + const MDNode *AM = LocA.AATags.TBAA; if (!AM) return AliasAnalysis::alias(LocA, LocB); - const MDNode *BM = LocB.TBAATag; + const MDNode *BM = LocB.AATags.TBAA; if (!BM) return AliasAnalysis::alias(LocA, LocB); // If they may alias, chain to the next AliasAnalysis. @@ -472,7 +474,7 @@ bool TypeBasedAliasAnalysis::pointsToConstantMemory(const Location &Loc, if (!EnableTBAA) return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); - const MDNode *M = Loc.TBAATag; + const MDNode *M = Loc.AATags.TBAA; if (!M) return AliasAnalysis::pointsToConstantMemory(Loc, OrLocal); // If this is an "immutable" type, we can assume the pointer is pointing @@ -513,9 +515,9 @@ TypeBasedAliasAnalysis::getModRefInfo(ImmutableCallSite CS, if (!EnableTBAA) return AliasAnalysis::getModRefInfo(CS, Loc); - if (const MDNode *L = Loc.TBAATag) + if (const MDNode *L = Loc.AATags.TBAA) if (const MDNode *M = - CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(L, M)) return NoModRef; @@ -529,9 +531,9 @@ TypeBasedAliasAnalysis::getModRefInfo(ImmutableCallSite CS1, return AliasAnalysis::getModRefInfo(CS1, CS2); if (const MDNode *M1 = - CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (const MDNode *M2 = - CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(M1, M2)) return NoModRef; @@ -608,6 +610,28 @@ MDNode *MDNode::getMostGenericTBAA(MDNode *A, MDNode *B) { return nullptr; // We need to convert from a type node to a tag node. Type *Int64 = IntegerType::get(A->getContext(), 64); - Value *Ops[3] = { Ret, Ret, ConstantInt::get(Int64, 0) }; + Metadata *Ops[3] = {Ret, Ret, + ConstantAsMetadata::get(ConstantInt::get(Int64, 0))}; return MDNode::get(A->getContext(), Ops); } + +void Instruction::getAAMetadata(AAMDNodes &N, bool Merge) const { + if (Merge) + N.TBAA = + MDNode::getMostGenericTBAA(N.TBAA, getMetadata(LLVMContext::MD_tbaa)); + else + N.TBAA = getMetadata(LLVMContext::MD_tbaa); + + if (Merge) + N.Scope = + MDNode::intersect(N.Scope, getMetadata(LLVMContext::MD_alias_scope)); + else + N.Scope = getMetadata(LLVMContext::MD_alias_scope); + + if (Merge) + N.NoAlias = + MDNode::intersect(N.NoAlias, getMetadata(LLVMContext::MD_noalias)); + else + N.NoAlias = getMetadata(LLVMContext::MD_noalias); +} + diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index b6a3c36b91fe..5d909177c078 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -14,12 +14,14 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" @@ -37,8 +39,8 @@ using namespace llvm::PatternMatch; const unsigned MaxDepth = 6; -/// getBitWidth - Returns the bitwidth of the given scalar or pointer type (if -/// unknown returns 0). For vector types, returns the element type's bitwidth. +/// Returns the bitwidth of the given scalar or pointer type (if unknown returns +/// 0). For vector types, returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout *TD) { if (unsigned BitWidth = Ty->getScalarSizeInBits()) return BitWidth; @@ -46,10 +48,123 @@ static unsigned getBitWidth(Type *Ty, const DataLayout *TD) { return TD ? TD->getPointerTypeSizeInBits(Ty) : 0; } +// Many of these functions have internal versions that take an assumption +// exclusion set. This is because of the potential for mutual recursion to +// cause computeKnownBits to repeatedly visit the same assume intrinsic. The +// classic case of this is assume(x = y), which will attempt to determine +// bits in x from bits in y, which will attempt to determine bits in y from +// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call +// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and +// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so on. +typedef SmallPtrSet<const Value *, 8> ExclInvsSet; + +namespace { +// Simplifying using an assume can only be done in a particular control-flow +// context (the context instruction provides that context). If an assume and +// the context instruction are not in the same block then the DT helps in +// figuring out if we can use it. +struct Query { + ExclInvsSet ExclInvs; + AssumptionCache *AC; + const Instruction *CxtI; + const DominatorTree *DT; + + Query(AssumptionCache *AC = nullptr, const Instruction *CxtI = nullptr, + const DominatorTree *DT = nullptr) + : AC(AC), CxtI(CxtI), DT(DT) {} + + Query(const Query &Q, const Value *NewExcl) + : ExclInvs(Q.ExclInvs), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT) { + ExclInvs.insert(NewExcl); + } +}; +} // end anonymous namespace + +// Given the provided Value and, potentially, a context instruction, return +// the preferred context instruction (if any). +static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { + // If we've been provided with a context instruction, then use that (provided + // it has been inserted). + if (CxtI && CxtI->getParent()) + return CxtI; + + // If the value is really an already-inserted instruction, then use that. + CxtI = dyn_cast<Instruction>(V); + if (CxtI && CxtI->getParent()) + return CxtI; + + return nullptr; +} + +static void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + ::computeKnownBits(V, KnownZero, KnownOne, TD, Depth, + Query(AC, safeCxtI(V, CxtI), DT)); +} + +static void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + ::ComputeSignBit(V, KnownZero, KnownOne, TD, Depth, + Query(AC, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q); + +bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, + Query(AC, safeCxtI(V, CxtI), DT)); +} + +static bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + return ::isKnownNonZero(V, TD, Depth, Query(AC, safeCxtI(V, CxtI), DT)); +} + +static bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q); + +bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, const DataLayout *TD, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT) { + return ::MaskedValueIsZero(V, Mask, TD, Depth, + Query(AC, safeCxtI(V, CxtI), DT)); +} + +static unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q); + +unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + return ::ComputeNumSignBits(V, TD, Depth, Query(AC, safeCxtI(V, CxtI), DT)); +} + static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { if (!Add) { if (ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { // We know that the top bits of C-X are clear if X contains less bits @@ -60,7 +175,7 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); // NLZ can't be BitWidth with no sign bit APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); - llvm::computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); // If all of the MaskV bits are known to be zero, then we know the // output top bits are zero, because we now know that the output is @@ -76,55 +191,51 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, unsigned BitWidth = KnownZero.getBitWidth(); - // If one of the operands has trailing zeros, then the bits that the - // other operand has in those bit positions will be preserved in the - // result. For an add, this works with either operand. For a subtract, - // this only works if the known zeros are in the right operand. + // If an initial sequence of bits in the result is not needed, the + // corresponding bits in the operands are not needed. APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - llvm::computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1); - unsigned LHSKnownZeroOut = LHSKnownZero.countTrailingOnes(); - - llvm::computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1); - unsigned RHSKnownZeroOut = KnownZero2.countTrailingOnes(); - - // Determine which operand has more trailing zeros, and use that - // many bits from the other operand. - if (LHSKnownZeroOut > RHSKnownZeroOut) { - if (Add) { - APInt Mask = APInt::getLowBitsSet(BitWidth, LHSKnownZeroOut); - KnownZero |= KnownZero2 & Mask; - KnownOne |= KnownOne2 & Mask; - } else { - // If the known zeros are in the left operand for a subtract, - // fall back to the minimum known zeros in both operands. - KnownZero |= APInt::getLowBitsSet(BitWidth, - std::min(LHSKnownZeroOut, - RHSKnownZeroOut)); - } - } else if (RHSKnownZeroOut >= LHSKnownZeroOut) { - APInt Mask = APInt::getLowBitsSet(BitWidth, RHSKnownZeroOut); - KnownZero |= LHSKnownZero & Mask; - KnownOne |= LHSKnownOne & Mask; + computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, TD, Depth+1, Q); + computeKnownBits(Op1, KnownZero2, KnownOne2, TD, Depth+1, Q); + + // Carry in a 1 for a subtract, rather than a 0. + APInt CarryIn(BitWidth, 0); + if (!Add) { + // Sum = LHS + ~RHS + 1 + std::swap(KnownZero2, KnownOne2); + CarryIn.setBit(0); } + APInt PossibleSumZero = ~LHSKnownZero + ~KnownZero2 + CarryIn; + APInt PossibleSumOne = LHSKnownOne + KnownOne2 + CarryIn; + + // Compute known bits of the carry. + APInt CarryKnownZero = ~(PossibleSumZero ^ LHSKnownZero ^ KnownZero2); + APInt CarryKnownOne = PossibleSumOne ^ LHSKnownOne ^ KnownOne2; + + // Compute set of known bits (where all three relevant bits are known). + APInt LHSKnown = LHSKnownZero | LHSKnownOne; + APInt RHSKnown = KnownZero2 | KnownOne2; + APInt CarryKnown = CarryKnownZero | CarryKnownOne; + APInt Known = LHSKnown & RHSKnown & CarryKnown; + + assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && + "known bits of sum differ"); + + // Compute known bits of the result. + KnownZero = ~PossibleSumOne & Known; + KnownOne = PossibleSumOne & Known; + // Are we still trying to solve for the sign bit? - if (!KnownZero.isNegative() && !KnownOne.isNegative()) { + if (!Known.isNegative()) { if (NSW) { - if (Add) { - // Adding two positive numbers can't wrap into negative - if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); - // and adding two negative numbers can't wrap into positive. - else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); - } else { - // Subtracting a negative number from a positive one can't wrap - if (LHSKnownZero.isNegative() && KnownOne2.isNegative()) - KnownZero |= APInt::getSignBit(BitWidth); - // neither can subtracting a positive number from a negative one. - else if (LHSKnownOne.isNegative() && KnownZero2.isNegative()) - KnownOne |= APInt::getSignBit(BitWidth); - } + // Adding two non-negative numbers, or subtracting a negative number from + // a non-negative one, can't wrap into negative. + if (LHSKnownZero.isNegative() && KnownZero2.isNegative()) + KnownZero |= APInt::getSignBit(BitWidth); + // Adding two negative numbers, or subtracting a non-negative number from + // a negative one, can't wrap into non-negative. + else if (LHSKnownOne.isNegative() && KnownOne2.isNegative()) + KnownOne |= APInt::getSignBit(BitWidth); } } } @@ -132,10 +243,11 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout *TD, unsigned Depth) { + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = KnownZero.getBitWidth(); - computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(Op1, KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(Op0, KnownZero2, KnownOne2, TD, Depth+1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -156,9 +268,9 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, // negative or zero. if (!isKnownNonNegative) isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, TD, Depth)) || + isKnownNonZero(Op0, TD, Depth, Q)) || (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, TD, Depth)); + isKnownNonZero(Op1, TD, Depth, Q)); } } @@ -198,8 +310,10 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, // Use the high end of the ranges to find leading zeros. unsigned MinLeadingZeros = BitWidth; for (unsigned i = 0; i < NumRanges; ++i) { - ConstantInt *Lower = cast<ConstantInt>(Ranges.getOperand(2*i + 0)); - ConstantInt *Upper = cast<ConstantInt>(Ranges.getOperand(2*i + 1)); + ConstantInt *Lower = + mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract<ConstantInt>(Ranges.getOperand(2 * i + 1)); ConstantRange Range(Lower->getValue(), Upper->getValue()); if (Range.isWrappedSet()) MinLeadingZeros = 0; // -1 has no zeros @@ -210,6 +324,414 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, KnownZero = APInt::getHighBitsSet(BitWidth, MinLeadingZeros); } +static bool isEphemeralValueOf(Instruction *I, const Value *E) { + SmallVector<const Value *, 16> WorkSet(1, I); + SmallPtrSet<const Value *, 32> Visited; + SmallPtrSet<const Value *, 16> EphValues; + + while (!WorkSet.empty()) { + const Value *V = WorkSet.pop_back_val(); + if (!Visited.insert(V).second) + continue; + + // If all uses of this value are ephemeral, then so is this value. + bool FoundNEUse = false; + for (const User *I : V->users()) + if (!EphValues.count(I)) { + FoundNEUse = true; + break; + } + + if (!FoundNEUse) { + if (V == E) + return true; + + EphValues.insert(V); + if (const User *U = dyn_cast<User>(V)) + for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); + J != JE; ++J) { + if (isSafeToSpeculativelyExecute(*J)) + WorkSet.push_back(*J); + } + } + } + + return false; +} + +// Is this an intrinsic that cannot be speculated but also cannot trap? +static bool isAssumeLikeIntrinsic(const Instruction *I) { + if (const CallInst *CI = dyn_cast<CallInst>(I)) + if (Function *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + default: break; + // FIXME: This list is repeated from NoTTI::getIntrinsicCost. + case Intrinsic::assume: + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + case Intrinsic::invariant_start: + case Intrinsic::invariant_end: + case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: + case Intrinsic::objectsize: + case Intrinsic::ptr_annotation: + case Intrinsic::var_annotation: + return true; + } + + return false; +} + +static bool isValidAssumeForContext(Value *V, const Query &Q, + const DataLayout *DL) { + Instruction *Inv = cast<Instruction>(V); + + // There are two restrictions on the use of an assume: + // 1. The assume must dominate the context (or the control flow must + // reach the assume whenever it reaches the context). + // 2. The context must not be in the assume's set of ephemeral values + // (otherwise we will use the assume to prove that the condition + // feeding the assume is trivially true, thus causing the removal of + // the assume). + + if (Q.DT) { + if (Q.DT->dominates(Inv, Q.CxtI)) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // The context comes first, but they're both in the same block. Make sure + // there is nothing in between that might interrupt the control flow. + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; + } + + // When we don't have a DT, we do a limited search... + if (Inv->getParent() == Q.CxtI->getParent()->getSinglePredecessor()) { + return true; + } else if (Inv->getParent() == Q.CxtI->getParent()) { + // Search forward from the assume until we reach the context (or the end + // of the block); the common case is that the assume will come first. + for (BasicBlock::iterator I = std::next(BasicBlock::iterator(Inv)), + IE = Inv->getParent()->end(); I != IE; ++I) + if (I == Q.CxtI) + return true; + + // The context must come first... + for (BasicBlock::const_iterator I = + std::next(BasicBlock::const_iterator(Q.CxtI)), + IE(Inv); I != IE; ++I) + if (!isSafeToSpeculativelyExecute(I, DL) && + !isAssumeLikeIntrinsic(I)) + return false; + + return !isEphemeralValueOf(Inv, Q.CxtI); + } + + return false; +} + +bool llvm::isValidAssumeForContext(const Instruction *I, + const Instruction *CxtI, + const DataLayout *DL, + const DominatorTree *DT) { + return ::isValidAssumeForContext(const_cast<Instruction*>(I), + Query(nullptr, CxtI, DT), DL); +} + +template<typename LHS, typename RHS> +inline match_combine_or<CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>, + CmpClass_match<RHS, LHS, ICmpInst, ICmpInst::Predicate>> +m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { + return m_CombineOr(m_ICmp(Pred, L, R), m_ICmp(Pred, R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::And>, + BinaryOp_match<RHS, LHS, Instruction::And>> +m_c_And(const LHS &L, const RHS &R) { + return m_CombineOr(m_And(L, R), m_And(R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Or>, + BinaryOp_match<RHS, LHS, Instruction::Or>> +m_c_Or(const LHS &L, const RHS &R) { + return m_CombineOr(m_Or(L, R), m_Or(R, L)); +} + +template<typename LHS, typename RHS> +inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Xor>, + BinaryOp_match<RHS, LHS, Instruction::Xor>> +m_c_Xor(const LHS &L, const RHS &R) { + return m_CombineOr(m_Xor(L, R), m_Xor(R, L)); +} + +static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, + APInt &KnownOne, + const DataLayout *DL, + unsigned Depth, const Query &Q) { + // Use of assumptions is context-sensitive. If we don't have a context, we + // cannot use them! + if (!Q.AC || !Q.CxtI) + return; + + unsigned BitWidth = KnownZero.getBitWidth(); + + for (auto &AssumeVH : Q.AC->assumptions()) { + if (!AssumeVH) + continue; + CallInst *I = cast<CallInst>(AssumeVH); + assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && + "Got assumption for the wrong function!"); + if (Q.ExclInvs.count(I)) + continue; + + // Warning: This loop can end up being somewhat performance sensetive. + // We're running this loop for once for each value queried resulting in a + // runtime of ~O(#assumes * #values). + + assert(isa<IntrinsicInst>(I) && + dyn_cast<IntrinsicInst>(I)->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + Value *Arg = I->getArgOperand(0); + + if (Arg == V && + isValidAssumeForContext(I, Q, DL)) { + assert(BitWidth == 1 && "assume operand is not i1?"); + KnownZero.clearAllBits(); + KnownOne.setAllBits(); + return; + } + + // The remaining tests are all recursive, so bail out if we hit the limit. + if (Depth == MaxDepth) + continue; + + Value *A, *B; + auto m_V = m_CombineOr(m_Specific(V), + m_CombineOr(m_PtrToInt(m_Specific(V)), + m_BitCast(m_Specific(V)))); + + CmpInst::Predicate Pred; + ConstantInt *C; + // assume(v = a) + if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + KnownZero |= RHSKnownZero; + KnownOne |= RHSKnownOne; + // assume(v & b = a) + } else if (match(Arg, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // known bits from the RHS to V. + KnownZero |= RHSKnownZero & MaskKnownOne; + KnownOne |= RHSKnownOne & MaskKnownOne; + // assume(~(v & b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in the mask that are known to be one, we can propagate + // inverted known bits from the RHS to V. + KnownZero |= RHSKnownOne & MaskKnownOne; + KnownOne |= RHSKnownZero & MaskKnownOne; + // assume(v | b = a) + } else if (match(Arg, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. + KnownZero |= RHSKnownZero & BKnownZero; + KnownOne |= RHSKnownOne & BKnownZero; + // assume(~(v | b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. + KnownZero |= RHSKnownOne & BKnownZero; + KnownOne |= RHSKnownZero & BKnownZero; + // assume(v ^ b = a) + } else if (match(Arg, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate known + // bits from the RHS to V. For those bits in B that are known to be one, + // we can propagate inverted known bits from the RHS to V. + KnownZero |= RHSKnownZero & BKnownZero; + KnownOne |= RHSKnownOne & BKnownZero; + KnownZero |= RHSKnownOne & BKnownOne; + KnownOne |= RHSKnownZero & BKnownOne; + // assume(~(v ^ b) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); + computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + + // For those bits in B that are known to be zero, we can propagate + // inverted known bits from the RHS to V. For those bits in B that are + // known to be one, we can propagate known bits from the RHS to V. + KnownZero |= RHSKnownOne & BKnownZero; + KnownOne |= RHSKnownZero & BKnownZero; + KnownZero |= RHSKnownZero & BKnownOne; + KnownOne |= RHSKnownOne & BKnownOne; + // assume(v << c = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + KnownZero |= RHSKnownZero.lshr(C->getZExtValue()); + KnownOne |= RHSKnownOne.lshr(C->getZExtValue()); + // assume(~(v << c) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + KnownZero |= RHSKnownOne.lshr(C->getZExtValue()); + KnownOne |= RHSKnownZero.lshr(C->getZExtValue()); + // assume(v >> c = a) + } else if (match(Arg, + m_c_ICmp(Pred, m_CombineOr(m_LShr(m_V, m_ConstantInt(C)), + m_AShr(m_V, + m_ConstantInt(C))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them to known + // bits in V shifted to the right by C. + KnownZero |= RHSKnownZero << C->getZExtValue(); + KnownOne |= RHSKnownOne << C->getZExtValue(); + // assume(~(v >> c) = a) + } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_CombineOr( + m_LShr(m_V, m_ConstantInt(C)), + m_AShr(m_V, m_ConstantInt(C)))), + m_Value(A))) && + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + // For those bits in RHS that are known, we can propagate them inverted + // to known bits in V shifted to the right by C. + KnownZero |= RHSKnownOne << C->getZExtValue(); + KnownOne |= RHSKnownZero << C->getZExtValue(); + // assume(v >=_s c) where c is non-negative + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SGE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownZero.isNegative()) { + // We know that the sign bit is zero. + KnownZero |= APInt::getSignBit(BitWidth); + } + // assume(v >_s c) where c is at least -1. + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SGT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownOne.isAllOnesValue() || RHSKnownZero.isNegative()) { + // We know that the sign bit is zero. + KnownZero |= APInt::getSignBit(BitWidth); + } + // assume(v <=_s c) where c is negative + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SLE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownOne.isNegative()) { + // We know that the sign bit is one. + KnownOne |= APInt::getSignBit(BitWidth); + } + // assume(v <_s c) where c is non-positive + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_SLT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + if (RHSKnownZero.isAllOnesValue() || RHSKnownOne.isNegative()) { + // We know that the sign bit is one. + KnownOne |= APInt::getSignBit(BitWidth); + } + // assume(v <=_u c) + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_ULE && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + // Whatever high bits in c are zero are known to be zero. + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + // assume(v <_u c) + } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && + Pred == ICmpInst::ICMP_ULT && + isValidAssumeForContext(I, Q, DL)) { + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + + // Whatever high bits in c are zero are known to be zero (if c is a power + // of 2, then one more). + if (isKnownToBeAPowerOfTwo(A, false, Depth+1, Query(Q, I))) + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()+1); + else + KnownZero |= + APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); + } + } +} + /// Determine which bits of V are known to be either zero or one and return /// them in the KnownZero/KnownOne bit sets. /// @@ -225,8 +747,9 @@ void llvm::computeKnownBitsFromRangeMetadata(const MDNode &Ranges, /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - const DataLayout *TD, unsigned Depth) { +void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); unsigned BitWidth = KnownZero.getBitWidth(); @@ -272,10 +795,10 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, } // The address of an aligned GlobalValue has trailing zeros. - if (GlobalValue *GV = dyn_cast<GlobalValue>(V)) { - unsigned Align = GV->getAlignment(); + if (auto *GO = dyn_cast<GlobalObject>(V)) { + unsigned Align = GO->getAlignment(); if (Align == 0 && TD) { - if (GlobalVariable *GVar = dyn_cast<GlobalVariable>(GV)) { + if (auto *GVar = dyn_cast<GlobalVariable>(GO)) { Type *ObjectType = GVar->getType()->getElementType(); if (ObjectType->isSized()) { // If the object is defined in the current Module, we'll be giving @@ -296,25 +819,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownOne.clearAllBits(); return; } - // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has - // the bits of its aliasee. - if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->mayBeOverridden()) { - KnownZero.clearAllBits(); KnownOne.clearAllBits(); - } else { - computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth+1); - } - return; - } if (Argument *A = dyn_cast<Argument>(V)) { - unsigned Align = 0; + unsigned Align = A->getType()->isPointerTy() ? A->getParamAlignment() : 0; - if (A->hasByValOrInAllocaAttr()) { - // Get alignment information off byval/inalloca arguments if specified in - // the IR. - Align = A->getParamAlignment(); - } else if (TD && A->hasStructRetAttr()) { + if (!Align && TD && A->hasStructRetAttr()) { // An sret parameter has at least the ABI alignment of the return type. Type *EltTy = cast<PointerType>(A->getType())->getElementType(); if (EltTy->isSized()) @@ -323,14 +832,34 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (Align) KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); + else + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); + + // Don't give up yet... there might be an assumption that provides more + // information... + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); return; } // Start out not knowing anything. KnownZero.clearAllBits(); KnownOne.clearAllBits(); + // Limit search depth. + // All recursive calls that increase depth must come after this. if (Depth == MaxDepth) - return; // Limit search depth. + return; + + // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has + // the bits of its aliasee. + if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { + if (!GA->mayBeOverridden()) + computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, TD, Depth + 1, Q); + return; + } + + // Check whether a nearby assume intrinsic can determine some known bits. + computeKnownBitsFromAssume(V, KnownZero, KnownOne, TD, Depth, Q); Operator *I = dyn_cast<Operator>(V); if (!I) return; @@ -344,8 +873,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-1 bits are only known if set in both the LHS & RHS. KnownOne &= KnownOne2; @@ -354,8 +883,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are only known if clear in both the LHS & RHS. KnownZero &= KnownZero2; @@ -364,8 +893,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); // Output known-0 bits are known if clear or set in both the LHS & RHS. APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); @@ -377,19 +906,20 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Instruction::Mul: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownZero, KnownOne, KnownZero2, KnownOne2, TD, + Depth, Q); break; } case Instruction::UDiv: { // For the purposes of computing leading zeros we can conservatively // treat a udiv as a logical right shift by the power of 2 known to // be less than the denominator. - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned LeadZ = KnownZero2.countLeadingOnes(); KnownOne2.clearAllBits(); KnownZero2.clearAllBits(); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); if (RHSUnknownLeadingOnes != BitWidth) LeadZ = std::min(BitWidth, @@ -399,9 +929,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; } case Instruction::Select: - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, - Depth+1); + computeKnownBits(I->getOperand(2), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); // Only known if known in both the LHS and RHS. KnownOne &= KnownOne2; @@ -437,7 +966,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, assert(SrcBitWidth && "SrcBitWidth can't be zero"); KnownZero = KnownZero.zextOrTrunc(SrcBitWidth); KnownOne = KnownOne.zextOrTrunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zextOrTrunc(BitWidth); KnownOne = KnownOne.zextOrTrunc(BitWidth); // Any top bits are known to be zero. @@ -451,7 +980,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); break; } break; @@ -462,7 +991,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = KnownZero.zext(BitWidth); KnownOne = KnownOne.zext(BitWidth); @@ -478,11 +1007,10 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 if (ConstantInt *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero <<= ShiftAmt; KnownOne <<= ShiftAmt; KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); // low bits known 0 - break; } break; case Instruction::LShr: @@ -492,12 +1020,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth); // Unsigned shift right. - computeKnownBits(I->getOperand(0), KnownZero,KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); // high bits known zero. KnownZero |= APInt::getHighBitsSet(BitWidth, ShiftAmt); - break; } break; case Instruction::AShr: @@ -507,7 +1034,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); // Signed shift right. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); @@ -516,21 +1043,20 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, KnownZero |= HighBits; else if (KnownOne[BitWidth-ShiftAmt-1]) // New bits are known one. KnownOne |= HighBits; - break; } break; case Instruction::Sub: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::Add: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, KnownZero, KnownOne, KnownZero2, KnownOne2, TD, - Depth); + Depth, Q); break; } case Instruction::SRem: @@ -538,7 +1064,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { APInt LowBits = RA - 1; - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, TD, + Depth+1, Q); // The low bits of the first operand are unchanged by the srem. KnownZero = KnownZero2 & LowBits; @@ -563,7 +1090,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, TD, - Depth+1); + Depth+1, Q); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(BitWidth - 1); @@ -576,7 +1103,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, if (RA.isPowerOf2()) { APInt LowBits = (RA - 1); computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, - Depth+1); + Depth+1, Q); KnownZero |= ~LowBits; KnownOne &= LowBits; break; @@ -585,8 +1112,8 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Since the result is less than or equal to either operand, any leading // zero bits in either operand must also exist in the result. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, TD, Depth+1, Q); unsigned Leaders = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); @@ -610,7 +1137,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // to determine if we can prove known low zero bits. APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, TD, - Depth+1); + Depth+1, Q); unsigned TrailZ = LocalKnownZero.countTrailingOnes(); gep_type_iterator GTI = gep_type_begin(I); @@ -646,7 +1173,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); uint64_t TypeSize = TD ? TD->getTypeAllocSize(IndexedTy) : 1; LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1); + computeKnownBits(Index, LocalKnownZero, LocalKnownOne, TD, Depth+1, Q); TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + LocalKnownZero.countTrailingOnes())); @@ -688,11 +1215,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, break; // Ok, we have a PHI of the form L op= R. Check for low // zero bits. - computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1); + computeKnownBits(R, KnownZero2, KnownOne2, TD, Depth+1, Q); // We need to take the minimum number of known bits APInt KnownZero3(KnownZero), KnownOne3(KnownOne); - computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1); + computeKnownBits(L, KnownZero3, KnownOne3, TD, Depth+1, Q); KnownZero = APInt::getLowBitsSet(BitWidth, std::min(KnownZero2.countTrailingOnes(), @@ -724,7 +1251,7 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // Recurse, but cap the recursion to one level, because we don't // want to waste time spinning around in loops. computeKnownBits(P->getIncomingValue(i), KnownZero2, KnownOne2, TD, - MaxDepth-1); + MaxDepth-1, Q); KnownZero &= KnownZero2; KnownOne &= KnownOne2; // If all bits have been ruled out, there's no need to check @@ -776,19 +1303,19 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, TD, Depth); + KnownOne, KnownZero2, KnownOne2, TD, Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, KnownOne, - KnownZero2, KnownOne2, TD, Depth); + KnownZero2, KnownOne2, TD, Depth, Q); break; } } @@ -798,10 +1325,11 @@ void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); } -/// ComputeSignBit - Determine whether the sign bit is known to be zero or -/// one. Convenience wrapper around computeKnownBits. -void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout *TD, unsigned Depth) { +/// Determine whether the sign bit is known to be zero or one. +/// Convenience wrapper around computeKnownBits. +void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, + const DataLayout *TD, unsigned Depth, + const Query &Q) { unsigned BitWidth = getBitWidth(V->getType(), TD); if (!BitWidth) { KnownZero = false; @@ -810,16 +1338,17 @@ void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, } APInt ZeroBits(BitWidth, 0); APInt OneBits(BitWidth, 0); - computeKnownBits(V, ZeroBits, OneBits, TD, Depth); + computeKnownBits(V, ZeroBits, OneBits, TD, Depth, Q); KnownOne = OneBits[BitWidth - 1]; KnownZero = ZeroBits[BitWidth - 1]; } -/// isKnownToBeAPowerOfTwo - Return true if the given value is known to have exactly one +/// Return true if the given value is known to have exactly one /// bit set when defined. For vectors return true if every element is known to -/// be a power of two when defined. Supports values with integer or pointer +/// be a power of two when defined. Supports values with integer or pointer /// types and vectors of integers. -bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { +bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return OrZero; @@ -846,19 +1375,20 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // A shift of a power of two is a power of two or zero. if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || match(V, m_Shr(m_Value(X), m_Value())))) - return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth); + return isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q); if (ZExtInst *ZI = dyn_cast<ZExtInst>(V)) - return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); if (SelectInst *SI = dyn_cast<SelectInst>(V)) - return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth) && - isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth); + return + isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && + isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { // A power of two and'd with anything is a power of two or zero. - if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth) || - isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth)) + if (isKnownToBeAPowerOfTwo(X, /*OrZero*/true, Depth, Q) || + isKnownToBeAPowerOfTwo(Y, /*OrZero*/true, Depth, Q)) return true; // X & (-X) is always a power of two or zero. if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) @@ -873,19 +1403,19 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) - if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) return true; if (match(Y, m_And(m_Specific(X), m_Value())) || match(Y, m_And(m_Value(), m_Specific(X)))) - if (isKnownToBeAPowerOfTwo(X, OrZero, Depth)) + if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) return true; unsigned BitWidth = V->getType()->getScalarSizeInBits(); APInt LHSZeroBits(BitWidth, 0), LHSOneBits(BitWidth, 0); - computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth); + computeKnownBits(X, LHSZeroBits, LHSOneBits, nullptr, Depth, Q); APInt RHSZeroBits(BitWidth, 0), RHSOneBits(BitWidth, 0); - computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth); + computeKnownBits(Y, RHSZeroBits, RHSOneBits, nullptr, Depth, Q); // If i8 V is a power of two or zero: // ZeroBits: 1 1 1 0 1 1 1 1 // ~ZeroBits: 0 0 0 1 0 0 0 0 @@ -902,7 +1432,8 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { // copying a sign bit (sdiv int_min, 2). if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { - return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, Depth); + return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, + Depth, Q); } return false; @@ -915,7 +1446,7 @@ bool llvm::isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth) { /// /// Currently this routine does not support vector GEPs. static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, - unsigned Depth) { + unsigned Depth, const Query &Q) { if (!GEP->isInBounds() || GEP->getPointerAddressSpace() != 0) return false; @@ -924,7 +1455,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, // If the base pointer is non-null, we cannot walk to a null address with an // inbounds GEP in address space zero. - if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth)) + if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth, Q)) return true; // Past this, if we don't have DataLayout, we can't do much. @@ -967,18 +1498,38 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout *DL, if (Depth++ >= MaxDepth) continue; - if (isKnownNonZero(GTI.getOperand(), DL, Depth)) + if (isKnownNonZero(GTI.getOperand(), DL, Depth, Q)) return true; } return false; } -/// isKnownNonZero - Return true if the given value is known to be non-zero -/// when defined. For vectors return true if every element is known to be -/// non-zero when defined. Supports values with integer or pointer type and -/// vectors of integers. -bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { +/// Does the 'Range' metadata (which must be a valid MD_range operand list) +/// ensure that the value it's attached to is never Value? 'RangeType' is +/// is the type of the value described by the range. +static bool rangeMetadataExcludesValue(MDNode* Ranges, + const APInt& Value) { + const unsigned NumRanges = Ranges->getNumOperands() / 2; + assert(NumRanges >= 1); + for (unsigned i = 0; i < NumRanges; ++i) { + ConstantInt *Lower = + mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 0)); + ConstantInt *Upper = + mdconst::extract<ConstantInt>(Ranges->getOperand(2 * i + 1)); + ConstantRange Range(Lower->getValue(), Upper->getValue()); + if (Range.contains(Value)) + return false; + } + return true; +} + +/// Return true if the given value is known to be non-zero when defined. +/// For vectors return true if every element is known to be non-zero when +/// defined. Supports values with integer or pointer type and vectors of +/// integers. +bool isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth, + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return false; @@ -989,6 +1540,18 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { return false; } + if (Instruction* I = dyn_cast<Instruction>(V)) { + if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { + // If the possible ranges don't contain zero, then the value is + // definitely non-zero. + if (IntegerType* Ty = dyn_cast<IntegerType>(V->getType())) { + const APInt ZeroValue(Ty->getBitWidth(), 0); + if (rangeMetadataExcludesValue(Ranges, ZeroValue)) + return true; + } + } + } + // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; @@ -998,7 +1561,7 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { if (isKnownNonNull(V)) return true; if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) - if (isGEPKnownNonNull(GEP, TD, Depth)) + if (isGEPKnownNonNull(GEP, TD, Depth, Q)) return true; } @@ -1007,11 +1570,12 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // X | Y != 0 if X != 0 or Y != 0. Value *X = nullptr, *Y = nullptr; if (match(V, m_Or(m_Value(X), m_Value(Y)))) - return isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q); // ext X != 0 if X != 0. if (isa<SExtInst>(V) || isa<ZExtInst>(V)) - return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth); + return isKnownNonZero(cast<Instruction>(V)->getOperand(0), TD, Depth, Q); // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined // if the lowest bit is shifted off the end. @@ -1019,11 +1583,11 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shl nuw can't remove any non-zero bits. OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); if (BO->hasNoUnsignedWrap()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if (KnownOne[0]) return true; } @@ -1033,28 +1597,29 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // shr exact can only shift out zero bits. PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); if (BO->isExact()) - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); bool XKnownNonNegative, XKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); if (XKnownNegative) return true; } // div exact can only produce a zero if the dividend is zero. else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { - return isKnownNonZero(X, TD, Depth); + return isKnownNonZero(X, TD, Depth, Q); } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { bool XKnownNonNegative, XKnownNegative; bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, TD, Depth, Q); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, TD, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. if (XKnownNonNegative && YKnownNonNegative) - if (isKnownNonZero(X, TD, Depth) || isKnownNonZero(Y, TD, Depth)) + if (isKnownNonZero(X, TD, Depth, Q) || + isKnownNonZero(Y, TD, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not @@ -1065,20 +1630,22 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { APInt Mask = APInt::getSignedMaxValue(BitWidth); // The sign bit of X is set. If some other bit is set then X is not equal // to INT_MIN. - computeKnownBits(X, KnownZero, KnownOne, TD, Depth); + computeKnownBits(X, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; // The sign bit of Y is set. If some other bit is set then Y is not equal // to INT_MIN. - computeKnownBits(Y, KnownZero, KnownOne, TD, Depth); + computeKnownBits(Y, KnownZero, KnownOne, TD, Depth, Q); if ((KnownOne & Mask) != 0) return true; } // The sum of a non-negative number and a power of two is not zero. - if (XKnownNonNegative && isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth)) + if (XKnownNonNegative && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/false, Depth, Q)) return true; - if (YKnownNonNegative && isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth)) + if (YKnownNonNegative && + isKnownToBeAPowerOfTwo(X, /*OrZero*/false, Depth, Q)) return true; } // X * Y. @@ -1087,51 +1654,53 @@ bool llvm::isKnownNonZero(Value *V, const DataLayout *TD, unsigned Depth) { // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && - isKnownNonZero(X, TD, Depth) && isKnownNonZero(Y, TD, Depth)) + isKnownNonZero(X, TD, Depth, Q) && + isKnownNonZero(Y, TD, Depth, Q)) return true; } // (C ? X : Y) != 0 if X != 0 and Y != 0. else if (SelectInst *SI = dyn_cast<SelectInst>(V)) { - if (isKnownNonZero(SI->getTrueValue(), TD, Depth) && - isKnownNonZero(SI->getFalseValue(), TD, Depth)) + if (isKnownNonZero(SI->getTrueValue(), TD, Depth, Q) && + isKnownNonZero(SI->getFalseValue(), TD, Depth, Q)) return true; } if (!BitWidth) return false; APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return KnownOne != 0; } -/// MaskedValueIsZero - Return true if 'V & Mask' is known to be zero. We use -/// this predicate to simplify operations downstream. Mask is known to be zero -/// for bits that V cannot have. +/// Return true if 'V & Mask' is known to be zero. We use this predicate to +/// simplify operations downstream. Mask is known to be zero for bits that V +/// cannot have. /// /// This function is defined on values with integer type, values with pointer /// type (but only if TD is non-null), and vectors of integers. In the case /// where V is a vector, the mask, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, - const DataLayout *TD, unsigned Depth) { +bool MaskedValueIsZero(Value *V, const APInt &Mask, + const DataLayout *TD, unsigned Depth, + const Query &Q) { APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); return (KnownZero & Mask) == Mask; } -/// ComputeNumSignBits - Return the number of times the sign bit of the -/// register is replicated into the other bits. We know that at least 1 bit -/// is always equal to the sign bit (itself), but other cases can give us -/// information. For example, immediately after an "ashr X, 2", we know that -/// the top 3 bits are all equal to each other, so we return 3. +/// Return the number of times the sign bit of the register is replicated into +/// the other bits. We know that at least 1 bit is always equal to the sign bit +/// (itself), but other cases can give us information. For example, immediately +/// after an "ashr X, 2", we know that the top 3 bits are all equal to each +/// other, so we return 3. /// /// 'Op' must have a scalar integer type. /// -unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, - unsigned Depth) { +unsigned ComputeNumSignBits(Value *V, const DataLayout *TD, + unsigned Depth, const Query &Q) { assert((TD || V->getType()->isIntOrIntVectorTy()) && "ComputeNumSignBits requires a DataLayout object to operate " "on non-integer values!"); @@ -1152,10 +1721,10 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, default: break; case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); - return ComputeNumSignBits(U->getOperand(0), TD, Depth+1) + Tmp; + return ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q) + Tmp; case Instruction::AShr: { - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { @@ -1168,7 +1737,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); Tmp2 = ShAmt->getZExtValue(); if (Tmp2 >= TyBits || // Bad shift. Tmp2 >= Tmp) break; // Shifted all sign bits out. @@ -1180,9 +1749,9 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, case Instruction::Or: case Instruction::Xor: // NOT is handled here. // Logical binary ops preserve the number of sign bits at the worst. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp != 1) { - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); FirstAnswer = std::min(Tmp, Tmp2); // We computed what we know about the sign bits as our first // answer. Now proceed to the generic code that uses @@ -1191,22 +1760,22 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, break; case Instruction::Select: - Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. - Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(2), TD, Depth+1, Q); return std::min(Tmp, Tmp2); case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. // Special case decrementing a value (ADD X, -1): - if (ConstantInt *CRHS = dyn_cast<ConstantInt>(U->getOperand(1))) + if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) if (CRHS->isAllOnesValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. @@ -1219,19 +1788,19 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, return Tmp; } - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; return std::min(Tmp, Tmp2)-1; case Instruction::Sub: - Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1); + Tmp2 = ComputeNumSignBits(U->getOperand(1), TD, Depth+1, Q); if (Tmp2 == 1) return 1; // Handle NEG. - if (ConstantInt *CLHS = dyn_cast<ConstantInt>(U->getOperand(0))) + if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) if (CLHS->isNullValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1); + computeKnownBits(U->getOperand(1), KnownZero, KnownOne, TD, Depth+1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) @@ -1247,22 +1816,26 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1); + Tmp = ComputeNumSignBits(U->getOperand(0), TD, Depth+1, Q); if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; case Instruction::PHI: { PHINode *PN = cast<PHINode>(U); + unsigned NumIncomingValues = PN->getNumIncomingValues(); // Don't analyze large in-degree PHIs. - if (PN->getNumIncomingValues() > 4) break; + if (NumIncomingValues > 4) break; + // Unreachable blocks may have zero-operand PHI nodes. + if (NumIncomingValues == 0) break; // Take the minimum of all incoming values. This can't infinitely loop // because of our depth threshold. - Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1); - for (unsigned i = 1, e = PN->getNumIncomingValues(); i != e; ++i) { + Tmp = ComputeNumSignBits(PN->getIncomingValue(0), TD, Depth+1, Q); + for (unsigned i = 1, e = NumIncomingValues; i != e; ++i) { if (Tmp == 1) return Tmp; Tmp = std::min(Tmp, - ComputeNumSignBits(PN->getIncomingValue(i), TD, Depth+1)); + ComputeNumSignBits(PN->getIncomingValue(i), TD, + Depth+1, Q)); } return Tmp; } @@ -1277,7 +1850,7 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, // use this information. APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); APInt Mask; - computeKnownBits(V, KnownZero, KnownOne, TD, Depth); + computeKnownBits(V, KnownZero, KnownOne, TD, Depth, Q); if (KnownZero.isNegative()) { // sign bit is 0 Mask = KnownZero; @@ -1297,9 +1870,9 @@ unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout *TD, return std::max(FirstAnswer, std::min(TyBits, Mask.countLeadingZeros())); } -/// ComputeMultiple - This function computes the integer multiple of Base that -/// equals V. If successful, it returns true and returns the multiple in -/// Multiple. If unsuccessful, it returns false. It looks +/// This function computes the integer multiple of Base that equals V. +/// If successful, it returns true and returns the multiple in +/// Multiple. If unsuccessful, it returns false. It looks /// through SExt instructions only if LookThroughSExt is true. bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, bool LookThroughSExt, unsigned Depth) { @@ -1417,8 +1990,8 @@ bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, return false; } -/// CannotBeNegativeZero - Return true if we can prove that the specified FP -/// value is never equal to -0.0. +/// Return true if we can prove that the specified FP value is never equal to +/// -0.0. /// /// NOTE: this function will need to be revisited when we support non-default /// rounding modes! @@ -1471,8 +2044,8 @@ bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) { return false; } -/// isBytewiseValue - If the specified value can be set by repeating the same -/// byte in memory, return the i8 value that it is represented with. This is +/// If the specified value can be set by repeating the same byte in memory, +/// return the i8 value that it is represented with. This is /// true for all i8 values obviously, but is also true for i32 0, i32 -1, /// i16 0xF0F0, double 0.0 etc. If the value can't be handled with a repeated /// byte store (e.g. i16 0x1234), return null. @@ -1620,7 +2193,7 @@ static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range, return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); } -/// FindInsertedValue - Given an aggregrate and an sequence of indices, see if +/// Given an aggregrate and an sequence of indices, see if /// the scalar value indexed is already around as a register, for example if it /// were inserted directly into the aggregrate. /// @@ -1710,9 +2283,8 @@ Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range, return nullptr; } -/// GetPointerBaseWithConstantOffset - Analyze the specified pointer to see if -/// it can be expressed as a base pointer plus a constant offset. Return the -/// base and offset to the caller. +/// Analyze the specified pointer to see if it can be expressed as a base +/// pointer plus a constant offset. Return the base and offset to the caller. Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout *DL) { // Without DataLayout, conservatively assume 64-bit offsets, which is @@ -1749,9 +2321,9 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, } -/// getConstantStringInfo - This function computes the length of a -/// null-terminated C string pointed to by V. If successful, it returns true -/// and returns the string in Str. If unsuccessful, it returns false. +/// This function computes the length of a null-terminated C string pointed to +/// by V. If successful, it returns true and returns the string in Str. +/// If unsuccessful, it returns false. bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, uint64_t Offset, bool TrimAtNul) { assert(V); @@ -1835,16 +2407,16 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, // nodes. // TODO: See if we can integrate these two together. -/// GetStringLengthH - If we can compute the length of the string pointed to by +/// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. -static uint64_t GetStringLengthH(Value *V, SmallPtrSet<PHINode*, 32> &PHIs) { +static uint64_t GetStringLengthH(Value *V, SmallPtrSetImpl<PHINode*> &PHIs) { // Look through noop bitcast instructions. V = V->stripPointerCasts(); // If this is a PHI node, there are two cases: either we have already seen it // or we haven't. if (PHINode *PN = dyn_cast<PHINode>(V)) { - if (!PHIs.insert(PN)) + if (!PHIs.insert(PN).second) return ~0ULL; // already in the set. // If it was new, see if all the input strings are the same length. @@ -1884,7 +2456,7 @@ static uint64_t GetStringLengthH(Value *V, SmallPtrSet<PHINode*, 32> &PHIs) { return StrData.size()+1; } -/// GetStringLength - If we can compute the length of the string pointed to by +/// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. uint64_t llvm::GetStringLength(Value *V) { if (!V->getType()->isPointerTy()) return 0; @@ -1913,7 +2485,7 @@ llvm::GetUnderlyingObject(Value *V, const DataLayout *TD, unsigned MaxLookup) { } else { // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Acquire a DominatorTree and use it. + // TODO: Acquire a DominatorTree and AssumptionCache and use them. if (Value *Simplified = SimplifyInstruction(I, TD, nullptr)) { V = Simplified; continue; @@ -1938,7 +2510,7 @@ llvm::GetUnderlyingObjects(Value *V, Value *P = Worklist.pop_back_val(); P = GetUnderlyingObject(P, TD, MaxLookup); - if (!Visited.insert(P)) + if (!Visited.insert(P).second) continue; if (SelectInst *SI = dyn_cast<SelectInst>(P)) { @@ -1957,9 +2529,7 @@ llvm::GetUnderlyingObjects(Value *V, } while (!Worklist.empty()); } -/// onlyUsedByLifetimeMarkers - Return true if the only users of this pointer -/// are lifetime markers. -/// +/// Return true if the only users of this pointer are lifetime markers. bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { for (const User *U : V->users()) { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); @@ -2022,41 +2592,44 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, return LI->getPointerOperand()->isDereferenceablePointer(TD); } case Instruction::Call: { - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { - switch (II->getIntrinsicID()) { - // These synthetic intrinsics have no side-effects and just mark - // information about their operands. - // FIXME: There are other no-op synthetic instructions that potentially - // should be considered at least *safe* to speculate... - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: - return true; - - case Intrinsic::bswap: - case Intrinsic::ctlz: - case Intrinsic::ctpop: - case Intrinsic::cttz: - case Intrinsic::objectsize: - case Intrinsic::sadd_with_overflow: - case Intrinsic::smul_with_overflow: - case Intrinsic::ssub_with_overflow: - case Intrinsic::uadd_with_overflow: - case Intrinsic::umul_with_overflow: - case Intrinsic::usub_with_overflow: - return true; - // Sqrt should be OK, since the llvm sqrt intrinsic isn't defined to set - // errno like libm sqrt would. - case Intrinsic::sqrt: - case Intrinsic::fma: - case Intrinsic::fmuladd: - return true; - // TODO: some fp intrinsics are marked as having the same error handling - // as libm. They're safe to speculate when they won't error. - // TODO: are convert_{from,to}_fp16 safe? - // TODO: can we list target-specific intrinsics here? - default: break; - } - } + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + // These synthetic intrinsics have no side-effects and just mark + // information about their operands. + // FIXME: There are other no-op synthetic instructions that potentially + // should be considered at least *safe* to speculate... + case Intrinsic::dbg_declare: + case Intrinsic::dbg_value: + return true; + + case Intrinsic::bswap: + case Intrinsic::ctlz: + case Intrinsic::ctpop: + case Intrinsic::cttz: + case Intrinsic::objectsize: + case Intrinsic::sadd_with_overflow: + case Intrinsic::smul_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::umul_with_overflow: + case Intrinsic::usub_with_overflow: + return true; + // Sqrt should be OK, since the llvm sqrt intrinsic isn't defined to set + // errno like libm sqrt would. + case Intrinsic::sqrt: + case Intrinsic::fma: + case Intrinsic::fmuladd: + case Intrinsic::fabs: + case Intrinsic::minnum: + case Intrinsic::maxnum: + return true; + // TODO: some fp intrinsics are marked as having the same error handling + // as libm. They're safe to speculate when they won't error. + // TODO: are convert_{from,to}_fp16 safe? + // TODO: can we list target-specific intrinsics here? + default: break; + } + } return false; // The called function could have undefined behavior or // side-effects, even if marked readnone nounwind. } @@ -2079,8 +2652,7 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, } } -/// isKnownNonNull - Return true if we know that the specified value is never -/// null. +/// Return true if we know that the specified value is never null. bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { // Alloca never returns null, malloc might. if (isa<AllocaInst>(V)) return true; @@ -2093,6 +2665,10 @@ bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) return !GV->hasExternalWeakLinkage(); + // A Load tagged w/nonnull metadata is never null. + if (const LoadInst *LI = dyn_cast<LoadInst>(V)) + return LI->getMetadata(LLVMContext::MD_nonnull); + if (ImmutableCallSite CS = V) if (CS.isReturnNonNull()) return true; @@ -2103,3 +2679,82 @@ bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { return false; } + +OverflowResult llvm::computeOverflowForUnsignedMul(Value *LHS, Value *RHS, + const DataLayout *DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading zero bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + APInt LHSKnownZero(BitWidth, 0); + APInt LHSKnownOne(BitWidth, 0); + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(LHS, LHSKnownZero, LHSKnownOne, DL, /*Depth=*/0, AC, CxtI, + DT); + computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, DL, /*Depth=*/0, AC, CxtI, + DT); + // Note that underestimating the number of zero bits gives a more + // conservative answer. + unsigned ZeroBits = LHSKnownZero.countLeadingOnes() + + RHSKnownZero.countLeadingOnes(); + // First handle the easy case: if we have enough zero bits there's + // definitely no overflow. + if (ZeroBits >= BitWidth) + return OverflowResult::NeverOverflows; + + // Get the largest possible values for each operand. + APInt LHSMax = ~LHSKnownZero; + APInt RHSMax = ~RHSKnownZero; + + // We know the multiply operation doesn't overflow if the maximum values for + // each operand will not overflow after we multiply them together. + bool MaxOverflow; + LHSMax.umul_ov(RHSMax, MaxOverflow); + if (!MaxOverflow) + return OverflowResult::NeverOverflows; + + // We know it always overflows if multiplying the smallest possible values for + // the operands also results in overflow. + bool MinOverflow; + LHSKnownOne.umul_ov(RHSKnownOne, MinOverflow); + if (MinOverflow) + return OverflowResult::AlwaysOverflows; + + return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForUnsignedAdd(Value *LHS, Value *RHS, + const DataLayout *DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + bool LHSKnownNonNegative, LHSKnownNegative; + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, DL, /*Depth=*/0, + AC, CxtI, DT); + if (LHSKnownNonNegative || LHSKnownNegative) { + bool RHSKnownNonNegative, RHSKnownNegative; + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, DL, /*Depth=*/0, + AC, CxtI, DT); + + if (LHSKnownNegative && RHSKnownNegative) { + // The sign bit is set in both cases: this MUST overflow. + // Create a simple add instruction, and insert it into the struct. + return OverflowResult::AlwaysOverflows; + } + + if (LHSKnownNonNegative && RHSKnownNonNegative) { + // The sign bit is clear in both cases: this CANNOT overflow. + // Create a simple add instruction, and insert it into the struct. + return OverflowResult::NeverOverflows; + } + } + + return OverflowResult::MayOverflow; +} |