diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2016-07-23 20:41:05 +0000 |
commit | 01095a5d43bbfde13731688ddcf6048ebb8b7721 (patch) | |
tree | 4def12e759965de927d963ac65840d663ef9d1ea /lib/Analysis | |
parent | f0f4822ed4b66e3579e92a89f368f8fb860e218e (diff) | |
download | src-test2-01095a5d43bbfde13731688ddcf6048ebb8b7721.tar.gz src-test2-01095a5d43bbfde13731688ddcf6048ebb8b7721.zip |
Notes
Diffstat (limited to 'lib/Analysis')
80 files changed, 12021 insertions, 6639 deletions
diff --git a/lib/Analysis/AliasAnalysis.cpp b/lib/Analysis/AliasAnalysis.cpp index 35f2e97622fa..f931b6fc6523 100644 --- a/lib/Analysis/AliasAnalysis.cpp +++ b/lib/Analysis/AliasAnalysis.cpp @@ -27,7 +27,8 @@ #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/CFLAliasAnalysis.h" +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ObjCARCAliasAnalysis.h" @@ -52,18 +53,11 @@ using namespace llvm; static cl::opt<bool> DisableBasicAA("disable-basicaa", cl::Hidden, cl::init(false)); -AAResults::AAResults(AAResults &&Arg) : AAs(std::move(Arg.AAs)) { +AAResults::AAResults(AAResults &&Arg) : TLI(Arg.TLI), AAs(std::move(Arg.AAs)) { for (auto &AA : AAs) AA->setAAResults(this); } -AAResults &AAResults::operator=(AAResults &&Arg) { - AAs = std::move(Arg.AAs); - for (auto &AA : AAs) - AA->setAAResults(this); - return *this; -} - AAResults::~AAResults() { // FIXME; It would be nice to at least clear out the pointers back to this // aggregation here, but we end up with non-nesting lifetimes in the legacy @@ -116,7 +110,10 @@ ModRefInfo AAResults::getModRefInfo(Instruction *I, ImmutableCallSite Call) { // We may have two calls if (auto CS = ImmutableCallSite(I)) { // Check if the two calls modify the same memory - return getModRefInfo(Call, CS); + return getModRefInfo(CS, Call); + } else if (I->isFenceLike()) { + // If this is a fence, just return MRI_ModRef. + return MRI_ModRef; } else { // Otherwise, check if the call modifies or references the // location this memory access defines. The best we can say @@ -141,6 +138,46 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, return Result; } + // Try to refine the mod-ref info further using other API entry points to the + // aggregate set of AA results. + auto MRB = getModRefBehavior(CS); + if (MRB == FMRB_DoesNotAccessMemory) + return MRI_NoModRef; + + if (onlyReadsMemory(MRB)) + Result = ModRefInfo(Result & MRI_Ref); + else if (doesNotReadMemory(MRB)) + Result = ModRefInfo(Result & MRI_Mod); + + if (onlyAccessesArgPointees(MRB)) { + bool DoesAlias = false; + ModRefInfo AllArgsMask = MRI_NoModRef; + if (doesAccessArgPointees(MRB)) { + for (auto AI = CS.arg_begin(), AE = CS.arg_end(); AI != AE; ++AI) { + const Value *Arg = *AI; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned ArgIdx = std::distance(CS.arg_begin(), AI); + MemoryLocation ArgLoc = MemoryLocation::getForArgument(CS, ArgIdx, TLI); + AliasResult ArgAlias = alias(ArgLoc, Loc); + if (ArgAlias != NoAlias) { + ModRefInfo ArgMask = getArgModRefInfo(CS, ArgIdx); + DoesAlias = true; + AllArgsMask = ModRefInfo(AllArgsMask | ArgMask); + } + } + } + if (!DoesAlias) + return MRI_NoModRef; + Result = ModRefInfo(Result & AllArgsMask); + } + + // If Loc is a constant memory location, the call definitely could not + // modify the memory location. + if ((Result & MRI_Mod) && + pointsToConstantMemory(Loc, /*OrLocal*/ false)) + Result = ModRefInfo(Result & ~MRI_Mod); + return Result; } @@ -156,6 +193,90 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, return Result; } + // Try to refine the mod-ref info further using other API entry points to the + // aggregate set of AA results. + + // If CS1 or CS2 are readnone, they don't interact. + auto CS1B = getModRefBehavior(CS1); + if (CS1B == FMRB_DoesNotAccessMemory) + return MRI_NoModRef; + + auto CS2B = getModRefBehavior(CS2); + if (CS2B == FMRB_DoesNotAccessMemory) + return MRI_NoModRef; + + // If they both only read from memory, there is no dependence. + if (onlyReadsMemory(CS1B) && onlyReadsMemory(CS2B)) + return MRI_NoModRef; + + // If CS1 only reads memory, the only dependence on CS2 can be + // from CS1 reading memory written by CS2. + if (onlyReadsMemory(CS1B)) + Result = ModRefInfo(Result & MRI_Ref); + else if (doesNotReadMemory(CS1B)) + Result = ModRefInfo(Result & MRI_Mod); + + // If CS2 only access memory through arguments, accumulate the mod/ref + // information from CS1's references to the memory referenced by + // CS2's arguments. + if (onlyAccessesArgPointees(CS2B)) { + ModRefInfo R = MRI_NoModRef; + if (doesAccessArgPointees(CS2B)) { + for (auto I = CS2.arg_begin(), E = CS2.arg_end(); I != E; ++I) { + const Value *Arg = *I; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned CS2ArgIdx = std::distance(CS2.arg_begin(), I); + auto CS2ArgLoc = MemoryLocation::getForArgument(CS2, CS2ArgIdx, TLI); + + // ArgMask indicates what CS2 might do to CS2ArgLoc, and the dependence + // of CS1 on that location is the inverse. + ModRefInfo ArgMask = getArgModRefInfo(CS2, CS2ArgIdx); + if (ArgMask == MRI_Mod) + ArgMask = MRI_ModRef; + else if (ArgMask == MRI_Ref) + ArgMask = MRI_Mod; + + ArgMask = ModRefInfo(ArgMask & getModRefInfo(CS1, CS2ArgLoc)); + + R = ModRefInfo((R | ArgMask) & Result); + if (R == Result) + break; + } + } + return R; + } + + // If CS1 only accesses memory through arguments, check if CS2 references + // any of the memory referenced by CS1's arguments. If not, return NoModRef. + if (onlyAccessesArgPointees(CS1B)) { + ModRefInfo R = MRI_NoModRef; + if (doesAccessArgPointees(CS1B)) { + for (auto I = CS1.arg_begin(), E = CS1.arg_end(); I != E; ++I) { + const Value *Arg = *I; + if (!Arg->getType()->isPointerTy()) + continue; + unsigned CS1ArgIdx = std::distance(CS1.arg_begin(), I); + auto CS1ArgLoc = MemoryLocation::getForArgument(CS1, CS1ArgIdx, TLI); + + // ArgMask indicates what CS1 might do to CS1ArgLoc; if CS1 might Mod + // CS1ArgLoc, then we care about either a Mod or a Ref by CS2. If CS1 + // might Ref, then we care only about a Mod by CS2. + ModRefInfo ArgMask = getArgModRefInfo(CS1, CS1ArgIdx); + ModRefInfo ArgR = getModRefInfo(CS2, CS1ArgLoc); + if (((ArgMask & MRI_Mod) != MRI_NoModRef && + (ArgR & MRI_ModRef) != MRI_NoModRef) || + ((ArgMask & MRI_Ref) != MRI_NoModRef && + (ArgR & MRI_Mod) != MRI_NoModRef)) + R = ModRefInfo((R | ArgMask) & Result); + + if (R == Result) + break; + } + } + return R; + } + return Result; } @@ -276,7 +397,7 @@ ModRefInfo AAResults::getModRefInfo(const CatchReturnInst *CatchRet, ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX, const MemoryLocation &Loc) { // Acquire/Release cmpxchg has properties that matter for arbitrary addresses. - if (CX->getSuccessOrdering() > Monotonic) + if (isStrongerThanMonotonic(CX->getSuccessOrdering())) return MRI_ModRef; // If the cmpxchg address does not alias the location, it does not access it. @@ -289,7 +410,7 @@ ModRefInfo AAResults::getModRefInfo(const AtomicCmpXchgInst *CX, ModRefInfo AAResults::getModRefInfo(const AtomicRMWInst *RMW, const MemoryLocation &Loc) { // Acquire/Release atomicrmw has properties that matter for arbitrary addresses. - if (RMW->getOrdering() > Monotonic) + if (isStrongerThanMonotonic(RMW->getOrdering())) return MRI_ModRef; // If the atomicrmw address does not alias the location, it does not access it. @@ -332,7 +453,7 @@ ModRefInfo AAResults::callCapturesBefore(const Instruction *I, unsigned ArgNo = 0; ModRefInfo R = MRI_NoModRef; - for (ImmutableCallSite::arg_iterator CI = CS.arg_begin(), CE = CS.arg_end(); + for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); CI != CE; ++CI, ++ArgNo) { // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it @@ -390,6 +511,9 @@ bool AAResults::canInstructionRangeModRef(const Instruction &I1, // Provide a definition for the root virtual destructor. AAResults::Concept::~Concept() {} +// Provide a definition for the static object used to identify passes. +char AAManager::PassID; + namespace { /// A wrapper pass for external alias analyses. This just squirrels away the /// callback used to run any analyses and register their results. @@ -432,7 +556,8 @@ char AAResultsWrapperPass::ID = 0; INITIALIZE_PASS_BEGIN(AAResultsWrapperPass, "aa", "Function Alias Analysis Results", false, true) INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) -INITIALIZE_PASS_DEPENDENCY(CFLAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLAndersAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(CFLSteensAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(ExternalAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass) @@ -461,7 +586,8 @@ bool AAResultsWrapperPass::runOnFunction(Function &F) { // unregistering themselves with them. We need to carefully tear down the // previous object first, in this case replacing it with an empty one, before // registering new results. - AAR.reset(new AAResults()); + AAR.reset( + new AAResults(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); // BasicAA is always available for function analyses. Also, we add it first // so that it can trump TBAA results when it proves MustAlias. @@ -482,7 +608,9 @@ bool AAResultsWrapperPass::runOnFunction(Function &F) { AAR->addAAResult(WrapperPass->getResult()); if (auto *WrapperPass = getAnalysisIfAvailable<SCEVAAWrapperPass>()) AAR->addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = getAnalysisIfAvailable<CFLAAWrapperPass>()) + if (auto *WrapperPass = getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) + AAR->addAAResult(WrapperPass->getResult()); + if (auto *WrapperPass = getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) AAR->addAAResult(WrapperPass->getResult()); // If available, run an external AA providing callback over the results as @@ -498,6 +626,7 @@ bool AAResultsWrapperPass::runOnFunction(Function &F) { void AAResultsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<BasicAAWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); // We also need to mark all the alias analysis passes we will potentially // probe in runOnFunction as used here to ensure the legacy pass manager @@ -508,12 +637,13 @@ void AAResultsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); AU.addUsedIfAvailable<SCEVAAWrapperPass>(); - AU.addUsedIfAvailable<CFLAAWrapperPass>(); + AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); + AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); } AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, BasicAAResult &BAR) { - AAResults AAR; + AAResults AAR(P.getAnalysis<TargetLibraryInfoWrapperPass>().getTLI()); // Add in our explicitly constructed BasicAA results. if (!DisableBasicAA) @@ -530,38 +660,26 @@ AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, AAR.addAAResult(WrapperPass->getResult()); if (auto *WrapperPass = P.getAnalysisIfAvailable<GlobalsAAWrapperPass>()) AAR.addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = P.getAnalysisIfAvailable<SCEVAAWrapperPass>()) + if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLAndersAAWrapperPass>()) AAR.addAAResult(WrapperPass->getResult()); - if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLAAWrapperPass>()) + if (auto *WrapperPass = P.getAnalysisIfAvailable<CFLSteensAAWrapperPass>()) AAR.addAAResult(WrapperPass->getResult()); return AAR; } -/// isNoAliasCall - Return true if this pointer is returned by a noalias -/// function. bool llvm::isNoAliasCall(const Value *V) { if (auto CS = ImmutableCallSite(V)) return CS.paramHasAttr(0, Attribute::NoAlias); return false; } -/// isNoAliasArgument - Return true if this is an argument with the noalias -/// attribute. -bool llvm::isNoAliasArgument(const Value *V) -{ +bool llvm::isNoAliasArgument(const Value *V) { if (const Argument *A = dyn_cast<Argument>(V)) return A->hasNoAliasAttr(); return false; } -/// isIdentifiedObject - Return true if this pointer refers to a distinct and -/// identifiable object. This returns true for: -/// Global Variables and Functions (but not Global Aliases) -/// Allocas and Mallocs -/// ByVal and NoAlias Arguments -/// NoAlias returns -/// bool llvm::isIdentifiedObject(const Value *V) { if (isa<AllocaInst>(V)) return true; @@ -574,12 +692,19 @@ bool llvm::isIdentifiedObject(const Value *V) { return false; } -/// isIdentifiedFunctionLocal - Return true if V is umabigously identified -/// at the function-level. Different IdentifiedFunctionLocals can't alias. -/// Further, an IdentifiedFunctionLocal can not alias with any function -/// arguments other than itself, which is not necessarily true for -/// IdentifiedObjects. -bool llvm::isIdentifiedFunctionLocal(const Value *V) -{ +bool llvm::isIdentifiedFunctionLocal(const Value *V) { return isa<AllocaInst>(V) || isNoAliasCall(V) || isNoAliasArgument(V); } + +void llvm::getAAResultsAnalysisUsage(AnalysisUsage &AU) { + // This function needs to be in sync with llvm::createLegacyPMAAResults -- if + // more alias analyses are added to llvm::createLegacyPMAAResults, they need + // to be added here also. + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addUsedIfAvailable<ScopedNoAliasAAWrapperPass>(); + AU.addUsedIfAvailable<TypeBasedAAWrapperPass>(); + AU.addUsedIfAvailable<objcarc::ObjCARCAAWrapperPass>(); + AU.addUsedIfAvailable<GlobalsAAWrapperPass>(); + AU.addUsedIfAvailable<CFLAndersAAWrapperPass>(); + AU.addUsedIfAvailable<CFLSteensAAWrapperPass>(); +} diff --git a/lib/Analysis/AliasAnalysisEvaluator.cpp b/lib/Analysis/AliasAnalysisEvaluator.cpp index 12917b650e5e..baf8f3f881db 100644 --- a/lib/Analysis/AliasAnalysisEvaluator.cpp +++ b/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -6,18 +6,8 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -// -// This file implements a simple N^2 alias analysis accuracy evaluator. -// Basically, for each function in the program, it simply queries to see how the -// alias analysis implementation answers alias queries between each pair of -// pointers in the function. -// -// This is inspired and adapted from code by: Naveen Neelakantam, Francesco -// Spadini, and Wojciech Stryjewski. -// -//===----------------------------------------------------------------------===// -#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/AliasAnalysisEvaluator.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/IR/Constants.h" @@ -47,51 +37,9 @@ static cl::opt<bool> PrintModRef("print-modref", cl::ReallyHidden); static cl::opt<bool> EvalAAMD("evaluate-aa-metadata", cl::ReallyHidden); -namespace { - class AAEval : public FunctionPass { - unsigned NoAliasCount, MayAliasCount, PartialAliasCount, MustAliasCount; - unsigned NoModRefCount, ModCount, RefCount, ModRefCount; - - public: - static char ID; // Pass identification, replacement for typeid - AAEval() : FunctionPass(ID) { - initializeAAEvalPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<AAResultsWrapperPass>(); - AU.setPreservesAll(); - } - - bool doInitialization(Module &M) override { - NoAliasCount = MayAliasCount = PartialAliasCount = MustAliasCount = 0; - NoModRefCount = ModCount = RefCount = ModRefCount = 0; - - if (PrintAll) { - PrintNoAlias = PrintMayAlias = true; - PrintPartialAlias = PrintMustAlias = true; - PrintNoModRef = PrintMod = PrintRef = PrintModRef = true; - } - return false; - } - - bool runOnFunction(Function &F) override; - bool doFinalization(Module &M) override; - }; -} - -char AAEval::ID = 0; -INITIALIZE_PASS_BEGIN(AAEval, "aa-eval", - "Exhaustive Alias Analysis Precision Evaluator", false, true) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(AAEval, "aa-eval", - "Exhaustive Alias Analysis Precision Evaluator", false, true) - -FunctionPass *llvm::createAAEvalPass() { return new AAEval(); } - static void PrintResults(const char *Msg, bool P, const Value *V1, const Value *V2, const Module *M) { - if (P) { + if (PrintAll || P) { std::string o1, o2; { raw_string_ostream os1(o1), os2(o2); @@ -110,7 +58,7 @@ static void PrintResults(const char *Msg, bool P, const Value *V1, static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I, Value *Ptr, Module *M) { - if (P) { + if (PrintAll || P) { errs() << " " << Msg << ": Ptr: "; Ptr->printAsOperand(errs(), true, M); errs() << "\t<->" << *I << '\n'; @@ -120,7 +68,7 @@ PrintModRefResults(const char *Msg, bool P, Instruction *I, Value *Ptr, static inline void PrintModRefResults(const char *Msg, bool P, CallSite CSA, CallSite CSB, Module *M) { - if (P) { + if (PrintAll || P) { errs() << " " << Msg << ": " << *CSA.getInstruction() << " <-> " << *CSB.getInstruction() << '\n'; } @@ -129,7 +77,7 @@ PrintModRefResults(const char *Msg, bool P, CallSite CSA, CallSite CSB, static inline void PrintLoadStoreResults(const char *Msg, bool P, const Value *V1, const Value *V2, const Module *M) { - if (P) { + if (PrintAll || P) { errs() << " " << Msg << ": " << *V1 << " <-> " << *V2 << '\n'; } @@ -140,9 +88,15 @@ static inline bool isInterestingPointer(Value *V) { && !isa<ConstantPointerNull>(V); } -bool AAEval::runOnFunction(Function &F) { +PreservedAnalyses AAEvaluator::run(Function &F, AnalysisManager<Function> &AM) { + runInternal(F, AM.getResult<AAManager>(F)); + return PreservedAnalyses::all(); +} + +void AAEvaluator::runInternal(Function &F, AAResults &AA) { const DataLayout &DL = F.getParent()->getDataLayout(); - AliasAnalysis &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + + ++FunctionCount; SetVector<Value *> Pointers; SmallSetVector<CallSite, 16> CallSites; @@ -180,8 +134,8 @@ bool AAEval::runOnFunction(Function &F) { } } - if (PrintNoAlias || PrintMayAlias || PrintPartialAlias || PrintMustAlias || - PrintNoModRef || PrintMod || PrintRef || PrintModRef) + if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias || + PrintMustAlias || PrintNoModRef || PrintMod || PrintRef || PrintModRef) errs() << "Function: " << F.getName() << ": " << Pointers.size() << " pointers, " << CallSites.size() << " call sites\n"; @@ -221,29 +175,27 @@ bool AAEval::runOnFunction(Function &F) { if (EvalAAMD) { // iterate over all pairs of load, store - for (SetVector<Value *>::iterator I1 = Loads.begin(), E = Loads.end(); - I1 != E; ++I1) { - for (SetVector<Value *>::iterator I2 = Stores.begin(), E2 = Stores.end(); - I2 != E2; ++I2) { - switch (AA.alias(MemoryLocation::get(cast<LoadInst>(*I1)), - MemoryLocation::get(cast<StoreInst>(*I2)))) { + for (Value *Load : Loads) { + for (Value *Store : Stores) { + switch (AA.alias(MemoryLocation::get(cast<LoadInst>(Load)), + MemoryLocation::get(cast<StoreInst>(Store)))) { case NoAlias: - PrintLoadStoreResults("NoAlias", PrintNoAlias, *I1, *I2, + PrintLoadStoreResults("NoAlias", PrintNoAlias, Load, Store, F.getParent()); ++NoAliasCount; break; case MayAlias: - PrintLoadStoreResults("MayAlias", PrintMayAlias, *I1, *I2, + PrintLoadStoreResults("MayAlias", PrintMayAlias, Load, Store, F.getParent()); ++MayAliasCount; break; case PartialAlias: - PrintLoadStoreResults("PartialAlias", PrintPartialAlias, *I1, *I2, + PrintLoadStoreResults("PartialAlias", PrintPartialAlias, Load, Store, F.getParent()); ++PartialAliasCount; break; case MustAlias: - PrintLoadStoreResults("MustAlias", PrintMustAlias, *I1, *I2, + PrintLoadStoreResults("MustAlias", PrintMustAlias, Load, Store, F.getParent()); ++MustAliasCount; break; @@ -283,30 +235,31 @@ bool AAEval::runOnFunction(Function &F) { } // Mod/ref alias analysis: compare all pairs of calls and values - for (auto C = CallSites.begin(), Ce = CallSites.end(); C != Ce; ++C) { - Instruction *I = C->getInstruction(); + for (CallSite C : CallSites) { + Instruction *I = C.getInstruction(); - for (SetVector<Value *>::iterator V = Pointers.begin(), Ve = Pointers.end(); - V != Ve; ++V) { + for (auto Pointer : Pointers) { uint64_t Size = MemoryLocation::UnknownSize; - Type *ElTy = cast<PointerType>((*V)->getType())->getElementType(); + Type *ElTy = cast<PointerType>(Pointer->getType())->getElementType(); if (ElTy->isSized()) Size = DL.getTypeStoreSize(ElTy); - switch (AA.getModRefInfo(*C, *V, Size)) { + switch (AA.getModRefInfo(C, Pointer, Size)) { case MRI_NoModRef: - PrintModRefResults("NoModRef", PrintNoModRef, I, *V, F.getParent()); + PrintModRefResults("NoModRef", PrintNoModRef, I, Pointer, + F.getParent()); ++NoModRefCount; break; case MRI_Mod: - PrintModRefResults("Just Mod", PrintMod, I, *V, F.getParent()); + PrintModRefResults("Just Mod", PrintMod, I, Pointer, F.getParent()); ++ModCount; break; case MRI_Ref: - PrintModRefResults("Just Ref", PrintRef, I, *V, F.getParent()); + PrintModRefResults("Just Ref", PrintRef, I, Pointer, F.getParent()); ++RefCount; break; case MRI_ModRef: - PrintModRefResults("Both ModRef", PrintModRef, I, *V, F.getParent()); + PrintModRefResults("Both ModRef", PrintModRef, I, Pointer, + F.getParent()); ++ModRefCount; break; } @@ -338,17 +291,18 @@ bool AAEval::runOnFunction(Function &F) { } } } - - return false; } -static void PrintPercent(unsigned Num, unsigned Sum) { - errs() << "(" << Num*100ULL/Sum << "." - << ((Num*1000ULL/Sum) % 10) << "%)\n"; +static void PrintPercent(int64_t Num, int64_t Sum) { + errs() << "(" << Num * 100LL / Sum << "." << ((Num * 1000LL / Sum) % 10) + << "%)\n"; } -bool AAEval::doFinalization(Module &M) { - unsigned AliasSum = +AAEvaluator::~AAEvaluator() { + if (FunctionCount == 0) + return; + + int64_t AliasSum = NoAliasCount + MayAliasCount + PartialAliasCount + MustAliasCount; errs() << "===== Alias Analysis Evaluator Report =====\n"; if (AliasSum == 0) { @@ -371,7 +325,7 @@ bool AAEval::doFinalization(Module &M) { } // Display the summary for mod/ref analysis - unsigned ModRefSum = NoModRefCount + ModCount + RefCount + ModRefCount; + int64_t ModRefSum = NoModRefCount + ModCount + RefCount + ModRefCount; if (ModRefSum == 0) { errs() << " Alias Analysis Mod/Ref Evaluator Summary: no " "mod/ref!\n"; @@ -390,6 +344,46 @@ bool AAEval::doFinalization(Module &M) { << ModCount * 100 / ModRefSum << "%/" << RefCount * 100 / ModRefSum << "%/" << ModRefCount * 100 / ModRefSum << "%\n"; } +} - return false; +namespace llvm { +class AAEvalLegacyPass : public FunctionPass { + std::unique_ptr<AAEvaluator> P; + +public: + static char ID; // Pass identification, replacement for typeid + AAEvalLegacyPass() : FunctionPass(ID) { + initializeAAEvalLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AAResultsWrapperPass>(); + AU.setPreservesAll(); + } + + bool doInitialization(Module &M) override { + P.reset(new AAEvaluator()); + return false; + } + + bool runOnFunction(Function &F) override { + P->runInternal(F, getAnalysis<AAResultsWrapperPass>().getAAResults()); + return false; + } + bool doFinalization(Module &M) override { + P.reset(); + return false; + } +}; } + +char AAEvalLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(AAEvalLegacyPass, "aa-eval", + "Exhaustive Alias Analysis Precision Evaluator", false, + true) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(AAEvalLegacyPass, "aa-eval", + "Exhaustive Alias Analysis Precision Evaluator", false, + true) + +FunctionPass *llvm::createAAEvalPass() { return new AAEvalLegacyPass(); } diff --git a/lib/Analysis/AliasAnalysisSummary.cpp b/lib/Analysis/AliasAnalysisSummary.cpp new file mode 100644 index 000000000000..f3f13df283db --- /dev/null +++ b/lib/Analysis/AliasAnalysisSummary.cpp @@ -0,0 +1,105 @@ +#include "AliasAnalysisSummary.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Compiler.h" + +namespace llvm { +namespace cflaa { + +namespace { +LLVM_CONSTEXPR unsigned AttrEscapedIndex = 0; +LLVM_CONSTEXPR unsigned AttrUnknownIndex = 1; +LLVM_CONSTEXPR unsigned AttrGlobalIndex = 2; +LLVM_CONSTEXPR unsigned AttrCallerIndex = 3; +LLVM_CONSTEXPR unsigned AttrFirstArgIndex = 4; +LLVM_CONSTEXPR unsigned AttrLastArgIndex = NumAliasAttrs; +LLVM_CONSTEXPR unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; + +// NOTE: These aren't AliasAttrs because bitsets don't have a constexpr +// ctor for some versions of MSVC that we support. We could maybe refactor, +// but... +using AliasAttr = unsigned; +LLVM_CONSTEXPR AliasAttr AttrNone = 0; +LLVM_CONSTEXPR AliasAttr AttrEscaped = 1 << AttrEscapedIndex; +LLVM_CONSTEXPR AliasAttr AttrUnknown = 1 << AttrUnknownIndex; +LLVM_CONSTEXPR AliasAttr AttrGlobal = 1 << AttrGlobalIndex; +LLVM_CONSTEXPR AliasAttr AttrCaller = 1 << AttrCallerIndex; +LLVM_CONSTEXPR AliasAttr ExternalAttrMask = + AttrEscaped | AttrUnknown | AttrGlobal; +} + +AliasAttrs getAttrNone() { return AttrNone; } + +AliasAttrs getAttrUnknown() { return AttrUnknown; } +bool hasUnknownAttr(AliasAttrs Attr) { return Attr.test(AttrUnknownIndex); } + +AliasAttrs getAttrCaller() { return AttrCaller; } +bool hasCallerAttr(AliasAttrs Attr) { return Attr.test(AttrCaller); } +bool hasUnknownOrCallerAttr(AliasAttrs Attr) { + return Attr.test(AttrUnknownIndex) || Attr.test(AttrCallerIndex); +} + +AliasAttrs getAttrEscaped() { return AttrEscaped; } +bool hasEscapedAttr(AliasAttrs Attr) { return Attr.test(AttrEscapedIndex); } + +static AliasAttr argNumberToAttr(unsigned ArgNum) { + if (ArgNum >= AttrMaxNumArgs) + return AttrUnknown; + // N.B. MSVC complains if we use `1U` here, since AliasAttr' ctor takes + // an unsigned long long. + return AliasAttr(1ULL << (ArgNum + AttrFirstArgIndex)); +} + +AliasAttrs getGlobalOrArgAttrFromValue(const Value &Val) { + if (isa<GlobalValue>(Val)) + return AttrGlobal; + + if (auto *Arg = dyn_cast<Argument>(&Val)) + // Only pointer arguments should have the argument attribute, + // because things can't escape through scalars without us seeing a + // cast, and thus, interaction with them doesn't matter. + if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy()) + return argNumberToAttr(Arg->getArgNo()); + return AttrNone; +} + +bool isGlobalOrArgAttr(AliasAttrs Attr) { + return Attr.reset(AttrEscapedIndex) + .reset(AttrUnknownIndex) + .reset(AttrCallerIndex) + .any(); +} + +AliasAttrs getExternallyVisibleAttrs(AliasAttrs Attr) { + return Attr & AliasAttrs(ExternalAttrMask); +} + +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue IValue, + CallSite CS) { + auto Index = IValue.Index; + auto Value = (Index == 0) ? CS.getInstruction() : CS.getArgument(Index - 1); + if (Value->getType()->isPointerTy()) + return InstantiatedValue{Value, IValue.DerefLevel}; + return None; +} + +Optional<InstantiatedRelation> +instantiateExternalRelation(ExternalRelation ERelation, CallSite CS) { + auto From = instantiateInterfaceValue(ERelation.From, CS); + if (!From) + return None; + auto To = instantiateInterfaceValue(ERelation.To, CS); + if (!To) + return None; + return InstantiatedRelation{*From, *To}; +} + +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute EAttr, + CallSite CS) { + auto Value = instantiateInterfaceValue(EAttr.IValue, CS); + if (!Value) + return None; + return InstantiatedAttr{*Value, EAttr.Attr}; +} +} +} diff --git a/lib/Analysis/AliasAnalysisSummary.h b/lib/Analysis/AliasAnalysisSummary.h new file mode 100644 index 000000000000..43c0d4cb14f9 --- /dev/null +++ b/lib/Analysis/AliasAnalysisSummary.h @@ -0,0 +1,211 @@ +//=====- CFLSummary.h - Abstract stratified sets implementation. --------=====// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines various utility types and functions useful to +/// summary-based alias analysis. +/// +/// Summary-based analysis, also known as bottom-up analysis, is a style of +/// interprocedrual static analysis that tries to analyze the callees before the +/// callers get analyzed. The key idea of summary-based analysis is to first +/// process each function indepedently, outline its behavior in a condensed +/// summary, and then instantiate the summary at the callsite when the said +/// function is called elsewhere. This is often in contrast to another style +/// called top-down analysis, in which callers are always analyzed first before +/// the callees. +/// +/// In a summary-based analysis, functions must be examined independently and +/// out-of-context. We have no information on the state of the memory, the +/// arguments, the global values, and anything else external to the function. To +/// carry out the analysis conservative assumptions have to be made about those +/// external states. In exchange for the potential loss of precision, the +/// summary we obtain this way is highly reusable, which makes the analysis +/// easier to scale to large programs even if carried out context-sensitively. +/// +/// Currently, all CFL-based alias analyses adopt the summary-based approach +/// and therefore heavily rely on this header. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H +#define LLVM_ANALYSIS_ALIASANALYSISSUMMARY_H + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/CallSite.h" +#include <bitset> + +namespace llvm { +namespace cflaa { + +//===----------------------------------------------------------------------===// +// AliasAttr related stuffs +//===----------------------------------------------------------------------===// + +/// The number of attributes that AliasAttr should contain. Attributes are +/// described below, and 32 was an arbitrary choice because it fits nicely in 32 +/// bits (because we use a bitset for AliasAttr). +static const unsigned NumAliasAttrs = 32; + +/// These are attributes that an alias analysis can use to mark certain special +/// properties of a given pointer. Refer to the related functions below to see +/// what kinds of attributes are currently defined. +typedef std::bitset<NumAliasAttrs> AliasAttrs; + +/// Attr represent whether the said pointer comes from an unknown source +/// (such as opaque memory or an integer cast). +AliasAttrs getAttrNone(); + +/// AttrUnknown represent whether the said pointer comes from a source not known +/// to alias analyses (such as opaque memory or an integer cast). +AliasAttrs getAttrUnknown(); +bool hasUnknownAttr(AliasAttrs); + +/// AttrCaller represent whether the said pointer comes from a source not known +/// to the current function but known to the caller. Values pointed to by the +/// arguments of the current function have this attribute set +AliasAttrs getAttrCaller(); +bool hasCallerAttr(AliasAttrs); +bool hasUnknownOrCallerAttr(AliasAttrs); + +/// AttrEscaped represent whether the said pointer comes from a known source but +/// escapes to the unknown world (e.g. casted to an integer, or passed as an +/// argument to opaque function). Unlike non-escaped pointers, escaped ones may +/// alias pointers coming from unknown sources. +AliasAttrs getAttrEscaped(); +bool hasEscapedAttr(AliasAttrs); + +/// AttrGlobal represent whether the said pointer is a global value. +/// AttrArg represent whether the said pointer is an argument, and if so, what +/// index the argument has. +AliasAttrs getGlobalOrArgAttrFromValue(const Value &); +bool isGlobalOrArgAttr(AliasAttrs); + +/// Given an AliasAttrs, return a new AliasAttrs that only contains attributes +/// meaningful to the caller. This function is primarily used for +/// interprocedural analysis +/// Currently, externally visible AliasAttrs include AttrUnknown, AttrGlobal, +/// and AttrEscaped +AliasAttrs getExternallyVisibleAttrs(AliasAttrs); + +//===----------------------------------------------------------------------===// +// Function summary related stuffs +//===----------------------------------------------------------------------===// + +/// The maximum number of arguments we can put into a summary. +LLVM_CONSTEXPR static unsigned MaxSupportedArgsInSummary = 50; + +/// We use InterfaceValue to describe parameters/return value, as well as +/// potential memory locations that are pointed to by parameters/return value, +/// of a function. +/// Index is an integer which represents a single parameter or a return value. +/// When the index is 0, it refers to the return value. Non-zero index i refers +/// to the i-th parameter. +/// DerefLevel indicates the number of dereferences one must perform on the +/// parameter/return value to get this InterfaceValue. +struct InterfaceValue { + unsigned Index; + unsigned DerefLevel; +}; + +inline bool operator==(InterfaceValue LHS, InterfaceValue RHS) { + return LHS.Index == RHS.Index && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InterfaceValue LHS, InterfaceValue RHS) { + return !(LHS == RHS); +} + +/// We use ExternalRelation to describe an externally visible aliasing relations +/// between parameters/return value of a function. +struct ExternalRelation { + InterfaceValue From, To; +}; + +/// We use ExternalAttribute to describe an externally visible AliasAttrs +/// for parameters/return value. +struct ExternalAttribute { + InterfaceValue IValue; + AliasAttrs Attr; +}; + +/// AliasSummary is just a collection of ExternalRelation and ExternalAttribute +struct AliasSummary { + // RetParamRelations is a collection of ExternalRelations. + SmallVector<ExternalRelation, 8> RetParamRelations; + + // RetParamAttributes is a collection of ExternalAttributes. + SmallVector<ExternalAttribute, 8> RetParamAttributes; +}; + +/// This is the result of instantiating InterfaceValue at a particular callsite +struct InstantiatedValue { + Value *Val; + unsigned DerefLevel; +}; +Optional<InstantiatedValue> instantiateInterfaceValue(InterfaceValue, CallSite); + +inline bool operator==(InstantiatedValue LHS, InstantiatedValue RHS) { + return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; +} +inline bool operator!=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(LHS == RHS); +} +inline bool operator<(InstantiatedValue LHS, InstantiatedValue RHS) { + return std::less<Value *>()(LHS.Val, RHS.Val) || + (LHS.Val == RHS.Val && LHS.DerefLevel < RHS.DerefLevel); +} +inline bool operator>(InstantiatedValue LHS, InstantiatedValue RHS) { + return RHS < LHS; +} +inline bool operator<=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(RHS < LHS); +} +inline bool operator>=(InstantiatedValue LHS, InstantiatedValue RHS) { + return !(LHS < RHS); +} + +/// This is the result of instantiating ExternalRelation at a particular +/// callsite +struct InstantiatedRelation { + InstantiatedValue From, To; +}; +Optional<InstantiatedRelation> instantiateExternalRelation(ExternalRelation, + CallSite); + +/// This is the result of instantiating ExternalAttribute at a particular +/// callsite +struct InstantiatedAttr { + InstantiatedValue IValue; + AliasAttrs Attr; +}; +Optional<InstantiatedAttr> instantiateExternalAttribute(ExternalAttribute, + CallSite); +} + +template <> struct DenseMapInfo<cflaa::InstantiatedValue> { + static inline cflaa::InstantiatedValue getEmptyKey() { + return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getEmptyKey(), + DenseMapInfo<unsigned>::getEmptyKey()}; + } + static inline cflaa::InstantiatedValue getTombstoneKey() { + return cflaa::InstantiatedValue{DenseMapInfo<Value *>::getTombstoneKey(), + DenseMapInfo<unsigned>::getTombstoneKey()}; + } + static unsigned getHashValue(const cflaa::InstantiatedValue &IV) { + return DenseMapInfo<std::pair<Value *, unsigned>>::getHashValue( + std::make_pair(IV.Val, IV.DerefLevel)); + } + static bool isEqual(const cflaa::InstantiatedValue &LHS, + const cflaa::InstantiatedValue &RHS) { + return LHS.Val == RHS.Val && LHS.DerefLevel == RHS.DerefLevel; + } +}; +} + +#endif diff --git a/lib/Analysis/AliasSetTracker.cpp b/lib/Analysis/AliasSetTracker.cpp index 3094049b3cc3..d349ac51a9b9 100644 --- a/lib/Analysis/AliasSetTracker.cpp +++ b/lib/Analysis/AliasSetTracker.cpp @@ -208,13 +208,12 @@ void AliasSetTracker::clear() { } -/// findAliasSetForPointer - Given a pointer, find the one alias set to put the -/// instruction referring to the pointer into. If there are multiple alias sets -/// that may alias the pointer, merge them together and return the unified set. -/// -AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, - uint64_t Size, - const AAMDNodes &AAInfo) { +/// mergeAliasSetsForPointer - Given a pointer, merge all alias sets that may +/// alias the pointer. Return the unified set, or nullptr if no set that aliases +/// the pointer was found. +AliasSet *AliasSetTracker::mergeAliasSetsForPointer(const Value *Ptr, + uint64_t Size, + const AAMDNodes &AAInfo) { AliasSet *FoundSet = nullptr; for (iterator I = begin(), E = end(); I != E;) { iterator Cur = I++; @@ -235,15 +234,15 @@ AliasSet *AliasSetTracker::findAliasSetForPointer(const Value *Ptr, /// alias sets. bool AliasSetTracker::containsPointer(const Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) const { - for (const_iterator I = begin(), E = end(); I != E; ++I) - if (!I->Forward && I->aliasesPointer(Ptr, Size, AAInfo, AA)) + for (const AliasSet &AS : *this) + if (!AS.Forward && AS.aliasesPointer(Ptr, Size, AAInfo, AA)) return true; return false; } bool AliasSetTracker::containsUnknown(const Instruction *Inst) const { - for (const_iterator I = begin(), E = end(); I != E; ++I) - if (!I->Forward && I->aliasesUnknownInst(Inst, AA)) + for (const AliasSet &AS : *this) + if (!AS.Forward && AS.aliasesUnknownInst(Inst, AA)) return true; return false; } @@ -274,12 +273,18 @@ AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, uint64_t Size, // Check to see if the pointer is already known. if (Entry.hasAliasSet()) { - Entry.updateSizeAndAAInfo(Size, AAInfo); + // If the size changed, we may need to merge several alias sets. + // Note that we can *not* return the result of mergeAliasSetsForPointer + // due to a quirk of alias analysis behavior. Since alias(undef, undef) + // is NoAlias, mergeAliasSetsForPointer(undef, ...) will not find the + // the right set for undef, even if it exists. + if (Entry.updateSizeAndAAInfo(Size, AAInfo)) + mergeAliasSetsForPointer(Pointer, Size, AAInfo); // Return the set! return *Entry.getAliasSet(*this)->getForwardedTarget(*this); } - if (AliasSet *AS = findAliasSetForPointer(Pointer, Size, AAInfo)) { + if (AliasSet *AS = mergeAliasSetsForPointer(Pointer, Size, AAInfo)) { // Add it to the alias set it aliases. AS->addPointer(*this, Entry, Size, AAInfo); return *AS; @@ -300,7 +305,7 @@ bool AliasSetTracker::add(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { bool AliasSetTracker::add(LoadInst *LI) { - if (LI->getOrdering() > Monotonic) return addUnknown(LI); + if (isStrongerThanMonotonic(LI->getOrdering())) return addUnknown(LI); AAMDNodes AAInfo; LI->getAAMetadata(AAInfo); @@ -316,7 +321,7 @@ bool AliasSetTracker::add(LoadInst *LI) { } bool AliasSetTracker::add(StoreInst *SI) { - if (SI->getOrdering() > Monotonic) return addUnknown(SI); + if (isStrongerThanMonotonic(SI->getOrdering())) return addUnknown(SI); AAMDNodes AAInfo; SI->getAAMetadata(AAInfo); @@ -342,6 +347,24 @@ bool AliasSetTracker::add(VAArgInst *VAAI) { return NewPtr; } +bool AliasSetTracker::add(MemSetInst *MSI) { + AAMDNodes AAInfo; + MSI->getAAMetadata(AAInfo); + + bool NewPtr; + uint64_t Len; + + if (ConstantInt *C = dyn_cast<ConstantInt>(MSI->getLength())) + Len = C->getZExtValue(); + else + Len = MemoryLocation::UnknownSize; + + AliasSet &AS = + addPointer(MSI->getRawDest(), Len, AAInfo, AliasSet::ModAccess, NewPtr); + if (MSI->isVolatile()) + AS.setVolatile(); + return NewPtr; +} bool AliasSetTracker::addUnknown(Instruction *Inst) { if (isa<DbgInfoIntrinsic>(Inst)) @@ -368,7 +391,10 @@ bool AliasSetTracker::add(Instruction *I) { return add(SI); if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) return add(VAAI); + if (MemSetInst *MSI = dyn_cast<MemSetInst>(I)) + return add(MSI); return addUnknown(I); + // FIXME: add support of memcpy and memmove. } void AliasSetTracker::add(BasicBlock &BB) { @@ -383,10 +409,9 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { // Loop over all of the alias sets in AST, adding the pointers contained // therein into the current alias sets. This can cause alias sets to be // merged together in the current AST. - for (const_iterator I = AST.begin(), E = AST.end(); I != E; ++I) { - if (I->Forward) continue; // Ignore forwarding alias sets - - AliasSet &AS = const_cast<AliasSet&>(*I); + for (const AliasSet &AS : AST) { + if (AS.Forward) + continue; // Ignore forwarding alias sets // If there are any call sites in the alias set, add them to this AST. for (unsigned i = 0, e = AS.UnknownInsts.size(); i != e; ++i) @@ -436,7 +461,7 @@ void AliasSetTracker::remove(AliasSet &AS) { bool AliasSetTracker::remove(Value *Ptr, uint64_t Size, const AAMDNodes &AAInfo) { - AliasSet *AS = findAliasSetForPointer(Ptr, Size, AAInfo); + AliasSet *AS = mergeAliasSetsForPointer(Ptr, Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -449,7 +474,7 @@ bool AliasSetTracker::remove(LoadInst *LI) { AAMDNodes AAInfo; LI->getAAMetadata(AAInfo); - AliasSet *AS = findAliasSetForPointer(LI->getOperand(0), Size, AAInfo); + AliasSet *AS = mergeAliasSetsForPointer(LI->getOperand(0), Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -462,7 +487,7 @@ bool AliasSetTracker::remove(StoreInst *SI) { AAMDNodes AAInfo; SI->getAAMetadata(AAInfo); - AliasSet *AS = findAliasSetForPointer(SI->getOperand(1), Size, AAInfo); + AliasSet *AS = mergeAliasSetsForPointer(SI->getOperand(1), Size, AAInfo); if (!AS) return false; remove(*AS); return true; @@ -472,13 +497,30 @@ bool AliasSetTracker::remove(VAArgInst *VAAI) { AAMDNodes AAInfo; VAAI->getAAMetadata(AAInfo); - AliasSet *AS = findAliasSetForPointer(VAAI->getOperand(0), - MemoryLocation::UnknownSize, AAInfo); + AliasSet *AS = mergeAliasSetsForPointer(VAAI->getOperand(0), + MemoryLocation::UnknownSize, AAInfo); if (!AS) return false; remove(*AS); return true; } +bool AliasSetTracker::remove(MemSetInst *MSI) { + AAMDNodes AAInfo; + MSI->getAAMetadata(AAInfo); + uint64_t Len; + + if (ConstantInt *C = dyn_cast<ConstantInt>(MSI->getLength())) + Len = C->getZExtValue(); + else + Len = MemoryLocation::UnknownSize; + + AliasSet *AS = mergeAliasSetsForPointer(MSI->getRawDest(), Len, AAInfo); + if (!AS) + return false; + remove(*AS); + return true; +} + bool AliasSetTracker::removeUnknown(Instruction *I) { if (!I->mayReadOrWriteMemory()) return false; // doesn't alias anything @@ -497,7 +539,10 @@ bool AliasSetTracker::remove(Instruction *I) { return remove(SI); if (VAArgInst *VAAI = dyn_cast<VAArgInst>(I)) return remove(VAAI); + if (MemSetInst *MSI = dyn_cast<MemSetInst>(I)) + return remove(MSI); return removeUnknown(I); + // FIXME: add support of memcpy and memmove. } @@ -602,14 +647,14 @@ void AliasSet::print(raw_ostream &OS) const { void AliasSetTracker::print(raw_ostream &OS) const { OS << "Alias Set Tracker: " << AliasSets.size() << " alias sets for " << PointerMap.size() << " pointer values.\n"; - for (const_iterator I = begin(), E = end(); I != E; ++I) - I->print(OS); + for (const AliasSet &AS : *this) + AS.print(OS); OS << "\n"; } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void AliasSet::dump() const { print(dbgs()); } -void AliasSetTracker::dump() const { print(dbgs()); } +LLVM_DUMP_METHOD void AliasSet::dump() const { print(dbgs()); } +LLVM_DUMP_METHOD void AliasSetTracker::dump() const { print(dbgs()); } #endif //===----------------------------------------------------------------------===// diff --git a/lib/Analysis/Analysis.cpp b/lib/Analysis/Analysis.cpp index 9c1ac000be2c..c04447ca58c9 100644 --- a/lib/Analysis/Analysis.cpp +++ b/lib/Analysis/Analysis.cpp @@ -20,25 +20,27 @@ using namespace llvm; /// initializeAnalysis - Initialize all passes linked into the Analysis library. void llvm::initializeAnalysis(PassRegistry &Registry) { - initializeAAEvalPass(Registry); + initializeAAEvalLegacyPassPass(Registry); initializeAliasSetPrinterPass(Registry); initializeBasicAAWrapperPassPass(Registry); initializeBlockFrequencyInfoWrapperPassPass(Registry); initializeBranchProbabilityInfoWrapperPassPass(Registry); initializeCallGraphWrapperPassPass(Registry); - initializeCallGraphPrinterPass(Registry); + initializeCallGraphDOTPrinterPass(Registry); + initializeCallGraphPrinterLegacyPassPass(Registry); initializeCallGraphViewerPass(Registry); initializeCostModelAnalysisPass(Registry); initializeCFGViewerPass(Registry); initializeCFGPrinterPass(Registry); initializeCFGOnlyViewerPass(Registry); initializeCFGOnlyPrinterPass(Registry); - initializeCFLAAWrapperPassPass(Registry); - initializeDependenceAnalysisPass(Registry); + initializeCFLAndersAAWrapperPassPass(Registry); + initializeCFLSteensAAWrapperPassPass(Registry); + initializeDependenceAnalysisWrapperPassPass(Registry); initializeDelinearizationPass(Registry); - initializeDemandedBitsPass(Registry); + initializeDemandedBitsWrapperPassPass(Registry); initializeDivergenceAnalysisPass(Registry); - initializeDominanceFrontierPass(Registry); + initializeDominanceFrontierWrapperPassPass(Registry); initializeDomViewerPass(Registry); initializeDomPrinterPass(Registry); initializeDomOnlyViewerPass(Registry); @@ -49,18 +51,21 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializePostDomOnlyPrinterPass(Registry); initializeAAResultsWrapperPassPass(Registry); initializeGlobalsAAWrapperPassPass(Registry); - initializeIVUsersPass(Registry); + initializeIVUsersWrapperPassPass(Registry); initializeInstCountPass(Registry); initializeIntervalPartitionPass(Registry); - initializeLazyValueInfoPass(Registry); + initializeLazyBlockFrequencyInfoPassPass(Registry); + initializeLazyValueInfoWrapperPassPass(Registry); initializeLintPass(Registry); initializeLoopInfoWrapperPassPass(Registry); initializeMemDepPrinterPass(Registry); initializeMemDerefPrinterPass(Registry); - initializeMemoryDependenceAnalysisPass(Registry); + initializeMemoryDependenceWrapperPassPass(Registry); initializeModuleDebugInfoPrinterPass(Registry); + initializeModuleSummaryIndexWrapperPassPass(Registry); initializeObjCARCAAWrapperPassPass(Registry); - initializePostDominatorTreePass(Registry); + initializeOptimizationRemarkEmitterWrapperPassPass(Registry); + initializePostDominatorTreeWrapperPassPass(Registry); initializeRegionInfoPassPass(Registry); initializeRegionViewerPass(Registry); initializeRegionPrinterPass(Registry); diff --git a/lib/Analysis/AssumptionCache.cpp b/lib/Analysis/AssumptionCache.cpp index f468a43ef0b8..ca71644757f0 100644 --- a/lib/Analysis/AssumptionCache.cpp +++ b/lib/Analysis/AssumptionCache.cpp @@ -77,8 +77,8 @@ void AssumptionCache::registerAssumption(CallInst *CI) { char AssumptionAnalysis::PassID; PreservedAnalyses AssumptionPrinterPass::run(Function &F, - AnalysisManager<Function> *AM) { - AssumptionCache &AC = AM->getResult<AssumptionAnalysis>(F); + AnalysisManager<Function> &AM) { + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); OS << "Cached assumptions for function: " << F.getName() << "\n"; for (auto &VH : AC.assumptions()) diff --git a/lib/Analysis/BasicAliasAnalysis.cpp b/lib/Analysis/BasicAliasAnalysis.cpp index c3d280350b90..43d5c3ccf907 100644 --- a/lib/Analysis/BasicAliasAnalysis.cpp +++ b/lib/Analysis/BasicAliasAnalysis.cpp @@ -37,22 +37,23 @@ #include "llvm/Pass.h" #include "llvm/Support/ErrorHandling.h" #include <algorithm> + +#define DEBUG_TYPE "basicaa" + using namespace llvm; /// Enable analysis of recursive PHI nodes. static cl::opt<bool> EnableRecPhiAnalysis("basicaa-recphi", cl::Hidden, cl::init(false)); - /// SearchLimitReached / SearchTimes shows how often the limit of /// to decompose GEPs is reached. It will affect the precision /// of basic alias analysis. -#define DEBUG_TYPE "basicaa" STATISTIC(SearchLimitReached, "Number of times the limit to " "decompose GEPs is reached"); STATISTIC(SearchTimes, "Number of times a GEP is decomposed"); /// Cutoff after which to stop analysing a set of phi nodes potentially involved -/// in a cycle. Because we are analysing 'through' phi nodes we need to be +/// in a cycle. Because we are analysing 'through' phi nodes, we need to be /// careful with value equivalence. We use reachability to make sure a value /// cannot be involved in a cycle. const unsigned MaxNumPhiBBsValueReachabilityCheck = 20; @@ -83,7 +84,7 @@ static bool isNonEscapingLocalObject(const Value *V) { // inside the function. if (const Argument *A = dyn_cast<Argument>(V)) if (A->hasByValAttr() || A->hasNoAliasAttr()) - // Note even if the argument is marked nocapture we still need to check + // Note even if the argument is marked nocapture, we still need to check // for copies made inside the function. The nocapture attribute only // specifies that there are no copies made that outlive the function. return !PointerMayBeCaptured(V, false, /*StoreCaptures=*/true); @@ -106,7 +107,7 @@ static bool isEscapeSource(const Value *V) { return false; } -/// Returns the size of the object specified by V, or UnknownSize if unknown. +/// Returns the size of the object specified by V or UnknownSize if unknown. static uint64_t getObjectSize(const Value *V, const DataLayout &DL, const TargetLibraryInfo &TLI, bool RoundToAlign = false) { @@ -173,7 +174,7 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, /// /// Returns the scale and offset values as APInts and return V as a Value*, and /// return whether we looked through any sign or zero extends. The incoming -/// Value is known to have IntegerType and it may already be sign or zero +/// Value is known to have IntegerType, and it may already be sign or zero /// extended. /// /// Note that this looks through extends, so the high bits may not be @@ -192,8 +193,8 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, } if (const ConstantInt *Const = dyn_cast<ConstantInt>(V)) { - // if it's a constant, just convert it to an offset and remove the variable. - // If we've been called recursively the Offset bit width will be greater + // If it's a constant, just convert it to an offset and remove the variable. + // If we've been called recursively, the Offset bit width will be greater // than the constant's (the Offset's always as wide as the outermost call), // so we'll zext here and process any extension in the isa<SExtInst> & // isa<ZExtInst> cases below. @@ -205,8 +206,8 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, if (const BinaryOperator *BOp = dyn_cast<BinaryOperator>(V)) { if (ConstantInt *RHSC = dyn_cast<ConstantInt>(BOp->getOperand(1))) { - // If we've been called recursively then Offset and Scale will be wider - // that the BOp operands. We'll always zext it here as we'll process sign + // If we've been called recursively, then Offset and Scale will be wider + // than the BOp operands. We'll always zext it here as we'll process sign // extensions below (see the isa<SExtInst> / isa<ZExtInst> cases). APInt RHS = RHSC->getValue().zextOrSelf(Offset.getBitWidth()); @@ -319,6 +320,16 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, return V; } +/// To ensure a pointer offset fits in an integer of size PointerSize +/// (in bits) when that size is smaller than 64. This is an issue in +/// particular for 32b programs with negative indices that rely on two's +/// complement wrap-arounds for precise alias information. +static int64_t adjustToPointerSize(int64_t Offset, unsigned PointerSize) { + assert(PointerSize <= 64 && "Invalid PointerSize!"); + unsigned ShiftBits = 64 - PointerSize; + return (int64_t)((uint64_t)Offset << ShiftBits) >> ShiftBits; +} + /// If V is a symbolic pointer expression, decompose it into a base pointer /// with a constant offset and a number of scaled symbolic offsets. /// @@ -332,28 +343,29 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, /// GetUnderlyingObject and DecomposeGEPExpression must use the same search /// depth (MaxLookupSearchDepth). When DataLayout not is around, it just looks /// through pointer casts. -/*static*/ const Value *BasicAAResult::DecomposeGEPExpression( - const Value *V, int64_t &BaseOffs, - SmallVectorImpl<VariableGEPIndex> &VarIndices, bool &MaxLookupReached, - const DataLayout &DL, AssumptionCache *AC, DominatorTree *DT) { +bool BasicAAResult::DecomposeGEPExpression(const Value *V, + DecomposedGEP &Decomposed, const DataLayout &DL, AssumptionCache *AC, + DominatorTree *DT) { // Limit recursion depth to limit compile time in crazy cases. unsigned MaxLookup = MaxLookupSearchDepth; - MaxLookupReached = false; SearchTimes++; - BaseOffs = 0; + Decomposed.StructOffset = 0; + Decomposed.OtherOffset = 0; + Decomposed.VarIndices.clear(); do { // See if this is a bitcast or GEP. const Operator *Op = dyn_cast<Operator>(V); if (!Op) { // The only non-operator case we can handle are GlobalAliases. if (const GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (!GA->mayBeOverridden()) { + if (!GA->isInterposable()) { V = GA->getAliasee(); continue; } } - return V; + Decomposed.Base = V; + return false; } if (Op->getOpcode() == Instruction::BitCast || @@ -364,6 +376,12 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, const GEPOperator *GEPOp = dyn_cast<GEPOperator>(Op); if (!GEPOp) { + if (auto CS = ImmutableCallSite(V)) + if (const Value *RV = CS.getReturnedArgOperand()) { + V = RV; + continue; + } + // If it's not a GEP, hand it off to SimplifyInstruction to see if it // can come up with something. This matches what GetUnderlyingObject does. if (const Instruction *I = dyn_cast<Instruction>(V)) @@ -377,16 +395,20 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, continue; } - return V; + Decomposed.Base = V; + return false; } // Don't attempt to analyze GEPs over unsized objects. - if (!GEPOp->getOperand(0)->getType()->getPointerElementType()->isSized()) - return V; + if (!GEPOp->getSourceElementType()->isSized()) { + Decomposed.Base = V; + return false; + } unsigned AS = GEPOp->getPointerAddressSpace(); // Walk the indices of the GEP, accumulating them into BaseOff/VarIndices. gep_type_iterator GTI = gep_type_begin(GEPOp); + unsigned PointerSize = DL.getPointerSizeInBits(AS); for (User::const_op_iterator I = GEPOp->op_begin() + 1, E = GEPOp->op_end(); I != E; ++I) { const Value *Index = *I; @@ -397,7 +419,8 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, if (FieldNo == 0) continue; - BaseOffs += DL.getStructLayout(STy)->getElementOffset(FieldNo); + Decomposed.StructOffset += + DL.getStructLayout(STy)->getElementOffset(FieldNo); continue; } @@ -405,7 +428,8 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, if (const ConstantInt *CIdx = dyn_cast<ConstantInt>(Index)) { if (CIdx->isZero()) continue; - BaseOffs += DL.getTypeAllocSize(*GTI) * CIdx->getSExtValue(); + Decomposed.OtherOffset += + DL.getTypeAllocSize(*GTI) * CIdx->getSExtValue(); continue; } @@ -415,7 +439,6 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, // If the integer type is smaller than the pointer size, it is implicitly // sign extended to pointer size. unsigned Width = Index->getType()->getIntegerBitWidth(); - unsigned PointerSize = DL.getPointerSizeInBits(AS); if (PointerSize > Width) SExtBits += PointerSize - Width; @@ -427,44 +450,48 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. - BaseOffs += IndexOffset.getSExtValue() * Scale; + Decomposed.OtherOffset += IndexOffset.getSExtValue() * Scale; Scale *= IndexScale.getSExtValue(); // If we already had an occurrence of this index variable, merge this // scale into it. For example, we want to handle: // A[x][x] -> x*16 + x*4 -> x*20 // This also ensures that 'x' only appears in the index list once. - for (unsigned i = 0, e = VarIndices.size(); i != e; ++i) { - if (VarIndices[i].V == Index && VarIndices[i].ZExtBits == ZExtBits && - VarIndices[i].SExtBits == SExtBits) { - Scale += VarIndices[i].Scale; - VarIndices.erase(VarIndices.begin() + i); + for (unsigned i = 0, e = Decomposed.VarIndices.size(); i != e; ++i) { + if (Decomposed.VarIndices[i].V == Index && + Decomposed.VarIndices[i].ZExtBits == ZExtBits && + Decomposed.VarIndices[i].SExtBits == SExtBits) { + Scale += Decomposed.VarIndices[i].Scale; + Decomposed.VarIndices.erase(Decomposed.VarIndices.begin() + i); break; } } // Make sure that we have a scale that makes sense for this target's // pointer size. - if (unsigned ShiftBits = 64 - PointerSize) { - Scale <<= ShiftBits; - Scale = (int64_t)Scale >> ShiftBits; - } + Scale = adjustToPointerSize(Scale, PointerSize); if (Scale) { VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, static_cast<int64_t>(Scale)}; - VarIndices.push_back(Entry); + Decomposed.VarIndices.push_back(Entry); } } + // Take care of wrap-arounds + Decomposed.StructOffset = + adjustToPointerSize(Decomposed.StructOffset, PointerSize); + Decomposed.OtherOffset = + adjustToPointerSize(Decomposed.OtherOffset, PointerSize); + // Analyze the base pointer next. V = GEPOp->getOperand(0); } while (--MaxLookup); // If the chain of expressions is too deep, just return early. - MaxLookupReached = true; + Decomposed.Base = V; SearchLimitReached++; - return V; + return true; } /// Returns whether the given pointer value points to memory that is local to @@ -530,22 +557,6 @@ bool BasicAAResult::pointsToConstantMemory(const MemoryLocation &Loc, return Worklist.empty(); } -// FIXME: This code is duplicated with MemoryLocation and should be hoisted to -// some common utility location. -static bool isMemsetPattern16(const Function *MS, - const TargetLibraryInfo &TLI) { - if (TLI.has(LibFunc::memset_pattern16) && - MS->getName() == "memset_pattern16") { - FunctionType *MemsetType = MS->getFunctionType(); - if (!MemsetType->isVarArg() && MemsetType->getNumParams() == 3 && - isa<PointerType>(MemsetType->getParamType(0)) && - isa<PointerType>(MemsetType->getParamType(1)) && - isa<IntegerType>(MemsetType->getParamType(2))) - return true; - } - return false; -} - /// Returns the behavior when calling the given call site. FunctionModRefBehavior BasicAAResult::getModRefBehavior(ImmutableCallSite CS) { if (CS.doesNotAccessMemory()) @@ -558,12 +569,21 @@ FunctionModRefBehavior BasicAAResult::getModRefBehavior(ImmutableCallSite CS) { // than that. if (CS.onlyReadsMemory()) Min = FMRB_OnlyReadsMemory; + else if (CS.doesNotReadMemory()) + Min = FMRB_DoesNotReadMemory; if (CS.onlyAccessesArgMemory()) Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); - // The AAResultBase base class has some smarts, lets use them. - return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); + // If CS has operand bundles then aliasing attributes from the function it + // calls do not directly apply to the CallSite. This can be made more + // precise in the future. + if (!CS.hasOperandBundles()) + if (const Function *F = CS.getCalledFunction()) + Min = + FunctionModRefBehavior(Min & getBestAAResults().getModRefBehavior(F)); + + return Min; } /// Returns the behavior when calling the given function. For use when the call @@ -578,41 +598,30 @@ FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) { // If the function declares it only reads memory, go with that. if (F->onlyReadsMemory()) Min = FMRB_OnlyReadsMemory; + else if (F->doesNotReadMemory()) + Min = FMRB_DoesNotReadMemory; if (F->onlyAccessesArgMemory()) Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); - // Otherwise be conservative. - return FunctionModRefBehavior(AAResultBase::getModRefBehavior(F) & Min); + return Min; } -/// Returns true if this is a writeonly (i.e Mod only) parameter. Currently, -/// we don't have a writeonly attribute, so this only knows about builtin -/// intrinsics and target library functions. We could consider adding a -/// writeonly attribute in the future and moving all of these facts to either -/// Intrinsics.td or InferFunctionAttr.cpp +/// Returns true if this is a writeonly (i.e Mod only) parameter. static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, const TargetLibraryInfo &TLI) { - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) - switch (II->getIntrinsicID()) { - default: - break; - case Intrinsic::memset: - case Intrinsic::memcpy: - case Intrinsic::memmove: - // We don't currently have a writeonly attribute. All other properties - // of these intrinsics are nicely described via attributes in - // Intrinsics.td and handled generically. - if (ArgIdx == 0) - return true; - } + if (CS.paramHasAttr(ArgIdx + 1, Attribute::WriteOnly)) + return true; // We can bound the aliasing properties of memset_pattern16 just as we can // for memcpy/memset. This is particularly important because the // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 - // whenever possible. Note that all but the missing writeonly attribute are - // handled via InferFunctionAttr. - if (CS.getCalledFunction() && isMemsetPattern16(CS.getCalledFunction(), TLI)) + // whenever possible. + // FIXME Consider handling this in InferFunctionAttr.cpp together with other + // attributes. + LibFunc::Func F; + if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && + F == LibFunc::memset_pattern16 && TLI.has(F)) if (ArgIdx == 0) return true; @@ -626,8 +635,7 @@ static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, ModRefInfo BasicAAResult::getArgModRefInfo(ImmutableCallSite CS, unsigned ArgIdx) { - // Emulate the missing writeonly attribute by checking for known builtin - // intrinsics and target library functions. + // Checking for known builtin intrinsics and target library functions. if (isWriteOnlyParam(CS, ArgIdx, TLI)) return MRI_Mod; @@ -640,9 +648,9 @@ ModRefInfo BasicAAResult::getArgModRefInfo(ImmutableCallSite CS, return AAResultBase::getArgModRefInfo(CS, ArgIdx); } -static bool isAssumeIntrinsic(ImmutableCallSite CS) { +static bool isIntrinsicCall(ImmutableCallSite CS, Intrinsic::ID IID) { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction()); - return II && II->getIntrinsicID() == Intrinsic::assume; + return II && II->getIntrinsicID() == IID; } #ifndef NDEBUG @@ -717,14 +725,14 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, if (!isa<Constant>(Object) && CS.getInstruction() != Object && isNonEscapingLocalObject(Object)) { bool PassedAsArg = false; - unsigned ArgNo = 0; - for (ImmutableCallSite::arg_iterator CI = CS.arg_begin(), CE = CS.arg_end(); - CI != CE; ++CI, ++ArgNo) { + unsigned OperandNo = 0; + for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); + CI != CE; ++CI, ++OperandNo) { // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it // couldn't be no-capture. if (!(*CI)->getType()->isPointerTy() || - (!CS.doesNotCapture(ArgNo) && !CS.isByValArgument(ArgNo))) + (!CS.doesNotCapture(OperandNo) && !CS.isByValArgument(OperandNo))) continue; // If this is a no-capture pointer argument, see if we can tell that it @@ -743,12 +751,36 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, return MRI_NoModRef; } + // If the CallSite is to malloc or calloc, we can assume that it doesn't + // modify any IR visible value. This is only valid because we assume these + // routines do not read values visible in the IR. TODO: Consider special + // casing realloc and strdup routines which access only their arguments as + // well. Or alternatively, replace all of this with inaccessiblememonly once + // that's implemented fully. + auto *Inst = CS.getInstruction(); + if (isMallocLikeFn(Inst, &TLI) || isCallocLikeFn(Inst, &TLI)) { + // Be conservative if the accessed pointer may alias the allocation - + // fallback to the generic handling below. + if (getBestAAResults().alias(MemoryLocation(Inst), Loc) == NoAlias) + return MRI_NoModRef; + } + // While the assume intrinsic is marked as arbitrarily writing so that // proper control dependencies will be maintained, it never aliases any // particular memory location. - if (isAssumeIntrinsic(CS)) + if (isIntrinsicCall(CS, Intrinsic::assume)) return MRI_NoModRef; + // Like assumes, guard intrinsics are also marked as arbitrarily writing so + // that proper control dependencies are maintained but they never mods any + // particular memory location. + // + // *Unlike* assumes, guard intrinsics are modeled as reading memory since the + // heap state at the point the guard is issued needs to be consistent in case + // the guard invokes the "deopt" continuation. + if (isIntrinsicCall(CS, Intrinsic::experimental_guard)) + return MRI_Ref; + // The AAResultBase base class has some smarts, lets use them. return AAResultBase::getModRefInfo(CS, Loc); } @@ -758,9 +790,27 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS1, // While the assume intrinsic is marked as arbitrarily writing so that // proper control dependencies will be maintained, it never aliases any // particular memory location. - if (isAssumeIntrinsic(CS1) || isAssumeIntrinsic(CS2)) + if (isIntrinsicCall(CS1, Intrinsic::assume) || + isIntrinsicCall(CS2, Intrinsic::assume)) return MRI_NoModRef; + // Like assumes, guard intrinsics are also marked as arbitrarily writing so + // that proper control dependencies are maintained but they never mod any + // particular memory location. + // + // *Unlike* assumes, guard intrinsics are modeled as reading memory since the + // heap state at the point the guard is issued needs to be consistent in case + // the guard invokes the "deopt" continuation. + + // NB! This function is *not* commutative, so we specical case two + // possibilities for guard intrinsics. + + if (isIntrinsicCall(CS1, Intrinsic::experimental_guard)) + return getModRefBehavior(CS2) & MRI_Mod ? MRI_Ref : MRI_NoModRef; + + if (isIntrinsicCall(CS2, Intrinsic::experimental_guard)) + return getModRefBehavior(CS1) & MRI_Mod ? MRI_Mod : MRI_NoModRef; + // The AAResultBase base class has some smarts, lets use them. return AAResultBase::getModRefInfo(CS1, CS2); } @@ -773,7 +823,10 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, uint64_t V2Size, const DataLayout &DL) { - assert(GEP1->getPointerOperand() == GEP2->getPointerOperand() && + assert(GEP1->getPointerOperand()->stripPointerCasts() == + GEP2->getPointerOperand()->stripPointerCasts() && + GEP1->getPointerOperand()->getType() == + GEP2->getPointerOperand()->getType() && "Expected GEPs with the same pointer operand"); // Try to determine whether GEP1 and GEP2 index through arrays, into structs, @@ -796,7 +849,7 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, // If the last (struct) indices are constants and are equal, the other indices // might be also be dynamically equal, so the GEPs can alias. - if (C1 && C2 && C1 == C2) + if (C1 && C2 && C1->getSExtValue() == C2->getSExtValue()) return MayAlias; // Find the last-indexed type of the GEP, i.e., the type you'd get if @@ -895,6 +948,67 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, return MayAlias; } +// If a we have (a) a GEP and (b) a pointer based on an alloca, and the +// beginning of the object the GEP points would have a negative offset with +// repsect to the alloca, that means the GEP can not alias pointer (b). +// Note that the pointer based on the alloca may not be a GEP. For +// example, it may be the alloca itself. +// The same applies if (b) is based on a GlobalVariable. Note that just being +// based on isIdentifiedObject() is not enough - we need an identified object +// that does not permit access to negative offsets. For example, a negative +// offset from a noalias argument or call can be inbounds w.r.t the actual +// underlying object. +// +// For example, consider: +// +// struct { int f0, int f1, ...} foo; +// foo alloca; +// foo* random = bar(alloca); +// int *f0 = &alloca.f0 +// int *f1 = &random->f1; +// +// Which is lowered, approximately, to: +// +// %alloca = alloca %struct.foo +// %random = call %struct.foo* @random(%struct.foo* %alloca) +// %f0 = getelementptr inbounds %struct, %struct.foo* %alloca, i32 0, i32 0 +// %f1 = getelementptr inbounds %struct, %struct.foo* %random, i32 0, i32 1 +// +// Assume %f1 and %f0 alias. Then %f1 would point into the object allocated +// by %alloca. Since the %f1 GEP is inbounds, that means %random must also +// point into the same object. But since %f0 points to the beginning of %alloca, +// the highest %f1 can be is (%alloca + 3). This means %random can not be higher +// than (%alloca - 1), and so is not inbounds, a contradiction. +bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, + const DecomposedGEP &DecompGEP, const DecomposedGEP &DecompObject, + uint64_t ObjectAccessSize) { + // If the object access size is unknown, or the GEP isn't inbounds, bail. + if (ObjectAccessSize == MemoryLocation::UnknownSize || !GEPOp->isInBounds()) + return false; + + // We need the object to be an alloca or a globalvariable, and want to know + // the offset of the pointer from the object precisely, so no variable + // indices are allowed. + if (!(isa<AllocaInst>(DecompObject.Base) || + isa<GlobalVariable>(DecompObject.Base)) || + !DecompObject.VarIndices.empty()) + return false; + + int64_t ObjectBaseOffset = DecompObject.StructOffset + + DecompObject.OtherOffset; + + // If the GEP has no variable indices, we know the precise offset + // from the base, then use it. If the GEP has variable indices, we're in + // a bit more trouble: we can't count on the constant offsets that come + // from non-struct sources, since these can be "rewound" by a negative + // variable offset. So use only offsets that came from structs. + int64_t GEPBaseOffset = DecompGEP.StructOffset; + if (DecompGEP.VarIndices.empty()) + GEPBaseOffset += DecompGEP.OtherOffset; + + return (GEPBaseOffset >= ObjectBaseOffset + (int64_t)ObjectAccessSize); +} + /// Provides a bunch of ad-hoc rules to disambiguate a GEP instruction against /// another pointer. /// @@ -906,14 +1020,34 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, uint64_t V2Size, const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2) { - int64_t GEP1BaseOffset; - bool GEP1MaxLookupReached; - SmallVector<VariableGEPIndex, 4> GEP1VariableIndices; - + DecomposedGEP DecompGEP1, DecompGEP2; + bool GEP1MaxLookupReached = + DecomposeGEPExpression(GEP1, DecompGEP1, DL, &AC, DT); + bool GEP2MaxLookupReached = + DecomposeGEPExpression(V2, DecompGEP2, DL, &AC, DT); + + int64_t GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; + int64_t GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; + + assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 && + "DecomposeGEPExpression returned a result different from " + "GetUnderlyingObject"); + + // If the GEP's offset relative to its base is such that the base would + // fall below the start of the object underlying V2, then the GEP and V2 + // cannot alias. + if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && + isGEPBaseAtNegativeOffset(GEP1, DecompGEP1, DecompGEP2, V2Size)) + return NoAlias; // If we have two gep instructions with must-alias or not-alias'ing base // pointers, figure out if the indexes to the GEP tell us anything about the // derived pointer. if (const GEPOperator *GEP2 = dyn_cast<GEPOperator>(V2)) { + // Check for the GEP base being at a negative offset, this time in the other + // direction. + if (!GEP1MaxLookupReached && !GEP2MaxLookupReached && + isGEPBaseAtNegativeOffset(GEP2, DecompGEP2, DecompGEP1, V1Size)) + return NoAlias; // Do the base pointers alias? AliasResult BaseAlias = aliasCheck(UnderlyingV1, MemoryLocation::UnknownSize, AAMDNodes(), @@ -928,31 +1062,14 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (PreciseBaseAlias == NoAlias) { // See if the computed offset from the common pointer tells us about the // relation of the resulting pointer. - int64_t GEP2BaseOffset; - bool GEP2MaxLookupReached; - SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; - const Value *GEP2BasePtr = - DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL, &AC, DT); - const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL, &AC, DT); - // DecomposeGEPExpression and GetUnderlyingObject should return the - // same result except when DecomposeGEPExpression has no DataLayout. - // FIXME: They always have a DataLayout so this should become an - // assert. - if (GEP1BasePtr != UnderlyingV1 || GEP2BasePtr != UnderlyingV2) { - return MayAlias; - } // If the max search depth is reached the result is undefined if (GEP2MaxLookupReached || GEP1MaxLookupReached) return MayAlias; // Same offsets. if (GEP1BaseOffset == GEP2BaseOffset && - GEP1VariableIndices == GEP2VariableIndices) + DecompGEP1.VarIndices == DecompGEP2.VarIndices) return NoAlias; - GEP1VariableIndices.clear(); } } @@ -964,42 +1081,27 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // Otherwise, we have a MustAlias. Since the base pointers alias each other // exactly, see if the computed offset from the common pointer tells us // about the relation of the resulting pointer. - const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL, &AC, DT); - - int64_t GEP2BaseOffset; - bool GEP2MaxLookupReached; - SmallVector<VariableGEPIndex, 4> GEP2VariableIndices; - const Value *GEP2BasePtr = - DecomposeGEPExpression(GEP2, GEP2BaseOffset, GEP2VariableIndices, - GEP2MaxLookupReached, DL, &AC, DT); - - // DecomposeGEPExpression and GetUnderlyingObject should return the - // same result except when DecomposeGEPExpression has no DataLayout. - // FIXME: They always have a DataLayout so this should become an assert. - if (GEP1BasePtr != UnderlyingV1 || GEP2BasePtr != UnderlyingV2) { - return MayAlias; - } - // If we know the two GEPs are based off of the exact same pointer (and not // just the same underlying object), see if that tells us anything about // the resulting pointers. - if (GEP1->getPointerOperand() == GEP2->getPointerOperand()) { + if (GEP1->getPointerOperand()->stripPointerCasts() == + GEP2->getPointerOperand()->stripPointerCasts() && + GEP1->getPointerOperand()->getType() == + GEP2->getPointerOperand()->getType()) { AliasResult R = aliasSameBasePointerGEPs(GEP1, V1Size, GEP2, V2Size, DL); // If we couldn't find anything interesting, don't abandon just yet. if (R != MayAlias) return R; } - // If the max search depth is reached the result is undefined + // If the max search depth is reached, the result is undefined if (GEP2MaxLookupReached || GEP1MaxLookupReached) return MayAlias; // Subtract the GEP2 pointer from the GEP1 pointer to find out their // symbolic difference. GEP1BaseOffset -= GEP2BaseOffset; - GetIndexDifference(GEP1VariableIndices, GEP2VariableIndices); + GetIndexDifference(DecompGEP1.VarIndices, DecompGEP2.VarIndices); } else { // Check to see if these two pointers are related by the getelementptr @@ -1021,16 +1123,6 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // with the first operand of the getelementptr". return R; - const Value *GEP1BasePtr = - DecomposeGEPExpression(GEP1, GEP1BaseOffset, GEP1VariableIndices, - GEP1MaxLookupReached, DL, &AC, DT); - - // DecomposeGEPExpression and GetUnderlyingObject should return the - // same result except when DecomposeGEPExpression has no DataLayout. - // FIXME: They always have a DataLayout so this should become an assert. - if (GEP1BasePtr != UnderlyingV1) { - return MayAlias; - } // If the max search depth is reached the result is undefined if (GEP1MaxLookupReached) return MayAlias; @@ -1038,18 +1130,18 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // In the two GEP Case, if there is no difference in the offsets of the // computed pointers, the resultant pointers are a must alias. This - // hapens when we have two lexically identical GEP's (for example). + // happens when we have two lexically identical GEP's (for example). // // In the other case, if we have getelementptr <ptr>, 0, 0, 0, 0, ... and V2 // must aliases the GEP, the end result is a must alias also. - if (GEP1BaseOffset == 0 && GEP1VariableIndices.empty()) + if (GEP1BaseOffset == 0 && DecompGEP1.VarIndices.empty()) return MustAlias; // If there is a constant difference between the pointers, but the difference // is less than the size of the associated memory object, then we know // that the objects are partially overlapping. If the difference is // greater, we know they do not overlap. - if (GEP1BaseOffset != 0 && GEP1VariableIndices.empty()) { + if (GEP1BaseOffset != 0 && DecompGEP1.VarIndices.empty()) { if (GEP1BaseOffset >= 0) { if (V2Size != MemoryLocation::UnknownSize) { if ((uint64_t)GEP1BaseOffset < V2Size) @@ -1074,22 +1166,22 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, } } - if (!GEP1VariableIndices.empty()) { + if (!DecompGEP1.VarIndices.empty()) { uint64_t Modulo = 0; bool AllPositive = true; - for (unsigned i = 0, e = GEP1VariableIndices.size(); i != e; ++i) { + for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) { // Try to distinguish something like &A[i][1] against &A[42][0]. // Grab the least significant bit set in any of the scales. We // don't need std::abs here (even if the scale's negative) as we'll // be ^'ing Modulo with itself later. - Modulo |= (uint64_t)GEP1VariableIndices[i].Scale; + Modulo |= (uint64_t)DecompGEP1.VarIndices[i].Scale; if (AllPositive) { // If the Value could change between cycles, then any reasoning about // the Value this cycle may not hold in the next cycle. We'll just // give up if we can't determine conditions that hold for every cycle: - const Value *V = GEP1VariableIndices[i].V; + const Value *V = DecompGEP1.VarIndices[i].V; bool SignKnownZero, SignKnownOne; ComputeSignBit(const_cast<Value *>(V), SignKnownZero, SignKnownOne, DL, @@ -1097,14 +1189,14 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, // Zero-extension widens the variable, and so forces the sign // bit to zero. - bool IsZExt = GEP1VariableIndices[i].ZExtBits > 0 || isa<ZExtInst>(V); + bool IsZExt = DecompGEP1.VarIndices[i].ZExtBits > 0 || isa<ZExtInst>(V); SignKnownZero |= IsZExt; SignKnownOne &= !IsZExt; // If the variable begins with a zero then we know it's // positive, regardless of whether the value is signed or // unsigned. - int64_t Scale = GEP1VariableIndices[i].Scale; + int64_t Scale = DecompGEP1.VarIndices[i].Scale; AllPositive = (SignKnownZero && Scale >= 0) || (SignKnownOne && Scale < 0); } @@ -1127,7 +1219,7 @@ AliasResult BasicAAResult::aliasGEP(const GEPOperator *GEP1, uint64_t V1Size, if (AllPositive && GEP1BaseOffset > 0 && V2Size <= (uint64_t)GEP1BaseOffset) return NoAlias; - if (constantOffsetHeuristic(GEP1VariableIndices, V1Size, V2Size, + if (constantOffsetHeuristic(DecompGEP1.VarIndices, V1Size, V2Size, GEP1BaseOffset, &AC, DT)) return NoAlias; } @@ -1312,7 +1404,7 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, uint64_t V1Size, return NoAlias; // Are we checking for alias of the same value? - // Because we look 'through' phi nodes we could look at "Value" pointers from + // Because we look 'through' phi nodes, we could look at "Value" pointers from // different iterations. We must therefore make sure that this is not the // case. The function isValueEqualInPotentialCycles ensures that this cannot // happen by looking at the visited phi nodes and making sure they cannot @@ -1337,7 +1429,7 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, uint64_t V1Size, return NoAlias; if (O1 != O2) { - // If V1/V2 point to two different objects we know that we have no alias. + // If V1/V2 point to two different objects, we know that we have no alias. if (isIdentifiedObject(O1) && isIdentifiedObject(O2)) return NoAlias; @@ -1430,8 +1522,7 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, uint64_t V1Size, } // If both pointers are pointing into the same object and one of them - // accesses is accessing the entire object, then the accesses must - // overlap in some way. + // accesses the entire object, then the accesses must overlap in some way. if (O1 == O2) if ((V1Size != MemoryLocation::UnknownSize && isObjectSize(O1, V1Size, DL, TLI)) || @@ -1544,7 +1635,8 @@ bool BasicAAResult::constantOffsetHeuristic( unsigned V0ZExtBits = 0, V0SExtBits = 0, V1ZExtBits = 0, V1SExtBits = 0; const Value *V0 = GetLinearExpression(Var0.V, V0Scale, V0Offset, V0ZExtBits, V0SExtBits, DL, 0, AC, DT, NSW, NUW); - NSW = true, NUW = true; + NSW = true; + NUW = true; const Value *V1 = GetLinearExpression(Var1.V, V1Scale, V1Offset, V1ZExtBits, V1SExtBits, DL, 0, AC, DT, NSW, NUW); @@ -1577,12 +1669,12 @@ bool BasicAAResult::constantOffsetHeuristic( char BasicAA::PassID; -BasicAAResult BasicAA::run(Function &F, AnalysisManager<Function> *AM) { +BasicAAResult BasicAA::run(Function &F, AnalysisManager<Function> &AM) { return BasicAAResult(F.getParent()->getDataLayout(), - AM->getResult<TargetLibraryAnalysis>(F), - AM->getResult<AssumptionAnalysis>(F), - AM->getCachedResult<DominatorTreeAnalysis>(F), - AM->getCachedResult<LoopAnalysis>(F)); + AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + &AM.getResult<DominatorTreeAnalysis>(F), + AM.getCachedResult<LoopAnalysis>(F)); } BasicAAWrapperPass::BasicAAWrapperPass() : FunctionPass(ID) { @@ -1595,6 +1687,7 @@ void BasicAAWrapperPass::anchor() {} INITIALIZE_PASS_BEGIN(BasicAAWrapperPass, "basicaa", "Basic Alias Analysis (stateless AA impl)", true, true) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(BasicAAWrapperPass, "basicaa", "Basic Alias Analysis (stateless AA impl)", true, true) @@ -1606,12 +1699,11 @@ FunctionPass *llvm::createBasicAAWrapperPass() { bool BasicAAWrapperPass::runOnFunction(Function &F) { auto &ACT = getAnalysis<AssumptionCacheTracker>(); auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); - auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto &DTWP = getAnalysis<DominatorTreeWrapperPass>(); auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); Result.reset(new BasicAAResult(F.getParent()->getDataLayout(), TLIWP.getTLI(), - ACT.getAssumptionCache(F), - DTWP ? &DTWP->getDomTree() : nullptr, + ACT.getAssumptionCache(F), &DTWP.getDomTree(), LIWP ? &LIWP->getLoopInfo() : nullptr)); return false; @@ -1620,6 +1712,7 @@ bool BasicAAWrapperPass::runOnFunction(Function &F) { void BasicAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } diff --git a/lib/Analysis/BlockFrequencyInfo.cpp b/lib/Analysis/BlockFrequencyInfo.cpp index 90b7a339a0fe..1dd8f4fdfcfe 100644 --- a/lib/Analysis/BlockFrequencyInfo.cpp +++ b/lib/Analysis/BlockFrequencyInfo.cpp @@ -27,24 +27,34 @@ using namespace llvm; #define DEBUG_TYPE "block-freq" #ifndef NDEBUG -enum GVDAGType { - GVDT_None, - GVDT_Fraction, - GVDT_Integer -}; - -static cl::opt<GVDAGType> -ViewBlockFreqPropagationDAG("view-block-freq-propagation-dags", cl::Hidden, - cl::desc("Pop up a window to show a dag displaying how block " - "frequencies propagation through the CFG."), - cl::values( - clEnumValN(GVDT_None, "none", - "do not display graphs."), - clEnumValN(GVDT_Fraction, "fraction", "display a graph using the " - "fractional block frequency representation."), - clEnumValN(GVDT_Integer, "integer", "display a graph using the raw " - "integer fractional block frequency representation."), - clEnumValEnd)); +static cl::opt<GVDAGType> ViewBlockFreqPropagationDAG( + "view-block-freq-propagation-dags", cl::Hidden, + cl::desc("Pop up a window to show a dag displaying how block " + "frequencies propagation through the CFG."), + cl::values(clEnumValN(GVDT_None, "none", "do not display graphs."), + clEnumValN(GVDT_Fraction, "fraction", + "display a graph using the " + "fractional block frequency representation."), + clEnumValN(GVDT_Integer, "integer", + "display a graph using the raw " + "integer fractional block frequency representation."), + clEnumValN(GVDT_Count, "count", "display a graph using the real " + "profile count if available."), + clEnumValEnd)); + +cl::opt<std::string> + ViewBlockFreqFuncName("view-bfi-func-name", cl::Hidden, + cl::desc("The option to specify " + "the name of the function " + "whose CFG will be displayed.")); + +cl::opt<unsigned> + ViewHotFreqPercent("view-hot-freq-percent", cl::init(10), cl::Hidden, + cl::desc("An integer in percent used to specify " + "the hot blocks/edges to be displayed " + "in red: a block or edge whose frequency " + "is no less than the max frequency of the " + "function multiplied by this percent.")); namespace llvm { @@ -71,34 +81,31 @@ struct GraphTraits<BlockFrequencyInfo *> { } }; -template<> -struct DOTGraphTraits<BlockFrequencyInfo*> : public DefaultDOTGraphTraits { - explicit DOTGraphTraits(bool isSimple=false) : - DefaultDOTGraphTraits(isSimple) {} +typedef BFIDOTGraphTraitsBase<BlockFrequencyInfo, BranchProbabilityInfo> + BFIDOTGTraitsBase; - static std::string getGraphName(const BlockFrequencyInfo *G) { - return G->getFunction()->getName(); - } +template <> +struct DOTGraphTraits<BlockFrequencyInfo *> : public BFIDOTGTraitsBase { + explicit DOTGraphTraits(bool isSimple = false) + : BFIDOTGTraitsBase(isSimple) {} std::string getNodeLabel(const BasicBlock *Node, const BlockFrequencyInfo *Graph) { - std::string Result; - raw_string_ostream OS(Result); - - OS << Node->getName() << ":"; - switch (ViewBlockFreqPropagationDAG) { - case GVDT_Fraction: - Graph->printBlockFreq(OS, Node); - break; - case GVDT_Integer: - OS << Graph->getBlockFreq(Node).getFrequency(); - break; - case GVDT_None: - llvm_unreachable("If we are not supposed to render a graph we should " - "never reach this point."); - } - - return Result; + + return BFIDOTGTraitsBase::getNodeLabel(Node, Graph, + ViewBlockFreqPropagationDAG); + } + + std::string getNodeAttributes(const BasicBlock *Node, + const BlockFrequencyInfo *Graph) { + return BFIDOTGTraitsBase::getNodeAttributes(Node, Graph, + ViewHotFreqPercent); + } + + std::string getEdgeAttributes(const BasicBlock *Node, EdgeIter EI, + const BlockFrequencyInfo *BFI) { + return BFIDOTGTraitsBase::getEdgeAttributes(Node, EI, BFI, BFI->getBPI(), + ViewHotFreqPercent); } }; @@ -113,6 +120,21 @@ BlockFrequencyInfo::BlockFrequencyInfo(const Function &F, calculate(F, BPI, LI); } +BlockFrequencyInfo::BlockFrequencyInfo(BlockFrequencyInfo &&Arg) + : BFI(std::move(Arg.BFI)) {} + +BlockFrequencyInfo &BlockFrequencyInfo::operator=(BlockFrequencyInfo &&RHS) { + releaseMemory(); + BFI = std::move(RHS.BFI); + return *this; +} + +// Explicitly define the default constructor otherwise it would be implicitly +// defined at the first ODR-use which is the BFI member in the +// LazyBlockFrequencyInfo header. The dtor needs the BlockFrequencyInfoImpl +// template instantiated which is not available in the header. +BlockFrequencyInfo::~BlockFrequencyInfo() {} + void BlockFrequencyInfo::calculate(const Function &F, const BranchProbabilityInfo &BPI, const LoopInfo &LI) { @@ -120,8 +142,11 @@ void BlockFrequencyInfo::calculate(const Function &F, BFI.reset(new ImplType); BFI->calculate(F, BPI, LI); #ifndef NDEBUG - if (ViewBlockFreqPropagationDAG != GVDT_None) + if (ViewBlockFreqPropagationDAG != GVDT_None && + (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { view(); + } #endif } @@ -129,8 +154,15 @@ BlockFrequency BlockFrequencyInfo::getBlockFreq(const BasicBlock *BB) const { return BFI ? BFI->getBlockFreq(BB) : 0; } -void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, - uint64_t Freq) { +Optional<uint64_t> +BlockFrequencyInfo::getBlockProfileCount(const BasicBlock *BB) const { + if (!BFI) + return None; + + return BFI->getBlockProfileCount(*getFunction(), BB); +} + +void BlockFrequencyInfo::setBlockFreq(const BasicBlock *BB, uint64_t Freq) { assert(BFI && "Expected analysis to be available"); BFI->setBlockFreq(BB, Freq); } @@ -151,6 +183,10 @@ const Function *BlockFrequencyInfo::getFunction() const { return BFI ? BFI->getFunction() : nullptr; } +const BranchProbabilityInfo *BlockFrequencyInfo::getBPI() const { + return BFI ? &BFI->getBPI() : nullptr; +} + raw_ostream &BlockFrequencyInfo:: printBlockFreq(raw_ostream &OS, const BlockFrequency Freq) const { return BFI ? BFI->printBlockFreq(OS, Freq) : OS; @@ -211,3 +247,21 @@ bool BlockFrequencyInfoWrapperPass::runOnFunction(Function &F) { BFI.calculate(F, BPI, LI); return false; } + +char BlockFrequencyAnalysis::PassID; +BlockFrequencyInfo BlockFrequencyAnalysis::run(Function &F, + AnalysisManager<Function> &AM) { + BlockFrequencyInfo BFI; + BFI.calculate(F, AM.getResult<BranchProbabilityAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); + return BFI; +} + +PreservedAnalyses +BlockFrequencyPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { + OS << "Printing analysis results of BFI for function " + << "'" << F.getName() << "':" + << "\n"; + AM.getResult<BlockFrequencyAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/lib/Analysis/BlockFrequencyInfoImpl.cpp b/lib/Analysis/BlockFrequencyInfoImpl.cpp index 48e23af2690a..90bc249bcb39 100644 --- a/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -13,6 +13,7 @@ #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/IR/Function.h" #include "llvm/Support/raw_ostream.h" #include <numeric> @@ -27,7 +28,7 @@ ScaledNumber<uint64_t> BlockMass::toScaled() const { return ScaledNumber<uint64_t>(getMass() + 1, -64); } -void BlockMass::dump() const { print(dbgs()); } +LLVM_DUMP_METHOD void BlockMass::dump() const { print(dbgs()); } static char getHexDigit(int N) { assert(N < 16); @@ -35,6 +36,7 @@ static char getHexDigit(int N) { return '0' + N; return 'a' + N - 10; } + raw_ostream &BlockMass::print(raw_ostream &OS) const { for (int Digits = 0; Digits < 16; ++Digits) OS << getHexDigit(Mass >> (60 - Digits * 4) & 0xf); @@ -78,7 +80,7 @@ struct DitheringDistributer { BlockMass takeMass(uint32_t Weight); }; -} // end namespace +} // end anonymous namespace DitheringDistributer::DitheringDistributer(Distribution &Dist, const BlockMass &Mass) { @@ -130,6 +132,7 @@ static void combineWeight(Weight &W, const Weight &OtherW) { else W.Amount += OtherW.Amount; } + static void combineWeightsBySorting(WeightList &Weights) { // Sort so edges to the same node are adjacent. std::sort(Weights.begin(), Weights.end(), @@ -149,8 +152,8 @@ static void combineWeightsBySorting(WeightList &Weights) { // Erase extra entries. Weights.erase(O, Weights.end()); - return; } + static void combineWeightsByHashing(WeightList &Weights) { // Collect weights into a DenseMap. typedef DenseMap<BlockNode::IndexType, Weight> HashTable; @@ -168,6 +171,7 @@ static void combineWeightsByHashing(WeightList &Weights) { for (const auto &I : Combined) Weights.push_back(I.second); } + static void combineWeights(WeightList &Weights) { // Use a hash table for many successors to keep this linear. if (Weights.size() > 128) { @@ -177,6 +181,7 @@ static void combineWeights(WeightList &Weights) { combineWeightsBySorting(Weights); } + static uint64_t shiftRightAndRound(uint64_t N, int Shift) { assert(Shift >= 0); assert(Shift < 64); @@ -184,6 +189,7 @@ static uint64_t shiftRightAndRound(uint64_t N, int Shift) { return N; return (N >> Shift) + (UINT64_C(1) & N >> (Shift - 1)); } + void Distribution::normalize() { // Early exit for termination nodes. if (Weights.empty()) @@ -345,7 +351,7 @@ void BlockFrequencyInfoImplBase::computeLoopScale(LoopData &Loop) { // replaced to be the maximum of all computed scales plus 1. This would // appropriately describe the loop as having a large scale, without skewing // the final frequency computation. - const Scaled64 InifiniteLoopScale(1, 12); + const Scaled64 InfiniteLoopScale(1, 12); // LoopScale == 1 / ExitMass // ExitMass == HeadMass - BackedgeMass @@ -358,7 +364,7 @@ void BlockFrequencyInfoImplBase::computeLoopScale(LoopData &Loop) { // its exit mass will be zero. In this case, use an arbitrary scale for the // loop scale. Loop.Scale = - ExitMass.isEmpty() ? InifiniteLoopScale : ExitMass.toScaled().inverse(); + ExitMass.isEmpty() ? InfiniteLoopScale : ExitMass.toScaled().inverse(); DEBUG(dbgs() << " - exit-mass = " << ExitMass << " (" << BlockMass::getFull() << " - " << TotalBackedgeMass << ")\n" @@ -523,6 +529,22 @@ BlockFrequencyInfoImplBase::getBlockFreq(const BlockNode &Node) const { return 0; return Freqs[Node.Index].Integer; } + +Optional<uint64_t> +BlockFrequencyInfoImplBase::getBlockProfileCount(const Function &F, + const BlockNode &Node) const { + auto EntryCount = F.getEntryCount(); + if (!EntryCount) + return None; + // Use 128 bit APInt to do the arithmetic to avoid overflow. + APInt BlockCount(128, EntryCount.getValue()); + APInt BlockFreq(128, getBlockFreq(Node).getFrequency()); + APInt EntryFreq(128, getEntryFreq()); + BlockCount *= BlockFreq; + BlockCount = BlockCount.udiv(EntryFreq); + return BlockCount.getLimitedValue(); +} + Scaled64 BlockFrequencyInfoImplBase::getFloatingBlockFreq(const BlockNode &Node) const { if (!Node.isValid()) @@ -541,6 +563,7 @@ std::string BlockFrequencyInfoImplBase::getBlockName(const BlockNode &Node) const { return std::string(); } + std::string BlockFrequencyInfoImplBase::getLoopName(const LoopData &Loop) const { return getBlockName(Loop.getHeader()) + (Loop.isIrreducible() ? "**" : "*"); @@ -568,6 +591,7 @@ void IrreducibleGraph::addNodesInLoop(const BFIBase::LoopData &OuterLoop) { addNode(N); indexNodes(); } + void IrreducibleGraph::addNodesInFunction() { Start = 0; for (uint32_t Index = 0; Index < BFI.Working.size(); ++Index) @@ -575,10 +599,12 @@ void IrreducibleGraph::addNodesInFunction() { addNode(Index); indexNodes(); } + void IrreducibleGraph::indexNodes() { for (auto &I : Nodes) Lookup[I.Node.Index] = &I; } + void IrreducibleGraph::addEdge(IrrNode &Irr, const BlockNode &Succ, const BFIBase::LoopData *OuterLoop) { if (OuterLoop && OuterLoop->isHeader(Succ)) @@ -605,7 +631,7 @@ template <> struct GraphTraits<IrreducibleGraph> { static ChildIteratorType child_begin(NodeType *N) { return N->succ_begin(); } static ChildIteratorType child_end(NodeType *N) { return N->succ_end(); } }; -} +} // end namespace llvm /// \brief Find extra irreducible headers. /// diff --git a/lib/Analysis/BranchProbabilityInfo.cpp b/lib/Analysis/BranchProbabilityInfo.cpp index cf0cc8da6ef8..d802552d4e29 100644 --- a/lib/Analysis/BranchProbabilityInfo.cpp +++ b/lib/Analysis/BranchProbabilityInfo.cpp @@ -112,10 +112,15 @@ static const uint32_t IH_NONTAKEN_WEIGHT = 1; /// /// Predict that a successor which leads necessarily to an /// unreachable-terminated block as extremely unlikely. -bool BranchProbabilityInfo::calcUnreachableHeuristics(BasicBlock *BB) { - TerminatorInst *TI = BB->getTerminator(); +bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 0) { - if (isa<UnreachableInst>(TI)) + if (isa<UnreachableInst>(TI) || + // If this block is terminated by a call to + // @llvm.experimental.deoptimize then treat it like an unreachable since + // the @llvm.experimental.deoptimize call is expected to practically + // never execute. + BB->getTerminatingDeoptimizeCall()) PostDominatedByUnreachable.insert(BB); return false; } @@ -123,7 +128,7 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(BasicBlock *BB) { SmallVector<unsigned, 4> UnreachableEdges; SmallVector<unsigned, 4> ReachableEdges; - for (succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { if (PostDominatedByUnreachable.count(*I)) UnreachableEdges.push_back(I.getSuccessorIndex()); else @@ -174,8 +179,8 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(BasicBlock *BB) { // Propagate existing explicit probabilities from either profile data or // 'expect' intrinsic processing. -bool BranchProbabilityInfo::calcMetadataWeights(BasicBlock *BB) { - TerminatorInst *TI = BB->getTerminator(); +bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 1) return false; if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) @@ -244,15 +249,15 @@ bool BranchProbabilityInfo::calcMetadataWeights(BasicBlock *BB) { /// /// Return true if we could compute the weights for cold edges. /// Return false, otherwise. -bool BranchProbabilityInfo::calcColdCallHeuristics(BasicBlock *BB) { - TerminatorInst *TI = BB->getTerminator(); +bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 0) return false; // Determine which successors are post-dominated by a cold block. SmallVector<unsigned, 4> ColdEdges; SmallVector<unsigned, 4> NormalEdges; - for (succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) if (PostDominatedByColdCall.count(*I)) ColdEdges.push_back(I.getSuccessorIndex()); else @@ -266,8 +271,8 @@ bool BranchProbabilityInfo::calcColdCallHeuristics(BasicBlock *BB) { // Otherwise, if the block itself contains a cold function, add it to the // set of blocks postdominated by a cold call. assert(!PostDominatedByColdCall.count(BB)); - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - if (CallInst *CI = dyn_cast<CallInst>(I)) + for (BasicBlock::const_iterator I = BB->begin(), E = BB->end(); I != E; ++I) + if (const CallInst *CI = dyn_cast<CallInst>(I)) if (CI->hasFnAttr(Attribute::Cold)) { PostDominatedByColdCall.insert(BB); break; @@ -302,8 +307,8 @@ bool BranchProbabilityInfo::calcColdCallHeuristics(BasicBlock *BB) { // Calculate Edge Weights using "Pointer Heuristics". Predict a comparsion // between two pointer or pointer and NULL will fail. -bool BranchProbabilityInfo::calcPointerHeuristics(BasicBlock *BB) { - BranchInst * BI = dyn_cast<BranchInst>(BB->getTerminator()); +bool BranchProbabilityInfo::calcPointerHeuristics(const BasicBlock *BB) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); if (!BI || !BI->isConditional()) return false; @@ -337,7 +342,7 @@ bool BranchProbabilityInfo::calcPointerHeuristics(BasicBlock *BB) { // Calculate Edge Weights using "Loop Branch Heuristics". Predict backedges // as taken, exiting edges as not-taken. -bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB, +bool BranchProbabilityInfo::calcLoopBranchHeuristics(const BasicBlock *BB, const LoopInfo &LI) { Loop *L = LI.getLoopFor(BB); if (!L) @@ -347,7 +352,7 @@ bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB, SmallVector<unsigned, 8> ExitingEdges; SmallVector<unsigned, 8> InEdges; // Edges from header to the loop. - for (succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { if (!L->contains(*I)) ExitingEdges.push_back(I.getSuccessorIndex()); else if (L->getHeader() == *I) @@ -361,7 +366,9 @@ bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB, // Collect the sum of probabilities of back-edges/in-edges/exiting-edges, and // normalize them so that they sum up to one. - SmallVector<BranchProbability, 4> Probs(3, BranchProbability::getZero()); + BranchProbability Probs[] = {BranchProbability::getZero(), + BranchProbability::getZero(), + BranchProbability::getZero()}; unsigned Denom = (BackEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + (InEdges.empty() ? 0 : LBH_TAKEN_WEIGHT) + (ExitingEdges.empty() ? 0 : LBH_NONTAKEN_WEIGHT); @@ -393,8 +400,8 @@ bool BranchProbabilityInfo::calcLoopBranchHeuristics(BasicBlock *BB, return true; } -bool BranchProbabilityInfo::calcZeroHeuristics(BasicBlock *BB) { - BranchInst * BI = dyn_cast<BranchInst>(BB->getTerminator()); +bool BranchProbabilityInfo::calcZeroHeuristics(const BasicBlock *BB) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); if (!BI || !BI->isConditional()) return false; @@ -476,8 +483,8 @@ bool BranchProbabilityInfo::calcZeroHeuristics(BasicBlock *BB) { return true; } -bool BranchProbabilityInfo::calcFloatingPointHeuristics(BasicBlock *BB) { - BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); +bool BranchProbabilityInfo::calcFloatingPointHeuristics(const BasicBlock *BB) { + const BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); if (!BI || !BI->isConditional()) return false; @@ -513,8 +520,8 @@ bool BranchProbabilityInfo::calcFloatingPointHeuristics(BasicBlock *BB) { return true; } -bool BranchProbabilityInfo::calcInvokeHeuristics(BasicBlock *BB) { - InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()); +bool BranchProbabilityInfo::calcInvokeHeuristics(const BasicBlock *BB) { + const InvokeInst *II = dyn_cast<InvokeInst>(BB->getTerminator()); if (!II) return false; @@ -549,12 +556,13 @@ isEdgeHot(const BasicBlock *Src, const BasicBlock *Dst) const { return getEdgeProbability(Src, Dst) > BranchProbability(4, 5); } -BasicBlock *BranchProbabilityInfo::getHotSucc(BasicBlock *BB) const { +const BasicBlock * +BranchProbabilityInfo::getHotSucc(const BasicBlock *BB) const { auto MaxProb = BranchProbability::getZero(); - BasicBlock *MaxSucc = nullptr; + const BasicBlock *MaxSucc = nullptr; - for (succ_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { - BasicBlock *Succ = *I; + for (succ_const_iterator I = succ_begin(BB), E = succ_end(BB); I != E; ++I) { + const BasicBlock *Succ = *I; auto Prob = getEdgeProbability(BB, Succ); if (Prob > MaxProb) { MaxProb = Prob; @@ -616,6 +624,7 @@ void BranchProbabilityInfo::setEdgeProbability(const BasicBlock *Src, unsigned IndexInSuccessors, BranchProbability Prob) { Probs[std::make_pair(Src, IndexInSuccessors)] = Prob; + Handles.insert(BasicBlockCallbackVH(Src, this)); DEBUG(dbgs() << "set edge " << Src->getName() << " -> " << IndexInSuccessors << " successor probability to " << Prob << "\n"); } @@ -633,7 +642,15 @@ BranchProbabilityInfo::printEdgeProbability(raw_ostream &OS, return OS; } -void BranchProbabilityInfo::calculate(Function &F, const LoopInfo& LI) { +void BranchProbabilityInfo::eraseBlock(const BasicBlock *BB) { + for (auto I = Probs.begin(), E = Probs.end(); I != E; ++I) { + auto Key = I->first; + if (Key.first == BB) + Probs.erase(Key); + } +} + +void BranchProbabilityInfo::calculate(const Function &F, const LoopInfo &LI) { DEBUG(dbgs() << "---- Branch Probability Info : " << F.getName() << " ----\n\n"); LastF = &F; // Store the last function we ran on for printing. @@ -683,3 +700,20 @@ void BranchProbabilityInfoWrapperPass::print(raw_ostream &OS, const Module *) const { BPI.print(OS); } + +char BranchProbabilityAnalysis::PassID; +BranchProbabilityInfo +BranchProbabilityAnalysis::run(Function &F, AnalysisManager<Function> &AM) { + BranchProbabilityInfo BPI; + BPI.calculate(F, AM.getResult<LoopAnalysis>(F)); + return BPI; +} + +PreservedAnalyses +BranchProbabilityPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { + OS << "Printing analysis results of BPI for function " + << "'" << F.getName() << "':" + << "\n"; + AM.getResult<BranchProbabilityAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} diff --git a/lib/Analysis/CFG.cpp b/lib/Analysis/CFG.cpp index 0dfd57d3cb6b..a319be8092f9 100644 --- a/lib/Analysis/CFG.cpp +++ b/lib/Analysis/CFG.cpp @@ -138,7 +138,7 @@ bool llvm::isPotentiallyReachableFromMany( // Limit the number of blocks we visit. The goal is to avoid run-away compile // times on large CFGs without hampering sensible code. Arbitrarily chosen. unsigned Limit = 32; - SmallSet<const BasicBlock*, 64> Visited; + SmallPtrSet<const BasicBlock*, 32> Visited; do { BasicBlock *BB = Worklist.pop_back_val(); if (!Visited.insert(BB).second) diff --git a/lib/Analysis/CFLAliasAnalysis.cpp b/lib/Analysis/CFLAliasAnalysis.cpp deleted file mode 100644 index 4843ed6587a8..000000000000 --- a/lib/Analysis/CFLAliasAnalysis.cpp +++ /dev/null @@ -1,1119 +0,0 @@ -//===- CFLAliasAnalysis.cpp - CFL-Based Alias Analysis Implementation ------==// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file implements a CFL-based context-insensitive alias analysis -// algorithm. It does not depend on types. The algorithm is a mixture of the one -// described in "Demand-driven alias analysis for C" by Xin Zheng and Radu -// Rugina, and "Fast algorithms for Dyck-CFL-reachability with applications to -// Alias Analysis" by Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the -// papers, we build a graph of the uses of a variable, where each node is a -// memory location, and each edge is an action that happened on that memory -// location. The "actions" can be one of Dereference, Reference, or Assign. -// -// Two variables are considered as aliasing iff you can reach one value's node -// from the other value's node and the language formed by concatenating all of -// the edge labels (actions) conforms to a context-free grammar. -// -// Because this algorithm requires a graph search on each query, we execute the -// algorithm outlined in "Fast algorithms..." (mentioned above) -// in order to transform the graph into sets of variables that may alias in -// ~nlogn time (n = number of variables.), which makes queries take constant -// time. -//===----------------------------------------------------------------------===// - -#include "llvm/Analysis/CFLAliasAnalysis.h" -#include "StratifiedSets.h" -#include "llvm/ADT/BitVector.h" -#include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Function.h" -#include "llvm/IR/InstVisitor.h" -#include "llvm/IR/Instructions.h" -#include "llvm/Pass.h" -#include "llvm/Support/Allocator.h" -#include "llvm/Support/Compiler.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/raw_ostream.h" -#include <algorithm> -#include <cassert> -#include <memory> -#include <tuple> - -using namespace llvm; - -#define DEBUG_TYPE "cfl-aa" - -CFLAAResult::CFLAAResult(const TargetLibraryInfo &TLI) : AAResultBase(TLI) {} -CFLAAResult::CFLAAResult(CFLAAResult &&Arg) : AAResultBase(std::move(Arg)) {} - -// \brief Information we have about a function and would like to keep around -struct CFLAAResult::FunctionInfo { - StratifiedSets<Value *> Sets; - // Lots of functions have < 4 returns. Adjust as necessary. - SmallVector<Value *, 4> ReturnedValues; - - FunctionInfo(StratifiedSets<Value *> &&S, SmallVector<Value *, 4> &&RV) - : Sets(std::move(S)), ReturnedValues(std::move(RV)) {} -}; - -// Try to go from a Value* to a Function*. Never returns nullptr. -static Optional<Function *> parentFunctionOfValue(Value *); - -// Returns possible functions called by the Inst* into the given -// SmallVectorImpl. Returns true if targets found, false otherwise. -// This is templated because InvokeInst/CallInst give us the same -// set of functions that we care about, and I don't like repeating -// myself. -template <typename Inst> -static bool getPossibleTargets(Inst *, SmallVectorImpl<Function *> &); - -// Some instructions need to have their users tracked. Instructions like -// `add` require you to get the users of the Instruction* itself, other -// instructions like `store` require you to get the users of the first -// operand. This function gets the "proper" value to track for each -// type of instruction we support. -static Optional<Value *> getTargetValue(Instruction *); - -// There are certain instructions (i.e. FenceInst, etc.) that we ignore. -// This notes that we should ignore those. -static bool hasUsefulEdges(Instruction *); - -const StratifiedIndex StratifiedLink::SetSentinel = - std::numeric_limits<StratifiedIndex>::max(); - -namespace { -// StratifiedInfo Attribute things. -typedef unsigned StratifiedAttr; -LLVM_CONSTEXPR unsigned MaxStratifiedAttrIndex = NumStratifiedAttrs; -LLVM_CONSTEXPR unsigned AttrAllIndex = 0; -LLVM_CONSTEXPR unsigned AttrGlobalIndex = 1; -LLVM_CONSTEXPR unsigned AttrUnknownIndex = 2; -LLVM_CONSTEXPR unsigned AttrFirstArgIndex = 3; -LLVM_CONSTEXPR unsigned AttrLastArgIndex = MaxStratifiedAttrIndex; -LLVM_CONSTEXPR unsigned AttrMaxNumArgs = AttrLastArgIndex - AttrFirstArgIndex; - -LLVM_CONSTEXPR StratifiedAttr AttrNone = 0; -LLVM_CONSTEXPR StratifiedAttr AttrUnknown = 1 << AttrUnknownIndex; -LLVM_CONSTEXPR StratifiedAttr AttrAll = ~AttrNone; - -// \brief StratifiedSets call for knowledge of "direction", so this is how we -// represent that locally. -enum class Level { Same, Above, Below }; - -// \brief Edges can be one of four "weights" -- each weight must have an inverse -// weight (Assign has Assign; Reference has Dereference). -enum class EdgeType { - // The weight assigned when assigning from or to a value. For example, in: - // %b = getelementptr %a, 0 - // ...The relationships are %b assign %a, and %a assign %b. This used to be - // two edges, but having a distinction bought us nothing. - Assign, - - // The edge used when we have an edge going from some handle to a Value. - // Examples of this include: - // %b = load %a (%b Dereference %a) - // %b = extractelement %a, 0 (%a Dereference %b) - Dereference, - - // The edge used when our edge goes from a value to a handle that may have - // contained it at some point. Examples: - // %b = load %a (%a Reference %b) - // %b = extractelement %a, 0 (%b Reference %a) - Reference -}; - -// \brief Encodes the notion of a "use" -struct Edge { - // \brief Which value the edge is coming from - Value *From; - - // \brief Which value the edge is pointing to - Value *To; - - // \brief Edge weight - EdgeType Weight; - - // \brief Whether we aliased any external values along the way that may be - // invisible to the analysis (i.e. landingpad for exceptions, calls for - // interprocedural analysis, etc.) - StratifiedAttrs AdditionalAttrs; - - Edge(Value *From, Value *To, EdgeType W, StratifiedAttrs A) - : From(From), To(To), Weight(W), AdditionalAttrs(A) {} -}; - -// \brief Gets the edges our graph should have, based on an Instruction* -class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { - CFLAAResult &AA; - SmallVectorImpl<Edge> &Output; - -public: - GetEdgesVisitor(CFLAAResult &AA, SmallVectorImpl<Edge> &Output) - : AA(AA), Output(Output) {} - - void visitInstruction(Instruction &) { - llvm_unreachable("Unsupported instruction encountered"); - } - - void visitPtrToIntInst(PtrToIntInst &Inst) { - auto *Ptr = Inst.getOperand(0); - Output.push_back(Edge(Ptr, Ptr, EdgeType::Assign, AttrUnknown)); - } - - void visitIntToPtrInst(IntToPtrInst &Inst) { - auto *Ptr = &Inst; - Output.push_back(Edge(Ptr, Ptr, EdgeType::Assign, AttrUnknown)); - } - - void visitCastInst(CastInst &Inst) { - Output.push_back( - Edge(&Inst, Inst.getOperand(0), EdgeType::Assign, AttrNone)); - } - - void visitBinaryOperator(BinaryOperator &Inst) { - auto *Op1 = Inst.getOperand(0); - auto *Op2 = Inst.getOperand(1); - Output.push_back(Edge(&Inst, Op1, EdgeType::Assign, AttrNone)); - Output.push_back(Edge(&Inst, Op2, EdgeType::Assign, AttrNone)); - } - - void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { - auto *Ptr = Inst.getPointerOperand(); - auto *Val = Inst.getNewValOperand(); - Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); - } - - void visitAtomicRMWInst(AtomicRMWInst &Inst) { - auto *Ptr = Inst.getPointerOperand(); - auto *Val = Inst.getValOperand(); - Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); - } - - void visitPHINode(PHINode &Inst) { - for (Value *Val : Inst.incoming_values()) { - Output.push_back(Edge(&Inst, Val, EdgeType::Assign, AttrNone)); - } - } - - void visitGetElementPtrInst(GetElementPtrInst &Inst) { - auto *Op = Inst.getPointerOperand(); - Output.push_back(Edge(&Inst, Op, EdgeType::Assign, AttrNone)); - for (auto I = Inst.idx_begin(), E = Inst.idx_end(); I != E; ++I) - Output.push_back(Edge(&Inst, *I, EdgeType::Assign, AttrNone)); - } - - void visitSelectInst(SelectInst &Inst) { - // Condition is not processed here (The actual statement producing - // the condition result is processed elsewhere). For select, the - // condition is evaluated, but not loaded, stored, or assigned - // simply as a result of being the condition of a select. - - auto *TrueVal = Inst.getTrueValue(); - Output.push_back(Edge(&Inst, TrueVal, EdgeType::Assign, AttrNone)); - auto *FalseVal = Inst.getFalseValue(); - Output.push_back(Edge(&Inst, FalseVal, EdgeType::Assign, AttrNone)); - } - - void visitAllocaInst(AllocaInst &) {} - - void visitLoadInst(LoadInst &Inst) { - auto *Ptr = Inst.getPointerOperand(); - auto *Val = &Inst; - Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); - } - - void visitStoreInst(StoreInst &Inst) { - auto *Ptr = Inst.getPointerOperand(); - auto *Val = Inst.getValueOperand(); - Output.push_back(Edge(Ptr, Val, EdgeType::Dereference, AttrNone)); - } - - void visitVAArgInst(VAArgInst &Inst) { - // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it does - // two things: - // 1. Loads a value from *((T*)*Ptr). - // 2. Increments (stores to) *Ptr by some target-specific amount. - // For now, we'll handle this like a landingpad instruction (by placing the - // result in its own group, and having that group alias externals). - auto *Val = &Inst; - Output.push_back(Edge(Val, Val, EdgeType::Assign, AttrAll)); - } - - static bool isFunctionExternal(Function *Fn) { - return Fn->isDeclaration() || !Fn->hasLocalLinkage(); - } - - // Gets whether the sets at Index1 above, below, or equal to the sets at - // Index2. Returns None if they are not in the same set chain. - static Optional<Level> getIndexRelation(const StratifiedSets<Value *> &Sets, - StratifiedIndex Index1, - StratifiedIndex Index2) { - if (Index1 == Index2) - return Level::Same; - - const auto *Current = &Sets.getLink(Index1); - while (Current->hasBelow()) { - if (Current->Below == Index2) - return Level::Below; - Current = &Sets.getLink(Current->Below); - } - - Current = &Sets.getLink(Index1); - while (Current->hasAbove()) { - if (Current->Above == Index2) - return Level::Above; - Current = &Sets.getLink(Current->Above); - } - - return NoneType(); - } - - bool - tryInterproceduralAnalysis(const SmallVectorImpl<Function *> &Fns, - Value *FuncValue, - const iterator_range<User::op_iterator> &Args) { - const unsigned ExpectedMaxArgs = 8; - const unsigned MaxSupportedArgs = 50; - assert(Fns.size() > 0); - - // I put this here to give us an upper bound on time taken by IPA. Is it - // really (realistically) needed? Keep in mind that we do have an n^2 algo. - if (std::distance(Args.begin(), Args.end()) > (int)MaxSupportedArgs) - return false; - - // Exit early if we'll fail anyway - for (auto *Fn : Fns) { - if (isFunctionExternal(Fn) || Fn->isVarArg()) - return false; - auto &MaybeInfo = AA.ensureCached(Fn); - if (!MaybeInfo.hasValue()) - return false; - } - - SmallVector<Value *, ExpectedMaxArgs> Arguments(Args.begin(), Args.end()); - SmallVector<StratifiedInfo, ExpectedMaxArgs> Parameters; - for (auto *Fn : Fns) { - auto &Info = *AA.ensureCached(Fn); - auto &Sets = Info.Sets; - auto &RetVals = Info.ReturnedValues; - - Parameters.clear(); - for (auto &Param : Fn->args()) { - auto MaybeInfo = Sets.find(&Param); - // Did a new parameter somehow get added to the function/slip by? - if (!MaybeInfo.hasValue()) - return false; - Parameters.push_back(*MaybeInfo); - } - - // Adding an edge from argument -> return value for each parameter that - // may alias the return value - for (unsigned I = 0, E = Parameters.size(); I != E; ++I) { - auto &ParamInfo = Parameters[I]; - auto &ArgVal = Arguments[I]; - bool AddEdge = false; - StratifiedAttrs Externals; - for (unsigned X = 0, XE = RetVals.size(); X != XE; ++X) { - auto MaybeInfo = Sets.find(RetVals[X]); - if (!MaybeInfo.hasValue()) - return false; - - auto &RetInfo = *MaybeInfo; - auto RetAttrs = Sets.getLink(RetInfo.Index).Attrs; - auto ParamAttrs = Sets.getLink(ParamInfo.Index).Attrs; - auto MaybeRelation = - getIndexRelation(Sets, ParamInfo.Index, RetInfo.Index); - if (MaybeRelation.hasValue()) { - AddEdge = true; - Externals |= RetAttrs | ParamAttrs; - } - } - if (AddEdge) - Output.push_back(Edge(FuncValue, ArgVal, EdgeType::Assign, - StratifiedAttrs().flip())); - } - - if (Parameters.size() != Arguments.size()) - return false; - - // Adding edges between arguments for arguments that may end up aliasing - // each other. This is necessary for functions such as - // void foo(int** a, int** b) { *a = *b; } - // (Technically, the proper sets for this would be those below - // Arguments[I] and Arguments[X], but our algorithm will produce - // extremely similar, and equally correct, results either way) - for (unsigned I = 0, E = Arguments.size(); I != E; ++I) { - auto &MainVal = Arguments[I]; - auto &MainInfo = Parameters[I]; - auto &MainAttrs = Sets.getLink(MainInfo.Index).Attrs; - for (unsigned X = I + 1; X != E; ++X) { - auto &SubInfo = Parameters[X]; - auto &SubVal = Arguments[X]; - auto &SubAttrs = Sets.getLink(SubInfo.Index).Attrs; - auto MaybeRelation = - getIndexRelation(Sets, MainInfo.Index, SubInfo.Index); - - if (!MaybeRelation.hasValue()) - continue; - - auto NewAttrs = SubAttrs | MainAttrs; - Output.push_back(Edge(MainVal, SubVal, EdgeType::Assign, NewAttrs)); - } - } - } - return true; - } - - template <typename InstT> void visitCallLikeInst(InstT &Inst) { - // TODO: Add support for noalias args/all the other fun function attributes - // that we can tack on. - SmallVector<Function *, 4> Targets; - if (getPossibleTargets(&Inst, Targets)) { - if (tryInterproceduralAnalysis(Targets, &Inst, Inst.arg_operands())) - return; - // Cleanup from interprocedural analysis - Output.clear(); - } - - // Because the function is opaque, we need to note that anything - // could have happened to the arguments, and that the result could alias - // just about anything, too. - // The goal of the loop is in part to unify many Values into one set, so we - // don't care if the function is void there. - for (Value *V : Inst.arg_operands()) - Output.push_back(Edge(&Inst, V, EdgeType::Assign, AttrAll)); - if (Inst.getNumArgOperands() == 0 && - Inst.getType() != Type::getVoidTy(Inst.getContext())) - Output.push_back(Edge(&Inst, &Inst, EdgeType::Assign, AttrAll)); - } - - void visitCallInst(CallInst &Inst) { visitCallLikeInst(Inst); } - - void visitInvokeInst(InvokeInst &Inst) { visitCallLikeInst(Inst); } - - // Because vectors/aggregates are immutable and unaddressable, - // there's nothing we can do to coax a value out of them, other - // than calling Extract{Element,Value}. We can effectively treat - // them as pointers to arbitrary memory locations we can store in - // and load from. - void visitExtractElementInst(ExtractElementInst &Inst) { - auto *Ptr = Inst.getVectorOperand(); - auto *Val = &Inst; - Output.push_back(Edge(Val, Ptr, EdgeType::Reference, AttrNone)); - } - - void visitInsertElementInst(InsertElementInst &Inst) { - auto *Vec = Inst.getOperand(0); - auto *Val = Inst.getOperand(1); - Output.push_back(Edge(&Inst, Vec, EdgeType::Assign, AttrNone)); - Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); - } - - void visitLandingPadInst(LandingPadInst &Inst) { - // Exceptions come from "nowhere", from our analysis' perspective. - // So we place the instruction its own group, noting that said group may - // alias externals - Output.push_back(Edge(&Inst, &Inst, EdgeType::Assign, AttrAll)); - } - - void visitInsertValueInst(InsertValueInst &Inst) { - auto *Agg = Inst.getOperand(0); - auto *Val = Inst.getOperand(1); - Output.push_back(Edge(&Inst, Agg, EdgeType::Assign, AttrNone)); - Output.push_back(Edge(&Inst, Val, EdgeType::Dereference, AttrNone)); - } - - void visitExtractValueInst(ExtractValueInst &Inst) { - auto *Ptr = Inst.getAggregateOperand(); - Output.push_back(Edge(&Inst, Ptr, EdgeType::Reference, AttrNone)); - } - - void visitShuffleVectorInst(ShuffleVectorInst &Inst) { - auto *From1 = Inst.getOperand(0); - auto *From2 = Inst.getOperand(1); - Output.push_back(Edge(&Inst, From1, EdgeType::Assign, AttrNone)); - Output.push_back(Edge(&Inst, From2, EdgeType::Assign, AttrNone)); - } - - void visitConstantExpr(ConstantExpr *CE) { - switch (CE->getOpcode()) { - default: - llvm_unreachable("Unknown instruction type encountered!"); -// Build the switch statement using the Instruction.def file. -#define HANDLE_INST(NUM, OPCODE, CLASS) \ - case Instruction::OPCODE: \ - visit##OPCODE(*(CLASS *)CE); \ - break; -#include "llvm/IR/Instruction.def" - } - } -}; - -// For a given instruction, we need to know which Value* to get the -// users of in order to build our graph. In some cases (i.e. add), -// we simply need the Instruction*. In other cases (i.e. store), -// finding the users of the Instruction* is useless; we need to find -// the users of the first operand. This handles determining which -// value to follow for us. -// -// Note: we *need* to keep this in sync with GetEdgesVisitor. Add -// something to GetEdgesVisitor, add it here -- remove something from -// GetEdgesVisitor, remove it here. -class GetTargetValueVisitor - : public InstVisitor<GetTargetValueVisitor, Value *> { -public: - Value *visitInstruction(Instruction &Inst) { return &Inst; } - - Value *visitStoreInst(StoreInst &Inst) { return Inst.getPointerOperand(); } - - Value *visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { - return Inst.getPointerOperand(); - } - - Value *visitAtomicRMWInst(AtomicRMWInst &Inst) { - return Inst.getPointerOperand(); - } - - Value *visitInsertElementInst(InsertElementInst &Inst) { - return Inst.getOperand(0); - } - - Value *visitInsertValueInst(InsertValueInst &Inst) { - return Inst.getAggregateOperand(); - } -}; - -// Set building requires a weighted bidirectional graph. -template <typename EdgeTypeT> class WeightedBidirectionalGraph { -public: - typedef std::size_t Node; - -private: - const static Node StartNode = Node(0); - - struct Edge { - EdgeTypeT Weight; - Node Other; - - Edge(const EdgeTypeT &W, const Node &N) : Weight(W), Other(N) {} - - bool operator==(const Edge &E) const { - return Weight == E.Weight && Other == E.Other; - } - - bool operator!=(const Edge &E) const { return !operator==(E); } - }; - - struct NodeImpl { - std::vector<Edge> Edges; - }; - - std::vector<NodeImpl> NodeImpls; - - bool inbounds(Node NodeIndex) const { return NodeIndex < NodeImpls.size(); } - - const NodeImpl &getNode(Node N) const { return NodeImpls[N]; } - NodeImpl &getNode(Node N) { return NodeImpls[N]; } - -public: - // ----- Various Edge iterators for the graph ----- // - - // \brief Iterator for edges. Because this graph is bidirected, we don't - // allow modification of the edges using this iterator. Additionally, the - // iterator becomes invalid if you add edges to or from the node you're - // getting the edges of. - struct EdgeIterator : public std::iterator<std::forward_iterator_tag, - std::tuple<EdgeTypeT, Node *>> { - EdgeIterator(const typename std::vector<Edge>::const_iterator &Iter) - : Current(Iter) {} - - EdgeIterator(NodeImpl &Impl) : Current(Impl.begin()) {} - - EdgeIterator &operator++() { - ++Current; - return *this; - } - - EdgeIterator operator++(int) { - EdgeIterator Copy(Current); - operator++(); - return Copy; - } - - std::tuple<EdgeTypeT, Node> &operator*() { - Store = std::make_tuple(Current->Weight, Current->Other); - return Store; - } - - bool operator==(const EdgeIterator &Other) const { - return Current == Other.Current; - } - - bool operator!=(const EdgeIterator &Other) const { - return !operator==(Other); - } - - private: - typename std::vector<Edge>::const_iterator Current; - std::tuple<EdgeTypeT, Node> Store; - }; - - // Wrapper for EdgeIterator with begin()/end() calls. - struct EdgeIterable { - EdgeIterable(const std::vector<Edge> &Edges) - : BeginIter(Edges.begin()), EndIter(Edges.end()) {} - - EdgeIterator begin() { return EdgeIterator(BeginIter); } - - EdgeIterator end() { return EdgeIterator(EndIter); } - - private: - typename std::vector<Edge>::const_iterator BeginIter; - typename std::vector<Edge>::const_iterator EndIter; - }; - - // ----- Actual graph-related things ----- // - - WeightedBidirectionalGraph() {} - - WeightedBidirectionalGraph(WeightedBidirectionalGraph<EdgeTypeT> &&Other) - : NodeImpls(std::move(Other.NodeImpls)) {} - - WeightedBidirectionalGraph<EdgeTypeT> & - operator=(WeightedBidirectionalGraph<EdgeTypeT> &&Other) { - NodeImpls = std::move(Other.NodeImpls); - return *this; - } - - Node addNode() { - auto Index = NodeImpls.size(); - auto NewNode = Node(Index); - NodeImpls.push_back(NodeImpl()); - return NewNode; - } - - void addEdge(Node From, Node To, const EdgeTypeT &Weight, - const EdgeTypeT &ReverseWeight) { - assert(inbounds(From)); - assert(inbounds(To)); - auto &FromNode = getNode(From); - auto &ToNode = getNode(To); - FromNode.Edges.push_back(Edge(Weight, To)); - ToNode.Edges.push_back(Edge(ReverseWeight, From)); - } - - EdgeIterable edgesFor(const Node &N) const { - const auto &Node = getNode(N); - return EdgeIterable(Node.Edges); - } - - bool empty() const { return NodeImpls.empty(); } - std::size_t size() const { return NodeImpls.size(); } - - // \brief Gets an arbitrary node in the graph as a starting point for - // traversal. - Node getEntryNode() { - assert(inbounds(StartNode)); - return StartNode; - } -}; - -typedef WeightedBidirectionalGraph<std::pair<EdgeType, StratifiedAttrs>> GraphT; -typedef DenseMap<Value *, GraphT::Node> NodeMapT; -} - -//===----------------------------------------------------------------------===// -// Function declarations that require types defined in the namespace above -//===----------------------------------------------------------------------===// - -// Given an argument number, returns the appropriate Attr index to set. -static StratifiedAttr argNumberToAttrIndex(StratifiedAttr); - -// Given a Value, potentially return which AttrIndex it maps to. -static Optional<StratifiedAttr> valueToAttrIndex(Value *Val); - -// Gets the inverse of a given EdgeType. -static EdgeType flipWeight(EdgeType); - -// Gets edges of the given Instruction*, writing them to the SmallVector*. -static void argsToEdges(CFLAAResult &, Instruction *, SmallVectorImpl<Edge> &); - -// Gets edges of the given ConstantExpr*, writing them to the SmallVector*. -static void argsToEdges(CFLAAResult &, ConstantExpr *, SmallVectorImpl<Edge> &); - -// Gets the "Level" that one should travel in StratifiedSets -// given an EdgeType. -static Level directionOfEdgeType(EdgeType); - -// Builds the graph needed for constructing the StratifiedSets for the -// given function -static void buildGraphFrom(CFLAAResult &, Function *, - SmallVectorImpl<Value *> &, NodeMapT &, GraphT &); - -// Gets the edges of a ConstantExpr as if it was an Instruction. This -// function also acts on any nested ConstantExprs, adding the edges -// of those to the given SmallVector as well. -static void constexprToEdges(CFLAAResult &, ConstantExpr &, - SmallVectorImpl<Edge> &); - -// Given an Instruction, this will add it to the graph, along with any -// Instructions that are potentially only available from said Instruction -// For example, given the following line: -// %0 = load i16* getelementptr ([1 x i16]* @a, 0, 0), align 2 -// addInstructionToGraph would add both the `load` and `getelementptr` -// instructions to the graph appropriately. -static void addInstructionToGraph(CFLAAResult &, Instruction &, - SmallVectorImpl<Value *> &, NodeMapT &, - GraphT &); - -// Notes whether it would be pointless to add the given Value to our sets. -static bool canSkipAddingToSets(Value *Val); - -static Optional<Function *> parentFunctionOfValue(Value *Val) { - if (auto *Inst = dyn_cast<Instruction>(Val)) { - auto *Bb = Inst->getParent(); - return Bb->getParent(); - } - - if (auto *Arg = dyn_cast<Argument>(Val)) - return Arg->getParent(); - return NoneType(); -} - -template <typename Inst> -static bool getPossibleTargets(Inst *Call, - SmallVectorImpl<Function *> &Output) { - if (auto *Fn = Call->getCalledFunction()) { - Output.push_back(Fn); - return true; - } - - // TODO: If the call is indirect, we might be able to enumerate all potential - // targets of the call and return them, rather than just failing. - return false; -} - -static Optional<Value *> getTargetValue(Instruction *Inst) { - GetTargetValueVisitor V; - return V.visit(Inst); -} - -static bool hasUsefulEdges(Instruction *Inst) { - bool IsNonInvokeTerminator = - isa<TerminatorInst>(Inst) && !isa<InvokeInst>(Inst); - return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && !IsNonInvokeTerminator; -} - -static bool hasUsefulEdges(ConstantExpr *CE) { - // ConstantExpr doesn't have terminators, invokes, or fences, so only needs - // to check for compares. - return CE->getOpcode() != Instruction::ICmp && - CE->getOpcode() != Instruction::FCmp; -} - -static Optional<StratifiedAttr> valueToAttrIndex(Value *Val) { - if (isa<GlobalValue>(Val)) - return AttrGlobalIndex; - - if (auto *Arg = dyn_cast<Argument>(Val)) - // Only pointer arguments should have the argument attribute, - // because things can't escape through scalars without us seeing a - // cast, and thus, interaction with them doesn't matter. - if (!Arg->hasNoAliasAttr() && Arg->getType()->isPointerTy()) - return argNumberToAttrIndex(Arg->getArgNo()); - return NoneType(); -} - -static StratifiedAttr argNumberToAttrIndex(unsigned ArgNum) { - if (ArgNum >= AttrMaxNumArgs) - return AttrAllIndex; - return ArgNum + AttrFirstArgIndex; -} - -static EdgeType flipWeight(EdgeType Initial) { - switch (Initial) { - case EdgeType::Assign: - return EdgeType::Assign; - case EdgeType::Dereference: - return EdgeType::Reference; - case EdgeType::Reference: - return EdgeType::Dereference; - } - llvm_unreachable("Incomplete coverage of EdgeType enum"); -} - -static void argsToEdges(CFLAAResult &Analysis, Instruction *Inst, - SmallVectorImpl<Edge> &Output) { - assert(hasUsefulEdges(Inst) && - "Expected instructions to have 'useful' edges"); - GetEdgesVisitor v(Analysis, Output); - v.visit(Inst); -} - -static void argsToEdges(CFLAAResult &Analysis, ConstantExpr *CE, - SmallVectorImpl<Edge> &Output) { - assert(hasUsefulEdges(CE) && "Expected constant expr to have 'useful' edges"); - GetEdgesVisitor v(Analysis, Output); - v.visitConstantExpr(CE); -} - -static Level directionOfEdgeType(EdgeType Weight) { - switch (Weight) { - case EdgeType::Reference: - return Level::Above; - case EdgeType::Dereference: - return Level::Below; - case EdgeType::Assign: - return Level::Same; - } - llvm_unreachable("Incomplete switch coverage"); -} - -static void constexprToEdges(CFLAAResult &Analysis, - ConstantExpr &CExprToCollapse, - SmallVectorImpl<Edge> &Results) { - SmallVector<ConstantExpr *, 4> Worklist; - Worklist.push_back(&CExprToCollapse); - - SmallVector<Edge, 8> ConstexprEdges; - SmallPtrSet<ConstantExpr *, 4> Visited; - while (!Worklist.empty()) { - auto *CExpr = Worklist.pop_back_val(); - - if (!hasUsefulEdges(CExpr)) - continue; - - ConstexprEdges.clear(); - argsToEdges(Analysis, CExpr, ConstexprEdges); - for (auto &Edge : ConstexprEdges) { - if (auto *Nested = dyn_cast<ConstantExpr>(Edge.From)) - if (Visited.insert(Nested).second) - Worklist.push_back(Nested); - - if (auto *Nested = dyn_cast<ConstantExpr>(Edge.To)) - if (Visited.insert(Nested).second) - Worklist.push_back(Nested); - } - - Results.append(ConstexprEdges.begin(), ConstexprEdges.end()); - } -} - -static void addInstructionToGraph(CFLAAResult &Analysis, Instruction &Inst, - SmallVectorImpl<Value *> &ReturnedValues, - NodeMapT &Map, GraphT &Graph) { - const auto findOrInsertNode = [&Map, &Graph](Value *Val) { - auto Pair = Map.insert(std::make_pair(Val, GraphT::Node())); - auto &Iter = Pair.first; - if (Pair.second) { - auto NewNode = Graph.addNode(); - Iter->second = NewNode; - } - return Iter->second; - }; - - // We don't want the edges of most "return" instructions, but we *do* want - // to know what can be returned. - if (isa<ReturnInst>(&Inst)) - ReturnedValues.push_back(&Inst); - - if (!hasUsefulEdges(&Inst)) - return; - - SmallVector<Edge, 8> Edges; - argsToEdges(Analysis, &Inst, Edges); - - // In the case of an unused alloca (or similar), edges may be empty. Note - // that it exists so we can potentially answer NoAlias. - if (Edges.empty()) { - auto MaybeVal = getTargetValue(&Inst); - assert(MaybeVal.hasValue()); - auto *Target = *MaybeVal; - findOrInsertNode(Target); - return; - } - - const auto addEdgeToGraph = [&Graph, &findOrInsertNode](const Edge &E) { - auto To = findOrInsertNode(E.To); - auto From = findOrInsertNode(E.From); - auto FlippedWeight = flipWeight(E.Weight); - auto Attrs = E.AdditionalAttrs; - Graph.addEdge(From, To, std::make_pair(E.Weight, Attrs), - std::make_pair(FlippedWeight, Attrs)); - }; - - SmallVector<ConstantExpr *, 4> ConstantExprs; - for (const Edge &E : Edges) { - addEdgeToGraph(E); - if (auto *Constexpr = dyn_cast<ConstantExpr>(E.To)) - ConstantExprs.push_back(Constexpr); - if (auto *Constexpr = dyn_cast<ConstantExpr>(E.From)) - ConstantExprs.push_back(Constexpr); - } - - for (ConstantExpr *CE : ConstantExprs) { - Edges.clear(); - constexprToEdges(Analysis, *CE, Edges); - std::for_each(Edges.begin(), Edges.end(), addEdgeToGraph); - } -} - -// Aside: We may remove graph construction entirely, because it doesn't really -// buy us much that we don't already have. I'd like to add interprocedural -// analysis prior to this however, in case that somehow requires the graph -// produced by this for efficient execution -static void buildGraphFrom(CFLAAResult &Analysis, Function *Fn, - SmallVectorImpl<Value *> &ReturnedValues, - NodeMapT &Map, GraphT &Graph) { - for (auto &Bb : Fn->getBasicBlockList()) - for (auto &Inst : Bb.getInstList()) - addInstructionToGraph(Analysis, Inst, ReturnedValues, Map, Graph); -} - -static bool canSkipAddingToSets(Value *Val) { - // Constants can share instances, which may falsely unify multiple - // sets, e.g. in - // store i32* null, i32** %ptr1 - // store i32* null, i32** %ptr2 - // clearly ptr1 and ptr2 should not be unified into the same set, so - // we should filter out the (potentially shared) instance to - // i32* null. - if (isa<Constant>(Val)) { - bool Container = isa<ConstantVector>(Val) || isa<ConstantArray>(Val) || - isa<ConstantStruct>(Val); - // TODO: Because all of these things are constant, we can determine whether - // the data is *actually* mutable at graph building time. This will probably - // come for free/cheap with offset awareness. - bool CanStoreMutableData = - isa<GlobalValue>(Val) || isa<ConstantExpr>(Val) || Container; - return !CanStoreMutableData; - } - - return false; -} - -// Builds the graph + StratifiedSets for a function. -CFLAAResult::FunctionInfo CFLAAResult::buildSetsFrom(Function *Fn) { - NodeMapT Map; - GraphT Graph; - SmallVector<Value *, 4> ReturnedValues; - - buildGraphFrom(*this, Fn, ReturnedValues, Map, Graph); - - DenseMap<GraphT::Node, Value *> NodeValueMap; - NodeValueMap.resize(Map.size()); - for (const auto &Pair : Map) - NodeValueMap.insert(std::make_pair(Pair.second, Pair.first)); - - const auto findValueOrDie = [&NodeValueMap](GraphT::Node Node) { - auto ValIter = NodeValueMap.find(Node); - assert(ValIter != NodeValueMap.end()); - return ValIter->second; - }; - - StratifiedSetsBuilder<Value *> Builder; - - SmallVector<GraphT::Node, 16> Worklist; - for (auto &Pair : Map) { - Worklist.clear(); - - auto *Value = Pair.first; - Builder.add(Value); - auto InitialNode = Pair.second; - Worklist.push_back(InitialNode); - while (!Worklist.empty()) { - auto Node = Worklist.pop_back_val(); - auto *CurValue = findValueOrDie(Node); - if (canSkipAddingToSets(CurValue)) - continue; - - for (const auto &EdgeTuple : Graph.edgesFor(Node)) { - auto Weight = std::get<0>(EdgeTuple); - auto Label = Weight.first; - auto &OtherNode = std::get<1>(EdgeTuple); - auto *OtherValue = findValueOrDie(OtherNode); - - if (canSkipAddingToSets(OtherValue)) - continue; - - bool Added; - switch (directionOfEdgeType(Label)) { - case Level::Above: - Added = Builder.addAbove(CurValue, OtherValue); - break; - case Level::Below: - Added = Builder.addBelow(CurValue, OtherValue); - break; - case Level::Same: - Added = Builder.addWith(CurValue, OtherValue); - break; - } - - auto Aliasing = Weight.second; - if (auto MaybeCurIndex = valueToAttrIndex(CurValue)) - Aliasing.set(*MaybeCurIndex); - if (auto MaybeOtherIndex = valueToAttrIndex(OtherValue)) - Aliasing.set(*MaybeOtherIndex); - Builder.noteAttributes(CurValue, Aliasing); - Builder.noteAttributes(OtherValue, Aliasing); - - if (Added) - Worklist.push_back(OtherNode); - } - } - } - - // There are times when we end up with parameters not in our graph (i.e. if - // it's only used as the condition of a branch). Other bits of code depend on - // things that were present during construction being present in the graph. - // So, we add all present arguments here. - for (auto &Arg : Fn->args()) { - if (!Builder.add(&Arg)) - continue; - - auto Attrs = valueToAttrIndex(&Arg); - if (Attrs.hasValue()) - Builder.noteAttributes(&Arg, *Attrs); - } - - return FunctionInfo(Builder.build(), std::move(ReturnedValues)); -} - -void CFLAAResult::scan(Function *Fn) { - auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); - (void)InsertPair; - assert(InsertPair.second && - "Trying to scan a function that has already been cached"); - - FunctionInfo Info(buildSetsFrom(Fn)); - Cache[Fn] = std::move(Info); - Handles.push_front(FunctionHandle(Fn, this)); -} - -void CFLAAResult::evict(Function *Fn) { Cache.erase(Fn); } - -/// \brief Ensures that the given function is available in the cache. -/// Returns the appropriate entry from the cache. -const Optional<CFLAAResult::FunctionInfo> & -CFLAAResult::ensureCached(Function *Fn) { - auto Iter = Cache.find(Fn); - if (Iter == Cache.end()) { - scan(Fn); - Iter = Cache.find(Fn); - assert(Iter != Cache.end()); - assert(Iter->second.hasValue()); - } - return Iter->second; -} - -AliasResult CFLAAResult::query(const MemoryLocation &LocA, - const MemoryLocation &LocB) { - auto *ValA = const_cast<Value *>(LocA.Ptr); - auto *ValB = const_cast<Value *>(LocB.Ptr); - - Function *Fn = nullptr; - auto MaybeFnA = parentFunctionOfValue(ValA); - auto MaybeFnB = parentFunctionOfValue(ValB); - if (!MaybeFnA.hasValue() && !MaybeFnB.hasValue()) { - // The only times this is known to happen are when globals + InlineAsm - // are involved - DEBUG(dbgs() << "CFLAA: could not extract parent function information.\n"); - return MayAlias; - } - - if (MaybeFnA.hasValue()) { - Fn = *MaybeFnA; - assert((!MaybeFnB.hasValue() || *MaybeFnB == *MaybeFnA) && - "Interprocedural queries not supported"); - } else { - Fn = *MaybeFnB; - } - - assert(Fn != nullptr); - auto &MaybeInfo = ensureCached(Fn); - assert(MaybeInfo.hasValue()); - - auto &Sets = MaybeInfo->Sets; - auto MaybeA = Sets.find(ValA); - if (!MaybeA.hasValue()) - return MayAlias; - - auto MaybeB = Sets.find(ValB); - if (!MaybeB.hasValue()) - return MayAlias; - - auto SetA = *MaybeA; - auto SetB = *MaybeB; - auto AttrsA = Sets.getLink(SetA.Index).Attrs; - auto AttrsB = Sets.getLink(SetB.Index).Attrs; - - // Stratified set attributes are used as markets to signify whether a member - // of a StratifiedSet (or a member of a set above the current set) has - // interacted with either arguments or globals. "Interacted with" meaning - // its value may be different depending on the value of an argument or - // global. The thought behind this is that, because arguments and globals - // may alias each other, if AttrsA and AttrsB have touched args/globals, - // we must conservatively say that they alias. However, if at least one of - // the sets has no values that could legally be altered by changing the value - // of an argument or global, then we don't have to be as conservative. - if (AttrsA.any() && AttrsB.any()) - return MayAlias; - - // We currently unify things even if the accesses to them may not be in - // bounds, so we can't return partial alias here because we don't - // know whether the pointer is really within the object or not. - // IE Given an out of bounds GEP and an alloca'd pointer, we may - // unify the two. We can't return partial alias for this case. - // Since we do not currently track enough information to - // differentiate - - if (SetA.Index == SetB.Index) - return MayAlias; - - return NoAlias; -} - -CFLAAResult CFLAA::run(Function &F, AnalysisManager<Function> *AM) { - return CFLAAResult(AM->getResult<TargetLibraryAnalysis>(F)); -} - -char CFLAA::PassID; - -char CFLAAWrapperPass::ID = 0; -INITIALIZE_PASS_BEGIN(CFLAAWrapperPass, "cfl-aa", "CFL-Based Alias Analysis", - false, true) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(CFLAAWrapperPass, "cfl-aa", "CFL-Based Alias Analysis", - false, true) - -ImmutablePass *llvm::createCFLAAWrapperPass() { return new CFLAAWrapperPass(); } - -CFLAAWrapperPass::CFLAAWrapperPass() : ImmutablePass(ID) { - initializeCFLAAWrapperPassPass(*PassRegistry::getPassRegistry()); -} - -bool CFLAAWrapperPass::doInitialization(Module &M) { - Result.reset( - new CFLAAResult(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); - return false; -} - -bool CFLAAWrapperPass::doFinalization(Module &M) { - Result.reset(); - return false; -} - -void CFLAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); -} diff --git a/lib/Analysis/CFLAndersAliasAnalysis.cpp b/lib/Analysis/CFLAndersAliasAnalysis.cpp new file mode 100644 index 000000000000..7d5bd94133a7 --- /dev/null +++ b/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -0,0 +1,584 @@ +//- CFLAndersAliasAnalysis.cpp - Unification-based Alias Analysis ---*- C++-*-// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-based, summary-based alias analysis algorithm. It +// differs from CFLSteensAliasAnalysis in its inclusion-based nature while +// CFLSteensAliasAnalysis is unification-based. This pass has worse performance +// than CFLSteensAliasAnalysis (the worst case complexity of +// CFLAndersAliasAnalysis is cubic, while the worst case complexity of +// CFLSteensAliasAnalysis is almost linear), but it is able to yield more +// precise analysis result. The precision of this analysis is roughly the same +// as that of an one level context-sensitive Andersen's algorithm. +// +// The algorithm used here is based on recursive state machine matching scheme +// proposed in "Demand-driven alias analysis for C" by Xin Zheng and Radu +// Rugina. The general idea is to extend the tranditional transitive closure +// algorithm to perform CFL matching along the way: instead of recording +// "whether X is reachable from Y", we keep track of "whether X is reachable +// from Y at state Z", where the "state" field indicates where we are in the CFL +// matching process. To understand the matching better, it is advisable to have +// the state machine shown in Figure 3 of the paper available when reading the +// codes: all we do here is to selectively expand the transitive closure by +// discarding edges that are not recognized by the state machine. +// +// There is one difference between our current implementation and the one +// described in the paper: out algorithm eagerly computes all alias pairs after +// the CFLGraph is built, while in the paper the authors did the computation in +// a demand-driven fashion. We did not implement the demand-driven algorithm due +// to the additional coding complexity and higher memory profile, but if we +// found it necessary we may switch to it eventually. +// +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLAndersAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLAndersAliasAnalysis.h" +#include "CFLGraph.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/Pass.h" + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-anders-aa" + +CFLAndersAAResult::CFLAndersAAResult(const TargetLibraryInfo &TLI) : TLI(TLI) {} +CFLAndersAAResult::CFLAndersAAResult(CFLAndersAAResult &&RHS) + : AAResultBase(std::move(RHS)), TLI(RHS.TLI) {} +CFLAndersAAResult::~CFLAndersAAResult() {} + +static const Function *parentFunctionOfValue(const Value *Val) { + if (auto *Inst = dyn_cast<Instruction>(Val)) { + auto *Bb = Inst->getParent(); + return Bb->getParent(); + } + + if (auto *Arg = dyn_cast<Argument>(Val)) + return Arg->getParent(); + return nullptr; +} + +namespace { + +enum class MatchState : uint8_t { + FlowFrom = 0, // S1 in the paper + FlowFromMemAlias, // S2 in the paper + FlowTo, // S3 in the paper + FlowToMemAlias // S4 in the paper +}; + +// We use ReachabilitySet to keep track of value aliases (The nonterminal "V" in +// the paper) during the analysis. +class ReachabilitySet { + typedef std::bitset<4> StateSet; + typedef DenseMap<InstantiatedValue, StateSet> ValueStateMap; + typedef DenseMap<InstantiatedValue, ValueStateMap> ValueReachMap; + ValueReachMap ReachMap; + +public: + typedef ValueStateMap::const_iterator const_valuestate_iterator; + typedef ValueReachMap::const_iterator const_value_iterator; + + // Insert edge 'From->To' at state 'State' + bool insert(InstantiatedValue From, InstantiatedValue To, MatchState State) { + auto &States = ReachMap[To][From]; + auto Idx = static_cast<size_t>(State); + if (!States.test(Idx)) { + States.set(Idx); + return true; + } + return false; + } + + // Return the set of all ('From', 'State') pair for a given node 'To' + iterator_range<const_valuestate_iterator> + reachableValueAliases(InstantiatedValue V) const { + auto Itr = ReachMap.find(V); + if (Itr == ReachMap.end()) + return make_range<const_valuestate_iterator>(const_valuestate_iterator(), + const_valuestate_iterator()); + return make_range<const_valuestate_iterator>(Itr->second.begin(), + Itr->second.end()); + } + + iterator_range<const_value_iterator> value_mappings() const { + return make_range<const_value_iterator>(ReachMap.begin(), ReachMap.end()); + } +}; + +// We use AliasMemSet to keep track of all memory aliases (the nonterminal "M" +// in the paper) during the analysis. +class AliasMemSet { + typedef DenseSet<InstantiatedValue> MemSet; + typedef DenseMap<InstantiatedValue, MemSet> MemMapType; + MemMapType MemMap; + +public: + typedef MemSet::const_iterator const_mem_iterator; + + bool insert(InstantiatedValue LHS, InstantiatedValue RHS) { + // Top-level values can never be memory aliases because one cannot take the + // addresses of them + assert(LHS.DerefLevel > 0 && RHS.DerefLevel > 0); + return MemMap[LHS].insert(RHS).second; + } + + const MemSet *getMemoryAliases(InstantiatedValue V) const { + auto Itr = MemMap.find(V); + if (Itr == MemMap.end()) + return nullptr; + return &Itr->second; + } +}; + +// We use AliasAttrMap to keep track of the AliasAttr of each node. +class AliasAttrMap { + typedef DenseMap<InstantiatedValue, AliasAttrs> MapType; + MapType AttrMap; + +public: + typedef MapType::const_iterator const_iterator; + + bool add(InstantiatedValue V, AliasAttrs Attr) { + if (Attr.none()) + return false; + auto &OldAttr = AttrMap[V]; + auto NewAttr = OldAttr | Attr; + if (OldAttr == NewAttr) + return false; + OldAttr = NewAttr; + return true; + } + + AliasAttrs getAttrs(InstantiatedValue V) const { + AliasAttrs Attr; + auto Itr = AttrMap.find(V); + if (Itr != AttrMap.end()) + Attr = Itr->second; + return Attr; + } + + iterator_range<const_iterator> mappings() const { + return make_range<const_iterator>(AttrMap.begin(), AttrMap.end()); + } +}; + +struct WorkListItem { + InstantiatedValue From; + InstantiatedValue To; + MatchState State; +}; +} + +class CFLAndersAAResult::FunctionInfo { + /// Map a value to other values that may alias it + /// Since the alias relation is symmetric, to save some space we assume values + /// are properly ordered: if a and b alias each other, and a < b, then b is in + /// AliasMap[a] but not vice versa. + DenseMap<const Value *, std::vector<const Value *>> AliasMap; + + /// Map a value to its corresponding AliasAttrs + DenseMap<const Value *, AliasAttrs> AttrMap; + + /// Summary of externally visible effects. + AliasSummary Summary; + + AliasAttrs getAttrs(const Value *) const; + +public: + FunctionInfo(const ReachabilitySet &, AliasAttrMap); + + bool mayAlias(const Value *LHS, const Value *RHS) const; + const AliasSummary &getAliasSummary() const { return Summary; } +}; + +CFLAndersAAResult::FunctionInfo::FunctionInfo(const ReachabilitySet &ReachSet, + AliasAttrMap AMap) { + // Populate AttrMap + for (const auto &Mapping : AMap.mappings()) { + auto IVal = Mapping.first; + + // AttrMap only cares about top-level values + if (IVal.DerefLevel == 0) + AttrMap[IVal.Val] = Mapping.second; + } + + // Populate AliasMap + for (const auto &OuterMapping : ReachSet.value_mappings()) { + // AliasMap only cares about top-level values + if (OuterMapping.first.DerefLevel > 0) + continue; + + auto Val = OuterMapping.first.Val; + auto &AliasList = AliasMap[Val]; + for (const auto &InnerMapping : OuterMapping.second) { + // Again, AliasMap only cares about top-level values + if (InnerMapping.first.DerefLevel == 0) + AliasList.push_back(InnerMapping.first.Val); + } + + // Sort AliasList for faster lookup + std::sort(AliasList.begin(), AliasList.end(), std::less<const Value *>()); + } + + // TODO: Populate function summary here +} + +AliasAttrs CFLAndersAAResult::FunctionInfo::getAttrs(const Value *V) const { + assert(V != nullptr); + + AliasAttrs Attr; + auto Itr = AttrMap.find(V); + if (Itr != AttrMap.end()) + Attr = Itr->second; + return Attr; +} + +bool CFLAndersAAResult::FunctionInfo::mayAlias(const Value *LHS, + const Value *RHS) const { + assert(LHS && RHS); + + auto Itr = AliasMap.find(LHS); + if (Itr != AliasMap.end()) { + if (std::binary_search(Itr->second.begin(), Itr->second.end(), RHS, + std::less<const Value *>())) + return true; + } + + // Even if LHS and RHS are not reachable, they may still alias due to their + // AliasAttrs + auto AttrsA = getAttrs(LHS); + auto AttrsB = getAttrs(RHS); + + if (AttrsA.none() || AttrsB.none()) + return false; + if (hasUnknownOrCallerAttr(AttrsA) || hasUnknownOrCallerAttr(AttrsB)) + return true; + if (isGlobalOrArgAttr(AttrsA) && isGlobalOrArgAttr(AttrsB)) + return true; + return false; +} + +static void propagate(InstantiatedValue From, InstantiatedValue To, + MatchState State, ReachabilitySet &ReachSet, + std::vector<WorkListItem> &WorkList) { + if (From == To) + return; + if (ReachSet.insert(From, To, State)) + WorkList.push_back(WorkListItem{From, To, State}); +} + +static void initializeWorkList(std::vector<WorkListItem> &WorkList, + ReachabilitySet &ReachSet, + const CFLGraph &Graph) { + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + auto &ValueInfo = Mapping.second; + assert(ValueInfo.getNumLevels() > 0); + + // Insert all immediate assignment neighbors to the worklist + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Src = InstantiatedValue{Val, I}; + // If there's an assignment edge from X to Y, it means Y is reachable from + // X at S2 and X is reachable from Y at S1 + for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) { + propagate(Edge.Other, Src, MatchState::FlowFrom, ReachSet, WorkList); + propagate(Src, Edge.Other, MatchState::FlowTo, ReachSet, WorkList); + } + } + } +} + +static Optional<InstantiatedValue> getNodeBelow(const CFLGraph &Graph, + InstantiatedValue V) { + auto NodeBelow = InstantiatedValue{V.Val, V.DerefLevel + 1}; + if (Graph.getNode(NodeBelow)) + return NodeBelow; + return None; +} + +static void processWorkListItem(const WorkListItem &Item, const CFLGraph &Graph, + ReachabilitySet &ReachSet, AliasMemSet &MemSet, + std::vector<WorkListItem> &WorkList) { + auto FromNode = Item.From; + auto ToNode = Item.To; + + auto NodeInfo = Graph.getNode(ToNode); + assert(NodeInfo != nullptr); + + // TODO: propagate field offsets + + // FIXME: Here is a neat trick we can do: since both ReachSet and MemSet holds + // relations that are symmetric, we could actually cut the storage by half by + // sorting FromNode and ToNode before insertion happens. + + // The newly added value alias pair may pontentially generate more memory + // alias pairs. Check for them here. + auto FromNodeBelow = getNodeBelow(Graph, FromNode); + auto ToNodeBelow = getNodeBelow(Graph, ToNode); + if (FromNodeBelow && ToNodeBelow && + MemSet.insert(*FromNodeBelow, *ToNodeBelow)) { + propagate(*FromNodeBelow, *ToNodeBelow, MatchState::FlowFromMemAlias, + ReachSet, WorkList); + for (const auto &Mapping : ReachSet.reachableValueAliases(*FromNodeBelow)) { + auto Src = Mapping.first; + if (Mapping.second.test(static_cast<size_t>(MatchState::FlowFrom))) + propagate(Src, *ToNodeBelow, MatchState::FlowFromMemAlias, ReachSet, + WorkList); + if (Mapping.second.test(static_cast<size_t>(MatchState::FlowTo))) + propagate(Src, *ToNodeBelow, MatchState::FlowToMemAlias, ReachSet, + WorkList); + } + } + + // This is the core of the state machine walking algorithm. We expand ReachSet + // based on which state we are at (which in turn dictates what edges we + // should examine) + // From a high-level point of view, the state machine here guarantees two + // properties: + // - If *X and *Y are memory aliases, then X and Y are value aliases + // - If Y is an alias of X, then reverse assignment edges (if there is any) + // should precede any assignment edges on the path from X to Y. + switch (Item.State) { + case MatchState::FlowFrom: { + for (const auto &RevAssignEdge : NodeInfo->ReverseEdges) + propagate(FromNode, RevAssignEdge.Other, MatchState::FlowFrom, ReachSet, + WorkList); + for (const auto &AssignEdge : NodeInfo->Edges) + propagate(FromNode, AssignEdge.Other, MatchState::FlowTo, ReachSet, + WorkList); + if (auto AliasSet = MemSet.getMemoryAliases(ToNode)) { + for (const auto &MemAlias : *AliasSet) + propagate(FromNode, MemAlias, MatchState::FlowFromMemAlias, ReachSet, + WorkList); + } + break; + } + case MatchState::FlowFromMemAlias: { + for (const auto &RevAssignEdge : NodeInfo->ReverseEdges) + propagate(FromNode, RevAssignEdge.Other, MatchState::FlowFrom, ReachSet, + WorkList); + for (const auto &AssignEdge : NodeInfo->Edges) + propagate(FromNode, AssignEdge.Other, MatchState::FlowTo, ReachSet, + WorkList); + break; + } + case MatchState::FlowTo: { + for (const auto &AssignEdge : NodeInfo->Edges) + propagate(FromNode, AssignEdge.Other, MatchState::FlowTo, ReachSet, + WorkList); + if (auto AliasSet = MemSet.getMemoryAliases(ToNode)) { + for (const auto &MemAlias : *AliasSet) + propagate(FromNode, MemAlias, MatchState::FlowToMemAlias, ReachSet, + WorkList); + } + break; + } + case MatchState::FlowToMemAlias: { + for (const auto &AssignEdge : NodeInfo->Edges) + propagate(FromNode, AssignEdge.Other, MatchState::FlowTo, ReachSet, + WorkList); + break; + } + } +} + +static AliasAttrMap buildAttrMap(const CFLGraph &Graph, + const ReachabilitySet &ReachSet) { + AliasAttrMap AttrMap; + std::vector<InstantiatedValue> WorkList, NextList; + + // Initialize each node with its original AliasAttrs in CFLGraph + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + auto &ValueInfo = Mapping.second; + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Node = InstantiatedValue{Val, I}; + AttrMap.add(Node, ValueInfo.getNodeInfoAtLevel(I).Attr); + WorkList.push_back(Node); + } + } + + while (!WorkList.empty()) { + for (const auto &Dst : WorkList) { + auto DstAttr = AttrMap.getAttrs(Dst); + if (DstAttr.none()) + continue; + + // Propagate attr on the same level + for (const auto &Mapping : ReachSet.reachableValueAliases(Dst)) { + auto Src = Mapping.first; + if (AttrMap.add(Src, DstAttr)) + NextList.push_back(Src); + } + + // Propagate attr to the levels below + auto DstBelow = getNodeBelow(Graph, Dst); + while (DstBelow) { + if (AttrMap.add(*DstBelow, DstAttr)) { + NextList.push_back(*DstBelow); + break; + } + DstBelow = getNodeBelow(Graph, *DstBelow); + } + } + WorkList.swap(NextList); + NextList.clear(); + } + + return AttrMap; +} + +CFLAndersAAResult::FunctionInfo +CFLAndersAAResult::buildInfoFrom(const Function &Fn) { + CFLGraphBuilder<CFLAndersAAResult> GraphBuilder( + *this, TLI, + // Cast away the constness here due to GraphBuilder's API requirement + const_cast<Function &>(Fn)); + auto &Graph = GraphBuilder.getCFLGraph(); + + ReachabilitySet ReachSet; + AliasMemSet MemSet; + + std::vector<WorkListItem> WorkList, NextList; + initializeWorkList(WorkList, ReachSet, Graph); + // TODO: make sure we don't stop before the fix point is reached + while (!WorkList.empty()) { + for (const auto &Item : WorkList) + processWorkListItem(Item, Graph, ReachSet, MemSet, NextList); + + NextList.swap(WorkList); + NextList.clear(); + } + + // Now that we have all the reachability info, propagate AliasAttrs according + // to it + auto IValueAttrMap = buildAttrMap(Graph, ReachSet); + + return FunctionInfo(ReachSet, std::move(IValueAttrMap)); +} + +void CFLAndersAAResult::scan(const Function &Fn) { + auto InsertPair = Cache.insert(std::make_pair(&Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call + // may get evaluated after operator[], potentially triggering a DenseMap + // resize and invalidating the reference returned by operator[] + auto FunInfo = buildInfoFrom(Fn); + Cache[&Fn] = std::move(FunInfo); + Handles.push_front(FunctionHandle(const_cast<Function *>(&Fn), this)); +} + +void CFLAndersAAResult::evict(const Function &Fn) { Cache.erase(&Fn); } + +const Optional<CFLAndersAAResult::FunctionInfo> & +CFLAndersAAResult::ensureCached(const Function &Fn) { + auto Iter = Cache.find(&Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(&Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; +} + +const AliasSummary *CFLAndersAAResult::getAliasSummary(const Function &Fn) { + auto &FunInfo = ensureCached(Fn); + if (FunInfo.hasValue()) + return &FunInfo->getAliasSummary(); + else + return nullptr; +} + +AliasResult CFLAndersAAResult::query(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + auto *ValA = LocA.Ptr; + auto *ValB = LocB.Ptr; + + if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) + return NoAlias; + + auto *Fn = parentFunctionOfValue(ValA); + if (!Fn) { + Fn = parentFunctionOfValue(ValB); + if (!Fn) { + // The only times this is known to happen are when globals + InlineAsm are + // involved + DEBUG(dbgs() + << "CFLAndersAA: could not extract parent function information.\n"); + return MayAlias; + } + } else { + assert(!parentFunctionOfValue(ValB) || parentFunctionOfValue(ValB) == Fn); + } + + assert(Fn != nullptr); + auto &FunInfo = ensureCached(*Fn); + + // AliasMap lookup + if (FunInfo->mayAlias(ValA, ValB)) + return MayAlias; + return NoAlias; +} + +AliasResult CFLAndersAAResult::alias(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + if (LocA.Ptr == LocB.Ptr) + return LocA.Size == LocB.Size ? MustAlias : PartialAlias; + + // Comparisons between global variables and other constants should be + // handled by BasicAA. + // CFLAndersAA may report NoAlias when comparing a GlobalValue and + // ConstantExpr, but every query needs to have at least one Value tied to a + // Function, and neither GlobalValues nor ConstantExprs are. + if (isa<Constant>(LocA.Ptr) && isa<Constant>(LocB.Ptr)) + return AAResultBase::alias(LocA, LocB); + + AliasResult QueryResult = query(LocA, LocB); + if (QueryResult == MayAlias) + return AAResultBase::alias(LocA, LocB); + + return QueryResult; +} + +char CFLAndersAA::PassID; + +CFLAndersAAResult CFLAndersAA::run(Function &F, AnalysisManager<Function> &AM) { + return CFLAndersAAResult(AM.getResult<TargetLibraryAnalysis>(F)); +} + +char CFLAndersAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLAndersAAWrapperPass, "cfl-anders-aa", + "Inclusion-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLAndersAAWrapperPass() { + return new CFLAndersAAWrapperPass(); +} + +CFLAndersAAWrapperPass::CFLAndersAAWrapperPass() : ImmutablePass(ID) { + initializeCFLAndersAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLAndersAAWrapperPass::initializePass() { + auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); + Result.reset(new CFLAndersAAResult(TLIWP.getTLI())); +} + +void CFLAndersAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/lib/Analysis/CFLGraph.h b/lib/Analysis/CFLGraph.h new file mode 100644 index 000000000000..bc6e794d0b2a --- /dev/null +++ b/lib/Analysis/CFLGraph.h @@ -0,0 +1,544 @@ +//======- CFLGraph.h - Abstract stratified sets implementation. --------======// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines CFLGraph, an auxiliary data structure used by CFL-based +/// alias analysis. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_CFLGRAPH_H +#define LLVM_ANALYSIS_CFLGRAPH_H + +#include "AliasAnalysisSummary.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" + +namespace llvm { +namespace cflaa { + +/// \brief The Program Expression Graph (PEG) of CFL analysis +/// CFLGraph is auxiliary data structure used by CFL-based alias analysis to +/// describe flow-insensitive pointer-related behaviors. Given an LLVM function, +/// the main purpose of this graph is to abstract away unrelated facts and +/// translate the rest into a form that can be easily digested by CFL analyses. +/// Each Node in the graph is an InstantiatedValue, and each edge represent a +/// pointer assignment between InstantiatedValue. Pointer +/// references/dereferences are not explicitly stored in the graph: we +/// implicitly assume that for each node (X, I) it has a dereference edge to (X, +/// I+1) and a reference edge to (X, I-1). +class CFLGraph { +public: + typedef InstantiatedValue Node; + + struct Edge { + Node Other; + }; + + typedef std::vector<Edge> EdgeList; + + struct NodeInfo { + EdgeList Edges, ReverseEdges; + AliasAttrs Attr; + }; + + class ValueInfo { + std::vector<NodeInfo> Levels; + + public: + bool addNodeToLevel(unsigned Level) { + auto NumLevels = Levels.size(); + if (NumLevels > Level) + return false; + Levels.resize(Level + 1); + return true; + } + + NodeInfo &getNodeInfoAtLevel(unsigned Level) { + assert(Level < Levels.size()); + return Levels[Level]; + } + const NodeInfo &getNodeInfoAtLevel(unsigned Level) const { + assert(Level < Levels.size()); + return Levels[Level]; + } + + unsigned getNumLevels() const { return Levels.size(); } + }; + +private: + typedef DenseMap<Value *, ValueInfo> ValueMap; + ValueMap ValueImpls; + + NodeInfo *getNode(Node N) { + auto Itr = ValueImpls.find(N.Val); + if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) + return nullptr; + return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); + } + +public: + typedef ValueMap::const_iterator const_value_iterator; + + bool addNode(Node N, AliasAttrs Attr = AliasAttrs()) { + assert(N.Val != nullptr); + auto &ValInfo = ValueImpls[N.Val]; + auto Changed = ValInfo.addNodeToLevel(N.DerefLevel); + ValInfo.getNodeInfoAtLevel(N.DerefLevel).Attr |= Attr; + return Changed; + } + + void addAttr(Node N, AliasAttrs Attr) { + auto *Info = getNode(N); + assert(Info != nullptr); + Info->Attr |= Attr; + } + + void addEdge(Node From, Node To, int64_t Offset = 0) { + auto *FromInfo = getNode(From); + assert(FromInfo != nullptr); + auto *ToInfo = getNode(To); + assert(ToInfo != nullptr); + + FromInfo->Edges.push_back(Edge{To}); + ToInfo->ReverseEdges.push_back(Edge{From}); + } + + const NodeInfo *getNode(Node N) const { + auto Itr = ValueImpls.find(N.Val); + if (Itr == ValueImpls.end() || Itr->second.getNumLevels() <= N.DerefLevel) + return nullptr; + return &Itr->second.getNodeInfoAtLevel(N.DerefLevel); + } + + AliasAttrs attrFor(Node N) const { + auto *Info = getNode(N); + assert(Info != nullptr); + return Info->Attr; + } + + iterator_range<const_value_iterator> value_mappings() const { + return make_range<const_value_iterator>(ValueImpls.begin(), + ValueImpls.end()); + } +}; + +///\brief A builder class used to create CFLGraph instance from a given function +/// The CFL-AA that uses this builder must provide its own type as a template +/// argument. This is necessary for interprocedural processing: CFLGraphBuilder +/// needs a way of obtaining the summary of other functions when callinsts are +/// encountered. +/// As a result, we expect the said CFL-AA to expose a getAliasSummary() public +/// member function that takes a Function& and returns the corresponding summary +/// as a const AliasSummary*. +template <typename CFLAA> class CFLGraphBuilder { + // Input of the builder + CFLAA &Analysis; + const TargetLibraryInfo &TLI; + + // Output of the builder + CFLGraph Graph; + SmallVector<Value *, 4> ReturnedValues; + + // Helper class + /// Gets the edges our graph should have, based on an Instruction* + class GetEdgesVisitor : public InstVisitor<GetEdgesVisitor, void> { + CFLAA &AA; + const TargetLibraryInfo &TLI; + + CFLGraph &Graph; + SmallVectorImpl<Value *> &ReturnValues; + + static bool hasUsefulEdges(ConstantExpr *CE) { + // ConstantExpr doesn't have terminators, invokes, or fences, so only + // needs + // to check for compares. + return CE->getOpcode() != Instruction::ICmp && + CE->getOpcode() != Instruction::FCmp; + } + + // Returns possible functions called by CS into the given SmallVectorImpl. + // Returns true if targets found, false otherwise. + static bool getPossibleTargets(CallSite CS, + SmallVectorImpl<Function *> &Output) { + if (auto *Fn = CS.getCalledFunction()) { + Output.push_back(Fn); + return true; + } + + // TODO: If the call is indirect, we might be able to enumerate all + // potential + // targets of the call and return them, rather than just failing. + return false; + } + + void addNode(Value *Val, AliasAttrs Attr = AliasAttrs()) { + assert(Val != nullptr && Val->getType()->isPointerTy()); + if (auto GVal = dyn_cast<GlobalValue>(Val)) { + if (Graph.addNode(InstantiatedValue{GVal, 0}, + getGlobalOrArgAttrFromValue(*GVal))) + Graph.addNode(InstantiatedValue{GVal, 1}, getAttrUnknown()); + } else if (auto CExpr = dyn_cast<ConstantExpr>(Val)) { + if (hasUsefulEdges(CExpr)) { + if (Graph.addNode(InstantiatedValue{CExpr, 0})) + visitConstantExpr(CExpr); + } + } else + Graph.addNode(InstantiatedValue{Val, 0}, Attr); + } + + void addAssignEdge(Value *From, Value *To, int64_t Offset = 0) { + assert(From != nullptr && To != nullptr); + if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) + return; + addNode(From); + if (To != From) { + addNode(To); + Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 0}, + Offset); + } + } + + void addDerefEdge(Value *From, Value *To, bool IsRead) { + assert(From != nullptr && To != nullptr); + if (!From->getType()->isPointerTy() || !To->getType()->isPointerTy()) + return; + addNode(From); + addNode(To); + if (IsRead) { + Graph.addNode(InstantiatedValue{From, 1}); + Graph.addEdge(InstantiatedValue{From, 1}, InstantiatedValue{To, 0}); + } else { + Graph.addNode(InstantiatedValue{To, 1}); + Graph.addEdge(InstantiatedValue{From, 0}, InstantiatedValue{To, 1}); + } + } + + void addLoadEdge(Value *From, Value *To) { addDerefEdge(From, To, true); } + void addStoreEdge(Value *From, Value *To) { addDerefEdge(From, To, false); } + + public: + GetEdgesVisitor(CFLGraphBuilder &Builder) + : AA(Builder.Analysis), TLI(Builder.TLI), Graph(Builder.Graph), + ReturnValues(Builder.ReturnedValues) {} + + void visitInstruction(Instruction &) { + llvm_unreachable("Unsupported instruction encountered"); + } + + void visitReturnInst(ReturnInst &Inst) { + if (auto RetVal = Inst.getReturnValue()) { + if (RetVal->getType()->isPointerTy()) { + addNode(RetVal); + ReturnValues.push_back(RetVal); + } + } + } + + void visitPtrToIntInst(PtrToIntInst &Inst) { + auto *Ptr = Inst.getOperand(0); + addNode(Ptr, getAttrEscaped()); + } + + void visitIntToPtrInst(IntToPtrInst &Inst) { + auto *Ptr = &Inst; + addNode(Ptr, getAttrUnknown()); + } + + void visitCastInst(CastInst &Inst) { + auto *Src = Inst.getOperand(0); + addAssignEdge(Src, &Inst); + } + + void visitBinaryOperator(BinaryOperator &Inst) { + auto *Op1 = Inst.getOperand(0); + auto *Op2 = Inst.getOperand(1); + addAssignEdge(Op1, &Inst); + addAssignEdge(Op2, &Inst); + } + + void visitAtomicCmpXchgInst(AtomicCmpXchgInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getNewValOperand(); + addStoreEdge(Val, Ptr); + } + + void visitAtomicRMWInst(AtomicRMWInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValOperand(); + addStoreEdge(Val, Ptr); + } + + void visitPHINode(PHINode &Inst) { + for (Value *Val : Inst.incoming_values()) + addAssignEdge(Val, &Inst); + } + + void visitGetElementPtrInst(GetElementPtrInst &Inst) { + auto *Op = Inst.getPointerOperand(); + addAssignEdge(Op, &Inst); + } + + void visitSelectInst(SelectInst &Inst) { + // Condition is not processed here (The actual statement producing + // the condition result is processed elsewhere). For select, the + // condition is evaluated, but not loaded, stored, or assigned + // simply as a result of being the condition of a select. + + auto *TrueVal = Inst.getTrueValue(); + auto *FalseVal = Inst.getFalseValue(); + addAssignEdge(TrueVal, &Inst); + addAssignEdge(FalseVal, &Inst); + } + + void visitAllocaInst(AllocaInst &Inst) { addNode(&Inst); } + + void visitLoadInst(LoadInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = &Inst; + addLoadEdge(Ptr, Val); + } + + void visitStoreInst(StoreInst &Inst) { + auto *Ptr = Inst.getPointerOperand(); + auto *Val = Inst.getValueOperand(); + addStoreEdge(Val, Ptr); + } + + void visitVAArgInst(VAArgInst &Inst) { + // We can't fully model va_arg here. For *Ptr = Inst.getOperand(0), it + // does + // two things: + // 1. Loads a value from *((T*)*Ptr). + // 2. Increments (stores to) *Ptr by some target-specific amount. + // For now, we'll handle this like a landingpad instruction (by placing + // the + // result in its own group, and having that group alias externals). + addNode(&Inst, getAttrUnknown()); + } + + static bool isFunctionExternal(Function *Fn) { + return !Fn->hasExactDefinition(); + } + + bool tryInterproceduralAnalysis(CallSite CS, + const SmallVectorImpl<Function *> &Fns) { + assert(Fns.size() > 0); + + if (CS.arg_size() > MaxSupportedArgsInSummary) + return false; + + // Exit early if we'll fail anyway + for (auto *Fn : Fns) { + if (isFunctionExternal(Fn) || Fn->isVarArg()) + return false; + // Fail if the caller does not provide enough arguments + assert(Fn->arg_size() <= CS.arg_size()); + if (!AA.getAliasSummary(*Fn)) + return false; + } + + for (auto *Fn : Fns) { + auto Summary = AA.getAliasSummary(*Fn); + assert(Summary != nullptr); + + auto &RetParamRelations = Summary->RetParamRelations; + for (auto &Relation : RetParamRelations) { + auto IRelation = instantiateExternalRelation(Relation, CS); + if (IRelation.hasValue()) { + Graph.addNode(IRelation->From); + Graph.addNode(IRelation->To); + Graph.addEdge(IRelation->From, IRelation->To); + } + } + + auto &RetParamAttributes = Summary->RetParamAttributes; + for (auto &Attribute : RetParamAttributes) { + auto IAttr = instantiateExternalAttribute(Attribute, CS); + if (IAttr.hasValue()) + Graph.addNode(IAttr->IValue, IAttr->Attr); + } + } + + return true; + } + + void visitCallSite(CallSite CS) { + auto Inst = CS.getInstruction(); + + // Make sure all arguments and return value are added to the graph first + for (Value *V : CS.args()) + if (V->getType()->isPointerTy()) + addNode(V); + if (Inst->getType()->isPointerTy()) + addNode(Inst); + + // Check if Inst is a call to a library function that + // allocates/deallocates + // on the heap. Those kinds of functions do not introduce any aliases. + // TODO: address other common library functions such as realloc(), + // strdup(), + // etc. + if (isMallocLikeFn(Inst, &TLI) || isCallocLikeFn(Inst, &TLI) || + isFreeCall(Inst, &TLI)) + return; + + // TODO: Add support for noalias args/all the other fun function + // attributes + // that we can tack on. + SmallVector<Function *, 4> Targets; + if (getPossibleTargets(CS, Targets)) + if (tryInterproceduralAnalysis(CS, Targets)) + return; + + // Because the function is opaque, we need to note that anything + // could have happened to the arguments (unless the function is marked + // readonly or readnone), and that the result could alias just about + // anything, too (unless the result is marked noalias). + if (!CS.onlyReadsMemory()) + for (Value *V : CS.args()) { + if (V->getType()->isPointerTy()) { + // The argument itself escapes. + Graph.addAttr(InstantiatedValue{V, 0}, getAttrEscaped()); + // The fate of argument memory is unknown. Note that since + // AliasAttrs is transitive with respect to dereference, we only + // need to specify it for the first-level memory. + Graph.addNode(InstantiatedValue{V, 1}, getAttrUnknown()); + } + } + + if (Inst->getType()->isPointerTy()) { + auto *Fn = CS.getCalledFunction(); + if (Fn == nullptr || !Fn->doesNotAlias(0)) + // No need to call addNode() since we've added Inst at the + // beginning of this function and we know it is not a global. + Graph.addAttr(InstantiatedValue{Inst, 0}, getAttrUnknown()); + } + } + + /// Because vectors/aggregates are immutable and unaddressable, there's + /// nothing we can do to coax a value out of them, other than calling + /// Extract{Element,Value}. We can effectively treat them as pointers to + /// arbitrary memory locations we can store in and load from. + void visitExtractElementInst(ExtractElementInst &Inst) { + auto *Ptr = Inst.getVectorOperand(); + auto *Val = &Inst; + addLoadEdge(Ptr, Val); + } + + void visitInsertElementInst(InsertElementInst &Inst) { + auto *Vec = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + addAssignEdge(Vec, &Inst); + addStoreEdge(Val, &Inst); + } + + void visitLandingPadInst(LandingPadInst &Inst) { + // Exceptions come from "nowhere", from our analysis' perspective. + // So we place the instruction its own group, noting that said group may + // alias externals + addNode(&Inst, getAttrUnknown()); + } + + void visitInsertValueInst(InsertValueInst &Inst) { + auto *Agg = Inst.getOperand(0); + auto *Val = Inst.getOperand(1); + addAssignEdge(Agg, &Inst); + addStoreEdge(Val, &Inst); + } + + void visitExtractValueInst(ExtractValueInst &Inst) { + auto *Ptr = Inst.getAggregateOperand(); + addLoadEdge(Ptr, &Inst); + } + + void visitShuffleVectorInst(ShuffleVectorInst &Inst) { + auto *From1 = Inst.getOperand(0); + auto *From2 = Inst.getOperand(1); + addAssignEdge(From1, &Inst); + addAssignEdge(From2, &Inst); + } + + void visitConstantExpr(ConstantExpr *CE) { + switch (CE->getOpcode()) { + default: + llvm_unreachable("Unknown instruction type encountered!"); +// Build the switch statement using the Instruction.def file. +#define HANDLE_INST(NUM, OPCODE, CLASS) \ + case Instruction::OPCODE: \ + this->visit##OPCODE(*(CLASS *)CE); \ + break; +#include "llvm/IR/Instruction.def" + } + } + }; + + // Helper functions + + // Determines whether or not we an instruction is useless to us (e.g. + // FenceInst) + static bool hasUsefulEdges(Instruction *Inst) { + bool IsNonInvokeRetTerminator = isa<TerminatorInst>(Inst) && + !isa<InvokeInst>(Inst) && + !isa<ReturnInst>(Inst); + return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && + !IsNonInvokeRetTerminator; + } + + void addArgumentToGraph(Argument &Arg) { + if (Arg.getType()->isPointerTy()) { + Graph.addNode(InstantiatedValue{&Arg, 0}, + getGlobalOrArgAttrFromValue(Arg)); + // Pointees of a formal parameter is known to the caller + Graph.addNode(InstantiatedValue{&Arg, 1}, getAttrCaller()); + } + } + + // Given an Instruction, this will add it to the graph, along with any + // Instructions that are potentially only available from said Instruction + // For example, given the following line: + // %0 = load i16* getelementptr ([1 x i16]* @a, 0, 0), align 2 + // addInstructionToGraph would add both the `load` and `getelementptr` + // instructions to the graph appropriately. + void addInstructionToGraph(GetEdgesVisitor &Visitor, Instruction &Inst) { + if (!hasUsefulEdges(&Inst)) + return; + + Visitor.visit(Inst); + } + + // Builds the graph needed for constructing the StratifiedSets for the given + // function + void buildGraphFrom(Function &Fn) { + GetEdgesVisitor Visitor(*this); + + for (auto &Bb : Fn.getBasicBlockList()) + for (auto &Inst : Bb.getInstList()) + addInstructionToGraph(Visitor, Inst); + + for (auto &Arg : Fn.args()) + addArgumentToGraph(Arg); + } + +public: + CFLGraphBuilder(CFLAA &Analysis, const TargetLibraryInfo &TLI, Function &Fn) + : Analysis(Analysis), TLI(TLI) { + buildGraphFrom(Fn); + } + + const CFLGraph &getCFLGraph() const { return Graph; } + const SmallVector<Value *, 4> &getReturnValues() const { + return ReturnedValues; + } +}; +} +} + +#endif diff --git a/lib/Analysis/CFLSteensAliasAnalysis.cpp b/lib/Analysis/CFLSteensAliasAnalysis.cpp new file mode 100644 index 000000000000..d816822aaaea --- /dev/null +++ b/lib/Analysis/CFLSteensAliasAnalysis.cpp @@ -0,0 +1,442 @@ +//- CFLSteensAliasAnalysis.cpp - Unification-based Alias Analysis ---*- C++-*-// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a CFL-base, summary-based alias analysis algorithm. It +// does not depend on types. The algorithm is a mixture of the one described in +// "Demand-driven alias analysis for C" by Xin Zheng and Radu Rugina, and "Fast +// algorithms for Dyck-CFL-reachability with applications to Alias Analysis" by +// Zhang Q, Lyu M R, Yuan H, and Su Z. -- to summarize the papers, we build a +// graph of the uses of a variable, where each node is a memory location, and +// each edge is an action that happened on that memory location. The "actions" +// can be one of Dereference, Reference, or Assign. The precision of this +// analysis is roughly the same as that of an one level context-sensitive +// Steensgaard's algorithm. +// +// Two variables are considered as aliasing iff you can reach one value's node +// from the other value's node and the language formed by concatenating all of +// the edge labels (actions) conforms to a context-free grammar. +// +// Because this algorithm requires a graph search on each query, we execute the +// algorithm outlined in "Fast algorithms..." (mentioned above) +// in order to transform the graph into sets of variables that may alias in +// ~nlogn time (n = number of variables), which makes queries take constant +// time. +//===----------------------------------------------------------------------===// + +// N.B. AliasAnalysis as a whole is phrased as a FunctionPass at the moment, and +// CFLSteensAA is interprocedural. This is *technically* A Bad Thing, because +// FunctionPasses are only allowed to inspect the Function that they're being +// run on. Realistically, this likely isn't a problem until we allow +// FunctionPasses to run concurrently. + +#include "llvm/Analysis/CFLSteensAliasAnalysis.h" +#include "CFLGraph.h" +#include "StratifiedSets.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <memory> +#include <tuple> + +using namespace llvm; +using namespace llvm::cflaa; + +#define DEBUG_TYPE "cfl-steens-aa" + +CFLSteensAAResult::CFLSteensAAResult(const TargetLibraryInfo &TLI) + : AAResultBase(), TLI(TLI) {} +CFLSteensAAResult::CFLSteensAAResult(CFLSteensAAResult &&Arg) + : AAResultBase(std::move(Arg)), TLI(Arg.TLI) {} +CFLSteensAAResult::~CFLSteensAAResult() {} + +/// Information we have about a function and would like to keep around. +class CFLSteensAAResult::FunctionInfo { + StratifiedSets<InstantiatedValue> Sets; + AliasSummary Summary; + +public: + FunctionInfo(Function &Fn, const SmallVectorImpl<Value *> &RetVals, + StratifiedSets<InstantiatedValue> S); + + const StratifiedSets<InstantiatedValue> &getStratifiedSets() const { + return Sets; + } + const AliasSummary &getAliasSummary() const { return Summary; } +}; + +/// Try to go from a Value* to a Function*. Never returns nullptr. +static Optional<Function *> parentFunctionOfValue(Value *); + +const StratifiedIndex StratifiedLink::SetSentinel = + std::numeric_limits<StratifiedIndex>::max(); + +//===----------------------------------------------------------------------===// +// Function declarations that require types defined in the namespace above +//===----------------------------------------------------------------------===// + +/// Determines whether it would be pointless to add the given Value to our sets. +static bool canSkipAddingToSets(Value *Val); + +static Optional<Function *> parentFunctionOfValue(Value *Val) { + if (auto *Inst = dyn_cast<Instruction>(Val)) { + auto *Bb = Inst->getParent(); + return Bb->getParent(); + } + + if (auto *Arg = dyn_cast<Argument>(Val)) + return Arg->getParent(); + return None; +} + +static bool canSkipAddingToSets(Value *Val) { + // Constants can share instances, which may falsely unify multiple + // sets, e.g. in + // store i32* null, i32** %ptr1 + // store i32* null, i32** %ptr2 + // clearly ptr1 and ptr2 should not be unified into the same set, so + // we should filter out the (potentially shared) instance to + // i32* null. + if (isa<Constant>(Val)) { + // TODO: Because all of these things are constant, we can determine whether + // the data is *actually* mutable at graph building time. This will probably + // come for free/cheap with offset awareness. + bool CanStoreMutableData = isa<GlobalValue>(Val) || + isa<ConstantExpr>(Val) || + isa<ConstantAggregate>(Val); + return !CanStoreMutableData; + } + + return false; +} + +CFLSteensAAResult::FunctionInfo::FunctionInfo( + Function &Fn, const SmallVectorImpl<Value *> &RetVals, + StratifiedSets<InstantiatedValue> S) + : Sets(std::move(S)) { + // Historically, an arbitrary upper-bound of 50 args was selected. We may want + // to remove this if it doesn't really matter in practice. + if (Fn.arg_size() > MaxSupportedArgsInSummary) + return; + + DenseMap<StratifiedIndex, InterfaceValue> InterfaceMap; + + // Our intention here is to record all InterfaceValues that share the same + // StratifiedIndex in RetParamRelations. For each valid InterfaceValue, we + // have its StratifiedIndex scanned here and check if the index is presented + // in InterfaceMap: if it is not, we add the correspondence to the map; + // otherwise, an aliasing relation is found and we add it to + // RetParamRelations. + + auto AddToRetParamRelations = [&](unsigned InterfaceIndex, + StratifiedIndex SetIndex) { + unsigned Level = 0; + while (true) { + InterfaceValue CurrValue{InterfaceIndex, Level}; + + auto Itr = InterfaceMap.find(SetIndex); + if (Itr != InterfaceMap.end()) { + if (CurrValue != Itr->second) + Summary.RetParamRelations.push_back( + ExternalRelation{CurrValue, Itr->second}); + break; + } + + auto &Link = Sets.getLink(SetIndex); + InterfaceMap.insert(std::make_pair(SetIndex, CurrValue)); + auto ExternalAttrs = getExternallyVisibleAttrs(Link.Attrs); + if (ExternalAttrs.any()) + Summary.RetParamAttributes.push_back( + ExternalAttribute{CurrValue, ExternalAttrs}); + + if (!Link.hasBelow()) + break; + + ++Level; + SetIndex = Link.Below; + } + }; + + // Populate RetParamRelations for return values + for (auto *RetVal : RetVals) { + assert(RetVal != nullptr); + assert(RetVal->getType()->isPointerTy()); + auto RetInfo = Sets.find(InstantiatedValue{RetVal, 0}); + if (RetInfo.hasValue()) + AddToRetParamRelations(0, RetInfo->Index); + } + + // Populate RetParamRelations for parameters + unsigned I = 0; + for (auto &Param : Fn.args()) { + if (Param.getType()->isPointerTy()) { + auto ParamInfo = Sets.find(InstantiatedValue{&Param, 0}); + if (ParamInfo.hasValue()) + AddToRetParamRelations(I + 1, ParamInfo->Index); + } + ++I; + } +} + +// Builds the graph + StratifiedSets for a function. +CFLSteensAAResult::FunctionInfo CFLSteensAAResult::buildSetsFrom(Function *Fn) { + CFLGraphBuilder<CFLSteensAAResult> GraphBuilder(*this, TLI, *Fn); + StratifiedSetsBuilder<InstantiatedValue> SetBuilder; + + // Add all CFLGraph nodes and all Dereference edges to StratifiedSets + auto &Graph = GraphBuilder.getCFLGraph(); + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + if (canSkipAddingToSets(Val)) + continue; + auto &ValueInfo = Mapping.second; + + assert(ValueInfo.getNumLevels() > 0); + SetBuilder.add(InstantiatedValue{Val, 0}); + SetBuilder.noteAttributes(InstantiatedValue{Val, 0}, + ValueInfo.getNodeInfoAtLevel(0).Attr); + for (unsigned I = 0, E = ValueInfo.getNumLevels() - 1; I < E; ++I) { + SetBuilder.add(InstantiatedValue{Val, I + 1}); + SetBuilder.noteAttributes(InstantiatedValue{Val, I + 1}, + ValueInfo.getNodeInfoAtLevel(I + 1).Attr); + SetBuilder.addBelow(InstantiatedValue{Val, I}, + InstantiatedValue{Val, I + 1}); + } + } + + // Add all assign edges to StratifiedSets + for (const auto &Mapping : Graph.value_mappings()) { + auto Val = Mapping.first; + if (canSkipAddingToSets(Val)) + continue; + auto &ValueInfo = Mapping.second; + + for (unsigned I = 0, E = ValueInfo.getNumLevels(); I < E; ++I) { + auto Src = InstantiatedValue{Val, I}; + for (auto &Edge : ValueInfo.getNodeInfoAtLevel(I).Edges) + SetBuilder.addWith(Src, Edge.Other); + } + } + + return FunctionInfo(*Fn, GraphBuilder.getReturnValues(), SetBuilder.build()); +} + +void CFLSteensAAResult::scan(Function *Fn) { + auto InsertPair = Cache.insert(std::make_pair(Fn, Optional<FunctionInfo>())); + (void)InsertPair; + assert(InsertPair.second && + "Trying to scan a function that has already been cached"); + + // Note that we can't do Cache[Fn] = buildSetsFrom(Fn) here: the function call + // may get evaluated after operator[], potentially triggering a DenseMap + // resize and invalidating the reference returned by operator[] + auto FunInfo = buildSetsFrom(Fn); + Cache[Fn] = std::move(FunInfo); + + Handles.push_front(FunctionHandle(Fn, this)); +} + +void CFLSteensAAResult::evict(Function *Fn) { Cache.erase(Fn); } + +/// Ensures that the given function is available in the cache, and returns the +/// entry. +const Optional<CFLSteensAAResult::FunctionInfo> & +CFLSteensAAResult::ensureCached(Function *Fn) { + auto Iter = Cache.find(Fn); + if (Iter == Cache.end()) { + scan(Fn); + Iter = Cache.find(Fn); + assert(Iter != Cache.end()); + assert(Iter->second.hasValue()); + } + return Iter->second; +} + +const AliasSummary *CFLSteensAAResult::getAliasSummary(Function &Fn) { + auto &FunInfo = ensureCached(&Fn); + if (FunInfo.hasValue()) + return &FunInfo->getAliasSummary(); + else + return nullptr; +} + +AliasResult CFLSteensAAResult::query(const MemoryLocation &LocA, + const MemoryLocation &LocB) { + auto *ValA = const_cast<Value *>(LocA.Ptr); + auto *ValB = const_cast<Value *>(LocB.Ptr); + + if (!ValA->getType()->isPointerTy() || !ValB->getType()->isPointerTy()) + return NoAlias; + + Function *Fn = nullptr; + auto MaybeFnA = parentFunctionOfValue(ValA); + auto MaybeFnB = parentFunctionOfValue(ValB); + if (!MaybeFnA.hasValue() && !MaybeFnB.hasValue()) { + // The only times this is known to happen are when globals + InlineAsm are + // involved + DEBUG(dbgs() + << "CFLSteensAA: could not extract parent function information.\n"); + return MayAlias; + } + + if (MaybeFnA.hasValue()) { + Fn = *MaybeFnA; + assert((!MaybeFnB.hasValue() || *MaybeFnB == *MaybeFnA) && + "Interprocedural queries not supported"); + } else { + Fn = *MaybeFnB; + } + + assert(Fn != nullptr); + auto &MaybeInfo = ensureCached(Fn); + assert(MaybeInfo.hasValue()); + + auto &Sets = MaybeInfo->getStratifiedSets(); + auto MaybeA = Sets.find(InstantiatedValue{ValA, 0}); + if (!MaybeA.hasValue()) + return MayAlias; + + auto MaybeB = Sets.find(InstantiatedValue{ValB, 0}); + if (!MaybeB.hasValue()) + return MayAlias; + + auto SetA = *MaybeA; + auto SetB = *MaybeB; + auto AttrsA = Sets.getLink(SetA.Index).Attrs; + auto AttrsB = Sets.getLink(SetB.Index).Attrs; + + // If both values are local (meaning the corresponding set has attribute + // AttrNone or AttrEscaped), then we know that CFLSteensAA fully models them: + // they may-alias each other if and only if they are in the same set. + // If at least one value is non-local (meaning it either is global/argument or + // it comes from unknown sources like integer cast), the situation becomes a + // bit more interesting. We follow three general rules described below: + // - Non-local values may alias each other + // - AttrNone values do not alias any non-local values + // - AttrEscaped do not alias globals/arguments, but they may alias + // AttrUnknown values + if (SetA.Index == SetB.Index) + return MayAlias; + if (AttrsA.none() || AttrsB.none()) + return NoAlias; + if (hasUnknownOrCallerAttr(AttrsA) || hasUnknownOrCallerAttr(AttrsB)) + return MayAlias; + if (isGlobalOrArgAttr(AttrsA) && isGlobalOrArgAttr(AttrsB)) + return MayAlias; + return NoAlias; +} + +ModRefInfo CFLSteensAAResult::getArgModRefInfo(ImmutableCallSite CS, + unsigned ArgIdx) { + if (auto CalledFunc = CS.getCalledFunction()) { + auto &MaybeInfo = ensureCached(const_cast<Function *>(CalledFunc)); + if (!MaybeInfo.hasValue()) + return MRI_ModRef; + auto &RetParamAttributes = MaybeInfo->getAliasSummary().RetParamAttributes; + auto &RetParamRelations = MaybeInfo->getAliasSummary().RetParamRelations; + + bool ArgAttributeIsWritten = + std::any_of(RetParamAttributes.begin(), RetParamAttributes.end(), + [ArgIdx](const ExternalAttribute &ExtAttr) { + return ExtAttr.IValue.Index == ArgIdx + 1; + }); + bool ArgIsAccessed = + std::any_of(RetParamRelations.begin(), RetParamRelations.end(), + [ArgIdx](const ExternalRelation &ExtRelation) { + return ExtRelation.To.Index == ArgIdx + 1 || + ExtRelation.From.Index == ArgIdx + 1; + }); + + return (!ArgIsAccessed && !ArgAttributeIsWritten) ? MRI_NoModRef + : MRI_ModRef; + } + + return MRI_ModRef; +} + +FunctionModRefBehavior +CFLSteensAAResult::getModRefBehavior(ImmutableCallSite CS) { + // If we know the callee, try analyzing it + if (auto CalledFunc = CS.getCalledFunction()) + return getModRefBehavior(CalledFunc); + + // Otherwise, be conservative + return FMRB_UnknownModRefBehavior; +} + +FunctionModRefBehavior CFLSteensAAResult::getModRefBehavior(const Function *F) { + assert(F != nullptr); + + // TODO: Remove the const_cast + auto &MaybeInfo = ensureCached(const_cast<Function *>(F)); + if (!MaybeInfo.hasValue()) + return FMRB_UnknownModRefBehavior; + auto &RetParamAttributes = MaybeInfo->getAliasSummary().RetParamAttributes; + auto &RetParamRelations = MaybeInfo->getAliasSummary().RetParamRelations; + + // First, if any argument is marked Escpaed, Unknown or Global, anything may + // happen to them and thus we can't draw any conclusion. + if (!RetParamAttributes.empty()) + return FMRB_UnknownModRefBehavior; + + // Currently we don't (and can't) distinguish reads from writes in + // RetParamRelations. All we can say is whether there may be memory access or + // not. + if (RetParamRelations.empty()) + return FMRB_DoesNotAccessMemory; + + // Check if something beyond argmem gets touched. + bool AccessArgMemoryOnly = + std::all_of(RetParamRelations.begin(), RetParamRelations.end(), + [](const ExternalRelation &ExtRelation) { + // Both DerefLevels has to be 0, since we don't know which + // one is a read and which is a write. + return ExtRelation.From.DerefLevel == 0 && + ExtRelation.To.DerefLevel == 0; + }); + return AccessArgMemoryOnly ? FMRB_OnlyAccessesArgumentPointees + : FMRB_UnknownModRefBehavior; +} + +char CFLSteensAA::PassID; + +CFLSteensAAResult CFLSteensAA::run(Function &F, AnalysisManager<Function> &AM) { + return CFLSteensAAResult(AM.getResult<TargetLibraryAnalysis>(F)); +} + +char CFLSteensAAWrapperPass::ID = 0; +INITIALIZE_PASS(CFLSteensAAWrapperPass, "cfl-steens-aa", + "Unification-Based CFL Alias Analysis", false, true) + +ImmutablePass *llvm::createCFLSteensAAWrapperPass() { + return new CFLSteensAAWrapperPass(); +} + +CFLSteensAAWrapperPass::CFLSteensAAWrapperPass() : ImmutablePass(ID) { + initializeCFLSteensAAWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void CFLSteensAAWrapperPass::initializePass() { + auto &TLIWP = getAnalysis<TargetLibraryInfoWrapperPass>(); + Result.reset(new CFLSteensAAResult(TLIWP.getTLI())); +} + +void CFLSteensAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); +} diff --git a/lib/Analysis/CGSCCPassManager.cpp b/lib/Analysis/CGSCCPassManager.cpp index 4a03002e510b..f6f30bb927a5 100644 --- a/lib/Analysis/CGSCCPassManager.cpp +++ b/lib/Analysis/CGSCCPassManager.cpp @@ -8,65 +8,17 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/CGSCCPassManager.h" -#include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" using namespace llvm; -char CGSCCAnalysisManagerModuleProxy::PassID; - -CGSCCAnalysisManagerModuleProxy::Result -CGSCCAnalysisManagerModuleProxy::run(Module &M) { - assert(CGAM->empty() && "CGSCC analyses ran prior to the module proxy!"); - return Result(*CGAM); -} - -CGSCCAnalysisManagerModuleProxy::Result::~Result() { - // Clear out the analysis manager if we're being destroyed -- it means we - // didn't even see an invalidate call when we got invalidated. - CGAM->clear(); -} - -bool CGSCCAnalysisManagerModuleProxy::Result::invalidate( - Module &M, const PreservedAnalyses &PA) { - // If this proxy isn't marked as preserved, then we can't even invalidate - // individual CGSCC analyses, there may be an invalid set of SCC objects in - // the cache making it impossible to incrementally preserve them. - // Just clear the entire manager. - if (!PA.preserved(ID())) - CGAM->clear(); - - // Return false to indicate that this result is still a valid proxy. - return false; -} - -char ModuleAnalysisManagerCGSCCProxy::PassID; - -char FunctionAnalysisManagerCGSCCProxy::PassID; - -FunctionAnalysisManagerCGSCCProxy::Result -FunctionAnalysisManagerCGSCCProxy::run(LazyCallGraph::SCC &C) { - assert(FAM->empty() && "Function analyses ran prior to the CGSCC proxy!"); - return Result(*FAM); -} - -FunctionAnalysisManagerCGSCCProxy::Result::~Result() { - // Clear out the analysis manager if we're being destroyed -- it means we - // didn't even see an invalidate call when we got invalidated. - FAM->clear(); +// Explicit instantiations for the core proxy templates. +namespace llvm { +template class PassManager<LazyCallGraph::SCC>; +template class AnalysisManager<LazyCallGraph::SCC>; +template class InnerAnalysisManagerProxy<CGSCCAnalysisManager, Module>; +template class OuterAnalysisManagerProxy<ModuleAnalysisManager, + LazyCallGraph::SCC>; +template class InnerAnalysisManagerProxy<FunctionAnalysisManager, + LazyCallGraph::SCC>; +template class OuterAnalysisManagerProxy<CGSCCAnalysisManager, Function>; } - -bool FunctionAnalysisManagerCGSCCProxy::Result::invalidate( - LazyCallGraph::SCC &C, const PreservedAnalyses &PA) { - // If this proxy isn't marked as preserved, then we can't even invalidate - // individual function analyses, there may be an invalid set of Function - // objects in the cache making it impossible to incrementally preserve them. - // Just clear the entire manager. - if (!PA.preserved(ID())) - FAM->clear(); - - // Return false to indicate that this result is still a valid proxy. - return false; -} - -char CGSCCAnalysisManagerFunctionProxy::PassID; diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 69623619a8b0..57ad437ef4fd 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(LLVMAnalysis AliasAnalysis.cpp AliasAnalysisEvaluator.cpp + AliasAnalysisSummary.cpp AliasSetTracker.cpp Analysis.cpp AssumptionCache.cpp @@ -10,7 +11,8 @@ add_llvm_library(LLVMAnalysis BranchProbabilityInfo.cpp CFG.cpp CFGPrinter.cpp - CFLAliasAnalysis.cpp + CFLAndersAliasAnalysis.cpp + CFLSteensAliasAnalysis.cpp CGSCCPassManager.cpp CallGraph.cpp CallGraphSCCPass.cpp @@ -28,31 +30,38 @@ add_llvm_library(LLVMAnalysis EHPersonalities.cpp GlobalsModRef.cpp IVUsers.cpp + IndirectCallPromotionAnalysis.cpp InlineCost.cpp InstCount.cpp InstructionSimplify.cpp Interval.cpp IntervalPartition.cpp IteratedDominanceFrontier.cpp + LazyBlockFrequencyInfo.cpp LazyCallGraph.cpp LazyValueInfo.cpp Lint.cpp Loads.cpp LoopAccessAnalysis.cpp + LoopUnrollAnalyzer.cpp LoopInfo.cpp LoopPass.cpp + LoopPassManager.cpp MemDepPrinter.cpp MemDerefPrinter.cpp MemoryBuiltins.cpp MemoryDependenceAnalysis.cpp MemoryLocation.cpp ModuleDebugInfoPrinter.cpp + ModuleSummaryAnalysis.cpp ObjCARCAliasAnalysis.cpp ObjCARCAnalysisUtils.cpp ObjCARCInstKind.cpp + OptimizationDiagnosticInfo.cpp OrderedBasicBlock.cpp PHITransAddr.cpp PostDominators.cpp + ProfileSummaryInfo.cpp PtrUseVisitor.cpp RegionInfo.cpp RegionPass.cpp @@ -66,6 +75,7 @@ add_llvm_library(LLVMAnalysis TargetTransformInfo.cpp Trace.cpp TypeBasedAliasAnalysis.cpp + TypeMetadataUtils.cpp ScopedNoAliasAA.cpp ValueTracking.cpp VectorUtils.cpp diff --git a/lib/Analysis/CallGraph.cpp b/lib/Analysis/CallGraph.cpp index 7cec962678e8..39cb86d2ccb1 100644 --- a/lib/Analysis/CallGraph.cpp +++ b/lib/Analysis/CallGraph.cpp @@ -80,11 +80,9 @@ void CallGraph::addToCallGraph(Function *F) { Node->addCalledFunction(CallSite(), CallsExternalNode.get()); // Look for calls by this function. - for (Function::iterator BB = F->begin(), BBE = F->end(); BB != BBE; ++BB) - for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE; - ++II) { - CallSite CS(cast<Value>(II)); - if (CS) { + for (BasicBlock &BB : *F) + for (Instruction &I : BB) { + if (auto CS = CallSite(&I)) { const Function *Callee = CS.getCalledFunction(); if (!Callee || !Intrinsic::isLeaf(Callee->getIntrinsicID())) // Indirect calls of intrinsics are not allowed so no need to check. @@ -111,8 +109,8 @@ void CallGraph::print(raw_ostream &OS) const { SmallVector<CallGraphNode *, 16> Nodes; Nodes.reserve(FunctionMap.size()); - for (auto I = begin(), E = end(); I != E; ++I) - Nodes.push_back(I->second.get()); + for (const auto &I : *this) + Nodes.push_back(I.second.get()); std::sort(Nodes.begin(), Nodes.end(), [](CallGraphNode *LHS, CallGraphNode *RHS) { @@ -186,9 +184,9 @@ void CallGraphNode::print(raw_ostream &OS) const { OS << "<<" << this << ">> #uses=" << getNumReferences() << '\n'; - for (const_iterator I = begin(), E = end(); I != E; ++I) { - OS << " CS<" << I->first << "> calls "; - if (Function *FI = I->second->getFunction()) + for (const auto &I : *this) { + OS << " CS<" << I.first << "> calls "; + if (Function *FI = I.second->getFunction()) OS << "function '" << FI->getName() <<"'\n"; else OS << "external node\n"; @@ -259,12 +257,19 @@ void CallGraphNode::replaceCallEdge(CallSite CS, } } +// Provide an explicit template instantiation for the static ID. +char CallGraphAnalysis::PassID; + +PreservedAnalyses CallGraphPrinterPass::run(Module &M, + AnalysisManager<Module> &AM) { + AM.getResult<CallGraphAnalysis>(M).print(OS); + return PreservedAnalyses::all(); +} + //===----------------------------------------------------------------------===// // Out-of-line definitions of CallGraphAnalysis class members. // -char CallGraphAnalysis::PassID; - //===----------------------------------------------------------------------===// // Implementations of the CallGraphWrapperPass class methods. // @@ -304,3 +309,29 @@ void CallGraphWrapperPass::print(raw_ostream &OS, const Module *) const { LLVM_DUMP_METHOD void CallGraphWrapperPass::dump() const { print(dbgs(), nullptr); } + +namespace { +struct CallGraphPrinterLegacyPass : public ModulePass { + static char ID; // Pass ID, replacement for typeid + CallGraphPrinterLegacyPass() : ModulePass(ID) { + initializeCallGraphPrinterLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + AU.addRequiredTransitive<CallGraphWrapperPass>(); + } + bool runOnModule(Module &M) override { + getAnalysis<CallGraphWrapperPass>().print(errs(), &M); + return false; + } +}; +} + +char CallGraphPrinterLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(CallGraphPrinterLegacyPass, "print-callgraph", + "Print a call graph", true, true) +INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) +INITIALIZE_PASS_END(CallGraphPrinterLegacyPass, "print-callgraph", + "Print a call graph", true, true) diff --git a/lib/Analysis/CallGraphSCCPass.cpp b/lib/Analysis/CallGraphSCCPass.cpp index 6dd1d0a066b6..69d767354785 100644 --- a/lib/Analysis/CallGraphSCCPass.cpp +++ b/lib/Analysis/CallGraphSCCPass.cpp @@ -23,6 +23,7 @@ #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManagers.h" +#include "llvm/IR/OptBisect.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" @@ -260,10 +261,10 @@ bool CGPassManager::RefreshCallGraph(CallGraphSCC &CurSCC, // Loop over all of the instructions in the function, getting the callsites. // Keep track of the number of direct/indirect calls added. unsigned NumDirectAdded = 0, NumIndirectAdded = 0; - - for (Function::iterator BB = F->begin(), E = F->end(); BB != E; ++BB) - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { - CallSite CS(cast<Value>(I)); + + for (BasicBlock &BB : *F) + for (Instruction &I : BB) { + CallSite CS(&I); if (!CS) continue; Function *Callee = CS.getCalledFunction(); if (Callee && Callee->isIntrinsic()) continue; @@ -444,7 +445,7 @@ bool CGPassManager::runOnModule(Module &M) { // Walk the callgraph in bottom-up SCC order. scc_iterator<CallGraph*> CGI = scc_begin(&CG); - CallGraphSCC CurSCC(&CGI); + CallGraphSCC CurSCC(CG, &CGI); while (!CGI.isAtEnd()) { // Copy the current SCC and increment past it so that the pass can hack // on the SCC if it wants to without invalidating our iterator. @@ -631,3 +632,13 @@ Pass *CallGraphSCCPass::createPrinterPass(raw_ostream &O, return new PrintCallGraphPass(Banner, O); } +bool CallGraphSCCPass::skipSCC(CallGraphSCC &SCC) const { + return !SCC.getCallGraph().getModule() + .getContext() + .getOptBisect() + .shouldRunPass(this, SCC); +} + +char DummyCGSCCPass::ID = 0; +INITIALIZE_PASS(DummyCGSCCPass, "DummyCGSCCPass", "DummyCGSCCPass", false, + false) diff --git a/lib/Analysis/CallPrinter.cpp b/lib/Analysis/CallPrinter.cpp index 68dcd3c06427..af942e9ed3e9 100644 --- a/lib/Analysis/CallPrinter.cpp +++ b/lib/Analysis/CallPrinter.cpp @@ -58,16 +58,16 @@ struct CallGraphViewer } }; -struct CallGraphPrinter : public DOTGraphTraitsModulePrinter< +struct CallGraphDOTPrinter : public DOTGraphTraitsModulePrinter< CallGraphWrapperPass, true, CallGraph *, AnalysisCallGraphWrapperPassTraits> { static char ID; - CallGraphPrinter() + CallGraphDOTPrinter() : DOTGraphTraitsModulePrinter<CallGraphWrapperPass, true, CallGraph *, AnalysisCallGraphWrapperPassTraits>( "callgraph", ID) { - initializeCallGraphPrinterPass(*PassRegistry::getPassRegistry()); + initializeCallGraphDOTPrinterPass(*PassRegistry::getPassRegistry()); } }; @@ -77,8 +77,8 @@ char CallGraphViewer::ID = 0; INITIALIZE_PASS(CallGraphViewer, "view-callgraph", "View call graph", false, false) -char CallGraphPrinter::ID = 0; -INITIALIZE_PASS(CallGraphPrinter, "dot-callgraph", +char CallGraphDOTPrinter::ID = 0; +INITIALIZE_PASS(CallGraphDOTPrinter, "dot-callgraph", "Print call graph to 'dot' file", false, false) // Create methods available outside of this file, to use them @@ -87,6 +87,6 @@ INITIALIZE_PASS(CallGraphPrinter, "dot-callgraph", ModulePass *llvm::createCallGraphViewerPass() { return new CallGraphViewer(); } -ModulePass *llvm::createCallGraphPrinterPass() { - return new CallGraphPrinter(); +ModulePass *llvm::createCallGraphDOTPrinterPass() { + return new CallGraphDOTPrinter(); } diff --git a/lib/Analysis/CaptureTracking.cpp b/lib/Analysis/CaptureTracking.cpp index 1add2fa77566..9862c3c9c270 100644 --- a/lib/Analysis/CaptureTracking.cpp +++ b/lib/Analysis/CaptureTracking.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" using namespace llvm; @@ -242,6 +243,13 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { if (CS.onlyReadsMemory() && CS.doesNotThrow() && I->getType()->isVoidTy()) break; + // Volatile operations effectively capture the memory location that they + // load and store to. + if (auto *MI = dyn_cast<MemIntrinsic>(I)) + if (MI->isVolatile()) + if (Tracker->captured(U)) + return; + // Not captured if only passed via 'nocapture' arguments. Note that // calling a function pointer does not in itself cause the pointer to // be captured. This is a subtle point considering that (for example) @@ -259,18 +267,46 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { break; } case Instruction::Load: - // Loading from a pointer does not cause it to be captured. + // Volatile loads make the address observable. + if (cast<LoadInst>(I)->isVolatile()) + if (Tracker->captured(U)) + return; break; case Instruction::VAArg: // "va-arg" from a pointer does not cause it to be captured. break; case Instruction::Store: - if (V == I->getOperand(0)) // Stored the pointer - conservatively assume it may be captured. + // Volatile stores make the address observable. + if (V == I->getOperand(0) || cast<StoreInst>(I)->isVolatile()) + if (Tracker->captured(U)) + return; + break; + case Instruction::AtomicRMW: { + // atomicrmw conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ARMWI = cast<AtomicRMWInst>(I); + if (ARMWI->getValOperand() == V || ARMWI->isVolatile()) if (Tracker->captured(U)) return; - // Storing to the pointee does not cause the pointer to be captured. break; + } + case Instruction::AtomicCmpXchg: { + // cmpxchg conceptually includes both a load and store from + // the same location. + // As with a store, the location being accessed is not captured, + // but the value being stored is. + // Volatile stores make the address observable. + auto *ACXI = cast<AtomicCmpXchgInst>(I); + if (ACXI->getCompareOperand() == V || ACXI->getNewValOperand() == V || + ACXI->isVolatile()) + if (Tracker->captured(U)) + return; + break; + } case Instruction::BitCast: case Instruction::GetElementPtr: case Instruction::PHI: @@ -289,7 +325,7 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { Worklist.push_back(&UU); } break; - case Instruction::ICmp: + case Instruction::ICmp: { // Don't count comparisons of a no-alias return value against null as // captures. This allows us to ignore comparisons of malloc results // with null, for example. @@ -298,11 +334,19 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { if (CPN->getType()->getAddressSpace() == 0) if (isNoAliasCall(V->stripPointerCasts())) break; + // Comparison against value stored in global variable. Given the pointer + // does not escape, its value cannot be guessed and stored separately in a + // global variable. + unsigned OtherIndex = (I->getOperand(0) == V) ? 1 : 0; + auto *LI = dyn_cast<LoadInst>(I->getOperand(OtherIndex)); + if (LI && isa<GlobalVariable>(LI->getPointerOperand())) + break; // Otherwise, be conservative. There are crazy ways to capture pointers // using comparisons. if (Tracker->captured(U)) return; break; + } default: // Something else - be conservative and say it is captured. if (Tracker->captured(U)) diff --git a/lib/Analysis/CodeMetrics.cpp b/lib/Analysis/CodeMetrics.cpp index 4090b4cd752b..ed8370498dd0 100644 --- a/lib/Analysis/CodeMetrics.cpp +++ b/lib/Analysis/CodeMetrics.cpp @@ -100,22 +100,21 @@ void CodeMetrics::collectEphemeralValues( completeEphemeralValues(WorkSet, EphValues); } -/// analyzeBasicBlock - Fill in the current structure with information gleaned -/// from the specified block. +/// Fill in the current structure with information gleaned from the specified +/// block. void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, const TargetTransformInfo &TTI, SmallPtrSetImpl<const Value*> &EphValues) { ++NumBlocks; unsigned NumInstsBeforeThisBB = NumInsts; - for (BasicBlock::const_iterator II = BB->begin(), E = BB->end(); - II != E; ++II) { + for (const Instruction &I : *BB) { // Skip ephemeral values. - if (EphValues.count(&*II)) + if (EphValues.count(&I)) continue; // Special handling for calls. - if (isa<CallInst>(II) || isa<InvokeInst>(II)) { - ImmutableCallSite CS(cast<Instruction>(II)); + if (isa<CallInst>(I) || isa<InvokeInst>(I)) { + ImmutableCallSite CS(&I); if (const Function *F = CS.getCalledFunction()) { // If a function is both internal and has a single use, then it is @@ -141,26 +140,29 @@ void CodeMetrics::analyzeBasicBlock(const BasicBlock *BB, } } - if (const AllocaInst *AI = dyn_cast<AllocaInst>(II)) { + if (const AllocaInst *AI = dyn_cast<AllocaInst>(&I)) { if (!AI->isStaticAlloca()) this->usesDynamicAlloca = true; } - if (isa<ExtractElementInst>(II) || II->getType()->isVectorTy()) + if (isa<ExtractElementInst>(I) || I.getType()->isVectorTy()) ++NumVectorInsts; - if (II->getType()->isTokenTy() && II->isUsedOutsideOfBlock(BB)) + if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) notDuplicatable = true; - if (const CallInst *CI = dyn_cast<CallInst>(II)) + if (const CallInst *CI = dyn_cast<CallInst>(&I)) { if (CI->cannotDuplicate()) notDuplicatable = true; + if (CI->isConvergent()) + convergent = true; + } - if (const InvokeInst *InvI = dyn_cast<InvokeInst>(II)) + if (const InvokeInst *InvI = dyn_cast<InvokeInst>(&I)) if (InvI->cannotDuplicate()) notDuplicatable = true; - NumInsts += TTI.getUserCost(&*II); + NumInsts += TTI.getUserCost(&I); } if (isa<ReturnInst>(BB->getTerminator())) diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index ccb56631b846..6c471ab45048 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -17,6 +17,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" @@ -34,15 +35,16 @@ #include "llvm/IR/Operator.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" +#include <cassert> #include <cerrno> +#include <cfenv> #include <cmath> - -#ifdef HAVE_FENV_H -#include <fenv.h> -#endif +#include <limits> using namespace llvm; +namespace { + //===----------------------------------------------------------------------===// // Constant Folding internal helper functions //===----------------------------------------------------------------------===// @@ -50,7 +52,7 @@ using namespace llvm; /// Constant fold bitcast, symbolically evaluating it with DataLayout. /// This always returns a non-null constant, but it may be a /// ConstantExpr if unfoldable. -static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { +Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { // Catch the obvious splat cases. if (C->isNullValue() && !DestTy->isX86_MMXTy()) return Constant::getNullValue(DestTy); @@ -59,8 +61,8 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { return Constant::getAllOnesValue(DestTy); // Handle a vector->integer cast. - if (IntegerType *IT = dyn_cast<IntegerType>(DestTy)) { - VectorType *VTy = dyn_cast<VectorType>(C->getType()); + if (auto *IT = dyn_cast<IntegerType>(DestTy)) { + auto *VTy = dyn_cast<VectorType>(C->getType()); if (!VTy) return ConstantExpr::getBitCast(C, DestTy); @@ -77,27 +79,30 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { C = ConstantExpr::getBitCast(C, SrcIVTy); } - ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(C); - if (!CDV) - return ConstantExpr::getBitCast(C, DestTy); - // Now that we know that the input value is a vector of integers, just shift // and insert them into our result. - unsigned BitShift = DL.getTypeAllocSizeInBits(SrcEltTy); + unsigned BitShift = DL.getTypeSizeInBits(SrcEltTy); APInt Result(IT->getBitWidth(), 0); for (unsigned i = 0; i != NumSrcElts; ++i) { - Result <<= BitShift; + Constant *Element; if (DL.isLittleEndian()) - Result |= CDV->getElementAsInteger(NumSrcElts-i-1); + Element = C->getAggregateElement(NumSrcElts-i-1); else - Result |= CDV->getElementAsInteger(i); + Element = C->getAggregateElement(i); + + auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); + if (!ElementCI) + return ConstantExpr::getBitCast(C, DestTy); + + Result <<= BitShift; + Result |= ElementCI->getValue().zextOrSelf(IT->getBitWidth()); } return ConstantInt::get(IT, Result); } // The code below only handles casts to vectors currently. - VectorType *DestVTy = dyn_cast<VectorType>(DestTy); + auto *DestVTy = dyn_cast<VectorType>(DestTy); if (!DestVTy) return ConstantExpr::getBitCast(C, DestTy); @@ -175,7 +180,7 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { Constant *Elt = Zero; unsigned ShiftAmt = isLittleEndian ? 0 : SrcBitSize*(Ratio-1); for (unsigned j = 0; j != Ratio; ++j) { - Constant *Src =dyn_cast<ConstantInt>(C->getAggregateElement(SrcElt++)); + Constant *Src = dyn_cast<ConstantInt>(C->getAggregateElement(SrcElt++)); if (!Src) // Reject constantexpr elements. return ConstantExpr::getBitCast(C, DestTy); @@ -201,7 +206,7 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { // Loop over each source value, expanding into multiple results. for (unsigned i = 0; i != NumSrcElt; ++i) { - Constant *Src = dyn_cast<ConstantInt>(C->getAggregateElement(i)); + auto *Src = dyn_cast<ConstantInt>(C->getAggregateElement(i)); if (!Src) // Reject constantexpr elements. return ConstantExpr::getBitCast(C, DestTy); @@ -230,11 +235,12 @@ static Constant *FoldBitCast(Constant *C, Type *DestTy, const DataLayout &DL) { return ConstantVector::get(Result); } +} // end anonymous namespace /// If this constant is a constant offset from a global, return the global and /// the constant. Because of constantexprs, this function is recursive. -static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, - APInt &Offset, const DataLayout &DL) { +bool llvm::IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, + APInt &Offset, const DataLayout &DL) { // Trivial case, constant is the global. if ((GV = dyn_cast<GlobalValue>(C))) { unsigned BitWidth = DL.getPointerTypeSizeInBits(GV->getType()); @@ -243,7 +249,7 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, } // Otherwise, if this isn't a constant expr, bail out. - ConstantExpr *CE = dyn_cast<ConstantExpr>(C); + auto *CE = dyn_cast<ConstantExpr>(C); if (!CE) return false; // Look through ptr->int and ptr->ptr casts. @@ -252,7 +258,7 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, return IsConstantOffsetFromGlobal(CE->getOperand(0), GV, Offset, DL); // i32* getelementptr ([5 x i32]* @a, i32 0, i32 5) - GEPOperator *GEP = dyn_cast<GEPOperator>(CE); + auto *GEP = dyn_cast<GEPOperator>(CE); if (!GEP) return false; @@ -271,13 +277,14 @@ static bool IsConstantOffsetFromGlobal(Constant *C, GlobalValue *&GV, return true; } +namespace { + /// Recursive helper to read bits out of global. C is the constant being copied /// out of. ByteOffset is an offset into C. CurPtr is the pointer to copy /// results into and BytesLeft is the number of bytes left in /// the CurPtr buffer. DL is the DataLayout. -static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, - unsigned char *CurPtr, unsigned BytesLeft, - const DataLayout &DL) { +bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, unsigned char *CurPtr, + unsigned BytesLeft, const DataLayout &DL) { assert(ByteOffset <= DL.getTypeAllocSize(C->getType()) && "Out of range access"); @@ -286,7 +293,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, if (isa<ConstantAggregateZero>(C) || isa<UndefValue>(C)) return true; - if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) { + if (auto *CI = dyn_cast<ConstantInt>(C)) { if (CI->getBitWidth() > 64 || (CI->getBitWidth() & 7) != 0) return false; @@ -304,7 +311,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, return true; } - if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { + if (auto *CFP = dyn_cast<ConstantFP>(C)) { if (CFP->getType()->isDoubleTy()) { C = FoldBitCast(C, Type::getInt64Ty(C->getContext()), DL); return ReadDataFromGlobal(C, ByteOffset, CurPtr, BytesLeft, DL); @@ -320,7 +327,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, return false; } - if (ConstantStruct *CS = dyn_cast<ConstantStruct>(C)) { + if (auto *CS = dyn_cast<ConstantStruct>(C)) { const StructLayout *SL = DL.getStructLayout(CS->getType()); unsigned Index = SL->getElementContainingOffset(ByteOffset); uint64_t CurEltOffset = SL->getElementOffset(Index); @@ -364,7 +371,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, uint64_t Index = ByteOffset / EltSize; uint64_t Offset = ByteOffset - Index * EltSize; uint64_t NumElts; - if (ArrayType *AT = dyn_cast<ArrayType>(C->getType())) + if (auto *AT = dyn_cast<ArrayType>(C->getType())) NumElts = AT->getNumElements(); else NumElts = C->getType()->getVectorNumElements(); @@ -386,7 +393,7 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, return true; } - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + if (auto *CE = dyn_cast<ConstantExpr>(C)) { if (CE->getOpcode() == Instruction::IntToPtr && CE->getOperand(0)->getType() == DL.getIntPtrType(CE->getType())) { return ReadDataFromGlobal(CE->getOperand(0), ByteOffset, CurPtr, @@ -398,11 +405,10 @@ static bool ReadDataFromGlobal(Constant *C, uint64_t ByteOffset, return false; } -static Constant *FoldReinterpretLoadFromConstPtr(Constant *C, - const DataLayout &DL) { - PointerType *PTy = cast<PointerType>(C->getType()); - Type *LoadTy = PTy->getElementType(); - IntegerType *IntType = dyn_cast<IntegerType>(LoadTy); +Constant *FoldReinterpretLoadFromConstPtr(Constant *C, Type *LoadTy, + const DataLayout &DL) { + auto *PTy = cast<PointerType>(C->getType()); + auto *IntType = dyn_cast<IntegerType>(LoadTy); // If this isn't an integer load we can't fold it directly. if (!IntType) { @@ -414,19 +420,19 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C, // an actual new load. Type *MapTy; if (LoadTy->isHalfTy()) - MapTy = Type::getInt16PtrTy(C->getContext(), AS); + MapTy = Type::getInt16Ty(C->getContext()); else if (LoadTy->isFloatTy()) - MapTy = Type::getInt32PtrTy(C->getContext(), AS); + MapTy = Type::getInt32Ty(C->getContext()); else if (LoadTy->isDoubleTy()) - MapTy = Type::getInt64PtrTy(C->getContext(), AS); + MapTy = Type::getInt64Ty(C->getContext()); else if (LoadTy->isVectorTy()) { - MapTy = PointerType::getIntNPtrTy(C->getContext(), - DL.getTypeAllocSizeInBits(LoadTy), AS); + MapTy = PointerType::getIntNTy(C->getContext(), + DL.getTypeAllocSizeInBits(LoadTy)); } else return nullptr; - C = FoldBitCast(C, MapTy, DL); - if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, DL)) + C = FoldBitCast(C, MapTy->getPointerTo(AS), DL); + if (Constant *Res = FoldReinterpretLoadFromConstPtr(C, MapTy, DL)) return FoldBitCast(Res, LoadTy, DL); return nullptr; } @@ -436,28 +442,38 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C, return nullptr; GlobalValue *GVal; - APInt Offset; - if (!IsConstantOffsetFromGlobal(C, GVal, Offset, DL)) + APInt OffsetAI; + if (!IsConstantOffsetFromGlobal(C, GVal, OffsetAI, DL)) return nullptr; - GlobalVariable *GV = dyn_cast<GlobalVariable>(GVal); + auto *GV = dyn_cast<GlobalVariable>(GVal); if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer() || !GV->getInitializer()->getType()->isSized()) return nullptr; - // If we're loading off the beginning of the global, some bytes may be valid, - // but we don't try to handle this. - if (Offset.isNegative()) - return nullptr; + int64_t Offset = OffsetAI.getSExtValue(); + int64_t InitializerSize = DL.getTypeAllocSize(GV->getInitializer()->getType()); + + // If we're not accessing anything in this constant, the result is undefined. + if (Offset + BytesLoaded <= 0) + return UndefValue::get(IntType); // If we're not accessing anything in this constant, the result is undefined. - if (Offset.getZExtValue() >= - DL.getTypeAllocSize(GV->getInitializer()->getType())) + if (Offset >= InitializerSize) return UndefValue::get(IntType); unsigned char RawBytes[32] = {0}; - if (!ReadDataFromGlobal(GV->getInitializer(), Offset.getZExtValue(), RawBytes, - BytesLoaded, DL)) + unsigned char *CurPtr = RawBytes; + unsigned BytesLeft = BytesLoaded; + + // If we're loading off the beginning of the global, some bytes may be valid. + if (Offset < 0) { + CurPtr += -Offset; + BytesLeft += Offset; + Offset = 0; + } + + if (!ReadDataFromGlobal(GV->getInitializer(), Offset, CurPtr, BytesLeft, DL)) return nullptr; APInt ResultVal = APInt(IntType->getBitWidth(), 0); @@ -478,14 +494,15 @@ static Constant *FoldReinterpretLoadFromConstPtr(Constant *C, return ConstantInt::get(IntType->getContext(), ResultVal); } -static Constant *ConstantFoldLoadThroughBitcast(ConstantExpr *CE, - const DataLayout &DL) { - auto *DestPtrTy = dyn_cast<PointerType>(CE->getType()); - if (!DestPtrTy) +Constant *ConstantFoldLoadThroughBitcast(ConstantExpr *CE, Type *DestTy, + const DataLayout &DL) { + auto *SrcPtr = CE->getOperand(0); + auto *SrcPtrTy = dyn_cast<PointerType>(SrcPtr->getType()); + if (!SrcPtrTy) return nullptr; - Type *DestTy = DestPtrTy->getElementType(); + Type *SrcTy = SrcPtrTy->getPointerElementType(); - Constant *C = ConstantFoldLoadFromConstPtr(CE->getOperand(0), DL); + Constant *C = ConstantFoldLoadFromConstPtr(SrcPtr, SrcTy, DL); if (!C) return nullptr; @@ -522,26 +539,26 @@ static Constant *ConstantFoldLoadThroughBitcast(ConstantExpr *CE, return nullptr; } -/// Return the value that a load from C would produce if it is constant and -/// determinable. If this is not determinable, return null. -Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, +} // end anonymous namespace + +Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, Type *Ty, const DataLayout &DL) { // First, try the easy cases: - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(C)) + if (auto *GV = dyn_cast<GlobalVariable>(C)) if (GV->isConstant() && GV->hasDefinitiveInitializer()) return GV->getInitializer(); if (auto *GA = dyn_cast<GlobalAlias>(C)) - if (GA->getAliasee() && !GA->mayBeOverridden()) - return ConstantFoldLoadFromConstPtr(GA->getAliasee(), DL); + if (GA->getAliasee() && !GA->isInterposable()) + return ConstantFoldLoadFromConstPtr(GA->getAliasee(), Ty, DL); // If the loaded value isn't a constant expr, we can't handle it. - ConstantExpr *CE = dyn_cast<ConstantExpr>(C); + auto *CE = dyn_cast<ConstantExpr>(C); if (!CE) return nullptr; if (CE->getOpcode() == Instruction::GetElementPtr) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) { + if (auto *GV = dyn_cast<GlobalVariable>(CE->getOperand(0))) { if (GV->isConstant() && GV->hasDefinitiveInitializer()) { if (Constant *V = ConstantFoldLoadThroughGEPConstantExpr(GV->getInitializer(), CE)) @@ -551,15 +568,14 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, } if (CE->getOpcode() == Instruction::BitCast) - if (Constant *LoadedC = ConstantFoldLoadThroughBitcast(CE, DL)) + if (Constant *LoadedC = ConstantFoldLoadThroughBitcast(CE, Ty, DL)) return LoadedC; // Instead of loading constant c string, use corresponding integer value // directly if string length is small enough. StringRef Str; if (getConstantStringInfo(CE, Str) && !Str.empty()) { - unsigned StrLen = Str.size(); - Type *Ty = cast<PointerType>(CE->getType())->getElementType(); + size_t StrLen = Str.size(); unsigned NumBits = Ty->getPrimitiveSizeInBits(); // Replace load with immediate integer if the result is an integer or fp // value. @@ -568,13 +584,13 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, APInt StrVal(NumBits, 0); APInt SingleChar(NumBits, 0); if (DL.isLittleEndian()) { - for (signed i = StrLen-1; i >= 0; i--) { - SingleChar = (uint64_t) Str[i] & UCHAR_MAX; + for (unsigned char C : reverse(Str.bytes())) { + SingleChar = static_cast<uint64_t>(C); StrVal = (StrVal << 8) | SingleChar; } } else { - for (unsigned i = 0; i < StrLen; i++) { - SingleChar = (uint64_t) Str[i] & UCHAR_MAX; + for (unsigned char C : Str.bytes()) { + SingleChar = static_cast<uint64_t>(C); StrVal = (StrVal << 8) | SingleChar; } // Append NULL at the end. @@ -591,27 +607,26 @@ Constant *llvm::ConstantFoldLoadFromConstPtr(Constant *C, // If this load comes from anywhere in a constant global, and if the global // is all undef or zero, we know what it loads. - if (GlobalVariable *GV = - dyn_cast<GlobalVariable>(GetUnderlyingObject(CE, DL))) { + if (auto *GV = dyn_cast<GlobalVariable>(GetUnderlyingObject(CE, DL))) { if (GV->isConstant() && GV->hasDefinitiveInitializer()) { - Type *ResTy = cast<PointerType>(C->getType())->getElementType(); if (GV->getInitializer()->isNullValue()) - return Constant::getNullValue(ResTy); + return Constant::getNullValue(Ty); if (isa<UndefValue>(GV->getInitializer())) - return UndefValue::get(ResTy); + return UndefValue::get(Ty); } } // Try hard to fold loads from bitcasted strange and non-type-safe things. - return FoldReinterpretLoadFromConstPtr(CE, DL); + return FoldReinterpretLoadFromConstPtr(CE, Ty, DL); } -static Constant *ConstantFoldLoadInst(const LoadInst *LI, - const DataLayout &DL) { +namespace { + +Constant *ConstantFoldLoadInst(const LoadInst *LI, const DataLayout &DL) { if (LI->isVolatile()) return nullptr; - if (Constant *C = dyn_cast<Constant>(LI->getOperand(0))) - return ConstantFoldLoadFromConstPtr(C, DL); + if (auto *C = dyn_cast<Constant>(LI->getOperand(0))) + return ConstantFoldLoadFromConstPtr(C, LI->getType(), DL); return nullptr; } @@ -620,9 +635,8 @@ static Constant *ConstantFoldLoadInst(const LoadInst *LI, /// Attempt to symbolically evaluate the result of a binary operator merging /// these together. If target data info is available, it is provided as DL, /// otherwise DL is null. -static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, - Constant *Op1, - const DataLayout &DL) { +Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, Constant *Op1, + const DataLayout &DL) { // SROA // Fold (and 0xffffffff00000000, (shl x, 32)) -> shl. @@ -674,18 +688,16 @@ static Constant *SymbolicallyEvaluateBinop(unsigned Opc, Constant *Op0, /// If array indices are not pointer-sized integers, explicitly cast them so /// that they aren't implicitly casted by the getelementptr. -static Constant *CastGEPIndices(Type *SrcTy, ArrayRef<Constant *> Ops, - Type *ResultTy, const DataLayout &DL, - const TargetLibraryInfo *TLI) { +Constant *CastGEPIndices(Type *SrcElemTy, ArrayRef<Constant *> Ops, + Type *ResultTy, const DataLayout &DL, + const TargetLibraryInfo *TLI) { Type *IntPtrTy = DL.getIntPtrType(ResultTy); bool Any = false; SmallVector<Constant*, 32> NewIdxs; for (unsigned i = 1, e = Ops.size(); i != e; ++i) { if ((i == 1 || - !isa<StructType>(GetElementPtrInst::getIndexedType( - cast<PointerType>(Ops[0]->getType()->getScalarType()) - ->getElementType(), + !isa<StructType>(GetElementPtrInst::getIndexedType(SrcElemTy, Ops.slice(1, i - 1)))) && Ops[i]->getType() != IntPtrTy) { Any = true; @@ -701,8 +713,8 @@ static Constant *CastGEPIndices(Type *SrcTy, ArrayRef<Constant *> Ops, if (!Any) return nullptr; - Constant *C = ConstantExpr::getGetElementPtr(SrcTy, Ops[0], NewIdxs); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(C)) { + Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ops[0], NewIdxs); + if (auto *CE = dyn_cast<ConstantExpr>(C)) { if (Constant *Folded = ConstantFoldConstantExpression(CE, DL, TLI)) C = Folded; } @@ -711,32 +723,41 @@ static Constant *CastGEPIndices(Type *SrcTy, ArrayRef<Constant *> Ops, } /// Strip the pointer casts, but preserve the address space information. -static Constant* StripPtrCastKeepAS(Constant* Ptr) { +Constant* StripPtrCastKeepAS(Constant* Ptr, Type *&ElemTy) { assert(Ptr->getType()->isPointerTy() && "Not a pointer type"); - PointerType *OldPtrTy = cast<PointerType>(Ptr->getType()); + auto *OldPtrTy = cast<PointerType>(Ptr->getType()); Ptr = Ptr->stripPointerCasts(); - PointerType *NewPtrTy = cast<PointerType>(Ptr->getType()); + auto *NewPtrTy = cast<PointerType>(Ptr->getType()); + + ElemTy = NewPtrTy->getPointerElementType(); // Preserve the address space number of the pointer. if (NewPtrTy->getAddressSpace() != OldPtrTy->getAddressSpace()) { - NewPtrTy = NewPtrTy->getElementType()->getPointerTo( - OldPtrTy->getAddressSpace()); + NewPtrTy = ElemTy->getPointerTo(OldPtrTy->getAddressSpace()); Ptr = ConstantExpr::getPointerCast(Ptr, NewPtrTy); } return Ptr; } /// If we can symbolically evaluate the GEP constant expression, do so. -static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, - Type *ResultTy, const DataLayout &DL, - const TargetLibraryInfo *TLI) { +Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + Type *SrcElemTy = GEP->getSourceElementType(); + Type *ResElemTy = GEP->getResultElementType(); + Type *ResTy = GEP->getType(); + if (!SrcElemTy->isSized()) + return nullptr; + + if (Constant *C = CastGEPIndices(SrcElemTy, Ops, ResTy, DL, TLI)) + return C; + Constant *Ptr = Ops[0]; - if (!Ptr->getType()->getPointerElementType()->isSized() || - !Ptr->getType()->isPointerTy()) + if (!Ptr->getType()->isPointerTy()) return nullptr; Type *IntPtrTy = DL.getIntPtrType(Ptr->getType()); - Type *ResultElementTy = ResultTy->getPointerElementType(); // If this is a constant expr gep that is effectively computing an // "offsetof", fold it into 'cast int Size to T*' instead of 'gep 0, 0, 12' @@ -745,16 +766,16 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, // If this is "gep i8* Ptr, (sub 0, V)", fold this as: // "inttoptr (sub (ptrtoint Ptr), V)" - if (Ops.size() == 2 && ResultElementTy->isIntegerTy(8)) { - ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[1]); + if (Ops.size() == 2 && ResElemTy->isIntegerTy(8)) { + auto *CE = dyn_cast<ConstantExpr>(Ops[1]); assert((!CE || CE->getType() == IntPtrTy) && "CastGEPIndices didn't canonicalize index types!"); if (CE && CE->getOpcode() == Instruction::Sub && CE->getOperand(0)->isNullValue()) { Constant *Res = ConstantExpr::getPtrToInt(Ptr, CE->getType()); Res = ConstantExpr::getSub(Res, CE->getOperand(1)); - Res = ConstantExpr::getIntToPtr(Res, ResultTy); - if (ConstantExpr *ResCE = dyn_cast<ConstantExpr>(Res)) + Res = ConstantExpr::getIntToPtr(Res, ResTy); + if (auto *ResCE = dyn_cast<ConstantExpr>(Res)) Res = ConstantFoldConstantExpression(ResCE, DL, TLI); return Res; } @@ -765,19 +786,19 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, unsigned BitWidth = DL.getTypeSizeInBits(IntPtrTy); APInt Offset = APInt(BitWidth, - DL.getIndexedOffset( - Ptr->getType(), + DL.getIndexedOffsetInType( + SrcElemTy, makeArrayRef((Value * const *)Ops.data() + 1, Ops.size() - 1))); - Ptr = StripPtrCastKeepAS(Ptr); + Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); // If this is a GEP of a GEP, fold it all into a single GEP. - while (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) { + while (auto *GEP = dyn_cast<GEPOperator>(Ptr)) { SmallVector<Value *, 4> NestedOps(GEP->op_begin() + 1, GEP->op_end()); // Do not try the incorporate the sub-GEP if some index is not a number. bool AllConstantInt = true; - for (unsigned i = 0, e = NestedOps.size(); i != e; ++i) - if (!isa<ConstantInt>(NestedOps[i])) { + for (Value *NestedOp : NestedOps) + if (!isa<ConstantInt>(NestedOp)) { AllConstantInt = false; break; } @@ -785,23 +806,24 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, break; Ptr = cast<Constant>(GEP->getOperand(0)); - Offset += APInt(BitWidth, DL.getIndexedOffset(Ptr->getType(), NestedOps)); - Ptr = StripPtrCastKeepAS(Ptr); + SrcElemTy = GEP->getSourceElementType(); + Offset += APInt(BitWidth, DL.getIndexedOffsetInType(SrcElemTy, NestedOps)); + Ptr = StripPtrCastKeepAS(Ptr, SrcElemTy); } // If the base value for this address is a literal integer value, fold the // getelementptr to the resulting integer value casted to the pointer type. APInt BasePtr(BitWidth, 0); - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ptr)) { + if (auto *CE = dyn_cast<ConstantExpr>(Ptr)) { if (CE->getOpcode() == Instruction::IntToPtr) { - if (ConstantInt *Base = dyn_cast<ConstantInt>(CE->getOperand(0))) + if (auto *Base = dyn_cast<ConstantInt>(CE->getOperand(0))) BasePtr = Base->getValue().zextOrTrunc(BitWidth); } } if (Ptr->isNullValue() || BasePtr != 0) { Constant *C = ConstantInt::get(Ptr->getContext(), Offset + BasePtr); - return ConstantExpr::getIntToPtr(C, ResultTy); + return ConstantExpr::getIntToPtr(C, ResTy); } // Otherwise form a regular getelementptr. Recompute the indices so that @@ -813,39 +835,49 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, SmallVector<Constant *, 32> NewIdxs; do { - if (SequentialType *ATy = dyn_cast<SequentialType>(Ty)) { - if (ATy->isPointerTy()) { + if (!Ty->isStructTy()) { + if (Ty->isPointerTy()) { // The only pointer indexing we'll do is on the first index of the GEP. if (!NewIdxs.empty()) break; + Ty = SrcElemTy; + // Only handle pointers to sized types, not pointers to functions. - if (!ATy->getElementType()->isSized()) + if (!Ty->isSized()) return nullptr; + } else if (auto *ATy = dyn_cast<SequentialType>(Ty)) { + Ty = ATy->getElementType(); + } else { + // We've reached some non-indexable type. + break; } // Determine which element of the array the offset points into. - APInt ElemSize(BitWidth, DL.getTypeAllocSize(ATy->getElementType())); - if (ElemSize == 0) + APInt ElemSize(BitWidth, DL.getTypeAllocSize(Ty)); + if (ElemSize == 0) { // The element size is 0. This may be [0 x Ty]*, so just use a zero // index for this level and proceed to the next level to see if it can // accommodate the offset. NewIdxs.push_back(ConstantInt::get(IntPtrTy, 0)); - else { + } else { // The element size is non-zero divide the offset by the element // size (rounding down), to compute the index at this level. - APInt NewIdx = Offset.udiv(ElemSize); + bool Overflow; + APInt NewIdx = Offset.sdiv_ov(ElemSize, Overflow); + if (Overflow) + break; Offset -= NewIdx * ElemSize; NewIdxs.push_back(ConstantInt::get(IntPtrTy, NewIdx)); } - Ty = ATy->getElementType(); - } else if (StructType *STy = dyn_cast<StructType>(Ty)) { + } else { + auto *STy = cast<StructType>(Ty); // If we end up with an offset that isn't valid for this struct type, we // can't re-form this GEP in a regular form, so bail out. The pointer // operand likely went through casts that are necessary to make the GEP // sensible. const StructLayout &SL = *DL.getStructLayout(STy); - if (Offset.uge(SL.getSizeInBytes())) + if (Offset.isNegative() || Offset.uge(SL.getSizeInBytes())) break; // Determine which field of the struct the offset points into. The @@ -856,11 +888,8 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, ElIdx)); Offset -= APInt(BitWidth, SL.getElementOffset(ElIdx)); Ty = STy->getTypeAtIndex(ElIdx); - } else { - // We've reached some non-indexable type. - break; } - } while (Ty != ResultElementTy); + } while (Ty != ResElemTy); // If we haven't used up the entire offset by descending the static // type, then the offset is pointing into the middle of an indivisible @@ -869,33 +898,78 @@ static Constant *SymbolicallyEvaluateGEP(Type *SrcTy, ArrayRef<Constant *> Ops, return nullptr; // Create a GEP. - Constant *C = ConstantExpr::getGetElementPtr(SrcTy, Ptr, NewIdxs); + Constant *C = ConstantExpr::getGetElementPtr(SrcElemTy, Ptr, NewIdxs); assert(C->getType()->getPointerElementType() == Ty && "Computed GetElementPtr has unexpected type!"); // If we ended up indexing a member with a type that doesn't match // the type of what the original indices indexed, add a cast. - if (Ty != ResultElementTy) - C = FoldBitCast(C, ResultTy, DL); + if (Ty != ResElemTy) + C = FoldBitCast(C, ResTy, DL); return C; } +/// Attempt to constant fold an instruction with the +/// specified opcode and operands. If successful, the constant result is +/// returned, if not, null is returned. Note that this function can fail when +/// attempting to fold instructions like loads and stores, which have no +/// constant expression form. +/// +/// TODO: This function neither utilizes nor preserves nsw/nuw/inbounds/etc +/// information, due to only being passed an opcode and operands. Constant +/// folding using this function strips this information. +/// +Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, Type *DestTy, + unsigned Opcode, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + // Handle easy binops first. + if (Instruction::isBinaryOp(Opcode)) + return ConstantFoldBinaryOpOperands(Opcode, Ops[0], Ops[1], DL); + + if (Instruction::isCast(Opcode)) + return ConstantFoldCastOperand(Opcode, Ops[0], DestTy, DL); + + if (auto *GEP = dyn_cast<GEPOperator>(InstOrCE)) { + if (Constant *C = SymbolicallyEvaluateGEP(GEP, Ops, DL, TLI)) + return C; + + return ConstantExpr::getGetElementPtr(GEP->getSourceElementType(), + Ops[0], Ops.slice(1)); + } + + switch (Opcode) { + default: return nullptr; + case Instruction::ICmp: + case Instruction::FCmp: llvm_unreachable("Invalid for compares"); + case Instruction::Call: + if (auto *F = dyn_cast<Function>(Ops.back())) + if (canConstantFoldCallTo(F)) + return ConstantFoldCall(F, Ops.slice(0, Ops.size() - 1), TLI); + return nullptr; + case Instruction::Select: + return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); + case Instruction::ExtractElement: + return ConstantExpr::getExtractElement(Ops[0], Ops[1]); + case Instruction::InsertElement: + return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); + case Instruction::ShuffleVector: + return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); + } +} +} // end anonymous namespace //===----------------------------------------------------------------------===// // Constant Folding public APIs //===----------------------------------------------------------------------===// -/// Try to constant fold the specified instruction. -/// If successful, the constant result is returned, if not, null is returned. -/// Note that this fails if not all of the operands are constant. Otherwise, -/// this function can only fail when attempting to fold instructions like loads -/// and stores, which have no constant expression form. Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, const TargetLibraryInfo *TLI) { // Handle PHI nodes quickly here... - if (PHINode *PN = dyn_cast<PHINode>(I)) { + if (auto *PN = dyn_cast<PHINode>(I)) { Constant *CommonValue = nullptr; for (Value *Incoming : PN->incoming_values()) { @@ -906,11 +980,11 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, if (isa<UndefValue>(Incoming)) continue; // If the incoming value is not a constant, then give up. - Constant *C = dyn_cast<Constant>(Incoming); + auto *C = dyn_cast<Constant>(Incoming); if (!C) return nullptr; // Fold the PHI's operands. - if (ConstantExpr *NewC = dyn_cast<ConstantExpr>(C)) + if (auto *NewC = dyn_cast<ConstantExpr>(C)) C = ConstantFoldConstantExpression(NewC, DL, TLI); // If the incoming value is a different constant to // the one we saw previously, then give up. @@ -925,54 +999,55 @@ Constant *llvm::ConstantFoldInstruction(Instruction *I, const DataLayout &DL, } // Scan the operand list, checking to see if they are all constants, if so, - // hand off to ConstantFoldInstOperands. - SmallVector<Constant*, 8> Ops; - for (User::op_iterator i = I->op_begin(), e = I->op_end(); i != e; ++i) { - Constant *Op = dyn_cast<Constant>(*i); - if (!Op) - return nullptr; // All operands not constant! + // hand off to ConstantFoldInstOperandsImpl. + if (!all_of(I->operands(), [](Use &U) { return isa<Constant>(U); })) + return nullptr; + SmallVector<Constant *, 8> Ops; + for (const Use &OpU : I->operands()) { + auto *Op = cast<Constant>(&OpU); // Fold the Instruction's operands. - if (ConstantExpr *NewCE = dyn_cast<ConstantExpr>(Op)) + if (auto *NewCE = dyn_cast<ConstantExpr>(Op)) Op = ConstantFoldConstantExpression(NewCE, DL, TLI); Ops.push_back(Op); } - if (const CmpInst *CI = dyn_cast<CmpInst>(I)) + if (const auto *CI = dyn_cast<CmpInst>(I)) return ConstantFoldCompareInstOperands(CI->getPredicate(), Ops[0], Ops[1], DL, TLI); - if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + if (const auto *LI = dyn_cast<LoadInst>(I)) return ConstantFoldLoadInst(LI, DL); - if (InsertValueInst *IVI = dyn_cast<InsertValueInst>(I)) { + if (auto *IVI = dyn_cast<InsertValueInst>(I)) { return ConstantExpr::getInsertValue( cast<Constant>(IVI->getAggregateOperand()), cast<Constant>(IVI->getInsertedValueOperand()), IVI->getIndices()); } - if (ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(I)) { + if (auto *EVI = dyn_cast<ExtractValueInst>(I)) { return ConstantExpr::getExtractValue( cast<Constant>(EVI->getAggregateOperand()), EVI->getIndices()); } - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Ops, DL, TLI); + return ConstantFoldInstOperands(I, Ops, DL, TLI); } -static Constant * +namespace { + +Constant * ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout &DL, const TargetLibraryInfo *TLI, SmallPtrSetImpl<ConstantExpr *> &FoldedOps) { SmallVector<Constant *, 8> Ops; - for (User::const_op_iterator i = CE->op_begin(), e = CE->op_end(); i != e; - ++i) { - Constant *NewC = cast<Constant>(*i); + for (const Use &NewU : CE->operands()) { + auto *NewC = cast<Constant>(&NewU); // Recursively fold the ConstantExpr's operands. If we have already folded // a ConstantExpr, we don't have to process it again. - if (ConstantExpr *NewCE = dyn_cast<ConstantExpr>(NewC)) { + if (auto *NewCE = dyn_cast<ConstantExpr>(NewC)) { if (FoldedOps.insert(NewCE).second) NewC = ConstantFoldConstantExpressionImpl(NewCE, DL, TLI, FoldedOps); } @@ -982,12 +1057,13 @@ ConstantFoldConstantExpressionImpl(const ConstantExpr *CE, const DataLayout &DL, if (CE->isCompare()) return ConstantFoldCompareInstOperands(CE->getPredicate(), Ops[0], Ops[1], DL, TLI); - return ConstantFoldInstOperands(CE->getOpcode(), CE->getType(), Ops, DL, TLI); + + return ConstantFoldInstOperandsImpl(CE, CE->getType(), CE->getOpcode(), Ops, + DL, TLI); } -/// Attempt to fold the constant expression -/// using the specified DataLayout. If successful, the constant result is -/// result is returned, if not, null is returned. +} // end anonymous namespace + Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, const DataLayout &DL, const TargetLibraryInfo *TLI) { @@ -995,114 +1071,22 @@ Constant *llvm::ConstantFoldConstantExpression(const ConstantExpr *CE, return ConstantFoldConstantExpressionImpl(CE, DL, TLI, FoldedOps); } -/// Attempt to constant fold an instruction with the -/// specified opcode and operands. If successful, the constant result is -/// returned, if not, null is returned. Note that this function can fail when -/// attempting to fold instructions like loads and stores, which have no -/// constant expression form. -/// -/// TODO: This function neither utilizes nor preserves nsw/nuw/inbounds/etc -/// information, due to only being passed an opcode and operands. Constant -/// folding using this function strips this information. -/// -Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy, +Constant *llvm::ConstantFoldInstOperands(Instruction *I, ArrayRef<Constant *> Ops, const DataLayout &DL, const TargetLibraryInfo *TLI) { - // Handle easy binops first. - if (Instruction::isBinaryOp(Opcode)) { - if (isa<ConstantExpr>(Ops[0]) || isa<ConstantExpr>(Ops[1])) { - if (Constant *C = SymbolicallyEvaluateBinop(Opcode, Ops[0], Ops[1], DL)) - return C; - } - - return ConstantExpr::get(Opcode, Ops[0], Ops[1]); - } - - switch (Opcode) { - default: return nullptr; - case Instruction::ICmp: - case Instruction::FCmp: llvm_unreachable("Invalid for compares"); - case Instruction::Call: - if (Function *F = dyn_cast<Function>(Ops.back())) - if (canConstantFoldCallTo(F)) - return ConstantFoldCall(F, Ops.slice(0, Ops.size() - 1), TLI); - return nullptr; - case Instruction::PtrToInt: - // If the input is a inttoptr, eliminate the pair. This requires knowing - // the width of a pointer, so it can't be done in ConstantExpr::getCast. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) { - if (CE->getOpcode() == Instruction::IntToPtr) { - Constant *Input = CE->getOperand(0); - unsigned InWidth = Input->getType()->getScalarSizeInBits(); - unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); - if (PtrWidth < InWidth) { - Constant *Mask = - ConstantInt::get(CE->getContext(), - APInt::getLowBitsSet(InWidth, PtrWidth)); - Input = ConstantExpr::getAnd(Input, Mask); - } - // Do a zext or trunc to get to the dest size. - return ConstantExpr::getIntegerCast(Input, DestTy, false); - } - } - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::IntToPtr: - // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if - // the int size is >= the ptr size and the address spaces are the same. - // This requires knowing the width of a pointer, so it can't be done in - // ConstantExpr::getCast. - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Ops[0])) { - if (CE->getOpcode() == Instruction::PtrToInt) { - Constant *SrcPtr = CE->getOperand(0); - unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); - unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); - - if (MidIntSize >= SrcPtrSize) { - unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); - if (SrcAS == DestTy->getPointerAddressSpace()) - return FoldBitCast(CE->getOperand(0), DestTy, DL); - } - } - } - - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::Trunc: - case Instruction::ZExt: - case Instruction::SExt: - case Instruction::FPTrunc: - case Instruction::FPExt: - case Instruction::UIToFP: - case Instruction::SIToFP: - case Instruction::FPToUI: - case Instruction::FPToSI: - case Instruction::AddrSpaceCast: - return ConstantExpr::getCast(Opcode, Ops[0], DestTy); - case Instruction::BitCast: - return FoldBitCast(Ops[0], DestTy, DL); - case Instruction::Select: - return ConstantExpr::getSelect(Ops[0], Ops[1], Ops[2]); - case Instruction::ExtractElement: - return ConstantExpr::getExtractElement(Ops[0], Ops[1]); - case Instruction::InsertElement: - return ConstantExpr::getInsertElement(Ops[0], Ops[1], Ops[2]); - case Instruction::ShuffleVector: - return ConstantExpr::getShuffleVector(Ops[0], Ops[1], Ops[2]); - case Instruction::GetElementPtr: { - Type *SrcTy = nullptr; - if (Constant *C = CastGEPIndices(SrcTy, Ops, DestTy, DL, TLI)) - return C; - if (Constant *C = SymbolicallyEvaluateGEP(SrcTy, Ops, DestTy, DL, TLI)) - return C; + return ConstantFoldInstOperandsImpl(I, I->getType(), I->getOpcode(), Ops, DL, + TLI); +} - return ConstantExpr::getGetElementPtr(SrcTy, Ops[0], Ops.slice(1)); - } - } +Constant *llvm::ConstantFoldInstOperands(unsigned Opcode, Type *DestTy, + ArrayRef<Constant *> Ops, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { + assert(Opcode != Instruction::GetElementPtr && "Invalid for GEPs"); + return ConstantFoldInstOperandsImpl(nullptr, DestTy, Opcode, Ops, DL, TLI); } -/// Attempt to constant fold a compare -/// instruction (icmp/fcmp) with the specified operands. If it fails, it -/// returns a constant expression of the specified operands. Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Constant *Ops0, Constant *Ops1, const DataLayout &DL, @@ -1115,7 +1099,7 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, // FIXME: The following comment is out of data and the DataLayout is here now. // ConstantExpr::getCompare cannot do this, because it doesn't have DL // around to know if bit truncation is happening. - if (ConstantExpr *CE0 = dyn_cast<ConstantExpr>(Ops0)) { + if (auto *CE0 = dyn_cast<ConstantExpr>(Ops0)) { if (Ops1->isNullValue()) { if (CE0->getOpcode() == Instruction::IntToPtr) { Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); @@ -1139,7 +1123,7 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, } } - if (ConstantExpr *CE1 = dyn_cast<ConstantExpr>(Ops1)) { + if (auto *CE1 = dyn_cast<ConstantExpr>(Ops1)) { if (CE0->getOpcode() == CE1->getOpcode()) { if (CE0->getOpcode() == Instruction::IntToPtr) { Type *IntPtrTy = DL.getIntPtrType(CE0->getType()); @@ -1176,18 +1160,85 @@ Constant *llvm::ConstantFoldCompareInstOperands(unsigned Predicate, Predicate, CE0->getOperand(1), Ops1, DL, TLI); unsigned OpC = Predicate == ICmpInst::ICMP_EQ ? Instruction::And : Instruction::Or; - Constant *Ops[] = { LHS, RHS }; - return ConstantFoldInstOperands(OpC, LHS->getType(), Ops, DL, TLI); + return ConstantFoldBinaryOpOperands(OpC, LHS, RHS, DL); } } return ConstantExpr::getCompare(Predicate, Ops0, Ops1); } +Constant *llvm::ConstantFoldBinaryOpOperands(unsigned Opcode, Constant *LHS, + Constant *RHS, + const DataLayout &DL) { + assert(Instruction::isBinaryOp(Opcode)); + if (isa<ConstantExpr>(LHS) || isa<ConstantExpr>(RHS)) + if (Constant *C = SymbolicallyEvaluateBinop(Opcode, LHS, RHS, DL)) + return C; + + return ConstantExpr::get(Opcode, LHS, RHS); +} + +Constant *llvm::ConstantFoldCastOperand(unsigned Opcode, Constant *C, + Type *DestTy, const DataLayout &DL) { + assert(Instruction::isCast(Opcode)); + switch (Opcode) { + default: + llvm_unreachable("Missing case"); + case Instruction::PtrToInt: + // If the input is a inttoptr, eliminate the pair. This requires knowing + // the width of a pointer, so it can't be done in ConstantExpr::getCast. + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::IntToPtr) { + Constant *Input = CE->getOperand(0); + unsigned InWidth = Input->getType()->getScalarSizeInBits(); + unsigned PtrWidth = DL.getPointerTypeSizeInBits(CE->getType()); + if (PtrWidth < InWidth) { + Constant *Mask = + ConstantInt::get(CE->getContext(), + APInt::getLowBitsSet(InWidth, PtrWidth)); + Input = ConstantExpr::getAnd(Input, Mask); + } + // Do a zext or trunc to get to the dest size. + return ConstantExpr::getIntegerCast(Input, DestTy, false); + } + } + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::IntToPtr: + // If the input is a ptrtoint, turn the pair into a ptr to ptr bitcast if + // the int size is >= the ptr size and the address spaces are the same. + // This requires knowing the width of a pointer, so it can't be done in + // ConstantExpr::getCast. + if (auto *CE = dyn_cast<ConstantExpr>(C)) { + if (CE->getOpcode() == Instruction::PtrToInt) { + Constant *SrcPtr = CE->getOperand(0); + unsigned SrcPtrSize = DL.getPointerTypeSizeInBits(SrcPtr->getType()); + unsigned MidIntSize = CE->getType()->getScalarSizeInBits(); + + if (MidIntSize >= SrcPtrSize) { + unsigned SrcAS = SrcPtr->getType()->getPointerAddressSpace(); + if (SrcAS == DestTy->getPointerAddressSpace()) + return FoldBitCast(CE->getOperand(0), DestTy, DL); + } + } + } + + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::AddrSpaceCast: + return ConstantExpr::getCast(Opcode, C, DestTy); + case Instruction::BitCast: + return FoldBitCast(C, DestTy, DL); + } +} -/// Given a constant and a getelementptr constantexpr, return the constant value -/// being addressed by the constant expression, or null if something is funny -/// and we can't decide. Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, ConstantExpr *CE) { if (!CE->getOperand(1)->isNullValue()) @@ -1203,27 +1254,23 @@ Constant *llvm::ConstantFoldLoadThroughGEPConstantExpr(Constant *C, return C; } -/// Given a constant and getelementptr indices (with an *implied* zero pointer -/// index that is not in the list), return the constant value being addressed by -/// a virtual load, or null if something is funny and we can't decide. -Constant *llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, - ArrayRef<Constant*> Indices) { +Constant * +llvm::ConstantFoldLoadThroughGEPIndices(Constant *C, + ArrayRef<Constant *> Indices) { // Loop over all of the operands, tracking down which value we are // addressing. - for (unsigned i = 0, e = Indices.size(); i != e; ++i) { - C = C->getAggregateElement(Indices[i]); + for (Constant *Index : Indices) { + C = C->getAggregateElement(Index); if (!C) return nullptr; } return C; } - //===----------------------------------------------------------------------===// // Constant Folding for Calls // -/// Return true if it's even possible to fold a call to the specified function. bool llvm::canConstantFoldCallTo(const Function *F) { switch (F->getIntrinsicID()) { case Intrinsic::fabs: @@ -1252,6 +1299,7 @@ bool llvm::canConstantFoldCallTo(const Function *F) { case Intrinsic::fmuladd: case Intrinsic::copysign: case Intrinsic::round: + case Intrinsic::masked_load: case Intrinsic::sadd_with_overflow: case Intrinsic::uadd_with_overflow: case Intrinsic::ssub_with_overflow: @@ -1260,6 +1308,7 @@ bool llvm::canConstantFoldCallTo(const Function *F) { case Intrinsic::umul_with_overflow: case Intrinsic::convert_from_fp16: case Intrinsic::convert_to_fp16: + case Intrinsic::bitreverse: case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvtss2si64: case Intrinsic::x86_sse_cvttss2si: @@ -1309,7 +1358,9 @@ bool llvm::canConstantFoldCallTo(const Function *F) { } } -static Constant *GetConstantFoldFPValue(double V, Type *Ty) { +namespace { + +Constant *GetConstantFoldFPValue(double V, Type *Ty) { if (Ty->isHalfTy()) { APFloat APF(V); bool unused; @@ -1321,12 +1372,10 @@ static Constant *GetConstantFoldFPValue(double V, Type *Ty) { if (Ty->isDoubleTy()) return ConstantFP::get(Ty->getContext(), APFloat(V)); llvm_unreachable("Can only constant fold half/float/double"); - } -namespace { /// Clear the floating-point exception state. -static inline void llvm_fenv_clearexcept() { +inline void llvm_fenv_clearexcept() { #if defined(HAVE_FENV_H) && HAVE_DECL_FE_ALL_EXCEPT feclearexcept(FE_ALL_EXCEPT); #endif @@ -1334,7 +1383,7 @@ static inline void llvm_fenv_clearexcept() { } /// Test if a floating-point exception was raised. -static inline bool llvm_fenv_testexcept() { +inline bool llvm_fenv_testexcept() { int errno_val = errno; if (errno_val == ERANGE || errno_val == EDOM) return true; @@ -1344,10 +1393,8 @@ static inline bool llvm_fenv_testexcept() { #endif return false; } -} // End namespace -static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, - Type *Ty) { +Constant *ConstantFoldFP(double (*NativeFP)(double), double V, Type *Ty) { llvm_fenv_clearexcept(); V = NativeFP(V); if (llvm_fenv_testexcept()) { @@ -1358,8 +1405,8 @@ static Constant *ConstantFoldFP(double (*NativeFP)(double), double V, return GetConstantFoldFPValue(V, Ty); } -static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), - double V, double W, Type *Ty) { +Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V, + double W, Type *Ty) { llvm_fenv_clearexcept(); V = NativeFP(V, W); if (llvm_fenv_testexcept()) { @@ -1377,8 +1424,8 @@ static Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), /// integer type Ty is used to select how many bits are available for the /// result. Returns null if the conversion cannot be performed, otherwise /// returns the Constant value resulting from the conversion. -static Constant *ConstantFoldConvertToInt(const APFloat &Val, - bool roundTowardZero, Type *Ty) { +Constant *ConstantFoldConvertToInt(const APFloat &Val, bool roundTowardZero, + Type *Ty) { // All of these conversion intrinsics form an integer of at most 64bits. unsigned ResultWidth = Ty->getIntegerBitWidth(); assert(ResultWidth <= 64 && @@ -1396,7 +1443,7 @@ static Constant *ConstantFoldConvertToInt(const APFloat &Val, return ConstantInt::get(Ty, UIntVal, /*isSigned=*/true); } -static double getValueAsDouble(ConstantFP *Op) { +double getValueAsDouble(ConstantFP *Op) { Type *Ty = Op->getType(); if (Ty->isFloatTy()) @@ -1411,11 +1458,16 @@ static double getValueAsDouble(ConstantFP *Op) { return APF.convertToDouble(); } -static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, - Type *Ty, ArrayRef<Constant *> Operands, - const TargetLibraryInfo *TLI) { +Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, + ArrayRef<Constant *> Operands, + const TargetLibraryInfo *TLI) { if (Operands.size() == 1) { - if (ConstantFP *Op = dyn_cast<ConstantFP>(Operands[0])) { + if (isa<UndefValue>(Operands[0])) { + // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN + if (IntrinsicID == Intrinsic::cos) + return Constant::getNullValue(Ty); + } + if (auto *Op = dyn_cast<ConstantFP>(Operands[0])) { if (IntrinsicID == Intrinsic::convert_to_fp16) { APFloat Val(Op->getValueAPF()); @@ -1586,12 +1638,14 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, return nullptr; } - if (ConstantInt *Op = dyn_cast<ConstantInt>(Operands[0])) { + if (auto *Op = dyn_cast<ConstantInt>(Operands[0])) { switch (IntrinsicID) { case Intrinsic::bswap: return ConstantInt::get(Ty->getContext(), Op->getValue().byteSwap()); case Intrinsic::ctpop: return ConstantInt::get(Ty, Op->getValue().countPopulation()); + case Intrinsic::bitreverse: + return ConstantInt::get(Ty->getContext(), Op->getValue().reverseBits()); case Intrinsic::convert_from_fp16: { APFloat Val(APFloat::IEEEhalf, Op->getValue()); @@ -1614,7 +1668,7 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, // Support ConstantVector in case we have an Undef in the top. if (isa<ConstantVector>(Operands[0]) || isa<ConstantDataVector>(Operands[0])) { - Constant *Op = cast<Constant>(Operands[0]); + auto *Op = cast<Constant>(Operands[0]); switch (IntrinsicID) { default: break; case Intrinsic::x86_sse_cvtss2si: @@ -1646,12 +1700,12 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, } if (Operands.size() == 2) { - if (ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) { + if (auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy()) return nullptr; double Op1V = getValueAsDouble(Op1); - if (ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) { + if (auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { if (Op2->getType() != Op1->getType()) return nullptr; @@ -1661,7 +1715,7 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, } if (IntrinsicID == Intrinsic::copysign) { APFloat V1 = Op1->getValueAPF(); - APFloat V2 = Op2->getValueAPF(); + const APFloat &V2 = Op2->getValueAPF(); V1.copySign(V2); return ConstantFP::get(Ty->getContext(), V1); } @@ -1689,7 +1743,7 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, if ((Name == "atan2" && TLI->has(LibFunc::atan2)) || (Name == "atan2f" && TLI->has(LibFunc::atan2f))) return ConstantFoldBinaryFP(atan2, Op1V, Op2V, Ty); - } else if (ConstantInt *Op2C = dyn_cast<ConstantInt>(Operands[1])) { + } else if (auto *Op2C = dyn_cast<ConstantInt>(Operands[1])) { if (IntrinsicID == Intrinsic::powi && Ty->isHalfTy()) return ConstantFP::get(Ty->getContext(), APFloat((float)std::pow((float)Op1V, @@ -1706,8 +1760,8 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, return nullptr; } - if (ConstantInt *Op1 = dyn_cast<ConstantInt>(Operands[0])) { - if (ConstantInt *Op2 = dyn_cast<ConstantInt>(Operands[1])) { + if (auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) { + if (auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) { switch (IntrinsicID) { default: break; case Intrinsic::sadd_with_overflow: @@ -1764,9 +1818,9 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, if (Operands.size() != 3) return nullptr; - if (const ConstantFP *Op1 = dyn_cast<ConstantFP>(Operands[0])) { - if (const ConstantFP *Op2 = dyn_cast<ConstantFP>(Operands[1])) { - if (const ConstantFP *Op3 = dyn_cast<ConstantFP>(Operands[2])) { + if (const auto *Op1 = dyn_cast<ConstantFP>(Operands[0])) { + if (const auto *Op2 = dyn_cast<ConstantFP>(Operands[1])) { + if (const auto *Op3 = dyn_cast<ConstantFP>(Operands[2])) { switch (IntrinsicID) { default: break; case Intrinsic::fma: @@ -1788,14 +1842,53 @@ static Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, return nullptr; } -static Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, - VectorType *VTy, - ArrayRef<Constant *> Operands, - const TargetLibraryInfo *TLI) { +Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, + VectorType *VTy, ArrayRef<Constant *> Operands, + const DataLayout &DL, + const TargetLibraryInfo *TLI) { SmallVector<Constant *, 4> Result(VTy->getNumElements()); SmallVector<Constant *, 4> Lane(Operands.size()); Type *Ty = VTy->getElementType(); + if (IntrinsicID == Intrinsic::masked_load) { + auto *SrcPtr = Operands[0]; + auto *Mask = Operands[2]; + auto *Passthru = Operands[3]; + + Constant *VecData = ConstantFoldLoadFromConstPtr(SrcPtr, VTy, DL); + + SmallVector<Constant *, 32> NewElements; + for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { + auto *MaskElt = Mask->getAggregateElement(I); + if (!MaskElt) + break; + auto *PassthruElt = Passthru->getAggregateElement(I); + auto *VecElt = VecData ? VecData->getAggregateElement(I) : nullptr; + if (isa<UndefValue>(MaskElt)) { + if (PassthruElt) + NewElements.push_back(PassthruElt); + else if (VecElt) + NewElements.push_back(VecElt); + else + return nullptr; + } + if (MaskElt->isNullValue()) { + if (!PassthruElt) + return nullptr; + NewElements.push_back(PassthruElt); + } else if (MaskElt->isOneValue()) { + if (!VecElt) + return nullptr; + NewElements.push_back(VecElt); + } else { + return nullptr; + } + } + if (NewElements.size() != VTy->getNumElements()) + return nullptr; + return ConstantVector::get(NewElements); + } + for (unsigned I = 0, E = VTy->getNumElements(); I != E; ++I) { // Gather a column of constants. for (unsigned J = 0, JE = Operands.size(); J != JE; ++J) { @@ -1816,8 +1909,8 @@ static Constant *ConstantFoldVectorCall(StringRef Name, unsigned IntrinsicID, return ConstantVector::get(Result); } -/// Attempt to constant fold a call to the specified function -/// with the specified arguments, returning null if unsuccessful. +} // end anonymous namespace + Constant * llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands, const TargetLibraryInfo *TLI) { @@ -1827,8 +1920,9 @@ llvm::ConstantFoldCall(Function *F, ArrayRef<Constant *> Operands, Type *Ty = F->getReturnType(); - if (VectorType *VTy = dyn_cast<VectorType>(Ty)) - return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands, TLI); + if (auto *VTy = dyn_cast<VectorType>(Ty)) + return ConstantFoldVectorCall(Name, F->getIntrinsicID(), VTy, Operands, + F->getParent()->getDataLayout(), TLI); return ConstantFoldScalarCall(Name, F->getIntrinsicID(), Ty, Operands, TLI); } diff --git a/lib/Analysis/CostModel.cpp b/lib/Analysis/CostModel.cpp index 0383cbfbbe4c..68a4bea96baa 100644 --- a/lib/Analysis/CostModel.cpp +++ b/lib/Analysis/CostModel.cpp @@ -504,8 +504,12 @@ unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J) Args.push_back(II->getArgOperand(J)); + FastMathFlags FMF; + if (auto *FPMO = dyn_cast<FPMathOperator>(II)) + FMF = FPMO->getFastMathFlags(); + return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), - Args); + Args, FMF); } return -1; default: @@ -518,16 +522,15 @@ void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { if (!F) return; - for (Function::iterator B = F->begin(), BE = F->end(); B != BE; ++B) { - for (BasicBlock::iterator it = B->begin(), e = B->end(); it != e; ++it) { - Instruction *Inst = &*it; - unsigned Cost = getInstructionCost(Inst); + for (BasicBlock &B : *F) { + for (Instruction &Inst : B) { + unsigned Cost = getInstructionCost(&Inst); if (Cost != (unsigned)-1) OS << "Cost Model: Found an estimated cost of " << Cost; else OS << "Cost Model: Unknown cost"; - OS << " for instruction: "<< *Inst << "\n"; + OS << " for instruction: " << Inst << "\n"; } } } diff --git a/lib/Analysis/Delinearization.cpp b/lib/Analysis/Delinearization.cpp index baee8b3b084b..dd5af9d43ef8 100644 --- a/lib/Analysis/Delinearization.cpp +++ b/lib/Analysis/Delinearization.cpp @@ -14,11 +14,11 @@ // //===----------------------------------------------------------------------===// -#include "llvm/IR/Constants.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" @@ -26,7 +26,6 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" #include "llvm/Pass.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/lib/Analysis/DemandedBits.cpp b/lib/Analysis/DemandedBits.cpp index 6f92ba6289a4..a3f8b7fda08a 100644 --- a/lib/Analysis/DemandedBits.cpp +++ b/lib/Analysis/DemandedBits.cpp @@ -20,8 +20,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/DemandedBits.h" -#include "llvm/Transforms/Scalar.h" -#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -44,25 +42,29 @@ using namespace llvm; #define DEBUG_TYPE "demanded-bits" -char DemandedBits::ID = 0; -INITIALIZE_PASS_BEGIN(DemandedBits, "demanded-bits", "Demanded bits analysis", - false, false) +char DemandedBitsWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(DemandedBitsWrapperPass, "demanded-bits", + "Demanded bits analysis", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(DemandedBits, "demanded-bits", "Demanded bits analysis", - false, false) +INITIALIZE_PASS_END(DemandedBitsWrapperPass, "demanded-bits", + "Demanded bits analysis", false, false) -DemandedBits::DemandedBits() : FunctionPass(ID), F(nullptr), Analyzed(false) { - initializeDemandedBitsPass(*PassRegistry::getPassRegistry()); +DemandedBitsWrapperPass::DemandedBitsWrapperPass() : FunctionPass(ID) { + initializeDemandedBitsWrapperPassPass(*PassRegistry::getPassRegistry()); } -void DemandedBits::getAnalysisUsage(AnalysisUsage &AU) const { +void DemandedBitsWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesCFG(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesAll(); } +void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const { + DB->print(OS); +} + static bool isAlwaysLive(Instruction *I) { return isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) || I->isEHPad() || I->mayHaveSideEffects(); @@ -86,13 +88,13 @@ void DemandedBits::determineLiveOperandBits( KnownZero = APInt(BitWidth, 0); KnownOne = APInt(BitWidth, 0); computeKnownBits(const_cast<Value *>(V1), KnownZero, KnownOne, DL, 0, - AC, UserI, DT); + &AC, UserI, &DT); if (V2) { KnownZero2 = APInt(BitWidth, 0); KnownOne2 = APInt(BitWidth, 0); computeKnownBits(const_cast<Value *>(V2), KnownZero2, KnownOne2, DL, - 0, AC, UserI, DT); + 0, &AC, UserI, &DT); } }; @@ -245,19 +247,22 @@ void DemandedBits::determineLiveOperandBits( } } -bool DemandedBits::runOnFunction(Function& Fn) { - F = &Fn; - Analyzed = false; +bool DemandedBitsWrapperPass::runOnFunction(Function &F) { + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + DB.emplace(F, AC, DT); return false; } +void DemandedBitsWrapperPass::releaseMemory() { + DB.reset(); +} + void DemandedBits::performAnalysis() { if (Analyzed) // Analysis already completed for this function. return; Analyzed = true; - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(*F); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); Visited.clear(); AliveBits.clear(); @@ -265,7 +270,7 @@ void DemandedBits::performAnalysis() { SmallVector<Instruction*, 128> Worklist; // Collect the set of "root" instructions that are known live. - for (Instruction &I : instructions(*F)) { + for (Instruction &I : instructions(F)) { if (!isAlwaysLive(&I)) continue; @@ -370,16 +375,29 @@ bool DemandedBits::isInstructionDead(Instruction *I) { !isAlwaysLive(I); } -void DemandedBits::print(raw_ostream &OS, const Module *M) const { - // This is gross. But the alternative is making all the state mutable - // just because of this one debugging method. - const_cast<DemandedBits*>(this)->performAnalysis(); +void DemandedBits::print(raw_ostream &OS) { + performAnalysis(); for (auto &KV : AliveBits) { OS << "DemandedBits: 0x" << utohexstr(KV.second.getLimitedValue()) << " for " << *KV.first << "\n"; } } -FunctionPass *llvm::createDemandedBitsPass() { - return new DemandedBits(); +FunctionPass *llvm::createDemandedBitsWrapperPass() { + return new DemandedBitsWrapperPass(); +} + +char DemandedBitsAnalysis::PassID; + +DemandedBits DemandedBitsAnalysis::run(Function &F, + AnalysisManager<Function> &AM) { + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + return DemandedBits(F, AC, DT); +} + +PreservedAnalyses DemandedBitsPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + AM.getResult<DemandedBitsAnalysis>(F).print(OS); + return PreservedAnalyses::all(); } diff --git a/lib/Analysis/DependenceAnalysis.cpp b/lib/Analysis/DependenceAnalysis.cpp index 4040ad3cacd5..eb4d925fea73 100644 --- a/lib/Analysis/DependenceAnalysis.cpp +++ b/lib/Analysis/DependenceAnalysis.cpp @@ -114,36 +114,43 @@ Delinearize("da-delinearize", cl::init(false), cl::Hidden, cl::ZeroOrMore, //===----------------------------------------------------------------------===// // basics -INITIALIZE_PASS_BEGIN(DependenceAnalysis, "da", +DependenceAnalysis::Result +DependenceAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AA = FAM.getResult<AAManager>(F); + auto &SE = FAM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = FAM.getResult<LoopAnalysis>(F); + return DependenceInfo(&F, &AA, &SE, &LI); +} + +char DependenceAnalysis::PassID; + +INITIALIZE_PASS_BEGIN(DependenceAnalysisWrapperPass, "da", "Dependence Analysis", true, true) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(DependenceAnalysis, "da", - "Dependence Analysis", true, true) +INITIALIZE_PASS_END(DependenceAnalysisWrapperPass, "da", "Dependence Analysis", + true, true) -char DependenceAnalysis::ID = 0; +char DependenceAnalysisWrapperPass::ID = 0; - -FunctionPass *llvm::createDependenceAnalysisPass() { - return new DependenceAnalysis(); +FunctionPass *llvm::createDependenceAnalysisWrapperPass() { + return new DependenceAnalysisWrapperPass(); } - -bool DependenceAnalysis::runOnFunction(Function &F) { - this->F = &F; - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); +bool DependenceAnalysisWrapperPass::runOnFunction(Function &F) { + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + info.reset(new DependenceInfo(&F, &AA, &SE, &LI)); return false; } +DependenceInfo &DependenceAnalysisWrapperPass::getDI() const { return *info; } -void DependenceAnalysis::releaseMemory() { -} - +void DependenceAnalysisWrapperPass::releaseMemory() { info.reset(); } -void DependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { +void DependenceAnalysisWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive<AAResultsWrapperPass>(); AU.addRequiredTransitive<ScalarEvolutionWrapperPass>(); @@ -155,11 +162,10 @@ void DependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { // Looks through the function, noting loads and stores. // Calls depends() on every possible pair and prints out the result. // Ignores all other instructions. -static -void dumpExampleDependence(raw_ostream &OS, Function *F, - DependenceAnalysis *DA) { - for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); - SrcI != SrcE; ++SrcI) { +static void dumpExampleDependence(raw_ostream &OS, DependenceInfo *DA) { + auto *F = DA->getFunction(); + for (inst_iterator SrcI = inst_begin(F), SrcE = inst_end(F); SrcI != SrcE; + ++SrcI) { if (isa<StoreInst>(*SrcI) || isa<LoadInst>(*SrcI)) { for (inst_iterator DstI = SrcI, DstE = inst_end(F); DstI != DstE; ++DstI) { @@ -183,9 +189,9 @@ void dumpExampleDependence(raw_ostream &OS, Function *F, } } - -void DependenceAnalysis::print(raw_ostream &OS, const Module*) const { - dumpExampleDependence(OS, F, const_cast<DependenceAnalysis *>(this)); +void DependenceAnalysisWrapperPass::print(raw_ostream &OS, + const Module *) const { + dumpExampleDependence(OS, info.get()); } //===----------------------------------------------------------------------===// @@ -286,11 +292,11 @@ bool FullDependence::isSplitable(unsigned Level) const { //===----------------------------------------------------------------------===// -// DependenceAnalysis::Constraint methods +// DependenceInfo::Constraint methods // If constraint is a point <X, Y>, returns X. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getX() const { +const SCEV *DependenceInfo::Constraint::getX() const { assert(Kind == Point && "Kind should be Point"); return A; } @@ -298,7 +304,7 @@ const SCEV *DependenceAnalysis::Constraint::getX() const { // If constraint is a point <X, Y>, returns Y. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getY() const { +const SCEV *DependenceInfo::Constraint::getY() const { assert(Kind == Point && "Kind should be Point"); return B; } @@ -306,7 +312,7 @@ const SCEV *DependenceAnalysis::Constraint::getY() const { // If constraint is a line AX + BY = C, returns A. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getA() const { +const SCEV *DependenceInfo::Constraint::getA() const { assert((Kind == Line || Kind == Distance) && "Kind should be Line (or Distance)"); return A; @@ -315,7 +321,7 @@ const SCEV *DependenceAnalysis::Constraint::getA() const { // If constraint is a line AX + BY = C, returns B. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getB() const { +const SCEV *DependenceInfo::Constraint::getB() const { assert((Kind == Line || Kind == Distance) && "Kind should be Line (or Distance)"); return B; @@ -324,7 +330,7 @@ const SCEV *DependenceAnalysis::Constraint::getB() const { // If constraint is a line AX + BY = C, returns C. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getC() const { +const SCEV *DependenceInfo::Constraint::getC() const { assert((Kind == Line || Kind == Distance) && "Kind should be Line (or Distance)"); return C; @@ -333,34 +339,29 @@ const SCEV *DependenceAnalysis::Constraint::getC() const { // If constraint is a distance, returns D. // Otherwise assert. -const SCEV *DependenceAnalysis::Constraint::getD() const { +const SCEV *DependenceInfo::Constraint::getD() const { assert(Kind == Distance && "Kind should be Distance"); return SE->getNegativeSCEV(C); } // Returns the loop associated with this constraint. -const Loop *DependenceAnalysis::Constraint::getAssociatedLoop() const { +const Loop *DependenceInfo::Constraint::getAssociatedLoop() const { assert((Kind == Distance || Kind == Line || Kind == Point) && "Kind should be Distance, Line, or Point"); return AssociatedLoop; } - -void DependenceAnalysis::Constraint::setPoint(const SCEV *X, - const SCEV *Y, - const Loop *CurLoop) { +void DependenceInfo::Constraint::setPoint(const SCEV *X, const SCEV *Y, + const Loop *CurLoop) { Kind = Point; A = X; B = Y; AssociatedLoop = CurLoop; } - -void DependenceAnalysis::Constraint::setLine(const SCEV *AA, - const SCEV *BB, - const SCEV *CC, - const Loop *CurLoop) { +void DependenceInfo::Constraint::setLine(const SCEV *AA, const SCEV *BB, + const SCEV *CC, const Loop *CurLoop) { Kind = Line; A = AA; B = BB; @@ -368,9 +369,8 @@ void DependenceAnalysis::Constraint::setLine(const SCEV *AA, AssociatedLoop = CurLoop; } - -void DependenceAnalysis::Constraint::setDistance(const SCEV *D, - const Loop *CurLoop) { +void DependenceInfo::Constraint::setDistance(const SCEV *D, + const Loop *CurLoop) { Kind = Distance; A = SE->getOne(D->getType()); B = SE->getNegativeSCEV(A); @@ -378,20 +378,16 @@ void DependenceAnalysis::Constraint::setDistance(const SCEV *D, AssociatedLoop = CurLoop; } +void DependenceInfo::Constraint::setEmpty() { Kind = Empty; } -void DependenceAnalysis::Constraint::setEmpty() { - Kind = Empty; -} - - -void DependenceAnalysis::Constraint::setAny(ScalarEvolution *NewSE) { +void DependenceInfo::Constraint::setAny(ScalarEvolution *NewSE) { SE = NewSE; Kind = Any; } // For debugging purposes. Dumps the constraint out to OS. -void DependenceAnalysis::Constraint::dump(raw_ostream &OS) const { +void DependenceInfo::Constraint::dump(raw_ostream &OS) const { if (isEmpty()) OS << " Empty\n"; else if (isAny()) @@ -416,8 +412,7 @@ void DependenceAnalysis::Constraint::dump(raw_ostream &OS) const { // Practical Dependence Testing // Goff, Kennedy, Tseng // PLDI 1991 -bool DependenceAnalysis::intersectConstraints(Constraint *X, - const Constraint *Y) { +bool DependenceInfo::intersectConstraints(Constraint *X, const Constraint *Y) { ++DeltaApplications; DEBUG(dbgs() << "\tintersect constraints\n"); DEBUG(dbgs() << "\t X ="; X->dump(dbgs())); @@ -528,7 +523,7 @@ bool DependenceAnalysis::intersectConstraints(Constraint *X, } if (const SCEVConstant *CUB = collectConstantUpperBound(X->getAssociatedLoop(), Prod1->getType())) { - APInt UpperBound = CUB->getAPInt(); + const APInt &UpperBound = CUB->getAPInt(); DEBUG(dbgs() << "\t\tupper bound = " << UpperBound << "\n"); if (Xq.sgt(UpperBound) || Yq.sgt(UpperBound)) { X->setEmpty(); @@ -569,7 +564,7 @@ bool DependenceAnalysis::intersectConstraints(Constraint *X, //===----------------------------------------------------------------------===// -// DependenceAnalysis methods +// DependenceInfo methods // For debugging purposes. Dumps a dependence to OS. void Dependence::dump(raw_ostream &OS) const { @@ -709,8 +704,8 @@ Value *getPointerOperand(Instruction *I) { // e - 5 // f - 6 // g - 7 = MaxLevels -void DependenceAnalysis::establishNestingLevels(const Instruction *Src, - const Instruction *Dst) { +void DependenceInfo::establishNestingLevels(const Instruction *Src, + const Instruction *Dst) { const BasicBlock *SrcBlock = Src->getParent(); const BasicBlock *DstBlock = Dst->getParent(); unsigned SrcLevel = LI->getLoopDepth(SrcBlock); @@ -739,14 +734,14 @@ void DependenceAnalysis::establishNestingLevels(const Instruction *Src, // Given one of the loops containing the source, return // its level index in our numbering scheme. -unsigned DependenceAnalysis::mapSrcLoop(const Loop *SrcLoop) const { +unsigned DependenceInfo::mapSrcLoop(const Loop *SrcLoop) const { return SrcLoop->getLoopDepth(); } // Given one of the loops containing the destination, // return its level index in our numbering scheme. -unsigned DependenceAnalysis::mapDstLoop(const Loop *DstLoop) const { +unsigned DependenceInfo::mapDstLoop(const Loop *DstLoop) const { unsigned D = DstLoop->getLoopDepth(); if (D > CommonLevels) return D - CommonLevels + SrcLevels; @@ -756,8 +751,8 @@ unsigned DependenceAnalysis::mapDstLoop(const Loop *DstLoop) const { // Returns true if Expression is loop invariant in LoopNest. -bool DependenceAnalysis::isLoopInvariant(const SCEV *Expression, - const Loop *LoopNest) const { +bool DependenceInfo::isLoopInvariant(const SCEV *Expression, + const Loop *LoopNest) const { if (!LoopNest) return true; return SE->isLoopInvariant(Expression, LoopNest) && @@ -768,9 +763,9 @@ bool DependenceAnalysis::isLoopInvariant(const SCEV *Expression, // Finds the set of loops from the LoopNest that // have a level <= CommonLevels and are referred to by the SCEV Expression. -void DependenceAnalysis::collectCommonLoops(const SCEV *Expression, - const Loop *LoopNest, - SmallBitVector &Loops) const { +void DependenceInfo::collectCommonLoops(const SCEV *Expression, + const Loop *LoopNest, + SmallBitVector &Loops) const { while (LoopNest) { unsigned Level = LoopNest->getLoopDepth(); if (Level <= CommonLevels && !SE->isLoopInvariant(Expression, LoopNest)) @@ -779,16 +774,16 @@ void DependenceAnalysis::collectCommonLoops(const SCEV *Expression, } } -void DependenceAnalysis::unifySubscriptType(ArrayRef<Subscript *> Pairs) { +void DependenceInfo::unifySubscriptType(ArrayRef<Subscript *> Pairs) { unsigned widestWidthSeen = 0; Type *widestType; // Go through each pair and find the widest bit to which we need // to extend all of them. - for (unsigned i = 0; i < Pairs.size(); i++) { - const SCEV *Src = Pairs[i]->Src; - const SCEV *Dst = Pairs[i]->Dst; + for (Subscript *Pair : Pairs) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); if (SrcTy == nullptr || DstTy == nullptr) { @@ -811,9 +806,9 @@ void DependenceAnalysis::unifySubscriptType(ArrayRef<Subscript *> Pairs) { assert(widestWidthSeen > 0); // Now extend each pair to the widest seen. - for (unsigned i = 0; i < Pairs.size(); i++) { - const SCEV *Src = Pairs[i]->Src; - const SCEV *Dst = Pairs[i]->Dst; + for (Subscript *Pair : Pairs) { + const SCEV *Src = Pair->Src; + const SCEV *Dst = Pair->Dst; IntegerType *SrcTy = dyn_cast<IntegerType>(Src->getType()); IntegerType *DstTy = dyn_cast<IntegerType>(Dst->getType()); if (SrcTy == nullptr || DstTy == nullptr) { @@ -824,10 +819,10 @@ void DependenceAnalysis::unifySubscriptType(ArrayRef<Subscript *> Pairs) { } if (SrcTy->getBitWidth() < widestWidthSeen) // Sign-extend Src to widestType - Pairs[i]->Src = SE->getSignExtendExpr(Src, widestType); + Pair->Src = SE->getSignExtendExpr(Src, widestType); if (DstTy->getBitWidth() < widestWidthSeen) { // Sign-extend Dst to widestType - Pairs[i]->Dst = SE->getSignExtendExpr(Dst, widestType); + Pair->Dst = SE->getSignExtendExpr(Dst, widestType); } } } @@ -836,7 +831,7 @@ void DependenceAnalysis::unifySubscriptType(ArrayRef<Subscript *> Pairs) { // If the source and destination are identically sign (or zero) // extended, it strips off the extension in an effect to simplify // the actual analysis. -void DependenceAnalysis::removeMatchingExtensions(Subscript *Pair) { +void DependenceInfo::removeMatchingExtensions(Subscript *Pair) { const SCEV *Src = Pair->Src; const SCEV *Dst = Pair->Dst; if ((isa<SCEVZeroExtendExpr>(Src) && isa<SCEVZeroExtendExpr>(Dst)) || @@ -855,9 +850,8 @@ void DependenceAnalysis::removeMatchingExtensions(Subscript *Pair) { // Examine the scev and return true iff it's linear. // Collect any loops mentioned in the set of "Loops". -bool DependenceAnalysis::checkSrcSubscript(const SCEV *Src, - const Loop *LoopNest, - SmallBitVector &Loops) { +bool DependenceInfo::checkSrcSubscript(const SCEV *Src, const Loop *LoopNest, + SmallBitVector &Loops) { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Src); if (!AddRec) return isLoopInvariant(Src, LoopNest); @@ -881,9 +875,8 @@ bool DependenceAnalysis::checkSrcSubscript(const SCEV *Src, // Examine the scev and return true iff it's linear. // Collect any loops mentioned in the set of "Loops". -bool DependenceAnalysis::checkDstSubscript(const SCEV *Dst, - const Loop *LoopNest, - SmallBitVector &Loops) { +bool DependenceInfo::checkDstSubscript(const SCEV *Dst, const Loop *LoopNest, + SmallBitVector &Loops) { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Dst); if (!AddRec) return isLoopInvariant(Dst, LoopNest); @@ -907,10 +900,10 @@ bool DependenceAnalysis::checkDstSubscript(const SCEV *Dst, // Examines the subscript pair (the Src and Dst SCEVs) // and classifies it as either ZIV, SIV, RDIV, MIV, or Nonlinear. // Collects the associated loops in a set. -DependenceAnalysis::Subscript::ClassificationKind -DependenceAnalysis::classifyPair(const SCEV *Src, const Loop *SrcLoopNest, - const SCEV *Dst, const Loop *DstLoopNest, - SmallBitVector &Loops) { +DependenceInfo::Subscript::ClassificationKind +DependenceInfo::classifyPair(const SCEV *Src, const Loop *SrcLoopNest, + const SCEV *Dst, const Loop *DstLoopNest, + SmallBitVector &Loops) { SmallBitVector SrcLoops(MaxLevels + 1); SmallBitVector DstLoops(MaxLevels + 1); if (!checkSrcSubscript(Src, SrcLoopNest, SrcLoops)) @@ -942,9 +935,8 @@ DependenceAnalysis::classifyPair(const SCEV *Src, const Loop *SrcLoopNest, // If SCEV::isKnownPredicate can't prove the predicate, // we try simple subtraction, which seems to help in some cases // involving symbolics. -bool DependenceAnalysis::isKnownPredicate(ICmpInst::Predicate Pred, - const SCEV *X, - const SCEV *Y) const { +bool DependenceInfo::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *X, + const SCEV *Y) const { if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE) { if ((isa<SCEVSignExtendExpr>(X) && @@ -995,8 +987,7 @@ bool DependenceAnalysis::isKnownPredicate(ICmpInst::Predicate Pred, // Truncating is safe when subscripts are known not to wrap. Cases without // nowrap flags should have been rejected earlier. // Return null if no bound available. -const SCEV *DependenceAnalysis::collectUpperBound(const Loop *L, - Type *T) const { +const SCEV *DependenceInfo::collectUpperBound(const Loop *L, Type *T) const { if (SE->hasLoopInvariantBackedgeTakenCount(L)) { const SCEV *UB = SE->getBackedgeTakenCount(L); return SE->getTruncateOrZeroExtend(UB, T); @@ -1007,9 +998,8 @@ const SCEV *DependenceAnalysis::collectUpperBound(const Loop *L, // Calls collectUpperBound(), then attempts to cast it to SCEVConstant. // If the cast fails, returns NULL. -const SCEVConstant *DependenceAnalysis::collectConstantUpperBound(const Loop *L, - Type *T - ) const { +const SCEVConstant *DependenceInfo::collectConstantUpperBound(const Loop *L, + Type *T) const { if (const SCEV *UB = collectUpperBound(L, T)) return dyn_cast<SCEVConstant>(UB); return nullptr; @@ -1026,9 +1016,8 @@ const SCEVConstant *DependenceAnalysis::collectConstantUpperBound(const Loop *L, // 3) the values might be equal, so we have to assume a dependence. // // Return true if dependence disproved. -bool DependenceAnalysis::testZIV(const SCEV *Src, - const SCEV *Dst, - FullDependence &Result) const { +bool DependenceInfo::testZIV(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { DEBUG(dbgs() << " src = " << *Src << "\n"); DEBUG(dbgs() << " dst = " << *Dst << "\n"); ++ZIVapplications; @@ -1074,13 +1063,10 @@ bool DependenceAnalysis::testZIV(const SCEV *Src, // { > if d < 0 // // Return true if dependence disproved. -bool DependenceAnalysis::strongSIVtest(const SCEV *Coeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *CurLoop, - unsigned Level, - FullDependence &Result, - Constraint &NewConstraint) const { +bool DependenceInfo::strongSIVtest(const SCEV *Coeff, const SCEV *SrcConst, + const SCEV *DstConst, const Loop *CurLoop, + unsigned Level, FullDependence &Result, + Constraint &NewConstraint) const { DEBUG(dbgs() << "\tStrong SIV test\n"); DEBUG(dbgs() << "\t Coeff = " << *Coeff); DEBUG(dbgs() << ", " << *Coeff->getType() << "\n"); @@ -1213,14 +1199,10 @@ bool DependenceAnalysis::strongSIVtest(const SCEV *Coeff, // Can determine iteration for splitting. // // Return true if dependence disproved. -bool DependenceAnalysis::weakCrossingSIVtest(const SCEV *Coeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *CurLoop, - unsigned Level, - FullDependence &Result, - Constraint &NewConstraint, - const SCEV *&SplitIter) const { +bool DependenceInfo::weakCrossingSIVtest( + const SCEV *Coeff, const SCEV *SrcConst, const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, FullDependence &Result, + Constraint &NewConstraint, const SCEV *&SplitIter) const { DEBUG(dbgs() << "\tWeak-Crossing SIV test\n"); DEBUG(dbgs() << "\t Coeff = " << *Coeff << "\n"); DEBUG(dbgs() << "\t SrcConst = " << *SrcConst << "\n"); @@ -1256,7 +1238,7 @@ bool DependenceAnalysis::weakCrossingSIVtest(const SCEV *Coeff, } assert(SE->isKnownPositive(ConstCoeff) && "ConstCoeff should be positive"); - // compute SplitIter for use by DependenceAnalysis::getSplitIteration() + // compute SplitIter for use by DependenceInfo::getSplitIteration() SplitIter = SE->getUDivExpr( SE->getSMaxExpr(SE->getZero(Delta->getType()), Delta), SE->getMulExpr(SE->getConstant(Delta->getType(), 2), ConstCoeff)); @@ -1344,9 +1326,8 @@ bool DependenceAnalysis::weakCrossingSIVtest(const SCEV *Coeff, // Computes the GCD of AM and BM. // Also finds a solution to the equation ax - by = gcd(a, b). // Returns true if dependence disproved; i.e., gcd does not divide Delta. -static -bool findGCD(unsigned Bits, APInt AM, APInt BM, APInt Delta, - APInt &G, APInt &X, APInt &Y) { +static bool findGCD(unsigned Bits, const APInt &AM, const APInt &BM, + const APInt &Delta, APInt &G, APInt &X, APInt &Y) { APInt A0(Bits, 1, true), A1(Bits, 0, true); APInt B0(Bits, 0, true), B1(Bits, 1, true); APInt G0 = AM.abs(); @@ -1375,9 +1356,7 @@ bool findGCD(unsigned Bits, APInt AM, APInt BM, APInt Delta, return false; } - -static -APInt floorOfQuotient(APInt A, APInt B) { +static APInt floorOfQuotient(const APInt &A, const APInt &B) { APInt Q = A; // these need to be initialized APInt R = A; APInt::sdivrem(A, B, Q, R); @@ -1390,9 +1369,7 @@ APInt floorOfQuotient(APInt A, APInt B) { return Q - 1; } - -static -APInt ceilingOfQuotient(APInt A, APInt B) { +static APInt ceilingOfQuotient(const APInt &A, const APInt &B) { APInt Q = A; // these need to be initialized APInt R = A; APInt::sdivrem(A, B, Q, R); @@ -1433,14 +1410,11 @@ APInt minAPInt(APInt A, APInt B) { // in the case of the strong SIV test, can compute Distances. // // Return true if dependence disproved. -bool DependenceAnalysis::exactSIVtest(const SCEV *SrcCoeff, - const SCEV *DstCoeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *CurLoop, - unsigned Level, - FullDependence &Result, - Constraint &NewConstraint) const { +bool DependenceInfo::exactSIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, + const SCEV *SrcConst, const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { DEBUG(dbgs() << "\tExact SIV test\n"); DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n"); DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n"); @@ -1608,8 +1582,8 @@ bool DependenceAnalysis::exactSIVtest(const SCEV *SrcCoeff, static bool isRemainderZero(const SCEVConstant *Dividend, const SCEVConstant *Divisor) { - APInt ConstDividend = Dividend->getAPInt(); - APInt ConstDivisor = Divisor->getAPInt(); + const APInt &ConstDividend = Dividend->getAPInt(); + const APInt &ConstDivisor = Divisor->getAPInt(); return ConstDividend.srem(ConstDivisor) == 0; } @@ -1645,13 +1619,12 @@ bool isRemainderZero(const SCEVConstant *Dividend, // (see also weakZeroDstSIVtest) // // Return true if dependence disproved. -bool DependenceAnalysis::weakZeroSrcSIVtest(const SCEV *DstCoeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *CurLoop, - unsigned Level, - FullDependence &Result, - Constraint &NewConstraint) const { +bool DependenceInfo::weakZeroSrcSIVtest(const SCEV *DstCoeff, + const SCEV *SrcConst, + const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { // For the WeakSIV test, it's possible the loop isn't common to // the Src and Dst loops. If it isn't, then there's no need to // record a direction. @@ -1756,13 +1729,12 @@ bool DependenceAnalysis::weakZeroSrcSIVtest(const SCEV *DstCoeff, // (see also weakZeroSrcSIVtest) // // Return true if dependence disproved. -bool DependenceAnalysis::weakZeroDstSIVtest(const SCEV *SrcCoeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *CurLoop, - unsigned Level, - FullDependence &Result, - Constraint &NewConstraint) const { +bool DependenceInfo::weakZeroDstSIVtest(const SCEV *SrcCoeff, + const SCEV *SrcConst, + const SCEV *DstConst, + const Loop *CurLoop, unsigned Level, + FullDependence &Result, + Constraint &NewConstraint) const { // For the WeakSIV test, it's possible the loop isn't common to the // Src and Dst loops. If it isn't, then there's no need to record a direction. DEBUG(dbgs() << "\tWeak-Zero (dst) SIV test\n"); @@ -1842,13 +1814,10 @@ bool DependenceAnalysis::weakZeroDstSIVtest(const SCEV *SrcCoeff, // Returns true if any possible dependence is disproved. // Marks the result as inconsistent. // Works in some cases that symbolicRDIVtest doesn't, and vice versa. -bool DependenceAnalysis::exactRDIVtest(const SCEV *SrcCoeff, - const SCEV *DstCoeff, - const SCEV *SrcConst, - const SCEV *DstConst, - const Loop *SrcLoop, - const Loop *DstLoop, - FullDependence &Result) const { +bool DependenceInfo::exactRDIVtest(const SCEV *SrcCoeff, const SCEV *DstCoeff, + const SCEV *SrcConst, const SCEV *DstConst, + const Loop *SrcLoop, const Loop *DstLoop, + FullDependence &Result) const { DEBUG(dbgs() << "\tExact RDIV test\n"); DEBUG(dbgs() << "\t SrcCoeff = " << *SrcCoeff << " = AM\n"); DEBUG(dbgs() << "\t DstCoeff = " << *DstCoeff << " = BM\n"); @@ -1986,12 +1955,10 @@ bool DependenceAnalysis::exactRDIVtest(const SCEV *SrcCoeff, // a1*N1 <= c2 - c1 <= -a2*N2 // // return true if dependence disproved -bool DependenceAnalysis::symbolicRDIVtest(const SCEV *A1, - const SCEV *A2, - const SCEV *C1, - const SCEV *C2, - const Loop *Loop1, - const Loop *Loop2) const { +bool DependenceInfo::symbolicRDIVtest(const SCEV *A1, const SCEV *A2, + const SCEV *C1, const SCEV *C2, + const Loop *Loop1, + const Loop *Loop2) const { ++SymbolicRDIVapplications; DEBUG(dbgs() << "\ttry symbolic RDIV test\n"); DEBUG(dbgs() << "\t A1 = " << *A1); @@ -2103,12 +2070,9 @@ bool DependenceAnalysis::symbolicRDIVtest(const SCEV *A1, // they apply; they're cheaper and sometimes more precise. // // Return true if dependence disproved. -bool DependenceAnalysis::testSIV(const SCEV *Src, - const SCEV *Dst, - unsigned &Level, - FullDependence &Result, - Constraint &NewConstraint, - const SCEV *&SplitIter) const { +bool DependenceInfo::testSIV(const SCEV *Src, const SCEV *Dst, unsigned &Level, + FullDependence &Result, Constraint &NewConstraint, + const SCEV *&SplitIter) const { DEBUG(dbgs() << " src = " << *Src << "\n"); DEBUG(dbgs() << " dst = " << *Dst << "\n"); const SCEVAddRecExpr *SrcAddRec = dyn_cast<SCEVAddRecExpr>(Src); @@ -2174,9 +2138,8 @@ bool DependenceAnalysis::testSIV(const SCEV *Src, // [c1 + a1*i + a2*j][c2]. // // Return true if dependence disproved. -bool DependenceAnalysis::testRDIV(const SCEV *Src, - const SCEV *Dst, - FullDependence &Result) const { +bool DependenceInfo::testRDIV(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { // we have 3 possible situations here: // 1) [a*i + b] and [c*j + d] // 2) [a*i + c*j + b] and [d] @@ -2241,10 +2204,9 @@ bool DependenceAnalysis::testRDIV(const SCEV *Src, // Tests the single-subscript MIV pair (Src and Dst) for dependence. // Return true if dependence disproved. // Can sometimes refine direction vectors. -bool DependenceAnalysis::testMIV(const SCEV *Src, - const SCEV *Dst, - const SmallBitVector &Loops, - FullDependence &Result) const { +bool DependenceInfo::testMIV(const SCEV *Src, const SCEV *Dst, + const SmallBitVector &Loops, + FullDependence &Result) const { DEBUG(dbgs() << " src = " << *Src << "\n"); DEBUG(dbgs() << " dst = " << *Dst << "\n"); Result.Consistent = false; @@ -2256,11 +2218,12 @@ bool DependenceAnalysis::testMIV(const SCEV *Src, // Given a product, e.g., 10*X*Y, returns the first constant operand, // in this case 10. If there is no constant part, returns NULL. static -const SCEVConstant *getConstantPart(const SCEVMulExpr *Product) { - for (unsigned Op = 0, Ops = Product->getNumOperands(); Op < Ops; Op++) { - if (const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Product->getOperand(Op))) +const SCEVConstant *getConstantPart(const SCEV *Expr) { + if (const auto *Constant = dyn_cast<SCEVConstant>(Expr)) + return Constant; + else if (const auto *Product = dyn_cast<SCEVMulExpr>(Expr)) + if (const auto *Constant = dyn_cast<SCEVConstant>(Product->getOperand(0))) return Constant; - } return nullptr; } @@ -2283,9 +2246,8 @@ const SCEVConstant *getConstantPart(const SCEVMulExpr *Product) { // It occurs to me that the presence of loop-invariant variables // changes the nature of the test from "greatest common divisor" // to "a common divisor". -bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, - const SCEV *Dst, - FullDependence &Result) const { +bool DependenceInfo::gcdMIVtest(const SCEV *Src, const SCEV *Dst, + FullDependence &Result) const { DEBUG(dbgs() << "starting gcd\n"); ++GCDapplications; unsigned BitWidth = SE->getTypeSizeInBits(Src->getType()); @@ -2299,11 +2261,9 @@ bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, while (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Coefficients)) { const SCEV *Coeff = AddRec->getStepRecurrence(*SE); - const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Coeff); - if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Coeff)) - // If the coefficient is the product of a constant and other stuff, - // we can use the constant in the GCD computation. - Constant = getConstantPart(Product); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + const auto *Constant = getConstantPart(Coeff); if (!Constant) return false; APInt ConstCoeff = Constant->getAPInt(); @@ -2320,11 +2280,9 @@ bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, while (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Coefficients)) { const SCEV *Coeff = AddRec->getStepRecurrence(*SE); - const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Coeff); - if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Coeff)) - // If the coefficient is the product of a constant and other stuff, - // we can use the constant in the GCD computation. - Constant = getConstantPart(Product); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + const auto *Constant = getConstantPart(Coeff); if (!Constant) return false; APInt ConstCoeff = Constant->getAPInt(); @@ -2403,12 +2361,11 @@ bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, if (CurLoop == AddRec->getLoop()) ; // SrcCoeff == Coeff else { - if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Coeff)) - // If the coefficient is the product of a constant and other stuff, - // we can use the constant in the GCD computation. - Constant = getConstantPart(Product); - else - Constant = cast<SCEVConstant>(Coeff); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Coeff); + if (!Constant) + return false; APInt ConstCoeff = Constant->getAPInt(); RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); } @@ -2421,29 +2378,24 @@ bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, if (CurLoop == AddRec->getLoop()) DstCoeff = Coeff; else { - if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Coeff)) - // If the coefficient is the product of a constant and other stuff, - // we can use the constant in the GCD computation. - Constant = getConstantPart(Product); - else - Constant = cast<SCEVConstant>(Coeff); + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Coeff); + if (!Constant) + return false; APInt ConstCoeff = Constant->getAPInt(); RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); } Inner = AddRec->getStart(); } Delta = SE->getMinusSCEV(SrcCoeff, DstCoeff); - if (const SCEVMulExpr *Product = dyn_cast<SCEVMulExpr>(Delta)) - // If the coefficient is the product of a constant and other stuff, - // we can use the constant in the GCD computation. - Constant = getConstantPart(Product); - else if (isa<SCEVConstant>(Delta)) - Constant = cast<SCEVConstant>(Delta); - else { + // If the coefficient is the product of a constant and other stuff, + // we can use the constant in the GCD computation. + Constant = getConstantPart(Delta); + if (!Constant) // The difference of the two coefficients might not be a product // or constant, in which case we give up on this direction. continue; - } APInt ConstCoeff = Constant->getAPInt(); RunningGCD = APIntOps::GreatestCommonDivisor(RunningGCD, ConstCoeff.abs()); DEBUG(dbgs() << "\tRunningGCD = " << RunningGCD << "\n"); @@ -2497,10 +2449,9 @@ bool DependenceAnalysis::gcdMIVtest(const SCEV *Src, // for the lower bound, NULL denotes -inf. // // Return true if dependence disproved. -bool DependenceAnalysis::banerjeeMIVtest(const SCEV *Src, - const SCEV *Dst, - const SmallBitVector &Loops, - FullDependence &Result) const { +bool DependenceInfo::banerjeeMIVtest(const SCEV *Src, const SCEV *Dst, + const SmallBitVector &Loops, + FullDependence &Result) const { DEBUG(dbgs() << "starting Banerjee\n"); ++BanerjeeApplications; DEBUG(dbgs() << " Src = " << *Src << '\n'); @@ -2578,13 +2529,11 @@ bool DependenceAnalysis::banerjeeMIVtest(const SCEV *Src, // in the DirSet field of Bound. Returns the number of distinct // dependences discovered. If the dependence is disproved, // it will return 0. -unsigned DependenceAnalysis::exploreDirections(unsigned Level, - CoefficientInfo *A, - CoefficientInfo *B, - BoundInfo *Bound, - const SmallBitVector &Loops, - unsigned &DepthExpanded, - const SCEV *Delta) const { +unsigned DependenceInfo::exploreDirections(unsigned Level, CoefficientInfo *A, + CoefficientInfo *B, BoundInfo *Bound, + const SmallBitVector &Loops, + unsigned &DepthExpanded, + const SCEV *Delta) const { if (Level > CommonLevels) { // record result DEBUG(dbgs() << "\t["); @@ -2679,10 +2628,8 @@ unsigned DependenceAnalysis::exploreDirections(unsigned Level, // Returns true iff the current bounds are plausible. -bool DependenceAnalysis::testBounds(unsigned char DirKind, - unsigned Level, - BoundInfo *Bound, - const SCEV *Delta) const { +bool DependenceInfo::testBounds(unsigned char DirKind, unsigned Level, + BoundInfo *Bound, const SCEV *Delta) const { Bound[Level].Direction = DirKind; if (const SCEV *LowerBound = getLowerBound(Bound)) if (isKnownPredicate(CmpInst::ICMP_SGT, LowerBound, Delta)) @@ -2709,10 +2656,8 @@ bool DependenceAnalysis::testBounds(unsigned char DirKind, // We must be careful to handle the case where the upper bound is unknown. // Note that the lower bound is always <= 0 // and the upper bound is always >= 0. -void DependenceAnalysis::findBoundsALL(CoefficientInfo *A, - CoefficientInfo *B, - BoundInfo *Bound, - unsigned K) const { +void DependenceInfo::findBoundsALL(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { Bound[K].Lower[Dependence::DVEntry::ALL] = nullptr; // Default value = -infinity. Bound[K].Upper[Dependence::DVEntry::ALL] = nullptr; // Default value = +infinity. if (Bound[K].Iterations) { @@ -2750,10 +2695,8 @@ void DependenceAnalysis::findBoundsALL(CoefficientInfo *A, // We must be careful to handle the case where the upper bound is unknown. // Note that the lower bound is always <= 0 // and the upper bound is always >= 0. -void DependenceAnalysis::findBoundsEQ(CoefficientInfo *A, - CoefficientInfo *B, - BoundInfo *Bound, - unsigned K) const { +void DependenceInfo::findBoundsEQ(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { Bound[K].Lower[Dependence::DVEntry::EQ] = nullptr; // Default value = -infinity. Bound[K].Upper[Dependence::DVEntry::EQ] = nullptr; // Default value = +infinity. if (Bound[K].Iterations) { @@ -2792,10 +2735,8 @@ void DependenceAnalysis::findBoundsEQ(CoefficientInfo *A, // UB^<_k = (A^+_k - B_k)^+ (U_k - 1) - B_k // // We must be careful to handle the case where the upper bound is unknown. -void DependenceAnalysis::findBoundsLT(CoefficientInfo *A, - CoefficientInfo *B, - BoundInfo *Bound, - unsigned K) const { +void DependenceInfo::findBoundsLT(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { Bound[K].Lower[Dependence::DVEntry::LT] = nullptr; // Default value = -infinity. Bound[K].Upper[Dependence::DVEntry::LT] = nullptr; // Default value = +infinity. if (Bound[K].Iterations) { @@ -2838,10 +2779,8 @@ void DependenceAnalysis::findBoundsLT(CoefficientInfo *A, // UB^>_k = (A_k - B^-_k)^+ (U_k - 1) + A_k // // We must be careful to handle the case where the upper bound is unknown. -void DependenceAnalysis::findBoundsGT(CoefficientInfo *A, - CoefficientInfo *B, - BoundInfo *Bound, - unsigned K) const { +void DependenceInfo::findBoundsGT(CoefficientInfo *A, CoefficientInfo *B, + BoundInfo *Bound, unsigned K) const { Bound[K].Lower[Dependence::DVEntry::GT] = nullptr; // Default value = -infinity. Bound[K].Upper[Dependence::DVEntry::GT] = nullptr; // Default value = +infinity. if (Bound[K].Iterations) { @@ -2870,13 +2809,13 @@ void DependenceAnalysis::findBoundsGT(CoefficientInfo *A, // X^+ = max(X, 0) -const SCEV *DependenceAnalysis::getPositivePart(const SCEV *X) const { +const SCEV *DependenceInfo::getPositivePart(const SCEV *X) const { return SE->getSMaxExpr(X, SE->getZero(X->getType())); } // X^- = min(X, 0) -const SCEV *DependenceAnalysis::getNegativePart(const SCEV *X) const { +const SCEV *DependenceInfo::getNegativePart(const SCEV *X) const { return SE->getSMinExpr(X, SE->getZero(X->getType())); } @@ -2884,10 +2823,9 @@ const SCEV *DependenceAnalysis::getNegativePart(const SCEV *X) const { // Walks through the subscript, // collecting each coefficient, the associated loop bounds, // and recording its positive and negative parts for later use. -DependenceAnalysis::CoefficientInfo * -DependenceAnalysis::collectCoeffInfo(const SCEV *Subscript, - bool SrcFlag, - const SCEV *&Constant) const { +DependenceInfo::CoefficientInfo * +DependenceInfo::collectCoeffInfo(const SCEV *Subscript, bool SrcFlag, + const SCEV *&Constant) const { const SCEV *Zero = SE->getZero(Subscript->getType()); CoefficientInfo *CI = new CoefficientInfo[MaxLevels + 1]; for (unsigned K = 1; K <= MaxLevels; ++K) { @@ -2931,7 +2869,7 @@ DependenceAnalysis::collectCoeffInfo(const SCEV *Subscript, // computes the lower bound given the current direction settings // at each level. If the lower bound for any level is -inf, // the result is -inf. -const SCEV *DependenceAnalysis::getLowerBound(BoundInfo *Bound) const { +const SCEV *DependenceInfo::getLowerBound(BoundInfo *Bound) const { const SCEV *Sum = Bound[1].Lower[Bound[1].Direction]; for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { if (Bound[K].Lower[Bound[K].Direction]) @@ -2947,7 +2885,7 @@ const SCEV *DependenceAnalysis::getLowerBound(BoundInfo *Bound) const { // computes the upper bound given the current direction settings // at each level. If the upper bound at any level is +inf, // the result is +inf. -const SCEV *DependenceAnalysis::getUpperBound(BoundInfo *Bound) const { +const SCEV *DependenceInfo::getUpperBound(BoundInfo *Bound) const { const SCEV *Sum = Bound[1].Upper[Bound[1].Direction]; for (unsigned K = 2; Sum && K <= MaxLevels; ++K) { if (Bound[K].Upper[Bound[K].Direction]) @@ -2968,8 +2906,8 @@ const SCEV *DependenceAnalysis::getUpperBound(BoundInfo *Bound) const { // If there isn't one, return 0. // For example, given a*i + b*j + c*k, finding the coefficient // corresponding to the j loop would yield b. -const SCEV *DependenceAnalysis::findCoefficient(const SCEV *Expr, - const Loop *TargetLoop) const { +const SCEV *DependenceInfo::findCoefficient(const SCEV *Expr, + const Loop *TargetLoop) const { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); if (!AddRec) return SE->getZero(Expr->getType()); @@ -2984,8 +2922,8 @@ const SCEV *DependenceAnalysis::findCoefficient(const SCEV *Expr, // corresponding to the specified loop. // For example, given a*i + b*j + c*k, zeroing the coefficient // corresponding to the j loop would yield a*i + c*k. -const SCEV *DependenceAnalysis::zeroCoefficient(const SCEV *Expr, - const Loop *TargetLoop) const { +const SCEV *DependenceInfo::zeroCoefficient(const SCEV *Expr, + const Loop *TargetLoop) const { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); if (!AddRec) return Expr; // ignore @@ -3003,9 +2941,9 @@ const SCEV *DependenceAnalysis::zeroCoefficient(const SCEV *Expr, // coefficient corresponding to the specified TargetLoop. // For example, given a*i + b*j + c*k, adding 1 to the coefficient // corresponding to the j loop would yield a*i + (b+1)*j + c*k. -const SCEV *DependenceAnalysis::addToCoefficient(const SCEV *Expr, - const Loop *TargetLoop, - const SCEV *Value) const { +const SCEV *DependenceInfo::addToCoefficient(const SCEV *Expr, + const Loop *TargetLoop, + const SCEV *Value) const { const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(Expr); if (!AddRec) // create a new addRec return SE->getAddRecExpr(Expr, @@ -3040,11 +2978,10 @@ const SCEV *DependenceAnalysis::addToCoefficient(const SCEV *Expr, // Practical Dependence Testing // Goff, Kennedy, Tseng // PLDI 1991 -bool DependenceAnalysis::propagate(const SCEV *&Src, - const SCEV *&Dst, - SmallBitVector &Loops, - SmallVectorImpl<Constraint> &Constraints, - bool &Consistent) { +bool DependenceInfo::propagate(const SCEV *&Src, const SCEV *&Dst, + SmallBitVector &Loops, + SmallVectorImpl<Constraint> &Constraints, + bool &Consistent) { bool Result = false; for (int LI = Loops.find_first(); LI >= 0; LI = Loops.find_next(LI)) { DEBUG(dbgs() << "\t Constraint[" << LI << "] is"); @@ -3065,10 +3002,9 @@ bool DependenceAnalysis::propagate(const SCEV *&Src, // Return true if some simplification occurs. // If the simplification isn't exact (that is, if it is conservative // in terms of dependence), set consistent to false. -bool DependenceAnalysis::propagateDistance(const SCEV *&Src, - const SCEV *&Dst, - Constraint &CurConstraint, - bool &Consistent) { +bool DependenceInfo::propagateDistance(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint, + bool &Consistent) { const Loop *CurLoop = CurConstraint.getAssociatedLoop(); DEBUG(dbgs() << "\t\tSrc is " << *Src << "\n"); const SCEV *A_K = findCoefficient(Src, CurLoop); @@ -3092,10 +3028,9 @@ bool DependenceAnalysis::propagateDistance(const SCEV *&Src, // Return true if some simplification occurs. // If the simplification isn't exact (that is, if it is conservative // in terms of dependence), set consistent to false. -bool DependenceAnalysis::propagateLine(const SCEV *&Src, - const SCEV *&Dst, - Constraint &CurConstraint, - bool &Consistent) { +bool DependenceInfo::propagateLine(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint, + bool &Consistent) { const Loop *CurLoop = CurConstraint.getAssociatedLoop(); const SCEV *A = CurConstraint.getA(); const SCEV *B = CurConstraint.getB(); @@ -3167,9 +3102,8 @@ bool DependenceAnalysis::propagateLine(const SCEV *&Src, // Attempt to propagate a point // constraint into a subscript pair (Src and Dst). // Return true if some simplification occurs. -bool DependenceAnalysis::propagatePoint(const SCEV *&Src, - const SCEV *&Dst, - Constraint &CurConstraint) { +bool DependenceInfo::propagatePoint(const SCEV *&Src, const SCEV *&Dst, + Constraint &CurConstraint) { const Loop *CurLoop = CurConstraint.getAssociatedLoop(); const SCEV *A_K = findCoefficient(Src, CurLoop); const SCEV *AP_K = findCoefficient(Dst, CurLoop); @@ -3187,9 +3121,8 @@ bool DependenceAnalysis::propagatePoint(const SCEV *&Src, // Update direction vector entry based on the current constraint. -void DependenceAnalysis::updateDirection(Dependence::DVEntry &Level, - const Constraint &CurConstraint - ) const { +void DependenceInfo::updateDirection(Dependence::DVEntry &Level, + const Constraint &CurConstraint) const { DEBUG(dbgs() << "\tUpdate direction, constraint ="); DEBUG(CurConstraint.dump(dbgs())); if (CurConstraint.isAny()) @@ -3241,10 +3174,8 @@ void DependenceAnalysis::updateDirection(Dependence::DVEntry &Level, /// source and destination array references are recurrences on a nested loop, /// this function flattens the nested recurrences into separate recurrences /// for each loop level. -bool DependenceAnalysis::tryDelinearize(Instruction *Src, - Instruction *Dst, - SmallVectorImpl<Subscript> &Pair) -{ +bool DependenceInfo::tryDelinearize(Instruction *Src, Instruction *Dst, + SmallVectorImpl<Subscript> &Pair) { Value *SrcPtr = getPointerOperand(Src); Value *DstPtr = getPointerOperand(Dst); @@ -3355,8 +3286,8 @@ static void dumpSmallBitVector(SmallBitVector &BV) { // Care is required to keep the routine below, getSplitIteration(), // up to date with respect to this routine. std::unique_ptr<Dependence> -DependenceAnalysis::depends(Instruction *Src, Instruction *Dst, - bool PossiblyLoopIndependent) { +DependenceInfo::depends(Instruction *Src, Instruction *Dst, + bool PossiblyLoopIndependent) { if (Src == Dst) PossiblyLoopIndependent = false; @@ -3811,8 +3742,8 @@ DependenceAnalysis::depends(Instruction *Src, Instruction *Dst, // // breaks the dependence and allows us to vectorize/parallelize // both loops. -const SCEV *DependenceAnalysis::getSplitIteration(const Dependence &Dep, - unsigned SplitLevel) { +const SCEV *DependenceInfo::getSplitIteration(const Dependence &Dep, + unsigned SplitLevel) { assert(Dep.isSplitable(SplitLevel) && "Dep should be splitable at SplitLevel"); Instruction *Src = Dep.getSrc(); diff --git a/lib/Analysis/DivergenceAnalysis.cpp b/lib/Analysis/DivergenceAnalysis.cpp index 5ae6d74130a7..1b36569f7a07 100644 --- a/lib/Analysis/DivergenceAnalysis.cpp +++ b/lib/Analysis/DivergenceAnalysis.cpp @@ -73,10 +73,8 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Value.h" -#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include <vector> using namespace llvm; @@ -140,14 +138,25 @@ void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) { // a2 = 2; // a = phi(a1, a2); // sync dependent on (tid < 5) BasicBlock *ThisBB = TI->getParent(); - BasicBlock *IPostDom = PDT.getNode(ThisBB)->getIDom()->getBlock(); + + // Unreachable blocks may not be in the dominator tree. + if (!DT.isReachableFromEntry(ThisBB)) + return; + + // If the function has no exit blocks or doesn't reach any exit blocks, the + // post dominator may be null. + DomTreeNode *ThisNode = PDT.getNode(ThisBB); + if (!ThisNode) + return; + + BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); if (IPostDom == nullptr) return; for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { // A PHINode is uniform if it returns the same value no matter which path is // taken. - if (!cast<PHINode>(I)->hasConstantValue() && DV.insert(&*I).second) + if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) Worklist.push_back(&*I); } @@ -259,7 +268,7 @@ char DivergenceAnalysis::ID = 0; INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTree) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis", false, true) @@ -269,7 +278,7 @@ FunctionPass *llvm::createDivergenceAnalysisPass() { void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTree>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); AU.setPreservesAll(); } @@ -285,9 +294,10 @@ bool DivergenceAnalysis::runOnFunction(Function &F) { return false; DivergentValues.clear(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); DivergencePropagator DP(F, TTI, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - getAnalysis<PostDominatorTree>(), DivergentValues); + PDT, DivergentValues); DP.populateWithSourcesOfDivergence(); DP.propagate(); return false; diff --git a/lib/Analysis/DomPrinter.cpp b/lib/Analysis/DomPrinter.cpp index 0c880df54f8e..7acfb41500d4 100644 --- a/lib/Analysis/DomPrinter.cpp +++ b/lib/Analysis/DomPrinter.cpp @@ -111,20 +111,36 @@ struct DomOnlyViewer : public DOTGraphTraitsViewer< } }; -struct PostDomViewer - : public DOTGraphTraitsViewer<PostDominatorTree, false> { +struct PostDominatorTreeWrapperPassAnalysisGraphTraits { + static PostDominatorTree *getGraph(PostDominatorTreeWrapperPass *PDTWP) { + return &PDTWP->getPostDomTree(); + } +}; + +struct PostDomViewer : public DOTGraphTraitsViewer< + PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; PostDomViewer() : - DOTGraphTraitsViewer<PostDominatorTree, false>("postdom", ID){ + DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdom", ID){ initializePostDomViewerPass(*PassRegistry::getPassRegistry()); } }; -struct PostDomOnlyViewer - : public DOTGraphTraitsViewer<PostDominatorTree, true> { +struct PostDomOnlyViewer : public DOTGraphTraitsViewer< + PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; PostDomOnlyViewer() : - DOTGraphTraitsViewer<PostDominatorTree, true>("postdomonly", ID){ + DOTGraphTraitsViewer<PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID){ initializePostDomOnlyViewerPass(*PassRegistry::getPassRegistry()); } }; @@ -175,19 +191,31 @@ struct DomOnlyPrinter : public DOTGraphTraitsPrinter< }; struct PostDomPrinter - : public DOTGraphTraitsPrinter<PostDominatorTree, false> { + : public DOTGraphTraitsPrinter< + PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; PostDomPrinter() : - DOTGraphTraitsPrinter<PostDominatorTree, false>("postdom", ID) { + DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, false, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdom", ID) { initializePostDomPrinterPass(*PassRegistry::getPassRegistry()); } }; struct PostDomOnlyPrinter - : public DOTGraphTraitsPrinter<PostDominatorTree, true> { + : public DOTGraphTraitsPrinter< + PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits> { static char ID; PostDomOnlyPrinter() : - DOTGraphTraitsPrinter<PostDominatorTree, true>("postdomonly", ID) { + DOTGraphTraitsPrinter<PostDominatorTreeWrapperPass, true, + PostDominatorTree *, + PostDominatorTreeWrapperPassAnalysisGraphTraits>( + "postdomonly", ID) { initializePostDomOnlyPrinterPass(*PassRegistry::getPassRegistry()); } }; diff --git a/lib/Analysis/DominanceFrontier.cpp b/lib/Analysis/DominanceFrontier.cpp index 7ba91bc90dfc..4554374252a4 100644 --- a/lib/Analysis/DominanceFrontier.cpp +++ b/lib/Analysis/DominanceFrontier.cpp @@ -9,6 +9,7 @@ #include "llvm/Analysis/DominanceFrontier.h" #include "llvm/Analysis/DominanceFrontierImpl.h" +#include "llvm/IR/PassManager.h" using namespace llvm; @@ -17,41 +18,60 @@ template class DominanceFrontierBase<BasicBlock>; template class ForwardDominanceFrontierBase<BasicBlock>; } -char DominanceFrontier::ID = 0; +char DominanceFrontierWrapperPass::ID = 0; -INITIALIZE_PASS_BEGIN(DominanceFrontier, "domfrontier", +INITIALIZE_PASS_BEGIN(DominanceFrontierWrapperPass, "domfrontier", "Dominance Frontier Construction", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_END(DominanceFrontier, "domfrontier", +INITIALIZE_PASS_END(DominanceFrontierWrapperPass, "domfrontier", "Dominance Frontier Construction", true, true) -DominanceFrontier::DominanceFrontier() - : FunctionPass(ID), - Base() { - initializeDominanceFrontierPass(*PassRegistry::getPassRegistry()); + DominanceFrontierWrapperPass::DominanceFrontierWrapperPass() + : FunctionPass(ID), DF() { + initializeDominanceFrontierWrapperPassPass(*PassRegistry::getPassRegistry()); } -void DominanceFrontier::releaseMemory() { - Base.releaseMemory(); +void DominanceFrontierWrapperPass::releaseMemory() { + DF.releaseMemory(); } -bool DominanceFrontier::runOnFunction(Function &) { +bool DominanceFrontierWrapperPass::runOnFunction(Function &) { releaseMemory(); - Base.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); + DF.analyze(getAnalysis<DominatorTreeWrapperPass>().getDomTree()); return false; } -void DominanceFrontier::getAnalysisUsage(AnalysisUsage &AU) const { +void DominanceFrontierWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<DominatorTreeWrapperPass>(); } -void DominanceFrontier::print(raw_ostream &OS, const Module *) const { - Base.print(OS); +void DominanceFrontierWrapperPass::print(raw_ostream &OS, const Module *) const { + DF.print(OS); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void DominanceFrontier::dump() const { +LLVM_DUMP_METHOD void DominanceFrontierWrapperPass::dump() const { print(dbgs()); } #endif + +char DominanceFrontierAnalysis::PassID; + +DominanceFrontier DominanceFrontierAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + DominanceFrontier DF; + DF.analyze(AM.getResult<DominatorTreeAnalysis>(F)); + return DF; +} + +DominanceFrontierPrinterPass::DominanceFrontierPrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses +DominanceFrontierPrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "DominanceFrontier for function: " << F.getName() << "\n"; + AM.getResult<DominanceFrontierAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} diff --git a/lib/Analysis/EHPersonalities.cpp b/lib/Analysis/EHPersonalities.cpp index 01be8b38fadd..5f951f5112e9 100644 --- a/lib/Analysis/EHPersonalities.cpp +++ b/lib/Analysis/EHPersonalities.cpp @@ -27,13 +27,16 @@ EHPersonality llvm::classifyEHPersonality(const Value *Pers) { return StringSwitch<EHPersonality>(F->getName()) .Case("__gnat_eh_personality", EHPersonality::GNU_Ada) .Case("__gxx_personality_v0", EHPersonality::GNU_CXX) + .Case("__gxx_personality_sj0", EHPersonality::GNU_CXX_SjLj) .Case("__gcc_personality_v0", EHPersonality::GNU_C) + .Case("__gcc_personality_sj0", EHPersonality::GNU_C_SjLj) .Case("__objc_personality_v0", EHPersonality::GNU_ObjC) .Case("_except_handler3", EHPersonality::MSVC_X86SEH) .Case("_except_handler4", EHPersonality::MSVC_X86SEH) .Case("__C_specific_handler", EHPersonality::MSVC_Win64SEH) .Case("__CxxFrameHandler3", EHPersonality::MSVC_CXX) .Case("ProcessCLRException", EHPersonality::CoreCLR) + .Case("rust_eh_personality", EHPersonality::Rust) .Default(EHPersonality::Unknown); } @@ -92,7 +95,7 @@ DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) { BasicBlock *SuccColor = Color; TerminatorInst *Terminator = Visiting->getTerminator(); if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) { - Value *ParentPad = CatchRet->getParentPad(); + Value *ParentPad = CatchRet->getCatchSwitchParentPad(); if (isa<ConstantTokenNone>(ParentPad)) SuccColor = EntryBlock; else diff --git a/lib/Analysis/GlobalsModRef.cpp b/lib/Analysis/GlobalsModRef.cpp index 1babb822074b..a7d1e048e133 100644 --- a/lib/Analysis/GlobalsModRef.cpp +++ b/lib/Analysis/GlobalsModRef.cpp @@ -243,13 +243,14 @@ FunctionModRefBehavior GlobalsAAResult::getModRefBehavior(ImmutableCallSite CS) { FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; - if (const Function *F = CS.getCalledFunction()) - if (FunctionInfo *FI = getFunctionInfo(F)) { - if (FI->getModRefInfo() == MRI_NoModRef) - Min = FMRB_DoesNotAccessMemory; - else if ((FI->getModRefInfo() & MRI_Mod) == 0) - Min = FMRB_OnlyReadsMemory; - } + if (!CS.hasOperandBundles()) + if (const Function *F = CS.getCalledFunction()) + if (FunctionInfo *FI = getFunctionInfo(F)) { + if (FI->getModRefInfo() == MRI_NoModRef) + Min = FMRB_DoesNotAccessMemory; + else if ((FI->getModRefInfo() & MRI_Mod) == 0) + Min = FMRB_OnlyReadsMemory; + } return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); } @@ -269,7 +270,7 @@ GlobalsAAResult::getFunctionInfo(const Function *F) { /// (really, their address passed to something nontrivial), record this fact, /// and record the functions that they are used directly in. void GlobalsAAResult::AnalyzeGlobals(Module &M) { - SmallPtrSet<Function *, 64> TrackedFunctions; + SmallPtrSet<Function *, 32> TrackedFunctions; for (Function &F : M) if (F.hasLocalLinkage()) if (!AnalyzeUsesOfPointer(&F)) { @@ -281,7 +282,7 @@ void GlobalsAAResult::AnalyzeGlobals(Module &M) { ++NumNonAddrTakenFunctions; } - SmallPtrSet<Function *, 64> Readers, Writers; + SmallPtrSet<Function *, 16> Readers, Writers; for (GlobalVariable &GV : M.globals()) if (GV.hasLocalLinkage()) { if (!AnalyzeUsesOfPointer(&GV, &Readers, @@ -310,7 +311,7 @@ void GlobalsAAResult::AnalyzeGlobals(Module &M) { ++NumNonAddrTakenGlobalVars; // If this global holds a pointer type, see if it is an indirect global. - if (GV.getType()->getElementType()->isPointerTy() && + if (GV.getValueType()->isPointerTy() && AnalyzeIndirectGlobalMemory(&GV)) ++NumIndirectGlobalVars; } @@ -470,9 +471,10 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { const std::vector<CallGraphNode *> &SCC = *I; assert(!SCC.empty() && "SCC with no functions?"); - if (!SCC[0]->getFunction() || SCC[0]->getFunction()->mayBeOverridden()) { - // Calls externally or is weak - can't say anything useful. Remove any existing - // function records (may have been created when scanning globals). + if (!SCC[0]->getFunction() || !SCC[0]->getFunction()->isDefinitionExact()) { + // Calls externally or not exact - can't say anything useful. Remove any + // existing function records (may have been created when scanning + // globals). for (auto *Node : SCC) FunctionInfos.erase(Node->getFunction()); continue; @@ -496,7 +498,7 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { // Can't do better than that! } else if (F->onlyReadsMemory()) { FI.addModRefInfo(MRI_Ref); - if (!F->isIntrinsic()) + if (!F->isIntrinsic() && !F->onlyAccessesArgMemory()) // This function might call back into the module and read a global - // consider every global as possibly being read by this function. FI.setMayReadAnyGlobal(); @@ -698,7 +700,7 @@ bool GlobalsAAResult::isNonEscapingGlobalNoAlias(const GlobalValue *GV, auto *InputGVar = dyn_cast<GlobalVariable>(InputGV); if (GVar && InputGVar && !GVar->isDeclaration() && !InputGVar->isDeclaration() && - !GVar->mayBeOverridden() && !InputGVar->mayBeOverridden()) { + !GVar->isInterposable() && !InputGVar->isInterposable()) { Type *GVType = GVar->getInitializer()->getType(); Type *InputGVType = InputGVar->getInitializer()->getType(); if (GVType->isSized() && InputGVType->isSized() && @@ -863,7 +865,11 @@ ModRefInfo GlobalsAAResult::getModRefInfoForArgument(ImmutableCallSite CS, GetUnderlyingObjects(A, Objects, DL); // All objects must be identified. - if (!std::all_of(Objects.begin(), Objects.end(), isIdentifiedObject)) + if (!std::all_of(Objects.begin(), Objects.end(), isIdentifiedObject) && + // Try ::alias to see if all objects are known not to alias GV. + !std::all_of(Objects.begin(), Objects.end(), [&](Value *V) { + return this->alias(MemoryLocation(V), MemoryLocation(GV)) == NoAlias; + })) return ConservativeResult; if (std::find(Objects.begin(), Objects.end(), GV) != Objects.end()) @@ -896,10 +902,10 @@ ModRefInfo GlobalsAAResult::getModRefInfo(ImmutableCallSite CS, GlobalsAAResult::GlobalsAAResult(const DataLayout &DL, const TargetLibraryInfo &TLI) - : AAResultBase(TLI), DL(DL) {} + : AAResultBase(), DL(DL), TLI(TLI) {} GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg) - : AAResultBase(std::move(Arg)), DL(Arg.DL), + : AAResultBase(std::move(Arg)), DL(Arg.DL), TLI(Arg.TLI), NonAddressTakenGlobals(std::move(Arg.NonAddressTakenGlobals)), IndirectGlobals(std::move(Arg.IndirectGlobals)), AllocsForIndirectGlobals(std::move(Arg.AllocsForIndirectGlobals)), @@ -912,6 +918,8 @@ GlobalsAAResult::GlobalsAAResult(GlobalsAAResult &&Arg) } } +GlobalsAAResult::~GlobalsAAResult() {} + /*static*/ GlobalsAAResult GlobalsAAResult::analyzeModule(Module &M, const TargetLibraryInfo &TLI, CallGraph &CG) { @@ -929,14 +937,14 @@ GlobalsAAResult::analyzeModule(Module &M, const TargetLibraryInfo &TLI, return Result; } -GlobalsAAResult GlobalsAA::run(Module &M, AnalysisManager<Module> *AM) { +char GlobalsAA::PassID; + +GlobalsAAResult GlobalsAA::run(Module &M, AnalysisManager<Module> &AM) { return GlobalsAAResult::analyzeModule(M, - AM->getResult<TargetLibraryAnalysis>(M), - AM->getResult<CallGraphAnalysis>(M)); + AM.getResult<TargetLibraryAnalysis>(M), + AM.getResult<CallGraphAnalysis>(M)); } -char GlobalsAA::PassID; - char GlobalsAAWrapperPass::ID = 0; INITIALIZE_PASS_BEGIN(GlobalsAAWrapperPass, "globals-aa", "Globals Alias Analysis", false, true) diff --git a/lib/Analysis/IVUsers.cpp b/lib/Analysis/IVUsers.cpp index e0c5d8fa5f5a..43c0ba17fe4a 100644 --- a/lib/Analysis/IVUsers.cpp +++ b/lib/Analysis/IVUsers.cpp @@ -12,11 +12,12 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Analysis/IVUsers.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/IVUsers.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constants.h" @@ -33,19 +34,35 @@ using namespace llvm; #define DEBUG_TYPE "iv-users" -char IVUsers::ID = 0; -INITIALIZE_PASS_BEGIN(IVUsers, "iv-users", +char IVUsersAnalysis::PassID; + +IVUsers IVUsersAnalysis::run(Loop &L, AnalysisManager<Loop> &AM) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function *F = L.getHeader()->getParent(); + + return IVUsers(&L, FAM.getCachedResult<AssumptionAnalysis>(*F), + FAM.getCachedResult<LoopAnalysis>(*F), + FAM.getCachedResult<DominatorTreeAnalysis>(*F), + FAM.getCachedResult<ScalarEvolutionAnalysis>(*F)); +} + +PreservedAnalyses IVUsersPrinterPass::run(Loop &L, AnalysisManager<Loop> &AM) { + AM.getResult<IVUsersAnalysis>(L).print(OS); + return PreservedAnalyses::all(); +} + +char IVUsersWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(IVUsersWrapperPass, "iv-users", "Induction Variable Users", false, true) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_END(IVUsers, "iv-users", - "Induction Variable Users", false, true) +INITIALIZE_PASS_END(IVUsersWrapperPass, "iv-users", "Induction Variable Users", + false, true) -Pass *llvm::createIVUsersPass() { - return new IVUsers(); -} +Pass *llvm::createIVUsersPass() { return new IVUsersWrapperPass(); } /// isInteresting - Test whether the given expression is "interesting" when /// used by the given expression, within the context of analyzing the @@ -246,28 +263,9 @@ IVStrideUse &IVUsers::AddUser(Instruction *User, Value *Operand) { return IVUses.back(); } -IVUsers::IVUsers() - : LoopPass(ID) { - initializeIVUsersPass(*PassRegistry::getPassRegistry()); -} - -void IVUsers::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequired<LoopInfoWrapperPass>(); - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.setPreservesAll(); -} - -bool IVUsers::runOnLoop(Loop *l, LPPassManager &LPM) { - - L = l; - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( - *L->getHeader()->getParent()); - LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); - DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); - +IVUsers::IVUsers(Loop *L, AssumptionCache *AC, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE) + : L(L), AC(AC), LI(LI), DT(DT), SE(SE), IVUses() { // Collect ephemeral values so that AddUsersIfInteresting skips them. EphValues.clear(); CodeMetrics::collectEphemeralValues(L, AC, EphValues); @@ -277,34 +275,28 @@ bool IVUsers::runOnLoop(Loop *l, LPPassManager &LPM) { // this loop. If they are induction variables, inspect their uses. for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) (void)AddUsersIfInteresting(&*I); - - return false; } void IVUsers::print(raw_ostream &OS, const Module *M) const { OS << "IV Users for loop "; L->getHeader()->printAsOperand(OS, false); if (SE->hasLoopInvariantBackedgeTakenCount(L)) { - OS << " with backedge-taken count " - << *SE->getBackedgeTakenCount(L); + OS << " with backedge-taken count " << *SE->getBackedgeTakenCount(L); } OS << ":\n"; - for (ilist<IVStrideUse>::const_iterator UI = IVUses.begin(), - E = IVUses.end(); UI != E; ++UI) { + for (const IVStrideUse &IVUse : IVUses) { OS << " "; - UI->getOperandValToReplace()->printAsOperand(OS, false); - OS << " = " << *getReplacementExpr(*UI); - for (PostIncLoopSet::const_iterator - I = UI->PostIncLoops.begin(), - E = UI->PostIncLoops.end(); I != E; ++I) { + IVUse.getOperandValToReplace()->printAsOperand(OS, false); + OS << " = " << *getReplacementExpr(IVUse); + for (auto PostIncLoop : IVUse.PostIncLoops) { OS << " (post-inc with loop "; - (*I)->getHeader()->printAsOperand(OS, false); + PostIncLoop->getHeader()->printAsOperand(OS, false); OS << ")"; } OS << " in "; - if (UI->getUser()) - UI->getUser()->print(OS); + if (IVUse.getUser()) + IVUse.getUser()->print(OS); else OS << "Printing <null> User"; OS << '\n'; @@ -312,9 +304,7 @@ void IVUsers::print(raw_ostream &OS, const Module *M) const { } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void IVUsers::dump() const { - print(dbgs()); -} +LLVM_DUMP_METHOD void IVUsers::dump() const { print(dbgs()); } #endif void IVUsers::releaseMemory() { @@ -322,6 +312,35 @@ void IVUsers::releaseMemory() { IVUses.clear(); } +IVUsersWrapperPass::IVUsersWrapperPass() : LoopPass(ID) { + initializeIVUsersWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void IVUsersWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.setPreservesAll(); +} + +bool IVUsersWrapperPass::runOnLoop(Loop *L, LPPassManager &LPM) { + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + + IU.reset(new IVUsers(L, AC, LI, DT, SE)); + return false; +} + +void IVUsersWrapperPass::print(raw_ostream &OS, const Module *M) const { + IU->print(OS, M); +} + +void IVUsersWrapperPass::releaseMemory() { IU->releaseMemory(); } + /// getReplacementExpr - Return a SCEV expression which computes the /// value of the OperandValToReplace. const SCEV *IVUsers::getReplacementExpr(const IVStrideUse &IU) const { diff --git a/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/lib/Analysis/IndirectCallPromotionAnalysis.cpp new file mode 100644 index 000000000000..3da33ac71421 --- /dev/null +++ b/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -0,0 +1,109 @@ +//===-- IndirectCallPromotionAnalysis.cpp - Find promotion candidates ===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Helper methods for identifying profitable indirect call promotion +// candidates for an instruction when the indirect-call value profile metadata +// is available. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Debug.h" +#include <string> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "pgo-icall-prom-analysis" + +// The minimum call count for the direct-call target to be considered as the +// promotion candidate. +static cl::opt<unsigned> + ICPCountThreshold("icp-count-threshold", cl::Hidden, cl::ZeroOrMore, + cl::init(1000), + cl::desc("The minimum count to the direct call target " + "for the promotion")); + +// The percent threshold for the direct-call target (this call site vs the +// total call count) for it to be considered as the promotion target. +static cl::opt<unsigned> + ICPPercentThreshold("icp-percent-threshold", cl::init(33), cl::Hidden, + cl::ZeroOrMore, + cl::desc("The percentage threshold for the promotion")); + +// Set the maximum number of targets to promote for a single indirect-call +// callsite. +static cl::opt<unsigned> + MaxNumPromotions("icp-max-prom", cl::init(2), cl::Hidden, cl::ZeroOrMore, + cl::desc("Max number of promotions for a single indirect " + "call callsite")); + +ICallPromotionAnalysis::ICallPromotionAnalysis() { + ValueDataArray = llvm::make_unique<InstrProfValueData[]>(MaxNumPromotions); +} + +bool ICallPromotionAnalysis::isPromotionProfitable(uint64_t Count, + uint64_t TotalCount) { + if (Count < ICPCountThreshold) + return false; + + unsigned Percentage = (Count * 100) / TotalCount; + return (Percentage >= ICPPercentThreshold); +} + +// Indirect-call promotion heuristic. The direct targets are sorted based on +// the count. Stop at the first target that is not promoted. Returns the +// number of candidates deemed profitable. +uint32_t ICallPromotionAnalysis::getProfitablePromotionCandidates( + const Instruction *Inst, uint32_t NumVals, uint64_t TotalCount) { + ArrayRef<InstrProfValueData> ValueDataRef(ValueDataArray.get(), NumVals); + + DEBUG(dbgs() << " \nWork on callsite " << *Inst << " Num_targets: " << NumVals + << "\n"); + + uint32_t I = 0; + for (; I < MaxNumPromotions && I < NumVals; I++) { + uint64_t Count = ValueDataRef[I].Count; + assert(Count <= TotalCount); + DEBUG(dbgs() << " Candidate " << I << " Count=" << Count + << " Target_func: " << ValueDataRef[I].Value << "\n"); + + if (!isPromotionProfitable(Count, TotalCount)) { + DEBUG(dbgs() << " Not promote: Cold target.\n"); + return I; + } + TotalCount -= Count; + } + return I; +} + +ArrayRef<InstrProfValueData> +ICallPromotionAnalysis::getPromotionCandidatesForInstruction( + const Instruction *I, uint32_t &NumVals, uint64_t &TotalCount, + uint32_t &NumCandidates) { + bool Res = + getValueProfDataFromInst(*I, IPVK_IndirectCallTarget, MaxNumPromotions, + ValueDataArray.get(), NumVals, TotalCount); + if (!Res) { + NumCandidates = 0; + return ArrayRef<InstrProfValueData>(); + } + NumCandidates = getProfitablePromotionCandidates(I, NumVals, TotalCount); + return ArrayRef<InstrProfValueData>(ValueDataArray.get(), NumVals); +} diff --git a/lib/Analysis/InlineCost.cpp b/lib/Analysis/InlineCost.cpp index a86a703ed9d6..dcb724abc02d 100644 --- a/lib/Analysis/InlineCost.cpp +++ b/lib/Analysis/InlineCost.cpp @@ -21,6 +21,7 @@ #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" @@ -39,6 +40,32 @@ using namespace llvm; STATISTIC(NumCallsAnalyzed, "Number of call sites analyzed"); +// Threshold to use when optsize is specified (and there is no +// -inline-threshold). +const int OptSizeThreshold = 75; + +// Threshold to use when -Oz is specified (and there is no -inline-threshold). +const int OptMinSizeThreshold = 25; + +// Threshold to use when -O[34] is specified (and there is no +// -inline-threshold). +const int OptAggressiveThreshold = 275; + +static cl::opt<int> DefaultInlineThreshold( + "inline-threshold", cl::Hidden, cl::init(225), cl::ZeroOrMore, + cl::desc("Control the amount of inlining to perform (default = 225)")); + +static cl::opt<int> HintThreshold( + "inlinehint-threshold", cl::Hidden, cl::init(325), + cl::desc("Threshold for inlining functions with inline hint")); + +// We introduce this threshold to help performance of instrumentation based +// PGO before we actually hook up inliner with analysis passes such as BPI and +// BFI. +static cl::opt<int> ColdThreshold( + "inlinecold-threshold", cl::Hidden, cl::init(225), + cl::desc("Threshold for inlining functions with cold attribute")); + namespace { class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { @@ -51,6 +78,9 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { /// The cache of @llvm.assume intrinsics. AssumptionCacheTracker *ACT; + /// Profile summary information. + ProfileSummaryInfo *PSI; + // The called function. Function &F; @@ -96,7 +126,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { DenseMap<Value *, int> SROAArgCosts; // Keep track of values which map to a pointer base and constant offset. - DenseMap<Value *, std::pair<Value *, APInt> > ConstantOffsetPtrs; + DenseMap<Value *, std::pair<Value *, APInt>> ConstantOffsetPtrs; // Custom simplification helper routines. bool isAllocaDerivedArg(Value *V); @@ -117,19 +147,31 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { /// attributes since these can be more precise than the ones on the callee /// itself. bool paramHasAttr(Argument *A, Attribute::AttrKind Attr); - + /// Return true if the given value is known non null within the callee if /// inlined through this particular callsite. bool isKnownNonNullInCallee(Value *V); + /// Update Threshold based on callsite properties such as callee + /// attributes and callee hotness for PGO builds. The Callee is explicitly + /// passed to support analyzing indirect calls whose target is inferred by + /// analysis. + void updateThreshold(CallSite CS, Function &Callee); + + /// Return true if size growth is allowed when inlining the callee at CS. + bool allowSizeGrowth(CallSite CS); + // Custom analysis routines. bool analyzeBlock(BasicBlock *BB, SmallPtrSetImpl<const Value *> &EphValues); // Disable several entry points to the visitor so we don't accidentally use // them by declaring but not defining them here. - void visit(Module *); void visit(Module &); - void visit(Function *); void visit(Function &); - void visit(BasicBlock *); void visit(BasicBlock &); + void visit(Module *); + void visit(Module &); + void visit(Function *); + void visit(Function &); + void visit(BasicBlock *); + void visit(BasicBlock &); // Provide base case for our instruction visit. bool visitInstruction(Instruction &I); @@ -162,17 +204,19 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { public: CallAnalyzer(const TargetTransformInfo &TTI, AssumptionCacheTracker *ACT, - Function &Callee, int Threshold, CallSite CSArg) - : TTI(TTI), ACT(ACT), F(Callee), CandidateCS(CSArg), Threshold(Threshold), - Cost(0), IsCallerRecursive(false), IsRecursiveCall(false), - ExposesReturnsTwice(false), HasDynamicAlloca(false), - ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), - HasFrameEscape(false), AllocatedSize(0), NumInstructions(0), - NumVectorInstructions(0), FiftyPercentVectorBonus(0), - TenPercentVectorBonus(0), VectorBonus(0), NumConstantArgs(0), - NumConstantOffsetPtrArgs(0), NumAllocaArgs(0), NumConstantPtrCmps(0), - NumConstantPtrDiffs(0), NumInstructionsSimplified(0), - SROACostSavings(0), SROACostSavingsLost(0) {} + ProfileSummaryInfo *PSI, Function &Callee, int Threshold, + CallSite CSArg) + : TTI(TTI), ACT(ACT), PSI(PSI), F(Callee), CandidateCS(CSArg), + Threshold(Threshold), Cost(0), IsCallerRecursive(false), + IsRecursiveCall(false), ExposesReturnsTwice(false), + HasDynamicAlloca(false), ContainsNoDuplicateCall(false), + HasReturn(false), HasIndirectBr(false), HasFrameEscape(false), + AllocatedSize(0), NumInstructions(0), NumVectorInstructions(0), + FiftyPercentVectorBonus(0), TenPercentVectorBonus(0), VectorBonus(0), + NumConstantArgs(0), NumConstantOffsetPtrArgs(0), NumAllocaArgs(0), + NumConstantPtrCmps(0), NumConstantPtrDiffs(0), + NumInstructionsSimplified(0), SROACostSavings(0), + SROACostSavingsLost(0) {} bool analyzeCall(CallSite CS); @@ -272,7 +316,8 @@ bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { OpC = dyn_cast<ConstantInt>(SimpleOp); if (!OpC) return false; - if (OpC->isZero()) continue; + if (OpC->isZero()) + continue; // Handle a struct index, which adds its field offset to the pointer. if (StructType *STy = dyn_cast<StructType>(*GTI)) { @@ -290,13 +335,14 @@ bool CallAnalyzer::accumulateGEPOffset(GEPOperator &GEP, APInt &Offset) { bool CallAnalyzer::visitAlloca(AllocaInst &I) { // Check whether inlining will turn a dynamic alloca into a static - // alloca, and handle that case. + // alloca and handle that case. if (I.isArrayAllocation()) { - if (Constant *Size = SimplifiedValues.lookup(I.getArraySize())) { - ConstantInt *AllocSize = dyn_cast<ConstantInt>(Size); - assert(AllocSize && "Allocation size not a constant int?"); + Constant *Size = SimplifiedValues.lookup(I.getArraySize()); + if (auto *AllocSize = dyn_cast_or_null<ConstantInt>(Size)) { + const DataLayout &DL = F.getParent()->getDataLayout(); Type *Ty = I.getAllocatedType(); - AllocatedSize += Ty->getPrimitiveSizeInBits() * AllocSize->getZExtValue(); + AllocatedSize = SaturatingMultiplyAdd( + AllocSize->getLimitedValue(), DL.getTypeAllocSize(Ty), AllocatedSize); return Base::visitAlloca(I); } } @@ -305,7 +351,7 @@ bool CallAnalyzer::visitAlloca(AllocaInst &I) { if (I.isStaticAlloca()) { const DataLayout &DL = F.getParent()->getDataLayout(); Type *Ty = I.getAllocatedType(); - AllocatedSize += DL.getTypeAllocSize(Ty); + AllocatedSize = SaturatingAdd(DL.getTypeAllocSize(Ty), AllocatedSize); } // We will happily inline static alloca instructions. @@ -336,8 +382,8 @@ bool CallAnalyzer::visitPHI(PHINode &I) { bool CallAnalyzer::visitGetElementPtr(GetElementPtrInst &I) { Value *SROAArg; DenseMap<Value *, int>::iterator CostIt; - bool SROACandidate = lookupSROAArgAndCost(I.getPointerOperand(), - SROAArg, CostIt); + bool SROACandidate = + lookupSROAArgAndCost(I.getPointerOperand(), SROAArg, CostIt); // Try to fold GEPs of constant-offset call site argument pointers. This // requires target data and inbounds GEPs. @@ -393,8 +439,8 @@ bool CallAnalyzer::visitBitCast(BitCastInst &I) { } // Track base/offsets through casts - std::pair<Value *, APInt> BaseAndOffset - = ConstantOffsetPtrs.lookup(I.getOperand(0)); + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(I.getOperand(0)); // Casts don't change the offset, just wrap it up. if (BaseAndOffset.first) ConstantOffsetPtrs[&I] = BaseAndOffset; @@ -425,8 +471,8 @@ bool CallAnalyzer::visitPtrToInt(PtrToIntInst &I) { unsigned IntegerSize = I.getType()->getScalarSizeInBits(); const DataLayout &DL = F.getParent()->getDataLayout(); if (IntegerSize >= DL.getPointerSizeInBits()) { - std::pair<Value *, APInt> BaseAndOffset - = ConstantOffsetPtrs.lookup(I.getOperand(0)); + std::pair<Value *, APInt> BaseAndOffset = + ConstantOffsetPtrs.lookup(I.getOperand(0)); if (BaseAndOffset.first) ConstantOffsetPtrs[&I] = BaseAndOffset; } @@ -501,8 +547,7 @@ bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { COp = SimplifiedValues.lookup(Operand); if (COp) { const DataLayout &DL = F.getParent()->getDataLayout(); - if (Constant *C = ConstantFoldInstOperands(I.getOpcode(), I.getType(), - COp, DL)) { + if (Constant *C = ConstantFoldInstOperands(&I, COp, DL)) { SimplifiedValues[&I] = C; return true; } @@ -516,7 +561,7 @@ bool CallAnalyzer::visitUnaryInstruction(UnaryInstruction &I) { bool CallAnalyzer::paramHasAttr(Argument *A, Attribute::AttrKind Attr) { unsigned ArgNo = A->getArgNo(); - return CandidateCS.paramHasAttr(ArgNo+1, Attr); + return CandidateCS.paramHasAttr(ArgNo + 1, Attr); } bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { @@ -528,7 +573,7 @@ bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { if (Argument *A = dyn_cast<Argument>(V)) if (paramHasAttr(A, Attribute::NonNull)) return true; - + // Is this an alloca in the caller? This is distinct from the attribute case // above because attributes aren't updated within the inliner itself and we // always want to catch the alloca derived case. @@ -537,10 +582,86 @@ bool CallAnalyzer::isKnownNonNullInCallee(Value *V) { // alloca-derived value and null. Note that this fires regardless of // SROA firing. return true; - + return false; } +bool CallAnalyzer::allowSizeGrowth(CallSite CS) { + // If the normal destination of the invoke or the parent block of the call + // site is unreachable-terminated, there is little point in inlining this + // unless there is literally zero cost. + // FIXME: Note that it is possible that an unreachable-terminated block has a + // hot entry. For example, in below scenario inlining hot_call_X() may be + // beneficial : + // main() { + // hot_call_1(); + // ... + // hot_call_N() + // exit(0); + // } + // For now, we are not handling this corner case here as it is rare in real + // code. In future, we should elaborate this based on BPI and BFI in more + // general threshold adjusting heuristics in updateThreshold(). + Instruction *Instr = CS.getInstruction(); + if (InvokeInst *II = dyn_cast<InvokeInst>(Instr)) { + if (isa<UnreachableInst>(II->getNormalDest()->getTerminator())) + return false; + } else if (isa<UnreachableInst>(Instr->getParent()->getTerminator())) + return false; + + return true; +} + +void CallAnalyzer::updateThreshold(CallSite CS, Function &Callee) { + // If no size growth is allowed for this inlining, set Threshold to 0. + if (!allowSizeGrowth(CS)) { + Threshold = 0; + return; + } + + Function *Caller = CS.getCaller(); + if (DefaultInlineThreshold.getNumOccurrences() > 0) { + // Explicitly specified -inline-threhold overrides the threshold passed to + // CallAnalyzer's constructor. + Threshold = DefaultInlineThreshold; + } else { + // If -inline-threshold is not given, listen to the optsize and minsize + // attributes when they would decrease the threshold. + if (Caller->optForMinSize() && OptMinSizeThreshold < Threshold) + Threshold = OptMinSizeThreshold; + else if (Caller->optForSize() && OptSizeThreshold < Threshold) + Threshold = OptSizeThreshold; + } + + bool HotCallsite = false; + uint64_t TotalWeight; + if (CS.getInstruction()->extractProfTotalWeight(TotalWeight) && + PSI->isHotCount(TotalWeight)) + HotCallsite = true; + + // Listen to the inlinehint attribute or profile based hotness information + // when it would increase the threshold and the caller does not need to + // minimize its size. + bool InlineHint = Callee.hasFnAttribute(Attribute::InlineHint) || + PSI->isHotFunction(&Callee) || + HotCallsite; + if (InlineHint && HintThreshold > Threshold && !Caller->optForMinSize()) + Threshold = HintThreshold; + + bool ColdCallee = PSI->isColdFunction(&Callee); + // Command line argument for DefaultInlineThreshold will override the default + // ColdThreshold. If we have -inline-threshold but no -inlinecold-threshold, + // do not use the default cold threshold even if it is smaller. + if ((DefaultInlineThreshold.getNumOccurrences() == 0 || + ColdThreshold.getNumOccurrences() > 0) && + ColdCallee && ColdThreshold < Threshold) + Threshold = ColdThreshold; + + // Finally, take the target-specific inlining threshold multiplier into + // account. + Threshold *= TTI.getInliningThresholdMultiplier(); +} + bool CallAnalyzer::visitCmpInst(CmpInst &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); // First try to handle simplified comparisons. @@ -552,7 +673,8 @@ bool CallAnalyzer::visitCmpInst(CmpInst &I) { RHS = SimpleRHS; if (Constant *CLHS = dyn_cast<Constant>(LHS)) { if (Constant *CRHS = dyn_cast<Constant>(RHS)) - if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { + if (Constant *C = + ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { SimplifiedValues[&I] = C; return true; } @@ -713,8 +835,8 @@ bool CallAnalyzer::visitInsertValue(InsertValueInst &I) { if (!InsertedC) InsertedC = SimplifiedValues.lookup(I.getInsertedValueOperand()); if (AggC && InsertedC) { - SimplifiedValues[&I] = ConstantExpr::getInsertValue(AggC, InsertedC, - I.getIndices()); + SimplifiedValues[&I] = + ConstantExpr::getInsertValue(AggC, InsertedC, I.getIndices()); return true; } @@ -739,8 +861,8 @@ bool CallAnalyzer::simplifyCallSite(Function *F, CallSite CS) { // Try to re-map the arguments to constants. SmallVector<Constant *, 4> ConstantArgs; ConstantArgs.reserve(CS.arg_size()); - for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); - I != E; ++I) { + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I) { Constant *C = dyn_cast<Constant>(*I); if (!C) C = dyn_cast_or_null<Constant>(SimplifiedValues.lookup(*I)); @@ -764,8 +886,7 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { ExposesReturnsTwice = true; return false; } - if (CS.isCall() && - cast<CallInst>(CS.getInstruction())->cannotDuplicate()) + if (CS.isCall() && cast<CallInst>(CS.getInstruction())->cannotDuplicate()) ContainsNoDuplicateCall = true; if (Function *F = CS.getCalledFunction()) { @@ -780,6 +901,11 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { default: return Base::visitCallSite(CS); + case Intrinsic::load_relative: + // This is normally lowered to 4 LLVM instructions. + Cost += 3 * InlineConstants::InstrCost; + return false; + case Intrinsic::memset: case Intrinsic::memcpy: case Intrinsic::memmove: @@ -831,7 +957,8 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { // during devirtualization and so we want to give it a hefty bonus for // inlining, but cap that bonus in the event that inlining wouldn't pan // out. Pretend to inline the function, with a custom threshold. - CallAnalyzer CA(TTI, ACT, *F, InlineConstants::IndirectCallThreshold, CS); + CallAnalyzer CA(TTI, ACT, PSI, *F, InlineConstants::IndirectCallThreshold, + CS); if (CA.analyzeCall(CS)) { // We were able to inline the indirect call! Subtract the cost from the // threshold to get the bonus we want to apply, but don't go below zero. @@ -938,7 +1065,6 @@ bool CallAnalyzer::visitInstruction(Instruction &I) { return false; } - /// \brief Analyze a basic block for its contribution to the inline cost. /// /// This method walks the analyzer over every instruction in the given basic @@ -1044,7 +1170,7 @@ ConstantInt *CallAnalyzer::stripAndComputeInBoundsConstantOffsets(Value *&V) { } else if (Operator::getOpcode(V) == Instruction::BitCast) { V = cast<Operator>(V)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) break; V = GA->getAliasee(); } else { @@ -1079,6 +1205,10 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // nice to base the bonus values on something more scientific. assert(NumInstructions == 0); assert(NumVectorInstructions == 0); + + // Update the threshold based on callsite properties + updateThreshold(CS, F); + FiftyPercentVectorBonus = 3 * Threshold / 2; TenPercentVectorBonus = 3 * Threshold / 4; const DataLayout &DL = F.getParent()->getDataLayout(); @@ -1124,22 +1254,11 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // If there is only one call of the function, and it has internal linkage, // the cost of inlining it drops dramatically. - bool OnlyOneCallAndLocalLinkage = F.hasLocalLinkage() && F.hasOneUse() && - &F == CS.getCalledFunction(); + bool OnlyOneCallAndLocalLinkage = + F.hasLocalLinkage() && F.hasOneUse() && &F == CS.getCalledFunction(); if (OnlyOneCallAndLocalLinkage) Cost += InlineConstants::LastCallToStaticBonus; - // If the instruction after the call, or if the normal destination of the - // invoke is an unreachable instruction, the function is noreturn. As such, - // there is little point in inlining this unless there is literally zero - // cost. - Instruction *Instr = CS.getInstruction(); - if (InvokeInst *II = dyn_cast<InvokeInst>(Instr)) { - if (isa<UnreachableInst>(II->getNormalDest()->begin())) - Threshold = 0; - } else if (isa<UnreachableInst>(++BasicBlock::iterator(Instr))) - Threshold = 0; - // If this function uses the coldcc calling convention, prefer not to inline // it. if (F.getCallingConv() == CallingConv::Cold) @@ -1193,7 +1312,8 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // the ephemeral values multiple times (and they're completely determined by // the callee, so this is purely duplicate work). SmallPtrSet<const Value *, 32> EphValues; - CodeMetrics::collectEphemeralValues(&F, &ACT->getAssumptionCache(F), EphValues); + CodeMetrics::collectEphemeralValues(&F, &ACT->getAssumptionCache(F), + EphValues); // The worklist of live basic blocks in the callee *after* inlining. We avoid // adding basic blocks of the callee which can be proven to be dead for this @@ -1203,7 +1323,8 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // accomplish this, prioritizing for small iterations because we exit after // crossing our threshold, we use a small-size optimized SetVector. typedef SetVector<BasicBlock *, SmallVector<BasicBlock *, 16>, - SmallPtrSet<BasicBlock *, 16> > BBSetVector; + SmallPtrSet<BasicBlock *, 16>> + BBSetVector; BBSetVector BBWorklist; BBWorklist.insert(&F.getEntryBlock()); // Note that we *must not* cache the size, this loop grows the worklist. @@ -1228,20 +1349,8 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // Analyze the cost of this block. If we blow through the threshold, this // returns false, and we can bail on out. - if (!analyzeBlock(BB, EphValues)) { - if (IsRecursiveCall || ExposesReturnsTwice || HasDynamicAlloca || - HasIndirectBr || HasFrameEscape) - return false; - - // If the caller is a recursive function then we don't want to inline - // functions which allocate a lot of stack space because it would increase - // the caller stack usage dramatically. - if (IsCallerRecursive && - AllocatedSize > InlineConstants::TotalAllocaSizeRecursiveCaller) - return false; - - break; - } + if (!analyzeBlock(BB, EphValues)) + return false; TerminatorInst *TI = BB->getTerminator(); @@ -1250,16 +1359,16 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { if (BI->isConditional()) { Value *Cond = BI->getCondition(); - if (ConstantInt *SimpleCond - = dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { + if (ConstantInt *SimpleCond = + dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { BBWorklist.insert(BI->getSuccessor(SimpleCond->isZero() ? 1 : 0)); continue; } } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { Value *Cond = SI->getCondition(); - if (ConstantInt *SimpleCond - = dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { + if (ConstantInt *SimpleCond = + dyn_cast_or_null<ConstantInt>(SimplifiedValues.lookup(Cond))) { BBWorklist.insert(SI->findCaseValue(SimpleCond).getCaseSuccessor()); continue; } @@ -1296,12 +1405,12 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { else if (NumVectorInstructions <= NumInstructions / 2) Threshold -= (FiftyPercentVectorBonus - TenPercentVectorBonus); - return Cost <= std::max(0, Threshold); + return Cost < std::max(1, Threshold); } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) /// \brief Dump stats about this call's analysis. -void CallAnalyzer::dump() { +LLVM_DUMP_METHOD void CallAnalyzer::dump() { #define DEBUG_PRINT_STAT(x) dbgs() << " " #x ": " << x << "\n" DEBUG_PRINT_STAT(NumConstantArgs); DEBUG_PRINT_STAT(NumConstantOffsetPtrArgs); @@ -1321,7 +1430,7 @@ void CallAnalyzer::dump() { /// \brief Test that two functions either have or have not the given attribute /// at the same time. -template<typename AttrKind> +template <typename AttrKind> static bool attributeMatches(Function *F1, Function *F2, AttrKind Attr) { return F1->getFnAttribute(Attr) == F2->getFnAttribute(Attr); } @@ -1335,15 +1444,33 @@ static bool functionsHaveCompatibleAttributes(Function *Caller, AttributeFuncs::areInlineCompatible(*Caller, *Callee); } -InlineCost llvm::getInlineCost(CallSite CS, int Threshold, +InlineCost llvm::getInlineCost(CallSite CS, int DefaultThreshold, TargetTransformInfo &CalleeTTI, - AssumptionCacheTracker *ACT) { - return getInlineCost(CS, CS.getCalledFunction(), Threshold, CalleeTTI, ACT); + AssumptionCacheTracker *ACT, + ProfileSummaryInfo *PSI) { + return getInlineCost(CS, CS.getCalledFunction(), DefaultThreshold, CalleeTTI, + ACT, PSI); +} + +int llvm::computeThresholdFromOptLevels(unsigned OptLevel, + unsigned SizeOptLevel) { + if (OptLevel > 2) + return OptAggressiveThreshold; + if (SizeOptLevel == 1) // -Os + return OptSizeThreshold; + if (SizeOptLevel == 2) // -Oz + return OptMinSizeThreshold; + return DefaultInlineThreshold; } -InlineCost llvm::getInlineCost(CallSite CS, Function *Callee, int Threshold, +int llvm::getDefaultInlineThreshold() { return DefaultInlineThreshold; } + +InlineCost llvm::getInlineCost(CallSite CS, Function *Callee, + int DefaultThreshold, TargetTransformInfo &CalleeTTI, - AssumptionCacheTracker *ACT) { + AssumptionCacheTracker *ACT, + ProfileSummaryInfo *PSI) { + // Cannot inline indirect calls. if (!Callee) return llvm::InlineCost::getNever(); @@ -1365,17 +1492,18 @@ InlineCost llvm::getInlineCost(CallSite CS, Function *Callee, int Threshold, if (CS.getCaller()->hasFnAttribute(Attribute::OptimizeNone)) return llvm::InlineCost::getNever(); - // Don't inline functions which can be redefined at link-time to mean - // something else. Don't inline functions marked noinline or call sites - // marked noinline. - if (Callee->mayBeOverridden() || - Callee->hasFnAttribute(Attribute::NoInline) || CS.isNoInline()) + // Don't inline functions which can be interposed at link-time. Don't inline + // functions marked noinline or call sites marked noinline. + // Note: inlining non-exact non-interposable fucntions is fine, since we know + // we have *a* correct implementation of the source level function. + if (Callee->isInterposable() || Callee->hasFnAttribute(Attribute::NoInline) || + CS.isNoInline()) return llvm::InlineCost::getNever(); DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() - << "...\n"); + << "...\n"); - CallAnalyzer CA(CalleeTTI, ACT, *Callee, Threshold, CS); + CallAnalyzer CA(CalleeTTI, ACT, PSI, *Callee, DefaultThreshold, CS); bool ShouldInline = CA.analyzeCall(CS); DEBUG(CA.dump()); diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 6dfe62596275..0cb2c78afb40 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" @@ -528,11 +529,8 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS, static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Add, CLHS->getType(), Ops, - Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Add, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -619,10 +617,15 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V, } else if (Operator::getOpcode(V) == Instruction::BitCast) { V = cast<Operator>(V)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) break; V = GA->getAliasee(); } else { + if (auto CS = CallSite(V)) + if (Value *RV = CS.getReturnedArgOperand()) { + V = RV; + continue; + } break; } assert(V->getType()->getScalarType()->isPointerTy() && @@ -660,11 +663,8 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS, static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Sub, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Sub, CLHS, CRHS, Q.DL); // X - undef -> undef // undef - X -> undef @@ -787,11 +787,8 @@ Value *llvm::SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FAdd, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FAdd, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -803,7 +800,7 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, // fadd X, 0 ==> X, when we know X is not -0 if (match(Op1, m_Zero()) && - (FMF.noSignedZeros() || CannotBeNegativeZero(Op0))) + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; // fadd [nnan ninf] X, (fsub [nnan ninf] 0, X) ==> 0 @@ -829,11 +826,8 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FSub, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FSub, CLHS, CRHS, Q.DL); } // fsub X, 0 ==> X @@ -842,17 +836,18 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, // fsub X, -0 ==> X, when we know X is not -0 if (match(Op1, m_NegZero()) && - (FMF.noSignedZeros() || CannotBeNegativeZero(Op0))) + (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI))) return Op0; - // fsub 0, (fsub -0.0, X) ==> X + // fsub -0.0, (fsub -0.0, X) ==> X Value *X; - if (match(Op0, m_AnyZero())) { - if (match(Op1, m_FSub(m_NegZero(), m_Value(X)))) - return X; - if (FMF.noSignedZeros() && match(Op1, m_FSub(m_AnyZero(), m_Value(X)))) - return X; - } + if (match(Op0, m_NegZero()) && match(Op1, m_FSub(m_NegZero(), m_Value(X)))) + return X; + + // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. + if (FMF.noSignedZeros() && match(Op0, m_AnyZero()) && + match(Op1, m_FSub(m_AnyZero(), m_Value(X)))) + return X; // fsub nnan x, x ==> 0.0 if (FMF.noNaNs() && Op0 == Op1) @@ -867,11 +862,8 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::FMul, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::FMul, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -893,11 +885,8 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, static Value *SimplifyMulInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Mul, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Mul, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -992,12 +981,9 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const DataLayout &DL, /// If not, this returns null. static Value *SimplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) { - if (Constant *C1 = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast<Constant>(Op0)) + if (Constant *C1 = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); bool isSigned = Opcode == Instruction::SDiv; @@ -1157,12 +1143,9 @@ Value *llvm::SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, /// If not, this returns null. static Value *SimplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) { - if (Constant *C1 = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast<Constant>(Op0)) + if (Constant *C1 = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); // X % undef -> undef if (match(Op1, m_Undef())) @@ -1309,12 +1292,9 @@ static bool isUndefShift(Value *Amount) { /// If not, this returns null. static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { - if (Constant *C0 = dyn_cast<Constant>(Op0)) { - if (Constant *C1 = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { C0, C1 }; - return ConstantFoldInstOperands(Opcode, C0->getType(), Ops, Q.DL, Q.TLI); - } - } + if (Constant *C0 = dyn_cast<Constant>(Op0)) + if (Constant *C1 = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Opcode, C0, C1, Q.DL); // 0 shift by X -> 0 if (match(Op0, m_Zero())) @@ -1340,6 +1320,22 @@ static Value *SimplifyShift(unsigned Opcode, Value *Op0, Value *Op1, if (Value *V = ThreadBinOpOverPHI(Opcode, Op0, Op1, Q, MaxRecurse)) return V; + // If any bits in the shift amount make that value greater than or equal to + // the number of bits in the type, the shift is undefined. + unsigned BitWidth = Op1->getType()->getScalarSizeInBits(); + APInt KnownZero(BitWidth, 0); + APInt KnownOne(BitWidth, 0); + computeKnownBits(Op1, KnownZero, KnownOne, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (KnownOne.getLimitedValue() >= BitWidth) + return UndefValue::get(Op0->getType()); + + // If all valid bits in the shift amount are known zero, the first operand is + // unchanged. + unsigned NumValidShiftBits = Log2_32_Ceil(BitWidth); + APInt ShiftAmountMask = APInt::getLowBitsSet(BitWidth, NumValidShiftBits); + if ((KnownZero & ShiftAmountMask) == ShiftAmountMask) + return Op0; + return nullptr; } @@ -1501,9 +1497,8 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp, return nullptr; } -/// Simplify (and (icmp ...) (icmp ...)) to true when we can tell that the range -/// of possible values cannot be satisfied. static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { + Type *ITy = Op0->getType(); ICmpInst::Predicate Pred0, Pred1; ConstantInt *CI1, *CI2; Value *V; @@ -1511,15 +1506,25 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) return X; + // Look for this pattern: (icmp V, C0) & (icmp V, C1)). + const APInt *C0, *C1; + if (match(Op0, m_ICmp(Pred0, m_Value(V), m_APInt(C0))) && + match(Op1, m_ICmp(Pred1, m_Specific(V), m_APInt(C1)))) { + // Make a constant range that's the intersection of the two icmp ranges. + // If the intersection is empty, we know that the result is false. + auto Range0 = ConstantRange::makeAllowedICmpRegion(Pred0, *C0); + auto Range1 = ConstantRange::makeAllowedICmpRegion(Pred1, *C1); + if (Range0.intersectWith(Range1).isEmptySet()) + return getFalse(ITy); + } + if (!match(Op0, m_ICmp(Pred0, m_Add(m_Value(V), m_ConstantInt(CI1)), m_ConstantInt(CI2)))) - return nullptr; + return nullptr; if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Specific(CI1)))) return nullptr; - Type *ITy = Op0->getType(); - auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); bool isNSW = AddInst->hasNoSignedWrap(); bool isNUW = AddInst->hasNoUnsignedWrap(); @@ -1558,11 +1563,8 @@ static Value *SimplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::And, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::And, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -1620,6 +1622,24 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const Query &Q, } } + // The compares may be hidden behind casts. Look through those and try the + // same folds as above. + auto *Cast0 = dyn_cast<CastInst>(Op0); + auto *Cast1 = dyn_cast<CastInst>(Op1); + if (Cast0 && Cast1 && Cast0->getOpcode() == Cast1->getOpcode() && + Cast0->getSrcTy() == Cast1->getSrcTy()) { + auto *Cmp0 = dyn_cast<ICmpInst>(Cast0->getOperand(0)); + auto *Cmp1 = dyn_cast<ICmpInst>(Cast1->getOperand(0)); + if (Cmp0 && Cmp1) { + Instruction::CastOps CastOpc = Cast0->getOpcode(); + Type *ResultType = Cast0->getType(); + if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp0, Cmp1))) + return ConstantExpr::getCast(CastOpc, V, ResultType); + if (auto *V = dyn_cast_or_null<Constant>(SimplifyAndOfICmps(Cmp1, Cmp0))) + return ConstantExpr::getCast(CastOpc, V, ResultType); + } + } + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::And, Op0, Op1, Q, MaxRecurse)) @@ -1717,11 +1737,8 @@ static Value *SimplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { static Value *SimplifyOrInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Or, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Or, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -1853,11 +1870,8 @@ Value *llvm::SimplifyOrInst(Value *Op0, Value *Op1, const DataLayout &DL, static Value *SimplifyXorInst(Value *Op0, Value *Op1, const Query &Q, unsigned MaxRecurse) { if (Constant *CLHS = dyn_cast<Constant>(Op0)) { - if (Constant *CRHS = dyn_cast<Constant>(Op1)) { - Constant *Ops[] = { CLHS, CRHS }; - return ConstantFoldInstOperands(Instruction::Xor, CLHS->getType(), - Ops, Q.DL, Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(Op1)) + return ConstantFoldBinaryOpOperands(Instruction::Xor, CLHS, CRHS, Q.DL); // Canonicalize the constant to the RHS. std::swap(Op0, Op1); @@ -1957,16 +1971,16 @@ static Value *ExtractEquivalentCondition(Value *V, CmpInst::Predicate Pred, // If the C and C++ standards are ever made sufficiently restrictive in this // area, it may be possible to update LLVM's semantics accordingly and reinstate // this optimization. -static Constant *computePointerICmp(const DataLayout &DL, - const TargetLibraryInfo *TLI, - CmpInst::Predicate Pred, Value *LHS, - Value *RHS) { +static Constant * +computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, + const DominatorTree *DT, CmpInst::Predicate Pred, + const Instruction *CxtI, Value *LHS, Value *RHS) { // First, skip past any trivial no-ops. LHS = LHS->stripPointerCasts(); RHS = RHS->stripPointerCasts(); // A non-null pointer is not equal to a null pointer. - if (llvm::isKnownNonNull(LHS, TLI) && isa<ConstantPointerNull>(RHS) && + if (llvm::isKnownNonNull(LHS) && isa<ConstantPointerNull>(RHS) && (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)) return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); @@ -2104,7 +2118,7 @@ static Constant *computePointerICmp(const DataLayout &DL, return AI->getParent() && AI->getFunction() && AI->isStaticAlloca(); if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) return (GV->hasLocalLinkage() || GV->hasHiddenVisibility() || - GV->hasProtectedVisibility() || GV->hasUnnamedAddr()) && + GV->hasProtectedVisibility() || GV->hasGlobalUnnamedAddr()) && !GV->isThreadLocal(); if (const Argument *A = dyn_cast<Argument>(V)) return A->hasByValAttr(); @@ -2116,6 +2130,20 @@ static Constant *computePointerICmp(const DataLayout &DL, (IsNAC(RHSUObjs) && IsAllocDisjoint(LHSUObjs))) return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); + + // Fold comparisons for non-escaping pointer even if the allocation call + // cannot be elided. We cannot fold malloc comparison to null. Also, the + // dynamic allocation call could be either of the operands. + Value *MI = nullptr; + if (isAllocLikeFn(LHS, TLI) && llvm::isKnownNonNullAt(RHS, CxtI, DT)) + MI = LHS; + else if (isAllocLikeFn(RHS, TLI) && llvm::isKnownNonNullAt(LHS, CxtI, DT)) + MI = RHS; + // FIXME: We should also fold the compare when the pointer escapes, but the + // compare dominates the pointer escape + if (MI && !PointerMayBeCaptured(MI, true, true)) + return ConstantInt::get(GetCompareTy(LHS), + CmpInst::isFalseWhenEqual(Pred)); } // Otherwise, fail. @@ -2166,24 +2194,26 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (match(RHS, m_Zero())) return LHS; break; - case ICmpInst::ICMP_UGE: + case ICmpInst::ICMP_UGE: { // X >=u 1 -> X if (match(RHS, m_One())) return LHS; - if (isImpliedCondition(RHS, LHS, Q.DL)) + if (isImpliedCondition(RHS, LHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; - case ICmpInst::ICMP_SGE: - /// For signed comparison, the values for an i1 are 0 and -1 + } + case ICmpInst::ICMP_SGE: { + /// For signed comparison, the values for an i1 are 0 and -1 /// respectively. This maps into a truth table of: /// LHS | RHS | LHS >=s RHS | LHS implies RHS /// 0 | 0 | 1 (0 >= 0) | 1 /// 0 | 1 | 1 (0 >= -1) | 1 /// 1 | 0 | 0 (-1 >= 0) | 0 /// 1 | 1 | 1 (-1 >= -1) | 1 - if (isImpliedCondition(LHS, RHS, Q.DL)) + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; + } case ICmpInst::ICMP_SLT: // X <s 0 -> X if (match(RHS, m_Zero())) @@ -2194,11 +2224,12 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (match(RHS, m_One())) return LHS; break; - case ICmpInst::ICMP_ULE: - if (isImpliedCondition(LHS, RHS, Q.DL)) + case ICmpInst::ICMP_ULE: { + if (isImpliedCondition(LHS, RHS, Q.DL).getValueOr(false)) return getTrue(ITy); break; } + } } // If we are comparing with zero then try hard since this is a common case. @@ -2300,7 +2331,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } else if (match(LHS, m_SDiv(m_Value(), m_ConstantInt(CI2)))) { APInt IntMin = APInt::getSignedMinValue(Width); APInt IntMax = APInt::getSignedMaxValue(Width); - APInt Val = CI2->getValue(); + const APInt &Val = CI2->getValue(); if (Val.isAllOnesValue()) { // 'sdiv x, -1' produces [INT_MIN + 1, INT_MAX] // where CI2 != -1 and CI2 != 0 and CI2 != 1 @@ -2581,7 +2612,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, return Pred == ICmpInst::ICMP_NE ? ConstantInt::getTrue(Ctx) : ConstantInt::getFalse(Ctx); } - + // Special logic for binary operators. BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); @@ -2645,21 +2676,48 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } - // icmp pred (or X, Y), X - if (LBO && match(LBO, m_CombineOr(m_Or(m_Value(), m_Specific(RHS)), - m_Or(m_Specific(RHS), m_Value())))) { - if (Pred == ICmpInst::ICMP_ULT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_UGE) - return getTrue(ITy); - } - // icmp pred X, (or X, Y) - if (RBO && match(RBO, m_CombineOr(m_Or(m_Value(), m_Specific(LHS)), - m_Or(m_Specific(LHS), m_Value())))) { - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); + { + Value *Y = nullptr; + // icmp pred (or X, Y), X + if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { + bool RHSKnownNonNegative, RHSKnownNegative; + bool YKnownNonNegative, YKnownNegative; + ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, Q.DL, 0, + Q.AC, Q.CxtI, Q.DT); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (RHSKnownNonNegative && YKnownNegative) + return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); + if (RHSKnownNegative || YKnownNonNegative) + return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); + } + } + // icmp pred X, (or X, Y) + if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) { + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { + bool LHSKnownNonNegative, LHSKnownNegative; + bool YKnownNonNegative, YKnownNegative; + ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, + Q.AC, Q.CxtI, Q.DT); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, + Q.CxtI, Q.DT); + if (LHSKnownNonNegative && YKnownNegative) + return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); + if (LHSKnownNegative || YKnownNonNegative) + return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); + } + } } // icmp pred (and X, Y), X @@ -2763,9 +2821,11 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } + // x >> y <=u x // x udiv y <=u x. - if (LBO && match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) { - // icmp pred (X /u Y), X + if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || + match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) { + // icmp pred (X op Y), X if (Pred == ICmpInst::ICMP_UGT) return getFalse(ITy); if (Pred == ICmpInst::ICMP_ULE) @@ -3030,7 +3090,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) - if (Constant *C = computePointerICmp(Q.DL, Q.TLI, Pred, LHS, RHS)) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.CxtI, LHS, RHS)) return C; if (GetElementPtrInst *GLHS = dyn_cast<GetElementPtrInst>(LHS)) { @@ -3145,7 +3205,14 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } // Handle fcmp with constant RHS - if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) { + const ConstantFP *CFP = nullptr; + if (const auto *RHSC = dyn_cast<Constant>(RHS)) { + if (RHS->getType()->isVectorTy()) + CFP = dyn_cast_or_null<ConstantFP>(RHSC->getSplatValue()); + else + CFP = dyn_cast<ConstantFP>(RHSC); + } + if (CFP) { // If the constant is a nan, see if we can fold the comparison based on it. if (CFP->getValueAPF().isNaN()) { if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" @@ -3153,7 +3220,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, assert(FCmpInst::isUnordered(Pred) && "Comparison must be either ordered or unordered!"); // True if unordered. - return ConstantInt::getTrue(CFP->getContext()); + return ConstantInt::get(GetCompareTy(LHS), 1); } // Check whether the constant is an infinity. if (CFP->getValueAPF().isInfinity()) { @@ -3161,10 +3228,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OLT: // No value is ordered and less than negative infinity. - return ConstantInt::getFalse(CFP->getContext()); + return ConstantInt::get(GetCompareTy(LHS), 0); case FCmpInst::FCMP_UGE: // All values are unordered with or at least negative infinity. - return ConstantInt::getTrue(CFP->getContext()); + return ConstantInt::get(GetCompareTy(LHS), 1); default: break; } @@ -3172,10 +3239,10 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, switch (Pred) { case FCmpInst::FCMP_OGT: // No value is ordered and greater than infinity. - return ConstantInt::getFalse(CFP->getContext()); + return ConstantInt::get(GetCompareTy(LHS), 0); case FCmpInst::FCMP_ULE: // All values are unordered with and at most infinity. - return ConstantInt::getTrue(CFP->getContext()); + return ConstantInt::get(GetCompareTy(LHS), 1); default: break; } @@ -3184,13 +3251,13 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (CFP->getValueAPF().isZero()) { switch (Pred) { case FCmpInst::FCMP_UGE: - if (CannotBeOrderedLessThanZero(LHS)) - return ConstantInt::getTrue(CFP->getContext()); + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return ConstantInt::get(GetCompareTy(LHS), 1); break; case FCmpInst::FCMP_OLT: // X < 0 - if (CannotBeOrderedLessThanZero(LHS)) - return ConstantInt::getFalse(CFP->getContext()); + if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return ConstantInt::get(GetCompareTy(LHS), 0); break; default: break; @@ -3295,10 +3362,9 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (LoadInst *LI = dyn_cast<LoadInst>(I)) if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(ConstOps[0], Q.DL); + return ConstantFoldLoadFromConstPtr(ConstOps[0], LI->getType(), Q.DL); - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), ConstOps, - Q.DL, Q.TLI); + return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); } } @@ -3527,13 +3593,13 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, Ops.slice(1)); } -Value *llvm::SimplifyGEPInst(ArrayRef<Value *> Ops, const DataLayout &DL, +Value *llvm::SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, + const DataLayout &DL, const TargetLibraryInfo *TLI, const DominatorTree *DT, AssumptionCache *AC, const Instruction *CxtI) { - return ::SimplifyGEPInst( - cast<PointerType>(Ops[0]->getType()->getScalarType())->getElementType(), - Ops, Query(DL, TLI, DT, AC, CxtI), RecursionLimit); + return ::SimplifyGEPInst(SrcTy, Ops, + Query(DL, TLI, DT, AC, CxtI), RecursionLimit); } /// Given operands for an InsertValueInst, see if we can fold the result. @@ -3675,7 +3741,7 @@ static Value *SimplifyPHINode(PHINode *PN, const Query &Q) { static Value *SimplifyTruncInst(Value *Op, Type *Ty, const Query &Q, unsigned) { if (Constant *C = dyn_cast<Constant>(Op)) - return ConstantFoldInstOperands(Instruction::Trunc, Ty, C, Q.DL, Q.TLI); + return ConstantFoldCastOperand(Instruction::Trunc, C, Ty, Q.DL); return nullptr; } @@ -3730,11 +3796,8 @@ static Value *SimplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS, case Instruction::Xor: return SimplifyXorInst(LHS, RHS, Q, MaxRecurse); default: if (Constant *CLHS = dyn_cast<Constant>(LHS)) - if (Constant *CRHS = dyn_cast<Constant>(RHS)) { - Constant *COps[] = {CLHS, CRHS}; - return ConstantFoldInstOperands(Opcode, LHS->getType(), COps, Q.DL, - Q.TLI); - } + if (Constant *CRHS = dyn_cast<Constant>(RHS)) + return ConstantFoldBinaryOpOperands(Opcode, CLHS, CRHS, Q.DL); // If the operation is associative, try some generic simplifications. if (Instruction::isAssociative(Opcode)) @@ -3825,6 +3888,78 @@ static bool IsIdempotent(Intrinsic::ID ID) { } } +static Value *SimplifyRelativeLoad(Constant *Ptr, Constant *Offset, + const DataLayout &DL) { + GlobalValue *PtrSym; + APInt PtrOffset; + if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL)) + return nullptr; + + Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext()); + Type *Int32Ty = Type::getInt32Ty(Ptr->getContext()); + Type *Int32PtrTy = Int32Ty->getPointerTo(); + Type *Int64Ty = Type::getInt64Ty(Ptr->getContext()); + + auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset); + if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64) + return nullptr; + + uint64_t OffsetInt = OffsetConstInt->getSExtValue(); + if (OffsetInt % 4 != 0) + return nullptr; + + Constant *C = ConstantExpr::getGetElementPtr( + Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy), + ConstantInt::get(Int64Ty, OffsetInt / 4)); + Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL); + if (!Loaded) + return nullptr; + + auto *LoadedCE = dyn_cast<ConstantExpr>(Loaded); + if (!LoadedCE) + return nullptr; + + if (LoadedCE->getOpcode() == Instruction::Trunc) { + LoadedCE = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); + if (!LoadedCE) + return nullptr; + } + + if (LoadedCE->getOpcode() != Instruction::Sub) + return nullptr; + + auto *LoadedLHS = dyn_cast<ConstantExpr>(LoadedCE->getOperand(0)); + if (!LoadedLHS || LoadedLHS->getOpcode() != Instruction::PtrToInt) + return nullptr; + auto *LoadedLHSPtr = LoadedLHS->getOperand(0); + + Constant *LoadedRHS = LoadedCE->getOperand(1); + GlobalValue *LoadedRHSSym; + APInt LoadedRHSOffset; + if (!IsConstantOffsetFromGlobal(LoadedRHS, LoadedRHSSym, LoadedRHSOffset, + DL) || + PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset) + return nullptr; + + return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); +} + +static bool maskIsAllZeroOrUndef(Value *Mask) { + auto *ConstMask = dyn_cast<Constant>(Mask); + if (!ConstMask) + return false; + if (ConstMask->isNullValue() || isa<UndefValue>(ConstMask)) + return true; + for (unsigned I = 0, E = ConstMask->getType()->getVectorNumElements(); I != E; + ++I) { + if (auto *MaskElt = ConstMask->getAggregateElement(I)) + if (MaskElt->isNullValue() || isa<UndefValue>(MaskElt)) + continue; + return false; + } + return true; +} + template <typename IterTy> static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, const Query &Q, unsigned MaxRecurse) { @@ -3865,6 +4000,20 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, if (match(RHS, m_Undef())) return Constant::getNullValue(ReturnType); } + + if (IID == Intrinsic::load_relative && isa<Constant>(LHS) && + isa<Constant>(RHS)) + return SimplifyRelativeLoad(cast<Constant>(LHS), cast<Constant>(RHS), + Q.DL); + } + + // Simplify calls to llvm.masked.load.* + if (IID == Intrinsic::masked_load) { + Value *MaskArg = ArgBegin[2]; + Value *PassthruArg = ArgBegin[3]; + // If the mask is all zeros or undef, the "passthru" argument is the result. + if (maskIsAllZeroOrUndef(MaskArg)) + return PassthruArg; } // Perform idempotent optimizations @@ -3889,7 +4038,8 @@ static Value *SimplifyCall(Value *V, IterTy ArgBegin, IterTy ArgEnd, FunctionType *FTy = cast<FunctionType>(Ty); // call undef -> undef - if (isa<UndefValue>(V)) + // call null -> undef + if (isa<UndefValue>(V) || isa<ConstantPointerNull>(V)) return UndefValue::get(FTy->getReturnType()); Function *F = dyn_cast<Function>(V); @@ -4038,7 +4188,8 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, break; case Instruction::GetElementPtr: { SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); - Result = SimplifyGEPInst(Ops, DL, TLI, DT, AC, I); + Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), + Ops, DL, TLI, DT, AC, I); break; } case Instruction::InsertValue: { @@ -4092,7 +4243,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const DataLayout &DL, return Result == I ? UndefValue::get(I->getType()) : Result; } -/// \brief Implementation of recursive simplification through an instructions +/// \brief Implementation of recursive simplification through an instruction's /// uses. /// /// This is the common implementation of the recursive simplification routines. diff --git a/lib/Analysis/Interval.cpp b/lib/Analysis/Interval.cpp index e3e785ffc45f..6c10d73bcb44 100644 --- a/lib/Analysis/Interval.cpp +++ b/lib/Analysis/Interval.cpp @@ -42,17 +42,14 @@ void Interval::print(raw_ostream &OS) const { << "Interval Contents:\n"; // Print out all of the basic blocks in the interval... - for (std::vector<BasicBlock*>::const_iterator I = Nodes.begin(), - E = Nodes.end(); I != E; ++I) - OS << **I << "\n"; + for (const BasicBlock *Node : Nodes) + OS << *Node << "\n"; OS << "Interval Predecessors:\n"; - for (std::vector<BasicBlock*>::const_iterator I = Predecessors.begin(), - E = Predecessors.end(); I != E; ++I) - OS << **I << "\n"; + for (const BasicBlock *Predecessor : Predecessors) + OS << *Predecessor << "\n"; OS << "Interval Successors:\n"; - for (std::vector<BasicBlock*>::const_iterator I = Successors.begin(), - E = Successors.end(); I != E; ++I) - OS << **I << "\n"; + for (const BasicBlock *Successor : Successors) + OS << *Successor << "\n"; } diff --git a/lib/Analysis/IntervalPartition.cpp b/lib/Analysis/IntervalPartition.cpp index a0583e86d185..a4e56e0694bc 100644 --- a/lib/Analysis/IntervalPartition.cpp +++ b/lib/Analysis/IntervalPartition.cpp @@ -57,9 +57,8 @@ void IntervalPartition::addIntervalToPartition(Interval *I) { // void IntervalPartition::updatePredecessors(Interval *Int) { BasicBlock *Header = Int->getHeaderNode(); - for (Interval::succ_iterator I = Int->Successors.begin(), - E = Int->Successors.end(); I != E; ++I) - getBlockInterval(*I)->Predecessors.push_back(Header); + for (BasicBlock *Successor : Int->Successors) + getBlockInterval(Successor)->Predecessors.push_back(Header); } // IntervalPartition ctor - Build the first level interval partition for the diff --git a/lib/Analysis/IteratedDominanceFrontier.cpp b/lib/Analysis/IteratedDominanceFrontier.cpp index 9f1edd21820f..3ab6b5d60905 100644 --- a/lib/Analysis/IteratedDominanceFrontier.cpp +++ b/lib/Analysis/IteratedDominanceFrontier.cpp @@ -16,9 +16,10 @@ #include "llvm/IR/Dominators.h" #include <queue> -using namespace llvm; - -void IDFCalculator::calculate(SmallVectorImpl<BasicBlock *> &PHIBlocks) { +namespace llvm { +template <class NodeTy> +void IDFCalculator<NodeTy>::calculate( + SmallVectorImpl<BasicBlock *> &PHIBlocks) { // If we haven't computed dominator tree levels, do so now. if (DomLevels.empty()) { for (auto DFI = df_begin(DT.getRootNode()), DFE = df_end(DT.getRootNode()); @@ -61,8 +62,12 @@ void IDFCalculator::calculate(SmallVectorImpl<BasicBlock *> &PHIBlocks) { while (!Worklist.empty()) { DomTreeNode *Node = Worklist.pop_back_val(); BasicBlock *BB = Node->getBlock(); - - for (auto Succ : successors(BB)) { + // Succ is the successor in the direction we are calculating IDF, so it is + // successor for IDF, and predecessor for Reverse IDF. + for (auto SuccIter = GraphTraits<NodeTy>::child_begin(BB), + End = GraphTraits<NodeTy>::child_end(BB); + SuccIter != End; ++SuccIter) { + BasicBlock *Succ = *SuccIter; DomTreeNode *SuccNode = DT.getNode(Succ); // Quickly skip all CFG edges that are also dominator tree edges instead @@ -93,3 +98,7 @@ void IDFCalculator::calculate(SmallVectorImpl<BasicBlock *> &PHIBlocks) { } } } + +template class IDFCalculator<BasicBlock *>; +template class IDFCalculator<Inverse<BasicBlock *>>; +} diff --git a/lib/Analysis/LLVMBuild.txt b/lib/Analysis/LLVMBuild.txt index bddf1a3ac201..08af5f37700d 100644 --- a/lib/Analysis/LLVMBuild.txt +++ b/lib/Analysis/LLVMBuild.txt @@ -19,4 +19,4 @@ type = Library name = Analysis parent = Libraries -required_libraries = Core Support +required_libraries = Core Support ProfileData diff --git a/lib/Analysis/LazyBlockFrequencyInfo.cpp b/lib/Analysis/LazyBlockFrequencyInfo.cpp new file mode 100644 index 000000000000..7debfde87d2a --- /dev/null +++ b/lib/Analysis/LazyBlockFrequencyInfo.cpp @@ -0,0 +1,68 @@ +//===- LazyBlockFrequencyInfo.cpp - Lazy Block Frequency Analysis ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This is an alternative analysis pass to BlockFrequencyInfoWrapperPass. The +// difference is that with this pass the block frequencies are not computed when +// the analysis pass is executed but rather when the BFI results is explicitly +// requested by the analysis client. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" + +using namespace llvm; + +#define DEBUG_TYPE "lazy-block-freq" + +INITIALIZE_PASS_BEGIN(LazyBlockFrequencyInfoPass, DEBUG_TYPE, + "Lazy Block Frequency Analysis", true, true) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LazyBlockFrequencyInfoPass, DEBUG_TYPE, + "Lazy Block Frequency Analysis", true, true) + +char LazyBlockFrequencyInfoPass::ID = 0; + +LazyBlockFrequencyInfoPass::LazyBlockFrequencyInfoPass() : FunctionPass(ID) { + initializeLazyBlockFrequencyInfoPassPass(*PassRegistry::getPassRegistry()); +} + +void LazyBlockFrequencyInfoPass::print(raw_ostream &OS, const Module *) const { + LBFI.getCalculated().print(OS); +} + +void LazyBlockFrequencyInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesAll(); +} + +void LazyBlockFrequencyInfoPass::releaseMemory() { LBFI.releaseMemory(); } + +bool LazyBlockFrequencyInfoPass::runOnFunction(Function &F) { + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LBFI.setAnalysis(&F, &BPI, &LI); + return false; +} + +void LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AnalysisUsage &AU) { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); + AU.addRequired<LazyBlockFrequencyInfoPass>(); + AU.addRequired<LoopInfoWrapperPass>(); +} + +void llvm::initializeLazyBFIPassPass(PassRegistry &Registry) { + INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass); + INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass); + INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass); +} diff --git a/lib/Analysis/LazyCallGraph.cpp b/lib/Analysis/LazyCallGraph.cpp index 0f0f31e62ac7..acff8529b151 100644 --- a/lib/Analysis/LazyCallGraph.cpp +++ b/lib/Analysis/LazyCallGraph.cpp @@ -14,36 +14,41 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/GraphWriter.h" using namespace llvm; #define DEBUG_TYPE "lcg" -static void findCallees( - SmallVectorImpl<Constant *> &Worklist, SmallPtrSetImpl<Constant *> &Visited, - SmallVectorImpl<PointerUnion<Function *, LazyCallGraph::Node *>> &Callees, - DenseMap<Function *, size_t> &CalleeIndexMap) { +static void addEdge(SmallVectorImpl<LazyCallGraph::Edge> &Edges, + DenseMap<Function *, int> &EdgeIndexMap, Function &F, + LazyCallGraph::Edge::Kind EK) { + // Note that we consider *any* function with a definition to be a viable + // edge. Even if the function's definition is subject to replacement by + // some other module (say, a weak definition) there may still be + // optimizations which essentially speculate based on the definition and + // a way to check that the specific definition is in fact the one being + // used. For example, this could be done by moving the weak definition to + // a strong (internal) definition and making the weak definition be an + // alias. Then a test of the address of the weak function against the new + // strong definition's address would be an effective way to determine the + // safety of optimizing a direct call edge. + if (!F.isDeclaration() && + EdgeIndexMap.insert({&F, Edges.size()}).second) { + DEBUG(dbgs() << " Added callable function: " << F.getName() << "\n"); + Edges.emplace_back(LazyCallGraph::Edge(F, EK)); + } +} + +static void findReferences(SmallVectorImpl<Constant *> &Worklist, + SmallPtrSetImpl<Constant *> &Visited, + SmallVectorImpl<LazyCallGraph::Edge> &Edges, + DenseMap<Function *, int> &EdgeIndexMap) { while (!Worklist.empty()) { Constant *C = Worklist.pop_back_val(); if (Function *F = dyn_cast<Function>(C)) { - // Note that we consider *any* function with a definition to be a viable - // edge. Even if the function's definition is subject to replacement by - // some other module (say, a weak definition) there may still be - // optimizations which essentially speculate based on the definition and - // a way to check that the specific definition is in fact the one being - // used. For example, this could be done by moving the weak definition to - // a strong (internal) definition and making the weak definition be an - // alias. Then a test of the address of the weak function against the new - // strong definition's address would be an effective way to determine the - // safety of optimizing a direct call edge. - if (!F->isDeclaration() && - CalleeIndexMap.insert(std::make_pair(F, Callees.size())).second) { - DEBUG(dbgs() << " Added callable function: " << F->getName() - << "\n"); - Callees.push_back(F); - } + addEdge(Edges, EdgeIndexMap, *F, LazyCallGraph::Edge::Ref); continue; } @@ -59,42 +64,63 @@ LazyCallGraph::Node::Node(LazyCallGraph &G, Function &F) << "' to the graph.\n"); SmallVector<Constant *, 16> Worklist; + SmallPtrSet<Function *, 4> Callees; SmallPtrSet<Constant *, 16> Visited; - // Find all the potential callees in this function. First walk the - // instructions and add every operand which is a constant to the worklist. + + // Find all the potential call graph edges in this function. We track both + // actual call edges and indirect references to functions. The direct calls + // are trivially added, but to accumulate the latter we walk the instructions + // and add every operand which is a constant to the worklist to process + // afterward. for (BasicBlock &BB : F) - for (Instruction &I : BB) + for (Instruction &I : BB) { + if (auto CS = CallSite(&I)) + if (Function *Callee = CS.getCalledFunction()) + if (Callees.insert(Callee).second) { + Visited.insert(Callee); + addEdge(Edges, EdgeIndexMap, *Callee, LazyCallGraph::Edge::Call); + } + for (Value *Op : I.operand_values()) if (Constant *C = dyn_cast<Constant>(Op)) if (Visited.insert(C).second) Worklist.push_back(C); + } // We've collected all the constant (and thus potentially function or // function containing) operands to all of the instructions in the function. // Process them (recursively) collecting every function found. - findCallees(Worklist, Visited, Callees, CalleeIndexMap); + findReferences(Worklist, Visited, Edges, EdgeIndexMap); } -void LazyCallGraph::Node::insertEdgeInternal(Function &Callee) { - if (Node *N = G->lookup(Callee)) - return insertEdgeInternal(*N); +void LazyCallGraph::Node::insertEdgeInternal(Function &Target, Edge::Kind EK) { + if (Node *N = G->lookup(Target)) + return insertEdgeInternal(*N, EK); + + EdgeIndexMap.insert({&Target, Edges.size()}); + Edges.emplace_back(Target, EK); +} - CalleeIndexMap.insert(std::make_pair(&Callee, Callees.size())); - Callees.push_back(&Callee); +void LazyCallGraph::Node::insertEdgeInternal(Node &TargetN, Edge::Kind EK) { + EdgeIndexMap.insert({&TargetN.getFunction(), Edges.size()}); + Edges.emplace_back(TargetN, EK); } -void LazyCallGraph::Node::insertEdgeInternal(Node &CalleeN) { - CalleeIndexMap.insert(std::make_pair(&CalleeN.getFunction(), Callees.size())); - Callees.push_back(&CalleeN); +void LazyCallGraph::Node::setEdgeKind(Function &TargetF, Edge::Kind EK) { + Edges[EdgeIndexMap.find(&TargetF)->second].setKind(EK); } -void LazyCallGraph::Node::removeEdgeInternal(Function &Callee) { - auto IndexMapI = CalleeIndexMap.find(&Callee); - assert(IndexMapI != CalleeIndexMap.end() && - "Callee not in the callee set for this caller?"); +void LazyCallGraph::Node::removeEdgeInternal(Function &Target) { + auto IndexMapI = EdgeIndexMap.find(&Target); + assert(IndexMapI != EdgeIndexMap.end() && + "Target not in the edge set for this caller?"); - Callees[IndexMapI->second] = nullptr; - CalleeIndexMap.erase(IndexMapI); + Edges[IndexMapI->second] = Edge(); + EdgeIndexMap.erase(IndexMapI); +} + +void LazyCallGraph::Node::dump() const { + dbgs() << *this << '\n'; } LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { @@ -102,10 +128,10 @@ LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { << "\n"); for (Function &F : M) if (!F.isDeclaration() && !F.hasLocalLinkage()) - if (EntryIndexMap.insert(std::make_pair(&F, EntryNodes.size())).second) { + if (EntryIndexMap.insert({&F, EntryEdges.size()}).second) { DEBUG(dbgs() << " Adding '" << F.getName() << "' to entry set of the graph.\n"); - EntryNodes.push_back(&F); + EntryEdges.emplace_back(F, Edge::Ref); } // Now add entry nodes for functions reachable via initializers to globals. @@ -118,25 +144,19 @@ LazyCallGraph::LazyCallGraph(Module &M) : NextDFSNumber(0) { DEBUG(dbgs() << " Adding functions referenced by global initializers to the " "entry set.\n"); - findCallees(Worklist, Visited, EntryNodes, EntryIndexMap); + findReferences(Worklist, Visited, EntryEdges, EntryIndexMap); - for (auto &Entry : EntryNodes) { - assert(!Entry.isNull() && - "We can't have removed edges before we finish the constructor!"); - if (Function *F = Entry.dyn_cast<Function *>()) - SCCEntryNodes.push_back(F); - else - SCCEntryNodes.push_back(&Entry.get<Node *>()->getFunction()); - } + for (const Edge &E : EntryEdges) + RefSCCEntryNodes.push_back(&E.getFunction()); } LazyCallGraph::LazyCallGraph(LazyCallGraph &&G) : BPA(std::move(G.BPA)), NodeMap(std::move(G.NodeMap)), - EntryNodes(std::move(G.EntryNodes)), + EntryEdges(std::move(G.EntryEdges)), EntryIndexMap(std::move(G.EntryIndexMap)), SCCBPA(std::move(G.SCCBPA)), - SCCMap(std::move(G.SCCMap)), LeafSCCs(std::move(G.LeafSCCs)), + SCCMap(std::move(G.SCCMap)), LeafRefSCCs(std::move(G.LeafRefSCCs)), DFSStack(std::move(G.DFSStack)), - SCCEntryNodes(std::move(G.SCCEntryNodes)), + RefSCCEntryNodes(std::move(G.RefSCCEntryNodes)), NextDFSNumber(G.NextDFSNumber) { updateGraphPtrs(); } @@ -144,402 +164,1080 @@ LazyCallGraph::LazyCallGraph(LazyCallGraph &&G) LazyCallGraph &LazyCallGraph::operator=(LazyCallGraph &&G) { BPA = std::move(G.BPA); NodeMap = std::move(G.NodeMap); - EntryNodes = std::move(G.EntryNodes); + EntryEdges = std::move(G.EntryEdges); EntryIndexMap = std::move(G.EntryIndexMap); SCCBPA = std::move(G.SCCBPA); SCCMap = std::move(G.SCCMap); - LeafSCCs = std::move(G.LeafSCCs); + LeafRefSCCs = std::move(G.LeafRefSCCs); DFSStack = std::move(G.DFSStack); - SCCEntryNodes = std::move(G.SCCEntryNodes); + RefSCCEntryNodes = std::move(G.RefSCCEntryNodes); NextDFSNumber = G.NextDFSNumber; updateGraphPtrs(); return *this; } -void LazyCallGraph::SCC::insert(Node &N) { - N.DFSNumber = N.LowLink = -1; - Nodes.push_back(&N); - G->SCCMap[&N] = this; +void LazyCallGraph::SCC::dump() const { + dbgs() << *this << '\n'; } -bool LazyCallGraph::SCC::isDescendantOf(const SCC &C) const { +#ifndef NDEBUG +void LazyCallGraph::SCC::verify() { + assert(OuterRefSCC && "Can't have a null RefSCC!"); + assert(!Nodes.empty() && "Can't have an empty SCC!"); + + for (Node *N : Nodes) { + assert(N && "Can't have a null node!"); + assert(OuterRefSCC->G->lookupSCC(*N) == this && + "Node does not map to this SCC!"); + assert(N->DFSNumber == -1 && + "Must set DFS numbers to -1 when adding a node to an SCC!"); + assert(N->LowLink == -1 && + "Must set low link to -1 when adding a node to an SCC!"); + for (Edge &E : *N) + assert(E.getNode() && "Can't have an edge to a raw function!"); + } +} +#endif + +LazyCallGraph::RefSCC::RefSCC(LazyCallGraph &G) : G(&G) {} + +void LazyCallGraph::RefSCC::dump() const { + dbgs() << *this << '\n'; +} + +#ifndef NDEBUG +void LazyCallGraph::RefSCC::verify() { + assert(G && "Can't have a null graph!"); + assert(!SCCs.empty() && "Can't have an empty SCC!"); + + // Verify basic properties of the SCCs. + for (SCC *C : SCCs) { + assert(C && "Can't have a null SCC!"); + C->verify(); + assert(&C->getOuterRefSCC() == this && + "SCC doesn't think it is inside this RefSCC!"); + } + + // Check that our indices map correctly. + for (auto &SCCIndexPair : SCCIndices) { + SCC *C = SCCIndexPair.first; + int i = SCCIndexPair.second; + assert(C && "Can't have a null SCC in the indices!"); + assert(SCCs[i] == C && "Index doesn't point to SCC!"); + } + + // Check that the SCCs are in fact in post-order. + for (int i = 0, Size = SCCs.size(); i < Size; ++i) { + SCC &SourceSCC = *SCCs[i]; + for (Node &N : SourceSCC) + for (Edge &E : N) { + if (!E.isCall()) + continue; + SCC &TargetSCC = *G->lookupSCC(*E.getNode()); + if (&TargetSCC.getOuterRefSCC() == this) { + assert(SCCIndices.find(&TargetSCC)->second <= i && + "Edge between SCCs violates post-order relationship."); + continue; + } + assert(TargetSCC.getOuterRefSCC().Parents.count(this) && + "Edge to a RefSCC missing us in its parent set."); + } + } +} +#endif + +bool LazyCallGraph::RefSCC::isDescendantOf(const RefSCC &C) const { // Walk up the parents of this SCC and verify that we eventually find C. - SmallVector<const SCC *, 4> AncestorWorklist; + SmallVector<const RefSCC *, 4> AncestorWorklist; AncestorWorklist.push_back(this); do { - const SCC *AncestorC = AncestorWorklist.pop_back_val(); + const RefSCC *AncestorC = AncestorWorklist.pop_back_val(); if (AncestorC->isChildOf(C)) return true; - for (const SCC *ParentC : AncestorC->ParentSCCs) + for (const RefSCC *ParentC : AncestorC->Parents) AncestorWorklist.push_back(ParentC); } while (!AncestorWorklist.empty()); return false; } -void LazyCallGraph::SCC::insertIntraSCCEdge(Node &CallerN, Node &CalleeN) { - // First insert it into the caller. - CallerN.insertEdgeInternal(CalleeN); +SmallVector<LazyCallGraph::SCC *, 1> +LazyCallGraph::RefSCC::switchInternalEdgeToCall(Node &SourceN, Node &TargetN) { + assert(!SourceN[TargetN].isCall() && "Must start with a ref edge!"); + + SmallVector<SCC *, 1> DeletedSCCs; + + SCC &SourceSCC = *G->lookupSCC(SourceN); + SCC &TargetSCC = *G->lookupSCC(TargetN); + + // If the two nodes are already part of the same SCC, we're also done as + // we've just added more connectivity. + if (&SourceSCC == &TargetSCC) { + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif + return DeletedSCCs; + } + + // At this point we leverage the postorder list of SCCs to detect when the + // insertion of an edge changes the SCC structure in any way. + // + // First and foremost, we can eliminate the need for any changes when the + // edge is toward the beginning of the postorder sequence because all edges + // flow in that direction already. Thus adding a new one cannot form a cycle. + int SourceIdx = SCCIndices[&SourceSCC]; + int TargetIdx = SCCIndices[&TargetSCC]; + if (TargetIdx < SourceIdx) { + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif + return DeletedSCCs; + } + + // When we do have an edge from an earlier SCC to a later SCC in the + // postorder sequence, all of the SCCs which may be impacted are in the + // closed range of those two within the postorder sequence. The algorithm to + // restore the state is as follows: + // + // 1) Starting from the source SCC, construct a set of SCCs which reach the + // source SCC consisting of just the source SCC. Then scan toward the + // target SCC in postorder and for each SCC, if it has an edge to an SCC + // in the set, add it to the set. Otherwise, the source SCC is not + // a successor, move it in the postorder sequence to immediately before + // the source SCC, shifting the source SCC and all SCCs in the set one + // position toward the target SCC. Stop scanning after processing the + // target SCC. + // 2) If the source SCC is now past the target SCC in the postorder sequence, + // and thus the new edge will flow toward the start, we are done. + // 3) Otherwise, starting from the target SCC, walk all edges which reach an + // SCC between the source and the target, and add them to the set of + // connected SCCs, then recurse through them. Once a complete set of the + // SCCs the target connects to is known, hoist the remaining SCCs between + // the source and the target to be above the target. Note that there is no + // need to process the source SCC, it is already known to connect. + // 4) At this point, all of the SCCs in the closed range between the source + // SCC and the target SCC in the postorder sequence are connected, + // including the target SCC and the source SCC. Inserting the edge from + // the source SCC to the target SCC will form a cycle out of precisely + // these SCCs. Thus we can merge all of the SCCs in this closed range into + // a single SCC. + // + // This process has various important properties: + // - Only mutates the SCCs when adding the edge actually changes the SCC + // structure. + // - Never mutates SCCs which are unaffected by the change. + // - Updates the postorder sequence to correctly satisfy the postorder + // constraint after the edge is inserted. + // - Only reorders SCCs in the closed postorder sequence from the source to + // the target, so easy to bound how much has changed even in the ordering. + // - Big-O is the number of edges in the closed postorder range of SCCs from + // source to target. + + assert(SourceIdx < TargetIdx && "Cannot have equal indices here!"); + SmallPtrSet<SCC *, 4> ConnectedSet; + + // Compute the SCCs which (transitively) reach the source. + ConnectedSet.insert(&SourceSCC); + auto IsConnected = [&](SCC &C) { + for (Node &N : C) + for (Edge &E : N.calls()) { + assert(E.getNode() && "Must have formed a node within an SCC!"); + if (ConnectedSet.count(G->lookupSCC(*E.getNode()))) + return true; + } + + return false; + }; + + for (SCC *C : + make_range(SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1)) + if (IsConnected(*C)) + ConnectedSet.insert(C); + + // Partition the SCCs in this part of the port-order sequence so only SCCs + // connecting to the source remain between it and the target. This is + // a benign partition as it preserves postorder. + auto SourceI = std::stable_partition( + SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx + 1, + [&ConnectedSet](SCC *C) { return !ConnectedSet.count(C); }); + for (int i = SourceIdx, e = TargetIdx + 1; i < e; ++i) + SCCIndices.find(SCCs[i])->second = i; + + // If the target doesn't connect to the source, then we've corrected the + // post-order and there are no cycles formed. + if (!ConnectedSet.count(&TargetSCC)) { + assert(SourceI > (SCCs.begin() + SourceIdx) && + "Must have moved the source to fix the post-order."); + assert(*std::prev(SourceI) == &TargetSCC && + "Last SCC to move should have bene the target."); + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); +#ifndef NDEBUG + verify(); +#endif + return DeletedSCCs; + } + + assert(SCCs[TargetIdx] == &TargetSCC && + "Should not have moved target if connected!"); + SourceIdx = SourceI - SCCs.begin(); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif + + // See whether there are any remaining intervening SCCs between the source + // and target. If so we need to make sure they all are reachable form the + // target. + if (SourceIdx + 1 < TargetIdx) { + // Use a normal worklist to find which SCCs the target connects to. We still + // bound the search based on the range in the postorder list we care about, + // but because this is forward connectivity we just "recurse" through the + // edges. + ConnectedSet.clear(); + ConnectedSet.insert(&TargetSCC); + SmallVector<SCC *, 4> Worklist; + Worklist.push_back(&TargetSCC); + do { + SCC &C = *Worklist.pop_back_val(); + for (Node &N : C) + for (Edge &E : N) { + assert(E.getNode() && "Must have formed a node within an SCC!"); + if (!E.isCall()) + continue; + SCC &EdgeC = *G->lookupSCC(*E.getNode()); + if (&EdgeC.getOuterRefSCC() != this) + // Not in this RefSCC... + continue; + if (SCCIndices.find(&EdgeC)->second <= SourceIdx) + // Not in the postorder sequence between source and target. + continue; + + if (ConnectedSet.insert(&EdgeC).second) + Worklist.push_back(&EdgeC); + } + } while (!Worklist.empty()); + + // Partition SCCs so that only SCCs reached from the target remain between + // the source and the target. This preserves postorder. + auto TargetI = std::stable_partition( + SCCs.begin() + SourceIdx + 1, SCCs.begin() + TargetIdx + 1, + [&ConnectedSet](SCC *C) { return ConnectedSet.count(C); }); + for (int i = SourceIdx + 1, e = TargetIdx + 1; i < e; ++i) + SCCIndices.find(SCCs[i])->second = i; + TargetIdx = std::prev(TargetI) - SCCs.begin(); + assert(SCCs[TargetIdx] == &TargetSCC && + "Should always end with the target!"); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif + } + + // At this point, we know that connecting source to target forms a cycle + // because target connects back to source, and we know that all of the SCCs + // between the source and target in the postorder sequence participate in that + // cycle. This means that we need to merge all of these SCCs into a single + // result SCC. + // + // NB: We merge into the target because all of these functions were already + // reachable from the target, meaning any SCC-wide properties deduced about it + // other than the set of functions within it will not have changed. + auto MergeRange = + make_range(SCCs.begin() + SourceIdx, SCCs.begin() + TargetIdx); + for (SCC *C : MergeRange) { + assert(C != &TargetSCC && + "We merge *into* the target and shouldn't process it here!"); + SCCIndices.erase(C); + TargetSCC.Nodes.append(C->Nodes.begin(), C->Nodes.end()); + for (Node *N : C->Nodes) + G->SCCMap[N] = &TargetSCC; + C->clear(); + DeletedSCCs.push_back(C); + } + + // Erase the merged SCCs from the list and update the indices of the + // remaining SCCs. + int IndexOffset = MergeRange.end() - MergeRange.begin(); + auto EraseEnd = SCCs.erase(MergeRange.begin(), MergeRange.end()); + for (SCC *C : make_range(EraseEnd, SCCs.end())) + SCCIndices[C] -= IndexOffset; - assert(G->SCCMap.lookup(&CallerN) == this && "Caller must be in this SCC."); - assert(G->SCCMap.lookup(&CalleeN) == this && "Callee must be in this SCC."); + // Now that the SCC structure is finalized, flip the kind to call. + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); - // Nothing changes about this SCC or any other. +#ifndef NDEBUG + // And we're done! Verify in debug builds that the RefSCC is coherent. + verify(); +#endif + return DeletedSCCs; +} + +void LazyCallGraph::RefSCC::switchInternalEdgeToRef(Node &SourceN, + Node &TargetN) { + assert(SourceN[TargetN].isCall() && "Must start with a call edge!"); + + SCC &SourceSCC = *G->lookupSCC(SourceN); + SCC &TargetSCC = *G->lookupSCC(TargetN); + + assert(&SourceSCC.getOuterRefSCC() == this && + "Source must be in this RefSCC."); + assert(&TargetSCC.getOuterRefSCC() == this && + "Target must be in this RefSCC."); + + // Set the edge kind. + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Ref); + + // If this call edge is just connecting two separate SCCs within this RefSCC, + // there is nothing to do. + if (&SourceSCC != &TargetSCC) { +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif + return; + } + + // Otherwise we are removing a call edge from a single SCC. This may break + // the cycle. In order to compute the new set of SCCs, we need to do a small + // DFS over the nodes within the SCC to form any sub-cycles that remain as + // distinct SCCs and compute a postorder over the resulting SCCs. + // + // However, we specially handle the target node. The target node is known to + // reach all other nodes in the original SCC by definition. This means that + // we want the old SCC to be replaced with an SCC contaning that node as it + // will be the root of whatever SCC DAG results from the DFS. Assumptions + // about an SCC such as the set of functions called will continue to hold, + // etc. + + SCC &OldSCC = TargetSCC; + SmallVector<std::pair<Node *, call_edge_iterator>, 16> DFSStack; + SmallVector<Node *, 16> PendingSCCStack; + SmallVector<SCC *, 4> NewSCCs; + + // Prepare the nodes for a fresh DFS. + SmallVector<Node *, 16> Worklist; + Worklist.swap(OldSCC.Nodes); + for (Node *N : Worklist) { + N->DFSNumber = N->LowLink = 0; + G->SCCMap.erase(N); + } + + // Force the target node to be in the old SCC. This also enables us to take + // a very significant short-cut in the standard Tarjan walk to re-form SCCs + // below: whenever we build an edge that reaches the target node, we know + // that the target node eventually connects back to all other nodes in our + // walk. As a consequence, we can detect and handle participants in that + // cycle without walking all the edges that form this connection, and instead + // by relying on the fundamental guarantee coming into this operation (all + // nodes are reachable from the target due to previously forming an SCC). + TargetN.DFSNumber = TargetN.LowLink = -1; + OldSCC.Nodes.push_back(&TargetN); + G->SCCMap[&TargetN] = &OldSCC; + + // Scan down the stack and DFS across the call edges. + for (Node *RootN : Worklist) { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; + } + + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; + + DFSStack.push_back({RootN, RootN->call_begin()}); + do { + Node *N; + call_edge_iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = N->call_end(); + while (I != E) { + Node &ChildN = *I->getNode(); + if (ChildN.DFSNumber == 0) { + // We haven't yet visited this child, so descend, pushing the current + // node onto the stack. + DFSStack.push_back({N, I}); + + assert(!G->SCCMap.count(&ChildN) && + "Found a node with 0 DFS number but already in an SCC!"); + ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; + N = &ChildN; + I = N->call_begin(); + E = N->call_end(); + continue; + } + + // Check for the child already being part of some component. + if (ChildN.DFSNumber == -1) { + if (G->lookupSCC(ChildN) == &OldSCC) { + // If the child is part of the old SCC, we know that it can reach + // every other node, so we have formed a cycle. Pull the entire DFS + // and pending stacks into it. See the comment above about setting + // up the old SCC for why we do this. + int OldSize = OldSCC.size(); + OldSCC.Nodes.push_back(N); + OldSCC.Nodes.append(PendingSCCStack.begin(), PendingSCCStack.end()); + PendingSCCStack.clear(); + while (!DFSStack.empty()) + OldSCC.Nodes.push_back(DFSStack.pop_back_val().first); + for (Node &N : make_range(OldSCC.begin() + OldSize, OldSCC.end())) { + N.DFSNumber = N.LowLink = -1; + G->SCCMap[&N] = &OldSCC; + } + N = nullptr; + break; + } + + // If the child has already been added to some child component, it + // couldn't impact the low-link of this parent because it isn't + // connected, and thus its low-link isn't relevant so skip it. + ++I; + continue; + } + + // Track the lowest linked child as the lowest link for this node. + assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); + if (ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + + // Move to the next edge. + ++I; + } + if (!N) + // Cleared the DFS early, start another round. + break; + + // We've finished processing N and its descendents, put it on our pending + // SCC stack to eventually get merged into an SCC of nodes. + PendingSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) + continue; + + // Otherwise, we've completed an SCC. Append it to our post order list of + // SCCs. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto SCCNodes = make_range( + PendingSCCStack.rbegin(), + std::find_if(PendingSCCStack.rbegin(), PendingSCCStack.rend(), + [RootDFSNumber](Node *N) { + return N->DFSNumber < RootDFSNumber; + })); + + // Form a new SCC out of these nodes and then clear them off our pending + // stack. + NewSCCs.push_back(G->createSCC(*this, SCCNodes)); + for (Node &N : *NewSCCs.back()) { + N.DFSNumber = N.LowLink = -1; + G->SCCMap[&N] = NewSCCs.back(); + } + PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); + } while (!DFSStack.empty()); + } + + // Insert the remaining SCCs before the old one. The old SCC can reach all + // other SCCs we form because it contains the target node of the removed edge + // of the old SCC. This means that we will have edges into all of the new + // SCCs, which means the old one must come last for postorder. + int OldIdx = SCCIndices[&OldSCC]; + SCCs.insert(SCCs.begin() + OldIdx, NewSCCs.begin(), NewSCCs.end()); + + // Update the mapping from SCC* to index to use the new SCC*s, and remove the + // old SCC from the mapping. + for (int Idx = OldIdx, Size = SCCs.size(); Idx < Size; ++Idx) + SCCIndices[SCCs[Idx]] = Idx; + +#ifndef NDEBUG + // We're done. Check the validity on our way out. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::switchOutgoingEdgeToCall(Node &SourceN, + Node &TargetN) { + assert(!SourceN[TargetN].isCall() && "Must start with a ref edge!"); + + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) != this && + "Target must not be in this RefSCC."); + assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && + "Target must be a descendant of the Source."); + + // Edges between RefSCCs are the same regardless of call or ref, so we can + // just flip the edge here. + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Call); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif } -void LazyCallGraph::SCC::insertOutgoingEdge(Node &CallerN, Node &CalleeN) { +void LazyCallGraph::RefSCC::switchOutgoingEdgeToRef(Node &SourceN, + Node &TargetN) { + assert(SourceN[TargetN].isCall() && "Must start with a call edge!"); + + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) != this && + "Target must not be in this RefSCC."); + assert(G->lookupRefSCC(TargetN)->isDescendantOf(*this) && + "Target must be a descendant of the Source."); + + // Edges between RefSCCs are the same regardless of call or ref, so we can + // just flip the edge here. + SourceN.setEdgeKind(TargetN.getFunction(), Edge::Ref); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertInternalRefEdge(Node &SourceN, + Node &TargetN) { + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); + assert(G->lookupRefSCC(TargetN) == this && "Target must be in this RefSCC."); + + SourceN.insertEdgeInternal(TargetN, Edge::Ref); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif +} + +void LazyCallGraph::RefSCC::insertOutgoingEdge(Node &SourceN, Node &TargetN, + Edge::Kind EK) { // First insert it into the caller. - CallerN.insertEdgeInternal(CalleeN); + SourceN.insertEdgeInternal(TargetN, EK); - assert(G->SCCMap.lookup(&CallerN) == this && "Caller must be in this SCC."); + assert(G->lookupRefSCC(SourceN) == this && "Source must be in this RefSCC."); - SCC &CalleeC = *G->SCCMap.lookup(&CalleeN); - assert(&CalleeC != this && "Callee must not be in this SCC."); - assert(CalleeC.isDescendantOf(*this) && - "Callee must be a descendant of the Caller."); + RefSCC &TargetC = *G->lookupRefSCC(TargetN); + assert(&TargetC != this && "Target must not be in this RefSCC."); + assert(TargetC.isDescendantOf(*this) && + "Target must be a descendant of the Source."); // The only change required is to add this SCC to the parent set of the // callee. - CalleeC.ParentSCCs.insert(this); + TargetC.Parents.insert(this); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif } -SmallVector<LazyCallGraph::SCC *, 1> -LazyCallGraph::SCC::insertIncomingEdge(Node &CallerN, Node &CalleeN) { - // First insert it into the caller. - CallerN.insertEdgeInternal(CalleeN); +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::insertIncomingRefEdge(Node &SourceN, Node &TargetN) { + assert(G->lookupRefSCC(TargetN) == this && "Target must be in this SCC."); - assert(G->SCCMap.lookup(&CalleeN) == this && "Callee must be in this SCC."); + // We store the RefSCCs found to be connected in postorder so that we can use + // that when merging. We also return this to the caller to allow them to + // invalidate information pertaining to these RefSCCs. + SmallVector<RefSCC *, 1> Connected; - SCC &CallerC = *G->SCCMap.lookup(&CallerN); - assert(&CallerC != this && "Caller must not be in this SCC."); - assert(CallerC.isDescendantOf(*this) && - "Caller must be a descendant of the Callee."); + RefSCC &SourceC = *G->lookupRefSCC(SourceN); + assert(&SourceC != this && "Source must not be in this SCC."); + assert(SourceC.isDescendantOf(*this) && + "Source must be a descendant of the Target."); // The algorithm we use for merging SCCs based on the cycle introduced here - // is to walk the SCC inverted DAG formed by the parent SCC sets. The inverse - // graph has the same cycle properties as the actual DAG of the SCCs, and - // when forming SCCs lazily by a DFS, the bottom of the graph won't exist in - // many cases which should prune the search space. + // is to walk the RefSCC inverted DAG formed by the parent sets. The inverse + // graph has the same cycle properties as the actual DAG of the RefSCCs, and + // when forming RefSCCs lazily by a DFS, the bottom of the graph won't exist + // in many cases which should prune the search space. // - // FIXME: We can get this pruning behavior even after the incremental SCC + // FIXME: We can get this pruning behavior even after the incremental RefSCC // formation by leaving behind (conservative) DFS numberings in the nodes, // and pruning the search with them. These would need to be cleverly updated // during the removal of intra-SCC edges, but could be preserved // conservatively. + // + // FIXME: This operation currently creates ordering stability problems + // because we don't use stably ordered containers for the parent SCCs. - // The set of SCCs that are connected to the caller, and thus will + // The set of RefSCCs that are connected to the parent, and thus will // participate in the merged connected component. - SmallPtrSet<SCC *, 8> ConnectedSCCs; - ConnectedSCCs.insert(this); - ConnectedSCCs.insert(&CallerC); + SmallPtrSet<RefSCC *, 8> ConnectedSet; + ConnectedSet.insert(this); // We build up a DFS stack of the parents chains. - SmallVector<std::pair<SCC *, SCC::parent_iterator>, 8> DFSSCCs; - SmallPtrSet<SCC *, 8> VisitedSCCs; + SmallVector<std::pair<RefSCC *, parent_iterator>, 8> DFSStack; + SmallPtrSet<RefSCC *, 8> Visited; int ConnectedDepth = -1; - SCC *C = this; - parent_iterator I = parent_begin(), E = parent_end(); - for (;;) { + DFSStack.push_back({&SourceC, SourceC.parent_begin()}); + do { + auto DFSPair = DFSStack.pop_back_val(); + RefSCC *C = DFSPair.first; + parent_iterator I = DFSPair.second; + auto E = C->parent_end(); + while (I != E) { - SCC &ParentSCC = *I++; + RefSCC &Parent = *I++; // If we have already processed this parent SCC, skip it, and remember // whether it was connected so we don't have to check the rest of the // stack. This also handles when we reach a child of the 'this' SCC (the // callee) which terminates the search. - if (ConnectedSCCs.count(&ParentSCC)) { - ConnectedDepth = std::max<int>(ConnectedDepth, DFSSCCs.size()); + if (ConnectedSet.count(&Parent)) { + assert(ConnectedDepth < (int)DFSStack.size() && + "Cannot have a connected depth greater than the DFS depth!"); + ConnectedDepth = DFSStack.size(); continue; } - if (VisitedSCCs.count(&ParentSCC)) + if (Visited.count(&Parent)) continue; // We fully explore the depth-first space, adding nodes to the connected // set only as we pop them off, so "recurse" by rotating to the parent. - DFSSCCs.push_back(std::make_pair(C, I)); - C = &ParentSCC; - I = ParentSCC.parent_begin(); - E = ParentSCC.parent_end(); + DFSStack.push_back({C, I}); + C = &Parent; + I = C->parent_begin(); + E = C->parent_end(); } // If we've found a connection anywhere below this point on the stack (and // thus up the parent graph from the caller), the current node needs to be // added to the connected set now that we've processed all of its parents. - if ((int)DFSSCCs.size() == ConnectedDepth) { + if ((int)DFSStack.size() == ConnectedDepth) { --ConnectedDepth; // We're finished with this connection. - ConnectedSCCs.insert(C); + bool Inserted = ConnectedSet.insert(C).second; + (void)Inserted; + assert(Inserted && "Cannot insert a refSCC multiple times!"); + Connected.push_back(C); } else { // Otherwise remember that its parents don't ever connect. - assert(ConnectedDepth < (int)DFSSCCs.size() && + assert(ConnectedDepth < (int)DFSStack.size() && "Cannot have a connected depth greater than the DFS depth!"); - VisitedSCCs.insert(C); + Visited.insert(C); } - - if (DFSSCCs.empty()) - break; // We've walked all the parents of the caller transitively. - - // Pop off the prior node and position to unwind the depth first recursion. - std::tie(C, I) = DFSSCCs.pop_back_val(); - E = C->parent_end(); - } + } while (!DFSStack.empty()); // Now that we have identified all of the SCCs which need to be merged into // a connected set with the inserted edge, merge all of them into this SCC. - // FIXME: This operation currently creates ordering stability problems - // because we don't use stably ordered containers for the parent SCCs or the - // connected SCCs. - unsigned NewNodeBeginIdx = Nodes.size(); - for (SCC *C : ConnectedSCCs) { - if (C == this) - continue; - for (SCC *ParentC : C->ParentSCCs) - if (!ConnectedSCCs.count(ParentC)) - ParentSCCs.insert(ParentC); - C->ParentSCCs.clear(); - - for (Node *N : *C) { - for (Node &ChildN : *N) { - SCC &ChildC = *G->SCCMap.lookup(&ChildN); - if (&ChildC != C) - ChildC.ParentSCCs.erase(C); + // We walk the newly connected RefSCCs in the reverse postorder of the parent + // DAG walk above and merge in each of their SCC postorder lists. This + // ensures a merged postorder SCC list. + SmallVector<SCC *, 16> MergedSCCs; + int SCCIndex = 0; + for (RefSCC *C : reverse(Connected)) { + assert(C != this && + "This RefSCC should terminate the DFS without being reached."); + + // Merge the parents which aren't part of the merge into the our parents. + for (RefSCC *ParentC : C->Parents) + if (!ConnectedSet.count(ParentC)) + Parents.insert(ParentC); + C->Parents.clear(); + + // Walk the inner SCCs to update their up-pointer and walk all the edges to + // update any parent sets. + // FIXME: We should try to find a way to avoid this (rather expensive) edge + // walk by updating the parent sets in some other manner. + for (SCC &InnerC : *C) { + InnerC.OuterRefSCC = this; + SCCIndices[&InnerC] = SCCIndex++; + for (Node &N : InnerC) { + G->SCCMap[&N] = &InnerC; + for (Edge &E : N) { + assert(E.getNode() && + "Cannot have a null node within a visited SCC!"); + RefSCC &ChildRC = *G->lookupRefSCC(*E.getNode()); + if (ConnectedSet.count(&ChildRC)) + continue; + ChildRC.Parents.erase(C); + ChildRC.Parents.insert(this); + } } - G->SCCMap[N] = this; - Nodes.push_back(N); } - C->Nodes.clear(); + + // Now merge in the SCCs. We can actually move here so try to reuse storage + // the first time through. + if (MergedSCCs.empty()) + MergedSCCs = std::move(C->SCCs); + else + MergedSCCs.append(C->SCCs.begin(), C->SCCs.end()); + C->SCCs.clear(); } - for (auto I = Nodes.begin() + NewNodeBeginIdx, E = Nodes.end(); I != E; ++I) - for (Node &ChildN : **I) { - SCC &ChildC = *G->SCCMap.lookup(&ChildN); - if (&ChildC != this) - ChildC.ParentSCCs.insert(this); - } + + // Finally append our original SCCs to the merged list and move it into + // place. + for (SCC &InnerC : *this) + SCCIndices[&InnerC] = SCCIndex++; + MergedSCCs.append(SCCs.begin(), SCCs.end()); + SCCs = std::move(MergedSCCs); + + // At this point we have a merged RefSCC with a post-order SCCs list, just + // connect the nodes to form the new edge. + SourceN.insertEdgeInternal(TargetN, Edge::Ref); + +#ifndef NDEBUG + // Check that the RefSCC is still valid. + verify(); +#endif // We return the list of SCCs which were merged so that callers can // invalidate any data they have associated with those SCCs. Note that these // SCCs are no longer in an interesting state (they are totally empty) but // the pointers will remain stable for the life of the graph itself. - return SmallVector<SCC *, 1>(ConnectedSCCs.begin(), ConnectedSCCs.end()); + return Connected; } -void LazyCallGraph::SCC::removeInterSCCEdge(Node &CallerN, Node &CalleeN) { +void LazyCallGraph::RefSCC::removeOutgoingEdge(Node &SourceN, Node &TargetN) { + assert(G->lookupRefSCC(SourceN) == this && + "The source must be a member of this RefSCC."); + + RefSCC &TargetRC = *G->lookupRefSCC(TargetN); + assert(&TargetRC != this && "The target must not be a member of this RefSCC"); + + assert(std::find(G->LeafRefSCCs.begin(), G->LeafRefSCCs.end(), this) == + G->LeafRefSCCs.end() && + "Cannot have a leaf RefSCC source."); + // First remove it from the node. - CallerN.removeEdgeInternal(CalleeN.getFunction()); - - assert(G->SCCMap.lookup(&CallerN) == this && - "The caller must be a member of this SCC."); - - SCC &CalleeC = *G->SCCMap.lookup(&CalleeN); - assert(&CalleeC != this && - "This API only supports the rmoval of inter-SCC edges."); - - assert(std::find(G->LeafSCCs.begin(), G->LeafSCCs.end(), this) == - G->LeafSCCs.end() && - "Cannot have a leaf SCC caller with a different SCC callee."); - - bool HasOtherCallToCalleeC = false; - bool HasOtherCallOutsideSCC = false; - for (Node *N : *this) { - for (Node &OtherCalleeN : *N) { - SCC &OtherCalleeC = *G->SCCMap.lookup(&OtherCalleeN); - if (&OtherCalleeC == &CalleeC) { - HasOtherCallToCalleeC = true; - break; + SourceN.removeEdgeInternal(TargetN.getFunction()); + + bool HasOtherEdgeToChildRC = false; + bool HasOtherChildRC = false; + for (SCC *InnerC : SCCs) { + for (Node &N : *InnerC) { + for (Edge &E : N) { + assert(E.getNode() && "Cannot have a missing node in a visited SCC!"); + RefSCC &OtherChildRC = *G->lookupRefSCC(*E.getNode()); + if (&OtherChildRC == &TargetRC) { + HasOtherEdgeToChildRC = true; + break; + } + if (&OtherChildRC != this) + HasOtherChildRC = true; } - if (&OtherCalleeC != this) - HasOtherCallOutsideSCC = true; + if (HasOtherEdgeToChildRC) + break; } - if (HasOtherCallToCalleeC) + if (HasOtherEdgeToChildRC) break; } // Because the SCCs form a DAG, deleting such an edge cannot change the set // of SCCs in the graph. However, it may cut an edge of the SCC DAG, making - // the caller no longer a parent of the callee. Walk the other call edges - // in the caller to tell. - if (!HasOtherCallToCalleeC) { - bool Removed = CalleeC.ParentSCCs.erase(this); + // the source SCC no longer connected to the target SCC. If so, we need to + // update the target SCC's map of its parents. + if (!HasOtherEdgeToChildRC) { + bool Removed = TargetRC.Parents.erase(this); (void)Removed; assert(Removed && - "Did not find the caller SCC in the callee SCC's parent list!"); + "Did not find the source SCC in the target SCC's parent list!"); // It may orphan an SCC if it is the last edge reaching it, but that does // not violate any invariants of the graph. - if (CalleeC.ParentSCCs.empty()) - DEBUG(dbgs() << "LCG: Update removing " << CallerN.getFunction().getName() - << " -> " << CalleeN.getFunction().getName() + if (TargetRC.Parents.empty()) + DEBUG(dbgs() << "LCG: Update removing " << SourceN.getFunction().getName() + << " -> " << TargetN.getFunction().getName() << " edge orphaned the callee's SCC!\n"); - } - // It may make the Caller SCC a leaf SCC. - if (!HasOtherCallOutsideSCC) - G->LeafSCCs.push_back(this); + // It may make the Source SCC a leaf SCC. + if (!HasOtherChildRC) + G->LeafRefSCCs.push_back(this); + } } -void LazyCallGraph::SCC::internalDFS( - SmallVectorImpl<std::pair<Node *, Node::iterator>> &DFSStack, - SmallVectorImpl<Node *> &PendingSCCStack, Node *N, - SmallVectorImpl<SCC *> &ResultSCCs) { - Node::iterator I = N->begin(); - N->LowLink = N->DFSNumber = 1; - int NextDFSNumber = 2; - for (;;) { - assert(N->DFSNumber != 0 && "We should always assign a DFS number " - "before processing a node."); +SmallVector<LazyCallGraph::RefSCC *, 1> +LazyCallGraph::RefSCC::removeInternalRefEdge(Node &SourceN, Node &TargetN) { + assert(!SourceN[TargetN].isCall() && + "Cannot remove a call edge, it must first be made a ref edge"); - // We simulate recursion by popping out of the nested loop and continuing. - Node::iterator E = N->end(); - while (I != E) { - Node &ChildN = *I; - if (SCC *ChildSCC = G->SCCMap.lookup(&ChildN)) { - // Check if we have reached a node in the new (known connected) set of - // this SCC. If so, the entire stack is necessarily in that set and we - // can re-start. - if (ChildSCC == this) { - insert(*N); - while (!PendingSCCStack.empty()) - insert(*PendingSCCStack.pop_back_val()); - while (!DFSStack.empty()) - insert(*DFSStack.pop_back_val().first); - return; - } + // First remove the actual edge. + SourceN.removeEdgeInternal(TargetN.getFunction()); - // If this child isn't currently in this SCC, no need to process it. - // However, we do need to remove this SCC from its SCC's parent set. - ChildSCC->ParentSCCs.erase(this); - ++I; - continue; - } + // We return a list of the resulting *new* RefSCCs in post-order. + SmallVector<RefSCC *, 1> Result; - if (ChildN.DFSNumber == 0) { - // Mark that we should start at this child when next this node is the - // top of the stack. We don't start at the next child to ensure this - // child's lowlink is reflected. - DFSStack.push_back(std::make_pair(N, I)); + // Direct recursion doesn't impact the SCC graph at all. + if (&SourceN == &TargetN) + return Result; + + // We build somewhat synthetic new RefSCCs by providing a postorder mapping + // for each inner SCC. We also store these associated with *nodes* rather + // than SCCs because this saves a round-trip through the node->SCC map and in + // the common case, SCCs are small. We will verify that we always give the + // same number to every node in the SCC such that these are equivalent. + const int RootPostOrderNumber = 0; + int PostOrderNumber = RootPostOrderNumber + 1; + SmallDenseMap<Node *, int> PostOrderMapping; + + // Every node in the target SCC can already reach every node in this RefSCC + // (by definition). It is the only node we know will stay inside this RefSCC. + // Everything which transitively reaches Target will also remain in the + // RefSCC. We handle this by pre-marking that the nodes in the target SCC map + // back to the root post order number. + // + // This also enables us to take a very significant short-cut in the standard + // Tarjan walk to re-form RefSCCs below: whenever we build an edge that + // references the target node, we know that the target node eventually + // references all other nodes in our walk. As a consequence, we can detect + // and handle participants in that cycle without walking all the edges that + // form the connections, and instead by relying on the fundamental guarantee + // coming into this operation. + SCC &TargetC = *G->lookupSCC(TargetN); + for (Node &N : TargetC) + PostOrderMapping[&N] = RootPostOrderNumber; + + // Reset all the other nodes to prepare for a DFS over them, and add them to + // our worklist. + SmallVector<Node *, 8> Worklist; + for (SCC *C : SCCs) { + if (C == &TargetC) + continue; - // Continue, resetting to the child node. - ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; - N = &ChildN; - I = ChildN.begin(); - E = ChildN.end(); - continue; - } + for (Node &N : *C) + N.DFSNumber = N.LowLink = 0; - // Track the lowest link of the children, if any are still in the stack. - // Any child not on the stack will have a LowLink of -1. - assert(ChildN.LowLink != 0 && - "Low-link must not be zero with a non-zero DFS number."); - if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink) - N->LowLink = ChildN.LowLink; - ++I; - } + Worklist.append(C->Nodes.begin(), C->Nodes.end()); + } - if (N->LowLink == N->DFSNumber) { - ResultSCCs.push_back(G->formSCC(N, PendingSCCStack)); - if (DFSStack.empty()) - return; - } else { - // At this point we know that N cannot ever be an SCC root. Its low-link - // is not its dfs-number, and we've processed all of its children. It is - // just sitting here waiting until some node further down the stack gets - // low-link == dfs-number and pops it off as well. Move it to the pending - // stack which is pulled into the next SCC to be formed. - PendingSCCStack.push_back(N); + auto MarkNodeForSCCNumber = [&PostOrderMapping](Node &N, int Number) { + N.DFSNumber = N.LowLink = -1; + PostOrderMapping[&N] = Number; + }; - assert(!DFSStack.empty() && "We shouldn't have an empty stack!"); + SmallVector<std::pair<Node *, edge_iterator>, 4> DFSStack; + SmallVector<Node *, 4> PendingRefSCCStack; + do { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingRefSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + Node *RootN = Worklist.pop_back_val(); + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; } - N = DFSStack.back().first; - I = DFSStack.back().second; - DFSStack.pop_back(); - } -} + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; -SmallVector<LazyCallGraph::SCC *, 1> -LazyCallGraph::SCC::removeIntraSCCEdge(Node &CallerN, Node &CalleeN) { - // First remove it from the node. - CallerN.removeEdgeInternal(CalleeN.getFunction()); + DFSStack.push_back({RootN, RootN->begin()}); + do { + Node *N; + edge_iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = N->end(); + + assert(N->DFSNumber != 0 && "We should always assign a DFS number " + "before processing a node."); + + while (I != E) { + Node &ChildN = I->getNode(*G); + if (ChildN.DFSNumber == 0) { + // Mark that we should start at this child when next this node is the + // top of the stack. We don't start at the next child to ensure this + // child's lowlink is reflected. + DFSStack.push_back({N, I}); + + // Continue, resetting to the child node. + ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; + N = &ChildN; + I = ChildN.begin(); + E = ChildN.end(); + continue; + } + if (ChildN.DFSNumber == -1) { + // Check if this edge's target node connects to the deleted edge's + // target node. If so, we know that every node connected will end up + // in this RefSCC, so collapse the entire current stack into the root + // slot in our SCC numbering. See above for the motivation of + // optimizing the target connected nodes in this way. + auto PostOrderI = PostOrderMapping.find(&ChildN); + if (PostOrderI != PostOrderMapping.end() && + PostOrderI->second == RootPostOrderNumber) { + MarkNodeForSCCNumber(*N, RootPostOrderNumber); + while (!PendingRefSCCStack.empty()) + MarkNodeForSCCNumber(*PendingRefSCCStack.pop_back_val(), + RootPostOrderNumber); + while (!DFSStack.empty()) + MarkNodeForSCCNumber(*DFSStack.pop_back_val().first, + RootPostOrderNumber); + // Ensure we break all the way out of the enclosing loop. + N = nullptr; + break; + } + + // If this child isn't currently in this RefSCC, no need to process + // it. + // However, we do need to remove this RefSCC from its RefSCC's parent + // set. + RefSCC &ChildRC = *G->lookupRefSCC(ChildN); + ChildRC.Parents.erase(this); + ++I; + continue; + } - // We return a list of the resulting *new* SCCs in postorder. - SmallVector<SCC *, 1> ResultSCCs; + // Track the lowest link of the children, if any are still in the stack. + // Any child not on the stack will have a LowLink of -1. + assert(ChildN.LowLink != 0 && + "Low-link must not be zero with a non-zero DFS number."); + if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + ++I; + } + if (!N) + // We short-circuited this node. + break; - // Direct recursion doesn't impact the SCC graph at all. - if (&CallerN == &CalleeN) - return ResultSCCs; + // We've finished processing N and its descendents, put it on our pending + // stack to eventually get merged into a RefSCC. + PendingRefSCCStack.push_back(N); - // The worklist is every node in the original SCC. - SmallVector<Node *, 1> Worklist; - Worklist.swap(Nodes); - for (Node *N : Worklist) { - // The nodes formerly in this SCC are no longer in any SCC. - N->DFSNumber = 0; - N->LowLink = 0; - G->SCCMap.erase(N); - } - assert(Worklist.size() > 1 && "We have to have at least two nodes to have an " - "edge between them that is within the SCC."); - - // The callee can already reach every node in this SCC (by definition). It is - // the only node we know will stay inside this SCC. Everything which - // transitively reaches Callee will also remain in the SCC. To model this we - // incrementally add any chain of nodes which reaches something in the new - // node set to the new node set. This short circuits one side of the Tarjan's - // walk. - insert(CalleeN); - - // We're going to do a full mini-Tarjan's walk using a local stack here. - SmallVector<std::pair<Node *, Node::iterator>, 4> DFSStack; - SmallVector<Node *, 4> PendingSCCStack; - do { - Node *N = Worklist.pop_back_val(); - if (N->DFSNumber == 0) - internalDFS(DFSStack, PendingSCCStack, N, ResultSCCs); + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) { + assert(!DFSStack.empty() && + "We never found a viable root for a RefSCC to pop off!"); + continue; + } + + // Otherwise, form a new RefSCC from the top of the pending node stack. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto RefSCCNodes = make_range( + PendingRefSCCStack.rbegin(), + std::find_if(PendingRefSCCStack.rbegin(), PendingRefSCCStack.rend(), + [RootDFSNumber](Node *N) { + return N->DFSNumber < RootDFSNumber; + })); + + // Mark the postorder number for these nodes and clear them off the + // stack. We'll use the postorder number to pull them into RefSCCs at the + // end. FIXME: Fuse with the loop above. + int RefSCCNumber = PostOrderNumber++; + for (Node *N : RefSCCNodes) + MarkNodeForSCCNumber(*N, RefSCCNumber); + + PendingRefSCCStack.erase(RefSCCNodes.end().base(), + PendingRefSCCStack.end()); + } while (!DFSStack.empty()); assert(DFSStack.empty() && "Didn't flush the entire DFS stack!"); - assert(PendingSCCStack.empty() && "Didn't flush all pending SCC nodes!"); + assert(PendingRefSCCStack.empty() && "Didn't flush all pending nodes!"); } while (!Worklist.empty()); - // Now we need to reconnect the current SCC to the graph. - bool IsLeafSCC = true; - for (Node *N : Nodes) { - for (Node &ChildN : *N) { - SCC &ChildSCC = *G->SCCMap.lookup(&ChildN); - if (&ChildSCC == this) - continue; - ChildSCC.ParentSCCs.insert(this); - IsLeafSCC = false; - } + // We now have a post-order numbering for RefSCCs and a mapping from each + // node in this RefSCC to its final RefSCC. We create each new RefSCC node + // (re-using this RefSCC node for the root) and build a radix-sort style map + // from postorder number to the RefSCC. We then append SCCs to each of these + // RefSCCs in the order they occured in the original SCCs container. + for (int i = 1; i < PostOrderNumber; ++i) + Result.push_back(G->createRefSCC(*G)); + + for (SCC *C : SCCs) { + auto PostOrderI = PostOrderMapping.find(&*C->begin()); + assert(PostOrderI != PostOrderMapping.end() && + "Cannot have missing mappings for nodes!"); + int SCCNumber = PostOrderI->second; +#ifndef NDEBUG + for (Node &N : *C) + assert(PostOrderMapping.find(&N)->second == SCCNumber && + "Cannot have different numbers for nodes in the same SCC!"); +#endif + if (SCCNumber == 0) + // The root node is handled separately by removing the SCCs. + continue; + + RefSCC &RC = *Result[SCCNumber - 1]; + int SCCIndex = RC.SCCs.size(); + RC.SCCs.push_back(C); + SCCIndices[C] = SCCIndex; + C->OuterRefSCC = &RC; } + + // FIXME: We re-walk the edges in each RefSCC to establish whether it is + // a leaf and connect it to the rest of the graph's parents lists. This is + // really wasteful. We should instead do this during the DFS to avoid yet + // another edge walk. + for (RefSCC *RC : Result) + G->connectRefSCC(*RC); + + // Now erase all but the root's SCCs. + SCCs.erase(std::remove_if(SCCs.begin(), SCCs.end(), + [&](SCC *C) { + return PostOrderMapping.lookup(&*C->begin()) != + RootPostOrderNumber; + }), + SCCs.end()); + +#ifndef NDEBUG + // Now we need to reconnect the current (root) SCC to the graph. We do this + // manually because we can special case our leaf handling and detect errors. + bool IsLeaf = true; +#endif + for (SCC *C : SCCs) + for (Node &N : *C) { + for (Edge &E : N) { + assert(E.getNode() && "Cannot have a missing node in a visited SCC!"); + RefSCC &ChildRC = *G->lookupRefSCC(*E.getNode()); + if (&ChildRC == this) + continue; + ChildRC.Parents.insert(this); +#ifndef NDEBUG + IsLeaf = false; +#endif + } + } #ifndef NDEBUG - if (!ResultSCCs.empty()) - assert(!IsLeafSCC && "This SCC cannot be a leaf as we have split out new " - "SCCs by removing this edge."); - if (!std::any_of(G->LeafSCCs.begin(), G->LeafSCCs.end(), - [&](SCC *C) { return C == this; })) - assert(!IsLeafSCC && "This SCC cannot be a leaf as it already had child " - "SCCs before we removed this edge."); + if (!Result.empty()) + assert(!IsLeaf && "This SCC cannot be a leaf as we have split out new " + "SCCs by removing this edge."); + if (!std::any_of(G->LeafRefSCCs.begin(), G->LeafRefSCCs.end(), + [&](RefSCC *C) { return C == this; })) + assert(!IsLeaf && "This SCC cannot be a leaf as it already had child " + "SCCs before we removed this edge."); #endif // If this SCC stopped being a leaf through this edge removal, remove it from - // the leaf SCC list. - if (!IsLeafSCC && !ResultSCCs.empty()) - G->LeafSCCs.erase(std::remove(G->LeafSCCs.begin(), G->LeafSCCs.end(), this), - G->LeafSCCs.end()); + // the leaf SCC list. Note that this DTRT in the case where this was never + // a leaf. + // FIXME: As LeafRefSCCs could be very large, we might want to not walk the + // entire list if this RefSCC wasn't a leaf before the edge removal. + if (!Result.empty()) + G->LeafRefSCCs.erase( + std::remove(G->LeafRefSCCs.begin(), G->LeafRefSCCs.end(), this), + G->LeafRefSCCs.end()); // Return the new list of SCCs. - return ResultSCCs; + return Result; } -void LazyCallGraph::insertEdge(Node &CallerN, Function &Callee) { +void LazyCallGraph::insertEdge(Node &SourceN, Function &Target, Edge::Kind EK) { assert(SCCMap.empty() && DFSStack.empty() && "This method cannot be called after SCCs have been formed!"); - return CallerN.insertEdgeInternal(Callee); + return SourceN.insertEdgeInternal(Target, EK); } -void LazyCallGraph::removeEdge(Node &CallerN, Function &Callee) { +void LazyCallGraph::removeEdge(Node &SourceN, Function &Target) { assert(SCCMap.empty() && DFSStack.empty() && "This method cannot be called after SCCs have been formed!"); - return CallerN.removeEdgeInternal(Callee); + return SourceN.removeEdgeInternal(Target); } LazyCallGraph::Node &LazyCallGraph::insertInto(Function &F, Node *&MappedN) { @@ -550,133 +1248,266 @@ void LazyCallGraph::updateGraphPtrs() { // Process all nodes updating the graph pointers. { SmallVector<Node *, 16> Worklist; - for (auto &Entry : EntryNodes) - if (Node *EntryN = Entry.dyn_cast<Node *>()) + for (Edge &E : EntryEdges) + if (Node *EntryN = E.getNode()) Worklist.push_back(EntryN); while (!Worklist.empty()) { Node *N = Worklist.pop_back_val(); N->G = this; - for (auto &Callee : N->Callees) - if (!Callee.isNull()) - if (Node *CalleeN = Callee.dyn_cast<Node *>()) - Worklist.push_back(CalleeN); + for (Edge &E : N->Edges) + if (Node *TargetN = E.getNode()) + Worklist.push_back(TargetN); } } // Process all SCCs updating the graph pointers. { - SmallVector<SCC *, 16> Worklist(LeafSCCs.begin(), LeafSCCs.end()); + SmallVector<RefSCC *, 16> Worklist(LeafRefSCCs.begin(), LeafRefSCCs.end()); while (!Worklist.empty()) { - SCC *C = Worklist.pop_back_val(); - C->G = this; - Worklist.insert(Worklist.end(), C->ParentSCCs.begin(), - C->ParentSCCs.end()); + RefSCC &C = *Worklist.pop_back_val(); + C.G = this; + for (RefSCC &ParentC : C.parents()) + Worklist.push_back(&ParentC); } } } -LazyCallGraph::SCC *LazyCallGraph::formSCC(Node *RootN, - SmallVectorImpl<Node *> &NodeStack) { - // The tail of the stack is the new SCC. Allocate the SCC and pop the stack - // into it. - SCC *NewSCC = new (SCCBPA.Allocate()) SCC(*this); +/// Build the internal SCCs for a RefSCC from a sequence of nodes. +/// +/// Appends the SCCs to the provided vector and updates the map with their +/// indices. Both the vector and map must be empty when passed into this +/// routine. +void LazyCallGraph::buildSCCs(RefSCC &RC, node_stack_range Nodes) { + assert(RC.SCCs.empty() && "Already built SCCs!"); + assert(RC.SCCIndices.empty() && "Already mapped SCC indices!"); - while (!NodeStack.empty() && NodeStack.back()->DFSNumber > RootN->DFSNumber) { - assert(NodeStack.back()->LowLink >= RootN->LowLink && + for (Node *N : Nodes) { + assert(N->LowLink >= (*Nodes.begin())->LowLink && "We cannot have a low link in an SCC lower than its root on the " "stack!"); - NewSCC->insert(*NodeStack.pop_back_val()); + + // This node will go into the next RefSCC, clear out its DFS and low link + // as we scan. + N->DFSNumber = N->LowLink = 0; } - NewSCC->insert(*RootN); - - // A final pass over all edges in the SCC (this remains linear as we only - // do this once when we build the SCC) to connect it to the parent sets of - // its children. - bool IsLeafSCC = true; - for (Node *SCCN : NewSCC->Nodes) - for (Node &SCCChildN : *SCCN) { - SCC &ChildSCC = *SCCMap.lookup(&SCCChildN); - if (&ChildSCC == NewSCC) - continue; - ChildSCC.ParentSCCs.insert(NewSCC); - IsLeafSCC = false; + + // Each RefSCC contains a DAG of the call SCCs. To build these, we do + // a direct walk of the call edges using Tarjan's algorithm. We reuse the + // internal storage as we won't need it for the outer graph's DFS any longer. + + SmallVector<std::pair<Node *, call_edge_iterator>, 16> DFSStack; + SmallVector<Node *, 16> PendingSCCStack; + + // Scan down the stack and DFS across the call edges. + for (Node *RootN : Nodes) { + assert(DFSStack.empty() && + "Cannot begin a new root with a non-empty DFS stack!"); + assert(PendingSCCStack.empty() && + "Cannot begin a new root with pending nodes for an SCC!"); + + // Skip any nodes we've already reached in the DFS. + if (RootN->DFSNumber != 0) { + assert(RootN->DFSNumber == -1 && + "Shouldn't have any mid-DFS root nodes!"); + continue; } - // For the SCCs where we fine no child SCCs, add them to the leaf list. - if (IsLeafSCC) - LeafSCCs.push_back(NewSCC); + RootN->DFSNumber = RootN->LowLink = 1; + int NextDFSNumber = 2; + + DFSStack.push_back({RootN, RootN->call_begin()}); + do { + Node *N; + call_edge_iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); + auto E = N->call_end(); + while (I != E) { + Node &ChildN = *I->getNode(); + if (ChildN.DFSNumber == 0) { + // We haven't yet visited this child, so descend, pushing the current + // node onto the stack. + DFSStack.push_back({N, I}); + + assert(!lookupSCC(ChildN) && + "Found a node with 0 DFS number but already in an SCC!"); + ChildN.DFSNumber = ChildN.LowLink = NextDFSNumber++; + N = &ChildN; + I = N->call_begin(); + E = N->call_end(); + continue; + } + + // If the child has already been added to some child component, it + // couldn't impact the low-link of this parent because it isn't + // connected, and thus its low-link isn't relevant so skip it. + if (ChildN.DFSNumber == -1) { + ++I; + continue; + } + + // Track the lowest linked child as the lowest link for this node. + assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); + if (ChildN.LowLink < N->LowLink) + N->LowLink = ChildN.LowLink; + + // Move to the next edge. + ++I; + } + + // We've finished processing N and its descendents, put it on our pending + // SCC stack to eventually get merged into an SCC of nodes. + PendingSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) + continue; + + // Otherwise, we've completed an SCC. Append it to our post order list of + // SCCs. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto SCCNodes = make_range( + PendingSCCStack.rbegin(), + std::find_if(PendingSCCStack.rbegin(), PendingSCCStack.rend(), + [RootDFSNumber](Node *N) { + return N->DFSNumber < RootDFSNumber; + })); + // Form a new SCC out of these nodes and then clear them off our pending + // stack. + RC.SCCs.push_back(createSCC(RC, SCCNodes)); + for (Node &N : *RC.SCCs.back()) { + N.DFSNumber = N.LowLink = -1; + SCCMap[&N] = RC.SCCs.back(); + } + PendingSCCStack.erase(SCCNodes.end().base(), PendingSCCStack.end()); + } while (!DFSStack.empty()); + } + + // Wire up the SCC indices. + for (int i = 0, Size = RC.SCCs.size(); i < Size; ++i) + RC.SCCIndices[RC.SCCs[i]] = i; +} + +// FIXME: We should move callers of this to embed the parent linking and leaf +// tracking into their DFS in order to remove a full walk of all edges. +void LazyCallGraph::connectRefSCC(RefSCC &RC) { + // Walk all edges in the RefSCC (this remains linear as we only do this once + // when we build the RefSCC) to connect it to the parent sets of its + // children. + bool IsLeaf = true; + for (SCC &C : RC) + for (Node &N : C) + for (Edge &E : N) { + assert(E.getNode() && + "Cannot have a missing node in a visited part of the graph!"); + RefSCC &ChildRC = *lookupRefSCC(*E.getNode()); + if (&ChildRC == &RC) + continue; + ChildRC.Parents.insert(&RC); + IsLeaf = false; + } - return NewSCC; + // For the SCCs where we fine no child SCCs, add them to the leaf list. + if (IsLeaf) + LeafRefSCCs.push_back(&RC); } -LazyCallGraph::SCC *LazyCallGraph::getNextSCCInPostOrder() { - Node *N; - Node::iterator I; - if (!DFSStack.empty()) { - N = DFSStack.back().first; - I = DFSStack.back().second; - DFSStack.pop_back(); - } else { - // If we've handled all candidate entry nodes to the SCC forest, we're done. +LazyCallGraph::RefSCC *LazyCallGraph::getNextRefSCCInPostOrder() { + if (DFSStack.empty()) { + Node *N; do { - if (SCCEntryNodes.empty()) + // If we've handled all candidate entry nodes to the SCC forest, we're + // done. + if (RefSCCEntryNodes.empty()) return nullptr; - N = &get(*SCCEntryNodes.pop_back_val()); + N = &get(*RefSCCEntryNodes.pop_back_val()); } while (N->DFSNumber != 0); - I = N->begin(); + + // Found a new root, begin the DFS here. N->LowLink = N->DFSNumber = 1; NextDFSNumber = 2; + DFSStack.push_back({N, N->begin()}); } for (;;) { - assert(N->DFSNumber != 0 && "We should always assign a DFS number " - "before placing a node onto the stack."); + Node *N; + edge_iterator I; + std::tie(N, I) = DFSStack.pop_back_val(); - Node::iterator E = N->end(); + assert(N->DFSNumber > 0 && "We should always assign a DFS number " + "before placing a node onto the stack."); + + auto E = N->end(); while (I != E) { - Node &ChildN = *I; + Node &ChildN = I->getNode(*this); if (ChildN.DFSNumber == 0) { - // Mark that we should start at this child when next this node is the - // top of the stack. We don't start at the next child to ensure this - // child's lowlink is reflected. - DFSStack.push_back(std::make_pair(N, N->begin())); + // We haven't yet visited this child, so descend, pushing the current + // node onto the stack. + DFSStack.push_back({N, N->begin()}); - // Recurse onto this node via a tail call. assert(!SCCMap.count(&ChildN) && "Found a node with 0 DFS number but already in an SCC!"); ChildN.LowLink = ChildN.DFSNumber = NextDFSNumber++; N = &ChildN; - I = ChildN.begin(); - E = ChildN.end(); + I = N->begin(); + E = N->end(); + continue; + } + + // If the child has already been added to some child component, it + // couldn't impact the low-link of this parent because it isn't + // connected, and thus its low-link isn't relevant so skip it. + if (ChildN.DFSNumber == -1) { + ++I; continue; } - // Track the lowest link of the children, if any are still in the stack. - assert(ChildN.LowLink != 0 && - "Low-link must not be zero with a non-zero DFS number."); - if (ChildN.LowLink >= 0 && ChildN.LowLink < N->LowLink) + // Track the lowest linked child as the lowest link for this node. + assert(ChildN.LowLink > 0 && "Must have a positive low-link number!"); + if (ChildN.LowLink < N->LowLink) N->LowLink = ChildN.LowLink; + + // Move to the next edge. ++I; } - if (N->LowLink == N->DFSNumber) - // Form the new SCC out of the top of the DFS stack. - return formSCC(N, PendingSCCStack); - - // At this point we know that N cannot ever be an SCC root. Its low-link - // is not its dfs-number, and we've processed all of its children. It is - // just sitting here waiting until some node further down the stack gets - // low-link == dfs-number and pops it off as well. Move it to the pending - // stack which is pulled into the next SCC to be formed. - PendingSCCStack.push_back(N); - - assert(!DFSStack.empty() && "We never found a viable root!"); - N = DFSStack.back().first; - I = DFSStack.back().second; - DFSStack.pop_back(); + // We've finished processing N and its descendents, put it on our pending + // SCC stack to eventually get merged into an SCC of nodes. + PendingRefSCCStack.push_back(N); + + // If this node is linked to some lower entry, continue walking up the + // stack. + if (N->LowLink != N->DFSNumber) { + assert(!DFSStack.empty() && + "We never found a viable root for an SCC to pop off!"); + continue; + } + + // Otherwise, form a new RefSCC from the top of the pending node stack. + int RootDFSNumber = N->DFSNumber; + // Find the range of the node stack by walking down until we pass the + // root DFS number. + auto RefSCCNodes = node_stack_range( + PendingRefSCCStack.rbegin(), + std::find_if( + PendingRefSCCStack.rbegin(), PendingRefSCCStack.rend(), + [RootDFSNumber](Node *N) { return N->DFSNumber < RootDFSNumber; })); + // Form a new RefSCC out of these nodes and then clear them off our pending + // stack. + RefSCC *NewRC = createRefSCC(*this); + buildSCCs(*NewRC, RefSCCNodes); + connectRefSCC(*NewRC); + PendingRefSCCStack.erase(RefSCCNodes.end().base(), + PendingRefSCCStack.end()); + + // We return the new node here. This essentially suspends the DFS walk + // until another RefSCC is requested. + return NewRC; } } @@ -684,44 +1515,76 @@ char LazyCallGraphAnalysis::PassID; LazyCallGraphPrinterPass::LazyCallGraphPrinterPass(raw_ostream &OS) : OS(OS) {} -static void printNodes(raw_ostream &OS, LazyCallGraph::Node &N, - SmallPtrSetImpl<LazyCallGraph::Node *> &Printed) { - // Recurse depth first through the nodes. - for (LazyCallGraph::Node &ChildN : N) - if (Printed.insert(&ChildN).second) - printNodes(OS, ChildN, Printed); - - OS << " Call edges in function: " << N.getFunction().getName() << "\n"; - for (LazyCallGraph::iterator I = N.begin(), E = N.end(); I != E; ++I) - OS << " -> " << I->getFunction().getName() << "\n"; +static void printNode(raw_ostream &OS, LazyCallGraph::Node &N) { + OS << " Edges in function: " << N.getFunction().getName() << "\n"; + for (const LazyCallGraph::Edge &E : N) + OS << " " << (E.isCall() ? "call" : "ref ") << " -> " + << E.getFunction().getName() << "\n"; OS << "\n"; } -static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &SCC) { - ptrdiff_t SCCSize = std::distance(SCC.begin(), SCC.end()); - OS << " SCC with " << SCCSize << " functions:\n"; +static void printSCC(raw_ostream &OS, LazyCallGraph::SCC &C) { + ptrdiff_t Size = std::distance(C.begin(), C.end()); + OS << " SCC with " << Size << " functions:\n"; + + for (LazyCallGraph::Node &N : C) + OS << " " << N.getFunction().getName() << "\n"; +} - for (LazyCallGraph::Node *N : SCC) - OS << " " << N->getFunction().getName() << "\n"; +static void printRefSCC(raw_ostream &OS, LazyCallGraph::RefSCC &C) { + ptrdiff_t Size = std::distance(C.begin(), C.end()); + OS << " RefSCC with " << Size << " call SCCs:\n"; + + for (LazyCallGraph::SCC &InnerC : C) + printSCC(OS, InnerC); OS << "\n"; } PreservedAnalyses LazyCallGraphPrinterPass::run(Module &M, - ModuleAnalysisManager *AM) { - LazyCallGraph &G = AM->getResult<LazyCallGraphAnalysis>(M); + ModuleAnalysisManager &AM) { + LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); OS << "Printing the call graph for module: " << M.getModuleIdentifier() << "\n\n"; - SmallPtrSet<LazyCallGraph::Node *, 16> Printed; - for (LazyCallGraph::Node &N : G) - if (Printed.insert(&N).second) - printNodes(OS, N, Printed); + for (Function &F : M) + printNode(OS, G.get(F)); + + for (LazyCallGraph::RefSCC &C : G.postorder_ref_sccs()) + printRefSCC(OS, C); + + return PreservedAnalyses::all(); +} + +LazyCallGraphDOTPrinterPass::LazyCallGraphDOTPrinterPass(raw_ostream &OS) + : OS(OS) {} + +static void printNodeDOT(raw_ostream &OS, LazyCallGraph::Node &N) { + std::string Name = "\"" + DOT::EscapeString(N.getFunction().getName()) + "\""; + + for (const LazyCallGraph::Edge &E : N) { + OS << " " << Name << " -> \"" + << DOT::EscapeString(E.getFunction().getName()) << "\""; + if (!E.isCall()) // It is a ref edge. + OS << " [style=dashed,label=\"ref\"]"; + OS << ";\n"; + } + + OS << "\n"; +} + +PreservedAnalyses LazyCallGraphDOTPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + LazyCallGraph &G = AM.getResult<LazyCallGraphAnalysis>(M); + + OS << "digraph \"" << DOT::EscapeString(M.getModuleIdentifier()) << "\" {\n"; + + for (Function &F : M) + printNodeDOT(OS, G.get(F)); - for (LazyCallGraph::SCC &SCC : G.postorder_sccs()) - printSCC(OS, SCC); + OS << "}\n"; return PreservedAnalyses::all(); } diff --git a/lib/Analysis/LazyValueInfo.cpp b/lib/Analysis/LazyValueInfo.cpp index 0d1d34e0cb4f..4d09b7ca006b 100644 --- a/lib/Analysis/LazyValueInfo.cpp +++ b/lib/Analysis/LazyValueInfo.cpp @@ -38,18 +38,19 @@ using namespace PatternMatch; #define DEBUG_TYPE "lazy-value-info" -char LazyValueInfo::ID = 0; -INITIALIZE_PASS_BEGIN(LazyValueInfo, "lazy-value-info", +char LazyValueInfoWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(LazyValueInfoWrapperPass, "lazy-value-info", "Lazy Value Information Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(LazyValueInfo, "lazy-value-info", +INITIALIZE_PASS_END(LazyValueInfoWrapperPass, "lazy-value-info", "Lazy Value Information Analysis", false, true) namespace llvm { - FunctionPass *createLazyValueInfoPass() { return new LazyValueInfo(); } + FunctionPass *createLazyValueInfoPass() { return new LazyValueInfoWrapperPass(); } } +char LazyValueAnalysis::PassID; //===----------------------------------------------------------------------===// // LVILatticeVal @@ -63,19 +64,24 @@ namespace llvm { namespace { class LVILatticeVal { enum LatticeValueTy { - /// This Value has no known value yet. + /// This Value has no known value yet. As a result, this implies the + /// producing instruction is dead. Caution: We use this as the starting + /// state in our local meet rules. In this usage, it's taken to mean + /// "nothing known yet". undefined, - /// This Value has a specific constant value. + /// This Value has a specific constant value. (For integers, constantrange + /// is used instead.) constant, - /// This Value is known to not have the specified value. + /// This Value is known to not have the specified value. (For integers, + /// constantrange is used instead.) notconstant, - /// The Value falls within this range. + /// The Value falls within this range. (Used only for integer typed values.) constantrange, - /// This value is not known to be constant, and we know that it has a value. + /// We can not precisely model the dynamic values this value might take. overdefined }; @@ -102,7 +108,7 @@ public: } static LVILatticeVal getRange(ConstantRange CR) { LVILatticeVal Res; - Res.markConstantRange(CR); + Res.markConstantRange(std::move(CR)); return Res; } static LVILatticeVal getOverdefined() { @@ -110,7 +116,7 @@ public: Res.markOverdefined(); return Res; } - + bool isUndefined() const { return Tag == undefined; } bool isConstant() const { return Tag == constant; } bool isNotConstant() const { return Tag == notconstant; } @@ -176,13 +182,13 @@ public: } /// Return true if this is a change in status. - bool markConstantRange(const ConstantRange NewR) { + bool markConstantRange(ConstantRange NewR) { if (isConstantRange()) { if (NewR.isEmptySet()) return markOverdefined(); bool changed = Range != NewR; - Range = NewR; + Range = std::move(NewR); return changed; } @@ -191,7 +197,7 @@ public: return markOverdefined(); Tag = constantrange; - Range = NewR; + Range = std::move(NewR); return true; } @@ -230,11 +236,6 @@ public: return markOverdefined(); } - // RHS is a ConstantRange, LHS is a non-integer Constant. - - // FIXME: consider the case where RHS is a range [1, 0) and LHS is - // a function. The correct result is to pick up RHS. - return markOverdefined(); } @@ -287,13 +288,76 @@ raw_ostream &operator<<(raw_ostream &OS, const LVILatticeVal &Val) { if (Val.isNotConstant()) return OS << "notconstant<" << *Val.getNotConstant() << '>'; - else if (Val.isConstantRange()) + if (Val.isConstantRange()) return OS << "constantrange<" << Val.getConstantRange().getLower() << ", " << Val.getConstantRange().getUpper() << '>'; return OS << "constant<" << *Val.getConstant() << '>'; } } +/// Returns true if this lattice value represents at most one possible value. +/// This is as precise as any lattice value can get while still representing +/// reachable code. +static bool hasSingleValue(const LVILatticeVal &Val) { + if (Val.isConstantRange() && + Val.getConstantRange().isSingleElement()) + // Integer constants are single element ranges + return true; + if (Val.isConstant()) + // Non integer constants + return true; + return false; +} + +/// Combine two sets of facts about the same value into a single set of +/// facts. Note that this method is not suitable for merging facts along +/// different paths in a CFG; that's what the mergeIn function is for. This +/// is for merging facts gathered about the same value at the same location +/// through two independent means. +/// Notes: +/// * This method does not promise to return the most precise possible lattice +/// value implied by A and B. It is allowed to return any lattice element +/// which is at least as strong as *either* A or B (unless our facts +/// conflict, see below). +/// * Due to unreachable code, the intersection of two lattice values could be +/// contradictory. If this happens, we return some valid lattice value so as +/// not confuse the rest of LVI. Ideally, we'd always return Undefined, but +/// we do not make this guarantee. TODO: This would be a useful enhancement. +static LVILatticeVal intersect(LVILatticeVal A, LVILatticeVal B) { + // Undefined is the strongest state. It means the value is known to be along + // an unreachable path. + if (A.isUndefined()) + return A; + if (B.isUndefined()) + return B; + + // If we gave up for one, but got a useable fact from the other, use it. + if (A.isOverdefined()) + return B; + if (B.isOverdefined()) + return A; + + // Can't get any more precise than constants. + if (hasSingleValue(A)) + return A; + if (hasSingleValue(B)) + return B; + + // Could be either constant range or not constant here. + if (!A.isConstantRange() || !B.isConstantRange()) { + // TODO: Arbitrary choice, could be improved + return A; + } + + // Intersect two constant ranges + ConstantRange Range = + A.getConstantRange().intersectWith(B.getConstantRange()); + // Note: An empty range is implicitly converted to overdefined internally. + // TODO: We could instead use Undefined here since we've proven a conflict + // and thus know this path must be unreachable. + return LVILatticeVal::getRange(std::move(Range)); +} + //===----------------------------------------------------------------------===// // LazyValueInfoCache Decl //===----------------------------------------------------------------------===// @@ -354,6 +418,8 @@ namespace { if (!BlockValueSet.insert(BV).second) return false; // It's already in the stack. + DEBUG(dbgs() << "PUSH: " << *BV.second << " in " << BV.first->getName() + << "\n"); BlockValueStack.push(BV); return true; } @@ -375,30 +441,31 @@ namespace { lookup(Val)[BB] = Result; } - LVILatticeVal getBlockValue(Value *Val, BasicBlock *BB); - bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, - LVILatticeVal &Result, - Instruction *CxtI = nullptr); - bool hasBlockValue(Value *Val, BasicBlock *BB); - - // These methods process one work item and may add more. A false value - // returned means that the work item was not completely processed and must - // be revisited after going through the new items. - bool solveBlockValue(Value *Val, BasicBlock *BB); - bool solveBlockValueNonLocal(LVILatticeVal &BBLV, - Value *Val, BasicBlock *BB); - bool solveBlockValuePHINode(LVILatticeVal &BBLV, - PHINode *PN, BasicBlock *BB); - bool solveBlockValueConstantRange(LVILatticeVal &BBLV, - Instruction *BBI, BasicBlock *BB); - void mergeAssumeBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, - Instruction *BBI); - - void solve(); - - ValueCacheEntryTy &lookup(Value *V) { - return ValueCache[LVIValueHandle(V, this)]; - } + LVILatticeVal getBlockValue(Value *Val, BasicBlock *BB); + bool getEdgeValue(Value *V, BasicBlock *F, BasicBlock *T, + LVILatticeVal &Result, Instruction *CxtI = nullptr); + bool hasBlockValue(Value *Val, BasicBlock *BB); + + // These methods process one work item and may add more. A false value + // returned means that the work item was not completely processed and must + // be revisited after going through the new items. + bool solveBlockValue(Value *Val, BasicBlock *BB); + bool solveBlockValueNonLocal(LVILatticeVal &BBLV, Value *Val, BasicBlock *BB); + bool solveBlockValuePHINode(LVILatticeVal &BBLV, PHINode *PN, BasicBlock *BB); + bool solveBlockValueSelect(LVILatticeVal &BBLV, SelectInst *S, + BasicBlock *BB); + bool solveBlockValueBinaryOp(LVILatticeVal &BBLV, Instruction *BBI, + BasicBlock *BB); + bool solveBlockValueCast(LVILatticeVal &BBLV, Instruction *BBI, + BasicBlock *BB); + void intersectAssumeBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, + Instruction *BBI); + + void solve(); + + ValueCacheEntryTy &lookup(Value *V) { + return ValueCache[LVIValueHandle(V, this)]; + } bool isOverdefined(Value *V, BasicBlock *BB) const { auto ODI = OverDefinedCache.find(BB); @@ -427,7 +494,7 @@ namespace { return lookup(V)[BB]; } - + public: /// This is the query interface to determine the lattice /// value for the specified Value* at the end of the specified block. @@ -493,8 +560,8 @@ void LazyValueInfoCache::eraseBlock(BasicBlock *BB) { if (ODI != OverDefinedCache.end()) OverDefinedCache.erase(ODI); - for (auto I = ValueCache.begin(), E = ValueCache.end(); I != E; ++I) - I->second.erase(BB); + for (auto &I : ValueCache) + I.second.erase(BB); } void LazyValueInfoCache::solve() { @@ -508,6 +575,9 @@ void LazyValueInfoCache::solve() { assert(hasCachedValueInfo(e.second, e.first) && "Result should be in cache!"); + DEBUG(dbgs() << "POP " << *e.second << " in " << e.first->getName() + << " = " << getCachedValueInfo(e.second, e.first) << "\n"); + BlockValueStack.pop(); BlockValueSet.erase(e); } else { @@ -542,15 +612,12 @@ static LVILatticeVal getFromRangeMetadata(Instruction *BBI) { case Instruction::Invoke: if (MDNode *Ranges = BBI->getMetadata(LLVMContext::MD_range)) if (isa<IntegerType>(BBI->getType())) { - ConstantRange Result = getConstantRangeFromMetadata(*Ranges); - return LVILatticeVal::getRange(Result); + return LVILatticeVal::getRange(getConstantRangeFromMetadata(*Ranges)); } break; }; - // Nothing known - Note that we do not want overdefined here. We may know - // something else about the value and not having range metadata shouldn't - // cause us to throw away those facts. - return LVILatticeVal(); + // Nothing known - will be intersected with other facts + return LVILatticeVal::getOverdefined(); } bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { @@ -587,44 +654,47 @@ bool LazyValueInfoCache::solveBlockValue(Value *Val, BasicBlock *BB) { return true; } - // If this value is a nonnull pointer, record it's range and bailout. - PointerType *PT = dyn_cast<PointerType>(BBI->getType()); - if (PT && isKnownNonNull(BBI)) { - Res = LVILatticeVal::getNot(ConstantPointerNull::get(PT)); + if (auto *SI = dyn_cast<SelectInst>(BBI)) { + if (!solveBlockValueSelect(Res, SI, BB)) + return false; insertResult(Val, BB, Res); return true; } - // If this is an instruction which supports range metadata, return the - // implied range. TODO: This should be an intersection, not a union. - Res.mergeIn(getFromRangeMetadata(BBI), DL); - - // We can only analyze the definitions of certain classes of instructions - // (integral binops and casts at the moment), so bail if this isn't one. - LVILatticeVal Result; - if ((!isa<BinaryOperator>(BBI) && !isa<CastInst>(BBI)) || - !BBI->getType()->isIntegerTy()) { - DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined because inst def found.\n"); - Res.markOverdefined(); + // If this value is a nonnull pointer, record it's range and bailout. Note + // that for all other pointer typed values, we terminate the search at the + // definition. We could easily extend this to look through geps, bitcasts, + // and the like to prove non-nullness, but it's not clear that's worth it + // compile time wise. The context-insensative value walk done inside + // isKnownNonNull gets most of the profitable cases at much less expense. + // This does mean that we have a sensativity to where the defining + // instruction is placed, even if it could legally be hoisted much higher. + // That is unfortunate. + PointerType *PT = dyn_cast<PointerType>(BBI->getType()); + if (PT && isKnownNonNull(BBI)) { + Res = LVILatticeVal::getNot(ConstantPointerNull::get(PT)); insertResult(Val, BB, Res); return true; } - - // FIXME: We're currently limited to binops with a constant RHS. This should - // be improved. - BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI); - if (BO && !isa<ConstantInt>(BO->getOperand(1))) { - DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined because inst def found.\n"); - - Res.markOverdefined(); - insertResult(Val, BB, Res); - return true; + if (BBI->getType()->isIntegerTy()) { + if (isa<CastInst>(BBI)) { + if (!solveBlockValueCast(Res, BBI, BB)) + return false; + insertResult(Val, BB, Res); + return true; + } + BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI); + if (BO && isa<ConstantInt>(BO->getOperand(1))) { + if (!solveBlockValueBinaryOp(Res, BBI, BB)) + return false; + insertResult(Val, BB, Res); + return true; + } } - if (!solveBlockValueConstantRange(Res, BBI, BB)) - return false; + DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - unknown inst def found.\n"); + Res = getFromRangeMetadata(BBI); insertResult(Val, BB, Res); return true; } @@ -660,37 +730,36 @@ static bool InstructionDereferencesPointer(Instruction *I, Value *Ptr) { return false; } +/// Return true if the allocation associated with Val is ever dereferenced +/// within the given basic block. This establishes the fact Val is not null, +/// but does not imply that the memory at Val is dereferenceable. (Val may +/// point off the end of the dereferenceable part of the object.) +static bool isObjectDereferencedInBlock(Value *Val, BasicBlock *BB) { + assert(Val->getType()->isPointerTy()); + + const DataLayout &DL = BB->getModule()->getDataLayout(); + Value *UnderlyingVal = GetUnderlyingObject(Val, DL); + // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge + // inside InstructionDereferencesPointer either. + if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, DL, 1)) + for (Instruction &I : *BB) + if (InstructionDereferencesPointer(&I, UnderlyingVal)) + return true; + return false; +} + bool LazyValueInfoCache::solveBlockValueNonLocal(LVILatticeVal &BBLV, Value *Val, BasicBlock *BB) { LVILatticeVal Result; // Start Undefined. - // If this is a pointer, and there's a load from that pointer in this BB, - // then we know that the pointer can't be NULL. - bool NotNull = false; - if (Val->getType()->isPointerTy()) { - if (isKnownNonNull(Val)) { - NotNull = true; - } else { - const DataLayout &DL = BB->getModule()->getDataLayout(); - Value *UnderlyingVal = GetUnderlyingObject(Val, DL); - // If 'GetUnderlyingObject' didn't converge, skip it. It won't converge - // inside InstructionDereferencesPointer either. - if (UnderlyingVal == GetUnderlyingObject(UnderlyingVal, DL, 1)) { - for (Instruction &I : *BB) { - if (InstructionDereferencesPointer(&I, UnderlyingVal)) { - NotNull = true; - break; - } - } - } - } - } - // If this is the entry block, we must be asking about an argument. The // value is overdefined. if (BB == &BB->getParent()->getEntryBlock()) { assert(isa<Argument>(Val) && "Unknown live-in to the entry block"); - if (NotNull) { + // Bofore giving up, see if we can prove the pointer non-null local to + // this particular block. + if (Val->getType()->isPointerTy() && + (isKnownNonNull(Val) || isObjectDereferencedInBlock(Val, BB))) { PointerType *PTy = cast<PointerType>(Val->getType()); Result = LVILatticeVal::getNot(ConstantPointerNull::get(PTy)); } else { @@ -715,10 +784,11 @@ bool LazyValueInfoCache::solveBlockValueNonLocal(LVILatticeVal &BBLV, // to overdefined. if (Result.isOverdefined()) { DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined because of pred.\n"); - // If we previously determined that this is a pointer that can't be null - // then return that rather than giving up entirely. - if (NotNull) { + << "' - overdefined because of pred (non local).\n"); + // Bofore giving up, see if we can prove the pointer non-null local to + // this particular block. + if (Val->getType()->isPointerTy() && + isObjectDereferencedInBlock(Val, BB)) { PointerType *PTy = cast<PointerType>(Val->getType()); Result = LVILatticeVal::getNot(ConstantPointerNull::get(PTy)); } @@ -760,7 +830,7 @@ bool LazyValueInfoCache::solveBlockValuePHINode(LVILatticeVal &BBLV, // to overdefined. if (Result.isOverdefined()) { DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined because of pred.\n"); + << "' - overdefined because of pred (local).\n"); BBLV = Result; return true; @@ -779,10 +849,9 @@ static bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, LVILatticeVal &Result, bool isTrueDest = true); -// If we can determine a constant range for the value Val in the context -// provided by the instruction BBI, then merge it into BBLV. If we did find a -// constant range, return true. -void LazyValueInfoCache::mergeAssumeBlockValueConstantRange(Value *Val, +// If we can determine a constraint on the value given conditions assumed by +// the program, intersect those constraints with BBLV +void LazyValueInfoCache::intersectAssumeBlockValueConstantRange(Value *Val, LVILatticeVal &BBLV, Instruction *BBI) { BBI = BBI ? BBI : dyn_cast<Instruction>(Val); @@ -799,46 +868,264 @@ void LazyValueInfoCache::mergeAssumeBlockValueConstantRange(Value *Val, Value *C = I->getArgOperand(0); if (ICmpInst *ICI = dyn_cast<ICmpInst>(C)) { LVILatticeVal Result; - if (getValueFromFromCondition(Val, ICI, Result)) { - if (BBLV.isOverdefined()) - BBLV = Result; - else - BBLV.mergeIn(Result, DL); - } + if (getValueFromFromCondition(Val, ICI, Result)) + BBLV = intersect(BBLV, Result); } } } -bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, - Instruction *BBI, - BasicBlock *BB) { - // Figure out the range of the LHS. If that fails, bail. - if (!hasBlockValue(BBI->getOperand(0), BB)) { - if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) +bool LazyValueInfoCache::solveBlockValueSelect(LVILatticeVal &BBLV, + SelectInst *SI, BasicBlock *BB) { + + // Recurse on our inputs if needed + if (!hasBlockValue(SI->getTrueValue(), BB)) { + if (pushBlockValue(std::make_pair(BB, SI->getTrueValue()))) return false; BBLV.markOverdefined(); return true; } + LVILatticeVal TrueVal = getBlockValue(SI->getTrueValue(), BB); + // If we hit overdefined, don't ask more queries. We want to avoid poisoning + // extra slots in the table if we can. + if (TrueVal.isOverdefined()) { + BBLV.markOverdefined(); + return true; + } - LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); - mergeAssumeBlockValueConstantRange(BBI->getOperand(0), LHSVal, BBI); - if (!LHSVal.isConstantRange()) { + if (!hasBlockValue(SI->getFalseValue(), BB)) { + if (pushBlockValue(std::make_pair(BB, SI->getFalseValue()))) + return false; + BBLV.markOverdefined(); + return true; + } + LVILatticeVal FalseVal = getBlockValue(SI->getFalseValue(), BB); + // If we hit overdefined, don't ask more queries. We want to avoid poisoning + // extra slots in the table if we can. + if (FalseVal.isOverdefined()) { BBLV.markOverdefined(); return true; } - ConstantRange LHSRange = LHSVal.getConstantRange(); - ConstantRange RHSRange(1); - IntegerType *ResultTy = cast<IntegerType>(BBI->getType()); - if (isa<BinaryOperator>(BBI)) { - if (ConstantInt *RHS = dyn_cast<ConstantInt>(BBI->getOperand(1))) { - RHSRange = ConstantRange(RHS->getValue()); - } else { - BBLV.markOverdefined(); - return true; + if (TrueVal.isConstantRange() && FalseVal.isConstantRange()) { + ConstantRange TrueCR = TrueVal.getConstantRange(); + ConstantRange FalseCR = FalseVal.getConstantRange(); + Value *LHS = nullptr; + Value *RHS = nullptr; + SelectPatternResult SPR = matchSelectPattern(SI, LHS, RHS); + // Is this a min specifically of our two inputs? (Avoid the risk of + // ValueTracking getting smarter looking back past our immediate inputs.) + if (SelectPatternResult::isMinOrMax(SPR.Flavor) && + LHS == SI->getTrueValue() && RHS == SI->getFalseValue()) { + switch (SPR.Flavor) { + default: + llvm_unreachable("unexpected minmax type!"); + case SPF_SMIN: /// Signed minimum + BBLV.markConstantRange(TrueCR.smin(FalseCR)); + return true; + case SPF_UMIN: /// Unsigned minimum + BBLV.markConstantRange(TrueCR.umin(FalseCR)); + return true; + case SPF_SMAX: /// Signed maximum + BBLV.markConstantRange(TrueCR.smax(FalseCR)); + return true; + case SPF_UMAX: /// Unsigned maximum + BBLV.markConstantRange(TrueCR.umax(FalseCR)); + return true; + }; + } + + // TODO: ABS, NABS from the SelectPatternResult + } + + // Can we constrain the facts about the true and false values by using the + // condition itself? This shows up with idioms like e.g. select(a > 5, a, 5). + // TODO: We could potentially refine an overdefined true value above. + if (auto *ICI = dyn_cast<ICmpInst>(SI->getCondition())) { + LVILatticeVal TrueValTaken, FalseValTaken; + if (!getValueFromFromCondition(SI->getTrueValue(), ICI, + TrueValTaken, true)) + TrueValTaken.markOverdefined(); + if (!getValueFromFromCondition(SI->getFalseValue(), ICI, + FalseValTaken, false)) + FalseValTaken.markOverdefined(); + + TrueVal = intersect(TrueVal, TrueValTaken); + FalseVal = intersect(FalseVal, FalseValTaken); + + + // Handle clamp idioms such as: + // %24 = constantrange<0, 17> + // %39 = icmp eq i32 %24, 0 + // %40 = add i32 %24, -1 + // %siv.next = select i1 %39, i32 16, i32 %40 + // %siv.next = constantrange<0, 17> not <-1, 17> + // In general, this can handle any clamp idiom which tests the edge + // condition via an equality or inequality. + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *A = ICI->getOperand(0); + if (ConstantInt *CIBase = dyn_cast<ConstantInt>(ICI->getOperand(1))) { + auto addConstants = [](ConstantInt *A, ConstantInt *B) { + assert(A->getType() == B->getType()); + return ConstantInt::get(A->getType(), A->getValue() + B->getValue()); + }; + // See if either input is A + C2, subject to the constraint from the + // condition that A != C when that input is used. We can assume that + // that input doesn't include C + C2. + ConstantInt *CIAdded; + switch (Pred) { + default: break; + case ICmpInst::ICMP_EQ: + if (match(SI->getFalseValue(), m_Add(m_Specific(A), + m_ConstantInt(CIAdded)))) { + auto ResNot = addConstants(CIBase, CIAdded); + FalseVal = intersect(FalseVal, + LVILatticeVal::getNot(ResNot)); + } + break; + case ICmpInst::ICMP_NE: + if (match(SI->getTrueValue(), m_Add(m_Specific(A), + m_ConstantInt(CIAdded)))) { + auto ResNot = addConstants(CIBase, CIAdded); + TrueVal = intersect(TrueVal, + LVILatticeVal::getNot(ResNot)); + } + break; + }; } } + LVILatticeVal Result; // Start Undefined. + Result.mergeIn(TrueVal, DL); + Result.mergeIn(FalseVal, DL); + BBLV = Result; + return true; +} + +bool LazyValueInfoCache::solveBlockValueCast(LVILatticeVal &BBLV, + Instruction *BBI, + BasicBlock *BB) { + if (!BBI->getOperand(0)->getType()->isSized()) { + // Without knowing how wide the input is, we can't analyze it in any useful + // way. + BBLV.markOverdefined(); + return true; + } + + // Filter out casts we don't know how to reason about before attempting to + // recurse on our operand. This can cut a long search short if we know we're + // not going to be able to get any useful information anways. + switch (BBI->getOpcode()) { + case Instruction::Trunc: + case Instruction::SExt: + case Instruction::ZExt: + case Instruction::BitCast: + break; + default: + // Unhandled instructions are overdefined. + DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown cast).\n"); + BBLV.markOverdefined(); + return true; + } + + // Figure out the range of the LHS. If that fails, we still apply the + // transfer rule on the full set since we may be able to locally infer + // interesting facts. + if (!hasBlockValue(BBI->getOperand(0), BB)) + if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) + // More work to do before applying this transfer rule. + return false; + + const unsigned OperandBitWidth = + DL.getTypeSizeInBits(BBI->getOperand(0)->getType()); + ConstantRange LHSRange = ConstantRange(OperandBitWidth); + if (hasBlockValue(BBI->getOperand(0), BB)) { + LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); + intersectAssumeBlockValueConstantRange(BBI->getOperand(0), LHSVal, BBI); + if (LHSVal.isConstantRange()) + LHSRange = LHSVal.getConstantRange(); + } + + const unsigned ResultBitWidth = + cast<IntegerType>(BBI->getType())->getBitWidth(); + + // NOTE: We're currently limited by the set of operations that ConstantRange + // can evaluate symbolically. Enhancing that set will allows us to analyze + // more definitions. + LVILatticeVal Result; + switch (BBI->getOpcode()) { + case Instruction::Trunc: + Result.markConstantRange(LHSRange.truncate(ResultBitWidth)); + break; + case Instruction::SExt: + Result.markConstantRange(LHSRange.signExtend(ResultBitWidth)); + break; + case Instruction::ZExt: + Result.markConstantRange(LHSRange.zeroExtend(ResultBitWidth)); + break; + case Instruction::BitCast: + Result.markConstantRange(LHSRange); + break; + default: + // Should be dead if the code above is correct + llvm_unreachable("inconsistent with above"); + break; + } + + BBLV = Result; + return true; +} + +bool LazyValueInfoCache::solveBlockValueBinaryOp(LVILatticeVal &BBLV, + Instruction *BBI, + BasicBlock *BB) { + + assert(BBI->getOperand(0)->getType()->isSized() && + "all operands to binary operators are sized"); + + // Filter out operators we don't know how to reason about before attempting to + // recurse on our operand(s). This can cut a long search short if we know + // we're not going to be able to get any useful information anways. + switch (BBI->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::And: + case Instruction::Or: + // continue into the code below + break; + default: + // Unhandled instructions are overdefined. + DEBUG(dbgs() << " compute BB '" << BB->getName() + << "' - overdefined (unknown binary operator).\n"); + BBLV.markOverdefined(); + return true; + }; + + // Figure out the range of the LHS. If that fails, use a conservative range, + // but apply the transfer rule anyways. This lets us pick up facts from + // expressions like "and i32 (call i32 @foo()), 32" + if (!hasBlockValue(BBI->getOperand(0), BB)) + if (pushBlockValue(std::make_pair(BB, BBI->getOperand(0)))) + // More work to do before applying this transfer rule. + return false; + + const unsigned OperandBitWidth = + DL.getTypeSizeInBits(BBI->getOperand(0)->getType()); + ConstantRange LHSRange = ConstantRange(OperandBitWidth); + if (hasBlockValue(BBI->getOperand(0), BB)) { + LVILatticeVal LHSVal = getBlockValue(BBI->getOperand(0), BB); + intersectAssumeBlockValueConstantRange(BBI->getOperand(0), LHSVal, BBI); + if (LHSVal.isConstantRange()) + LHSRange = LHSVal.getConstantRange(); + } + + ConstantInt *RHS = cast<ConstantInt>(BBI->getOperand(1)); + ConstantRange RHSRange = ConstantRange(RHS->getValue()); + // NOTE: We're currently limited by the set of operations that ConstantRange // can evaluate symbolically. Enhancing that set will allows us to analyze // more definitions. @@ -862,30 +1149,15 @@ bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, case Instruction::LShr: Result.markConstantRange(LHSRange.lshr(RHSRange)); break; - case Instruction::Trunc: - Result.markConstantRange(LHSRange.truncate(ResultTy->getBitWidth())); - break; - case Instruction::SExt: - Result.markConstantRange(LHSRange.signExtend(ResultTy->getBitWidth())); - break; - case Instruction::ZExt: - Result.markConstantRange(LHSRange.zeroExtend(ResultTy->getBitWidth())); - break; - case Instruction::BitCast: - Result.markConstantRange(LHSRange); - break; case Instruction::And: Result.markConstantRange(LHSRange.binaryAnd(RHSRange)); break; case Instruction::Or: Result.markConstantRange(LHSRange.binaryOr(RHSRange)); break; - - // Unhandled instructions are overdefined. default: - DEBUG(dbgs() << " compute BB '" << BB->getName() - << "' - overdefined because inst def found.\n"); - Result.markOverdefined(); + // Should be dead if the code above is correct + llvm_unreachable("inconsistent with above"); break; } @@ -895,10 +1167,11 @@ bool LazyValueInfoCache::solveBlockValueConstantRange(LVILatticeVal &BBLV, bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, LVILatticeVal &Result, bool isTrueDest) { - if (ICI && isa<Constant>(ICI->getOperand(1))) { + assert(ICI && "precondition"); + if (isa<Constant>(ICI->getOperand(1))) { if (ICI->isEquality() && ICI->getOperand(0) == Val) { // We know that V has the RHS constant if this is a true SETEQ or - // false SETNE. + // false SETNE. if (isTrueDest == (ICI->getPredicate() == ICmpInst::ICMP_EQ)) Result = LVILatticeVal::get(cast<Constant>(ICI->getOperand(1))); else @@ -926,7 +1199,7 @@ bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, // If we're interested in the false dest, invert the condition. if (!isTrueDest) TrueValues = TrueValues.inverse(); - Result = LVILatticeVal::getRange(TrueValues); + Result = LVILatticeVal::getRange(std::move(TrueValues)); return true; } } @@ -935,7 +1208,8 @@ bool getValueFromFromCondition(Value *Val, ICmpInst *ICI, } /// \brief Compute the value of Val on the edge BBFrom -> BBTo. Returns false if -/// Val is not constrained on the edge. +/// Val is not constrained on the edge. Result is unspecified if return value +/// is false. static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, BasicBlock *BBTo, LVILatticeVal &Result) { // TODO: Handle more complex conditionals. If (v == 0 || v2 < 1) is false, we @@ -985,7 +1259,7 @@ static bool getEdgeValueLocal(Value *Val, BasicBlock *BBFrom, } else if (i.getCaseSuccessor() == BBTo) EdgesVals = EdgesVals.unionWith(EdgeVal); } - Result = LVILatticeVal::getRange(EdgesVals); + Result = LVILatticeVal::getRange(std::move(EdgesVals)); return true; } return false; @@ -1002,46 +1276,29 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, return true; } - if (getEdgeValueLocal(Val, BBFrom, BBTo, Result)) { - if (!Result.isConstantRange() || - Result.getConstantRange().getSingleElement()) - return true; - - // FIXME: this check should be moved to the beginning of the function when - // LVI better supports recursive values. Even for the single value case, we - // can intersect to detect dead code (an empty range). - if (!hasBlockValue(Val, BBFrom)) { - if (pushBlockValue(std::make_pair(BBFrom, Val))) - return false; - Result.markOverdefined(); - return true; - } - - // Try to intersect ranges of the BB and the constraint on the edge. - LVILatticeVal InBlock = getBlockValue(Val, BBFrom); - mergeAssumeBlockValueConstantRange(Val, InBlock, BBFrom->getTerminator()); - // See note on the use of the CxtI with mergeAssumeBlockValueConstantRange, - // and caching, below. - mergeAssumeBlockValueConstantRange(Val, InBlock, CxtI); - if (!InBlock.isConstantRange()) - return true; + LVILatticeVal LocalResult; + if (!getEdgeValueLocal(Val, BBFrom, BBTo, LocalResult)) + // If we couldn't constrain the value on the edge, LocalResult doesn't + // provide any information. + LocalResult.markOverdefined(); - ConstantRange Range = - Result.getConstantRange().intersectWith(InBlock.getConstantRange()); - Result = LVILatticeVal::getRange(Range); + if (hasSingleValue(LocalResult)) { + // Can't get any more precise here + Result = LocalResult; return true; } if (!hasBlockValue(Val, BBFrom)) { if (pushBlockValue(std::make_pair(BBFrom, Val))) return false; - Result.markOverdefined(); + // No new information. + Result = LocalResult; return true; } - // If we couldn't compute the value on the edge, use the value from the BB. - Result = getBlockValue(Val, BBFrom); - mergeAssumeBlockValueConstantRange(Val, Result, BBFrom->getTerminator()); + // Try to intersect ranges of the BB and the constraint on the edge. + LVILatticeVal InBlock = getBlockValue(Val, BBFrom); + intersectAssumeBlockValueConstantRange(Val, InBlock, BBFrom->getTerminator()); // We can use the context instruction (generically the ultimate instruction // the calling pass is trying to simplify) here, even though the result of // this function is generally cached when called from the solve* functions @@ -1050,7 +1307,9 @@ bool LazyValueInfoCache::getEdgeValue(Value *Val, BasicBlock *BBFrom, // functions, the context instruction is not provided. When called from // LazyValueInfoCache::getValueOnEdge, the context instruction is provided, // but then the result is not cached. - mergeAssumeBlockValueConstantRange(Val, Result, CxtI); + intersectAssumeBlockValueConstantRange(Val, InBlock, CxtI); + + Result = intersect(LocalResult, InBlock); return true; } @@ -1060,11 +1319,12 @@ LVILatticeVal LazyValueInfoCache::getValueInBlock(Value *V, BasicBlock *BB, << BB->getName() << "'\n"); assert(BlockValueStack.empty() && BlockValueSet.empty()); - pushBlockValue(std::make_pair(BB, V)); - - solve(); + if (!hasBlockValue(V, BB)) { + pushBlockValue(std::make_pair(BB, V)); + solve(); + } LVILatticeVal Result = getBlockValue(V, BB); - mergeAssumeBlockValueConstantRange(V, Result, CxtI); + intersectAssumeBlockValueConstantRange(V, Result, CxtI); DEBUG(dbgs() << " Result = " << Result << "\n"); return Result; @@ -1074,10 +1334,13 @@ LVILatticeVal LazyValueInfoCache::getValueAt(Value *V, Instruction *CxtI) { DEBUG(dbgs() << "LVI Getting value " << *V << " at '" << CxtI->getName() << "'\n"); - LVILatticeVal Result; + if (auto *C = dyn_cast<Constant>(V)) + return LVILatticeVal::get(C); + + LVILatticeVal Result = LVILatticeVal::getOverdefined(); if (auto *I = dyn_cast<Instruction>(V)) Result = getFromRangeMetadata(I); - mergeAssumeBlockValueConstantRange(V, Result, CxtI); + intersectAssumeBlockValueConstantRange(V, Result, CxtI); DEBUG(dbgs() << " Result = " << Result << "\n"); return Result; @@ -1172,29 +1435,32 @@ static LazyValueInfoCache &getCache(void *&PImpl, AssumptionCache *AC, return *static_cast<LazyValueInfoCache*>(PImpl); } -bool LazyValueInfo::runOnFunction(Function &F) { - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); +bool LazyValueInfoWrapperPass::runOnFunction(Function &F) { + Info.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); const DataLayout &DL = F.getParent()->getDataLayout(); DominatorTreeWrapperPass *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DT = DTWP ? &DTWP->getDomTree() : nullptr; - - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + Info.DT = DTWP ? &DTWP->getDomTree() : nullptr; + Info.TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - if (PImpl) - getCache(PImpl, AC, &DL, DT).clear(); + if (Info.PImpl) + getCache(Info.PImpl, Info.AC, &DL, Info.DT).clear(); // Fully lazy. return false; } -void LazyValueInfo::getAnalysisUsage(AnalysisUsage &AU) const { +void LazyValueInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } +LazyValueInfo &LazyValueInfoWrapperPass::getLVI() { return Info; } + +LazyValueInfo::~LazyValueInfo() { releaseMemory(); } + void LazyValueInfo::releaseMemory() { // If the cache was allocated, free it. if (PImpl) { @@ -1203,6 +1469,16 @@ void LazyValueInfo::releaseMemory() { } } +void LazyValueInfoWrapperPass::releaseMemory() { Info.releaseMemory(); } + +LazyValueInfo LazyValueAnalysis::run(Function &F, FunctionAnalysisManager &FAM) { + auto &AC = FAM.getResult<AssumptionAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + + return LazyValueInfo(&AC, &TLI, DT); +} + Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, Instruction *CxtI) { const DataLayout &DL = BB->getModule()->getDataLayout(); @@ -1219,6 +1495,21 @@ Constant *LazyValueInfo::getConstant(Value *V, BasicBlock *BB, return nullptr; } +ConstantRange LazyValueInfo::getConstantRange(Value *V, BasicBlock *BB, + Instruction *CxtI) { + assert(V->getType()->isIntegerTy()); + unsigned Width = V->getType()->getIntegerBitWidth(); + const DataLayout &DL = BB->getModule()->getDataLayout(); + LVILatticeVal Result = + getCache(PImpl, AC, &DL, DT).getValueInBlock(V, BB, CxtI); + assert(!Result.isConstant()); + if (Result.isUndefined()) + return ConstantRange(Width, /*isFullSet=*/false); + if (Result.isConstantRange()) + return Result.getConstantRange(); + return ConstantRange(Width, /*isFullSet=*/true); +} + /// Determine whether the specified value is known to be a /// constant on the specified edge. Return null if not. Constant *LazyValueInfo::getConstantOnEdge(Value *V, BasicBlock *FromBB, @@ -1349,7 +1640,7 @@ LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, // We limit the search to one step backwards from the current BB and value. // We could consider extending this to search further backwards through the // CFG and/or value graph, but there are non-obvious compile time vs quality - // tradeoffs. + // tradeoffs. if (CxtI) { BasicBlock *BB = CxtI->getParent(); @@ -1369,10 +1660,10 @@ LazyValueInfo::getPredicateAt(unsigned Pred, Value *V, Constant *C, for (unsigned i = 0, e = PHI->getNumIncomingValues(); i < e; i++) { Value *Incoming = PHI->getIncomingValue(i); BasicBlock *PredBB = PHI->getIncomingBlock(i); - // Note that PredBB may be BB itself. + // Note that PredBB may be BB itself. Tristate Result = getPredicateOnEdge(Pred, Incoming, C, PredBB, BB, CxtI); - + // Keep going as long as we've seen a consistent known result for // all inputs. Baseline = (i == 0) ? Result /* First iteration */ diff --git a/lib/Analysis/Lint.cpp b/lib/Analysis/Lint.cpp index 2dfb09c95ad6..fdf5f55dab9f 100644 --- a/lib/Analysis/Lint.cpp +++ b/lib/Analysis/Lint.cpp @@ -435,7 +435,7 @@ void Lint::visitMemoryReference(Instruction &I, // If the global may be defined differently in another compilation unit // then don't warn about funky memory accesses. if (GV->hasDefinitiveInitializer()) { - Type *GTy = GV->getType()->getElementType(); + Type *GTy = GV->getValueType(); if (GTy->isSized()) BaseSize = DL->getTypeAllocSize(GTy); BaseAlign = GV->getAlignment(); @@ -642,8 +642,7 @@ Value *Lint::findValueImpl(Value *V, bool OffsetOk, if (!VisitedBlocks.insert(BB).second) break; if (Value *U = - FindAvailableLoadedValue(L->getPointerOperand(), - BB, BBI, DefMaxInstsToScan, AA)) + FindAvailableLoadedValue(L, BB, BBI, DefMaxInstsToScan, AA)) return findValueImpl(U, OffsetOk, Visited); if (BBI != BB->begin()) break; BB = BB->getUniquePredecessor(); diff --git a/lib/Analysis/Loads.cpp b/lib/Analysis/Loads.cpp index 4b2fa3c6505a..75426b54195a 100644 --- a/lib/Analysis/Loads.cpp +++ b/lib/Analysis/Loads.cpp @@ -21,8 +21,125 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Statepoint.h" + using namespace llvm; +static bool isAligned(const Value *Base, const APInt &Offset, unsigned Align, + const DataLayout &DL) { + APInt BaseAlign(Offset.getBitWidth(), Base->getPointerAlignment(DL)); + + if (!BaseAlign) { + Type *Ty = Base->getType()->getPointerElementType(); + if (!Ty->isSized()) + return false; + BaseAlign = DL.getABITypeAlignment(Ty); + } + + APInt Alignment(Offset.getBitWidth(), Align); + + assert(Alignment.isPowerOf2() && "must be a power of 2!"); + return BaseAlign.uge(Alignment) && !(Offset & (Alignment-1)); +} + +static bool isAligned(const Value *Base, unsigned Align, const DataLayout &DL) { + Type *Ty = Base->getType(); + assert(Ty->isSized() && "must be sized"); + APInt Offset(DL.getTypeStoreSizeInBits(Ty), 0); + return isAligned(Base, Offset, Align, DL); +} + +/// Test if V is always a pointer to allocated and suitably aligned memory for +/// a simple load or store. +static bool isDereferenceableAndAlignedPointer( + const Value *V, unsigned Align, const APInt &Size, const DataLayout &DL, + const Instruction *CtxI, const DominatorTree *DT, + SmallPtrSetImpl<const Value *> &Visited) { + // Note that it is not safe to speculate into a malloc'd region because + // malloc may return null. + + // bitcast instructions are no-ops as far as dereferenceability is concerned. + if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(V)) + return isDereferenceableAndAlignedPointer(BC->getOperand(0), Align, Size, + DL, CtxI, DT, Visited); + + bool CheckForNonNull = false; + APInt KnownDerefBytes(Size.getBitWidth(), + V->getPointerDereferenceableBytes(DL, CheckForNonNull)); + if (KnownDerefBytes.getBoolValue()) { + if (KnownDerefBytes.uge(Size)) + if (!CheckForNonNull || isKnownNonNullAt(V, CtxI, DT)) + return isAligned(V, Align, DL); + } + + // For GEPs, determine if the indexing lands within the allocated object. + if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { + const Value *Base = GEP->getPointerOperand(); + + APInt Offset(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (!GEP->accumulateConstantOffset(DL, Offset) || Offset.isNegative() || + !Offset.urem(APInt(Offset.getBitWidth(), Align)).isMinValue()) + return false; + + // If the base pointer is dereferenceable for Offset+Size bytes, then the + // GEP (== Base + Offset) is dereferenceable for Size bytes. If the base + // pointer is aligned to Align bytes, and the Offset is divisible by Align + // then the GEP (== Base + Offset == k_0 * Align + k_1 * Align) is also + // aligned to Align bytes. + + return Visited.insert(Base).second && + isDereferenceableAndAlignedPointer(Base, Align, Offset + Size, DL, + CtxI, DT, Visited); + } + + // For gc.relocate, look through relocations + if (const GCRelocateInst *RelocateInst = dyn_cast<GCRelocateInst>(V)) + return isDereferenceableAndAlignedPointer( + RelocateInst->getDerivedPtr(), Align, Size, DL, CtxI, DT, Visited); + + if (const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(V)) + return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Align, Size, + DL, CtxI, DT, Visited); + + if (auto CS = ImmutableCallSite(V)) + if (const Value *RV = CS.getReturnedArgOperand()) + return isDereferenceableAndAlignedPointer(RV, Align, Size, DL, CtxI, DT, + Visited); + + // If we don't know, assume the worst. + return false; +} + +bool llvm::isDereferenceableAndAlignedPointer(const Value *V, unsigned Align, + const DataLayout &DL, + const Instruction *CtxI, + const DominatorTree *DT) { + // When dereferenceability information is provided by a dereferenceable + // attribute, we know exactly how many bytes are dereferenceable. If we can + // determine the exact offset to the attributed variable, we can use that + // information here. + Type *VTy = V->getType(); + Type *Ty = VTy->getPointerElementType(); + + // Require ABI alignment for loads without alignment specification + if (Align == 0) + Align = DL.getABITypeAlignment(Ty); + + if (!Ty->isSized()) + return false; + + SmallPtrSet<const Value *, 32> Visited; + return ::isDereferenceableAndAlignedPointer( + V, Align, APInt(DL.getTypeSizeInBits(VTy), DL.getTypeStoreSize(Ty)), DL, + CtxI, DT, Visited); +} + +bool llvm::isDereferenceablePointer(const Value *V, const DataLayout &DL, + const Instruction *CtxI, + const DominatorTree *DT) { + return isDereferenceableAndAlignedPointer(V, 1, DL, CtxI, DT); +} + /// \brief Test if A and B will obviously have the same value. /// /// This includes recognizing that %t0 and %t1 will have the same @@ -56,21 +173,29 @@ static bool AreEquivalentAddressValues(const Value *A, const Value *B) { /// \brief Check if executing a load of this pointer value cannot trap. /// +/// If DT and ScanFrom are specified this method performs context-sensitive +/// analysis and returns true if it is safe to load immediately before ScanFrom. +/// /// If it is not obviously safe to load from the specified pointer, we do /// a quick local scan of the basic block containing \c ScanFrom, to determine /// if the address is already accessed. /// /// This uses the pointee type to determine how many bytes need to be safe to /// load from the pointer. -bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, - unsigned Align) { - const DataLayout &DL = ScanFrom->getModule()->getDataLayout(); - +bool llvm::isSafeToLoadUnconditionally(Value *V, unsigned Align, + const DataLayout &DL, + Instruction *ScanFrom, + const DominatorTree *DT) { // Zero alignment means that the load has the ABI alignment for the target if (Align == 0) Align = DL.getABITypeAlignment(V->getType()->getPointerElementType()); assert(isPowerOf2_32(Align)); + // If DT is not specified we can't make context-sensitive query + const Instruction* CtxI = DT ? ScanFrom : nullptr; + if (isDereferenceableAndAlignedPointer(V, Align, DL, CtxI, DT)) + return true; + int64_t ByteOffset = 0; Value *Base = V; Base = GetPointerBaseWithConstantOffset(V, ByteOffset, DL); @@ -86,9 +211,9 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, BaseAlign = AI->getAlignment(); } else if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Base)) { // Global variables are not necessarily safe to load from if they are - // overridden. Their size may change or they may be weak and require a test - // to determine if they were in fact provided. - if (!GV->mayBeOverridden()) { + // interposed arbitrarily. Their size may change or they may be weak and + // require a test to determine if they were in fact provided. + if (!GV->isInterposable()) { BaseType = GV->getType()->getElementType(); BaseAlign = GV->getAlignment(); } @@ -113,6 +238,9 @@ bool llvm::isSafeToLoadUnconditionally(Value *V, Instruction *ScanFrom, } } + if (!ScanFrom) + return false; + // Otherwise, be a little bit aggressive by scanning the local block where we // want to check to see if the pointer is already being loaded or stored // from/to. If so, the previous load or store would have already trapped, @@ -174,33 +302,24 @@ llvm::DefMaxInstsToScan("available-load-scan-limit", cl::init(6), cl::Hidden, "to scan backward from a given instruction, when searching for " "available loaded value")); -/// \brief Scan the ScanBB block backwards to see if we have the value at the -/// memory address *Ptr locally available within a small number of instructions. -/// -/// The scan starts from \c ScanFrom. \c MaxInstsToScan specifies the maximum -/// instructions to scan in the block. If it is set to \c 0, it will scan the whole -/// block. -/// -/// If the value is available, this function returns it. If not, it returns the -/// iterator for the last validated instruction that the value would be live -/// through. If we scanned the entire block and didn't find something that -/// invalidates \c *Ptr or provides it, \c ScanFrom is left at the last -/// instruction processed and this returns null. -/// -/// You can also optionally specify an alias analysis implementation, which -/// makes this more precise. -/// -/// If \c AATags is non-null and a load or store is found, the AA tags from the -/// load or store are recorded there. If there are no AA tags or if no access is -/// found, it is left unmodified. -Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, +Value *llvm::FindAvailableLoadedValue(LoadInst *Load, BasicBlock *ScanBB, BasicBlock::iterator &ScanFrom, unsigned MaxInstsToScan, - AliasAnalysis *AA, AAMDNodes *AATags) { + AliasAnalysis *AA, AAMDNodes *AATags, + bool *IsLoadCSE) { if (MaxInstsToScan == 0) MaxInstsToScan = ~0U; - Type *AccessTy = cast<PointerType>(Ptr->getType())->getElementType(); + Value *Ptr = Load->getPointerOperand(); + Type *AccessTy = Load->getType(); + + // We can never remove a volatile load + if (Load->isVolatile()) + return nullptr; + + // Anything stronger than unordered is currently unimplemented. + if (!Load->isUnordered()) + return nullptr; const DataLayout &DL = ScanBB->getModule()->getDataLayout(); @@ -231,8 +350,16 @@ Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, if (AreEquivalentAddressValues( LI->getPointerOperand()->stripPointerCasts(), StrippedPtr) && CastInst::isBitOrNoopPointerCastable(LI->getType(), AccessTy, DL)) { + + // We can value forward from an atomic to a non-atomic, but not the + // other way around. + if (LI->isAtomic() < Load->isAtomic()) + return nullptr; + if (AATags) LI->getAAMetadata(*AATags); + if (IsLoadCSE) + *IsLoadCSE = true; return LI; } @@ -244,6 +371,12 @@ Value *llvm::FindAvailableLoadedValue(Value *Ptr, BasicBlock *ScanBB, if (AreEquivalentAddressValues(StorePtr, StrippedPtr) && CastInst::isBitOrNoopPointerCastable(SI->getValueOperand()->getType(), AccessTy, DL)) { + + // We can value forward from an atomic to a non-atomic, but not the + // other way around. + if (SI->isAtomic() < Load->isAtomic()) + return nullptr; + if (AATags) SI->getAAMetadata(*AATags); return SI->getOperand(0); diff --git a/lib/Analysis/LoopAccessAnalysis.cpp b/lib/Analysis/LoopAccessAnalysis.cpp index 8bcdcb862014..0d774cf08e2f 100644 --- a/lib/Analysis/LoopAccessAnalysis.cpp +++ b/lib/Analysis/LoopAccessAnalysis.cpp @@ -14,15 +14,17 @@ #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/PassManager.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Analysis/VectorUtils.h" using namespace llvm; #define DEBUG_TYPE "loop-accesses" @@ -65,6 +67,28 @@ static cl::opt<unsigned> "loop-access analysis (default = 100)"), cl::init(100)); +/// This enables versioning on the strides of symbolically striding memory +/// accesses in code like the following. +/// for (i = 0; i < N; ++i) +/// A[i * Stride1] += B[i * Stride2] ... +/// +/// Will be roughly translated to +/// if (Stride1 == 1 && Stride2 == 1) { +/// for (i = 0; i < N; i+=4) +/// A[i:i+3] += ... +/// } else +/// ... +static cl::opt<bool> EnableMemAccessVersioning( + "enable-mem-access-versioning", cl::init(true), cl::Hidden, + cl::desc("Enable symbolic stride memory access versioning")); + +/// \brief Enable store-to-load forwarding conflict detection. This option can +/// be disabled for correctness testing. +static cl::opt<bool> EnableForwardingConflictDetection( + "store-to-load-forwarding-conflict-detection", cl::Hidden, + cl::desc("Enable conflict detection in loop-access analysis"), + cl::init(true)); + bool VectorizerParams::isInterleaveForced() { return ::VectorizationInterleave.getNumOccurrences() > 0; } @@ -81,7 +105,7 @@ void LoopAccessReport::emitAnalysis(const LoopAccessReport &Message, } Value *llvm::stripIntegerCast(Value *V) { - if (CastInst *CI = dyn_cast<CastInst>(V)) + if (auto *CI = dyn_cast<CastInst>(V)) if (CI->getOperand(0)->getType()->isIntegerTy()) return CI->getOperand(0); return V; @@ -130,26 +154,34 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, PredicatedScalarEvolution &PSE) { // Get the stride replaced scev. const SCEV *Sc = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); - const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); - assert(AR && "Invalid addrec expression"); ScalarEvolution *SE = PSE.getSE(); - const SCEV *Ex = SE->getBackedgeTakenCount(Lp); - const SCEV *ScStart = AR->getStart(); - const SCEV *ScEnd = AR->evaluateAtIteration(Ex, *SE); - const SCEV *Step = AR->getStepRecurrence(*SE); + const SCEV *ScStart; + const SCEV *ScEnd; - // For expressions with negative step, the upper bound is ScStart and the - // lower bound is ScEnd. - if (const SCEVConstant *CStep = dyn_cast<const SCEVConstant>(Step)) { - if (CStep->getValue()->isNegative()) - std::swap(ScStart, ScEnd); - } else { - // Fallback case: the step is not constant, but the we can still - // get the upper and lower bounds of the interval by using min/max - // expressions. - ScStart = SE->getUMinExpr(ScStart, ScEnd); - ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); + if (SE->isLoopInvariant(Sc, Lp)) + ScStart = ScEnd = Sc; + else { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Sc); + assert(AR && "Invalid addrec expression"); + const SCEV *Ex = PSE.getBackedgeTakenCount(); + + ScStart = AR->getStart(); + ScEnd = AR->evaluateAtIteration(Ex, *SE); + const SCEV *Step = AR->getStepRecurrence(*SE); + + // For expressions with negative step, the upper bound is ScStart and the + // lower bound is ScEnd. + if (const auto *CStep = dyn_cast<SCEVConstant>(Step)) { + if (CStep->getValue()->isNegative()) + std::swap(ScStart, ScEnd); + } else { + // Fallback case: the step is not constant, but the we can still + // get the upper and lower bounds of the interval by using min/max + // expressions. + ScStart = SE->getUMinExpr(ScStart, ScEnd); + ScEnd = SE->getUMaxExpr(AR->getStart(), ScEnd); + } } Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); @@ -452,7 +484,7 @@ public: /// (i.e. the pointers have computable bounds). bool canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &Strides, - bool ShouldCheckStride = false); + bool ShouldCheckWrap = false); /// \brief Goes over all memory accesses, checks whether a RT check is needed /// and builds sets of dependent accesses. @@ -524,6 +556,11 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, const ValueToValueMap &Strides, Value *Ptr, Loop *L) { const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + + // The bounds for loop-invariant pointer is trivial. + if (PSE.getSE()->isLoopInvariant(PtrScev, L)) + return true; + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); if (!AR) return false; @@ -531,10 +568,21 @@ static bool hasComputableBounds(PredicatedScalarEvolution &PSE, return AR->isAffine(); } +/// \brief Check whether a pointer address cannot wrap. +static bool isNoWrap(PredicatedScalarEvolution &PSE, + const ValueToValueMap &Strides, Value *Ptr, Loop *L) { + const SCEV *PtrScev = PSE.getSCEV(Ptr); + if (PSE.getSE()->isLoopInvariant(PtrScev, L)) + return true; + + int64_t Stride = getPtrStride(PSE, Ptr, L, Strides); + return Stride == 1; +} + bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, ScalarEvolution *SE, Loop *TheLoop, const ValueToValueMap &StridesMap, - bool ShouldCheckStride) { + bool ShouldCheckWrap) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. bool CanDoRT = true; @@ -569,8 +617,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, if (hasComputableBounds(PSE, StridesMap, Ptr, TheLoop) && // When we run after a failing dependency check we have to make sure // we don't have wrapping pointers. - (!ShouldCheckStride || - isStridedPtr(PSE, Ptr, TheLoop, StridesMap) == 1)) { + (!ShouldCheckWrap || isNoWrap(PSE, StridesMap, Ptr, TheLoop))) { // The id of the dependence set. unsigned DepId; @@ -773,7 +820,7 @@ static bool isInBoundsGep(Value *Ptr) { /// \brief Return true if an AddRec pointer \p Ptr is unsigned non-wrapping, /// i.e. monotonically increasing/decreasing. static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, - ScalarEvolution *SE, const Loop *L) { + PredicatedScalarEvolution &PSE, const Loop *L) { // FIXME: This should probably only return true for NUW. if (AR->getNoWrapFlags(SCEV::NoWrapMask)) return true; @@ -792,11 +839,11 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, // Make sure there is only one non-const index and analyze that. Value *NonConstIndex = nullptr; - for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) - if (!isa<ConstantInt>(*Index)) { + for (Value *Index : make_range(GEP->idx_begin(), GEP->idx_end())) + if (!isa<ConstantInt>(Index)) { if (NonConstIndex) return false; - NonConstIndex = *Index; + NonConstIndex = Index; } if (!NonConstIndex) // The recurrence is on the pointer, ignore for now. @@ -809,7 +856,7 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, // Assume constant for other the operand so that the AddRec can be // easily found. isa<ConstantInt>(OBO->getOperand(1))) { - auto *OpScev = SE->getSCEV(OBO->getOperand(0)); + auto *OpScev = PSE.getSCEV(OBO->getOperand(0)); if (auto *OpAR = dyn_cast<SCEVAddRecExpr>(OpScev)) return OpAR->getLoop() == L && OpAR->getNoWrapFlags(SCEV::FlagNSW); @@ -819,32 +866,36 @@ static bool isNoWrapAddRec(Value *Ptr, const SCEVAddRecExpr *AR, } /// \brief Check whether the access through \p Ptr has a constant stride. -int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, - const Loop *Lp, const ValueToValueMap &StridesMap) { +int64_t llvm::getPtrStride(PredicatedScalarEvolution &PSE, Value *Ptr, + const Loop *Lp, const ValueToValueMap &StridesMap, + bool Assume) { Type *Ty = Ptr->getType(); assert(Ty->isPointerTy() && "Unexpected non-ptr"); // Make sure that the pointer does not point to aggregate types. auto *PtrTy = cast<PointerType>(Ty); if (PtrTy->getElementType()->isAggregateType()) { - DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" - << *Ptr << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not a pointer to a scalar type" << *Ptr + << "\n"); return 0; } const SCEV *PtrScev = replaceSymbolicStrideSCEV(PSE, StridesMap, Ptr); const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PtrScev); + if (Assume && !AR) + AR = PSE.getAsAddRec(Ptr); + if (!AR) { - DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " - << *Ptr << " SCEV: " << *PtrScev << "\n"); + DEBUG(dbgs() << "LAA: Bad stride - Not an AddRecExpr pointer " << *Ptr + << " SCEV: " << *PtrScev << "\n"); return 0; } // The accesss function must stride over the innermost loop. if (Lp != AR->getLoop()) { DEBUG(dbgs() << "LAA: Bad stride - Not striding over innermost loop " << - *Ptr << " SCEV: " << *PtrScev << "\n"); + *Ptr << " SCEV: " << *AR << "\n"); return 0; } @@ -856,12 +907,23 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, // to access the pointer value "0" which is undefined behavior in address // space 0, therefore we can also vectorize this case. bool IsInBoundsGEP = isInBoundsGep(Ptr); - bool IsNoWrapAddRec = isNoWrapAddRec(Ptr, AR, PSE.getSE(), Lp); + bool IsNoWrapAddRec = + PSE.hasNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW) || + isNoWrapAddRec(Ptr, AR, PSE, Lp); bool IsInAddressSpaceZero = PtrTy->getAddressSpace() == 0; if (!IsNoWrapAddRec && !IsInBoundsGEP && !IsInAddressSpaceZero) { - DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " - << *Ptr << " SCEV: " << *PtrScev << "\n"); - return 0; + if (Assume) { + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + IsNoWrapAddRec = true; + DEBUG(dbgs() << "LAA: Pointer may wrap in the address space:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + } else { + DEBUG(dbgs() << "LAA: Bad stride - Pointer may wrap in the address space " + << *Ptr << " SCEV: " << *AR << "\n"); + return 0; + } } // Check the step is constant. @@ -871,7 +933,7 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, const SCEVConstant *C = dyn_cast<SCEVConstant>(Step); if (!C) { DEBUG(dbgs() << "LAA: Bad stride - Not a constant strided " << *Ptr << - " SCEV: " << *PtrScev << "\n"); + " SCEV: " << *AR << "\n"); return 0; } @@ -895,12 +957,94 @@ int llvm::isStridedPtr(PredicatedScalarEvolution &PSE, Value *Ptr, // know we can't "wrap around the address space". In case of address space // zero we know that this won't happen without triggering undefined behavior. if (!IsNoWrapAddRec && (IsInBoundsGEP || IsInAddressSpaceZero) && - Stride != 1 && Stride != -1) - return 0; + Stride != 1 && Stride != -1) { + if (Assume) { + // We can avoid this case by adding a run-time check. + DEBUG(dbgs() << "LAA: Non unit strided pointer which is not either " + << "inbouds or in address space 0 may wrap:\n" + << "LAA: Pointer: " << *Ptr << "\n" + << "LAA: SCEV: " << *AR << "\n" + << "LAA: Added an overflow assumption\n"); + PSE.setNoOverflow(Ptr, SCEVWrapPredicate::IncrementNUSW); + } else + return 0; + } return Stride; } +/// Take the pointer operand from the Load/Store instruction. +/// Returns NULL if this is not a valid Load/Store instruction. +static Value *getPointerOperand(Value *I) { + if (auto *LI = dyn_cast<LoadInst>(I)) + return LI->getPointerOperand(); + if (auto *SI = dyn_cast<StoreInst>(I)) + return SI->getPointerOperand(); + return nullptr; +} + +/// Take the address space operand from the Load/Store instruction. +/// Returns -1 if this is not a valid Load/Store instruction. +static unsigned getAddressSpaceOperand(Value *I) { + if (LoadInst *L = dyn_cast<LoadInst>(I)) + return L->getPointerAddressSpace(); + if (StoreInst *S = dyn_cast<StoreInst>(I)) + return S->getPointerAddressSpace(); + return -1; +} + +/// Returns true if the memory operations \p A and \p B are consecutive. +bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, + ScalarEvolution &SE, bool CheckType) { + Value *PtrA = getPointerOperand(A); + Value *PtrB = getPointerOperand(B); + unsigned ASA = getAddressSpaceOperand(A); + unsigned ASB = getAddressSpaceOperand(B); + + // Check that the address spaces match and that the pointers are valid. + if (!PtrA || !PtrB || (ASA != ASB)) + return false; + + // Make sure that A and B are different pointers. + if (PtrA == PtrB) + return false; + + // Make sure that A and B have the same type if required. + if(CheckType && PtrA->getType() != PtrB->getType()) + return false; + + unsigned PtrBitWidth = DL.getPointerSizeInBits(ASA); + Type *Ty = cast<PointerType>(PtrA->getType())->getElementType(); + APInt Size(PtrBitWidth, DL.getTypeStoreSize(Ty)); + + APInt OffsetA(PtrBitWidth, 0), OffsetB(PtrBitWidth, 0); + PtrA = PtrA->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetA); + PtrB = PtrB->stripAndAccumulateInBoundsConstantOffsets(DL, OffsetB); + + // OffsetDelta = OffsetB - OffsetA; + const SCEV *OffsetSCEVA = SE.getConstant(OffsetA); + const SCEV *OffsetSCEVB = SE.getConstant(OffsetB); + const SCEV *OffsetDeltaSCEV = SE.getMinusSCEV(OffsetSCEVB, OffsetSCEVA); + const SCEVConstant *OffsetDeltaC = dyn_cast<SCEVConstant>(OffsetDeltaSCEV); + const APInt &OffsetDelta = OffsetDeltaC->getAPInt(); + // Check if they are based on the same pointer. That makes the offsets + // sufficient. + if (PtrA == PtrB) + return OffsetDelta == Size; + + // Compute the necessary base pointer delta to have the necessary final delta + // equal to the size. + // BaseDelta = Size - OffsetDelta; + const SCEV *SizeSCEV = SE.getConstant(Size); + const SCEV *BaseDelta = SE.getMinusSCEV(SizeSCEV, OffsetDeltaSCEV); + + // Otherwise compute the distance with SCEV between the base pointers. + const SCEV *PtrSCEVA = SE.getSCEV(PtrA); + const SCEV *PtrSCEVB = SE.getSCEV(PtrB); + const SCEV *X = SE.getAddExpr(PtrSCEVA, BaseDelta); + return X == PtrSCEVB; +} + bool MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { switch (Type) { case NoDep: @@ -953,8 +1097,8 @@ bool MemoryDepChecker::Dependence::isForward() const { llvm_unreachable("unexpected DepType!"); } -bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance, - unsigned TypeByteSize) { +bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance, + uint64_t TypeByteSize) { // If loads occur at a distance that is not a multiple of a feasible vector // factor store-load forwarding does not take place. // Positive dependences might cause troubles because vectorizing them might @@ -964,30 +1108,34 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance, // hence on your typical architecture store-load forwarding does not take // place. Vectorizing in such cases does not make sense. // Store-load forwarding distance. - const unsigned NumCyclesForStoreLoadThroughMemory = 8*TypeByteSize; + + // After this many iterations store-to-load forwarding conflicts should not + // cause any slowdowns. + const uint64_t NumItersForStoreLoadThroughMemory = 8 * TypeByteSize; // Maximum vector factor. - unsigned MaxVFWithoutSLForwardIssues = - VectorizerParams::MaxVectorWidth * TypeByteSize; - if(MaxSafeDepDistBytes < MaxVFWithoutSLForwardIssues) - MaxVFWithoutSLForwardIssues = MaxSafeDepDistBytes; - - for (unsigned vf = 2*TypeByteSize; vf <= MaxVFWithoutSLForwardIssues; - vf *= 2) { - if (Distance % vf && Distance / vf < NumCyclesForStoreLoadThroughMemory) { - MaxVFWithoutSLForwardIssues = (vf >>=1); + uint64_t MaxVFWithoutSLForwardIssues = std::min( + VectorizerParams::MaxVectorWidth * TypeByteSize, MaxSafeDepDistBytes); + + // Compute the smallest VF at which the store and load would be misaligned. + for (uint64_t VF = 2 * TypeByteSize; VF <= MaxVFWithoutSLForwardIssues; + VF *= 2) { + // If the number of vector iteration between the store and the load are + // small we could incur conflicts. + if (Distance % VF && Distance / VF < NumItersForStoreLoadThroughMemory) { + MaxVFWithoutSLForwardIssues = (VF >>= 1); break; } } - if (MaxVFWithoutSLForwardIssues< 2*TypeByteSize) { - DEBUG(dbgs() << "LAA: Distance " << Distance << - " that could cause a store-load forwarding conflict\n"); + if (MaxVFWithoutSLForwardIssues < 2 * TypeByteSize) { + DEBUG(dbgs() << "LAA: Distance " << Distance + << " that could cause a store-load forwarding conflict\n"); return true; } if (MaxVFWithoutSLForwardIssues < MaxSafeDepDistBytes && MaxVFWithoutSLForwardIssues != - VectorizerParams::MaxVectorWidth * TypeByteSize) + VectorizerParams::MaxVectorWidth * TypeByteSize) MaxSafeDepDistBytes = MaxVFWithoutSLForwardIssues; return false; } @@ -997,8 +1145,8 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(unsigned Distance, /// bytes. /// /// \returns true if they are independent. -static bool areStridedAccessesIndependent(unsigned Distance, unsigned Stride, - unsigned TypeByteSize) { +static bool areStridedAccessesIndependent(uint64_t Distance, uint64_t Stride, + uint64_t TypeByteSize) { assert(Stride > 1 && "The stride must be greater than 1"); assert(TypeByteSize > 0 && "The type size in byte must be non-zero"); assert(Distance > 0 && "The distance must be non-zero"); @@ -1007,7 +1155,7 @@ static bool areStridedAccessesIndependent(unsigned Distance, unsigned Stride, if (Distance % TypeByteSize) return false; - unsigned ScaledDist = Distance / TypeByteSize; + uint64_t ScaledDist = Distance / TypeByteSize; // No dependence if the scaled distance is not multiple of the stride. // E.g. @@ -1048,20 +1196,15 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, BPtr->getType()->getPointerAddressSpace()) return Dependence::Unknown; - const SCEV *AScev = replaceSymbolicStrideSCEV(PSE, Strides, APtr); - const SCEV *BScev = replaceSymbolicStrideSCEV(PSE, Strides, BPtr); - - int StrideAPtr = isStridedPtr(PSE, APtr, InnermostLoop, Strides); - int StrideBPtr = isStridedPtr(PSE, BPtr, InnermostLoop, Strides); + int64_t StrideAPtr = getPtrStride(PSE, APtr, InnermostLoop, Strides, true); + int64_t StrideBPtr = getPtrStride(PSE, BPtr, InnermostLoop, Strides, true); - const SCEV *Src = AScev; - const SCEV *Sink = BScev; + const SCEV *Src = PSE.getSCEV(APtr); + const SCEV *Sink = PSE.getSCEV(BPtr); // If the induction step is negative we have to invert source and sink of the // dependence. if (StrideAPtr < 0) { - //Src = BScev; - //Sink = AScev; std::swap(APtr, BPtr); std::swap(Src, Sink); std::swap(AIsWrite, BIsWrite); @@ -1094,18 +1237,30 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, Type *ATy = APtr->getType()->getPointerElementType(); Type *BTy = BPtr->getType()->getPointerElementType(); auto &DL = InnermostLoop->getHeader()->getModule()->getDataLayout(); - unsigned TypeByteSize = DL.getTypeAllocSize(ATy); + uint64_t TypeByteSize = DL.getTypeAllocSize(ATy); - // Negative distances are not plausible dependencies. const APInt &Val = C->getAPInt(); + int64_t Distance = Val.getSExtValue(); + uint64_t Stride = std::abs(StrideAPtr); + + // Attempt to prove strided accesses independent. + if (std::abs(Distance) > 0 && Stride > 1 && ATy == BTy && + areStridedAccessesIndependent(std::abs(Distance), Stride, TypeByteSize)) { + DEBUG(dbgs() << "LAA: Strided accesses are independent\n"); + return Dependence::NoDep; + } + + // Negative distances are not plausible dependencies. if (Val.isNegative()) { bool IsTrueDataDependence = (AIsWrite && !BIsWrite); - if (IsTrueDataDependence && + if (IsTrueDataDependence && EnableForwardingConflictDetection && (couldPreventStoreLoadForward(Val.abs().getZExtValue(), TypeByteSize) || - ATy != BTy)) + ATy != BTy)) { + DEBUG(dbgs() << "LAA: Forward but may prevent st->ld forwarding\n"); return Dependence::ForwardButPreventsForwarding; + } - DEBUG(dbgs() << "LAA: Dependence is negative: NoDep\n"); + DEBUG(dbgs() << "LAA: Dependence is negative\n"); return Dependence::Forward; } @@ -1126,15 +1281,6 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return Dependence::Unknown; } - unsigned Distance = (unsigned) Val.getZExtValue(); - - unsigned Stride = std::abs(StrideAPtr); - if (Stride > 1 && - areStridedAccessesIndependent(Distance, Stride, TypeByteSize)) { - DEBUG(dbgs() << "LAA: Strided accesses are independent\n"); - return Dependence::NoDep; - } - // Bail out early if passed-in parameters make vectorization not feasible. unsigned ForcedFactor = (VectorizerParams::VectorizationFactor ? VectorizerParams::VectorizationFactor : 1); @@ -1169,9 +1315,9 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, // If MinNumIter is 4 (Say if a user forces the vectorization factor to be 4), // the minimum distance needed is 28, which is greater than distance. It is // not safe to do vectorization. - unsigned MinDistanceNeeded = + uint64_t MinDistanceNeeded = TypeByteSize * Stride * (MinNumIter - 1) + TypeByteSize; - if (MinDistanceNeeded > Distance) { + if (MinDistanceNeeded > static_cast<uint64_t>(Distance)) { DEBUG(dbgs() << "LAA: Failure because of positive distance " << Distance << '\n'); return Dependence::Backward; @@ -1201,10 +1347,10 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, // is 8, which is less than 2 and forbidden vectorization, But actually // both A and B could be vectorized by 2 iterations. MaxSafeDepDistBytes = - Distance < MaxSafeDepDistBytes ? Distance : MaxSafeDepDistBytes; + std::min(static_cast<uint64_t>(Distance), MaxSafeDepDistBytes); bool IsTrueDataDependence = (!AIsWrite && BIsWrite); - if (IsTrueDataDependence && + if (IsTrueDataDependence && EnableForwardingConflictDetection && couldPreventStoreLoadForward(Distance, TypeByteSize)) return Dependence::BackwardVectorizableButPreventsForwarding; @@ -1219,7 +1365,7 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, MemAccessInfoSet &CheckDeps, const ValueToValueMap &Strides) { - MaxSafeDepDistBytes = -1U; + MaxSafeDepDistBytes = -1; while (!CheckDeps.empty()) { MemAccessInfo CurAccess = *CheckDeps.begin(); @@ -1228,8 +1374,10 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, AccessSets.findValue(AccessSets.getLeaderValue(CurAccess)); // Check accesses within this set. - EquivalenceClasses<MemAccessInfo>::member_iterator AI, AE; - AI = AccessSets.member_begin(I), AE = AccessSets.member_end(); + EquivalenceClasses<MemAccessInfo>::member_iterator AI = + AccessSets.member_begin(I); + EquivalenceClasses<MemAccessInfo>::member_iterator AE = + AccessSets.member_end(); // Check every access pair. while (AI != AE) { @@ -1305,10 +1453,11 @@ void MemoryDepChecker::Dependence::print( bool LoopAccessInfo::canAnalyzeLoop() { // We need to have a loop header. - DEBUG(dbgs() << "LAA: Found a loop: " << - TheLoop->getHeader()->getName() << '\n'); + DEBUG(dbgs() << "LAA: Found a loop in " + << TheLoop->getHeader()->getParent()->getName() << ": " + << TheLoop->getHeader()->getName() << '\n'); - // We can only analyze innermost loops. + // We can only analyze innermost loops. if (!TheLoop->empty()) { DEBUG(dbgs() << "LAA: loop is not the innermost loop\n"); emitAnalysis(LoopAccessReport() << "loop is not the innermost loop"); @@ -1345,8 +1494,8 @@ bool LoopAccessInfo::canAnalyzeLoop() { } // ScalarEvolution needs to be able to find the exit count. - const SCEV *ExitCount = PSE.getSE()->getBackedgeTakenCount(TheLoop); - if (ExitCount == PSE.getSE()->getCouldNotCompute()) { + const SCEV *ExitCount = PSE->getBackedgeTakenCount(); + if (ExitCount == PSE->getSE()->getCouldNotCompute()) { emitAnalysis(LoopAccessReport() << "could not determine number of loop iterations"); DEBUG(dbgs() << "LAA: SCEV could not compute the loop exit count.\n"); @@ -1356,41 +1505,37 @@ bool LoopAccessInfo::canAnalyzeLoop() { return true; } -void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { - - typedef SmallVector<Value*, 16> ValueVector; +void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, + const TargetLibraryInfo *TLI, + DominatorTree *DT) { typedef SmallPtrSet<Value*, 16> ValueSet; - // Holds the Load and Store *instructions*. - ValueVector Loads; - ValueVector Stores; + // Holds the Load and Store instructions. + SmallVector<LoadInst *, 16> Loads; + SmallVector<StoreInst *, 16> Stores; // Holds all the different accesses in the loop. unsigned NumReads = 0; unsigned NumReadWrites = 0; - PtrRtChecking.Pointers.clear(); - PtrRtChecking.Need = false; + PtrRtChecking->Pointers.clear(); + PtrRtChecking->Need = false; const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); // For each block. - for (Loop::block_iterator bb = TheLoop->block_begin(), - be = TheLoop->block_end(); bb != be; ++bb) { - + for (BasicBlock *BB : TheLoop->blocks()) { // Scan the BB and collect legal loads and stores. - for (BasicBlock::iterator it = (*bb)->begin(), e = (*bb)->end(); it != e; - ++it) { - + for (Instruction &I : *BB) { // If this is a load, save it. If this instruction can read from memory // but is not a load, then we quit. Notice that we don't handle function // calls that read or write. - if (it->mayReadFromMemory()) { + if (I.mayReadFromMemory()) { // Many math library functions read the rounding mode. We will only // vectorize a loop if it contains known function calls that don't set // the flag. Therefore, it is safe to ignore this read from memory. - CallInst *Call = dyn_cast<CallInst>(it); - if (Call && getIntrinsicIDForCall(Call, TLI)) + auto *Call = dyn_cast<CallInst>(&I); + if (Call && getVectorIntrinsicIDForCall(Call, TLI)) continue; // If the function has an explicit vectorized counterpart, we can safely @@ -1399,7 +1544,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { TLI->isFunctionVectorizable(Call->getCalledFunction()->getName())) continue; - LoadInst *Ld = dyn_cast<LoadInst>(it); + auto *Ld = dyn_cast<LoadInst>(&I); if (!Ld || (!Ld->isSimple() && !IsAnnotatedParallel)) { emitAnalysis(LoopAccessReport(Ld) << "read with atomic ordering or volatile read"); @@ -1409,16 +1554,18 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { } NumLoads++; Loads.push_back(Ld); - DepChecker.addAccess(Ld); + DepChecker->addAccess(Ld); + if (EnableMemAccessVersioning) + collectStridedAccess(Ld); continue; } // Save 'store' instructions. Abort if other instructions write to memory. - if (it->mayWriteToMemory()) { - StoreInst *St = dyn_cast<StoreInst>(it); + if (I.mayWriteToMemory()) { + auto *St = dyn_cast<StoreInst>(&I); if (!St) { - emitAnalysis(LoopAccessReport(&*it) << - "instruction cannot be vectorized"); + emitAnalysis(LoopAccessReport(St) + << "instruction cannot be vectorized"); CanVecMem = false; return; } @@ -1431,7 +1578,9 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { } NumStores++; Stores.push_back(St); - DepChecker.addAccess(St); + DepChecker->addAccess(St); + if (EnableMemAccessVersioning) + collectStridedAccess(St); } } // Next instr. } // Next block. @@ -1449,7 +1598,7 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { MemoryDepChecker::DepCandidates DependentAccesses; AccessAnalysis Accesses(TheLoop->getHeader()->getModule()->getDataLayout(), - AA, LI, DependentAccesses, PSE); + AA, LI, DependentAccesses, *PSE); // Holds the analyzed pointers. We don't want to call GetUnderlyingObjects // multiple times on the same object. If the ptr is accessed twice, once @@ -1458,10 +1607,8 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // writes and between reads and writes, but not between reads and reads. ValueSet Seen; - ValueVector::iterator I, IE; - for (I = Stores.begin(), IE = Stores.end(); I != IE; ++I) { - StoreInst *ST = cast<StoreInst>(*I); - Value* Ptr = ST->getPointerOperand(); + for (StoreInst *ST : Stores) { + Value *Ptr = ST->getPointerOperand(); // Check for store to loop invariant address. StoreToLoopInvariantAddress |= isUniform(Ptr); // If we did *not* see this pointer before, insert it to the read-write @@ -1488,9 +1635,8 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { return; } - for (I = Loads.begin(), IE = Loads.end(); I != IE; ++I) { - LoadInst *LD = cast<LoadInst>(*I); - Value* Ptr = LD->getPointerOperand(); + for (LoadInst *LD : Loads) { + Value *Ptr = LD->getPointerOperand(); // If we did *not* see this pointer before, insert it to the // read list. If we *did* see it before, then it is already in // the read-write list. This allows us to vectorize expressions @@ -1500,7 +1646,8 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // read a few words, modify, and write a few words, and some of the // words may be written to the same address. bool IsReadOnlyPtr = false; - if (Seen.insert(Ptr).second || !isStridedPtr(PSE, Ptr, TheLoop, Strides)) { + if (Seen.insert(Ptr).second || + !getPtrStride(*PSE, Ptr, TheLoop, SymbolicStrides)) { ++NumReads; IsReadOnlyPtr = true; } @@ -1529,8 +1676,8 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { // Find pointers with computable bounds. We are going to use this information // to place a runtime bound check. - bool CanDoRTIfNeeded = - Accesses.canCheckPtrAtRT(PtrRtChecking, PSE.getSE(), TheLoop, Strides); + bool CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, PSE->getSE(), + TheLoop, SymbolicStrides); if (!CanDoRTIfNeeded) { emitAnalysis(LoopAccessReport() << "cannot identify array bounds"); DEBUG(dbgs() << "LAA: We can't vectorize because we can't find " @@ -1544,22 +1691,22 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { CanVecMem = true; if (Accesses.isDependencyCheckNeeded()) { DEBUG(dbgs() << "LAA: Checking memory dependencies\n"); - CanVecMem = DepChecker.areDepsSafe( - DependentAccesses, Accesses.getDependenciesToCheck(), Strides); - MaxSafeDepDistBytes = DepChecker.getMaxSafeDepDistBytes(); + CanVecMem = DepChecker->areDepsSafe( + DependentAccesses, Accesses.getDependenciesToCheck(), SymbolicStrides); + MaxSafeDepDistBytes = DepChecker->getMaxSafeDepDistBytes(); - if (!CanVecMem && DepChecker.shouldRetryWithRuntimeCheck()) { + if (!CanVecMem && DepChecker->shouldRetryWithRuntimeCheck()) { DEBUG(dbgs() << "LAA: Retrying with memory checks\n"); // Clear the dependency checks. We assume they are not needed. - Accesses.resetDepChecks(DepChecker); + Accesses.resetDepChecks(*DepChecker); - PtrRtChecking.reset(); - PtrRtChecking.Need = true; + PtrRtChecking->reset(); + PtrRtChecking->Need = true; - auto *SE = PSE.getSE(); - CanDoRTIfNeeded = - Accesses.canCheckPtrAtRT(PtrRtChecking, SE, TheLoop, Strides, true); + auto *SE = PSE->getSE(); + CanDoRTIfNeeded = Accesses.canCheckPtrAtRT(*PtrRtChecking, SE, TheLoop, + SymbolicStrides, true); // Check that we found the bounds for the pointer. if (!CanDoRTIfNeeded) { @@ -1576,11 +1723,15 @@ void LoopAccessInfo::analyzeLoop(const ValueToValueMap &Strides) { if (CanVecMem) DEBUG(dbgs() << "LAA: No unsafe dependent memory operations in loop. We" - << (PtrRtChecking.Need ? "" : " don't") + << (PtrRtChecking->Need ? "" : " don't") << " need runtime memory checks.\n"); else { - emitAnalysis(LoopAccessReport() << - "unsafe dependent memory operations in loop"); + emitAnalysis( + LoopAccessReport() + << "unsafe dependent memory operations in loop. Use " + "#pragma loop distribute(enable) to allow loop distribution " + "to attempt to isolate the offending operations into a separate " + "loop"); DEBUG(dbgs() << "LAA: unsafe dependent memory operations in loop\n"); } } @@ -1600,7 +1751,7 @@ void LoopAccessInfo::emitAnalysis(LoopAccessReport &Message) { } bool LoopAccessInfo::isUniform(Value *V) const { - return (PSE.getSE()->isLoopInvariant(PSE.getSE()->getSCEV(V), TheLoop)); + return (PSE->getSE()->isLoopInvariant(PSE->getSE()->getSCEV(V), TheLoop)); } // FIXME: this function is currently a duplicate of the one in @@ -1681,10 +1832,11 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( Instruction *Loc, const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks) const { - auto *SE = PSE.getSE(); + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + auto *SE = PSE->getSE(); SCEVExpander Exp(*SE, DL, "induction"); auto ExpandedChecks = - expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, PtrRtChecking); + expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, *PtrRtChecking); LLVMContext &Ctx = Loc->getContext(); Instruction *FirstInst = nullptr; @@ -1740,47 +1892,68 @@ std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks(Instruction *Loc) const { - if (!PtrRtChecking.Need) + if (!PtrRtChecking->Need) return std::make_pair(nullptr, nullptr); - return addRuntimeChecks(Loc, PtrRtChecking.getChecks()); + return addRuntimeChecks(Loc, PtrRtChecking->getChecks()); +} + +void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { + Value *Ptr = nullptr; + if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) + Ptr = LI->getPointerOperand(); + else if (StoreInst *SI = dyn_cast<StoreInst>(MemAccess)) + Ptr = SI->getPointerOperand(); + else + return; + + Value *Stride = getStrideFromPointer(Ptr, PSE->getSE(), TheLoop); + if (!Stride) + return; + + DEBUG(dbgs() << "LAA: Found a strided access that we can version"); + DEBUG(dbgs() << " Ptr: " << *Ptr << " Stride: " << *Stride << "\n"); + SymbolicStrides[Ptr] = Stride; + StrideSet.insert(Stride); } LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, - const DataLayout &DL, const TargetLibraryInfo *TLI, AliasAnalysis *AA, - DominatorTree *DT, LoopInfo *LI, - const ValueToValueMap &Strides) - : PSE(*SE), PtrRtChecking(SE), DepChecker(PSE, L), TheLoop(L), DL(DL), - TLI(TLI), AA(AA), DT(DT), LI(LI), NumLoads(0), NumStores(0), - MaxSafeDepDistBytes(-1U), CanVecMem(false), + DominatorTree *DT, LoopInfo *LI) + : PSE(llvm::make_unique<PredicatedScalarEvolution>(*SE, *L)), + PtrRtChecking(llvm::make_unique<RuntimePointerChecking>(SE)), + DepChecker(llvm::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L), + NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), StoreToLoopInvariantAddress(false) { if (canAnalyzeLoop()) - analyzeLoop(Strides); + analyzeLoop(AA, LI, TLI, DT); } void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { if (CanVecMem) { - if (PtrRtChecking.Need) - OS.indent(Depth) << "Memory dependences are safe with run-time checks\n"; - else - OS.indent(Depth) << "Memory dependences are safe\n"; + OS.indent(Depth) << "Memory dependences are safe"; + if (MaxSafeDepDistBytes != -1ULL) + OS << " with a maximum dependence distance of " << MaxSafeDepDistBytes + << " bytes"; + if (PtrRtChecking->Need) + OS << " with run-time checks"; + OS << "\n"; } if (Report) OS.indent(Depth) << "Report: " << Report->str() << "\n"; - if (auto *Dependences = DepChecker.getDependences()) { + if (auto *Dependences = DepChecker->getDependences()) { OS.indent(Depth) << "Dependences:\n"; for (auto &Dep : *Dependences) { - Dep.print(OS, Depth + 2, DepChecker.getMemoryInstructions()); + Dep.print(OS, Depth + 2, DepChecker->getMemoryInstructions()); OS << "\n"; } } else OS.indent(Depth) << "Too many dependences, not recorded\n"; // List the pair of accesses need run-time checks to prove independence. - PtrRtChecking.print(OS, Depth); + PtrRtChecking->print(OS, Depth); OS << "\n"; OS.indent(Depth) << "Store to invariant address was " @@ -1788,43 +1961,35 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { << "found in loop.\n"; OS.indent(Depth) << "SCEV assumptions:\n"; - PSE.getUnionPredicate().print(OS, Depth); + PSE->getUnionPredicate().print(OS, Depth); + + OS << "\n"; + + OS.indent(Depth) << "Expressions re-written:\n"; + PSE->print(OS, Depth); } -const LoopAccessInfo & -LoopAccessAnalysis::getInfo(Loop *L, const ValueToValueMap &Strides) { +const LoopAccessInfo &LoopAccessLegacyAnalysis::getInfo(Loop *L) { auto &LAI = LoopAccessInfoMap[L]; -#ifndef NDEBUG - assert((!LAI || LAI->NumSymbolicStrides == Strides.size()) && - "Symbolic strides changed for loop"); -#endif - - if (!LAI) { - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - LAI = - llvm::make_unique<LoopAccessInfo>(L, SE, DL, TLI, AA, DT, LI, Strides); -#ifndef NDEBUG - LAI->NumSymbolicStrides = Strides.size(); -#endif - } + if (!LAI) + LAI = llvm::make_unique<LoopAccessInfo>(L, SE, TLI, AA, DT, LI); + return *LAI.get(); } -void LoopAccessAnalysis::print(raw_ostream &OS, const Module *M) const { - LoopAccessAnalysis &LAA = *const_cast<LoopAccessAnalysis *>(this); - - ValueToValueMap NoSymbolicStrides; +void LoopAccessLegacyAnalysis::print(raw_ostream &OS, const Module *M) const { + LoopAccessLegacyAnalysis &LAA = *const_cast<LoopAccessLegacyAnalysis *>(this); for (Loop *TopLevelLoop : *LI) for (Loop *L : depth_first(TopLevelLoop)) { OS.indent(2) << L->getHeader()->getName() << ":\n"; - auto &LAI = LAA.getInfo(L, NoSymbolicStrides); + auto &LAI = LAA.getInfo(L); LAI.print(OS, 4); } } -bool LoopAccessAnalysis::runOnFunction(Function &F) { +bool LoopAccessLegacyAnalysis::runOnFunction(Function &F) { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); TLI = TLIP ? &TLIP->getTLI() : nullptr; @@ -1835,7 +2000,7 @@ bool LoopAccessAnalysis::runOnFunction(Function &F) { return false; } -void LoopAccessAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { +void LoopAccessLegacyAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); @@ -1844,19 +2009,52 @@ void LoopAccessAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); } -char LoopAccessAnalysis::ID = 0; +char LoopAccessLegacyAnalysis::ID = 0; static const char laa_name[] = "Loop Access Analysis"; #define LAA_NAME "loop-accesses" -INITIALIZE_PASS_BEGIN(LoopAccessAnalysis, LAA_NAME, laa_name, false, true) +INITIALIZE_PASS_BEGIN(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) -INITIALIZE_PASS_END(LoopAccessAnalysis, LAA_NAME, laa_name, false, true) +INITIALIZE_PASS_END(LoopAccessLegacyAnalysis, LAA_NAME, laa_name, false, true) + +char LoopAccessAnalysis::PassID; + +LoopAccessInfo LoopAccessAnalysis::run(Loop &L, AnalysisManager<Loop> &AM) { + const AnalysisManager<Function> &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L).getManager(); + Function &F = *L.getHeader()->getParent(); + auto *SE = FAM.getCachedResult<ScalarEvolutionAnalysis>(F); + auto *TLI = FAM.getCachedResult<TargetLibraryAnalysis>(F); + auto *AA = FAM.getCachedResult<AAManager>(F); + auto *DT = FAM.getCachedResult<DominatorTreeAnalysis>(F); + auto *LI = FAM.getCachedResult<LoopAnalysis>(F); + if (!SE) + report_fatal_error( + "ScalarEvolution must have been cached at a higher level"); + if (!AA) + report_fatal_error("AliasAnalysis must have been cached at a higher level"); + if (!DT) + report_fatal_error("DominatorTree must have been cached at a higher level"); + if (!LI) + report_fatal_error("LoopInfo must have been cached at a higher level"); + return LoopAccessInfo(&L, SE, TLI, AA, DT, LI); +} + +PreservedAnalyses LoopAccessInfoPrinterPass::run(Loop &L, + AnalysisManager<Loop> &AM) { + Function &F = *L.getHeader()->getParent(); + auto &LAI = AM.getResult<LoopAccessAnalysis>(L); + OS << "Loop access info in function '" << F.getName() << "':\n"; + OS.indent(2) << L.getHeader()->getName() << ":\n"; + LAI.print(OS, 4); + return PreservedAnalyses::all(); +} namespace llvm { Pass *createLAAPass() { - return new LoopAccessAnalysis(); + return new LoopAccessLegacyAnalysis(); } } diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index 0c725fcadff7..30f7ef392422 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -22,6 +22,7 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" @@ -38,7 +39,7 @@ template class llvm::LoopBase<BasicBlock, Loop>; template class llvm::LoopInfoBase<BasicBlock, Loop>; // Always verify loopinfo if expensive checking is enabled. -#ifdef XDEBUG +#ifdef EXPENSIVE_CHECKS static bool VerifyLoopInfo = true; #else static bool VerifyLoopInfo = false; @@ -47,36 +48,20 @@ static cl::opt<bool,true> VerifyLoopInfoX("verify-loop-info", cl::location(VerifyLoopInfo), cl::desc("Verify loop info (time consuming)")); -// Loop identifier metadata name. -static const char *const LoopMDName = "llvm.loop"; - //===----------------------------------------------------------------------===// // Loop implementation // -/// isLoopInvariant - Return true if the specified value is loop invariant -/// bool Loop::isLoopInvariant(const Value *V) const { if (const Instruction *I = dyn_cast<Instruction>(V)) return !contains(I); return true; // All non-instructions are loop invariant } -/// hasLoopInvariantOperands - Return true if all the operands of the -/// specified instruction are loop invariant. bool Loop::hasLoopInvariantOperands(const Instruction *I) const { return all_of(I->operands(), [this](Value *V) { return isLoopInvariant(V); }); } -/// makeLoopInvariant - If the given value is an instruciton inside of the -/// loop and it can be hoisted, do so to make it trivially loop-invariant. -/// Return true if the value after any hoisting is loop invariant. This -/// function can be used as a slightly more aggressive replacement for -/// isLoopInvariant. -/// -/// If InsertPt is specified, it is the point to hoist instructions to. -/// If null, the terminator of the loop preheader is used. -/// bool Loop::makeLoopInvariant(Value *V, bool &Changed, Instruction *InsertPt) const { if (Instruction *I = dyn_cast<Instruction>(V)) @@ -84,15 +69,6 @@ bool Loop::makeLoopInvariant(Value *V, bool &Changed, return true; // All non-instructions are loop-invariant. } -/// makeLoopInvariant - If the given instruction is inside of the -/// loop and it can be hoisted, do so to make it trivially loop-invariant. -/// Return true if the instruction after any hoisting is loop invariant. This -/// function can be used as a slightly more aggressive replacement for -/// isLoopInvariant. -/// -/// If InsertPt is specified, it is the point to hoist instructions to. -/// If null, the terminator of the loop preheader is used. -/// bool Loop::makeLoopInvariant(Instruction *I, bool &Changed, Instruction *InsertPt) const { // Test if the value is already loop-invariant. @@ -114,8 +90,8 @@ bool Loop::makeLoopInvariant(Instruction *I, bool &Changed, InsertPt = Preheader->getTerminator(); } // Don't hoist instructions with loop-variant operands. - for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) - if (!makeLoopInvariant(I->getOperand(i), Changed, InsertPt)) + for (Value *Operand : I->operands()) + if (!makeLoopInvariant(Operand, Changed, InsertPt)) return false; // Hoist. @@ -131,14 +107,6 @@ bool Loop::makeLoopInvariant(Instruction *I, bool &Changed, return true; } -/// getCanonicalInductionVariable - Check to see if the loop has a canonical -/// induction variable: an integer recurrence that starts at 0 and increments -/// by one each time through the loop. If so, return the phi node that -/// corresponds to it. -/// -/// The IndVarSimplify pass transforms loops to have a canonical induction -/// variable. -/// PHINode *Loop::getCanonicalInductionVariable() const { BasicBlock *H = getHeader(); @@ -175,18 +143,16 @@ PHINode *Loop::getCanonicalInductionVariable() const { return nullptr; } -/// isLCSSAForm - Return true if the Loop is in LCSSA form bool Loop::isLCSSAForm(DominatorTree &DT) const { - for (block_iterator BI = block_begin(), E = block_end(); BI != E; ++BI) { - BasicBlock *BB = *BI; - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;++I) { + for (BasicBlock *BB : this->blocks()) { + for (Instruction &I : *BB) { // Tokens can't be used in PHI nodes and live-out tokens prevent loop // optimizations, so for the purposes of considered LCSSA form, we // can ignore them. - if (I->getType()->isTokenTy()) + if (I.getType()->isTokenTy()) continue; - for (Use &U : I->uses()) { + for (Use &U : I.uses()) { Instruction *UI = cast<Instruction>(U.getUser()); BasicBlock *UserBB = UI->getParent(); if (PHINode *P = dyn_cast<PHINode>(UI)) @@ -216,42 +182,24 @@ bool Loop::isRecursivelyLCSSAForm(DominatorTree &DT) const { }); } -/// isLoopSimplifyForm - Return true if the Loop is in the form that -/// the LoopSimplify form transforms loops to, which is sometimes called -/// normal form. bool Loop::isLoopSimplifyForm() const { // Normal-form loops have a preheader, a single backedge, and all of their // exits have all their predecessors inside the loop. return getLoopPreheader() && getLoopLatch() && hasDedicatedExits(); } -/// isSafeToClone - Return true if the loop body is safe to clone in practice. -/// Routines that reform the loop CFG and split edges often fail on indirectbr. +// Routines that reform the loop CFG and split edges often fail on indirectbr. bool Loop::isSafeToClone() const { // Return false if any loop blocks contain indirectbrs, or there are any calls // to noduplicate functions. - for (Loop::block_iterator I = block_begin(), E = block_end(); I != E; ++I) { - if (isa<IndirectBrInst>((*I)->getTerminator())) + for (BasicBlock *BB : this->blocks()) { + if (isa<IndirectBrInst>(BB->getTerminator())) return false; - if (const InvokeInst *II = dyn_cast<InvokeInst>((*I)->getTerminator())) { - if (II->cannotDuplicate()) - return false; - // Return false if any loop blocks contain invokes to EH-pads other than - // landingpads; we don't know how to split those edges yet. - auto *FirstNonPHI = II->getUnwindDest()->getFirstNonPHI(); - if (FirstNonPHI->isEHPad() && !isa<LandingPadInst>(FirstNonPHI)) - return false; - } - - for (BasicBlock::iterator BI = (*I)->begin(), BE = (*I)->end(); BI != BE; ++BI) { - if (const CallInst *CI = dyn_cast<CallInst>(BI)) { - if (CI->cannotDuplicate()) + for (Instruction &I : *BB) + if (auto CS = CallSite(&I)) + if (CS.cannotDuplicate()) return false; - } - if (BI->getType()->isTokenTy() && BI->isUsedOutsideOfBlock(*I)) - return false; - } } return true; } @@ -259,19 +207,19 @@ bool Loop::isSafeToClone() const { MDNode *Loop::getLoopID() const { MDNode *LoopID = nullptr; if (isLoopSimplifyForm()) { - LoopID = getLoopLatch()->getTerminator()->getMetadata(LoopMDName); + LoopID = getLoopLatch()->getTerminator()->getMetadata(LLVMContext::MD_loop); } else { // Go through each predecessor of the loop header and check the // terminator for the metadata. BasicBlock *H = getHeader(); - for (block_iterator I = block_begin(), IE = block_end(); I != IE; ++I) { - TerminatorInst *TI = (*I)->getTerminator(); + for (BasicBlock *BB : this->blocks()) { + TerminatorInst *TI = BB->getTerminator(); MDNode *MD = nullptr; // Check if this terminator branches to the loop header. - for (unsigned i = 0, ie = TI->getNumSuccessors(); i != ie; ++i) { - if (TI->getSuccessor(i) == H) { - MD = TI->getMetadata(LoopMDName); + for (BasicBlock *Successor : TI->successors()) { + if (Successor == H) { + MD = TI->getMetadata(LLVMContext::MD_loop); break; } } @@ -296,24 +244,24 @@ void Loop::setLoopID(MDNode *LoopID) const { assert(LoopID->getOperand(0) == LoopID && "Loop ID should refer to itself"); if (isLoopSimplifyForm()) { - getLoopLatch()->getTerminator()->setMetadata(LoopMDName, LoopID); + getLoopLatch()->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); return; } BasicBlock *H = getHeader(); - for (block_iterator I = block_begin(), IE = block_end(); I != IE; ++I) { - TerminatorInst *TI = (*I)->getTerminator(); - for (unsigned i = 0, ie = TI->getNumSuccessors(); i != ie; ++i) { - if (TI->getSuccessor(i) == H) - TI->setMetadata(LoopMDName, LoopID); + for (BasicBlock *BB : this->blocks()) { + TerminatorInst *TI = BB->getTerminator(); + for (BasicBlock *Successor : TI->successors()) { + if (Successor == H) + TI->setMetadata(LLVMContext::MD_loop, LoopID); } } } bool Loop::isAnnotatedParallel() const { - MDNode *desiredLoopIdMetadata = getLoopID(); + MDNode *DesiredLoopIdMetadata = getLoopID(); - if (!desiredLoopIdMetadata) + if (!DesiredLoopIdMetadata) return false; // The loop branch contains the parallel loop metadata. In order to ensure @@ -321,108 +269,112 @@ bool Loop::isAnnotatedParallel() const { // dependencies (thus converted the loop back to a sequential loop), check // that all the memory instructions in the loop contain parallelism metadata // that point to the same unique "loop id metadata" the loop branch does. - for (block_iterator BB = block_begin(), BE = block_end(); BB != BE; ++BB) { - for (BasicBlock::iterator II = (*BB)->begin(), EE = (*BB)->end(); - II != EE; II++) { - - if (!II->mayReadOrWriteMemory()) + for (BasicBlock *BB : this->blocks()) { + for (Instruction &I : *BB) { + if (!I.mayReadOrWriteMemory()) continue; // The memory instruction can refer to the loop identifier metadata // directly or indirectly through another list metadata (in case of // nested parallel loops). The loop identifier metadata refers to // itself so we can check both cases with the same routine. - MDNode *loopIdMD = - II->getMetadata(LLVMContext::MD_mem_parallel_loop_access); + MDNode *LoopIdMD = + I.getMetadata(LLVMContext::MD_mem_parallel_loop_access); - if (!loopIdMD) + if (!LoopIdMD) return false; - bool loopIdMDFound = false; - for (unsigned i = 0, e = loopIdMD->getNumOperands(); i < e; ++i) { - if (loopIdMD->getOperand(i) == desiredLoopIdMetadata) { - loopIdMDFound = true; + bool LoopIdMDFound = false; + for (const MDOperand &MDOp : LoopIdMD->operands()) { + if (MDOp == DesiredLoopIdMetadata) { + LoopIdMDFound = true; break; } } - if (!loopIdMDFound) + if (!LoopIdMDFound) return false; } } return true; } +DebugLoc Loop::getStartLoc() const { + // If we have a debug location in the loop ID, then use it. + if (MDNode *LoopID = getLoopID()) + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) + if (DILocation *L = dyn_cast<DILocation>(LoopID->getOperand(i))) + return DebugLoc(L); + + // Try the pre-header first. + if (BasicBlock *PHeadBB = getLoopPreheader()) + if (DebugLoc DL = PHeadBB->getTerminator()->getDebugLoc()) + return DL; + + // If we have no pre-header or there are no instructions with debug + // info in it, try the header. + if (BasicBlock *HeadBB = getHeader()) + return HeadBB->getTerminator()->getDebugLoc(); + + return DebugLoc(); +} -/// hasDedicatedExits - Return true if no exit block for the loop -/// has a predecessor that is outside the loop. bool Loop::hasDedicatedExits() const { // Each predecessor of each exit block of a normal loop is contained // within the loop. SmallVector<BasicBlock *, 4> ExitBlocks; getExitBlocks(ExitBlocks); - for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) - for (pred_iterator PI = pred_begin(ExitBlocks[i]), - PE = pred_end(ExitBlocks[i]); PI != PE; ++PI) - if (!contains(*PI)) + for (BasicBlock *BB : ExitBlocks) + for (BasicBlock *Predecessor : predecessors(BB)) + if (!contains(Predecessor)) return false; // All the requirements are met. return true; } -/// getUniqueExitBlocks - Return all unique successor blocks of this loop. -/// These are the blocks _outside of the current loop_ which are branched to. -/// This assumes that loop exits are in canonical form. -/// void Loop::getUniqueExitBlocks(SmallVectorImpl<BasicBlock *> &ExitBlocks) const { assert(hasDedicatedExits() && "getUniqueExitBlocks assumes the loop has canonical form exits!"); - SmallVector<BasicBlock *, 32> switchExitBlocks; - - for (block_iterator BI = block_begin(), BE = block_end(); BI != BE; ++BI) { - - BasicBlock *current = *BI; - switchExitBlocks.clear(); - - for (succ_iterator I = succ_begin(*BI), E = succ_end(*BI); I != E; ++I) { - // If block is inside the loop then it is not a exit block. - if (contains(*I)) + SmallVector<BasicBlock *, 32> SwitchExitBlocks; + for (BasicBlock *BB : this->blocks()) { + SwitchExitBlocks.clear(); + for (BasicBlock *Successor : successors(BB)) { + // If block is inside the loop then it is not an exit block. + if (contains(Successor)) continue; - pred_iterator PI = pred_begin(*I); - BasicBlock *firstPred = *PI; + pred_iterator PI = pred_begin(Successor); + BasicBlock *FirstPred = *PI; // If current basic block is this exit block's first predecessor // then only insert exit block in to the output ExitBlocks vector. // This ensures that same exit block is not inserted twice into // ExitBlocks vector. - if (current != firstPred) + if (BB != FirstPred) continue; // If a terminator has more then two successors, for example SwitchInst, // then it is possible that there are multiple edges from current block // to one exit block. - if (std::distance(succ_begin(current), succ_end(current)) <= 2) { - ExitBlocks.push_back(*I); + if (std::distance(succ_begin(BB), succ_end(BB)) <= 2) { + ExitBlocks.push_back(Successor); continue; } // In case of multiple edges from current block to exit block, collect // only one edge in ExitBlocks. Use switchExitBlocks to keep track of // duplicate edges. - if (std::find(switchExitBlocks.begin(), switchExitBlocks.end(), *I) - == switchExitBlocks.end()) { - switchExitBlocks.push_back(*I); - ExitBlocks.push_back(*I); + if (std::find(SwitchExitBlocks.begin(), SwitchExitBlocks.end(), Successor) + == SwitchExitBlocks.end()) { + SwitchExitBlocks.push_back(Successor); + ExitBlocks.push_back(Successor); } } } } -/// getUniqueExitBlock - If getUniqueExitBlocks would return exactly one -/// block, return that block. Otherwise return null. BasicBlock *Loop::getUniqueExitBlock() const { SmallVector<BasicBlock *, 8> UniqueExitBlocks; getUniqueExitBlocks(UniqueExitBlocks); @@ -432,7 +384,7 @@ BasicBlock *Loop::getUniqueExitBlock() const { } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void Loop::dump() const { +LLVM_DUMP_METHOD void Loop::dump() const { print(dbgs()); } #endif @@ -445,7 +397,7 @@ namespace { /// Find the new parent loop for all blocks within the "unloop" whose last /// backedges has just been removed. class UnloopUpdater { - Loop *Unloop; + Loop &Unloop; LoopInfo *LI; LoopBlocksDFS DFS; @@ -462,7 +414,7 @@ class UnloopUpdater { public: UnloopUpdater(Loop *UL, LoopInfo *LInfo) : - Unloop(UL), LI(LInfo), DFS(UL), FoundIB(false) {} + Unloop(*UL), LI(LInfo), DFS(UL), FoundIB(false) {} void updateBlockParents(); @@ -475,29 +427,28 @@ protected: }; } // end anonymous namespace -/// updateBlockParents - Update the parent loop for all blocks that are directly -/// contained within the original "unloop". +/// Update the parent loop for all blocks that are directly contained within the +/// original "unloop". void UnloopUpdater::updateBlockParents() { - if (Unloop->getNumBlocks()) { + if (Unloop.getNumBlocks()) { // Perform a post order CFG traversal of all blocks within this loop, // propagating the nearest loop from sucessors to predecessors. LoopBlocksTraversal Traversal(DFS, LI); - for (LoopBlocksTraversal::POTIterator POI = Traversal.begin(), - POE = Traversal.end(); POI != POE; ++POI) { + for (BasicBlock *POI : Traversal) { - Loop *L = LI->getLoopFor(*POI); - Loop *NL = getNearestLoop(*POI, L); + Loop *L = LI->getLoopFor(POI); + Loop *NL = getNearestLoop(POI, L); if (NL != L) { // For reducible loops, NL is now an ancestor of Unloop. - assert((NL != Unloop && (!NL || NL->contains(Unloop))) && + assert((NL != &Unloop && (!NL || NL->contains(&Unloop))) && "uninitialized successor"); - LI->changeLoopFor(*POI, NL); + LI->changeLoopFor(POI, NL); } else { // Or the current block is part of a subloop, in which case its parent // is unchanged. - assert((FoundIB || Unloop->contains(L)) && "uninitialized successor"); + assert((FoundIB || Unloop.contains(L)) && "uninitialized successor"); } } } @@ -505,7 +456,7 @@ void UnloopUpdater::updateBlockParents() { // the DFS result cached by Traversal. bool Changed = FoundIB; for (unsigned NIters = 0; Changed; ++NIters) { - assert(NIters < Unloop->getNumBlocks() && "runaway iterative algorithm"); + assert(NIters < Unloop.getNumBlocks() && "runaway iterative algorithm"); // Iterate over the postorder list of blocks, propagating the nearest loop // from successors to predecessors as before. @@ -516,7 +467,7 @@ void UnloopUpdater::updateBlockParents() { Loop *L = LI->getLoopFor(*POI); Loop *NL = getNearestLoop(*POI, L); if (NL != L) { - assert(NL != Unloop && (!NL || NL->contains(Unloop)) && + assert(NL != &Unloop && (!NL || NL->contains(&Unloop)) && "uninitialized successor"); LI->changeLoopFor(*POI, NL); Changed = true; @@ -525,22 +476,21 @@ void UnloopUpdater::updateBlockParents() { } } -/// removeBlocksFromAncestors - Remove unloop's blocks from all ancestors below -/// their new parents. +/// Remove unloop's blocks from all ancestors below their new parents. void UnloopUpdater::removeBlocksFromAncestors() { // Remove all unloop's blocks (including those in nested subloops) from // ancestors below the new parent loop. - for (Loop::block_iterator BI = Unloop->block_begin(), - BE = Unloop->block_end(); BI != BE; ++BI) { + for (Loop::block_iterator BI = Unloop.block_begin(), + BE = Unloop.block_end(); BI != BE; ++BI) { Loop *OuterParent = LI->getLoopFor(*BI); - if (Unloop->contains(OuterParent)) { - while (OuterParent->getParentLoop() != Unloop) + if (Unloop.contains(OuterParent)) { + while (OuterParent->getParentLoop() != &Unloop) OuterParent = OuterParent->getParentLoop(); OuterParent = SubloopParents[OuterParent]; } // Remove blocks from former Ancestors except Unloop itself which will be // deleted. - for (Loop *OldParent = Unloop->getParentLoop(); OldParent != OuterParent; + for (Loop *OldParent = Unloop.getParentLoop(); OldParent != OuterParent; OldParent = OldParent->getParentLoop()) { assert(OldParent && "new loop is not an ancestor of the original"); OldParent->removeBlockFromLoop(*BI); @@ -548,12 +498,11 @@ void UnloopUpdater::removeBlocksFromAncestors() { } } -/// updateSubloopParents - Update the parent loop for all subloops directly -/// nested within unloop. +/// Update the parent loop for all subloops directly nested within unloop. void UnloopUpdater::updateSubloopParents() { - while (!Unloop->empty()) { - Loop *Subloop = *std::prev(Unloop->end()); - Unloop->removeChildLoop(std::prev(Unloop->end())); + while (!Unloop.empty()) { + Loop *Subloop = *std::prev(Unloop.end()); + Unloop.removeChildLoop(std::prev(Unloop.end())); assert(SubloopParents.count(Subloop) && "DFS failed to visit subloop"); if (Loop *Parent = SubloopParents[Subloop]) @@ -563,9 +512,9 @@ void UnloopUpdater::updateSubloopParents() { } } -/// getNearestLoop - Return the nearest parent loop among this block's -/// successors. If a successor is a subloop header, consider its parent to be -/// the nearest parent of the subloop's exits. +/// Return the nearest parent loop among this block's successors. If a successor +/// is a subloop header, consider its parent to be the nearest parent of the +/// subloop's exits. /// /// For subloop blocks, simply update SubloopParents and return NULL. Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) { @@ -575,16 +524,16 @@ Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) { Loop *NearLoop = BBLoop; Loop *Subloop = nullptr; - if (NearLoop != Unloop && Unloop->contains(NearLoop)) { + if (NearLoop != &Unloop && Unloop.contains(NearLoop)) { Subloop = NearLoop; // Find the subloop ancestor that is directly contained within Unloop. - while (Subloop->getParentLoop() != Unloop) { + while (Subloop->getParentLoop() != &Unloop) { Subloop = Subloop->getParentLoop(); assert(Subloop && "subloop is not an ancestor of the original loop"); } // Get the current nearest parent of the Subloop exits, initially Unloop. NearLoop = - SubloopParents.insert(std::make_pair(Subloop, Unloop)).first->second; + SubloopParents.insert(std::make_pair(Subloop, &Unloop)).first->second; } succ_iterator I = succ_begin(BB), E = succ_end(BB); @@ -597,33 +546,33 @@ Loop *UnloopUpdater::getNearestLoop(BasicBlock *BB, Loop *BBLoop) { continue; // self loops are uninteresting Loop *L = LI->getLoopFor(*I); - if (L == Unloop) { + if (L == &Unloop) { // This successor has not been processed. This path must lead to an // irreducible backedge. assert((FoundIB || !DFS.hasPostorder(*I)) && "should have seen IB"); FoundIB = true; } - if (L != Unloop && Unloop->contains(L)) { + if (L != &Unloop && Unloop.contains(L)) { // Successor is in a subloop. if (Subloop) continue; // Branching within subloops. Ignore it. // BB branches from the original into a subloop header. - assert(L->getParentLoop() == Unloop && "cannot skip into nested loops"); + assert(L->getParentLoop() == &Unloop && "cannot skip into nested loops"); // Get the current nearest parent of the Subloop's exits. L = SubloopParents[L]; // L could be Unloop if the only exit was an irreducible backedge. } - if (L == Unloop) { + if (L == &Unloop) { continue; } // Handle critical edges from Unloop into a sibling loop. - if (L && !L->contains(Unloop)) { + if (L && !L->contains(&Unloop)) { L = L->getParentLoop(); } // Remember the nearest parent loop among successors or subloop exits. - if (NearLoop == Unloop || !NearLoop || NearLoop->contains(L)) + if (NearLoop == &Unloop || !NearLoop || NearLoop->contains(L)) NearLoop = L; } if (Subloop) { @@ -698,7 +647,7 @@ void LoopInfo::markAsRemoved(Loop *Unloop) { char LoopAnalysis::PassID; -LoopInfo LoopAnalysis::run(Function &F, AnalysisManager<Function> *AM) { +LoopInfo LoopAnalysis::run(Function &F, AnalysisManager<Function> &AM) { // FIXME: Currently we create a LoopInfo from scratch for every function. // This may prove to be too wasteful due to deallocating and re-allocating // memory each time for the underlying map and vector datastructures. At some @@ -706,13 +655,13 @@ LoopInfo LoopAnalysis::run(Function &F, AnalysisManager<Function> *AM) { // objects. I don't want to add that kind of complexity until the scope of // the problem is better understood. LoopInfo LI; - LI.analyze(AM->getResult<DominatorTreeAnalysis>(F)); + LI.analyze(AM.getResult<DominatorTreeAnalysis>(F)); return LI; } PreservedAnalyses LoopPrinterPass::run(Function &F, - AnalysisManager<Function> *AM) { - AM->getResult<LoopAnalysis>(F).print(OS); + AnalysisManager<Function> &AM) { + AM.getResult<LoopAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -720,7 +669,7 @@ PrintLoopPass::PrintLoopPass() : OS(dbgs()) {} PrintLoopPass::PrintLoopPass(raw_ostream &OS, const std::string &Banner) : OS(OS), Banner(Banner) {} -PreservedAnalyses PrintLoopPass::run(Loop &L) { +PreservedAnalyses PrintLoopPass::run(Loop &L, AnalysisManager<Loop> &) { OS << Banner; for (auto *Block : L.blocks()) if (Block) diff --git a/lib/Analysis/LoopPass.cpp b/lib/Analysis/LoopPass.cpp index 8163231c3323..222345c9a980 100644 --- a/lib/Analysis/LoopPass.cpp +++ b/lib/Analysis/LoopPass.cpp @@ -14,8 +14,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopPassManager.h" #include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/OptBisect.h" #include "llvm/IR/PassManager.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" @@ -45,8 +47,10 @@ public: auto BBI = find_if(L->blocks().begin(), L->blocks().end(), [](BasicBlock *BB) { return BB; }); if (BBI != L->blocks().end() && - isFunctionInPrintList((*BBI)->getParent()->getName())) - P.run(*L); + isFunctionInPrintList((*BBI)->getParent()->getName())) { + AnalysisManager<Loop> DummyLAM; + P.run(*L, DummyLAM); + } return false; } }; @@ -105,9 +109,7 @@ void LPPassManager::cloneBasicBlockSimpleAnalysis(BasicBlock *From, /// deleteSimpleAnalysisValue - Invoke deleteAnalysisValue hook for all passes. void LPPassManager::deleteSimpleAnalysisValue(Value *V, Loop *L) { if (BasicBlock *BB = dyn_cast<BasicBlock>(V)) { - for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE; - ++BI) { - Instruction &I = *BI; + for (Instruction &I : *BB) { deleteSimpleAnalysisValue(&I, L); } } @@ -335,11 +337,16 @@ void LoopPass::assignPassManager(PMStack &PMS, LPPM->add(this); } -// Containing function has Attribute::OptimizeNone and transformation -// passes should skip it. -bool LoopPass::skipOptnoneFunction(const Loop *L) const { +bool LoopPass::skipLoop(const Loop *L) const { const Function *F = L->getHeader()->getParent(); - if (F && F->hasFnAttribute(Attribute::OptimizeNone)) { + if (!F) + return false; + // Check the opt bisect limit. + LLVMContext &Context = F->getContext(); + if (!Context.getOptBisect().shouldRunPass(this, *L)) + return true; + // Check for the OptimizeNone attribute. + if (F->hasFnAttribute(Attribute::OptimizeNone)) { // FIXME: Report this to dbgs() only once per function. DEBUG(dbgs() << "Skipping pass '" << getPassName() << "' in function " << F->getName() << "\n"); diff --git a/lib/Analysis/LoopPassManager.cpp b/lib/Analysis/LoopPassManager.cpp new file mode 100644 index 000000000000..8bac19a58217 --- /dev/null +++ b/lib/Analysis/LoopPassManager.cpp @@ -0,0 +1,39 @@ +//===- LoopPassManager.cpp - Loop pass management -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopPassManager.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +// Explicit instantiations for core typedef'ed templates. +namespace llvm { +template class PassManager<Loop>; +template class AnalysisManager<Loop>; +template class InnerAnalysisManagerProxy<LoopAnalysisManager, Function>; +template class OuterAnalysisManagerProxy<FunctionAnalysisManager, Loop>; +} + +PreservedAnalyses llvm::getLoopPassPreservedAnalyses() { + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + PA.preserve<ScalarEvolutionAnalysis>(); + // TODO: What we really want to do here is preserve an AA category, but that + // concept doesn't exist yet. + PA.preserve<BasicAA>(); + PA.preserve<GlobalsAA>(); + PA.preserve<SCEVAA>(); + return PA; +} diff --git a/lib/Analysis/LoopUnrollAnalyzer.cpp b/lib/Analysis/LoopUnrollAnalyzer.cpp new file mode 100644 index 000000000000..f59257ab16b5 --- /dev/null +++ b/lib/Analysis/LoopUnrollAnalyzer.cpp @@ -0,0 +1,210 @@ +//===- LoopUnrollAnalyzer.cpp - Unrolling Effect Estimation -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements UnrolledInstAnalyzer class. It's used for predicting +// potential effects that loop unrolling might have, such as enabling constant +// propagation and other optimizations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/LoopUnrollAnalyzer.h" +#include "llvm/IR/Dominators.h" + +using namespace llvm; + +/// \brief Try to simplify instruction \param I using its SCEV expression. +/// +/// The idea is that some AddRec expressions become constants, which then +/// could trigger folding of other instructions. However, that only happens +/// for expressions whose start value is also constant, which isn't always the +/// case. In another common and important case the start value is just some +/// address (i.e. SCEVUnknown) - in this case we compute the offset and save +/// it along with the base address instead. +bool UnrolledInstAnalyzer::simplifyInstWithSCEV(Instruction *I) { + if (!SE.isSCEVable(I->getType())) + return false; + + const SCEV *S = SE.getSCEV(I); + if (auto *SC = dyn_cast<SCEVConstant>(S)) { + SimplifiedValues[I] = SC->getValue(); + return true; + } + + auto *AR = dyn_cast<SCEVAddRecExpr>(S); + if (!AR || AR->getLoop() != L) + return false; + + const SCEV *ValueAtIteration = AR->evaluateAtIteration(IterationNumber, SE); + // Check if the AddRec expression becomes a constant. + if (auto *SC = dyn_cast<SCEVConstant>(ValueAtIteration)) { + SimplifiedValues[I] = SC->getValue(); + return true; + } + + // Check if the offset from the base address becomes a constant. + auto *Base = dyn_cast<SCEVUnknown>(SE.getPointerBase(S)); + if (!Base) + return false; + auto *Offset = + dyn_cast<SCEVConstant>(SE.getMinusSCEV(ValueAtIteration, Base)); + if (!Offset) + return false; + SimplifiedAddress Address; + Address.Base = Base->getValue(); + Address.Offset = Offset->getValue(); + SimplifiedAddresses[I] = Address; + return false; +} + +/// Try to simplify binary operator I. +/// +/// TODO: Probably it's worth to hoist the code for estimating the +/// simplifications effects to a separate class, since we have a very similar +/// code in InlineCost already. +bool UnrolledInstAnalyzer::visitBinaryOperator(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + if (!isa<Constant>(LHS)) + if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) + LHS = SimpleLHS; + if (!isa<Constant>(RHS)) + if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) + RHS = SimpleRHS; + + Value *SimpleV = nullptr; + const DataLayout &DL = I.getModule()->getDataLayout(); + if (auto FI = dyn_cast<FPMathOperator>(&I)) + SimpleV = + SimplifyFPBinOp(I.getOpcode(), LHS, RHS, FI->getFastMathFlags(), DL); + else + SimpleV = SimplifyBinOp(I.getOpcode(), LHS, RHS, DL); + + if (Constant *C = dyn_cast_or_null<Constant>(SimpleV)) + SimplifiedValues[&I] = C; + + if (SimpleV) + return true; + return Base::visitBinaryOperator(I); +} + +/// Try to fold load I. +bool UnrolledInstAnalyzer::visitLoad(LoadInst &I) { + Value *AddrOp = I.getPointerOperand(); + + auto AddressIt = SimplifiedAddresses.find(AddrOp); + if (AddressIt == SimplifiedAddresses.end()) + return false; + ConstantInt *SimplifiedAddrOp = AddressIt->second.Offset; + + auto *GV = dyn_cast<GlobalVariable>(AddressIt->second.Base); + // We're only interested in loads that can be completely folded to a + // constant. + if (!GV || !GV->hasDefinitiveInitializer() || !GV->isConstant()) + return false; + + ConstantDataSequential *CDS = + dyn_cast<ConstantDataSequential>(GV->getInitializer()); + if (!CDS) + return false; + + // We might have a vector load from an array. FIXME: for now we just bail + // out in this case, but we should be able to resolve and simplify such + // loads. + if(CDS->getElementType() != I.getType()) + return false; + + int ElemSize = CDS->getElementType()->getPrimitiveSizeInBits() / 8U; + if (SimplifiedAddrOp->getValue().getActiveBits() >= 64) + return false; + int64_t Index = SimplifiedAddrOp->getSExtValue() / ElemSize; + if (Index >= CDS->getNumElements()) { + // FIXME: For now we conservatively ignore out of bound accesses, but + // we're allowed to perform the optimization in this case. + return false; + } + + Constant *CV = CDS->getElementAsConstant(Index); + assert(CV && "Constant expected."); + SimplifiedValues[&I] = CV; + + return true; +} + +/// Try to simplify cast instruction. +bool UnrolledInstAnalyzer::visitCastInst(CastInst &I) { + // Propagate constants through casts. + Constant *COp = dyn_cast<Constant>(I.getOperand(0)); + if (!COp) + COp = SimplifiedValues.lookup(I.getOperand(0)); + + // If we know a simplified value for this operand and cast is valid, save the + // result to SimplifiedValues. + // The cast can be invalid, because SimplifiedValues contains results of SCEV + // analysis, which operates on integers (and, e.g., might convert i8* null to + // i32 0). + if (COp && CastInst::castIsValid(I.getOpcode(), COp, I.getType())) { + if (Constant *C = + ConstantExpr::getCast(I.getOpcode(), COp, I.getType())) { + SimplifiedValues[&I] = C; + return true; + } + } + + return Base::visitCastInst(I); +} + +/// Try to simplify cmp instruction. +bool UnrolledInstAnalyzer::visitCmpInst(CmpInst &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + + // First try to handle simplified comparisons. + if (!isa<Constant>(LHS)) + if (Constant *SimpleLHS = SimplifiedValues.lookup(LHS)) + LHS = SimpleLHS; + if (!isa<Constant>(RHS)) + if (Constant *SimpleRHS = SimplifiedValues.lookup(RHS)) + RHS = SimpleRHS; + + if (!isa<Constant>(LHS) && !isa<Constant>(RHS)) { + auto SimplifiedLHS = SimplifiedAddresses.find(LHS); + if (SimplifiedLHS != SimplifiedAddresses.end()) { + auto SimplifiedRHS = SimplifiedAddresses.find(RHS); + if (SimplifiedRHS != SimplifiedAddresses.end()) { + SimplifiedAddress &LHSAddr = SimplifiedLHS->second; + SimplifiedAddress &RHSAddr = SimplifiedRHS->second; + if (LHSAddr.Base == RHSAddr.Base) { + LHS = LHSAddr.Offset; + RHS = RHSAddr.Offset; + } + } + } + } + + if (Constant *CLHS = dyn_cast<Constant>(LHS)) { + if (Constant *CRHS = dyn_cast<Constant>(RHS)) { + if (CLHS->getType() == CRHS->getType()) { + if (Constant *C = ConstantExpr::getCompare(I.getPredicate(), CLHS, CRHS)) { + SimplifiedValues[&I] = C; + return true; + } + } + } + } + + return Base::visitCmpInst(I); +} + +bool UnrolledInstAnalyzer::visitPHINode(PHINode &PN) { + // Run base visitor first. This way we can gather some useful for later + // analysis information. + if (Base::visitPHINode(PN)) + return true; + + // The loop induction PHI nodes are definitionally free. + return PN.getParent() == L->getHeader(); +} diff --git a/lib/Analysis/Makefile b/lib/Analysis/Makefile deleted file mode 100644 index 93fd7f9bdd93..000000000000 --- a/lib/Analysis/Makefile +++ /dev/null @@ -1,15 +0,0 @@ -##===- lib/Analysis/Makefile -------------------------------*- Makefile -*-===## -# -# The LLVM Compiler Infrastructure -# -# This file is distributed under the University of Illinois Open Source -# License. See LICENSE.TXT for details. -# -##===----------------------------------------------------------------------===## - -LEVEL = ../.. -LIBRARYNAME = LLVMAnalysis -BUILD_ARCHIVE = 1 - -include $(LEVEL)/Makefile.common - diff --git a/lib/Analysis/MemDepPrinter.cpp b/lib/Analysis/MemDepPrinter.cpp index 078cefe51807..e7a85ae06e68 100644 --- a/lib/Analysis/MemDepPrinter.cpp +++ b/lib/Analysis/MemDepPrinter.cpp @@ -50,7 +50,7 @@ namespace { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequiredTransitive<AAResultsWrapperPass>(); - AU.addRequiredTransitive<MemoryDependenceAnalysis>(); + AU.addRequiredTransitive<MemoryDependenceWrapperPass>(); AU.setPreservesAll(); } @@ -79,7 +79,7 @@ namespace { char MemDepPrinter::ID = 0; INITIALIZE_PASS_BEGIN(MemDepPrinter, "print-memdeps", "Print MemDeps of function", false, true) -INITIALIZE_PASS_DEPENDENCY(MemoryDependenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) INITIALIZE_PASS_END(MemDepPrinter, "print-memdeps", "Print MemDeps of function", false, true) @@ -92,7 +92,7 @@ const char *const MemDepPrinter::DepTypeStr[] bool MemDepPrinter::runOnFunction(Function &F) { this->F = &F; - MemoryDependenceAnalysis &MDA = getAnalysis<MemoryDependenceAnalysis>(); + MemoryDependenceResults &MDA = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); // All this code uses non-const interfaces because MemDep is not // const-friendly, though nothing is actually modified. @@ -107,14 +107,13 @@ bool MemDepPrinter::runOnFunction(Function &F) { Deps[Inst].insert(std::make_pair(getInstTypePair(Res), static_cast<BasicBlock *>(nullptr))); } else if (auto CS = CallSite(Inst)) { - const MemoryDependenceAnalysis::NonLocalDepInfo &NLDI = + const MemoryDependenceResults::NonLocalDepInfo &NLDI = MDA.getNonLocalCallDependency(CS); DepSet &InstDeps = Deps[Inst]; - for (MemoryDependenceAnalysis::NonLocalDepInfo::const_iterator - I = NLDI.begin(), E = NLDI.end(); I != E; ++I) { - const MemDepResult &Res = I->getResult(); - InstDeps.insert(std::make_pair(getInstTypePair(Res), I->getBB())); + for (const NonLocalDepEntry &I : NLDI) { + const MemDepResult &Res = I.getResult(); + InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); } } else { SmallVector<NonLocalDepResult, 4> NLDI; @@ -123,10 +122,9 @@ bool MemDepPrinter::runOnFunction(Function &F) { MDA.getNonLocalPointerDependency(Inst, NLDI); DepSet &InstDeps = Deps[Inst]; - for (SmallVectorImpl<NonLocalDepResult>::const_iterator - I = NLDI.begin(), E = NLDI.end(); I != E; ++I) { - const MemDepResult &Res = I->getResult(); - InstDeps.insert(std::make_pair(getInstTypePair(Res), I->getBB())); + for (const NonLocalDepResult &I : NLDI) { + const MemDepResult &Res = I.getResult(); + InstDeps.insert(std::make_pair(getInstTypePair(Res), I.getBB())); } } } diff --git a/lib/Analysis/MemDerefPrinter.cpp b/lib/Analysis/MemDerefPrinter.cpp index 36f1424c8cf9..fa0cc5a46c2b 100644 --- a/lib/Analysis/MemDerefPrinter.cpp +++ b/lib/Analysis/MemDerefPrinter.cpp @@ -10,7 +10,7 @@ #include "llvm/Analysis/Passes.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/Loads.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/InstIterator.h" diff --git a/lib/Analysis/MemoryBuiltins.cpp b/lib/Analysis/MemoryBuiltins.cpp index 9e896aed0dce..f23477622bec 100644 --- a/lib/Analysis/MemoryBuiltins.cpp +++ b/lib/Analysis/MemoryBuiltins.cpp @@ -42,39 +42,38 @@ enum AllocType : uint8_t { }; struct AllocFnsTy { - LibFunc::Func Func; AllocType AllocTy; - unsigned char NumParams; + unsigned NumParams; // First and Second size parameters (or -1 if unused) - signed char FstParam, SndParam; + int FstParam, SndParam; }; // FIXME: certain users need more information. E.g., SimplifyLibCalls needs to // know which functions are nounwind, noalias, nocapture parameters, etc. -static const AllocFnsTy AllocationFnData[] = { - {LibFunc::malloc, MallocLike, 1, 0, -1}, - {LibFunc::valloc, MallocLike, 1, 0, -1}, - {LibFunc::Znwj, OpNewLike, 1, 0, -1}, // new(unsigned int) - {LibFunc::ZnwjRKSt9nothrow_t, MallocLike, 2, 0, -1}, // new(unsigned int, nothrow) - {LibFunc::Znwm, OpNewLike, 1, 0, -1}, // new(unsigned long) - {LibFunc::ZnwmRKSt9nothrow_t, MallocLike, 2, 0, -1}, // new(unsigned long, nothrow) - {LibFunc::Znaj, OpNewLike, 1, 0, -1}, // new[](unsigned int) - {LibFunc::ZnajRKSt9nothrow_t, MallocLike, 2, 0, -1}, // new[](unsigned int, nothrow) - {LibFunc::Znam, OpNewLike, 1, 0, -1}, // new[](unsigned long) - {LibFunc::ZnamRKSt9nothrow_t, MallocLike, 2, 0, -1}, // new[](unsigned long, nothrow) - {LibFunc::msvc_new_int, OpNewLike, 1, 0, -1}, // new(unsigned int) - {LibFunc::msvc_new_int_nothrow, MallocLike, 2, 0, -1}, // new(unsigned int, nothrow) - {LibFunc::msvc_new_longlong, OpNewLike, 1, 0, -1}, // new(unsigned long long) - {LibFunc::msvc_new_longlong_nothrow, MallocLike, 2, 0, -1}, // new(unsigned long long, nothrow) - {LibFunc::msvc_new_array_int, OpNewLike, 1, 0, -1}, // new[](unsigned int) - {LibFunc::msvc_new_array_int_nothrow, MallocLike, 2, 0, -1}, // new[](unsigned int, nothrow) - {LibFunc::msvc_new_array_longlong, OpNewLike, 1, 0, -1}, // new[](unsigned long long) - {LibFunc::msvc_new_array_longlong_nothrow, MallocLike, 2, 0, -1}, // new[](unsigned long long, nothrow) - {LibFunc::calloc, CallocLike, 2, 0, 1}, - {LibFunc::realloc, ReallocLike, 2, 1, -1}, - {LibFunc::reallocf, ReallocLike, 2, 1, -1}, - {LibFunc::strdup, StrDupLike, 1, -1, -1}, - {LibFunc::strndup, StrDupLike, 2, 1, -1} +static const std::pair<LibFunc::Func, AllocFnsTy> AllocationFnData[] = { + {LibFunc::malloc, {MallocLike, 1, 0, -1}}, + {LibFunc::valloc, {MallocLike, 1, 0, -1}}, + {LibFunc::Znwj, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc::ZnwjRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc::Znwm, {OpNewLike, 1, 0, -1}}, // new(unsigned long) + {LibFunc::ZnwmRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new(unsigned long, nothrow) + {LibFunc::Znaj, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc::ZnajRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc::Znam, {OpNewLike, 1, 0, -1}}, // new[](unsigned long) + {LibFunc::ZnamRKSt9nothrow_t, {MallocLike, 2, 0, -1}}, // new[](unsigned long, nothrow) + {LibFunc::msvc_new_int, {OpNewLike, 1, 0, -1}}, // new(unsigned int) + {LibFunc::msvc_new_int_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned int, nothrow) + {LibFunc::msvc_new_longlong, {OpNewLike, 1, 0, -1}}, // new(unsigned long long) + {LibFunc::msvc_new_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new(unsigned long long, nothrow) + {LibFunc::msvc_new_array_int, {OpNewLike, 1, 0, -1}}, // new[](unsigned int) + {LibFunc::msvc_new_array_int_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned int, nothrow) + {LibFunc::msvc_new_array_longlong, {OpNewLike, 1, 0, -1}}, // new[](unsigned long long) + {LibFunc::msvc_new_array_longlong_nothrow, {MallocLike, 2, 0, -1}}, // new[](unsigned long long, nothrow) + {LibFunc::calloc, {CallocLike, 2, 0, 1}}, + {LibFunc::realloc, {ReallocLike, 2, 1, -1}}, + {LibFunc::reallocf, {ReallocLike, 2, 1, -1}}, + {LibFunc::strdup, {StrDupLike, 1, -1, -1}}, + {LibFunc::strndup, {StrDupLike, 2, 1, -1}} // TODO: Handle "int posix_memalign(void **, size_t, size_t)" }; @@ -96,34 +95,57 @@ static Function *getCalledFunction(const Value *V, bool LookThroughBitCast) { return Callee; } -/// \brief Returns the allocation data for the given value if it is a call to a -/// known allocation function, and NULL otherwise. -static const AllocFnsTy *getAllocationData(const Value *V, AllocType AllocTy, - const TargetLibraryInfo *TLI, - bool LookThroughBitCast = false) { +/// Returns the allocation data for the given value if it's either a call to a +/// known allocation function, or a call to a function with the allocsize +/// attribute. +static Optional<AllocFnsTy> getAllocationData(const Value *V, AllocType AllocTy, + const TargetLibraryInfo *TLI, + bool LookThroughBitCast = false) { // Skip intrinsics if (isa<IntrinsicInst>(V)) - return nullptr; + return None; - Function *Callee = getCalledFunction(V, LookThroughBitCast); + const Function *Callee = getCalledFunction(V, LookThroughBitCast); if (!Callee) - return nullptr; + return None; + + // If it has allocsize, we can skip checking if it's a known function. + // + // MallocLike is chosen here because allocsize makes no guarantees about the + // nullness of the result of the function, nor does it deal with strings, nor + // does it require that the memory returned is zeroed out. + LLVM_CONSTEXPR auto AllocSizeAllocTy = MallocLike; + if ((AllocTy & AllocSizeAllocTy) == AllocSizeAllocTy && + Callee->hasFnAttribute(Attribute::AllocSize)) { + Attribute Attr = Callee->getFnAttribute(Attribute::AllocSize); + std::pair<unsigned, Optional<unsigned>> Args = Attr.getAllocSizeArgs(); + + AllocFnsTy Result; + Result.AllocTy = AllocSizeAllocTy; + Result.NumParams = Callee->getNumOperands(); + Result.FstParam = Args.first; + Result.SndParam = Args.second.getValueOr(-1); + return Result; + } // Make sure that the function is available. StringRef FnName = Callee->getName(); LibFunc::Func TLIFn; if (!TLI || !TLI->getLibFunc(FnName, TLIFn) || !TLI->has(TLIFn)) - return nullptr; + return None; - const AllocFnsTy *FnData = + const auto *Iter = std::find_if(std::begin(AllocationFnData), std::end(AllocationFnData), - [TLIFn](const AllocFnsTy &Fn) { return Fn.Func == TLIFn; }); + [TLIFn](const std::pair<LibFunc::Func, AllocFnsTy> &P) { + return P.first == TLIFn; + }); - if (FnData == std::end(AllocationFnData)) - return nullptr; + if (Iter == std::end(AllocationFnData)) + return None; + const AllocFnsTy *FnData = &Iter->second; if ((FnData->AllocTy & AllocTy) != FnData->AllocTy) - return nullptr; + return None; // Check function prototype. int FstParam = FnData->FstParam; @@ -138,13 +160,13 @@ static const AllocFnsTy *getAllocationData(const Value *V, AllocType AllocTy, (SndParam < 0 || FTy->getParamType(SndParam)->isIntegerTy(32) || FTy->getParamType(SndParam)->isIntegerTy(64))) - return FnData; - return nullptr; + return *FnData; + return None; } static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) { ImmutableCallSite CS(LookThroughBitCast ? V->stripPointerCasts() : V); - return CS && CS.hasFnAttr(Attribute::NoAlias); + return CS && CS.paramHasAttr(AttributeSet::ReturnIndex, Attribute::NoAlias); } @@ -153,7 +175,7 @@ static bool hasNoAliasAttr(const Value *V, bool LookThroughBitCast) { /// like). bool llvm::isAllocationFn(const Value *V, const TargetLibraryInfo *TLI, bool LookThroughBitCast) { - return getAllocationData(V, AnyAlloc, TLI, LookThroughBitCast); + return getAllocationData(V, AnyAlloc, TLI, LookThroughBitCast).hasValue(); } /// \brief Tests if a value is a call or invoke to a function that returns a @@ -170,21 +192,21 @@ bool llvm::isNoAliasFn(const Value *V, const TargetLibraryInfo *TLI, /// allocates uninitialized memory (such as malloc). bool llvm::isMallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, bool LookThroughBitCast) { - return getAllocationData(V, MallocLike, TLI, LookThroughBitCast); + return getAllocationData(V, MallocLike, TLI, LookThroughBitCast).hasValue(); } /// \brief Tests if a value is a call or invoke to a library function that /// allocates zero-filled memory (such as calloc). bool llvm::isCallocLikeFn(const Value *V, const TargetLibraryInfo *TLI, bool LookThroughBitCast) { - return getAllocationData(V, CallocLike, TLI, LookThroughBitCast); + return getAllocationData(V, CallocLike, TLI, LookThroughBitCast).hasValue(); } /// \brief Tests if a value is a call or invoke to a library function that /// allocates memory (either malloc, calloc, or strdup like). bool llvm::isAllocLikeFn(const Value *V, const TargetLibraryInfo *TLI, bool LookThroughBitCast) { - return getAllocationData(V, AllocLike, TLI, LookThroughBitCast); + return getAllocationData(V, AllocLike, TLI, LookThroughBitCast).hasValue(); } /// extractMallocCall - Returns the corresponding CallInst if the instruction @@ -214,8 +236,7 @@ static Value *computeArraySize(const CallInst *CI, const DataLayout &DL, // return the multiple. Otherwise, return NULL. Value *MallocArg = CI->getArgOperand(0); Value *Multiple = nullptr; - if (ComputeMultiple(MallocArg, ElementSize, Multiple, - LookThroughSExt)) + if (ComputeMultiple(MallocArg, ElementSize, Multiple, LookThroughSExt)) return Multiple; return nullptr; @@ -345,29 +366,29 @@ const CallInst *llvm::isFreeCall(const Value *I, const TargetLibraryInfo *TLI) { //===----------------------------------------------------------------------===// // Utility functions to compute size of objects. // - +static APInt getSizeWithOverflow(const SizeOffsetType &Data) { + if (Data.second.isNegative() || Data.first.ult(Data.second)) + return APInt(Data.first.getBitWidth(), 0); + return Data.first - Data.second; +} /// \brief Compute the size of the object pointed by Ptr. Returns true and the /// object size in Size if successful, and false otherwise. /// If RoundToAlign is true, then Size is rounded up to the aligment of allocas, /// byval arguments, and global variables. bool llvm::getObjectSize(const Value *Ptr, uint64_t &Size, const DataLayout &DL, - const TargetLibraryInfo *TLI, bool RoundToAlign) { - ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), RoundToAlign); + const TargetLibraryInfo *TLI, bool RoundToAlign, + llvm::ObjSizeMode Mode) { + ObjectSizeOffsetVisitor Visitor(DL, TLI, Ptr->getContext(), + RoundToAlign, Mode); SizeOffsetType Data = Visitor.compute(const_cast<Value*>(Ptr)); if (!Visitor.bothKnown(Data)) return false; - APInt ObjSize = Data.first, Offset = Data.second; - // check for overflow - if (Offset.slt(0) || ObjSize.ult(Offset)) - Size = 0; - else - Size = (ObjSize - Offset).getZExtValue(); + Size = getSizeWithOverflow(Data).getZExtValue(); return true; } - STATISTIC(ObjectVisitorArgument, "Number of arguments with unsolved size and offset"); STATISTIC(ObjectVisitorLoad, @@ -376,15 +397,16 @@ STATISTIC(ObjectVisitorLoad, APInt ObjectSizeOffsetVisitor::align(APInt Size, uint64_t Align) { if (RoundToAlign && Align) - return APInt(IntTyBits, RoundUpToAlignment(Size.getZExtValue(), Align)); + return APInt(IntTyBits, alignTo(Size.getZExtValue(), Align)); return Size; } ObjectSizeOffsetVisitor::ObjectSizeOffsetVisitor(const DataLayout &DL, const TargetLibraryInfo *TLI, LLVMContext &Context, - bool RoundToAlign) - : DL(DL), TLI(TLI), RoundToAlign(RoundToAlign) { + bool RoundToAlign, + ObjSizeMode Mode) + : DL(DL), TLI(TLI), RoundToAlign(RoundToAlign), Mode(Mode) { // Pointer size must be rechecked for each object visited since it could have // a different address space. } @@ -443,7 +465,7 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitAllocaInst(AllocaInst &I) { } SizeOffsetType ObjectSizeOffsetVisitor::visitArgument(Argument &A) { - // no interprocedural analysis is done at the moment + // No interprocedural analysis is done at the moment. if (!A.hasByValOrInAllocaAttr()) { ++ObjectVisitorArgument; return unknown(); @@ -454,20 +476,21 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitArgument(Argument &A) { } SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { - const AllocFnsTy *FnData = getAllocationData(CS.getInstruction(), AnyAlloc, - TLI); + Optional<AllocFnsTy> FnData = + getAllocationData(CS.getInstruction(), AnyAlloc, TLI); if (!FnData) return unknown(); - // handle strdup-like functions separately + // Handle strdup-like functions separately. if (FnData->AllocTy == StrDupLike) { APInt Size(IntTyBits, GetStringLength(CS.getArgument(0))); if (!Size) return unknown(); - // strndup limits strlen + // Strndup limits strlen. if (FnData->FstParam > 0) { - ConstantInt *Arg= dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); + ConstantInt *Arg = + dyn_cast<ConstantInt>(CS.getArgument(FnData->FstParam)); if (!Arg) return unknown(); @@ -482,8 +505,26 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { if (!Arg) return unknown(); - APInt Size = Arg->getValue().zextOrSelf(IntTyBits); - // size determined by just 1 parameter + // When we're compiling N-bit code, and the user uses parameters that are + // greater than N bits (e.g. uint64_t on a 32-bit build), we can run into + // trouble with APInt size issues. This function handles resizing + overflow + // checks for us. + auto CheckedZextOrTrunc = [&](APInt &I) { + // More bits than we can handle. Checking the bit width isn't necessary, but + // it's faster than checking active bits, and should give `false` in the + // vast majority of cases. + if (I.getBitWidth() > IntTyBits && I.getActiveBits() > IntTyBits) + return false; + if (I.getBitWidth() != IntTyBits) + I = I.zextOrTrunc(IntTyBits); + return true; + }; + + APInt Size = Arg->getValue(); + if (!CheckedZextOrTrunc(Size)) + return unknown(); + + // Size is determined by just 1 parameter. if (FnData->SndParam < 0) return std::make_pair(Size, Zero); @@ -491,8 +532,13 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitCallSite(CallSite CS) { if (!Arg) return unknown(); - Size *= Arg->getValue().zextOrSelf(IntTyBits); - return std::make_pair(Size, Zero); + APInt NumElems = Arg->getValue(); + if (!CheckedZextOrTrunc(NumElems)) + return unknown(); + + bool Overflow; + Size = Size.umul_ov(NumElems, Overflow); + return Overflow ? unknown() : std::make_pair(Size, Zero); // TODO: handle more standard functions (+ wchar cousins): // - strdup / strndup @@ -529,7 +575,7 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitGEPOperator(GEPOperator &GEP) { } SizeOffsetType ObjectSizeOffsetVisitor::visitGlobalAlias(GlobalAlias &GA) { - if (GA.mayBeOverridden()) + if (GA.isInterposable()) return unknown(); return compute(GA.getAliasee()); } @@ -560,8 +606,28 @@ SizeOffsetType ObjectSizeOffsetVisitor::visitPHINode(PHINode&) { SizeOffsetType ObjectSizeOffsetVisitor::visitSelectInst(SelectInst &I) { SizeOffsetType TrueSide = compute(I.getTrueValue()); SizeOffsetType FalseSide = compute(I.getFalseValue()); - if (bothKnown(TrueSide) && bothKnown(FalseSide) && TrueSide == FalseSide) - return TrueSide; + if (bothKnown(TrueSide) && bothKnown(FalseSide)) { + if (TrueSide == FalseSide) { + return TrueSide; + } + + APInt TrueResult = getSizeWithOverflow(TrueSide); + APInt FalseResult = getSizeWithOverflow(FalseSide); + + if (TrueResult == FalseResult) { + return TrueSide; + } + if (Mode == ObjSizeMode::Min) { + if (TrueResult.slt(FalseResult)) + return TrueSide; + return FalseSide; + } + if (Mode == ObjSizeMode::Max) { + if (TrueResult.sgt(FalseResult)) + return TrueSide; + return FalseSide; + } + } return unknown(); } @@ -591,11 +657,11 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute(Value *V) { SizeOffsetEvalType Result = compute_(V); if (!bothKnown(Result)) { - // erase everything that was computed in this iteration from the cache, so + // Erase everything that was computed in this iteration from the cache, so // that no dangling references are left behind. We could be a bit smarter if // we kept a dependency graph. It's probably not worth the complexity. - for (PtrSetTy::iterator I=SeenVals.begin(), E=SeenVals.end(); I != E; ++I) { - CacheMapTy::iterator CacheIt = CacheMap.find(*I); + for (const Value *SeenVal : SeenVals) { + CacheMapTy::iterator CacheIt = CacheMap.find(SeenVal); // non-computable results can be safely cached if (CacheIt != CacheMap.end() && anyKnown(CacheIt->second)) CacheMap.erase(CacheIt); @@ -615,18 +681,18 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { V = V->stripPointerCasts(); - // check cache + // Check cache. CacheMapTy::iterator CacheIt = CacheMap.find(V); if (CacheIt != CacheMap.end()) return CacheIt->second; - // always generate code immediately before the instruction being - // processed, so that the generated code dominates the same BBs + // Always generate code immediately before the instruction being + // processed, so that the generated code dominates the same BBs. BuilderTy::InsertPointGuard Guard(Builder); if (Instruction *I = dyn_cast<Instruction>(V)) Builder.SetInsertPoint(I); - // now compute the size and offset + // Now compute the size and offset. SizeOffsetEvalType Result; // Record the pointers that were handled in this run, so that they can be @@ -643,7 +709,7 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::compute_(Value *V) { cast<ConstantExpr>(V)->getOpcode() == Instruction::IntToPtr) || isa<GlobalAlias>(V) || isa<GlobalVariable>(V)) { - // ignore values where we cannot do more than what ObjectSizeVisitor can + // Ignore values where we cannot do more than ObjectSizeVisitor. Result = unknown(); } else { DEBUG(dbgs() << "ObjectSizeOffsetEvaluator::compute() unhandled value: " @@ -670,12 +736,12 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitAllocaInst(AllocaInst &I) { } SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitCallSite(CallSite CS) { - const AllocFnsTy *FnData = getAllocationData(CS.getInstruction(), AnyAlloc, - TLI); + Optional<AllocFnsTy> FnData = + getAllocationData(CS.getInstruction(), AnyAlloc, TLI); if (!FnData) return unknown(); - // handle strdup-like functions separately + // Handle strdup-like functions separately. if (FnData->AllocTy == StrDupLike) { // TODO return unknown(); @@ -731,14 +797,14 @@ SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitLoadInst(LoadInst&) { } SizeOffsetEvalType ObjectSizeOffsetEvaluator::visitPHINode(PHINode &PHI) { - // create 2 PHIs: one for size and another for offset + // Create 2 PHIs: one for size and another for offset. PHINode *SizePHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); PHINode *OffsetPHI = Builder.CreatePHI(IntTy, PHI.getNumIncomingValues()); - // insert right away in the cache to handle recursive PHIs + // Insert right away in the cache to handle recursive PHIs. CacheMap[&PHI] = std::make_pair(SizePHI, OffsetPHI); - // compute offset/size for each PHI incoming pointer + // Compute offset/size for each PHI incoming pointer. for (unsigned i = 0, e = PHI.getNumIncomingValues(); i != e; ++i) { Builder.SetInsertPoint(&*PHI.getIncomingBlock(i)->getFirstInsertionPt()); SizeOffsetEvalType EdgeData = compute_(PHI.getIncomingValue(i)); diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp index 6918360536a3..33499334fefa 100644 --- a/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -45,8 +45,7 @@ STATISTIC(NumCacheNonLocalPtr, "Number of fully cached non-local ptr responses"); STATISTIC(NumCacheDirtyNonLocalPtr, "Number of cached, but dirty, non-local ptr responses"); -STATISTIC(NumUncacheNonLocalPtr, - "Number of uncached non-local ptr responses"); +STATISTIC(NumUncacheNonLocalPtr, "Number of uncached non-local ptr responses"); STATISTIC(NumCacheCompleteNonLocalPtr, "Number of block queries that were completely cached"); @@ -57,75 +56,35 @@ static cl::opt<unsigned> BlockScanLimit( cl::desc("The number of instructions to scan in a block in memory " "dependency analysis (default = 100)")); +static cl::opt<unsigned> + BlockNumberLimit("memdep-block-number-limit", cl::Hidden, cl::init(1000), + cl::desc("The number of blocks to scan during memory " + "dependency analysis (default = 1000)")); + // Limit on the number of memdep results to process. static const unsigned int NumResultsLimit = 100; -char MemoryDependenceAnalysis::ID = 0; - -// Register this pass... -INITIALIZE_PASS_BEGIN(MemoryDependenceAnalysis, "memdep", - "Memory Dependence Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(MemoryDependenceAnalysis, "memdep", - "Memory Dependence Analysis", false, true) - -MemoryDependenceAnalysis::MemoryDependenceAnalysis() - : FunctionPass(ID) { - initializeMemoryDependenceAnalysisPass(*PassRegistry::getPassRegistry()); -} -MemoryDependenceAnalysis::~MemoryDependenceAnalysis() { -} - -/// Clean up memory in between runs -void MemoryDependenceAnalysis::releaseMemory() { - LocalDeps.clear(); - NonLocalDeps.clear(); - NonLocalPointerDeps.clear(); - ReverseLocalDeps.clear(); - ReverseNonLocalDeps.clear(); - ReverseNonLocalPtrDeps.clear(); - PredCache.clear(); -} - -/// getAnalysisUsage - Does not modify anything. It uses Alias Analysis. +/// This is a helper function that removes Val from 'Inst's set in ReverseMap. /// -void MemoryDependenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { - AU.setPreservesAll(); - AU.addRequired<AssumptionCacheTracker>(); - AU.addRequiredTransitive<AAResultsWrapperPass>(); - AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); -} - -bool MemoryDependenceAnalysis::runOnFunction(Function &F) { - AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - DominatorTreeWrapperPass *DTWP = - getAnalysisIfAvailable<DominatorTreeWrapperPass>(); - DT = DTWP ? &DTWP->getDomTree() : nullptr; - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - return false; -} - -/// RemoveFromReverseMap - This is a helper function that removes Val from -/// 'Inst's set in ReverseMap. If the set becomes empty, remove Inst's entry. +/// If the set becomes empty, remove Inst's entry. template <typename KeyTy> -static void RemoveFromReverseMap(DenseMap<Instruction*, - SmallPtrSet<KeyTy, 4> > &ReverseMap, - Instruction *Inst, KeyTy Val) { - typename DenseMap<Instruction*, SmallPtrSet<KeyTy, 4> >::iterator - InstIt = ReverseMap.find(Inst); +static void +RemoveFromReverseMap(DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>> &ReverseMap, + Instruction *Inst, KeyTy Val) { + typename DenseMap<Instruction *, SmallPtrSet<KeyTy, 4>>::iterator InstIt = + ReverseMap.find(Inst); assert(InstIt != ReverseMap.end() && "Reverse map out of sync?"); bool Found = InstIt->second.erase(Val); - assert(Found && "Invalid reverse map!"); (void)Found; + assert(Found && "Invalid reverse map!"); + (void)Found; if (InstIt->second.empty()) ReverseMap.erase(InstIt); } -/// GetLocation - If the given instruction references a specific memory -/// location, fill in Loc with the details, otherwise set Loc.Ptr to null. -/// Return a ModRefInfo value describing the general behavior of the +/// If the given instruction references a specific memory location, fill in Loc +/// with the details, otherwise set Loc.Ptr to null. +/// +/// Returns a ModRefInfo value describing the general behavior of the /// instruction. static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, const TargetLibraryInfo &TLI) { @@ -134,7 +93,7 @@ static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, Loc = MemoryLocation::get(LI); return MRI_Ref; } - if (LI->getOrdering() == Monotonic) { + if (LI->getOrdering() == AtomicOrdering::Monotonic) { Loc = MemoryLocation::get(LI); return MRI_ModRef; } @@ -147,7 +106,7 @@ static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, Loc = MemoryLocation::get(SI); return MRI_Mod; } - if (SI->getOrdering() == Monotonic) { + if (SI->getOrdering() == AtomicOrdering::Monotonic) { Loc = MemoryLocation::get(SI); return MRI_ModRef; } @@ -201,11 +160,10 @@ static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, return MRI_NoModRef; } -/// getCallSiteDependencyFrom - Private helper for finding the local -/// dependencies of a call site. -MemDepResult MemoryDependenceAnalysis:: -getCallSiteDependencyFrom(CallSite CS, bool isReadOnlyCall, - BasicBlock::iterator ScanIt, BasicBlock *BB) { +/// Private helper for finding the local dependencies of a call site. +MemDepResult MemoryDependenceResults::getCallSiteDependencyFrom( + CallSite CS, bool isReadOnlyCall, BasicBlock::iterator ScanIt, + BasicBlock *BB) { unsigned Limit = BlockScanLimit; // Walk backwards through the block, looking for dependencies @@ -220,19 +178,20 @@ getCallSiteDependencyFrom(CallSite CS, bool isReadOnlyCall, // If this inst is a memory op, get the pointer it accessed MemoryLocation Loc; - ModRefInfo MR = GetLocation(Inst, Loc, *TLI); + ModRefInfo MR = GetLocation(Inst, Loc, TLI); if (Loc.Ptr) { // A simple instruction. - if (AA->getModRefInfo(CS, Loc) != MRI_NoModRef) + if (AA.getModRefInfo(CS, Loc) != MRI_NoModRef) return MemDepResult::getClobber(Inst); continue; } if (auto InstCS = CallSite(Inst)) { // Debug intrinsics don't cause dependences. - if (isa<DbgInfoIntrinsic>(Inst)) continue; + if (isa<DbgInfoIntrinsic>(Inst)) + continue; // If these two calls do not interfere, look past it. - switch (AA->getModRefInfo(CS, InstCS)) { + switch (AA.getModRefInfo(CS, InstCS)) { case MRI_NoModRef: // If the two calls are the same, return InstCS as a Def, so that // CS can be found redundant and eliminated. @@ -261,8 +220,8 @@ getCallSiteDependencyFrom(CallSite CS, bool isReadOnlyCall, return MemDepResult::getNonFuncLocal(); } -/// isLoadLoadClobberIfExtendedToFullWidth - Return true if LI is a load that -/// would fully overlap MemLoc if done as a wider legal integer load. +/// Return true if LI is a load that would fully overlap MemLoc if done as +/// a wider legal integer load. /// /// MemLocBase, MemLocOffset are lazily computed here the first time the /// base/offs of memloc is needed. @@ -276,23 +235,17 @@ static bool isLoadLoadClobberIfExtendedToFullWidth(const MemoryLocation &MemLoc, if (!MemLocBase) MemLocBase = GetPointerBaseWithConstantOffset(MemLoc.Ptr, MemLocOffs, DL); - unsigned Size = MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( + unsigned Size = MemoryDependenceResults::getLoadLoadClobberFullWidthSize( MemLocBase, MemLocOffs, MemLoc.Size, LI); return Size != 0; } -/// getLoadLoadClobberFullWidthSize - This is a little bit of analysis that -/// looks at a memory location for a load (specified by MemLocBase, Offs, -/// and Size) and compares it against a load. If the specified load could -/// be safely widened to a larger integer load that is 1) still efficient, -/// 2) safe for the target, and 3) would provide the specified memory -/// location value, then this function returns the size in bytes of the -/// load width to use. If not, this returns zero. -unsigned MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( +unsigned MemoryDependenceResults::getLoadLoadClobberFullWidthSize( const Value *MemLocBase, int64_t MemLocOffs, unsigned MemLocSize, const LoadInst *LI) { // We can only extend simple integer loads. - if (!isa<IntegerType>(LI->getType()) || !LI->isSimple()) return 0; + if (!isa<IntegerType>(LI->getType()) || !LI->isSimple()) + return 0; // Load widening is hostile to ThreadSanitizer: it may cause false positives // or make the reports more cryptic (access sizes are wrong). @@ -308,7 +261,8 @@ unsigned MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( // If the two pointers are not based on the same pointer, we can't tell that // they are related. - if (LIBase != MemLocBase) return 0; + if (LIBase != MemLocBase) + return 0; // Okay, the two values are based on the same pointer, but returned as // no-alias. This happens when we have things like two byte loads at "P+1" @@ -317,7 +271,8 @@ unsigned MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( // the bits required by MemLoc. // If MemLoc is before LI, then no widening of LI will help us out. - if (MemLocOffs < LIOffs) return 0; + if (MemLocOffs < LIOffs) + return 0; // Get the alignment of the load in bytes. We assume that it is safe to load // any legal integer up to this size without a problem. For example, if we're @@ -326,21 +281,22 @@ unsigned MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( // to i16. unsigned LoadAlign = LI->getAlignment(); - int64_t MemLocEnd = MemLocOffs+MemLocSize; + int64_t MemLocEnd = MemLocOffs + MemLocSize; // If no amount of rounding up will let MemLoc fit into LI, then bail out. - if (LIOffs+LoadAlign < MemLocEnd) return 0; + if (LIOffs + LoadAlign < MemLocEnd) + return 0; // This is the size of the load to try. Start with the next larger power of // two. - unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits()/8U; + unsigned NewLoadByteSize = LI->getType()->getPrimitiveSizeInBits() / 8U; NewLoadByteSize = NextPowerOf2(NewLoadByteSize); while (1) { // If this load size is bigger than our known alignment or would not fit // into a native integer register, then we fail. if (NewLoadByteSize > LoadAlign || - !DL.fitsInLegalInteger(NewLoadByteSize*8)) + !DL.fitsInLegalInteger(NewLoadByteSize * 8)) return 0; if (LIOffs + NewLoadByteSize > MemLocEnd && @@ -352,7 +308,7 @@ unsigned MemoryDependenceAnalysis::getLoadLoadClobberFullWidthSize( return 0; // If a load of this width would include all of MemLoc, then we succeed. - if (LIOffs+NewLoadByteSize >= MemLocEnd) + if (LIOffs + NewLoadByteSize >= MemLocEnd) return NewLoadByteSize; NewLoadByteSize <<= 1; @@ -369,14 +325,7 @@ static bool isVolatile(Instruction *Inst) { return false; } - -/// getPointerDependencyFrom - Return the instruction on which a memory -/// location depends. If isLoad is true, this routine ignores may-aliases with -/// read-only operations. If isLoad is false, this routine ignores may-aliases -/// with reads from read-only locations. If possible, pass the query -/// instruction as well; this function may take advantage of the metadata -/// annotated to the query instruction to refine the result. -MemDepResult MemoryDependenceAnalysis::getPointerDependencyFrom( +MemDepResult MemoryDependenceResults::getPointerDependencyFrom( const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, BasicBlock *BB, Instruction *QueryInst) { @@ -393,7 +342,7 @@ MemDepResult MemoryDependenceAnalysis::getPointerDependencyFrom( } MemDepResult -MemoryDependenceAnalysis::getInvariantGroupPointerDependency(LoadInst *LI, +MemoryDependenceResults::getInvariantGroupPointerDependency(LoadInst *LI, BasicBlock *BB) { Value *LoadOperand = LI->getPointerOperand(); // It's is not safe to walk the use list of global value, because function @@ -416,21 +365,19 @@ MemoryDependenceAnalysis::getInvariantGroupPointerDependency(LoadInst *LI, continue; if (auto *BCI = dyn_cast<BitCastInst>(Ptr)) { - if (!Seen.count(BCI->getOperand(0))) { + if (Seen.insert(BCI->getOperand(0)).second) { LoadOperandsQueue.push_back(BCI->getOperand(0)); - Seen.insert(BCI->getOperand(0)); } } for (Use &Us : Ptr->uses()) { auto *U = dyn_cast<Instruction>(Us.getUser()); - if (!U || U == LI || !DT->dominates(U, LI)) + if (!U || U == LI || !DT.dominates(U, LI)) continue; if (auto *BCI = dyn_cast<BitCastInst>(U)) { - if (!Seen.count(BCI)) { + if (Seen.insert(BCI).second) { LoadOperandsQueue.push_back(BCI); - Seen.insert(BCI); } continue; } @@ -445,7 +392,7 @@ MemoryDependenceAnalysis::getInvariantGroupPointerDependency(LoadInst *LI, return Result; } -MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( +MemDepResult MemoryDependenceResults::getSimplePointerDependencyFrom( const MemoryLocation &MemLoc, bool isLoad, BasicBlock::iterator ScanIt, BasicBlock *BB, Instruction *QueryInst) { @@ -455,7 +402,7 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( bool isInvariantLoad = false; // We must be careful with atomic accesses, as they may allow another thread - // to touch this location, cloberring it. We are conservative: if the + // to touch this location, clobbering it. We are conservative: if the // QueryInst is not a simple (non-atomic) memory access, we automatically // return getClobber. // If it is simple, we know based on the results of @@ -476,9 +423,9 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( // with synchronization from 1 to 2 and from 3 to 4, resulting in %val // being 42. A key property of this program however is that if either // 1 or 4 were missing, there would be a race between the store of 42 - // either the store of 0 or the load (making the whole progam racy). + // either the store of 0 or the load (making the whole program racy). // The paper mentioned above shows that the same property is respected - // by every program that can detect any optimisation of that kind: either + // by every program that can detect any optimization of that kind: either // it is racy (undefined) or there is a release followed by an acquire // between the pair of accesses under consideration. @@ -500,13 +447,30 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( // AliasAnalysis::callCapturesBefore. OrderedBasicBlock OBB(BB); + // Return "true" if and only if the instruction I is either a non-simple + // load or a non-simple store. + auto isNonSimpleLoadOrStore = [](Instruction *I) -> bool { + if (auto *LI = dyn_cast<LoadInst>(I)) + return !LI->isSimple(); + if (auto *SI = dyn_cast<StoreInst>(I)) + return !SI->isSimple(); + return false; + }; + + // Return "true" if I is not a load and not a store, but it does access + // memory. + auto isOtherMemAccess = [](Instruction *I) -> bool { + return !isa<LoadInst>(I) && !isa<StoreInst>(I) && I->mayReadOrWriteMemory(); + }; + // Walk backwards through the basic block, looking for dependencies. while (ScanIt != BB->begin()) { Instruction *Inst = &*--ScanIt; if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) // Debug intrinsics don't (and can't) cause dependencies. - if (isa<DbgInfoIntrinsic>(II)) continue; + if (isa<DbgInfoIntrinsic>(II)) + continue; // Limit the amount of scanning we do so we don't end up with quadratic // running time on extreme testcases. @@ -522,17 +486,17 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( // pointer, not on query pointers that are indexed off of them. It'd // be nice to handle that at some point (the right approach is to use // GetPointerBaseWithConstantOffset). - if (AA->isMustAlias(MemoryLocation(II->getArgOperand(1)), MemLoc)) + if (AA.isMustAlias(MemoryLocation(II->getArgOperand(1)), MemLoc)) return MemDepResult::getDef(II); continue; } } - // Values depend on loads if the pointers are must aliased. This means that - // a load depends on another must aliased load from the same value. - // One exception is atomic loads: a value can depend on an atomic load that it - // does not alias with when this atomic load indicates that another thread may - // be accessing the location. + // Values depend on loads if the pointers are must aliased. This means + // that a load depends on another must aliased load from the same value. + // One exception is atomic loads: a value can depend on an atomic load that + // it does not alias with when this atomic load indicates that another + // thread may be accessing the location. if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { // While volatile access cannot be eliminated, they do not have to clobber @@ -547,30 +511,23 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( return MemDepResult::getClobber(LI); // Otherwise, volatile doesn't imply any special ordering } - + // Atomic loads have complications involved. - // A Monotonic (or higher) load is OK if the query inst is itself not atomic. + // A Monotonic (or higher) load is OK if the query inst is itself not + // atomic. // FIXME: This is overly conservative. - if (LI->isAtomic() && LI->getOrdering() > Unordered) { - if (!QueryInst) - return MemDepResult::getClobber(LI); - if (LI->getOrdering() != Monotonic) + if (LI->isAtomic() && isStrongerThanUnordered(LI->getOrdering())) { + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) return MemDepResult::getClobber(LI); - if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { - if (!QueryLI->isSimple()) - return MemDepResult::getClobber(LI); - } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { - if (!QuerySI->isSimple()) - return MemDepResult::getClobber(LI); - } else if (QueryInst->mayReadOrWriteMemory()) { + if (LI->getOrdering() != AtomicOrdering::Monotonic) return MemDepResult::getClobber(LI); - } } MemoryLocation LoadLoc = MemoryLocation::get(LI); // If we found a pointer, check if it could be the same as our pointer. - AliasResult R = AA->alias(LoadLoc, MemLoc); + AliasResult R = AA.alias(LoadLoc, MemLoc); if (isLoad) { if (R == NoAlias) { @@ -614,7 +571,7 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( continue; // Stores don't alias loads from read-only memory. - if (AA->pointsToConstantMemory(LoadLoc)) + if (AA.pointsToConstantMemory(LoadLoc)) continue; // Stores depend on may/must aliased loads. @@ -625,20 +582,12 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( // Atomic stores have complications involved. // A Monotonic store is OK if the query inst is itself not atomic. // FIXME: This is overly conservative. - if (!SI->isUnordered()) { - if (!QueryInst) - return MemDepResult::getClobber(SI); - if (SI->getOrdering() != Monotonic) + if (!SI->isUnordered() && SI->isAtomic()) { + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) return MemDepResult::getClobber(SI); - if (auto *QueryLI = dyn_cast<LoadInst>(QueryInst)) { - if (!QueryLI->isSimple()) - return MemDepResult::getClobber(SI); - } else if (auto *QuerySI = dyn_cast<StoreInst>(QueryInst)) { - if (!QuerySI->isSimple()) - return MemDepResult::getClobber(SI); - } else if (QueryInst->mayReadOrWriteMemory()) { + if (SI->getOrdering() != AtomicOrdering::Monotonic) return MemDepResult::getClobber(SI); - } } // FIXME: this is overly conservative. @@ -646,12 +595,14 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( // non-aliasing locations, as normal accesses can for example be reordered // with volatile accesses. if (SI->isVolatile()) - return MemDepResult::getClobber(SI); + if (!QueryInst || isNonSimpleLoadOrStore(QueryInst) || + isOtherMemAccess(QueryInst)) + return MemDepResult::getClobber(SI); // If alias analysis can tell that this store is guaranteed to not modify // the query pointer, ignore it. Use getModRefInfo to handle cases where // the query pointer points to constant memory etc. - if (AA->getModRefInfo(SI, MemLoc) == MRI_NoModRef) + if (AA.getModRefInfo(SI, MemLoc) == MRI_NoModRef) continue; // Ok, this store might clobber the query pointer. Check to see if it is @@ -659,50 +610,46 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( MemoryLocation StoreLoc = MemoryLocation::get(SI); // If we found a pointer, check if it could be the same as our pointer. - AliasResult R = AA->alias(StoreLoc, MemLoc); + AliasResult R = AA.alias(StoreLoc, MemLoc); if (R == NoAlias) continue; if (R == MustAlias) return MemDepResult::getDef(Inst); if (isInvariantLoad) - continue; + continue; return MemDepResult::getClobber(Inst); } // If this is an allocation, and if we know that the accessed pointer is to // the allocation, return Def. This means that there is no dependence and // the access can be optimized based on that. For example, a load could - // turn into undef. - // Note: Only determine this to be a malloc if Inst is the malloc call, not - // a subsequent bitcast of the malloc call result. There can be stores to - // the malloced memory between the malloc call and its bitcast uses, and we - // need to continue scanning until the malloc call. - if (isa<AllocaInst>(Inst) || isNoAliasFn(Inst, TLI)) { + // turn into undef. Note that we can bypass the allocation itself when + // looking for a clobber in many cases; that's an alias property and is + // handled by BasicAA. + if (isa<AllocaInst>(Inst) || isNoAliasFn(Inst, &TLI)) { const Value *AccessPtr = GetUnderlyingObject(MemLoc.Ptr, DL); - - if (AccessPtr == Inst || AA->isMustAlias(Inst, AccessPtr)) + if (AccessPtr == Inst || AA.isMustAlias(Inst, AccessPtr)) return MemDepResult::getDef(Inst); - if (isInvariantLoad) - continue; - // Be conservative if the accessed pointer may alias the allocation - - // fallback to the generic handling below. - if ((AA->alias(Inst, AccessPtr) == NoAlias) && - // If the allocation is not aliased and does not read memory (like - // strdup), it is safe to ignore. - (isa<AllocaInst>(Inst) || isMallocLikeFn(Inst, TLI) || - isCallocLikeFn(Inst, TLI))) - continue; } if (isInvariantLoad) - continue; + continue; + + // A release fence requires that all stores complete before it, but does + // not prevent the reordering of following loads or stores 'before' the + // fence. As a result, we look past it when finding a dependency for + // loads. DSE uses this to find preceeding stores to delete and thus we + // can't bypass the fence if the query instruction is a store. + if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + if (isLoad && FI->getOrdering() == AtomicOrdering::Release) + continue; // See if this instruction (e.g. a call or vaarg) mod/ref's the pointer. - ModRefInfo MR = AA->getModRefInfo(Inst, MemLoc); + ModRefInfo MR = AA.getModRefInfo(Inst, MemLoc); // If necessary, perform additional analysis. if (MR == MRI_ModRef) - MR = AA->callCapturesBefore(Inst, MemLoc, DT, &OBB); + MR = AA.callCapturesBefore(Inst, MemLoc, &DT, &OBB); switch (MR) { case MRI_NoModRef: // If the call has no effect on the queried pointer, just ignore it. @@ -727,9 +674,7 @@ MemDepResult MemoryDependenceAnalysis::getSimplePointerDependencyFrom( return MemDepResult::getNonFuncLocal(); } -/// getDependency - Return the instruction on which a memory operation -/// depends. -MemDepResult MemoryDependenceAnalysis::getDependency(Instruction *QueryInst) { +MemDepResult MemoryDependenceResults::getDependency(Instruction *QueryInst) { Instruction *ScanPos = QueryInst; // Check for a cached result @@ -760,7 +705,7 @@ MemDepResult MemoryDependenceAnalysis::getDependency(Instruction *QueryInst) { LocalCache = MemDepResult::getNonFuncLocal(); } else { MemoryLocation MemLoc; - ModRefInfo MR = GetLocation(QueryInst, MemLoc, *TLI); + ModRefInfo MR = GetLocation(QueryInst, MemLoc, TLI); if (MemLoc.Ptr) { // If we can do a pointer scan, make it happen. bool isLoad = !(MR & MRI_Mod); @@ -771,7 +716,7 @@ MemDepResult MemoryDependenceAnalysis::getDependency(Instruction *QueryInst) { MemLoc, isLoad, ScanPos->getIterator(), QueryParent, QueryInst); } else if (isa<CallInst>(QueryInst) || isa<InvokeInst>(QueryInst)) { CallSite QueryCS(QueryInst); - bool isReadOnly = AA->onlyReadsMemory(QueryCS); + bool isReadOnly = AA.onlyReadsMemory(QueryCS); LocalCache = getCallSiteDependencyFrom( QueryCS, isReadOnly, ScanPos->getIterator(), QueryParent); } else @@ -787,40 +732,29 @@ MemDepResult MemoryDependenceAnalysis::getDependency(Instruction *QueryInst) { } #ifndef NDEBUG -/// AssertSorted - This method is used when -debug is specified to verify that -/// cache arrays are properly kept sorted. -static void AssertSorted(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, +/// This method is used when -debug is specified to verify that cache arrays +/// are properly kept sorted. +static void AssertSorted(MemoryDependenceResults::NonLocalDepInfo &Cache, int Count = -1) { - if (Count == -1) Count = Cache.size(); + if (Count == -1) + Count = Cache.size(); assert(std::is_sorted(Cache.begin(), Cache.begin() + Count) && "Cache isn't sorted!"); } #endif -/// getNonLocalCallDependency - Perform a full dependency query for the -/// specified call, returning the set of blocks that the value is -/// potentially live across. The returned set of results will include a -/// "NonLocal" result for all blocks where the value is live across. -/// -/// This method assumes the instruction returns a "NonLocal" dependency -/// within its own block. -/// -/// This returns a reference to an internal data structure that may be -/// invalidated on the next non-local query or when an instruction is -/// removed. Clients must copy this data if they want it around longer than -/// that. -const MemoryDependenceAnalysis::NonLocalDepInfo & -MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { +const MemoryDependenceResults::NonLocalDepInfo & +MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { assert(getDependency(QueryCS.getInstruction()).isNonLocal() && - "getNonLocalCallDependency should only be used on calls with non-local deps!"); + "getNonLocalCallDependency should only be used on calls with " + "non-local deps!"); PerInstNLInfo &CacheP = NonLocalDeps[QueryCS.getInstruction()]; NonLocalDepInfo &Cache = CacheP.first; - /// DirtyBlocks - This is the set of blocks that need to be recomputed. In - /// the cached case, this can happen due to instructions being deleted etc. In - /// the uncached case, this starts out as the set of predecessors we care - /// about. - SmallVector<BasicBlock*, 32> DirtyBlocks; + // This is the set of blocks that need to be recomputed. In the cached case, + // this can happen due to instructions being deleted etc. In the uncached + // case, this starts out as the set of predecessors we care about. + SmallVector<BasicBlock *, 32> DirtyBlocks; if (!Cache.empty()) { // Okay, we have a cache entry. If we know it is not dirty, just return it @@ -832,16 +766,15 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { // If we already have a partially computed set of results, scan them to // determine what is dirty, seeding our initial DirtyBlocks worklist. - for (NonLocalDepInfo::iterator I = Cache.begin(), E = Cache.end(); - I != E; ++I) - if (I->getResult().isDirty()) - DirtyBlocks.push_back(I->getBB()); + for (auto &Entry : Cache) + if (Entry.getResult().isDirty()) + DirtyBlocks.push_back(Entry.getBB()); // Sort the cache so that we can do fast binary search lookups below. std::sort(Cache.begin(), Cache.end()); ++NumCacheDirtyNonLocal; - //cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " + // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " // << Cache.size() << " cached: " << *QueryInst; } else { // Seed DirtyBlocks with each of the preds of QueryInst's block. @@ -852,9 +785,9 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { } // isReadonlyCall - If this is a read-only call, we can be more aggressive. - bool isReadonlyCall = AA->onlyReadsMemory(QueryCS); + bool isReadonlyCall = AA.onlyReadsMemory(QueryCS); - SmallPtrSet<BasicBlock*, 64> Visited; + SmallPtrSet<BasicBlock *, 32> Visited; unsigned NumSortedEntries = Cache.size(); DEBUG(AssertSorted(Cache)); @@ -872,13 +805,13 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { // the cache set. If so, find it. DEBUG(AssertSorted(Cache, NumSortedEntries)); NonLocalDepInfo::iterator Entry = - std::upper_bound(Cache.begin(), Cache.begin()+NumSortedEntries, - NonLocalDepEntry(DirtyBB)); + std::upper_bound(Cache.begin(), Cache.begin() + NumSortedEntries, + NonLocalDepEntry(DirtyBB)); if (Entry != Cache.begin() && std::prev(Entry)->getBB() == DirtyBB) --Entry; NonLocalDepEntry *ExistingResult = nullptr; - if (Entry != Cache.begin()+NumSortedEntries && + if (Entry != Cache.begin() + NumSortedEntries && Entry->getBB() == DirtyBB) { // If we already have an entry, and if it isn't already dirty, the block // is done. @@ -905,7 +838,8 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { MemDepResult Dep; if (ScanPos != DirtyBB->begin()) { - Dep = getCallSiteDependencyFrom(QueryCS, isReadonlyCall,ScanPos, DirtyBB); + Dep = + getCallSiteDependencyFrom(QueryCS, isReadonlyCall, ScanPos, DirtyBB); } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) { // No dependence found. If this is the entry block of the function, it is // a clobber, otherwise it is unknown. @@ -940,16 +874,8 @@ MemoryDependenceAnalysis::getNonLocalCallDependency(CallSite QueryCS) { return Cache; } -/// getNonLocalPointerDependency - Perform a full dependency query for an -/// access to the specified (non-volatile) memory location, returning the -/// set of instructions that either define or clobber the value. -/// -/// This method assumes the pointer has a "NonLocal" dependency within its -/// own block. -/// -void MemoryDependenceAnalysis:: -getNonLocalPointerDependency(Instruction *QueryInst, - SmallVectorImpl<NonLocalDepResult> &Result) { +void MemoryDependenceResults::getNonLocalPointerDependency( + Instruction *QueryInst, SmallVectorImpl<NonLocalDepResult> &Result) { const MemoryLocation Loc = MemoryLocation::get(QueryInst); bool isLoad = isa<LoadInst>(QueryInst); BasicBlock *FromBB = QueryInst->getParent(); @@ -958,7 +884,7 @@ getNonLocalPointerDependency(Instruction *QueryInst, assert(Loc.Ptr->getType()->isPointerTy() && "Can't get pointer deps of a non-pointer!"); Result.clear(); - + // This routine does not expect to deal with volatile instructions. // Doing so would require piping through the QueryInst all the way through. // TODO: volatiles can't be elided, but they can be reordered with other @@ -976,46 +902,44 @@ getNonLocalPointerDependency(Instruction *QueryInst, return false; }; if (isVolatile(QueryInst) || isOrdered(QueryInst)) { - Result.push_back(NonLocalDepResult(FromBB, - MemDepResult::getUnknown(), + Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), const_cast<Value *>(Loc.Ptr))); return; } const DataLayout &DL = FromBB->getModule()->getDataLayout(); - PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, AC); + PHITransAddr Address(const_cast<Value *>(Loc.Ptr), DL, &AC); // This is the set of blocks we've inspected, and the pointer we consider in // each block. Because of critical edges, we currently bail out if querying // a block with multiple different pointers. This can happen during PHI // translation. - DenseMap<BasicBlock*, Value*> Visited; - if (!getNonLocalPointerDepFromBB(QueryInst, Address, Loc, isLoad, FromBB, + DenseMap<BasicBlock *, Value *> Visited; + if (getNonLocalPointerDepFromBB(QueryInst, Address, Loc, isLoad, FromBB, Result, Visited, true)) return; Result.clear(); - Result.push_back(NonLocalDepResult(FromBB, - MemDepResult::getUnknown(), + Result.push_back(NonLocalDepResult(FromBB, MemDepResult::getUnknown(), const_cast<Value *>(Loc.Ptr))); } -/// GetNonLocalInfoForBlock - Compute the memdep value for BB with -/// Pointer/PointeeSize using either cached information in Cache or by doing a -/// lookup (which may use dirty cache info if available). If we do a lookup, -/// add the result to the cache. -MemDepResult MemoryDependenceAnalysis::GetNonLocalInfoForBlock( +/// Compute the memdep value for BB with Pointer/PointeeSize using either +/// cached information in Cache or by doing a lookup (which may use dirty cache +/// info if available). +/// +/// If we do a lookup, add the result to the cache. +MemDepResult MemoryDependenceResults::GetNonLocalInfoForBlock( Instruction *QueryInst, const MemoryLocation &Loc, bool isLoad, BasicBlock *BB, NonLocalDepInfo *Cache, unsigned NumSortedEntries) { // Do a binary search to see if we already have an entry for this block in // the cache set. If so, find it. - NonLocalDepInfo::iterator Entry = - std::upper_bound(Cache->begin(), Cache->begin()+NumSortedEntries, - NonLocalDepEntry(BB)); - if (Entry != Cache->begin() && (Entry-1)->getBB() == BB) + NonLocalDepInfo::iterator Entry = std::upper_bound( + Cache->begin(), Cache->begin() + NumSortedEntries, NonLocalDepEntry(BB)); + if (Entry != Cache->begin() && (Entry - 1)->getBB() == BB) --Entry; NonLocalDepEntry *ExistingResult = nullptr; - if (Entry != Cache->begin()+NumSortedEntries && Entry->getBB() == BB) + if (Entry != Cache->begin() + NumSortedEntries && Entry->getBB() == BB) ExistingResult = &*Entry; // If we have a cached entry, and it is non-dirty, use it as the value for @@ -1043,8 +967,8 @@ MemDepResult MemoryDependenceAnalysis::GetNonLocalInfoForBlock( } // Scan the block for the dependency. - MemDepResult Dep = getPointerDependencyFrom(Loc, isLoad, ScanPos, BB, - QueryInst); + MemDepResult Dep = + getPointerDependencyFrom(Loc, isLoad, ScanPos, BB, QueryInst); // If we had a dirty entry for the block, update it. Otherwise, just add // a new entry. @@ -1068,11 +992,12 @@ MemDepResult MemoryDependenceAnalysis::GetNonLocalInfoForBlock( return Dep; } -/// SortNonLocalDepInfoCache - Sort the NonLocalDepInfo cache, given a certain -/// number of elements in the array that are already properly ordered. This is -/// optimized for the case when only a few entries are added. +/// Sort the NonLocalDepInfo cache, given a certain number of elements in the +/// array that are already properly ordered. +/// +/// This is optimized for the case when only a few entries are added. static void -SortNonLocalDepInfoCache(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, +SortNonLocalDepInfoCache(MemoryDependenceResults::NonLocalDepInfo &Cache, unsigned NumSortedEntries) { switch (Cache.size() - NumSortedEntries) { case 0: @@ -1082,8 +1007,8 @@ SortNonLocalDepInfoCache(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, // Two new entries, insert the last one into place. NonLocalDepEntry Val = Cache.back(); Cache.pop_back(); - MemoryDependenceAnalysis::NonLocalDepInfo::iterator Entry = - std::upper_bound(Cache.begin(), Cache.end()-1, Val); + MemoryDependenceResults::NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.end() - 1, Val); Cache.insert(Entry, Val); // FALL THROUGH. } @@ -1092,8 +1017,8 @@ SortNonLocalDepInfoCache(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, if (Cache.size() != 1) { NonLocalDepEntry Val = Cache.back(); Cache.pop_back(); - MemoryDependenceAnalysis::NonLocalDepInfo::iterator Entry = - std::upper_bound(Cache.begin(), Cache.end(), Val); + MemoryDependenceResults::NonLocalDepInfo::iterator Entry = + std::upper_bound(Cache.begin(), Cache.end(), Val); Cache.insert(Entry, Val); } break; @@ -1104,19 +1029,20 @@ SortNonLocalDepInfoCache(MemoryDependenceAnalysis::NonLocalDepInfo &Cache, } } -/// getNonLocalPointerDepFromBB - Perform a dependency query based on -/// pointer/pointeesize starting at the end of StartBB. Add any clobber/def -/// results to the results vector and keep track of which blocks are visited in -/// 'Visited'. +/// Perform a dependency query based on pointer/pointeesize starting at the end +/// of StartBB. +/// +/// Add any clobber/def results to the results vector and keep track of which +/// blocks are visited in 'Visited'. /// /// This has special behavior for the first block queries (when SkipFirstBlock /// is true). In this special case, it ignores the contents of the specified /// block and starts returning dependence info for its predecessors. /// -/// This function returns false on success, or true to indicate that it could +/// This function returns true on success, or false to indicate that it could /// not compute dependence information for some reason. This should be treated /// as a clobber dependence on the first instruction in the predecessor block. -bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( +bool MemoryDependenceResults::getNonLocalPointerDepFromBB( Instruction *QueryInst, const PHITransAddr &Pointer, const MemoryLocation &Loc, bool isLoad, BasicBlock *StartBB, SmallVectorImpl<NonLocalDepResult> &Result, @@ -1135,7 +1061,7 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // Get the NLPI for CacheKey, inserting one into the map if it doesn't // already have one. std::pair<CachedNonLocalPointerInfo::iterator, bool> Pair = - NonLocalPointerDeps.insert(std::make_pair(CacheKey, InitialNLPI)); + NonLocalPointerDeps.insert(std::make_pair(CacheKey, InitialNLPI)); NonLocalPointerInfo *CacheInfo = &Pair.first->second; // If we already have a cache entry for this CacheKey, we may need to do some @@ -1146,18 +1072,16 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // cached data and proceed with the query at the greater size. CacheInfo->Pair = BBSkipFirstBlockPair(); CacheInfo->Size = Loc.Size; - for (NonLocalDepInfo::iterator DI = CacheInfo->NonLocalDeps.begin(), - DE = CacheInfo->NonLocalDeps.end(); DI != DE; ++DI) - if (Instruction *Inst = DI->getResult().getInst()) + for (auto &Entry : CacheInfo->NonLocalDeps) + if (Instruction *Inst = Entry.getResult().getInst()) RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); CacheInfo->NonLocalDeps.clear(); } else if (CacheInfo->Size > Loc.Size) { // This query's Size is less than the cached one. Conservatively restart // the query using the greater size. - return getNonLocalPointerDepFromBB(QueryInst, Pointer, - Loc.getWithNewSize(CacheInfo->Size), - isLoad, StartBB, Result, Visited, - SkipFirstBlock); + return getNonLocalPointerDepFromBB( + QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad, + StartBB, Result, Visited, SkipFirstBlock); } // If the query's AATags are inconsistent with the cached one, @@ -1167,17 +1091,15 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( if (CacheInfo->AATags) { CacheInfo->Pair = BBSkipFirstBlockPair(); CacheInfo->AATags = AAMDNodes(); - for (NonLocalDepInfo::iterator DI = CacheInfo->NonLocalDeps.begin(), - DE = CacheInfo->NonLocalDeps.end(); DI != DE; ++DI) - if (Instruction *Inst = DI->getResult().getInst()) + for (auto &Entry : CacheInfo->NonLocalDeps) + if (Instruction *Inst = Entry.getResult().getInst()) RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); CacheInfo->NonLocalDeps.clear(); } if (Loc.AATags) - return getNonLocalPointerDepFromBB(QueryInst, - Pointer, Loc.getWithoutAATags(), - isLoad, StartBB, Result, Visited, - SkipFirstBlock); + return getNonLocalPointerDepFromBB( + QueryInst, Pointer, Loc.getWithoutAATags(), isLoad, StartBB, Result, + Visited, SkipFirstBlock); } } @@ -1192,37 +1114,33 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // to ensure that if a block in the results set is in the visited set that // it was for the same pointer query. if (!Visited.empty()) { - for (NonLocalDepInfo::iterator I = Cache->begin(), E = Cache->end(); - I != E; ++I) { - DenseMap<BasicBlock*, Value*>::iterator VI = Visited.find(I->getBB()); + for (auto &Entry : *Cache) { + DenseMap<BasicBlock *, Value *>::iterator VI = + Visited.find(Entry.getBB()); if (VI == Visited.end() || VI->second == Pointer.getAddr()) continue; - // We have a pointer mismatch in a block. Just return clobber, saying + // We have a pointer mismatch in a block. Just return false, saying // that something was clobbered in this result. We could also do a // non-fully cached query, but there is little point in doing this. - return true; + return false; } } Value *Addr = Pointer.getAddr(); - for (NonLocalDepInfo::iterator I = Cache->begin(), E = Cache->end(); - I != E; ++I) { - Visited.insert(std::make_pair(I->getBB(), Addr)); - if (I->getResult().isNonLocal()) { + for (auto &Entry : *Cache) { + Visited.insert(std::make_pair(Entry.getBB(), Addr)); + if (Entry.getResult().isNonLocal()) { continue; } - if (!DT) { - Result.push_back(NonLocalDepResult(I->getBB(), - MemDepResult::getUnknown(), - Addr)); - } else if (DT->isReachableFromEntry(I->getBB())) { - Result.push_back(NonLocalDepResult(I->getBB(), I->getResult(), Addr)); + if (DT.isReachableFromEntry(Entry.getBB())) { + Result.push_back( + NonLocalDepResult(Entry.getBB(), Entry.getResult(), Addr)); } } ++NumCacheCompleteNonLocalPtr; - return false; + return true; } // Otherwise, either this is a new block, a block with an invalid cache @@ -1234,11 +1152,11 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( else CacheInfo->Pair = BBSkipFirstBlockPair(); - SmallVector<BasicBlock*, 32> Worklist; + SmallVector<BasicBlock *, 32> Worklist; Worklist.push_back(StartBB); // PredList used inside loop. - SmallVector<std::pair<BasicBlock*, PHITransAddr>, 16> PredList; + SmallVector<std::pair<BasicBlock *, PHITransAddr>, 16> PredList; // Keep track of the entries that we know are sorted. Previously cached // entries will all be sorted. The entries we add we only sort on demand (we @@ -1246,6 +1164,8 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // won't get any reuse from currently inserted values, because we don't // revisit blocks after we insert info for them. unsigned NumSortedEntries = Cache->size(); + unsigned WorklistEntries = BlockNumberLimit; + bool GotWorklistLimit = false; DEBUG(AssertSorted(*Cache)); while (!Worklist.empty()) { @@ -1266,7 +1186,7 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // specific block queries) but we can't do the fastpath "return all // results from the set". Clear out the indicator for this. CacheInfo->Pair = BBSkipFirstBlockPair(); - return true; + return false; } // Skip the first block if we have it. @@ -1278,18 +1198,12 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // Get the dependency info for Pointer in BB. If we have cached // information, we will use it, otherwise we compute it. DEBUG(AssertSorted(*Cache, NumSortedEntries)); - MemDepResult Dep = GetNonLocalInfoForBlock(QueryInst, - Loc, isLoad, BB, Cache, - NumSortedEntries); + MemDepResult Dep = GetNonLocalInfoForBlock(QueryInst, Loc, isLoad, BB, + Cache, NumSortedEntries); // If we got a Def or Clobber, add this to the list of results. if (!Dep.isNonLocal()) { - if (!DT) { - Result.push_back(NonLocalDepResult(BB, - MemDepResult::getUnknown(), - Pointer.getAddr())); - continue; - } else if (DT->isReachableFromEntry(BB)) { + if (DT.isReachableFromEntry(BB)) { Result.push_back(NonLocalDepResult(BB, Dep, Pointer.getAddr())); continue; } @@ -1302,11 +1216,11 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // the same Pointer. if (!Pointer.NeedsPHITranslationFromBlock(BB)) { SkipFirstBlock = false; - SmallVector<BasicBlock*, 16> NewBlocks; + SmallVector<BasicBlock *, 16> NewBlocks; for (BasicBlock *Pred : PredCache.get(BB)) { // Verify that we haven't looked at this block yet. - std::pair<DenseMap<BasicBlock*,Value*>::iterator, bool> - InsertRes = Visited.insert(std::make_pair(Pred, Pointer.getAddr())); + std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = + Visited.insert(std::make_pair(Pred, Pointer.getAddr())); if (InsertRes.second) { // First time we've looked at *PI. NewBlocks.push_back(Pred); @@ -1324,6 +1238,15 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( goto PredTranslationFailure; } } + if (NewBlocks.size() > WorklistEntries) { + // Make sure to clean up the Visited map before continuing on to + // PredTranslationFailure. + for (unsigned i = 0; i < NewBlocks.size(); i++) + Visited.erase(NewBlocks[i]); + GotWorklistLimit = true; + goto PredTranslationFailure; + } + WorklistEntries -= NewBlocks.size(); Worklist.append(NewBlocks.begin(), NewBlocks.end()); continue; } @@ -1351,7 +1274,7 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // Get the PHI translated pointer in this predecessor. This can fail if // not translatable, in which case the getAddr() returns null. PHITransAddr &PredPointer = PredList.back().second; - PredPointer.PHITranslateValue(BB, Pred, DT, /*MustDominate=*/false); + PredPointer.PHITranslateValue(BB, Pred, &DT, /*MustDominate=*/false); Value *PredPtrVal = PredPointer.getAddr(); // Check to see if we have already visited this pred block with another @@ -1359,8 +1282,8 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // with PHI translation when a critical edge exists and the PHI node in // the successor translates to a pointer value different than the // pointer the block was first analyzed with. - std::pair<DenseMap<BasicBlock*,Value*>::iterator, bool> - InsertRes = Visited.insert(std::make_pair(Pred, PredPtrVal)); + std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> InsertRes = + Visited.insert(std::make_pair(Pred, PredPtrVal)); if (!InsertRes.second) { // We found the pred; take it off the list of preds to visit. @@ -1411,10 +1334,9 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // result conflicted with the Visited list; we have to conservatively // assume it is unknown, but this also does not block PRE of the load. if (!CanTranslate || - getNonLocalPointerDepFromBB(QueryInst, PredPointer, - Loc.getWithNewPtr(PredPtrVal), - isLoad, Pred, - Result, Visited)) { + !getNonLocalPointerDepFromBB(QueryInst, PredPointer, + Loc.getWithNewPtr(PredPtrVal), isLoad, + Pred, Result, Visited)) { // Add the entry to the Result list. NonLocalDepResult Entry(Pred, MemDepResult::getUnknown(), PredPtrVal); Result.push_back(Entry); @@ -1467,35 +1389,38 @@ bool MemoryDependenceAnalysis::getNonLocalPointerDepFromBB( // incoming value. Since we can't phi translate to one of the predecessors, // we have to bail out. if (SkipFirstBlock) - return true; + return false; - for (NonLocalDepInfo::reverse_iterator I = Cache->rbegin(); ; ++I) { - assert(I != Cache->rend() && "Didn't find current block??"); - if (I->getBB() != BB) + bool foundBlock = false; + for (NonLocalDepEntry &I : llvm::reverse(*Cache)) { + if (I.getBB() != BB) continue; - assert((I->getResult().isNonLocal() || !DT->isReachableFromEntry(BB)) && + assert((GotWorklistLimit || I.getResult().isNonLocal() || + !DT.isReachableFromEntry(BB)) && "Should only be here with transparent block"); - I->setResult(MemDepResult::getUnknown()); - Result.push_back(NonLocalDepResult(I->getBB(), I->getResult(), - Pointer.getAddr())); + foundBlock = true; + I.setResult(MemDepResult::getUnknown()); + Result.push_back( + NonLocalDepResult(I.getBB(), I.getResult(), Pointer.getAddr())); break; } + (void)foundBlock; (void)GotWorklistLimit; + assert((foundBlock || GotWorklistLimit) && "Current block not in cache?"); } // Okay, we're done now. If we added new values to the cache, re-sort it. SortNonLocalDepInfoCache(*Cache, NumSortedEntries); DEBUG(AssertSorted(*Cache)); - return false; + return true; } -/// RemoveCachedNonLocalPointerDependencies - If P exists in -/// CachedNonLocalPointerInfo, remove it. -void MemoryDependenceAnalysis:: -RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair P) { - CachedNonLocalPointerInfo::iterator It = - NonLocalPointerDeps.find(P); - if (It == NonLocalPointerDeps.end()) return; +/// If P exists in CachedNonLocalPointerInfo, remove it. +void MemoryDependenceResults::RemoveCachedNonLocalPointerDependencies( + ValueIsLoadPair P) { + CachedNonLocalPointerInfo::iterator It = NonLocalPointerDeps.find(P); + if (It == NonLocalPointerDeps.end()) + return; // Remove all of the entries in the BB->val map. This involves removing // instructions from the reverse map. @@ -1503,7 +1428,8 @@ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair P) { for (unsigned i = 0, e = PInfo.size(); i != e; ++i) { Instruction *Target = PInfo[i].getResult().getInst(); - if (!Target) continue; // Ignore non-local dep results. + if (!Target) + continue; // Ignore non-local dep results. assert(Target->getParent() == PInfo[i].getBB()); // Eliminating the dirty entry from 'Cache', so update the reverse info. @@ -1514,41 +1440,28 @@ RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair P) { NonLocalPointerDeps.erase(It); } - -/// invalidateCachedPointerInfo - This method is used to invalidate cached -/// information about the specified pointer, because it may be too -/// conservative in memdep. This is an optional call that can be used when -/// the client detects an equivalence between the pointer and some other -/// value and replaces the other value with ptr. This can make Ptr available -/// in more places that cached info does not necessarily keep. -void MemoryDependenceAnalysis::invalidateCachedPointerInfo(Value *Ptr) { +void MemoryDependenceResults::invalidateCachedPointerInfo(Value *Ptr) { // If Ptr isn't really a pointer, just ignore it. - if (!Ptr->getType()->isPointerTy()) return; + if (!Ptr->getType()->isPointerTy()) + return; // Flush store info for the pointer. RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, false)); // Flush load info for the pointer. RemoveCachedNonLocalPointerDependencies(ValueIsLoadPair(Ptr, true)); } -/// invalidateCachedPredecessors - Clear the PredIteratorCache info. -/// This needs to be done when the CFG changes, e.g., due to splitting -/// critical edges. -void MemoryDependenceAnalysis::invalidateCachedPredecessors() { +void MemoryDependenceResults::invalidateCachedPredecessors() { PredCache.clear(); } -/// removeInstruction - Remove an instruction from the dependence analysis, -/// updating the dependence of instructions that previously depended on it. -/// This method attempts to keep the cache coherent using the reverse map. -void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { +void MemoryDependenceResults::removeInstruction(Instruction *RemInst) { // Walk through the Non-local dependencies, removing this one as the value // for any cached queries. NonLocalDepMapType::iterator NLDI = NonLocalDeps.find(RemInst); if (NLDI != NonLocalDeps.end()) { NonLocalDepInfo &BlockMap = NLDI->second.first; - for (NonLocalDepInfo::iterator DI = BlockMap.begin(), DE = BlockMap.end(); - DI != DE; ++DI) - if (Instruction *Inst = DI->getResult().getInst()) + for (auto &Entry : BlockMap) + if (Instruction *Inst = Entry.getResult().getInst()) RemoveFromReverseMap(ReverseNonLocalDeps, Inst, RemInst); NonLocalDeps.erase(NLDI); } @@ -1578,7 +1491,7 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { // Loop over all of the things that depend on the instruction we're removing. // - SmallVector<std::pair<Instruction*, Instruction*>, 8> ReverseDepsToAdd; + SmallVector<std::pair<Instruction *, Instruction *>, 8> ReverseDepsToAdd; // If we find RemInst as a clobber or Def in any of the maps for other values, // we need to replace its entry with a dirty version of the instruction after @@ -1603,10 +1516,11 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { LocalDeps[InstDependingOnRemInst] = NewDirtyVal; // Make sure to remember that new things depend on NewDepInst. - assert(NewDirtyVal.getInst() && "There is no way something else can have " + assert(NewDirtyVal.getInst() && + "There is no way something else can have " "a local dep on this if it is a terminator!"); - ReverseDepsToAdd.push_back(std::make_pair(NewDirtyVal.getInst(), - InstDependingOnRemInst)); + ReverseDepsToAdd.push_back( + std::make_pair(NewDirtyVal.getInst(), InstDependingOnRemInst)); } ReverseLocalDeps.erase(ReverseDepIt); @@ -1614,8 +1528,8 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { // Add new reverse deps after scanning the set, to avoid invalidating the // 'ReverseDeps' reference. while (!ReverseDepsToAdd.empty()) { - ReverseLocalDeps[ReverseDepsToAdd.back().first] - .insert(ReverseDepsToAdd.back().second); + ReverseLocalDeps[ReverseDepsToAdd.back().first].insert( + ReverseDepsToAdd.back().second); ReverseDepsToAdd.pop_back(); } } @@ -1629,12 +1543,12 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { // The information is now dirty! INLD.second = true; - for (NonLocalDepInfo::iterator DI = INLD.first.begin(), - DE = INLD.first.end(); DI != DE; ++DI) { - if (DI->getResult().getInst() != RemInst) continue; + for (auto &Entry : INLD.first) { + if (Entry.getResult().getInst() != RemInst) + continue; // Convert to a dirty entry for the subsequent instruction. - DI->setResult(NewDirtyVal); + Entry.setResult(NewDirtyVal); if (Instruction *NextI = NewDirtyVal.getInst()) ReverseDepsToAdd.push_back(std::make_pair(NextI, I)); @@ -1645,8 +1559,8 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { // Add new reverse deps after scanning the set, to avoid invalidating 'Set' while (!ReverseDepsToAdd.empty()) { - ReverseNonLocalDeps[ReverseDepsToAdd.back().first] - .insert(ReverseDepsToAdd.back().second); + ReverseNonLocalDeps[ReverseDepsToAdd.back().first].insert( + ReverseDepsToAdd.back().second); ReverseDepsToAdd.pop_back(); } } @@ -1654,9 +1568,10 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { // If the instruction is in ReverseNonLocalPtrDeps then it appears as a // value in the NonLocalPointerDeps info. ReverseNonLocalPtrDepTy::iterator ReversePtrDepIt = - ReverseNonLocalPtrDeps.find(RemInst); + ReverseNonLocalPtrDeps.find(RemInst); if (ReversePtrDepIt != ReverseNonLocalPtrDeps.end()) { - SmallVector<std::pair<Instruction*, ValueIsLoadPair>,8> ReversePtrDepsToAdd; + SmallVector<std::pair<Instruction *, ValueIsLoadPair>, 8> + ReversePtrDepsToAdd; for (ValueIsLoadPair P : ReversePtrDepIt->second) { assert(P.getPointer() != RemInst && @@ -1668,12 +1583,12 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { NonLocalPointerDeps[P].Pair = BBSkipFirstBlockPair(); // Update any entries for RemInst to use the instruction after it. - for (NonLocalDepInfo::iterator DI = NLPDI.begin(), DE = NLPDI.end(); - DI != DE; ++DI) { - if (DI->getResult().getInst() != RemInst) continue; + for (auto &Entry : NLPDI) { + if (Entry.getResult().getInst() != RemInst) + continue; // Convert to a dirty entry for the subsequent instruction. - DI->setResult(NewDirtyVal); + Entry.setResult(NewDirtyVal); if (Instruction *NewDirtyInst = NewDirtyVal.getInst()) ReversePtrDepsToAdd.push_back(std::make_pair(NewDirtyInst, P)); @@ -1687,70 +1602,107 @@ void MemoryDependenceAnalysis::removeInstruction(Instruction *RemInst) { ReverseNonLocalPtrDeps.erase(ReversePtrDepIt); while (!ReversePtrDepsToAdd.empty()) { - ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first] - .insert(ReversePtrDepsToAdd.back().second); + ReverseNonLocalPtrDeps[ReversePtrDepsToAdd.back().first].insert( + ReversePtrDepsToAdd.back().second); ReversePtrDepsToAdd.pop_back(); } } - assert(!NonLocalDeps.count(RemInst) && "RemInst got reinserted?"); DEBUG(verifyRemoved(RemInst)); } -/// verifyRemoved - Verify that the specified instruction does not occur -/// in our internal data structures. This function verifies by asserting in -/// debug builds. -void MemoryDependenceAnalysis::verifyRemoved(Instruction *D) const { + +/// Verify that the specified instruction does not occur in our internal data +/// structures. +/// +/// This function verifies by asserting in debug builds. +void MemoryDependenceResults::verifyRemoved(Instruction *D) const { #ifndef NDEBUG - for (LocalDepMapType::const_iterator I = LocalDeps.begin(), - E = LocalDeps.end(); I != E; ++I) { - assert(I->first != D && "Inst occurs in data structures"); - assert(I->second.getInst() != D && - "Inst occurs in data structures"); + for (const auto &DepKV : LocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + assert(DepKV.second.getInst() != D && "Inst occurs in data structures"); } - for (CachedNonLocalPointerInfo::const_iterator I =NonLocalPointerDeps.begin(), - E = NonLocalPointerDeps.end(); I != E; ++I) { - assert(I->first.getPointer() != D && "Inst occurs in NLPD map key"); - const NonLocalDepInfo &Val = I->second.NonLocalDeps; - for (NonLocalDepInfo::const_iterator II = Val.begin(), E = Val.end(); - II != E; ++II) - assert(II->getResult().getInst() != D && "Inst occurs as NLPD value"); + for (const auto &DepKV : NonLocalPointerDeps) { + assert(DepKV.first.getPointer() != D && "Inst occurs in NLPD map key"); + for (const auto &Entry : DepKV.second.NonLocalDeps) + assert(Entry.getResult().getInst() != D && "Inst occurs as NLPD value"); } - for (NonLocalDepMapType::const_iterator I = NonLocalDeps.begin(), - E = NonLocalDeps.end(); I != E; ++I) { - assert(I->first != D && "Inst occurs in data structures"); - const PerInstNLInfo &INLD = I->second; - for (NonLocalDepInfo::const_iterator II = INLD.first.begin(), - EE = INLD.first.end(); II != EE; ++II) - assert(II->getResult().getInst() != D && "Inst occurs in data structures"); + for (const auto &DepKV : NonLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + const PerInstNLInfo &INLD = DepKV.second; + for (const auto &Entry : INLD.first) + assert(Entry.getResult().getInst() != D && + "Inst occurs in data structures"); } - for (ReverseDepMapType::const_iterator I = ReverseLocalDeps.begin(), - E = ReverseLocalDeps.end(); I != E; ++I) { - assert(I->first != D && "Inst occurs in data structures"); - for (Instruction *Inst : I->second) + for (const auto &DepKV : ReverseLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + for (Instruction *Inst : DepKV.second) assert(Inst != D && "Inst occurs in data structures"); } - for (ReverseDepMapType::const_iterator I = ReverseNonLocalDeps.begin(), - E = ReverseNonLocalDeps.end(); - I != E; ++I) { - assert(I->first != D && "Inst occurs in data structures"); - for (Instruction *Inst : I->second) + for (const auto &DepKV : ReverseNonLocalDeps) { + assert(DepKV.first != D && "Inst occurs in data structures"); + for (Instruction *Inst : DepKV.second) assert(Inst != D && "Inst occurs in data structures"); } - for (ReverseNonLocalPtrDepTy::const_iterator - I = ReverseNonLocalPtrDeps.begin(), - E = ReverseNonLocalPtrDeps.end(); I != E; ++I) { - assert(I->first != D && "Inst occurs in rev NLPD map"); + for (const auto &DepKV : ReverseNonLocalPtrDeps) { + assert(DepKV.first != D && "Inst occurs in rev NLPD map"); - for (ValueIsLoadPair P : I->second) - assert(P != ValueIsLoadPair(D, false) && - P != ValueIsLoadPair(D, true) && + for (ValueIsLoadPair P : DepKV.second) + assert(P != ValueIsLoadPair(D, false) && P != ValueIsLoadPair(D, true) && "Inst occurs in ReverseNonLocalPtrDeps map"); } #endif } + +char MemoryDependenceAnalysis::PassID; + +MemoryDependenceResults +MemoryDependenceAnalysis::run(Function &F, AnalysisManager<Function> &AM) { + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + return MemoryDependenceResults(AA, AC, TLI, DT); +} + +char MemoryDependenceWrapperPass::ID = 0; + +INITIALIZE_PASS_BEGIN(MemoryDependenceWrapperPass, "memdep", + "Memory Dependence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(MemoryDependenceWrapperPass, "memdep", + "Memory Dependence Analysis", false, true) + +MemoryDependenceWrapperPass::MemoryDependenceWrapperPass() : FunctionPass(ID) { + initializeMemoryDependenceWrapperPassPass(*PassRegistry::getPassRegistry()); +} +MemoryDependenceWrapperPass::~MemoryDependenceWrapperPass() {} + +void MemoryDependenceWrapperPass::releaseMemory() { + MemDep.reset(); +} + +void MemoryDependenceWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequiredTransitive<AAResultsWrapperPass>(); + AU.addRequiredTransitive<TargetLibraryInfoWrapperPass>(); +} + +bool MemoryDependenceWrapperPass::runOnFunction(Function &F) { + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + MemDep.emplace(AA, AC, TLI, DT); + return false; +} diff --git a/lib/Analysis/MemoryLocation.cpp b/lib/Analysis/MemoryLocation.cpp index e4491261e055..a0ae72f1415f 100644 --- a/lib/Analysis/MemoryLocation.cpp +++ b/lib/Analysis/MemoryLocation.cpp @@ -90,23 +90,6 @@ MemoryLocation MemoryLocation::getForDest(const MemIntrinsic *MTI) { return MemoryLocation(MTI->getRawDest(), Size, AATags); } -// FIXME: This code is duplicated with BasicAliasAnalysis and should be hoisted -// to some common utility location. -static bool isMemsetPattern16(const Function *MS, - const TargetLibraryInfo &TLI) { - if (TLI.has(LibFunc::memset_pattern16) && - MS->getName() == "memset_pattern16") { - FunctionType *MemsetType = MS->getFunctionType(); - if (!MemsetType->isVarArg() && MemsetType->getNumParams() == 3 && - isa<PointerType>(MemsetType->getParamType(0)) && - isa<PointerType>(MemsetType->getParamType(1)) && - isa<IntegerType>(MemsetType->getParamType(2))) - return true; - } - - return false; -} - MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, unsigned ArgIdx, const TargetLibraryInfo &TLI) { @@ -159,8 +142,9 @@ MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, // for memcpy/memset. This is particularly important because the // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 // whenever possible. - if (CS.getCalledFunction() && - isMemsetPattern16(CS.getCalledFunction(), TLI)) { + LibFunc::Func F; + if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && + F == LibFunc::memset_pattern16 && TLI.has(F)) { assert((ArgIdx == 0 || ArgIdx == 1) && "Invalid argument index for memset_pattern16"); if (ArgIdx == 1) diff --git a/lib/Analysis/ModuleSummaryAnalysis.cpp b/lib/Analysis/ModuleSummaryAnalysis.cpp new file mode 100644 index 000000000000..c9ac2bdb7942 --- /dev/null +++ b/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -0,0 +1,249 @@ +//===- ModuleSummaryAnalysis.cpp - Module summary index builder -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass builds a ModuleSummaryIndex object for the module, to be written +// to bitcode or LLVM assembly. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ModuleSummaryAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/IndirectCallPromotionAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/ValueSymbolTable.h" +#include "llvm/Pass.h" +using namespace llvm; + +#define DEBUG_TYPE "module-summary-analysis" + +// Walk through the operands of a given User via worklist iteration and populate +// the set of GlobalValue references encountered. Invoked either on an +// Instruction or a GlobalVariable (which walks its initializer). +static void findRefEdges(const User *CurUser, DenseSet<const Value *> &RefEdges, + SmallPtrSet<const User *, 8> &Visited) { + SmallVector<const User *, 32> Worklist; + Worklist.push_back(CurUser); + + while (!Worklist.empty()) { + const User *U = Worklist.pop_back_val(); + + if (!Visited.insert(U).second) + continue; + + ImmutableCallSite CS(U); + + for (const auto &OI : U->operands()) { + const User *Operand = dyn_cast<User>(OI); + if (!Operand) + continue; + if (isa<BlockAddress>(Operand)) + continue; + if (isa<GlobalValue>(Operand)) { + // We have a reference to a global value. This should be added to + // the reference set unless it is a callee. Callees are handled + // specially by WriteFunction and are added to a separate list. + if (!(CS && CS.isCallee(&OI))) + RefEdges.insert(Operand); + continue; + } + Worklist.push_back(Operand); + } + } +} + +void ModuleSummaryIndexBuilder::computeFunctionSummary( + const Function &F, BlockFrequencyInfo *BFI) { + // Summary not currently supported for anonymous functions, they must + // be renamed. + if (!F.hasName()) + return; + + unsigned NumInsts = 0; + // Map from callee ValueId to profile count. Used to accumulate profile + // counts for all static calls to a given callee. + DenseMap<const Value *, CalleeInfo> CallGraphEdges; + DenseMap<GlobalValue::GUID, CalleeInfo> IndirectCallEdges; + DenseSet<const Value *> RefEdges; + ICallPromotionAnalysis ICallAnalysis; + + SmallPtrSet<const User *, 8> Visited; + for (const BasicBlock &BB : F) + for (const Instruction &I : BB) { + if (!isa<DbgInfoIntrinsic>(I)) + ++NumInsts; + + if (auto CS = ImmutableCallSite(&I)) { + auto *CalledFunction = CS.getCalledFunction(); + // Check if this is a direct call to a known function. + if (CalledFunction) { + if (CalledFunction->hasName() && !CalledFunction->isIntrinsic()) { + auto ScaledCount = BFI ? BFI->getBlockProfileCount(&BB) : None; + auto *CalleeId = + M->getValueSymbolTable().lookup(CalledFunction->getName()); + CallGraphEdges[CalleeId] += + (ScaledCount ? ScaledCount.getValue() : 0); + } + } else { + // Otherwise, check for an indirect call (call to a non-const value + // that isn't an inline assembly call). + const CallInst *CI = dyn_cast<CallInst>(&I); + if (CS.getCalledValue() && !isa<Constant>(CS.getCalledValue()) && + !(CI && CI->isInlineAsm())) { + uint32_t NumVals, NumCandidates; + uint64_t TotalCount; + auto CandidateProfileData = + ICallAnalysis.getPromotionCandidatesForInstruction( + &I, NumVals, TotalCount, NumCandidates); + for (auto &Candidate : CandidateProfileData) + IndirectCallEdges[Candidate.Value] += Candidate.Count; + } + } + } + findRefEdges(&I, RefEdges, Visited); + } + + GlobalValueSummary::GVFlags Flags(F); + std::unique_ptr<FunctionSummary> FuncSummary = + llvm::make_unique<FunctionSummary>(Flags, NumInsts); + FuncSummary->addCallGraphEdges(CallGraphEdges); + FuncSummary->addCallGraphEdges(IndirectCallEdges); + FuncSummary->addRefEdges(RefEdges); + Index->addGlobalValueSummary(F.getName(), std::move(FuncSummary)); +} + +void ModuleSummaryIndexBuilder::computeVariableSummary( + const GlobalVariable &V) { + DenseSet<const Value *> RefEdges; + SmallPtrSet<const User *, 8> Visited; + findRefEdges(&V, RefEdges, Visited); + GlobalValueSummary::GVFlags Flags(V); + std::unique_ptr<GlobalVarSummary> GVarSummary = + llvm::make_unique<GlobalVarSummary>(Flags); + GVarSummary->addRefEdges(RefEdges); + Index->addGlobalValueSummary(V.getName(), std::move(GVarSummary)); +} + +ModuleSummaryIndexBuilder::ModuleSummaryIndexBuilder( + const Module *M, + std::function<BlockFrequencyInfo *(const Function &F)> Ftor) + : Index(llvm::make_unique<ModuleSummaryIndex>()), M(M) { + // Check if the module can be promoted, otherwise just disable importing from + // it by not emitting any summary. + // FIXME: we could still import *into* it most of the time. + if (!moduleCanBeRenamedForThinLTO(*M)) + return; + + // Compute summaries for all functions defined in module, and save in the + // index. + for (auto &F : *M) { + if (F.isDeclaration()) + continue; + + BlockFrequencyInfo *BFI = nullptr; + std::unique_ptr<BlockFrequencyInfo> BFIPtr; + if (Ftor) + BFI = Ftor(F); + else if (F.getEntryCount().hasValue()) { + LoopInfo LI{DominatorTree(const_cast<Function &>(F))}; + BranchProbabilityInfo BPI{F, LI}; + BFIPtr = llvm::make_unique<BlockFrequencyInfo>(F, BPI, LI); + BFI = BFIPtr.get(); + } + + computeFunctionSummary(F, BFI); + } + + // Compute summaries for all variables defined in module, and save in the + // index. + for (const GlobalVariable &G : M->globals()) { + if (G.isDeclaration()) + continue; + computeVariableSummary(G); + } +} + +char ModuleSummaryIndexWrapperPass::ID = 0; +INITIALIZE_PASS_BEGIN(ModuleSummaryIndexWrapperPass, "module-summary-analysis", + "Module Summary Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(ModuleSummaryIndexWrapperPass, "module-summary-analysis", + "Module Summary Analysis", false, true) + +ModulePass *llvm::createModuleSummaryIndexWrapperPass() { + return new ModuleSummaryIndexWrapperPass(); +} + +ModuleSummaryIndexWrapperPass::ModuleSummaryIndexWrapperPass() + : ModulePass(ID) { + initializeModuleSummaryIndexWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) { + IndexBuilder = llvm::make_unique<ModuleSummaryIndexBuilder>( + &M, [this](const Function &F) { + return &(this->getAnalysis<BlockFrequencyInfoWrapperPass>( + *const_cast<Function *>(&F)) + .getBFI()); + }); + return false; +} + +bool ModuleSummaryIndexWrapperPass::doFinalization(Module &M) { + IndexBuilder.reset(); + return false; +} + +void ModuleSummaryIndexWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.setPreservesAll(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); +} + +bool llvm::moduleCanBeRenamedForThinLTO(const Module &M) { + // We cannot currently promote or rename anything used in inline assembly, + // which are not visible to the compiler. Detect a possible case by looking + // for a llvm.used local value, in conjunction with an inline assembly call + // in the module. Prevent importing of any modules containing these uses by + // suppressing generation of the index. This also prevents importing + // into this module, which is also necessary to avoid needing to rename + // in case of a name clash between a local in this module and an imported + // global. + // FIXME: If we find we need a finer-grained approach of preventing promotion + // and renaming of just the functions using inline assembly we will need to: + // - Add flag in the function summaries to identify those with inline asm. + // - Prevent importing of any functions with flag set. + // - Prevent importing of any global function with the same name as a + // function in current module that has the flag set. + // - For any llvm.used value that is exported and promoted, add a private + // alias to the original name in the current module (even if we don't + // export the function using those values in inline asm, another function + // with a reference could be exported). + SmallPtrSet<GlobalValue *, 8> Used; + collectUsedGlobalVariables(M, Used, /*CompilerUsed*/ false); + bool LocalIsUsed = + llvm::any_of(Used, [](GlobalValue *V) { return V->hasLocalLinkage(); }); + if (!LocalIsUsed) + return true; + + // Walk all the instructions in the module and find if one is inline ASM + auto HasInlineAsm = llvm::any_of(M, [](const Function &F) { + return llvm::any_of(instructions(F), [](const Instruction &I) { + const CallInst *CallI = dyn_cast<CallInst>(&I); + if (!CallI) + return false; + return CallI->isInlineAsm(); + }); + }); + return !HasInlineAsm; +} diff --git a/lib/Analysis/ObjCARCAliasAnalysis.cpp b/lib/Analysis/ObjCARCAliasAnalysis.cpp index 25f660ffe221..9bb1048ea8ba 100644 --- a/lib/Analysis/ObjCARCAliasAnalysis.cpp +++ b/lib/Analysis/ObjCARCAliasAnalysis.cpp @@ -131,19 +131,13 @@ ModRefInfo ObjCARCAAResult::getModRefInfo(ImmutableCallSite CS, return AAResultBase::getModRefInfo(CS, Loc); } -ObjCARCAAResult ObjCARCAA::run(Function &F, AnalysisManager<Function> *AM) { - return ObjCARCAAResult(F.getParent()->getDataLayout(), - AM->getResult<TargetLibraryAnalysis>(F)); +ObjCARCAAResult ObjCARCAA::run(Function &F, AnalysisManager<Function> &AM) { + return ObjCARCAAResult(F.getParent()->getDataLayout()); } -char ObjCARCAA::PassID; - char ObjCARCAAWrapperPass::ID = 0; -INITIALIZE_PASS_BEGIN(ObjCARCAAWrapperPass, "objc-arc-aa", - "ObjC-ARC-Based Alias Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ObjCARCAAWrapperPass, "objc-arc-aa", - "ObjC-ARC-Based Alias Analysis", false, true) +INITIALIZE_PASS(ObjCARCAAWrapperPass, "objc-arc-aa", + "ObjC-ARC-Based Alias Analysis", false, true) ImmutablePass *llvm::createObjCARCAAWrapperPass() { return new ObjCARCAAWrapperPass(); @@ -154,8 +148,7 @@ ObjCARCAAWrapperPass::ObjCARCAAWrapperPass() : ImmutablePass(ID) { } bool ObjCARCAAWrapperPass::doInitialization(Module &M) { - Result.reset(new ObjCARCAAResult( - M.getDataLayout(), getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); + Result.reset(new ObjCARCAAResult(M.getDataLayout())); return false; } @@ -166,5 +159,4 @@ bool ObjCARCAAWrapperPass::doFinalization(Module &M) { void ObjCARCAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); } diff --git a/lib/Analysis/ObjCARCInstKind.cpp b/lib/Analysis/ObjCARCInstKind.cpp index 133b63513c87..3dc1463b8d8b 100644 --- a/lib/Analysis/ObjCARCInstKind.cpp +++ b/lib/Analysis/ObjCARCInstKind.cpp @@ -34,6 +34,8 @@ raw_ostream &llvm::objcarc::operator<<(raw_ostream &OS, return OS << "ARCInstKind::Retain"; case ARCInstKind::RetainRV: return OS << "ARCInstKind::RetainRV"; + case ARCInstKind::ClaimRV: + return OS << "ARCInstKind::ClaimRV"; case ARCInstKind::RetainBlock: return OS << "ARCInstKind::RetainBlock"; case ARCInstKind::Release: @@ -103,6 +105,8 @@ ARCInstKind llvm::objcarc::GetFunctionClass(const Function *F) { return StringSwitch<ARCInstKind>(F->getName()) .Case("objc_retain", ARCInstKind::Retain) .Case("objc_retainAutoreleasedReturnValue", ARCInstKind::RetainRV) + .Case("objc_unsafeClaimAutoreleasedReturnValue", + ARCInstKind::ClaimRV) .Case("objc_retainBlock", ARCInstKind::RetainBlock) .Case("objc_release", ARCInstKind::Release) .Case("objc_autorelease", ARCInstKind::Autorelease) @@ -350,6 +354,7 @@ bool llvm::objcarc::IsUser(ARCInstKind Class) { case ARCInstKind::StoreStrong: case ARCInstKind::Call: case ARCInstKind::None: + case ARCInstKind::ClaimRV: return false; } llvm_unreachable("covered switch isn't covered?"); @@ -385,6 +390,7 @@ bool llvm::objcarc::IsRetain(ARCInstKind Class) { case ARCInstKind::Call: case ARCInstKind::User: case ARCInstKind::None: + case ARCInstKind::ClaimRV: return false; } llvm_unreachable("covered switch isn't covered?"); @@ -398,6 +404,7 @@ bool llvm::objcarc::IsAutorelease(ARCInstKind Class) { return true; case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::RetainBlock: case ARCInstKind::Release: case ARCInstKind::AutoreleasepoolPush: @@ -429,6 +436,7 @@ bool llvm::objcarc::IsForwarding(ARCInstKind Class) { switch (Class) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::Autorelease: case ARCInstKind::AutoreleaseRV: case ARCInstKind::NoopCast: @@ -463,6 +471,7 @@ bool llvm::objcarc::IsNoopOnNull(ARCInstKind Class) { switch (Class) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::Release: case ARCInstKind::Autorelease: case ARCInstKind::AutoreleaseRV: @@ -498,6 +507,7 @@ bool llvm::objcarc::IsAlwaysTail(ARCInstKind Class) { switch (Class) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::AutoreleaseRV: return true; case ARCInstKind::Release: @@ -538,6 +548,7 @@ bool llvm::objcarc::IsNeverTail(ARCInstKind Class) { return true; case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::AutoreleaseRV: case ARCInstKind::Release: case ARCInstKind::RetainBlock: @@ -572,6 +583,7 @@ bool llvm::objcarc::IsNoThrow(ARCInstKind Class) { switch (Class) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::Release: case ARCInstKind::Autorelease: case ARCInstKind::AutoreleaseRV: @@ -616,6 +628,7 @@ bool llvm::objcarc::CanInterruptRV(ARCInstKind Class) { return true; case ARCInstKind::Retain: case ARCInstKind::RetainRV: + case ARCInstKind::ClaimRV: case ARCInstKind::Release: case ARCInstKind::AutoreleasepoolPush: case ARCInstKind::RetainBlock: @@ -668,6 +681,7 @@ bool llvm::objcarc::CanDecrementRefCount(ARCInstKind Kind) { case ARCInstKind::StoreStrong: case ARCInstKind::CallOrUser: case ARCInstKind::Call: + case ARCInstKind::ClaimRV: return true; } diff --git a/lib/Analysis/OptimizationDiagnosticInfo.cpp b/lib/Analysis/OptimizationDiagnosticInfo.cpp new file mode 100644 index 000000000000..e979ba2531e4 --- /dev/null +++ b/lib/Analysis/OptimizationDiagnosticInfo.cpp @@ -0,0 +1,88 @@ +//===- OptimizationDiagnosticInfo.cpp - Optimization Diagnostic -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Optimization diagnostic interfaces. It's packaged as an analysis pass so +// that by using this service passes become dependent on BFI as well. BFI is +// used to compute the "hotness" of the diagnostic message. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/LazyBlockFrequencyInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/LLVMContext.h" + +using namespace llvm; + +Optional<uint64_t> OptimizationRemarkEmitter::computeHotness(Value *V) { + if (!BFI) + return None; + + return BFI->getBlockProfileCount(cast<BasicBlock>(V)); +} + +void OptimizationRemarkEmitter::emitOptimizationRemarkMissed( + const char *PassName, const DebugLoc &DLoc, Value *V, const Twine &Msg) { + LLVMContext &Ctx = F->getContext(); + Ctx.diagnose(DiagnosticInfoOptimizationRemarkMissed(PassName, *F, DLoc, Msg, + computeHotness(V))); +} + +void OptimizationRemarkEmitter::emitOptimizationRemarkMissed( + const char *PassName, Loop *L, const Twine &Msg) { + emitOptimizationRemarkMissed(PassName, L->getStartLoc(), L->getHeader(), Msg); +} + +OptimizationRemarkEmitterWrapperPass::OptimizationRemarkEmitterWrapperPass() + : FunctionPass(ID) { + initializeOptimizationRemarkEmitterWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +bool OptimizationRemarkEmitterWrapperPass::runOnFunction(Function &Fn) { + BlockFrequencyInfo *BFI; + + if (Fn.getContext().getDiagnosticHotnessRequested()) + BFI = &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI(); + else + BFI = nullptr; + + ORE = llvm::make_unique<OptimizationRemarkEmitter>(&Fn, BFI); + return false; +} + +void OptimizationRemarkEmitterWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); + AU.setPreservesAll(); +} + +char OptimizationRemarkEmitterAnalysis::PassID; + +OptimizationRemarkEmitter +OptimizationRemarkEmitterAnalysis::run(Function &F, AnalysisManager<Function> &AM) { + BlockFrequencyInfo *BFI; + + if (F.getContext().getDiagnosticHotnessRequested()) + BFI = &AM.getResult<BlockFrequencyAnalysis>(F); + else + BFI = nullptr; + + return OptimizationRemarkEmitter(&F, BFI); +} + +char OptimizationRemarkEmitterWrapperPass::ID = 0; +static const char ore_name[] = "Optimization Remark Emitter"; +#define ORE_NAME "opt-remark-emitter" + +INITIALIZE_PASS_BEGIN(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, + false, true) +INITIALIZE_PASS_DEPENDENCY(LazyBFIPass) +INITIALIZE_PASS_END(OptimizationRemarkEmitterWrapperPass, ORE_NAME, ore_name, + false, true) diff --git a/lib/Analysis/PHITransAddr.cpp b/lib/Analysis/PHITransAddr.cpp index f7545ea05a39..b4aad74d50dc 100644 --- a/lib/Analysis/PHITransAddr.cpp +++ b/lib/Analysis/PHITransAddr.cpp @@ -42,7 +42,7 @@ static bool CanPHITrans(Instruction *Inst) { } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void PHITransAddr::dump() const { +LLVM_DUMP_METHOD void PHITransAddr::dump() const { if (!Addr) { dbgs() << "PHITransAddr: null\n"; return; @@ -229,7 +229,8 @@ Value *PHITransAddr::PHITranslateSubExpr(Value *V, BasicBlock *CurBB, return GEP; // Simplify the GEP to handle 'gep x, 0' -> x etc. - if (Value *V = SimplifyGEPInst(GEPOps, DL, TLI, DT, AC)) { + if (Value *V = SimplifyGEPInst(GEP->getSourceElementType(), + GEPOps, DL, TLI, DT, AC)) { for (unsigned i = 0, e = GEPOps.size(); i != e; ++i) RemoveInstInputs(GEPOps[i], InstInputs); diff --git a/lib/Analysis/PostDominators.cpp b/lib/Analysis/PostDominators.cpp index 6d929091e3d2..73550805d5ba 100644 --- a/lib/Analysis/PostDominators.cpp +++ b/lib/Analysis/PostDominators.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/SetOperations.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" #include "llvm/Support/Debug.h" #include "llvm/Support/GenericDomTreeConstruction.h" using namespace llvm; @@ -26,25 +27,39 @@ using namespace llvm; // PostDominatorTree Implementation //===----------------------------------------------------------------------===// -char PostDominatorTree::ID = 0; -INITIALIZE_PASS(PostDominatorTree, "postdomtree", +char PostDominatorTreeWrapperPass::ID = 0; +INITIALIZE_PASS(PostDominatorTreeWrapperPass, "postdomtree", "Post-Dominator Tree Construction", true, true) -bool PostDominatorTree::runOnFunction(Function &F) { - DT->recalculate(F); +bool PostDominatorTreeWrapperPass::runOnFunction(Function &F) { + DT.recalculate(F); return false; } -PostDominatorTree::~PostDominatorTree() { - delete DT; +void PostDominatorTreeWrapperPass::print(raw_ostream &OS, const Module *) const { + DT.print(OS); } -void PostDominatorTree::print(raw_ostream &OS, const Module *) const { - DT->print(OS); +FunctionPass* llvm::createPostDomTree() { + return new PostDominatorTreeWrapperPass(); } +char PostDominatorTreeAnalysis::PassID; -FunctionPass* llvm::createPostDomTree() { - return new PostDominatorTree(); +PostDominatorTree PostDominatorTreeAnalysis::run(Function &F, + FunctionAnalysisManager &) { + PostDominatorTree PDT; + PDT.recalculate(F); + return PDT; } +PostDominatorTreePrinterPass::PostDominatorTreePrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses +PostDominatorTreePrinterPass::run(Function &F, FunctionAnalysisManager &AM) { + OS << "PostDominatorTree for function: " << F.getName() << "\n"; + AM.getResult<PostDominatorTreeAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} diff --git a/lib/Analysis/ProfileSummaryInfo.cpp b/lib/Analysis/ProfileSummaryInfo.cpp new file mode 100644 index 000000000000..9cf99af49581 --- /dev/null +++ b/lib/Analysis/ProfileSummaryInfo.cpp @@ -0,0 +1,166 @@ +//===- ProfileSummaryInfo.cpp - Global profile summary information --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains a pass that provides access to the global profile summary +// information. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ProfileSummary.h" +using namespace llvm; + +// The following two parameters determine the threshold for a count to be +// considered hot/cold. These two parameters are percentile values (multiplied +// by 10000). If the counts are sorted in descending order, the minimum count to +// reach ProfileSummaryCutoffHot gives the threshold to determine a hot count. +// Similarly, the minimum count to reach ProfileSummaryCutoffCold gives the +// threshold for determining cold count (everything <= this threshold is +// considered cold). + +static cl::opt<int> ProfileSummaryCutoffHot( + "profile-summary-cutoff-hot", cl::Hidden, cl::init(999000), cl::ZeroOrMore, + cl::desc("A count is hot if it exceeds the minimum count to" + " reach this percentile of total counts.")); + +static cl::opt<int> ProfileSummaryCutoffCold( + "profile-summary-cutoff-cold", cl::Hidden, cl::init(999999), cl::ZeroOrMore, + cl::desc("A count is cold if it is below the minimum count" + " to reach this percentile of total counts.")); + +// Find the minimum count to reach a desired percentile of counts. +static uint64_t getMinCountForPercentile(SummaryEntryVector &DS, + uint64_t Percentile) { + auto Compare = [](const ProfileSummaryEntry &Entry, uint64_t Percentile) { + return Entry.Cutoff < Percentile; + }; + auto It = std::lower_bound(DS.begin(), DS.end(), Percentile, Compare); + // The required percentile has to be <= one of the percentiles in the + // detailed summary. + if (It == DS.end()) + report_fatal_error("Desired percentile exceeds the maximum cutoff"); + return It->MinCount; +} + +// The profile summary metadata may be attached either by the frontend or by +// any backend passes (IR level instrumentation, for example). This method +// checks if the Summary is null and if so checks if the summary metadata is now +// available in the module and parses it to get the Summary object. +void ProfileSummaryInfo::computeSummary() { + if (Summary) + return; + auto *SummaryMD = M.getProfileSummary(); + if (!SummaryMD) + return; + Summary.reset(ProfileSummary::getFromMD(SummaryMD)); +} + +// Returns true if the function is a hot function. If it returns false, it +// either means it is not hot or it is unknown whether F is hot or not (for +// example, no profile data is available). +bool ProfileSummaryInfo::isHotFunction(const Function *F) { + computeSummary(); + if (!F || !Summary) + return false; + auto FunctionCount = F->getEntryCount(); + // FIXME: The heuristic used below for determining hotness is based on + // preliminary SPEC tuning for inliner. This will eventually be a + // convenience method that calls isHotCount. + return (FunctionCount && + FunctionCount.getValue() >= + (uint64_t)(0.3 * (double)Summary->getMaxFunctionCount())); +} + +// Returns true if the function is a cold function. If it returns false, it +// either means it is not cold or it is unknown whether F is cold or not (for +// example, no profile data is available). +bool ProfileSummaryInfo::isColdFunction(const Function *F) { + computeSummary(); + if (!F) + return false; + if (F->hasFnAttribute(Attribute::Cold)) { + return true; + } + if (!Summary) + return false; + auto FunctionCount = F->getEntryCount(); + // FIXME: The heuristic used below for determining coldness is based on + // preliminary SPEC tuning for inliner. This will eventually be a + // convenience method that calls isHotCount. + return (FunctionCount && + FunctionCount.getValue() <= + (uint64_t)(0.01 * (double)Summary->getMaxFunctionCount())); +} + +// Compute the hot and cold thresholds. +void ProfileSummaryInfo::computeThresholds() { + if (!Summary) + computeSummary(); + if (!Summary) + return; + auto &DetailedSummary = Summary->getDetailedSummary(); + HotCountThreshold = + getMinCountForPercentile(DetailedSummary, ProfileSummaryCutoffHot); + ColdCountThreshold = + getMinCountForPercentile(DetailedSummary, ProfileSummaryCutoffCold); +} + +bool ProfileSummaryInfo::isHotCount(uint64_t C) { + if (!HotCountThreshold) + computeThresholds(); + return HotCountThreshold && C >= HotCountThreshold.getValue(); +} + +bool ProfileSummaryInfo::isColdCount(uint64_t C) { + if (!ColdCountThreshold) + computeThresholds(); + return ColdCountThreshold && C <= ColdCountThreshold.getValue(); +} + +ProfileSummaryInfo *ProfileSummaryInfoWrapperPass::getPSI(Module &M) { + if (!PSI) + PSI.reset(new ProfileSummaryInfo(M)); + return PSI.get(); +} + +INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", + "Profile summary info", false, true) + +ProfileSummaryInfoWrapperPass::ProfileSummaryInfoWrapperPass() + : ImmutablePass(ID) { + initializeProfileSummaryInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +char ProfileSummaryAnalysis::PassID; +ProfileSummaryInfo ProfileSummaryAnalysis::run(Module &M, + ModuleAnalysisManager &) { + return ProfileSummaryInfo(M); +} + +// FIXME: This only tests isHotFunction and isColdFunction and not the +// isHotCount and isColdCount calls. +PreservedAnalyses ProfileSummaryPrinterPass::run(Module &M, + AnalysisManager<Module> &AM) { + ProfileSummaryInfo &PSI = AM.getResult<ProfileSummaryAnalysis>(M); + + OS << "Functions in " << M.getName() << " with hot/cold annotations: \n"; + for (auto &F : M) { + OS << F.getName(); + if (PSI.isHotFunction(&F)) + OS << " :hot "; + else if (PSI.isColdFunction(&F)) + OS << " :cold "; + OS << "\n"; + } + return PreservedAnalyses::all(); +} + +char ProfileSummaryInfoWrapperPass::ID = 0; diff --git a/lib/Analysis/RegionInfo.cpp b/lib/Analysis/RegionInfo.cpp index f59d26730327..6860a3e63953 100644 --- a/lib/Analysis/RegionInfo.cpp +++ b/lib/Analysis/RegionInfo.cpp @@ -15,12 +15,10 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfoImpl.h" #include "llvm/Analysis/RegionIterator.h" +#include "llvm/IR/PassManager.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" -#include <algorithm> -#include <iterator> -#include <set> #ifndef NDEBUG #include "llvm/Analysis/RegionPrinter.h" #endif @@ -128,8 +126,8 @@ bool RegionInfoPass::runOnFunction(Function &F) { releaseMemory(); auto DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - auto PDT = &getAnalysis<PostDominatorTree>(); - auto DF = &getAnalysis<DominanceFrontier>(); + auto PDT = &getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + auto DF = &getAnalysis<DominanceFrontierWrapperPass>().getDominanceFrontier(); RI.recalculate(F, DT, PDT, DF); return false; @@ -146,8 +144,8 @@ void RegionInfoPass::verifyAnalysis() const { void RegionInfoPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequiredTransitive<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTree>(); - AU.addRequired<DominanceFrontier>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<DominanceFrontierWrapperPass>(); } void RegionInfoPass::print(raw_ostream &OS, const Module *) const { @@ -155,7 +153,7 @@ void RegionInfoPass::print(raw_ostream &OS, const Module *) const { } #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) -void RegionInfoPass::dump() const { +LLVM_DUMP_METHOD void RegionInfoPass::dump() const { RI.dump(); } #endif @@ -165,8 +163,8 @@ char RegionInfoPass::ID = 0; INITIALIZE_PASS_BEGIN(RegionInfoPass, "regions", "Detect single entry single exit regions", true, true) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTree) -INITIALIZE_PASS_DEPENDENCY(DominanceFrontier) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominanceFrontierWrapperPass) INITIALIZE_PASS_END(RegionInfoPass, "regions", "Detect single entry single exit regions", true, true) @@ -180,3 +178,36 @@ namespace llvm { } } +//===----------------------------------------------------------------------===// +// RegionInfoAnalysis implementation +// + +char RegionInfoAnalysis::PassID; + +RegionInfo RegionInfoAnalysis::run(Function &F, AnalysisManager<Function> &AM) { + RegionInfo RI; + auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *PDT = &AM.getResult<PostDominatorTreeAnalysis>(F); + auto *DF = &AM.getResult<DominanceFrontierAnalysis>(F); + + RI.recalculate(F, DT, PDT, DF); + return RI; +} + +RegionInfoPrinterPass::RegionInfoPrinterPass(raw_ostream &OS) + : OS(OS) {} + +PreservedAnalyses RegionInfoPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "Region Tree for function: " << F.getName() << "\n"; + AM.getResult<RegionInfoAnalysis>(F).print(OS); + + return PreservedAnalyses::all(); +} + +PreservedAnalyses RegionInfoVerifierPass::run(Function &F, + AnalysisManager<Function> &AM) { + AM.getResult<RegionInfoAnalysis>(F).verifyAnalysis(); + + return PreservedAnalyses::all(); +} diff --git a/lib/Analysis/RegionPrinter.cpp b/lib/Analysis/RegionPrinter.cpp index acb218d5fea0..30a4e011060e 100644 --- a/lib/Analysis/RegionPrinter.cpp +++ b/lib/Analysis/RegionPrinter.cpp @@ -117,8 +117,8 @@ struct DOTGraphTraits<RegionInfo *> : public DOTGraphTraits<RegionNode *> { << ((R.getDepth() * 2 % 12) + 2) << "\n"; } - for (Region::const_iterator RI = R.begin(), RE = R.end(); RI != RE; ++RI) - printRegionCluster(**RI, GW, depth + 1); + for (const auto &RI : R) + printRegionCluster(*RI, GW, depth + 1); const RegionInfo &RI = *static_cast<const RegionInfo*>(R.getRegionInfo()); diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index ef1bb3a36c8d..2abbf3480358 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -111,10 +111,14 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, "derived loop"), cl::init(100)); -// FIXME: Enable this with XDEBUG when the test suite is clean. +// FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. static cl::opt<bool> VerifySCEV("verify-scev", cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); +static cl::opt<bool> + VerifySCEVMap("verify-scev-maps", + cl::desc("Verify no dangling value in ScalarEvolution's " + "ExprValueMap (slow)")); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -162,11 +166,11 @@ void SCEV::print(raw_ostream &OS) const { for (unsigned i = 1, e = AR->getNumOperands(); i != e; ++i) OS << ",+," << *AR->getOperand(i); OS << "}<"; - if (AR->getNoWrapFlags(FlagNUW)) + if (AR->hasNoUnsignedWrap()) OS << "nuw><"; - if (AR->getNoWrapFlags(FlagNSW)) + if (AR->hasNoSignedWrap()) OS << "nsw><"; - if (AR->getNoWrapFlags(FlagNW) && + if (AR->hasNoSelfWrap() && !AR->getNoWrapFlags((NoWrapFlags)(FlagNUW | FlagNSW))) OS << "nw><"; AR->getLoop()->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -196,9 +200,9 @@ void SCEV::print(raw_ostream &OS) const { switch (NAry->getSCEVType()) { case scAddExpr: case scMulExpr: - if (NAry->getNoWrapFlags(FlagNUW)) + if (NAry->hasNoUnsignedWrap()) OS << "<nuw>"; - if (NAry->getNoWrapFlags(FlagNSW)) + if (NAry->hasNoSignedWrap()) OS << "<nsw>"; } return; @@ -283,8 +287,6 @@ bool SCEV::isAllOnesValue() const { return false; } -/// isNonConstantNegative - Return true if the specified scev is negated, but -/// not a constant. bool SCEV::isNonConstantNegative() const { const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(this); if (!Mul) return false; @@ -620,10 +622,10 @@ public: }; } // end anonymous namespace -/// GroupByComplexity - Given a list of SCEV objects, order them by their -/// complexity, and group objects of the same complexity together by value. -/// When this routine is finished, we know that any duplicates in the vector are -/// consecutive and that complexity is monotonically increasing. +/// Given a list of SCEV objects, order them by their complexity, and group +/// objects of the same complexity together by value. When this routine is +/// finished, we know that any duplicates in the vector are consecutive and that +/// complexity is monotonically increasing. /// /// Note that we go take special precautions to ensure that we get deterministic /// results from this routine. In other words, we don't want the results of @@ -723,7 +725,7 @@ public: } // Split the Denominator when it is a product. - if (const SCEVMulExpr *T = dyn_cast<const SCEVMulExpr>(Denominator)) { + if (const SCEVMulExpr *T = dyn_cast<SCEVMulExpr>(Denominator)) { const SCEV *Q, *R; *Quotient = Numerator; for (const SCEV *Op : T->operands()) { @@ -922,8 +924,7 @@ private: // Simple SCEV method implementations //===----------------------------------------------------------------------===// -/// BinomialCoefficient - Compute BC(It, K). The result has width W. -/// Assume, K > 0. +/// Compute BC(It, K). The result has width W. Assume, K > 0. static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, ScalarEvolution &SE, Type *ResultTy) { @@ -1034,10 +1035,10 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, SE.getTruncateOrZeroExtend(DivResult, ResultTy)); } -/// evaluateAtIteration - Return the value of this chain of recurrences at -/// the specified iteration number. We can evaluate this recurrence by -/// multiplying each element in the chain by the binomial coefficient -/// corresponding to it. In other words, we can evaluate {A,+,B,+,C,+,D} as: +/// Return the value of this chain of recurrences at the specified iteration +/// number. We can evaluate this recurrence by multiplying each element in the +/// chain by the binomial coefficient corresponding to it. In other words, we +/// can evaluate {A,+,B,+,C,+,D} as: /// /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// @@ -1450,9 +1451,14 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoUnsignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNUW)) + if (AR->hasNoUnsignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this), getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); @@ -1512,11 +1518,22 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. if (isKnownPositive(Step)) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - getUnsignedRange(Step).getUnsignedMax()); @@ -1524,7 +1541,8 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NUW, which is propagated to this AddRec. + // Cache knowledge of AR NUW, which is propagated to this + // AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1538,8 +1556,9 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR->getPostIncExpr(*this), N))) { - // Cache knowledge of AR NW, which is propagated to this AddRec. - // Negative step causes unsigned wrap, but it still can't self-wrap. + // Cache knowledge of AR NW, which is propagated to this + // AddRec. Negative step causes unsigned wrap, but it + // still can't self-wrap. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW); // Return the expression with the addrec on the outside. return getAddRecExpr( @@ -1559,7 +1578,7 @@ const SCEV *ScalarEvolution::getZeroExtendExpr(const SCEV *Op, if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> - if (SA->getNoWrapFlags(SCEV::FlagNUW)) { + if (SA->hasNoUnsignedWrap()) { // If the addition does not unsign overflow then we can, by definition, // commute the zero extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1608,10 +1627,6 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, void *IP = nullptr; if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; - // If the input value is provably positive, build a zext instead. - if (isKnownNonNegative(Op)) - return getZeroExtendExpr(Op, Ty); - // sext(trunc(x)) --> sext(x) or x or trunc(x) if (const SCEVTruncateExpr *ST = dyn_cast<SCEVTruncateExpr>(Op)) { // It's possible the bits taken off by the truncate were all sign bits. If @@ -1643,7 +1658,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> - if (SA->getNoWrapFlags(SCEV::FlagNSW)) { + if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, // commute the sign extension with the addition operation. SmallVector<const SCEV *, 4> Ops; @@ -1663,9 +1678,14 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, unsigned BitWidth = getTypeSizeInBits(AR->getType()); const Loop *L = AR->getLoop(); + if (!AR->hasNoSignedWrap()) { + auto NewFlags = proveNoWrapViaConstantRanges(AR); + const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags); + } + // If we have special knowledge that this addrec won't overflow, // we don't need to do any further analysis. - if (AR->getNoWrapFlags(SCEV::FlagNSW)) + if (AR->hasNoSignedWrap()) return getAddRecExpr( getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this), getSignExtendExpr(Step, Ty), L, SCEV::FlagNSW); @@ -1732,11 +1752,23 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getZeroExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + } - // If the backedge is guarded by a comparison with the pre-inc value - // the addrec is safe. Also, if the entry is guarded by a comparison - // with the start value and the backedge is guarded by a comparison - // with the post-inc value, the addrec is safe. + // Normally, in the cases we can prove no-overflow via a + // backedge guarding condition, we can also compute a backedge + // taken count for the loop. The exceptions are assumptions and + // guards present in the loop -- SCEV is not great at exploiting + // these to compute max backedge taken counts, but can still use + // these to prove lack of overflow. Use this fact to avoid + // doing extra work that may not pay off. + + if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards || + !AC.assumptions().empty()) { + // If the backedge is guarded by a comparison with the pre-inc + // value the addrec is safe. Also, if the entry is guarded by + // a comparison with the start value and the backedge is + // guarded by a comparison with the post-inc value, the addrec + // is safe. ICmpInst::Predicate Pred; const SCEV *OverflowLimit = getSignedOverflowLimitForStep(Step, &Pred, this); @@ -1752,6 +1784,7 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, getSignExtendExpr(Step, Ty), L, AR->getNoWrapFlags()); } } + // If Start and Step are constants, check if we can apply this // transformation: // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 @@ -1777,6 +1810,11 @@ const SCEV *ScalarEvolution::getSignExtendExpr(const SCEV *Op, } } + // If the input value is provably positive and we could not simplify + // away the sext build a zext instead. + if (isKnownNonNegative(Op)) + return getZeroExtendExpr(Op, Ty); + // The cast wasn't folded; create an explicit cast node. // Recompute the insert position, as it may have been invalidated. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; @@ -1836,11 +1874,10 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, return ZExt; } -/// CollectAddOperandsWithScales - Process the given Ops list, which is -/// a list of operands to be added under the given scale, update the given -/// map. This is a helper function for getAddRecExpr. As an example of -/// what it does, given a sequence of operands that would form an add -/// expression like this: +/// Process the given Ops list, which is a list of operands to be added under +/// the given scale, update the given map. This is a helper function for +/// getAddRecExpr. As an example of what it does, given a sequence of operands +/// that would form an add expression like this: /// /// m + n + 13 + (A * (o + p + (B * (q + m + 29)))) + r + (-1 * r) /// @@ -1899,7 +1936,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, // the map. SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end()); const SCEV *Key = SE.getMulExpr(MulOps); - auto Pair = M.insert(std::make_pair(Key, NewScale)); + auto Pair = M.insert({Key, NewScale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1912,7 +1949,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, } else { // An ordinary operand. Update the map. std::pair<DenseMap<const SCEV *, APInt>::iterator, bool> Pair = - M.insert(std::make_pair(Ops[i], Scale)); + M.insert({Ops[i], Scale}); if (Pair.second) { NewOps.push_back(Pair.first->first); } else { @@ -1965,15 +2002,14 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { - auto NSWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, OBO::NoSignedWrap); + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); } if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { - auto NUWRegion = - ConstantRange::makeNoWrapRegion(Instruction::Add, C, - OBO::NoUnsignedWrap); + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); } @@ -1982,8 +2018,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, return Flags; } -/// getAddExpr - Get a canonical add expression, or something simpler if -/// possible. +/// Get a canonical add expression, or something simpler if possible. const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) && @@ -2266,7 +2301,10 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(), AddRec->op_end()); - AddRecOps[0] = getAddExpr(LIOps); + // This follows from the fact that the no-wrap flags on the outer add + // expression are applicable on the 0th iteration, when the add recurrence + // will be equal to its start value. + AddRecOps[0] = getAddExpr(LIOps, Flags); // Build the new addrec. Propagate the NUW and NSW flags if both the // outer add and the inner addrec are guaranteed to have no overflow. @@ -2391,8 +2429,7 @@ static bool containsConstantSomewhere(const SCEV *StartExpr) { return false; } -/// getMulExpr - Get a canonical multiply expression, or something simpler if -/// possible. +/// Get a canonical multiply expression, or something simpler if possible. const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) && @@ -2632,8 +2669,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, return S; } -/// getUDivExpr - Get a canonical unsigned division expression, or something -/// simpler if possible. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, const SCEV *RHS) { assert(getEffectiveSCEVType(LHS->getType()) == @@ -2764,10 +2801,10 @@ static const APInt gcd(const SCEVConstant *C1, const SCEVConstant *C2) { return APIntOps::GreatestCommonDivisor(A, B); } -/// getUDivExactExpr - Get a canonical unsigned division expression, or -/// something simpler if possible. There is no representation for an exact udiv -/// in SCEV IR, but we can attempt to remove factors from the LHS and RHS. -/// We can't do this when it's not exact because the udiv may be clearing bits. +/// Get a canonical unsigned division expression, or something simpler if +/// possible. There is no representation for an exact udiv in SCEV IR, but we +/// can attempt to remove factors from the LHS and RHS. We can't do this when +/// it's not exact because the udiv may be clearing bits. const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, const SCEV *RHS) { // TODO: we could try to find factors in all sorts of things, but for now we @@ -2821,8 +2858,8 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS, return getUDivExpr(LHS, RHS); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2838,8 +2875,8 @@ const SCEV *ScalarEvolution::getAddRecExpr(const SCEV *Start, const SCEV *Step, return getAddRecExpr(Operands, L, Flags); } -/// getAddRecExpr - Get an add recurrence expression for the specified loop. -/// Simplify the expression as much as possible. +/// Get an add recurrence expression for the specified loop. Simplify the +/// expression as much as possible. const SCEV * ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, const Loop *L, SCEV::NoWrapFlags Flags) { @@ -2985,9 +3022,7 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, const SCEV *ScalarEvolution::getSMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getSMaxExpr(Ops); } @@ -3088,9 +3123,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { const SCEV *ScalarEvolution::getUMaxExpr(const SCEV *LHS, const SCEV *RHS) { - SmallVector<const SCEV *, 2> Ops; - Ops.push_back(LHS); - Ops.push_back(RHS); + SmallVector<const SCEV *, 2> Ops = {LHS, RHS}; return getUMaxExpr(Ops); } @@ -3244,26 +3277,25 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { // Basic SCEV Analysis and PHI Idiom Recognition Code // -/// isSCEVable - Test if values of the given type are analyzable within -/// the SCEV framework. This primarily includes integer types, and it -/// can optionally include pointer types if the ScalarEvolution class -/// has access to target-specific information. +/// Test if values of the given type are analyzable within the SCEV +/// framework. This primarily includes integer types, and it can optionally +/// include pointer types if the ScalarEvolution class has access to +/// target-specific information. bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. return Ty->isIntegerTy() || Ty->isPointerTy(); } -/// getTypeSizeInBits - Return the size in bits of the specified type, -/// for which isSCEVable must return true. +/// Return the size in bits of the specified type, for which isSCEVable must +/// return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); return getDataLayout().getTypeSizeInBits(Ty); } -/// getEffectiveSCEVType - Return a type with the same bitwidth as -/// the given type and which represents how SCEV will treat the given -/// type, for which isSCEVable must return true. For pointer types, -/// this is the pointer-sized integer type. +/// Return a type with the same bitwidth as the given type and which represents +/// how SCEV will treat the given type, for which isSCEVable must return +/// true. For pointer types, this is the pointer-sized integer type. Type *ScalarEvolution::getEffectiveSCEVType(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); @@ -3310,15 +3342,88 @@ bool ScalarEvolution::checkValidity(const SCEV *S) const { return !F.FindOne; } -/// getSCEV - Return an existing SCEV if it exists, otherwise analyze the -/// expression and create a new one. +namespace { +// Helper class working with SCEVTraversal to figure out if a SCEV contains +// a sub SCEV of scAddRecExpr type. FindInvalidSCEVUnknown::FoundOne is set +// iff if such sub scAddRecExpr type SCEV is found. +struct FindAddRecurrence { + bool FoundOne; + FindAddRecurrence() : FoundOne(false) {} + + bool follow(const SCEV *S) { + switch (static_cast<SCEVTypes>(S->getSCEVType())) { + case scAddRecExpr: + FoundOne = true; + case scConstant: + case scUnknown: + case scCouldNotCompute: + return false; + default: + return true; + } + } + bool isDone() const { return FoundOne; } +}; +} + +bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { + HasRecMapType::iterator I = HasRecMap.find_as(S); + if (I != HasRecMap.end()) + return I->second; + + FindAddRecurrence F; + SCEVTraversal<FindAddRecurrence> ST(F); + ST.visitAll(S); + HasRecMap.insert({S, F.FoundOne}); + return F.FoundOne; +} + +/// Return the Value set from S. +SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) { + ExprValueMapType::iterator SI = ExprValueMap.find_as(S); + if (SI == ExprValueMap.end()) + return nullptr; +#ifndef NDEBUG + if (VerifySCEVMap) { + // Check there is no dangling Value in the set returned. + for (const auto &VE : SI->second) + assert(ValueExprMap.count(VE)); + } +#endif + return &SI->second; +} + +/// Erase Value from ValueExprMap and ExprValueMap. If ValueExprMap.erase(V) is +/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be +/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also +/// removed from the set of ExprValueMap[S]. +void ScalarEvolution::eraseValueFromMap(Value *V) { + ValueExprMapType::iterator I = ValueExprMap.find_as(V); + if (I != ValueExprMap.end()) { + const SCEV *S = I->second; + SetVector<Value *> *SV = getSCEVValues(S); + // Remove V from the set of ExprValueMap[S] + if (SV) + SV->remove(V); + ValueExprMap.erase(V); + } +} + +/// Return an existing SCEV if it exists, otherwise analyze the expression and +/// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { assert(isSCEVable(V->getType()) && "Value is not SCEVable!"); const SCEV *S = getExistingSCEV(V); if (S == nullptr) { S = createSCEV(V); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(V, this), S)); + // During PHI resolution, it is possible to create two SCEVs for the same + // V, so it is needed to double check whether V->S is inserted into + // ValueExprMap before insert S->V into ExprValueMap. + std::pair<ValueExprMapType::iterator, bool> Pair = + ValueExprMap.insert({SCEVCallbackVH(V, this), S}); + if (Pair.second) + ExprValueMap[S].insert(V); } return S; } @@ -3331,12 +3436,13 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { const SCEV *S = I->second; if (checkValidity(S)) return S; + forgetMemoizedResults(S); ValueExprMap.erase(I); } return nullptr; } -/// getNegativeSCEV - Return a SCEV corresponding to -V = -1*V +/// Return a SCEV corresponding to -V = -1*V /// const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags) { @@ -3350,7 +3456,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags); } -/// getNotSCEV - Return a SCEV corresponding to ~V = -1-V +/// Return a SCEV corresponding to ~V = -1-V const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) return getConstant( @@ -3363,7 +3469,6 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) { return getMinusSCEV(AllOnes, V); } -/// getMinusSCEV - Return LHS-RHS. Minus is represented in SCEV as A+B*-1. const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, SCEV::NoWrapFlags Flags) { // Fast path: X - X --> 0. @@ -3402,9 +3507,6 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, return getAddExpr(LHS, getNegativeSCEV(RHS, NegFlags), AddFlags); } -/// getTruncateOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3418,9 +3520,6 @@ ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getTruncateOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty) { @@ -3435,9 +3534,6 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, return getSignExtendExpr(V, Ty); } -/// getNoopOrZeroExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is zero -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3451,9 +3547,6 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { return getZeroExtendExpr(V, Ty); } -/// getNoopOrSignExtend - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. If the type must be extended, it is sign -/// extended. The conversion must not be narrowing. const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3467,10 +3560,6 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { return getSignExtendExpr(V, Ty); } -/// getNoopOrAnyExtend - Return a SCEV corresponding to a conversion of -/// the input value to the specified type. If the type must be extended, -/// it is extended with unspecified bits. The conversion must not be -/// narrowing. const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3484,8 +3573,6 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { return getAnyExtendExpr(V, Ty); } -/// getTruncateOrNoop - Return a SCEV corresponding to a conversion of the -/// input value to the specified type. The conversion must not be widening. const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); @@ -3499,9 +3586,6 @@ ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { return getTruncateExpr(V, Ty); } -/// getUMaxFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umax operation -/// with them. const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3515,9 +3599,6 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, return getUMaxExpr(PromotedLHS, PromotedRHS); } -/// getUMinFromMismatchedTypes - Promote the operands to the wider of -/// the types using zero-extension, and then perform a umin operation -/// with them. const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { const SCEV *PromotedLHS = LHS; @@ -3531,10 +3612,6 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, return getUMinExpr(PromotedLHS, PromotedRHS); } -/// getPointerBase - Transitively follow the chain of pointer-type operands -/// until reaching a SCEV that does not have a single pointer operand. This -/// returns a SCEVUnknown pointer for well-formed pointer-type expressions, -/// but corner cases do exist. const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { // A pointer operand may evaluate to a nonpointer expression, such as null. if (!V->getType()->isPointerTy()) @@ -3559,8 +3636,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { return V; } -/// PushDefUseChildren - Push users of the given Instruction -/// onto the given Worklist. +/// Push users of the given Instruction onto the given Worklist. static void PushDefUseChildren(Instruction *I, SmallVectorImpl<Instruction *> &Worklist) { @@ -3569,12 +3645,7 @@ PushDefUseChildren(Instruction *I, Worklist.push_back(cast<Instruction>(U)); } -/// ForgetSymbolicValue - This looks up computed SCEV values for all -/// instructions that depend on the given instruction and removes them from -/// the ValueExprMapType map if they reference SymName. This is used during PHI -/// resolution. -void -ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { +void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { SmallVector<Instruction *, 16> Worklist; PushDefUseChildren(PN, Worklist); @@ -3616,10 +3687,10 @@ ScalarEvolution::ForgetSymbolicName(Instruction *PN, const SCEV *SymName) { namespace { class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVInitRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3649,10 +3720,10 @@ private: class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, const Loop *L, + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { SCEVShiftRewriter Rewriter(L, SE); - const SCEV *Result = Rewriter.visit(Scev); + const SCEV *Result = Rewriter.visit(S); return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } @@ -3680,6 +3751,167 @@ private: }; } // end anonymous namespace +SCEV::NoWrapFlags +ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { + if (!AR->isAffine()) + return SCEV::FlagAnyWrap; + + typedef OverflowingBinaryOperator OBO; + SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; + + if (!AR->hasNoSignedWrap()) { + ConstantRange AddRecRange = getSignedRange(AR); + ConstantRange IncRange = getSignedRange(AR->getStepRecurrence(*this)); + + auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoSignedWrap); + if (NSWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNSW); + } + + if (!AR->hasNoUnsignedWrap()) { + ConstantRange AddRecRange = getUnsignedRange(AR); + ConstantRange IncRange = getUnsignedRange(AR->getStepRecurrence(*this)); + + auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + Instruction::Add, IncRange, OBO::NoUnsignedWrap); + if (NUWRegion.contains(AddRecRange)) + Result = ScalarEvolution::setFlags(Result, SCEV::FlagNUW); + } + + return Result; +} + +namespace { +/// Represents an abstract binary operation. This may exist as a +/// normal instruction or constant expression, or may have been +/// derived from an expression tree. +struct BinaryOp { + unsigned Opcode; + Value *LHS; + Value *RHS; + bool IsNSW; + bool IsNUW; + + /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or + /// constant expression. + Operator *Op; + + explicit BinaryOp(Operator *Op) + : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), + IsNSW(false), IsNUW(false), Op(Op) { + if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { + IsNSW = OBO->hasNoSignedWrap(); + IsNUW = OBO->hasNoUnsignedWrap(); + } + } + + explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, + bool IsNUW = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), + Op(nullptr) {} +}; +} + + +/// Try to map \p V into a BinaryOp, and return \c None on failure. +static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return None; + + // Implementation detail: all the cleverness here should happen without + // creating new SCEV expressions -- our caller knowns tricks to avoid creating + // SCEV expressions when possible, and we should not break that. + + switch (Op->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::And: + case Instruction::Or: + case Instruction::AShr: + case Instruction::Shl: + return BinaryOp(Op); + + case Instruction::Xor: + if (auto *RHSC = dyn_cast<ConstantInt>(Op->getOperand(1))) + // If the RHS of the xor is a signbit, then this is just an add. + // Instcombine turns add of signbit into xor as a strength reduction step. + if (RHSC->getValue().isSignBit()) + return BinaryOp(Instruction::Add, Op->getOperand(0), Op->getOperand(1)); + return BinaryOp(Op); + + case Instruction::LShr: + // Turn logical shift right of a constant into a unsigned divide. + if (ConstantInt *SA = dyn_cast<ConstantInt>(Op->getOperand(1))) { + uint32_t BitWidth = cast<IntegerType>(Op->getType())->getBitWidth(); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().ult(BitWidth)) { + Constant *X = + ConstantInt::get(SA->getContext(), + APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return BinaryOp(Instruction::UDiv, Op->getOperand(0), X); + } + } + return BinaryOp(Op); + + case Instruction::ExtractValue: { + auto *EVI = cast<ExtractValueInst>(Op); + if (EVI->getNumIndices() != 1 || EVI->getIndices()[0] != 0) + break; + + auto *CI = dyn_cast<CallInst>(EVI->getAggregateOperand()); + if (!CI) + break; + + if (auto *F = CI->getCalledFunction()) + switch (F->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: { + if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1)); + + // Now that we know that all uses of the arithmetic-result component of + // CI are guarded by the overflow check, we can go ahead and pretend + // that the arithmetic is non-overflowing. + if (F->getIntrinsicID() == Intrinsic::sadd_with_overflow) + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ true, + /* IsNUW = */ false); + else + return BinaryOp(Instruction::Add, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ false, + /* IsNUW*/ true); + } + + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1)); + + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + return BinaryOp(Instruction::Mul, CI->getArgOperand(0), + CI->getArgOperand(1)); + default: + break; + } + } + + default: + break; + } + + return None; +} + const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const Loop *L = LI.getLoopFor(PN->getParent()); if (!L || L->getHeader() != PN->getParent()) @@ -3710,7 +3942,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *SymbolicName = getUnknown(PN); assert(ValueExprMap.find_as(PN) == ValueExprMap.end() && "PHI node already processed?"); - ValueExprMap.insert(std::make_pair(SCEVCallbackVH(PN, this), SymbolicName)); + ValueExprMap.insert({SCEVCallbackVH(PN, this), SymbolicName}); // Using this symbolic name for the PHI, analyze the value coming around // the back-edge. @@ -3747,13 +3979,11 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; - // If the increment doesn't overflow, then neither the addrec nor - // the post-increment will overflow. - if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) { - if (OBO->getOperand(0) == PN) { - if (OBO->hasNoUnsignedWrap()) + if (auto BO = MatchBinaryOp(BEValueV, DT)) { + if (BO->Opcode == Instruction::Add && BO->LHS == PN) { + if (BO->IsNUW) Flags = setFlags(Flags, SCEV::FlagNUW); - if (OBO->hasNoSignedWrap()) + if (BO->IsNSW) Flags = setFlags(Flags, SCEV::FlagNSW); } } else if (GEPOperator *GEP = dyn_cast<GEPOperator>(BEValueV)) { @@ -3779,16 +4009,19 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { const SCEV *StartVal = getSCEV(StartValueV); const SCEV *PHISCEV = getAddRecExpr(StartVal, Accum, L, Flags); - // Since the no-wrap flags are on the increment, they apply to the - // post-incremented value as well. - if (isLoopInvariant(Accum, L)) - (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); - // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = PHISCEV; + + // We can add Flags to the post-inc expression only if we + // know that it us *undefined behavior* for BEValueV to + // overflow. + if (auto *BEInst = dyn_cast<Instruction>(BEValueV)) + if (isLoopInvariant(Accum, L) && isAddRecNeverPoison(BEInst, L)) + (void)getAddRecExpr(getAddExpr(StartVal, Accum), Accum, L, Flags); + return PHISCEV; } } @@ -3811,12 +4044,18 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // Okay, for the entire analysis of this edge we assumed the PHI // to be symbolic. We now need to go back and purge all of the // entries for the scalars that use the symbolic expression. - ForgetSymbolicName(PN, SymbolicName); + forgetSymbolicName(PN, SymbolicName); ValueExprMap[SCEVCallbackVH(PN, this)] = Shifted; return Shifted; } } } + + // Remove the temporary PHI node SCEV that has been inserted while intending + // to create an AddRecExpr for this PHI node. We can not keep this temporary + // as it will prevent later (possibly simpler) SCEV expressions to be added + // to the ValueExprMap. + ValueExprMap.erase(PN); } return nullptr; @@ -4083,26 +4322,21 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I, return getUnknown(I); } -/// createNodeForGEP - Expand GEP instructions into add and multiply -/// operations. This allows them to be analyzed by regular SCEV code. -/// +/// Expand GEP instructions into add and multiply operations. This allows them +/// to be analyzed by regular SCEV code. const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) { - Value *Base = GEP->getOperand(0); // Don't attempt to analyze GEPs over unsized objects. - if (!Base->getType()->getPointerElementType()->isSized()) + if (!GEP->getSourceElementType()->isSized()) return getUnknown(GEP); SmallVector<const SCEV *, 4> IndexExprs; for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) IndexExprs.push_back(getSCEV(*Index)); - return getGEPExpr(GEP->getSourceElementType(), getSCEV(Base), IndexExprs, - GEP->isInBounds()); + return getGEPExpr(GEP->getSourceElementType(), + getSCEV(GEP->getPointerOperand()), + IndexExprs, GEP->isInBounds()); } -/// GetMinTrailingZeros - Determine the minimum number of zero bits that S is -/// guaranteed to end in (at every loop iteration). It is, at the same time, -/// the minimum number of times S is divisible by 2. For example, given {4,+,8} -/// it returns 2. If S is guaranteed to be 0, it returns the bitwidth of S. uint32_t ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) @@ -4180,8 +4414,7 @@ ScalarEvolution::GetMinTrailingZeros(const SCEV *S) { return 0; } -/// GetRangeFromMetadata - Helper method to assign a range to V from -/// metadata present in the IR. +/// Helper method to assign a range to V from metadata present in the IR. static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { if (Instruction *I = dyn_cast<Instruction>(V)) if (MDNode *MD = I->getMetadata(LLVMContext::MD_range)) @@ -4190,10 +4423,9 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) { return None; } -/// getRange - Determine the range for a particular SCEV. If SignHint is +/// Determine the range for a particular SCEV. If SignHint is /// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges /// with a "cleaner" unsigned (resp. signed) representation. -/// ConstantRange ScalarEvolution::getRange(const SCEV *S, ScalarEvolution::RangeSignHint SignHint) { @@ -4282,7 +4514,7 @@ ScalarEvolution::getRange(const SCEV *S, if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) { // If there's no unsigned wrap, the value will never be less than its // initial value. - if (AddRec->getNoWrapFlags(SCEV::FlagNUW)) + if (AddRec->hasNoUnsignedWrap()) if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart())) if (!C->getValue()->isZero()) ConservativeResult = ConservativeResult.intersectWith( @@ -4290,7 +4522,7 @@ ScalarEvolution::getRange(const SCEV *S, // If there's no signed wrap, and all the operands have the same sign or // zero, the value won't ever change sign. - if (AddRec->getNoWrapFlags(SCEV::FlagNSW)) { + if (AddRec->hasNoSignedWrap()) { bool AllNonNeg = true; bool AllNonPos = true; for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) { @@ -4309,66 +4541,22 @@ ScalarEvolution::getRange(const SCEV *S, // TODO: non-affine addrec if (AddRec->isAffine()) { - Type *Ty = AddRec->getType(); const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop()); if (!isa<SCEVCouldNotCompute>(MaxBECount) && getTypeSizeInBits(MaxBECount->getType()) <= BitWidth) { - - // Check for overflow. This must be done with ConstantRange arithmetic - // because we could be called from within the ScalarEvolution overflow - // checking code. - - MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty); - ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); - ConstantRange ZExtMaxBECountRange = - MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); - - const SCEV *Start = AddRec->getStart(); - const SCEV *Step = AddRec->getStepRecurrence(*this); - ConstantRange StepSRange = getSignedRange(Step); - ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); - - ConstantRange StartURange = getUnsignedRange(Start); - ConstantRange EndURange = - StartURange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for unsigned overflow. - ConstantRange ZExtStartURange = - StartURange.zextOrTrunc(BitWidth * 2 + 1); - ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); - if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - ZExtEndURange) { - APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), - EndURange.getUnsignedMin()); - APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), - EndURange.getUnsignedMax()); - bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } - - ConstantRange StartSRange = getSignedRange(Start); - ConstantRange EndSRange = - StartSRange.add(MaxBECountRange.multiply(StepSRange)); - - // Check for signed overflow. This must be done with ConstantRange - // arithmetic because we could be called from within the ScalarEvolution - // overflow checking code. - ConstantRange SExtStartSRange = - StartSRange.sextOrTrunc(BitWidth * 2 + 1); - ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); - if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == - SExtEndSRange) { - APInt Min = APIntOps::smin(StartSRange.getSignedMin(), - EndSRange.getSignedMin()); - APInt Max = APIntOps::smax(StartSRange.getSignedMax(), - EndSRange.getSignedMax()); - bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); - if (!IsFullRange) - ConservativeResult = - ConservativeResult.intersectWith(ConstantRange(Min, Max + 1)); - } + auto RangeFromAffine = getRangeForAffineAR( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromAffine.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromAffine); + + auto RangeFromFactoring = getRangeViaFactoring( + AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount, + BitWidth); + if (!RangeFromFactoring.isFullSet()) + ConservativeResult = + ConservativeResult.intersectWith(RangeFromFactoring); } } @@ -4408,6 +4596,186 @@ ScalarEvolution::getRange(const SCEV *S, return setRange(S, SignHint, ConservativeResult); } +ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + assert(!isa<SCEVCouldNotCompute>(MaxBECount) && + getTypeSizeInBits(MaxBECount->getType()) <= BitWidth && + "Precondition!"); + + ConstantRange Result(BitWidth, /* isFullSet = */ true); + + // Check for overflow. This must be done with ConstantRange arithmetic + // because we could be called from within the ScalarEvolution overflow + // checking code. + + MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType()); + ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); + ConstantRange ZExtMaxBECountRange = + MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StepSRange = getSignedRange(Step); + ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); + + ConstantRange StartURange = getUnsignedRange(Start); + ConstantRange EndURange = + StartURange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for unsigned overflow. + ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); + ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); + if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + ZExtEndURange) { + APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), + EndURange.getUnsignedMin()); + APInt Max = APIntOps::umax(StartURange.getUnsignedMax(), + EndURange.getUnsignedMax()); + bool IsFullRange = Min.isMinValue() && Max.isMaxValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + ConstantRange StartSRange = getSignedRange(Start); + ConstantRange EndSRange = + StartSRange.add(MaxBECountRange.multiply(StepSRange)); + + // Check for signed overflow. This must be done with ConstantRange + // arithmetic because we could be called from within the ScalarEvolution + // overflow checking code. + ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); + ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); + if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) == + SExtEndSRange) { + APInt Min = + APIntOps::smin(StartSRange.getSignedMin(), EndSRange.getSignedMin()); + APInt Max = + APIntOps::smax(StartSRange.getSignedMax(), EndSRange.getSignedMax()); + bool IsFullRange = Min.isMinSignedValue() && Max.isMaxSignedValue(); + if (!IsFullRange) + Result = + Result.intersectWith(ConstantRange(Min, Max + 1)); + } + + return Result; +} + +ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start, + const SCEV *Step, + const SCEV *MaxBECount, + unsigned BitWidth) { + // RangeOf({C?A:B,+,C?P:Q}) == RangeOf(C?{A,+,P}:{B,+,Q}) + // == RangeOf({A,+,P}) union RangeOf({B,+,Q}) + + struct SelectPattern { + Value *Condition = nullptr; + APInt TrueValue; + APInt FalseValue; + + explicit SelectPattern(ScalarEvolution &SE, unsigned BitWidth, + const SCEV *S) { + Optional<unsigned> CastOp; + APInt Offset(BitWidth, 0); + + assert(SE.getTypeSizeInBits(S->getType()) == BitWidth && + "Should be!"); + + // Peel off a constant offset: + if (auto *SA = dyn_cast<SCEVAddExpr>(S)) { + // In the future we could consider being smarter here and handle + // {Start+Step,+,Step} too. + if (SA->getNumOperands() != 2 || !isa<SCEVConstant>(SA->getOperand(0))) + return; + + Offset = cast<SCEVConstant>(SA->getOperand(0))->getAPInt(); + S = SA->getOperand(1); + } + + // Peel off a cast operation + if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) { + CastOp = SCast->getSCEVType(); + S = SCast->getOperand(); + } + + using namespace llvm::PatternMatch; + + auto *SU = dyn_cast<SCEVUnknown>(S); + const APInt *TrueVal, *FalseVal; + if (!SU || + !match(SU->getValue(), m_Select(m_Value(Condition), m_APInt(TrueVal), + m_APInt(FalseVal)))) { + Condition = nullptr; + return; + } + + TrueValue = *TrueVal; + FalseValue = *FalseVal; + + // Re-apply the cast we peeled off earlier + if (CastOp.hasValue()) + switch (*CastOp) { + default: + llvm_unreachable("Unknown SCEV cast type!"); + + case scTruncate: + TrueValue = TrueValue.trunc(BitWidth); + FalseValue = FalseValue.trunc(BitWidth); + break; + case scZeroExtend: + TrueValue = TrueValue.zext(BitWidth); + FalseValue = FalseValue.zext(BitWidth); + break; + case scSignExtend: + TrueValue = TrueValue.sext(BitWidth); + FalseValue = FalseValue.sext(BitWidth); + break; + } + + // Re-apply the constant offset we peeled off earlier + TrueValue += Offset; + FalseValue += Offset; + } + + bool isRecognized() { return Condition != nullptr; } + }; + + SelectPattern StartPattern(*this, BitWidth, Start); + if (!StartPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + SelectPattern StepPattern(*this, BitWidth, Step); + if (!StepPattern.isRecognized()) + return ConstantRange(BitWidth, /* isFullSet = */ true); + + if (StartPattern.Condition != StepPattern.Condition) { + // We don't handle this case today; but we could, by considering four + // possibilities below instead of two. I'm not sure if there are cases where + // that will help over what getRange already does, though. + return ConstantRange(BitWidth, /* isFullSet = */ true); + } + + // NB! Calling ScalarEvolution::getConstant is fine, but we should not try to + // construct arbitrary general SCEV expressions here. This function is called + // from deep in the call stack, and calling getSCEV (on a sext instruction, + // say) can end up caching a suboptimal value. + + // FIXME: without the explicit `this` receiver below, MSVC errors out with + // C2352 and C2512 (otherwise it isn't needed). + + const SCEV *TrueStart = this->getConstant(StartPattern.TrueValue); + const SCEV *TrueStep = this->getConstant(StepPattern.TrueValue); + const SCEV *FalseStart = this->getConstant(StartPattern.FalseValue); + const SCEV *FalseStep = this->getConstant(StepPattern.FalseValue); + + ConstantRange TrueRange = + this->getRangeForAffineAR(TrueStart, TrueStep, MaxBECount, BitWidth); + ConstantRange FalseRange = + this->getRangeForAffineAR(FalseStart, FalseStep, MaxBECount, BitWidth); + + return TrueRange.unionWith(FalseRange); +} + SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { if (isa<ConstantExpr>(V)) return SCEV::FlagAnyWrap; const BinaryOperator *BinOp = cast<BinaryOperator>(V); @@ -4418,273 +4786,359 @@ SCEV::NoWrapFlags ScalarEvolution::getNoWrapFlagsFromUB(const Value *V) { Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); if (BinOp->hasNoSignedWrap()) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); - if (Flags == SCEV::FlagAnyWrap) { + if (Flags == SCEV::FlagAnyWrap) return SCEV::FlagAnyWrap; - } - // Here we check that BinOp is in the header of the innermost loop - // containing BinOp, since we only deal with instructions in the loop - // header. The actual loop we need to check later will come from an add - // recurrence, but getting that requires computing the SCEV of the operands, - // which can be expensive. This check we can do cheaply to rule out some - // cases early. - Loop *innermostContainingLoop = LI.getLoopFor(BinOp->getParent()); - if (innermostContainingLoop == nullptr || - innermostContainingLoop->getHeader() != BinOp->getParent()) - return SCEV::FlagAnyWrap; + return isSCEVExprNeverPoison(BinOp) ? Flags : SCEV::FlagAnyWrap; +} - // Only proceed if we can prove that BinOp does not yield poison. - if (!isKnownNotFullPoison(BinOp)) return SCEV::FlagAnyWrap; +bool ScalarEvolution::isSCEVExprNeverPoison(const Instruction *I) { + // Here we check that I is in the header of the innermost loop containing I, + // since we only deal with instructions in the loop header. The actual loop we + // need to check later will come from an add recurrence, but getting that + // requires computing the SCEV of the operands, which can be expensive. This + // check we can do cheaply to rule out some cases early. + Loop *InnermostContainingLoop = LI.getLoopFor(I->getParent()); + if (InnermostContainingLoop == nullptr || + InnermostContainingLoop->getHeader() != I->getParent()) + return false; + + // Only proceed if we can prove that I does not yield poison. + if (!isKnownNotFullPoison(I)) return false; - // At this point we know that if V is executed, then it does not wrap - // according to at least one of NSW or NUW. If V is not executed, then we do - // not know if the calculation that V represents would wrap. Multiple - // instructions can map to the same SCEV. If we apply NSW or NUW from V to + // At this point we know that if I is executed, then it does not wrap + // according to at least one of NSW or NUW. If I is not executed, then we do + // not know if the calculation that I represents would wrap. Multiple + // instructions can map to the same SCEV. If we apply NSW or NUW from I to // the SCEV, we must guarantee no wrapping for that SCEV also when it is // derived from other instructions that map to the same SCEV. We cannot make - // that guarantee for cases where V is not executed. So we need to find the - // loop that V is considered in relation to and prove that V is executed for - // every iteration of that loop. That implies that the value that V + // that guarantee for cases where I is not executed. So we need to find the + // loop that I is considered in relation to and prove that I is executed for + // every iteration of that loop. That implies that the value that I // calculates does not wrap anywhere in the loop, so then we can apply the // flags to the SCEV. // - // We check isLoopInvariant to disambiguate in case we are adding two - // recurrences from different loops, so that we know which loop to prove - // that V is executed in. - for (int OpIndex = 0; OpIndex < 2; ++OpIndex) { - const SCEV *Op = getSCEV(BinOp->getOperand(OpIndex)); + // We check isLoopInvariant to disambiguate in case we are adding recurrences + // from different loops, so that we know which loop to prove that I is + // executed in. + for (unsigned OpIndex = 0; OpIndex < I->getNumOperands(); ++OpIndex) { + const SCEV *Op = getSCEV(I->getOperand(OpIndex)); if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(Op)) { - const int OtherOpIndex = 1 - OpIndex; - const SCEV *OtherOp = getSCEV(BinOp->getOperand(OtherOpIndex)); - if (isLoopInvariant(OtherOp, AddRec->getLoop()) && - isGuaranteedToExecuteForEveryIteration(BinOp, AddRec->getLoop())) - return Flags; + bool AllOtherOpsLoopInvariant = true; + for (unsigned OtherOpIndex = 0; OtherOpIndex < I->getNumOperands(); + ++OtherOpIndex) { + if (OtherOpIndex != OpIndex) { + const SCEV *OtherOp = getSCEV(I->getOperand(OtherOpIndex)); + if (!isLoopInvariant(OtherOp, AddRec->getLoop())) { + AllOtherOpsLoopInvariant = false; + break; + } + } + } + if (AllOtherOpsLoopInvariant && + isGuaranteedToExecuteForEveryIteration(I, AddRec->getLoop())) + return true; + } + } + return false; +} + +bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { + // If we know that \c I can never be poison period, then that's enough. + if (isSCEVExprNeverPoison(I)) + return true; + + // For an add recurrence specifically, we assume that infinite loops without + // side effects are undefined behavior, and then reason as follows: + // + // If the add recurrence is poison in any iteration, it is poison on all + // future iterations (since incrementing poison yields poison). If the result + // of the add recurrence is fed into the loop latch condition and the loop + // does not contain any throws or exiting blocks other than the latch, we now + // have the ability to "choose" whether the backedge is taken or not (by + // choosing a sufficiently evil value for the poison feeding into the branch) + // for every iteration including and after the one in which \p I first became + // poison. There are two possibilities (let's call the iteration in which \p + // I first became poison as K): + // + // 1. In the set of iterations including and after K, the loop body executes + // no side effects. In this case executing the backege an infinte number + // of times will yield undefined behavior. + // + // 2. In the set of iterations including and after K, the loop body executes + // at least one side effect. In this case, that specific instance of side + // effect is control dependent on poison, which also yields undefined + // behavior. + + auto *ExitingBB = L->getExitingBlock(); + auto *LatchBB = L->getLoopLatch(); + if (!ExitingBB || !LatchBB || ExitingBB != LatchBB) + return false; + + SmallPtrSet<const Instruction *, 16> Pushed; + SmallVector<const Instruction *, 8> PoisonStack; + + // We start by assuming \c I, the post-inc add recurrence, is poison. Only + // things that are known to be fully poison under that assumption go on the + // PoisonStack. + Pushed.insert(I); + PoisonStack.push_back(I); + + bool LatchControlDependentOnPoison = false; + while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { + const Instruction *Poison = PoisonStack.pop_back_val(); + + for (auto *PoisonUser : Poison->users()) { + if (propagatesFullPoison(cast<Instruction>(PoisonUser))) { + if (Pushed.insert(cast<Instruction>(PoisonUser)).second) + PoisonStack.push_back(cast<Instruction>(PoisonUser)); + } else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) { + assert(BI->isConditional() && "Only possibility!"); + if (BI->getParent() == LatchBB) { + LatchControlDependentOnPoison = true; + break; + } + } } } - return SCEV::FlagAnyWrap; + + return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L); +} + +bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { + auto Itr = LoopHasNoAbnormalExits.find(L); + if (Itr == LoopHasNoAbnormalExits.end()) { + auto NoAbnormalExitInBB = [&](BasicBlock *BB) { + return all_of(*BB, [](Instruction &I) { + return isGuaranteedToTransferExecutionToSuccessor(&I); + }); + }; + + auto InsertPair = LoopHasNoAbnormalExits.insert( + {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); + assert(InsertPair.second && "We just checked!"); + Itr = InsertPair.first; + } + + return Itr->second; } -/// createSCEV - We know that there is no SCEV for the specified value. Analyze -/// the expression. -/// const SCEV *ScalarEvolution::createSCEV(Value *V) { if (!isSCEVable(V->getType())) return getUnknown(V); - unsigned Opcode = Instruction::UserOp1; if (Instruction *I = dyn_cast<Instruction>(V)) { - Opcode = I->getOpcode(); - // Don't attempt to analyze instructions in blocks that aren't // reachable. Such instructions don't matter, and they aren't required // to obey basic rules for definitions dominating uses which this // analysis depends on. if (!DT.isReachableFromEntry(I->getParent())) return getUnknown(V); - } else if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) - Opcode = CE->getOpcode(); - else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) + } else if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) return getConstant(CI); else if (isa<ConstantPointerNull>(V)) return getZero(V->getType()); else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) - return GA->mayBeOverridden() ? getUnknown(V) : getSCEV(GA->getAliasee()); - else + return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee()); + else if (!isa<ConstantExpr>(V)) return getUnknown(V); Operator *U = cast<Operator>(V); - switch (Opcode) { - case Instruction::Add: { - // The simple thing to do would be to just call getSCEV on both operands - // and call getAddExpr with the result. However if we're looking at a - // bunch of things all added together, this can be quite inefficient, - // because it leads to N-1 getAddExpr calls for N ultimate operands. - // Instead, gather up all the operands and make a single getAddExpr call. - // LLVM IR canonical form means we need only traverse the left operands. - SmallVector<const SCEV *, 4> AddOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - unsigned Opcode = U ? U->getOpcode() : 0; - if (!U || (Opcode != Instruction::Add && Opcode != Instruction::Sub)) { - assert(Op != V && "V should be an add"); - AddOps.push_back(getSCEV(Op)); - break; - } + if (auto BO = MatchBinaryOp(U, DT)) { + switch (BO->Opcode) { + case Instruction::Add: { + // The simple thing to do would be to just call getSCEV on both operands + // and call getAddExpr with the result. However if we're looking at a + // bunch of things all added together, this can be quite inefficient, + // because it leads to N-1 getAddExpr calls for N ultimate operands. + // Instead, gather up all the operands and make a single getAddExpr call. + // LLVM IR canonical form means we need only traverse the left operands. + SmallVector<const SCEV *, 4> AddOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + AddOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - AddOps.push_back(OpSCEV); - break; - } + // If a NUW or NSW flag can be applied to the SCEV for this + // addition, then compute the SCEV for this addition by itself + // with a separate call to getAddExpr. We need to do that + // instead of pushing the operands of the addition onto AddOps, + // since the flags are only known to apply to this particular + // addition - they may not apply to other additions that can be + // formed with operands from AddOps. + const SCEV *RHS = getSCEV(BO->RHS); + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + const SCEV *LHS = getSCEV(BO->LHS); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + else + AddOps.push_back(getAddExpr(LHS, RHS, Flags)); + break; + } + } - // If a NUW or NSW flag can be applied to the SCEV for this - // addition, then compute the SCEV for this addition by itself - // with a separate call to getAddExpr. We need to do that - // instead of pushing the operands of the addition onto AddOps, - // since the flags are only known to apply to this particular - // addition - they may not apply to other additions that can be - // formed with operands from AddOps. - const SCEV *RHS = getSCEV(U->getOperand(1)); - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getMinusSCEV(LHS, RHS, Flags)); + if (BO->Opcode == Instruction::Sub) + AddOps.push_back(getNegativeSCEV(getSCEV(BO->RHS))); else - AddOps.push_back(getAddExpr(LHS, RHS, Flags)); - break; - } + AddOps.push_back(getSCEV(BO->RHS)); - if (Opcode == Instruction::Sub) - AddOps.push_back(getNegativeSCEV(RHS)); - else - AddOps.push_back(RHS); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || (NewBO->Opcode != Instruction::Add && + NewBO->Opcode != Instruction::Sub)) { + AddOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getAddExpr(AddOps); } - return getAddExpr(AddOps); - } - case Instruction::Mul: { - SmallVector<const SCEV *, 4> MulOps; - for (Value *Op = U;; Op = U->getOperand(0)) { - U = dyn_cast<Operator>(Op); - if (!U || U->getOpcode() != Instruction::Mul) { - assert(Op != V && "V should be a mul"); - MulOps.push_back(getSCEV(Op)); - break; - } + case Instruction::Mul: { + SmallVector<const SCEV *, 4> MulOps; + do { + if (BO->Op) { + if (auto *OpSCEV = getExistingSCEV(BO->Op)) { + MulOps.push_back(OpSCEV); + break; + } - if (auto *OpSCEV = getExistingSCEV(U)) { - MulOps.push_back(OpSCEV); - break; - } + SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(BO->Op); + if (Flags != SCEV::FlagAnyWrap) { + MulOps.push_back( + getMulExpr(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags)); + break; + } + } - SCEV::NoWrapFlags Flags = getNoWrapFlagsFromUB(U); - if (Flags != SCEV::FlagAnyWrap) { - MulOps.push_back(getMulExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1)), Flags)); - break; + MulOps.push_back(getSCEV(BO->RHS)); + auto NewBO = MatchBinaryOp(BO->LHS, DT); + if (!NewBO || NewBO->Opcode != Instruction::Mul) { + MulOps.push_back(getSCEV(BO->LHS)); + break; + } + BO = NewBO; + } while (true); + + return getMulExpr(MulOps); + } + case Instruction::UDiv: + return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::Sub: { + SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; + if (BO->Op) + Flags = getNoWrapFlagsFromUB(BO->Op); + return getMinusSCEV(getSCEV(BO->LHS), getSCEV(BO->RHS), Flags); + } + case Instruction::And: + // For an expression like x&255 that merely masks off the high bits, + // use zext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + if (CI->isNullValue()) + return getSCEV(BO->RHS); + if (CI->isAllOnesValue()) + return getSCEV(BO->LHS); + const APInt &A = CI->getValue(); + + // Instcombine's ShrinkDemandedConstant may strip bits out of + // constants, obscuring what would otherwise be a low-bits mask. + // Use computeKnownBits to compute what ShrinkDemandedConstant + // knew about to reconstruct a low-bits mask value. + unsigned LZ = A.countLeadingZeros(); + unsigned TZ = A.countTrailingZeros(); + unsigned BitWidth = A.getBitWidth(); + APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); + computeKnownBits(BO->LHS, KnownZero, KnownOne, getDataLayout(), + 0, &AC, nullptr, &DT); + + APInt EffectiveMask = + APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); + if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { + const SCEV *MulCount = getConstant(ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, TZ))); + return getMulExpr( + getZeroExtendExpr( + getTruncateExpr( + getUDivExactExpr(getSCEV(BO->LHS), MulCount), + IntegerType::get(getContext(), BitWidth - LZ - TZ)), + BO->LHS->getType()), + MulCount); + } } + break; - MulOps.push_back(getSCEV(U->getOperand(1))); - } - return getMulExpr(MulOps); - } - case Instruction::UDiv: - return getUDivExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - case Instruction::Sub: - return getMinusSCEV(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)), - getNoWrapFlagsFromUB(U)); - case Instruction::And: - // For an expression like x&255 that merely masks off the high bits, - // use zext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - if (CI->isNullValue()) - return getSCEV(U->getOperand(1)); - if (CI->isAllOnesValue()) - return getSCEV(U->getOperand(0)); - const APInt &A = CI->getValue(); - - // Instcombine's ShrinkDemandedConstant may strip bits out of - // constants, obscuring what would otherwise be a low-bits mask. - // Use computeKnownBits to compute what ShrinkDemandedConstant - // knew about to reconstruct a low-bits mask value. - unsigned LZ = A.countLeadingZeros(); - unsigned TZ = A.countTrailingZeros(); - unsigned BitWidth = A.getBitWidth(); - APInt KnownZero(BitWidth, 0), KnownOne(BitWidth, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, getDataLayout(), - 0, &AC, nullptr, &DT); - - APInt EffectiveMask = - APInt::getLowBitsSet(BitWidth, BitWidth - LZ - TZ).shl(TZ); - if ((LZ != 0 || TZ != 0) && !((~A & ~KnownZero) & EffectiveMask)) { - const SCEV *MulCount = getConstant( - ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, TZ))); - return getMulExpr( - getZeroExtendExpr( - getTruncateExpr( - getUDivExactExpr(getSCEV(U->getOperand(0)), MulCount), - IntegerType::get(getContext(), BitWidth - LZ - TZ)), - U->getType()), - MulCount); + case Instruction::Or: + // If the RHS of the Or is a constant, we may have something like: + // X*4+1 which got turned into X*4|1. Handle this as an Add so loop + // optimizations will transparently handle this case. + // + // In order for this transformation to be safe, the LHS must be of the + // form X*(2^n) and the Or constant must be less than 2^n. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + const SCEV *LHS = getSCEV(BO->LHS); + const APInt &CIVal = CI->getValue(); + if (GetMinTrailingZeros(LHS) >= + (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { + // Build a plain add SCEV. + const SCEV *S = getAddExpr(LHS, getSCEV(CI)); + // If the LHS of the add was an addrec and it has no-wrap flags, + // transfer the no-wrap flags, since an or won't introduce a wrap. + if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { + const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); + const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( + OldAR->getNoWrapFlags()); + } + return S; + } } - } - break; + break; - case Instruction::Or: - // If the RHS of the Or is a constant, we may have something like: - // X*4+1 which got turned into X*4|1. Handle this as an Add so loop - // optimizations will transparently handle this case. - // - // In order for this transformation to be safe, the LHS must be of the - // form X*(2^n) and the Or constant must be less than 2^n. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - const SCEV *LHS = getSCEV(U->getOperand(0)); - const APInt &CIVal = CI->getValue(); - if (GetMinTrailingZeros(LHS) >= - (CIVal.getBitWidth() - CIVal.countLeadingZeros())) { - // Build a plain add SCEV. - const SCEV *S = getAddExpr(LHS, getSCEV(CI)); - // If the LHS of the add was an addrec and it has no-wrap flags, - // transfer the no-wrap flags, since an or won't introduce a wrap. - if (const SCEVAddRecExpr *NewAR = dyn_cast<SCEVAddRecExpr>(S)) { - const SCEVAddRecExpr *OldAR = cast<SCEVAddRecExpr>(LHS); - const_cast<SCEVAddRecExpr *>(NewAR)->setNoWrapFlags( - OldAR->getNoWrapFlags()); - } - return S; + case Instruction::Xor: + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) { + // If the RHS of xor is -1, then this is a not operation. + if (CI->isAllOnesValue()) + return getNotSCEV(getSCEV(BO->LHS)); + + // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. + // This is a variant of the check for xor with -1, and it handles + // the case where instcombine has trimmed non-demanded bits out + // of an xor with -1. + if (auto *LBO = dyn_cast<BinaryOperator>(BO->LHS)) + if (ConstantInt *LCI = dyn_cast<ConstantInt>(LBO->getOperand(1))) + if (LBO->getOpcode() == Instruction::And && + LCI->getValue() == CI->getValue()) + if (const SCEVZeroExtendExpr *Z = + dyn_cast<SCEVZeroExtendExpr>(getSCEV(BO->LHS))) { + Type *UTy = BO->LHS->getType(); + const SCEV *Z0 = Z->getOperand(); + Type *Z0Ty = Z0->getType(); + unsigned Z0TySize = getTypeSizeInBits(Z0Ty); + + // If C is a low-bits mask, the zero extend is serving to + // mask off the high bits. Complement the operand and + // re-apply the zext. + if (APIntOps::isMask(Z0TySize, CI->getValue())) + return getZeroExtendExpr(getNotSCEV(Z0), UTy); + + // If C is a single bit, it may be in the sign-bit position + // before the zero-extend. In this case, represent the xor + // using an add, which is equivalent, and re-apply the zext. + APInt Trunc = CI->getValue().trunc(Z0TySize); + if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && + Trunc.isSignBit()) + return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), + UTy); + } } - } - break; - case Instruction::Xor: - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) { - // If the RHS of the xor is a signbit, then this is just an add. - // Instcombine turns add of signbit into xor as a strength reduction step. - if (CI->getValue().isSignBit()) - return getAddExpr(getSCEV(U->getOperand(0)), - getSCEV(U->getOperand(1))); - - // If the RHS of xor is -1, then this is a not operation. - if (CI->isAllOnesValue()) - return getNotSCEV(getSCEV(U->getOperand(0))); - - // Model xor(and(x, C), C) as and(~x, C), if C is a low-bits mask. - // This is a variant of the check for xor with -1, and it handles - // the case where instcombine has trimmed non-demanded bits out - // of an xor with -1. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(U->getOperand(0))) - if (ConstantInt *LCI = dyn_cast<ConstantInt>(BO->getOperand(1))) - if (BO->getOpcode() == Instruction::And && - LCI->getValue() == CI->getValue()) - if (const SCEVZeroExtendExpr *Z = - dyn_cast<SCEVZeroExtendExpr>(getSCEV(U->getOperand(0)))) { - Type *UTy = U->getType(); - const SCEV *Z0 = Z->getOperand(); - Type *Z0Ty = Z0->getType(); - unsigned Z0TySize = getTypeSizeInBits(Z0Ty); - - // If C is a low-bits mask, the zero extend is serving to - // mask off the high bits. Complement the operand and - // re-apply the zext. - if (APIntOps::isMask(Z0TySize, CI->getValue())) - return getZeroExtendExpr(getNotSCEV(Z0), UTy); - - // If C is a single bit, it may be in the sign-bit position - // before the zero-extend. In this case, represent the xor - // using an add, which is equivalent, and re-apply the zext. - APInt Trunc = CI->getValue().trunc(Z0TySize); - if (Trunc.zext(getTypeSizeInBits(UTy)) == CI->getValue() && - Trunc.isSignBit()) - return getZeroExtendExpr(getAddExpr(Z0, getConstant(Trunc)), - UTy); - } - } - break; + break; case Instruction::Shl: // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); + if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { + uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); // If the shift count is not less than the bitwidth, the result of // the shift is undefined. Don't try to analyze it, because the @@ -4700,58 +5154,43 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html // and http://reviews.llvm.org/D8890 . auto Flags = SCEV::FlagAnyWrap; - if (SA->getValue().ult(BitWidth - 1)) Flags = getNoWrapFlagsFromUB(U); + if (BO->Op && SA->getValue().ult(BitWidth - 1)) + Flags = getNoWrapFlagsFromUB(BO->Op); Constant *X = ConstantInt::get(getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(U->getOperand(0)), getSCEV(X), Flags); + return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); } break; - case Instruction::LShr: - // Turn logical shift right of a constant into a unsigned divide. - if (ConstantInt *SA = dyn_cast<ConstantInt>(U->getOperand(1))) { - uint32_t BitWidth = cast<IntegerType>(U->getType())->getBitWidth(); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (SA->getValue().uge(BitWidth)) - break; + case Instruction::AShr: + // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS)) + if (Operator *L = dyn_cast<Operator>(BO->LHS)) + if (L->getOpcode() == Instruction::Shl && + L->getOperand(1) == BO->RHS) { + uint64_t BitWidth = getTypeSizeInBits(BO->LHS->getType()); + + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (CI->getValue().uge(BitWidth)) + break; - Constant *X = ConstantInt::get(getContext(), - APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getUDivExpr(getSCEV(U->getOperand(0)), getSCEV(X)); + uint64_t Amt = BitWidth - CI->getZExtValue(); + if (Amt == BitWidth) + return getSCEV(L->getOperand(0)); // shift by zero --> noop + return getSignExtendExpr( + getTruncateExpr(getSCEV(L->getOperand(0)), + IntegerType::get(getContext(), Amt)), + BO->LHS->getType()); + } + break; } - break; - - case Instruction::AShr: - // For a two-shift sext-inreg, use sext(trunc(x)) as the SCEV expression. - if (ConstantInt *CI = dyn_cast<ConstantInt>(U->getOperand(1))) - if (Operator *L = dyn_cast<Operator>(U->getOperand(0))) - if (L->getOpcode() == Instruction::Shl && - L->getOperand(1) == U->getOperand(1)) { - uint64_t BitWidth = getTypeSizeInBits(U->getType()); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (CI->getValue().uge(BitWidth)) - break; - - uint64_t Amt = BitWidth - CI->getZExtValue(); - if (Amt == BitWidth) - return getSCEV(L->getOperand(0)); // shift by zero --> noop - return - getSignExtendExpr(getTruncateExpr(getSCEV(L->getOperand(0)), - IntegerType::get(getContext(), - Amt)), - U->getType()); - } - break; + } + switch (U->getOpcode()) { case Instruction::Trunc: return getTruncateExpr(getSCEV(U->getOperand(0)), U->getType()); @@ -4786,8 +5225,12 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { if (isa<Instruction>(U)) return createNodeForSelectOrPHI(cast<Instruction>(U), U->getOperand(0), U->getOperand(1), U->getOperand(2)); + break; - default: // We cannot analyze this expression. + case Instruction::Call: + case Instruction::Invoke: + if (Value *RV = CallSite(U).getReturnedArgOperand()) + return getSCEV(RV); break; } @@ -4808,16 +5251,6 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) { return 0; } -/// getSmallConstantTripCount - Returns the maximum trip count of this loop as a -/// normal unsigned value. Returns 0 if the trip count is unknown or not -/// constant. Will also return 0 if the maximum trip count is very large (>= -/// 2^32). -/// -/// This "trip count" assumes that control exits via ExitingBlock. More -/// precisely, it is the number of times that control may reach ExitingBlock -/// before taking the branch. For loops with multiple exits, it may not be the -/// number times that the loop header executes because the loop may exit -/// prematurely via another branch. unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L, BasicBlock *ExitingBlock) { assert(ExitingBlock && "Must pass a non-null exiting block!"); @@ -4846,10 +5279,10 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { return 0; } -/// getSmallConstantTripMultiple - Returns the largest constant divisor of the -/// trip count of this loop as a normal unsigned value, if possible. This -/// means that the actual trip count is always a multiple of the returned -/// value (don't forget the trip count could very well be zero as well!). +/// Returns the largest constant divisor of the trip count of this loop as a +/// normal unsigned value, if possible. This means that the actual trip count is +/// always a multiple of the returned value (don't forget the trip count could +/// very well be zero as well!). /// /// Returns 1 if the trip count is unknown or not guaranteed to be the /// multiple of a constant (which is also the case if the trip count is simply @@ -4891,37 +5324,30 @@ ScalarEvolution::getSmallConstantTripMultiple(Loop *L, return (unsigned)Result->getZExtValue(); } -// getExitCount - Get the expression for the number of loop iterations for which -// this loop is guaranteed not to exit via ExitingBlock. Otherwise return -// SCEVCouldNotCompute. +/// Get the expression for the number of loop iterations for which this loop is +/// guaranteed not to exit via ExitingBlock. Otherwise return +/// SCEVCouldNotCompute. const SCEV *ScalarEvolution::getExitCount(Loop *L, BasicBlock *ExitingBlock) { return getBackedgeTakenInfo(L).getExact(ExitingBlock, this); } -/// getBackedgeTakenCount - If the specified loop has a predictable -/// backedge-taken count, return it, otherwise return a SCEVCouldNotCompute -/// object. The backedge-taken count is the number of times the loop header -/// will be branched to from within the loop. This is one less than the -/// trip count of the loop, since it doesn't count the first iteration, -/// when the header is branched to from outside the loop. -/// -/// Note that it is not valid to call this method on a loop without a -/// loop-invariant backedge-taken count (see -/// hasLoopInvariantBackedgeTakenCount). -/// +const SCEV * +ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, + SCEVUnionPredicate &Preds) { + return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); +} + const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getExact(this); } -/// getMaxBackedgeTakenCount - Similar to getBackedgeTakenCount, except -/// return the least SCEV value that is known never to be less than the -/// actual backedge taken count. +/// Similar to getBackedgeTakenCount, except return the least SCEV value that is +/// known never to be less than the actual backedge taken count. const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) { return getBackedgeTakenInfo(L).getMax(this); } -/// PushLoopPHIs - Push PHI nodes in the header of the given loop -/// onto the given Worklist. +/// Push PHI nodes in the header of the given loop onto the given Worklist. static void PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { BasicBlock *Header = L->getHeader(); @@ -4933,6 +5359,23 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { } const ScalarEvolution::BackedgeTakenInfo & +ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) { + auto &BTI = getBackedgeTakenInfo(L); + if (BTI.hasFullInfo()) + return BTI; + + auto Pair = PredicatedBackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); + + if (!Pair.second) + return Pair.first->second; + + BackedgeTakenInfo Result = + computeBackedgeTakenCount(L, /*AllowPredicates=*/true); + + return PredicatedBackedgeTakenCounts.find(L)->second = Result; +} + +const ScalarEvolution::BackedgeTakenInfo & ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Initially insert an invalid entry for this loop. If the insertion // succeeds, proceed to actually compute a backedge-taken count and @@ -4940,7 +5383,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // code elsewhere that it shouldn't attempt to request a new // backedge-taken count, which could result in infinite recursion. std::pair<DenseMap<const Loop *, BackedgeTakenInfo>::iterator, bool> Pair = - BackedgeTakenCounts.insert(std::make_pair(L, BackedgeTakenInfo())); + BackedgeTakenCounts.insert({L, BackedgeTakenInfo()}); if (!Pair.second) return Pair.first->second; @@ -5007,17 +5450,19 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { return BackedgeTakenCounts.find(L)->second = Result; } -/// forgetLoop - This method should be called by the client when it has -/// changed a loop in a way that may effect ScalarEvolution's ability to -/// compute a trip count, or if the loop is deleted. void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. - DenseMap<const Loop*, BackedgeTakenInfo>::iterator BTCPos = - BackedgeTakenCounts.find(L); - if (BTCPos != BackedgeTakenCounts.end()) { - BTCPos->second.clear(); - BackedgeTakenCounts.erase(BTCPos); - } + auto RemoveLoopFromBackedgeMap = + [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + auto BTCPos = Map.find(L); + if (BTCPos != Map.end()) { + BTCPos->second.clear(); + Map.erase(BTCPos); + } + }; + + RemoveLoopFromBackedgeMap(BackedgeTakenCounts); + RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); // Drop information about expressions based on loop-header PHIs. SmallVector<Instruction *, 16> Worklist; @@ -5043,13 +5488,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) { // Forget all contained loops too, to avoid dangling entries in the // ValuesAtScopes map. - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - forgetLoop(*I); + for (Loop *I : *L) + forgetLoop(I); + + LoopHasNoAbnormalExits.erase(L); } -/// forgetValue - This method should be called by the client when it has -/// changed a value in a way that may effect its value, or which may -/// disconnect it from a def-use chain linking it to a loop. void ScalarEvolution::forgetValue(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return; @@ -5077,16 +5521,17 @@ void ScalarEvolution::forgetValue(Value *V) { } } -/// getExact - Get the exact loop backedge taken count considering all loop -/// exits. A computable result can only be returned for loops with a single -/// exit. Returning the minimum taken count among all exits is incorrect -/// because one of the loop's exit limit's may have been skipped. HowFarToZero -/// assumes that the limit of each loop test is never skipped. This is a valid -/// assumption as long as the loop exits via that test. For precise results, it -/// is the caller's responsibility to specify the relevant loop exit using +/// Get the exact loop backedge taken count considering all loop exits. A +/// computable result can only be returned for loops with a single exit. +/// Returning the minimum taken count among all exits is incorrect because one +/// of the loop's exit limit's may have been skipped. howFarToZero assumes that +/// the limit of each loop test is never skipped. This is a valid assumption as +/// long as the loop exits via that test. For precise results, it is the +/// caller's responsibility to specify the relevant loop exit using /// getExact(ExitingBlock, SE). const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { +ScalarEvolution::BackedgeTakenInfo::getExact( + ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); @@ -5095,36 +5540,42 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE) const { assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); const SCEV *BECount = nullptr; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - assert(ENT->ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + for (auto &ENT : ExitNotTaken) { + assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); if (!BECount) - BECount = ENT->ExactNotTaken; - else if (BECount != ENT->ExactNotTaken) + BECount = ENT.ExactNotTaken; + else if (BECount != ENT.ExactNotTaken) return SE->getCouldNotCompute(); + if (Preds && ENT.getPred()) + Preds->add(ENT.getPred()); + + assert((Preds || ENT.hasAlwaysTruePred()) && + "Predicate should be always true!"); } + assert(BECount && "Invalid not taken count for loop exit"); return BECount; } -/// getExact - Get the exact not taken count for this loop exit. +/// Get the exact not taken count for this loop exit. const SCEV * ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock, ScalarEvolution *SE) const { - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) + return ENT.ExactNotTaken; - if (ENT->ExitingBlock == ExitingBlock) - return ENT->ExactNotTaken; - } return SE->getCouldNotCompute(); } /// getMax - Get the max backedge taken count for the loop. const SCEV * ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { + for (auto &ENT : ExitNotTaken) + if (!ENT.hasAlwaysTruePred()) + return SE->getCouldNotCompute(); + return Max ? Max : SE->getCouldNotCompute(); } @@ -5136,22 +5587,19 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, if (!ExitNotTaken.ExitingBlock) return false; - for (const ExitNotTakenInfo *ENT = &ExitNotTaken; - ENT != nullptr; ENT = ENT->getNextExit()) { - - if (ENT->ExactNotTaken != SE->getCouldNotCompute() - && SE->hasOperand(ENT->ExactNotTaken, S)) { + for (auto &ENT : ExitNotTaken) + if (ENT.ExactNotTaken != SE->getCouldNotCompute() && + SE->hasOperand(ENT.ExactNotTaken, S)) return true; - } - } + return false; } /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each /// computable exit into a persistent ExitNotTakenInfo array. ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( - SmallVectorImpl< std::pair<BasicBlock *, const SCEV *> > &ExitCounts, - bool Complete, const SCEV *MaxCount) : Max(MaxCount) { + SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount) + : Max(MaxCount) { if (!Complete) ExitNotTaken.setIncomplete(); @@ -5159,36 +5607,63 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( unsigned NumExits = ExitCounts.size(); if (NumExits == 0) return; - ExitNotTaken.ExitingBlock = ExitCounts[0].first; - ExitNotTaken.ExactNotTaken = ExitCounts[0].second; - if (NumExits == 1) return; + ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; + ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; + + // Determine the number of ExitNotTakenExtras structures that we need. + unsigned ExtraInfoSize = 0; + if (NumExits > 1) + ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), + ExitCounts.end(), [](EdgeInfo &Entry) { + return !Entry.Pred.isAlwaysTrue(); + }); + else if (!ExitCounts[0].Pred.isAlwaysTrue()) + ExtraInfoSize = 1; + + ExitNotTakenExtras *ENT = nullptr; + + // Allocate the ExitNotTakenExtras structures and initialize the first + // element (ExitNotTaken). + if (ExtraInfoSize > 0) { + ENT = new ExitNotTakenExtras[ExtraInfoSize]; + ExitNotTaken.ExtraInfo = &ENT[0]; + *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); + } + + if (NumExits == 1) + return; + + assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); + + auto &Exits = ExitNotTaken.ExtraInfo->Exits; // Handle the rare case of multiple computable exits. - ExitNotTakenInfo *ENT = new ExitNotTakenInfo[NumExits-1]; + for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { + ExitNotTakenExtras *Ptr = nullptr; + if (!ExitCounts[i].Pred.isAlwaysTrue()) { + Ptr = &ENT[PredPos++]; + Ptr->Pred = std::move(ExitCounts[i].Pred); + } - ExitNotTakenInfo *PrevENT = &ExitNotTaken; - for (unsigned i = 1; i < NumExits; ++i, PrevENT = ENT, ++ENT) { - PrevENT->setNextExit(ENT); - ENT->ExitingBlock = ExitCounts[i].first; - ENT->ExactNotTaken = ExitCounts[i].second; + Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); } } -/// clear - Invalidate this result and free the ExitNotTakenInfo array. +/// Invalidate this result and free the ExitNotTakenInfo array. void ScalarEvolution::BackedgeTakenInfo::clear() { ExitNotTaken.ExitingBlock = nullptr; ExitNotTaken.ExactNotTaken = nullptr; - delete[] ExitNotTaken.getNextExit(); + delete[] ExitNotTaken.ExtraInfo; } -/// computeBackedgeTakenCount - Compute the number of times the backedge -/// of the specified loop will execute. +/// Compute the number of times the backedge of the specified loop will execute. ScalarEvolution::BackedgeTakenInfo -ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { +ScalarEvolution::computeBackedgeTakenCount(const Loop *L, + bool AllowPredicates) { SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - SmallVector<std::pair<BasicBlock *, const SCEV *>, 4> ExitCounts; + SmallVector<EdgeInfo, 4> ExitCounts; bool CouldComputeBECount = true; BasicBlock *Latch = L->getLoopLatch(); // may be NULL. const SCEV *MustExitMaxBECount = nullptr; @@ -5196,9 +5671,13 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts // and compute maxBECount. + // Do a union of all the predicates here. for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) { BasicBlock *ExitBB = ExitingBlocks[i]; - ExitLimit EL = computeExitLimit(L, ExitBB); + ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); + + assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && + "Predicated exit limit when predicates are not allowed!"); // 1. For each exit that can be computed, add an entry to ExitCounts. // CouldComputeBECount is true only if all exits can be computed. @@ -5207,7 +5686,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { // we won't be able to compute an exact value for the loop. CouldComputeBECount = false; else - ExitCounts.push_back(std::make_pair(ExitBB, EL.Exact)); + ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); // 2. Derive the loop's MaxBECount from each exit's max number of // non-exiting iterations. Partition the loop exits into two kinds: @@ -5241,20 +5720,20 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L) { } ScalarEvolution::ExitLimit -ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { +ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, + bool AllowPredicates) { // Okay, we've chosen an exiting block. See what condition causes us to exit // at this block and remember the exit block and whether all other targets // lead to the loop header. bool MustExecuteLoopHeader = true; BasicBlock *Exit = nullptr; - for (succ_iterator SI = succ_begin(ExitingBlock), SE = succ_end(ExitingBlock); - SI != SE; ++SI) - if (!L->contains(*SI)) { + for (auto *SBB : successors(ExitingBlock)) + if (!L->contains(SBB)) { if (Exit) // Multiple exit successors. return getCouldNotCompute(); - Exit = *SI; - } else if (*SI != L->getHeader()) { + Exit = SBB; + } else if (SBB != L->getHeader()) { MustExecuteLoopHeader = false; } @@ -5307,9 +5786,9 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); // Proceed to the next level to examine the exit condition expression. - return computeExitLimitFromCond(L, BI->getCondition(), BI->getSuccessor(0), - BI->getSuccessor(1), - /*ControlsExit=*/IsOnlyExit); + return computeExitLimitFromCond( + L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), + /*ControlsExit=*/IsOnlyExit, AllowPredicates); } if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) @@ -5319,29 +5798,24 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock) { return getCouldNotCompute(); } -/// computeExitLimitFromCond - Compute the number of times the -/// backedge of the specified loop will execute if its exit condition -/// were a conditional branch of ExitCond, TBB, and FBB. -/// -/// @param ControlsExit is true if ExitCond directly controls the exit -/// branch. In this case, we can assume that the loop exits only if the -/// condition is true and can infer that failing to meet the condition prior to -/// integer wraparound results in undefined behavior. ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. bool EitherMayExit = L->contains(TBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5368,6 +5842,9 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); // There are cases (e.g. PR26207) where computeExitLimitFromCond is able // to be more aggressive when computing BECount than when computing // MaxBECount. In these cases it is possible for EL0.Exact and EL1.Exact @@ -5376,15 +5853,17 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, !isa<SCEVCouldNotCompute>(BECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, NP); } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. bool EitherMayExit = L->contains(FBB); ExitLimit EL0 = computeExitLimitFromCond(L, BO->getOperand(0), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); ExitLimit EL1 = computeExitLimitFromCond(L, BO->getOperand(1), TBB, FBB, - ControlsExit && !EitherMayExit); + ControlsExit && !EitherMayExit, + AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -5411,14 +5890,25 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L, BECount = EL0.Exact; } - return ExitLimit(BECount, MaxBECount); + SCEVUnionPredicate NP; + NP.add(&EL0.Pred); + NP.add(&EL1.Pred); + return ExitLimit(BECount, MaxBECount, NP); } } // With an icmp, it may be feasible to compute an exact backedge-taken count. // Proceed to the next level to examine the icmp. - if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) - return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { + ExitLimit EL = + computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + if (EL.hasFullInfo() || !AllowPredicates) + return EL; + + // Try again, but use SCEV predicates this time. + return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, + /*AllowPredicates=*/true); + } // Check for a constant condition. These are normally stripped out by // SimplifyCFG, but ScalarEvolution may be used by a pass which wishes to @@ -5442,7 +5932,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit) { + bool ControlsExit, + bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Cond; @@ -5460,11 +5951,6 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ItCnt; } - ExitLimit ShiftEL = computeShiftCompareExitLimit( - ExitCond->getOperand(0), ExitCond->getOperand(1), L, Cond); - if (ShiftEL.hasAnyInfo()) - return ShiftEL; - const SCEV *LHS = getSCEV(ExitCond->getOperand(0)); const SCEV *RHS = getSCEV(ExitCond->getOperand(1)); @@ -5499,34 +5985,46 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, switch (Cond) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_EQ: { // while (X == Y) // Convert to: while (X-Y == 0) - ExitLimit EL = HowFarToNonZero(getMinusSCEV(LHS, RHS), L); + ExitLimit EL = howFarToNonZero(getMinusSCEV(LHS, RHS), L); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) bool IsSigned = Cond == ICmpInst::ICMP_SLT; - ExitLimit EL = HowManyLessThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) bool IsSigned = Cond == ICmpInst::ICMP_SGT; - ExitLimit EL = HowManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit); + ExitLimit EL = + howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, + AllowPredicates); if (EL.hasAnyInfo()) return EL; break; } default: break; } - return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + auto *ExhaustiveCount = + computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + + if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) + return ExhaustiveCount; + + return computeShiftCompareExitLimit(ExitCond->getOperand(0), + ExitCond->getOperand(1), L, Cond); } ScalarEvolution::ExitLimit @@ -5546,7 +6044,7 @@ ScalarEvolution::computeExitLimitFromSingleExitSwitch(const Loop *L, const SCEV *RHS = getConstant(Switch->findCaseDest(ExitingBlock)); // while (X != Y) --> while (X-Y != 0) - ExitLimit EL = HowFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); + ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit); if (EL.hasAnyInfo()) return EL; @@ -5563,9 +6061,8 @@ EvaluateConstantChrecAtConstant(const SCEVAddRecExpr *AddRec, ConstantInt *C, return cast<SCEVConstant>(Val)->getValue(); } -/// computeLoadConstantCompareExitLimit - Given an exit condition of -/// 'icmp op load X, cst', try to see if we can compute the backedge -/// execution count. +/// Given an exit condition of 'icmp op load X, cst', try to see if we can +/// compute the backedge execution count. ScalarEvolution::ExitLimit ScalarEvolution::computeLoadConstantCompareExitLimit( LoadInst *LI, @@ -5781,14 +6278,15 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit( unsigned BitWidth = getTypeSizeInBits(RHS->getType()); const SCEV *UpperBound = getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); - return ExitLimit(getCouldNotCompute(), UpperBound); + SCEVUnionPredicate P; + return ExitLimit(getCouldNotCompute(), UpperBound, P); } return getCouldNotCompute(); } -/// CanConstantFold - Return true if we can constant fold an instruction of the -/// specified type, assuming that all operands were constants. +/// Return true if we can constant fold an instruction of the specified type, +/// assuming that all operands were constants. static bool CanConstantFold(const Instruction *I) { if (isa<BinaryOperator>(I) || isa<CmpInst>(I) || isa<SelectInst>(I) || isa<CastInst>(I) || isa<GetElementPtrInst>(I) || @@ -5916,10 +6414,9 @@ static Constant *EvaluateExpression(Value *V, const Loop *L, Operands[1], DL, TLI); if (LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - return ConstantFoldLoadFromConstPtr(Operands[0], DL); + return ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } - return ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, DL, - TLI); + return ConstantFoldInstOperands(I, Operands, DL, TLI); } @@ -6107,16 +6604,6 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, return getCouldNotCompute(); } -/// getSCEVAtScope - Return a SCEV expression for the specified value -/// at the specified scope in the program. The L value specifies a loop -/// nest to evaluate the expression at, where null is the top-level or a -/// specified loop is immediately inside of the loop. -/// -/// This method can be used to compute the exit value for a variable defined -/// in a loop by querying what the value will hold in the parent loop. -/// -/// In the case that a relevant loop exit value cannot be computed, the -/// original value V is returned. const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) { SmallVector<std::pair<const Loop *, const SCEV *>, 2> &Values = ValuesAtScopes[V]; @@ -6305,10 +6792,9 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { Operands[1], DL, &TLI); else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) { if (!LI->isVolatile()) - C = ConstantFoldLoadFromConstPtr(Operands[0], DL); + C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL); } else - C = ConstantFoldInstOperands(I->getOpcode(), I->getType(), Operands, - DL, &TLI); + C = ConstantFoldInstOperands(I, Operands, DL, &TLI); if (!C) return V; return getSCEV(C); } @@ -6428,14 +6914,11 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) { llvm_unreachable("Unknown SCEV type!"); } -/// getSCEVAtScope - This is a convenience function which does -/// getSCEVAtScope(getSCEV(V), L). const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } -/// SolveLinEquationWithOverflow - Finds the minimum unsigned root of the -/// following equation: +/// Finds the minimum unsigned root of the following equation: /// /// A * X = B (mod N) /// @@ -6482,11 +6965,11 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, return SE.getConstant(Result.trunc(BW)); } -/// SolveQuadraticEquation - Find the roots of the quadratic equation for the -/// given quadratic chrec {L,+,M,+,N}. This returns either the two roots (which -/// might be the same) or two SCEVCouldNotCompute objects. +/// Find the roots of the quadratic equation for the given quadratic chrec +/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or +/// two SCEVCouldNotCompute objects. /// -static std::pair<const SCEV *,const SCEV *> +static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); @@ -6494,10 +6977,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); // We currently can only solve this if the coefficients are constants. - if (!LC || !MC || !NC) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (!LC || !MC || !NC) + return None; uint32_t BitWidth = LC->getAPInt().getBitWidth(); const APInt &L = LC->getAPInt(); @@ -6524,8 +7005,7 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { if (SqrtTerm.isNegative()) { // The loop is provably infinite. - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); + return None; } // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest @@ -6536,10 +7016,8 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { // The divisions must be performed as signed divisions. APInt NegB(-B); APInt TwoA(A << 1); - if (TwoA.isMinValue()) { - const SCEV *CNC = SE.getCouldNotCompute(); - return std::make_pair(CNC, CNC); - } + if (TwoA.isMinValue()) + return None; LLVMContext &Context = SE.getContext(); @@ -6548,20 +7026,21 @@ SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { ConstantInt *Solution2 = ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); - return std::make_pair(SE.getConstant(Solution1), - SE.getConstant(Solution2)); + return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), + cast<SCEVConstant>(SE.getConstant(Solution2))); } // end APIntOps namespace } -/// HowFarToZero - Return the number of times a backedge comparing the specified -/// value to zero will execute. If not computable, return CouldNotCompute. -/// -/// This is only used for loops with a "x != y" exit test. The exit condition is -/// now expressed as a single expression, V = x-y. So the exit test is -/// effectively V != 0. We know and take advantage of the fact that this -/// expression only being used in a comparison by zero context. ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { +ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, + bool AllowPredicates) { + + // This is only used for loops with a "x != y" exit test. The exit condition + // is now expressed as a single expression, V = x-y. So the exit test is + // effectively V != 0. We know and take advantage of the fact that this + // expression only being used in a comparison by zero context. + + SCEVUnionPredicate P; // If the value is a constant if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) { // If the value is already zero, the branch will execute zero times. @@ -6570,31 +7049,33 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { } const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); + if (!AddRec && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + AddRec = convertSCEVToAddRecWithPredicates(V, L, P); + if (!AddRec || AddRec->getLoop() != L) return getCouldNotCompute(); // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - std::pair<const SCEV *,const SCEV *> Roots = - SolveQuadraticEquation(AddRec, *this); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1 && R2) { + if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. - if (ConstantInt *CB = - dyn_cast<ConstantInt>(ConstantExpr::getICmp(CmpInst::ICMP_ULT, - R1->getValue(), - R2->getValue()))) { + if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( + CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // We can only use this value if the chrec ends up with an exact zero // value at this index. When solving for "X*X != 5", for example, we // should not accept a root of 2. const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); if (Val->isZero()) - return R1; // We found a quadratic root! + return ExitLimit(R1, R1, P); // We found a quadratic root! } } return getCouldNotCompute(); @@ -6651,7 +7132,7 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { else MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount); + return ExitLimit(Distance, MaxBECount, P); } // As a special case, handle the instance where Step is a positive power of @@ -6704,7 +7185,9 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { auto *NarrowTy = IntegerType::get(getContext(), NarrowWidth); auto *WideTy = Distance->getType(); - return getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + const SCEV *Limit = + getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); + return ExitLimit(Limit, Limit, P); } } @@ -6713,24 +7196,24 @@ ScalarEvolution::HowFarToZero(const SCEV *V, const Loop *L, bool ControlsExit) { // compute the backedge count. In this case, the step may not divide the // distance, but we don't care because if the condition is "missed" the loop // will have undefined behavior due to wrapping. - if (ControlsExit && AddRec->getNoWrapFlags(SCEV::FlagNW)) { + if (ControlsExit && AddRec->hasNoSelfWrap() && + loopHasNoAbnormalExits(AddRec->getLoop())) { const SCEV *Exact = getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); - return ExitLimit(Exact, Exact); + return ExitLimit(Exact, Exact, P); } // Then, try to solve the above equation provided that Start is constant. - if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) - return SolveLinEquationWithOverflow(StepC->getAPInt(), -StartC->getAPInt(), - *this); + if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) { + const SCEV *E = SolveLinEquationWithOverflow( + StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); + return ExitLimit(E, E, P); + } return getCouldNotCompute(); } -/// HowFarToNonZero - Return the number of times a backedge checking the -/// specified value for nonzero will execute. If not computable, return -/// CouldNotCompute ScalarEvolution::ExitLimit -ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { +ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) { // Loops that look like: while (X == 0) are very strange indeed. We don't // handle them yet except for the trivial case. This could be expanded in the // future as needed. @@ -6748,33 +7231,27 @@ ScalarEvolution::HowFarToNonZero(const SCEV *V, const Loop *L) { return getCouldNotCompute(); } -/// getPredecessorWithUniqueSuccessorForBB - Return a predecessor of BB -/// (which may not be an immediate predecessor) which has exactly one -/// successor from which BB is reachable, or null if no such block is -/// found. -/// std::pair<BasicBlock *, BasicBlock *> ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { // If the block has a unique predecessor, then there is no path from the // predecessor to the block that does not go through the direct edge // from the predecessor to the block. if (BasicBlock *Pred = BB->getSinglePredecessor()) - return std::make_pair(Pred, BB); + return {Pred, BB}; // A loop's header is defined to be a block that dominates the loop. // If the header has a unique predecessor outside the loop, it must be // a block that has exactly one successor that can reach the loop. if (Loop *L = LI.getLoopFor(BB)) - return std::make_pair(L->getLoopPredecessor(), L->getHeader()); + return {L->getLoopPredecessor(), L->getHeader()}; - return std::pair<BasicBlock *, BasicBlock *>(); + return {nullptr, nullptr}; } -/// HasSameValue - SCEV structural equivalence is usually sufficient for -/// testing whether two expressions are equal, however for the purposes of -/// looking for a condition guarding a loop, it can be useful to be a little -/// more general, since a front-end may have replicated the controlling -/// expression. +/// SCEV structural equivalence is usually sufficient for testing whether two +/// expressions are equal, however for the purposes of looking for a condition +/// guarding a loop, it can be useful to be a little more general, since a +/// front-end may have replicated the controlling expression. /// static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. @@ -6800,9 +7277,6 @@ static bool HasSameValue(const SCEV *A, const SCEV *B) { return false; } -/// SimplifyICmpOperands - Simplify LHS and RHS in a comparison with -/// predicate Pred. Return true iff any changes were made. -/// bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { @@ -7134,7 +7608,7 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, return true; // Otherwise see what can be done with known constant ranges. - return isKnownPredicateWithRanges(Pred, LHS, RHS); + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); } bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, @@ -7180,7 +7654,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_UGE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (!LHS->getNoWrapFlags(SCEV::FlagNUW)) + if (!LHS->hasNoUnsignedWrap()) return false; Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE; @@ -7190,7 +7664,7 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS, case ICmpInst::ICMP_SGE: case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: { - if (!LHS->getNoWrapFlags(SCEV::FlagNSW)) + if (!LHS->hasNoSignedWrap()) return false; const SCEV *Step = LHS->getStepRecurrence(*this); @@ -7264,78 +7738,34 @@ bool ScalarEvolution::isLoopInvariantPredicate( return true; } -bool -ScalarEvolution::isKnownPredicateWithRanges(ICmpInst::Predicate Pred, - const SCEV *LHS, const SCEV *RHS) { +bool ScalarEvolution::isKnownPredicateViaConstantRanges( + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { if (HasSameValue(LHS, RHS)) return ICmpInst::isTrueWhenEqual(Pred); // This code is split out from isKnownPredicate because it is called from // within isLoopEntryGuardedByCond. - switch (Pred) { - default: - llvm_unreachable("Unexpected ICmpInst::Predicate value!"); - case ICmpInst::ICMP_SGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLT: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().slt(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sge(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_SGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_SLE: { - ConstantRange LHSRange = getSignedRange(LHS); - ConstantRange RHSRange = getSignedRange(RHS); - if (LHSRange.getSignedMax().sle(RHSRange.getSignedMin())) - return true; - if (LHSRange.getSignedMin().sgt(RHSRange.getSignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGT: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULT: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ult(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().uge(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_UGE: - std::swap(LHS, RHS); - case ICmpInst::ICMP_ULE: { - ConstantRange LHSRange = getUnsignedRange(LHS); - ConstantRange RHSRange = getUnsignedRange(RHS); - if (LHSRange.getUnsignedMax().ule(RHSRange.getUnsignedMin())) - return true; - if (LHSRange.getUnsignedMin().ugt(RHSRange.getUnsignedMax())) - return false; - break; - } - case ICmpInst::ICMP_NE: { - if (getUnsignedRange(LHS).intersectWith(getUnsignedRange(RHS)).isEmptySet()) - return true; - if (getSignedRange(LHS).intersectWith(getSignedRange(RHS)).isEmptySet()) - return true; - const SCEV *Diff = getMinusSCEV(LHS, RHS); - if (isKnownNonZero(Diff)) - return true; - break; - } - case ICmpInst::ICMP_EQ: - // The check at the top of the function catches the case where - // the values are known to be equal. - break; - } - return false; + auto CheckRanges = + [&](const ConstantRange &RangeLHS, const ConstantRange &RangeRHS) { + return ConstantRange::makeSatisfyingICmpRegion(Pred, RangeRHS) + .contains(RangeLHS); + }; + + // The check at the top of the function catches the case where the values are + // known to be equal. + if (Pred == CmpInst::ICMP_EQ) + return false; + + if (Pred == CmpInst::ICMP_NE) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)) || + CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)) || + isKnownNonZero(getMinusSCEV(LHS, RHS)); + + if (CmpInst::isSigned(Pred)) + return CheckRanges(getSignedRange(LHS), getSignedRange(RHS)); + + return CheckRanges(getUnsignedRange(LHS), getUnsignedRange(RHS)); } bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, @@ -7416,6 +7846,23 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred, isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS); } +bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB, + ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // No need to even try if we know the module has no guards. + if (!HasGuards) + return false; + + return any_of(*BB, [&](Instruction &I) { + using namespace llvm::PatternMatch; + + Value *Condition; + return match(&I, m_Intrinsic<Intrinsic::experimental_guard>( + m_Value(Condition))) && + isImpliedCond(Pred, LHS, RHS, Condition, false); + }); +} + /// isLoopBackedgeGuardedByCond - Test whether the backedge of the loop is /// protected by a conditional between LHS and RHS. This is used to /// to eliminate casts. @@ -7427,7 +7874,8 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; BasicBlock *Latch = L->getLoopLatch(); if (!Latch) @@ -7482,12 +7930,18 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, if (!DT.isReachableFromEntry(L->getHeader())) return false; + if (isImpliedViaGuard(Latch, Pred, LHS, RHS)) + return true; + for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; DTN != HeaderDTN; DTN = DTN->getIDom()) { assert(DTN && "should reach the loop header before reaching the root!"); BasicBlock *BB = DTN->getBlock(); + if (isImpliedViaGuard(BB, Pred, LHS, RHS)) + return true; + BasicBlock *PBB = BB->getSinglePredecessor(); if (!PBB) continue; @@ -7518,9 +7972,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, return false; } -/// isLoopEntryGuardedByCond - Test whether entry to the loop is protected -/// by a conditional between LHS and RHS. This is used to help avoid max -/// expressions in loop trip counts, and to eliminate casts. bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, ICmpInst::Predicate Pred, @@ -7529,7 +7980,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; - if (isKnownPredicateWithRanges(Pred, LHS, RHS)) return true; + if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + return true; // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors @@ -7539,6 +7991,9 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { + if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) + return true; + BranchInst *LoopEntryPredicate = dyn_cast<BranchInst>(Pair.first->getTerminator()); if (!LoopEntryPredicate || @@ -7586,8 +8041,6 @@ struct MarkPendingLoopPredicate { }; } // end anonymous namespace -/// isImpliedCond - Test whether the condition described by Pred, LHS, -/// and RHS is true whenever the given Cond value evaluates to true. bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Value *FoundCondValue, @@ -7910,9 +8363,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( getConstant(FoundRHSLimit)); } -/// isImpliedCondOperands - Test whether the condition described by Pred, -/// LHS, and RHS is true whenever the condition described by Pred, FoundLHS, -/// and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -8037,9 +8487,6 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE, llvm_unreachable("covered switch fell through?!"); } -/// isImpliedCondOperandsHelper - Test whether the condition described by -/// Pred, LHS, and RHS is true whenever the condition described by Pred, -/// FoundLHS, and FoundRHS is true. bool ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8047,7 +8494,7 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, const SCEV *FoundRHS) { auto IsKnownPredicateFull = [this](ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - return isKnownPredicateWithRanges(Pred, LHS, RHS) || + return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || IsKnownPredicateViaAddRecStart(*this, Pred, LHS, RHS) || isKnownPredicateViaNoOverflow(Pred, LHS, RHS); @@ -8089,8 +8536,6 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, return false; } -/// isImpliedCondOperandsViaRanges - helper function for isImpliedCondOperands. -/// Tries to get cases like "X `sgt` 0 => X - 1 `sgt` -1". bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, @@ -8129,9 +8574,6 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred, return SatisfyingLHSRange.contains(LHSRange); } -// Verify if an linear IV with positive stride can overflow when in a -// less-than comparison, knowing the invariant term of the comparison, the -// stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8158,9 +8600,6 @@ bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride, return (MaxValue - MaxStrideMinusOne).ult(MaxRHS); } -// Verify if an linear IV with negative stride can overflow when in a -// greater-than comparison, knowing the invariant term of the comparison, -// the stride and the knowledge of NSW/NUW flags on the recurrence. bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, bool IsSigned, bool NoWrap) { if (NoWrap) return false; @@ -8187,8 +8626,6 @@ bool ScalarEvolution::doesIVOverflowOnGT(const SCEV *RHS, const SCEV *Stride, return (MinValue + MaxStrideMinusOne).ugt(MinRHS); } -// Compute the backedge taken count knowing the interval difference, the -// stride and presence of the equality in the comparison. const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, bool Equality) { const SCEV *One = getOne(Step->getType()); @@ -8197,22 +8634,21 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, return getUDivExpr(Delta, Step); } -/// HowManyLessThans - Return the number of times a backedge containing the -/// specified less-than comparison will execute. If not computable, return -/// CouldNotCompute. -/// -/// @param ControlsExit is true when the LHS < RHS condition directly controls -/// the branch (loops exits only if condition is true). In this case, we can use -/// NoWrapFlags to skip overflow checks. ScalarEvolution::ExitLimit -ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV < Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8238,19 +8674,8 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMaxExpr(RHS, Start) - : getUMaxExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) + End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); @@ -8281,18 +8706,24 @@ ScalarEvolution::HowManyLessThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } ScalarEvolution::ExitLimit -ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, +ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, - bool ControlsExit) { + bool ControlsExit, bool AllowPredicates) { + SCEVUnionPredicate P; // We handle only IV > Invariant if (!isLoopInvariant(RHS, L)) return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); + if (!IV && AllowPredicates) + // Try to make this an AddRec using runtime tests, in the first X + // iterations of this loop, where X is the SCEV expression found by the + // algorithm below. + IV = convertSCEVToAddRecWithPredicates(LHS, L, P); // Avoid weird loops if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8319,19 +8750,8 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, const SCEV *Start = IV->getStart(); const SCEV *End = RHS; - if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) { - const SCEV *Diff = getMinusSCEV(RHS, Start); - // If we have NoWrap set, then we can assume that the increment won't - // overflow, in which case if RHS - Start is a constant, we don't need to - // do a max operation since we can just figure it out statically - if (NoWrap && isa<SCEVConstant>(Diff)) { - APInt D = dyn_cast<const SCEVConstant>(Diff)->getAPInt(); - if (!D.isNegative()) - End = Start; - } else - End = IsSigned ? getSMinExpr(RHS, Start) - : getUMinExpr(RHS, Start); - } + if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) + End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start); const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false); @@ -8363,15 +8783,10 @@ ScalarEvolution::HowManyGreaterThans(const SCEV *LHS, const SCEV *RHS, if (isa<SCEVCouldNotCompute>(MaxBECount)) MaxBECount = BECount; - return ExitLimit(BECount, MaxBECount); + return ExitLimit(BECount, MaxBECount, P); } -/// getNumIterationsInRange - Return the number of iterations of this loop that -/// produce values in the specified constant range. Another way of looking at -/// this is that it returns the first iteration number where the value is not in -/// the condition, thus computing the exit count. If the iteration count can't -/// be computed, an instance of SCEVCouldNotCompute is returned. -const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, +const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ScalarEvolution &SE) const { if (Range.isFullSet()) // Infinite loop. return SE.getCouldNotCompute(); @@ -8445,22 +8860,21 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, FlagAnyWrap); // Next, solve the constructed addrec - auto Roots = SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE); - const SCEVConstant *R1 = dyn_cast<SCEVConstant>(Roots.first); - const SCEVConstant *R2 = dyn_cast<SCEVConstant>(Roots.second); - if (R1) { + if (auto Roots = + SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { + const SCEVConstant *R1 = Roots->first; + const SCEVConstant *R2 = Roots->second; // Pick the smallest positive root value. if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. + std::swap(R1, R2); // R1 is the minimum root now. // Make sure the root is not off by one. The returned iteration should // not be in the range, but the previous one should be. When solving // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = EvaluateConstantChrecAtConstant(this, - R1->getValue(), - SE); + ConstantInt *R1Val = + EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); if (Range.contains(R1Val->getValue())) { // The next iteration must be out of the range... ConstantInt *NextVal = @@ -8469,7 +8883,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (!Range.contains(R1Val->getValue())) return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } // If R1 was not in the range, then it is a good return value. Make @@ -8479,7 +8893,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(ConstantRange Range, R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); if (Range.contains(R1Val->getValue())) return R1; - return SE.getCouldNotCompute(); // Something strange happened + return SE.getCouldNotCompute(); // Something strange happened } } } @@ -8789,12 +9203,9 @@ const SCEV *ScalarEvolution::getElementSize(Instruction *Inst) { return getSizeOfExpr(ETy, Ty); } -/// Second step of delinearization: compute the array dimensions Sizes from the -/// set of Terms extracted from the memory access function of this SCEVAddRec. void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, SmallVectorImpl<const SCEV *> &Sizes, const SCEV *ElementSize) const { - if (Terms.size() < 1 || !ElementSize) return; @@ -8858,8 +9269,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, }); } -/// Third step of delinearization: compute the access functions for the -/// Subscripts based on the dimensions in Sizes. void ScalarEvolution::computeAccessFunctions( const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<const SCEV *> &Sizes) { @@ -9012,7 +9421,7 @@ void ScalarEvolution::SCEVCallbackVH::deleted() { assert(SE && "SCEVCallbackVH called with a null ScalarEvolution!"); if (PHINode *PN = dyn_cast<PHINode>(getValPtr())) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(getValPtr()); + SE->eraseValueFromMap(getValPtr()); // this now dangles! } @@ -9035,13 +9444,13 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) { continue; if (PHINode *PN = dyn_cast<PHINode>(U)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(U); + SE->eraseValueFromMap(U); Worklist.insert(Worklist.end(), U->user_begin(), U->user_end()); } // Delete the Old value. if (PHINode *PN = dyn_cast<PHINode>(Old)) SE->ConstantEvolutionLoopExitValue.erase(PN); - SE->ValueExprMap.erase(Old); + SE->eraseValueFromMap(Old); // this now dangles! } @@ -9059,14 +9468,31 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, CouldNotCompute(new SCEVCouldNotCompute()), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), - FirstUnknown(nullptr) {} + FirstUnknown(nullptr) { + + // To use guards for proving predicates, we need to scan every instruction in + // relevant basic blocks, and not just terminators. Doing this is a waste of + // time if the IR does not actually contain any calls to + // @llvm.experimental.guard, so do a quick check and remember this beforehand. + // + // This pessimizes the case where a pass that preserves ScalarEvolution wants + // to _add_ guards to the module when there weren't any before, and wants + // ScalarEvolution to optimize based on those guards. For now we prefer to be + // efficient in lieu of being smart in that rather obscure case. + + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); +} ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) - : F(Arg.F), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), LI(Arg.LI), - CouldNotCompute(std::move(Arg.CouldNotCompute)), + : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT), + LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), WalkingBEDominatingConds(false), ProvingSplitPredicate(false), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), + PredicatedBackedgeTakenCounts( + std::move(Arg.PredicatedBackedgeTakenCounts)), ConstantEvolutionLoopExitValue( std::move(Arg.ConstantEvolutionLoopExitValue)), ValuesAtScopes(std::move(Arg.ValuesAtScopes)), @@ -9091,12 +9517,16 @@ ScalarEvolution::~ScalarEvolution() { } FirstUnknown = nullptr; + ExprValueMap.clear(); ValueExprMap.clear(); + HasRecMap.clear(); // Free any extra memory created for ExitNotTakenInfo in the unlikely event // that a loop had multiple computable exits. for (auto &BTCI : BackedgeTakenCounts) BTCI.second.clear(); + for (auto &BTCI : PredicatedBackedgeTakenCounts) + BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); @@ -9110,8 +9540,8 @@ bool ScalarEvolution::hasLoopInvariantBackedgeTakenCount(const Loop *L) { static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, const Loop *L) { // Print all inner loops first - for (Loop::iterator I = L->begin(), E = L->end(); I != E; ++I) - PrintLoopInfo(OS, SE, *I); + for (Loop *I : *L) + PrintLoopInfo(OS, SE, I); OS << "Loop "; L->getHeader()->printAsOperand(OS, /*PrintType=*/false); @@ -9139,9 +9569,35 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE, OS << "Unpredictable max backedge-taken count. "; } + OS << "\n" + "Loop "; + L->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": "; + + SCEVUnionPredicate Pred; + auto PBT = SE->getPredicatedBackedgeTakenCount(L, Pred); + if (!isa<SCEVCouldNotCompute>(PBT)) { + OS << "Predicated backedge-taken count is " << *PBT << "\n"; + OS << " Predicates:\n"; + Pred.print(OS, 4); + } else { + OS << "Unpredictable predicated backedge-taken count. "; + } OS << "\n"; } +static StringRef loopDispositionToStr(ScalarEvolution::LoopDisposition LD) { + switch (LD) { + case ScalarEvolution::LoopVariant: + return "Variant"; + case ScalarEvolution::LoopInvariant: + return "Invariant"; + case ScalarEvolution::LoopComputable: + return "Computable"; + } + llvm_unreachable("Unknown ScalarEvolution::LoopDisposition kind!"); +} + void ScalarEvolution::print(raw_ostream &OS) const { // ScalarEvolution's implementation of the print method is to print // out SCEV values of all instructions that are interesting. Doing @@ -9189,6 +9645,35 @@ void ScalarEvolution::print(raw_ostream &OS) const { } else { OS << *ExitValue; } + + bool First = true; + for (auto *Iter = L; Iter; Iter = Iter->getParentLoop()) { + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + Iter->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, Iter)); + } + + for (auto *InnerL : depth_first(L)) { + if (InnerL == L) + continue; + if (First) { + OS << "\t\t" "LoopDispositions: { "; + First = false; + } else { + OS << ", "; + } + + InnerL->getHeader()->printAsOperand(OS, /*PrintType=*/false); + OS << ": " << loopDispositionToStr(SE.getLoopDisposition(SV, InnerL)); + } + + OS << " }"; } OS << "\n"; @@ -9197,8 +9682,8 @@ void ScalarEvolution::print(raw_ostream &OS) const { OS << "Determining loop execution counts for: "; F.printAsOperand(OS, /*PrintType=*/false); OS << "\n"; - for (LoopInfo::iterator I = LI.begin(), E = LI.end(); I != E; ++I) - PrintLoopInfo(OS, &SE, *I); + for (Loop *I : LI) + PrintLoopInfo(OS, &SE, I); } ScalarEvolution::LoopDisposition @@ -9420,17 +9905,23 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { BlockDispositions.erase(S); UnsignedRanges.erase(S); SignedRanges.erase(S); + ExprValueMap.erase(S); + HasRecMap.erase(S); + + auto RemoveSCEVFromBackedgeMap = + [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + for (auto I = Map.begin(), E = Map.end(); I != E;) { + BackedgeTakenInfo &BEInfo = I->second; + if (BEInfo.hasOperand(S, this)) { + BEInfo.clear(); + Map.erase(I++); + } else + ++I; + } + }; - for (DenseMap<const Loop*, BackedgeTakenInfo>::iterator I = - BackedgeTakenCounts.begin(), E = BackedgeTakenCounts.end(); I != E; ) { - BackedgeTakenInfo &BEInfo = I->second; - if (BEInfo.hasOperand(S, this)) { - BEInfo.clear(); - BackedgeTakenCounts.erase(I++); - } - else - ++I; - } + RemoveSCEVFromBackedgeMap(BackedgeTakenCounts); + RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } typedef DenseMap<const Loop *, std::string> VerifyMap; @@ -9516,16 +10007,16 @@ void ScalarEvolution::verify() const { char ScalarEvolutionAnalysis::PassID; ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, - AnalysisManager<Function> *AM) { - return ScalarEvolution(F, AM->getResult<TargetLibraryAnalysis>(F), - AM->getResult<AssumptionAnalysis>(F), - AM->getResult<DominatorTreeAnalysis>(F), - AM->getResult<LoopAnalysis>(F)); + AnalysisManager<Function> &AM) { + return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F), + AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<LoopAnalysis>(F)); } PreservedAnalyses -ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> *AM) { - AM->getResult<ScalarEvolutionAnalysis>(F).print(OS); +ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { + AM.getResult<ScalarEvolutionAnalysis>(F).print(OS); return PreservedAnalyses::all(); } @@ -9590,36 +10081,121 @@ ScalarEvolution::getEqualPredicate(const SCEVUnknown *LHS, return Eq; } +const SCEVPredicate *ScalarEvolution::getWrapPredicate( + const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + FoldingSetNodeID ID; + // Unique this node based on the arguments + ID.AddInteger(SCEVPredicate::P_Wrap); + ID.AddPointer(AR); + ID.AddInteger(AddedFlags); + void *IP = nullptr; + if (const auto *S = UniquePreds.FindNodeOrInsertPos(ID, IP)) + return S; + auto *OF = new (SCEVAllocator) + SCEVWrapPredicate(ID.Intern(SCEVAllocator), AR, AddedFlags); + UniquePreds.InsertNode(OF, IP); + return OF; +} + namespace { + class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: - static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE, - SCEVUnionPredicate &A) { - SCEVPredicateRewriter Rewriter(SE, A); - return Rewriter.visit(Scev); + // Rewrites \p S in the context of a loop L and the predicate A. + // If Assume is true, rewrite is free to add further predicates to A + // such that the result will be an AddRecExpr. + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &A, bool Assume) { + SCEVPredicateRewriter Rewriter(L, SE, A, Assume); + return Rewriter.visit(S); } - SCEVPredicateRewriter(ScalarEvolution &SE, SCEVUnionPredicate &P) - : SCEVRewriteVisitor(SE), P(P) {} + SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SCEVUnionPredicate &P, bool Assume) + : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} const SCEV *visitUnknown(const SCEVUnknown *Expr) { auto ExprPreds = P.getPredicatesForExpr(Expr); for (auto *Pred : ExprPreds) - if (const auto *IPred = dyn_cast<const SCEVEqualPredicate>(Pred)) + if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) if (IPred->getLHS() == Expr) return IPred->getRHS(); return Expr; } + const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nuw + // flag. Add the nusw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNUSW)) + return SE.getAddRecExpr(SE.getZeroExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getZeroExtendExpr(Operand, Expr->getType()); + } + + const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) { + const SCEV *Operand = visit(Expr->getOperand()); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Operand); + if (AR && AR->getLoop() == L && AR->isAffine()) { + // This couldn't be folded because the operand didn't have the nsw + // flag. Add the nssw flag as an assumption that we could make. + const SCEV *Step = AR->getStepRecurrence(SE); + Type *Ty = Expr->getType(); + if (addOverflowAssumption(AR, SCEVWrapPredicate::IncrementNSSW)) + return SE.getAddRecExpr(SE.getSignExtendExpr(AR->getStart(), Ty), + SE.getSignExtendExpr(Step, Ty), L, + AR->getNoWrapFlags()); + } + return SE.getSignExtendExpr(Operand, Expr->getType()); + } + private: + bool addOverflowAssumption(const SCEVAddRecExpr *AR, + SCEVWrapPredicate::IncrementWrapFlags AddedFlags) { + auto *A = SE.getWrapPredicate(AR, AddedFlags); + if (!Assume) { + // Check if we've already made this assumption. + if (P.implies(A)) + return true; + return false; + } + P.add(A); + return true; + } + SCEVUnionPredicate &P; + const Loop *L; + bool Assume; }; } // end anonymous namespace -const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *Scev, +const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, + SCEVUnionPredicate &Preds) { + return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); +} + +const SCEVAddRecExpr * +ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, SCEVUnionPredicate &Preds) { - return SCEVPredicateRewriter::rewrite(Scev, *this, Preds); + SCEVUnionPredicate TransformPreds; + S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); + auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); + + if (!AddRec) + return nullptr; + + // Since the transformation was successful, we can now transfer the SCEV + // predicates. + Preds.add(&TransformPreds); + return AddRec; } /// SCEV predicates @@ -9633,7 +10209,7 @@ SCEVEqualPredicate::SCEVEqualPredicate(const FoldingSetNodeIDRef ID, : SCEVPredicate(ID, P_Equal), LHS(LHS), RHS(RHS) {} bool SCEVEqualPredicate::implies(const SCEVPredicate *N) const { - const auto *Op = dyn_cast<const SCEVEqualPredicate>(N); + const auto *Op = dyn_cast<SCEVEqualPredicate>(N); if (!Op) return false; @@ -9649,6 +10225,59 @@ void SCEVEqualPredicate::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth) << "Equal predicate: " << *LHS << " == " << *RHS << "\n"; } +SCEVWrapPredicate::SCEVWrapPredicate(const FoldingSetNodeIDRef ID, + const SCEVAddRecExpr *AR, + IncrementWrapFlags Flags) + : SCEVPredicate(ID, P_Wrap), AR(AR), Flags(Flags) {} + +const SCEV *SCEVWrapPredicate::getExpr() const { return AR; } + +bool SCEVWrapPredicate::implies(const SCEVPredicate *N) const { + const auto *Op = dyn_cast<SCEVWrapPredicate>(N); + + return Op && Op->AR == AR && setFlags(Flags, Op->Flags) == Flags; +} + +bool SCEVWrapPredicate::isAlwaysTrue() const { + SCEV::NoWrapFlags ScevFlags = AR->getNoWrapFlags(); + IncrementWrapFlags IFlags = Flags; + + if (ScalarEvolution::setFlags(ScevFlags, SCEV::FlagNSW) == ScevFlags) + IFlags = clearFlags(IFlags, IncrementNSSW); + + return IFlags == IncrementAnyWrap; +} + +void SCEVWrapPredicate::print(raw_ostream &OS, unsigned Depth) const { + OS.indent(Depth) << *getExpr() << " Added Flags: "; + if (SCEVWrapPredicate::IncrementNUSW & getFlags()) + OS << "<nusw>"; + if (SCEVWrapPredicate::IncrementNSSW & getFlags()) + OS << "<nssw>"; + OS << "\n"; +} + +SCEVWrapPredicate::IncrementWrapFlags +SCEVWrapPredicate::getImpliedFlags(const SCEVAddRecExpr *AR, + ScalarEvolution &SE) { + IncrementWrapFlags ImpliedFlags = IncrementAnyWrap; + SCEV::NoWrapFlags StaticFlags = AR->getNoWrapFlags(); + + // We can safely transfer the NSW flag as NSSW. + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNSW) == StaticFlags) + ImpliedFlags = IncrementNSSW; + + if (ScalarEvolution::setFlags(StaticFlags, SCEV::FlagNUW) == StaticFlags) { + // If the increment is positive, the SCEV NUW flag will also imply the + // WrapPredicate NUSW flag. + if (const auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) + if (Step->getValue()->getValue().isNonNegative()) + ImpliedFlags = setFlags(ImpliedFlags, IncrementNUSW); + } + + return ImpliedFlags; +} + /// Union predicates don't get cached so create a dummy set ID for it. SCEVUnionPredicate::SCEVUnionPredicate() : SCEVPredicate(FoldingSetNodeIDRef(nullptr, 0), P_Union) {} @@ -9667,7 +10296,7 @@ SCEVUnionPredicate::getPredicatesForExpr(const SCEV *Expr) { } bool SCEVUnionPredicate::implies(const SCEVPredicate *N) const { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) return all_of(Set->Preds, [this](const SCEVPredicate *I) { return this->implies(I); }); @@ -9688,7 +10317,7 @@ void SCEVUnionPredicate::print(raw_ostream &OS, unsigned Depth) const { } void SCEVUnionPredicate::add(const SCEVPredicate *N) { - if (const auto *Set = dyn_cast<const SCEVUnionPredicate>(N)) { + if (const auto *Set = dyn_cast<SCEVUnionPredicate>(N)) { for (auto Pred : Set->Preds) add(Pred); return; @@ -9705,8 +10334,9 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { Preds.push_back(N); } -PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE) - : SE(SE), Generation(0) {} +PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, + Loop &L) + : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); @@ -9721,12 +10351,21 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { if (Entry.second) Expr = Entry.second; - const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, Preds); + const SCEV *NewSCEV = SE.rewriteUsingPredicate(Expr, &L, Preds); Entry = {Generation, NewSCEV}; return NewSCEV; } +const SCEV *PredicatedScalarEvolution::getBackedgeTakenCount() { + if (!BackedgeCount) { + SCEVUnionPredicate BackedgePred; + BackedgeCount = SE.getPredicatedBackedgeTakenCount(&L, BackedgePred); + addPredicate(BackedgePred); + } + return BackedgeCount; +} + void PredicatedScalarEvolution::addPredicate(const SCEVPredicate &Pred) { if (Preds.implies(&Pred)) return; @@ -9743,7 +10382,82 @@ void PredicatedScalarEvolution::updateGeneration() { if (++Generation == 0) { for (auto &II : RewriteMap) { const SCEV *Rewritten = II.second.second; - II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, Preds)}; + II.second = {Generation, SE.rewriteUsingPredicate(Rewritten, &L, Preds)}; } } } + +void PredicatedScalarEvolution::setNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + auto ImpliedFlags = SCEVWrapPredicate::getImpliedFlags(AR, SE); + + // Clear the statically implied flags. + Flags = SCEVWrapPredicate::clearFlags(Flags, ImpliedFlags); + addPredicate(*SE.getWrapPredicate(AR, Flags)); + + auto II = FlagsMap.insert({V, Flags}); + if (!II.second) + II.first->second = SCEVWrapPredicate::setFlags(Flags, II.first->second); +} + +bool PredicatedScalarEvolution::hasNoOverflow( + Value *V, SCEVWrapPredicate::IncrementWrapFlags Flags) { + const SCEV *Expr = getSCEV(V); + const auto *AR = cast<SCEVAddRecExpr>(Expr); + + Flags = SCEVWrapPredicate::clearFlags( + Flags, SCEVWrapPredicate::getImpliedFlags(AR, SE)); + + auto II = FlagsMap.find(V); + + if (II != FlagsMap.end()) + Flags = SCEVWrapPredicate::clearFlags(Flags, II->second); + + return Flags == SCEVWrapPredicate::IncrementAnyWrap; +} + +const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) { + const SCEV *Expr = this->getSCEV(V); + auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); + + if (!New) + return nullptr; + + updateGeneration(); + RewriteMap[SE.getSCEV(V)] = {Generation, New}; + return New; +} + +PredicatedScalarEvolution::PredicatedScalarEvolution( + const PredicatedScalarEvolution &Init) + : RewriteMap(Init.RewriteMap), SE(Init.SE), L(Init.L), Preds(Init.Preds), + Generation(Init.Generation), BackedgeCount(Init.BackedgeCount) { + for (const auto &I : Init.FlagsMap) + FlagsMap.insert(I); +} + +void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { + // For each block. + for (auto *BB : L.getBlocks()) + for (auto &I : *BB) { + if (!SE.isSCEVable(I.getType())) + continue; + + auto *Expr = SE.getSCEV(&I); + auto II = RewriteMap.find(Expr); + + if (II == RewriteMap.end()) + continue; + + // Don't print things that are not interesting. + if (II->second.second == Expr) + continue; + + OS.indent(Depth) << "[PSE]" << I << ":\n"; + OS.indent(Depth + 2) << *Expr << "\n"; + OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; + } +} diff --git a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp index 2e50c80c4e73..61fb411d3150 100644 --- a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp +++ b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -20,7 +20,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" -#include "llvm/Analysis/TargetLibraryInfo.h" using namespace llvm; AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, @@ -111,18 +110,16 @@ Value *SCEVAAResult::GetBaseValue(const SCEV *S) { return nullptr; } -SCEVAAResult SCEVAA::run(Function &F, AnalysisManager<Function> *AM) { - return SCEVAAResult(AM->getResult<TargetLibraryAnalysis>(F), - AM->getResult<ScalarEvolutionAnalysis>(F)); -} - char SCEVAA::PassID; +SCEVAAResult SCEVAA::run(Function &F, AnalysisManager<Function> &AM) { + return SCEVAAResult(AM.getResult<ScalarEvolutionAnalysis>(F)); +} + char SCEVAAWrapperPass::ID = 0; INITIALIZE_PASS_BEGIN(SCEVAAWrapperPass, "scev-aa", "ScalarEvolution-based Alias Analysis", false, true) INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(SCEVAAWrapperPass, "scev-aa", "ScalarEvolution-based Alias Analysis", false, true) @@ -136,13 +133,11 @@ SCEVAAWrapperPass::SCEVAAWrapperPass() : FunctionPass(ID) { bool SCEVAAWrapperPass::runOnFunction(Function &F) { Result.reset( - new SCEVAAResult(getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), - getAnalysis<ScalarEvolutionWrapperPass>().getSE())); + new SCEVAAResult(getAnalysis<ScalarEvolutionWrapperPass>().getSE())); return false; } void SCEVAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); AU.addRequired<ScalarEvolutionWrapperPass>(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); } diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 921403ddc0fd..77e4ec7ab40c 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis --*- C++ -*-===// +//===- ScalarEvolutionExpander.cpp - Scalar Evolution Analysis ------------===// // // The LLVM Compiler Infrastructure // @@ -95,14 +95,12 @@ static BasicBlock::iterator findInsertPointAfter(Instruction *I, while (isa<PHINode>(IP)) ++IP; - while (IP->isEHPad()) { - if (isa<FuncletPadInst>(IP) || isa<LandingPadInst>(IP)) { - ++IP; - } else if (isa<CatchSwitchInst>(IP)) { - IP = MustDominate->getFirstInsertionPt(); - } else { - llvm_unreachable("unexpected eh pad!"); - } + if (isa<FuncletPadInst>(IP) || isa<LandingPadInst>(IP)) { + ++IP; + } else if (isa<CatchSwitchInst>(IP)) { + IP = MustDominate->getFirstInsertionPt(); + } else { + assert(!IP->isEHPad() && "unexpected eh pad!"); } return IP; @@ -198,7 +196,7 @@ Value *SCEVExpander::InsertBinop(Instruction::BinaryOps Opcode, // Save the original insertion point so we can restore it when we're done. DebugLoc Loc = Builder.GetInsertPoint()->getDebugLoc(); - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); // Move the insertion point out of as many loops as we can. while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { @@ -525,7 +523,7 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, } // Save the original insertion point so we can restore it when we're done. - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); // Move the insertion point out of as many loops as we can. while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { @@ -544,39 +542,37 @@ Value *SCEVExpander::expandAddToGEP(const SCEV *const *op_begin, return GEP; } - // Save the original insertion point so we can restore it when we're done. - BuilderType::InsertPoint SaveInsertPt = Builder.saveIP(); + { + SCEVInsertPointGuard Guard(Builder, this); - // Move the insertion point out of as many loops as we can. - while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { - if (!L->isLoopInvariant(V)) break; - - bool AnyIndexNotLoopInvariant = - std::any_of(GepIndices.begin(), GepIndices.end(), - [L](Value *Op) { return !L->isLoopInvariant(Op); }); + // Move the insertion point out of as many loops as we can. + while (const Loop *L = SE.LI.getLoopFor(Builder.GetInsertBlock())) { + if (!L->isLoopInvariant(V)) break; - if (AnyIndexNotLoopInvariant) - break; + bool AnyIndexNotLoopInvariant = + std::any_of(GepIndices.begin(), GepIndices.end(), + [L](Value *Op) { return !L->isLoopInvariant(Op); }); - BasicBlock *Preheader = L->getLoopPreheader(); - if (!Preheader) break; + if (AnyIndexNotLoopInvariant) + break; - // Ok, move up a level. - Builder.SetInsertPoint(Preheader->getTerminator()); - } + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) break; - // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, - // because ScalarEvolution may have changed the address arithmetic to - // compute a value which is beyond the end of the allocated object. - Value *Casted = V; - if (V->getType() != PTy) - Casted = InsertNoopCastOfTo(Casted, PTy); - Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep"); - Ops.push_back(SE.getUnknown(GEP)); - rememberInstruction(GEP); + // Ok, move up a level. + Builder.SetInsertPoint(Preheader->getTerminator()); + } - // Restore the original insert point. - Builder.restoreIP(SaveInsertPt); + // Insert a pretty getelementptr. Note that this GEP is not marked inbounds, + // because ScalarEvolution may have changed the address arithmetic to + // compute a value which is beyond the end of the allocated object. + Value *Casted = V; + if (V->getType() != PTy) + Casted = InsertNoopCastOfTo(Casted, PTy); + Value *GEP = Builder.CreateGEP(OriginalElTy, Casted, GepIndices, "scevgep"); + Ops.push_back(SE.getUnknown(GEP)); + rememberInstruction(GEP); + } return expand(SE.getAddExpr(Ops)); } @@ -907,6 +903,23 @@ Instruction *SCEVExpander::getIVIncOperand(Instruction *IncV, } } +/// If the insert point of the current builder or any of the builders on the +/// stack of saved builders has 'I' as its insert point, update it to point to +/// the instruction after 'I'. This is intended to be used when the instruction +/// 'I' is being moved. If this fixup is not done and 'I' is moved to a +/// different block, the inconsistent insert point (with a mismatched +/// Instruction and Block) can lead to an instruction being inserted in a block +/// other than its parent. +void SCEVExpander::fixupInsertPoints(Instruction *I) { + BasicBlock::iterator It(*I); + BasicBlock::iterator NewInsertPt = std::next(It); + if (Builder.GetInsertPoint() == It) + Builder.SetInsertPoint(&*NewInsertPt); + for (auto *InsertPtGuard : InsertPointGuards) + if (InsertPtGuard->GetInsertPoint() == It) + InsertPtGuard->SetInsertPoint(NewInsertPt); +} + /// hoistStep - Attempt to hoist a simple IV increment above InsertPos to make /// it available to other uses in this loop. Recursively hoist any operands, /// until we reach a value that dominates InsertPos. @@ -936,6 +949,7 @@ bool SCEVExpander::hoistIVInc(Instruction *IncV, Instruction *InsertPos) { break; } for (auto I = IVIncs.rbegin(), E = IVIncs.rend(); I != E; ++I) { + fixupInsertPoints(*I); (*I)->moveBefore(InsertPos); } return true; @@ -989,13 +1003,14 @@ Value *SCEVExpander::expandIVInc(PHINode *PN, Value *StepV, const Loop *L, /// \brief Hoist the addrec instruction chain rooted in the loop phi above the /// position. This routine assumes that this is possible (has been checked). -static void hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, - Instruction *Pos, PHINode *LoopPhi) { +void SCEVExpander::hoistBeforePos(DominatorTree *DT, Instruction *InstToHoist, + Instruction *Pos, PHINode *LoopPhi) { do { if (DT->dominates(InstToHoist, Pos)) break; // Make sure the increment is where we want it. But don't move it // down past a potential existing post-inc user. + fixupInsertPoints(InstToHoist); InstToHoist->moveBefore(Pos); Pos = InstToHoist; InstToHoist = cast<Instruction>(InstToHoist->getOperand(0)); @@ -1156,7 +1171,7 @@ SCEVExpander::getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, } // Save the original insertion point so we can restore it when we're done. - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); // Another AddRec may need to be recursively expanded below. For example, if // this AddRec is quadratic, the StepV may itself be an AddRec in this @@ -1273,6 +1288,13 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { if (!SE.dominates(Step, L->getHeader())) { PostLoopScale = Step; Step = SE.getConstant(Normalized->getType(), 1); + if (!Start->isZero()) { + // The normalization below assumes that Start is constant zero, so if + // it isn't re-associate Start to PostLoopOffset. + assert(!PostLoopOffset && "Start not-null but PostLoopOffset set?"); + PostLoopOffset = Start; + Start = SE.getConstant(Normalized->getType(), 0); + } Normalized = cast<SCEVAddRecExpr>(SE.getAddRecExpr( Start, Step, Normalized->getLoop(), @@ -1321,7 +1343,7 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { Value *StepV; { // Expand the step somewhere that dominates the loop header. - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); StepV = expandCodeFor(Step, IntTy, &L->getHeader()->front()); } Result = expandIVInc(PN, StepV, L, ExpandTy, IntTy, useSubtract); @@ -1428,8 +1450,12 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { } // Just do a normal add. Pre-expand the operands to suppress folding. - return expand(SE.getAddExpr(SE.getUnknown(expand(S->getStart())), - SE.getUnknown(expand(Rest)))); + // + // The LHS and RHS values are factored out of the expand call to make the + // output independent of the argument evaluation order. + const SCEV *AddExprLHS = SE.getUnknown(expand(S->getStart())); + const SCEV *AddExprRHS = SE.getUnknown(expand(Rest)); + return expand(SE.getAddExpr(AddExprLHS, AddExprRHS)); } // If we don't yet have a canonical IV, create one. @@ -1600,6 +1626,40 @@ Value *SCEVExpander::expandCodeFor(const SCEV *SH, Type *Ty) { return V; } +Value *SCEVExpander::FindValueInExprValueMap(const SCEV *S, + const Instruction *InsertPt) { + SetVector<Value *> *Set = SE.getSCEVValues(S); + // If the expansion is not in CanonicalMode, and the SCEV contains any + // sub scAddRecExpr type SCEV, it is required to expand the SCEV literally. + if (CanonicalMode || !SE.containsAddRecurrence(S)) { + // If S is scConstant, it may be worse to reuse an existing Value. + if (S->getSCEVType() != scConstant && Set) { + // Choose a Value from the set which dominates the insertPt. + // insertPt should be inside the Value's parent loop so as not to break + // the LCSSA form. + for (auto const &Ent : *Set) { + Instruction *EntInst = nullptr; + if (Ent && isa<Instruction>(Ent) && + (EntInst = cast<Instruction>(Ent)) && + S->getType() == Ent->getType() && + EntInst->getFunction() == InsertPt->getFunction() && + SE.DT.dominates(EntInst, InsertPt) && + (SE.LI.getLoopFor(EntInst->getParent()) == nullptr || + SE.LI.getLoopFor(EntInst->getParent())->contains(InsertPt))) { + return Ent; + } + } + } + } + return nullptr; +} + +// The expansion of SCEV will either reuse a previous Value in ExprValueMap, +// or expand the SCEV literally. Specifically, if the expansion is in LSRMode, +// and the SCEV contains any sub scAddRecExpr type SCEV, it will be expanded +// literally, to prevent LSR's transformed SCEV from being reverted. Otherwise, +// the expansion will try to reuse Value from ExprValueMap, and only when it +// fails, expand the SCEV literally. Value *SCEVExpander::expand(const SCEV *S) { // Compute an insertion point for this SCEV object. Hoist the instructions // as far out in the loop nest as possible. @@ -1622,9 +1682,9 @@ Value *SCEVExpander::expand(const SCEV *S) { // there) so that it is guaranteed to dominate any user inside the loop. if (L && SE.hasComputableLoopEvolution(S, L) && !PostIncLoops.count(L)) InsertPt = &*L->getHeader()->getFirstInsertionPt(); - while (InsertPt != Builder.GetInsertPoint() - && (isInsertedInstruction(InsertPt) - || isa<DbgInfoIntrinsic>(InsertPt))) { + while (InsertPt->getIterator() != Builder.GetInsertPoint() && + (isInsertedInstruction(InsertPt) || + isa<DbgInfoIntrinsic>(InsertPt))) { InsertPt = &*std::next(InsertPt->getIterator()); } break; @@ -1635,11 +1695,14 @@ Value *SCEVExpander::expand(const SCEV *S) { if (I != InsertedExpressions.end()) return I->second; - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); Builder.SetInsertPoint(InsertPt); // Expand the expression into instructions. - Value *V = visit(S); + Value *V = FindValueInExprValueMap(S, InsertPt); + + if (!V) + V = visit(S); // Remember the expanded value for this SCEV at this location. // @@ -1673,7 +1736,7 @@ SCEVExpander::getOrInsertCanonicalInductionVariable(const Loop *L, SE.getConstant(Ty, 1), L, SCEV::FlagAnyWrap); // Emit code for it. - BuilderType::InsertPointGuard Guard(Builder); + SCEVInsertPointGuard Guard(Builder, this); PHINode *V = cast<PHINode>(expandCodeFor(H, nullptr, &L->getHeader()->front())); @@ -1742,8 +1805,8 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, PHINode *&OrigPhiRef = ExprToIVMap[SE.getSCEV(Phi)]; if (!OrigPhiRef) { OrigPhiRef = Phi; - if (Phi->getType()->isIntegerTy() && TTI - && TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { + if (Phi->getType()->isIntegerTy() && TTI && + TTI->isTruncateFree(Phi->getType(), Phis.back()->getType())) { // This phi can be freely truncated to the narrowest phi type. Map the // truncated expression to it so it will be reused for narrow types. const SCEV *TruncExpr = @@ -1759,56 +1822,59 @@ unsigned SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, continue; if (BasicBlock *LatchBlock = L->getLoopLatch()) { - Instruction *OrigInc = - cast<Instruction>(OrigPhiRef->getIncomingValueForBlock(LatchBlock)); + Instruction *OrigInc = dyn_cast<Instruction>( + OrigPhiRef->getIncomingValueForBlock(LatchBlock)); Instruction *IsomorphicInc = - cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); - - // If this phi has the same width but is more canonical, replace the - // original with it. As part of the "more canonical" determination, - // respect a prior decision to use an IV chain. - if (OrigPhiRef->getType() == Phi->getType() - && !(ChainedPhis.count(Phi) - || isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) - && (ChainedPhis.count(Phi) - || isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { - std::swap(OrigPhiRef, Phi); - std::swap(OrigInc, IsomorphicInc); - } - // Replacing the congruent phi is sufficient because acyclic redundancy - // elimination, CSE/GVN, should handle the rest. However, once SCEV proves - // that a phi is congruent, it's often the head of an IV user cycle that - // is isomorphic with the original phi. It's worth eagerly cleaning up the - // common case of a single IV increment so that DeleteDeadPHIs can remove - // cycles that had postinc uses. - const SCEV *TruncExpr = SE.getTruncateOrNoop(SE.getSCEV(OrigInc), - IsomorphicInc->getType()); - if (OrigInc != IsomorphicInc - && TruncExpr == SE.getSCEV(IsomorphicInc) - && ((isa<PHINode>(OrigInc) && isa<PHINode>(IsomorphicInc)) - || hoistIVInc(OrigInc, IsomorphicInc))) { - DEBUG_WITH_TYPE(DebugType, dbgs() - << "INDVARS: Eliminated congruent iv.inc: " - << *IsomorphicInc << '\n'); - Value *NewInc = OrigInc; - if (OrigInc->getType() != IsomorphicInc->getType()) { - Instruction *IP = nullptr; - if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) - IP = &*PN->getParent()->getFirstInsertionPt(); - else - IP = OrigInc->getNextNode(); - - IRBuilder<> Builder(IP); - Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); - NewInc = Builder. - CreateTruncOrBitCast(OrigInc, IsomorphicInc->getType(), IVName); + dyn_cast<Instruction>(Phi->getIncomingValueForBlock(LatchBlock)); + + if (OrigInc && IsomorphicInc) { + // If this phi has the same width but is more canonical, replace the + // original with it. As part of the "more canonical" determination, + // respect a prior decision to use an IV chain. + if (OrigPhiRef->getType() == Phi->getType() && + !(ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(OrigPhiRef, OrigInc, L)) && + (ChainedPhis.count(Phi) || + isExpandedAddRecExprPHI(Phi, IsomorphicInc, L))) { + std::swap(OrigPhiRef, Phi); + std::swap(OrigInc, IsomorphicInc); + } + // Replacing the congruent phi is sufficient because acyclic + // redundancy elimination, CSE/GVN, should handle the + // rest. However, once SCEV proves that a phi is congruent, + // it's often the head of an IV user cycle that is isomorphic + // with the original phi. It's worth eagerly cleaning up the + // common case of a single IV increment so that DeleteDeadPHIs + // can remove cycles that had postinc uses. + const SCEV *TruncExpr = + SE.getTruncateOrNoop(SE.getSCEV(OrigInc), IsomorphicInc->getType()); + if (OrigInc != IsomorphicInc && + TruncExpr == SE.getSCEV(IsomorphicInc) && + SE.LI.replacementPreservesLCSSAForm(IsomorphicInc, OrigInc) && + hoistIVInc(OrigInc, IsomorphicInc)) { + DEBUG_WITH_TYPE(DebugType, + dbgs() << "INDVARS: Eliminated congruent iv.inc: " + << *IsomorphicInc << '\n'); + Value *NewInc = OrigInc; + if (OrigInc->getType() != IsomorphicInc->getType()) { + Instruction *IP = nullptr; + if (PHINode *PN = dyn_cast<PHINode>(OrigInc)) + IP = &*PN->getParent()->getFirstInsertionPt(); + else + IP = OrigInc->getNextNode(); + + IRBuilder<> Builder(IP); + Builder.SetCurrentDebugLocation(IsomorphicInc->getDebugLoc()); + NewInc = Builder.CreateTruncOrBitCast( + OrigInc, IsomorphicInc->getType(), IVName); + } + IsomorphicInc->replaceAllUsesWith(NewInc); + DeadInsts.emplace_back(IsomorphicInc); } - IsomorphicInc->replaceAllUsesWith(NewInc); - DeadInsts.emplace_back(IsomorphicInc); } } - DEBUG_WITH_TYPE(DebugType, dbgs() - << "INDVARS: Eliminated congruent iv: " << *Phi << '\n'); + DEBUG_WITH_TYPE(DebugType, dbgs() << "INDVARS: Eliminated congruent iv: " + << *Phi << '\n'); ++NumElim; Value *NewIV = OrigPhiRef; if (OrigPhiRef->getType() != Phi->getType()) { @@ -1847,6 +1913,11 @@ Value *SCEVExpander::findExistingExpansion(const SCEV *S, return RHS; } + // Use expand's logic which is used for reusing a previous Value in + // ExprValueMap. + if (Value *Val = FindValueInExprValueMap(S, At)) + return Val; + // There is potential to make this significantly smarter, but this simple // heuristic already gets some interesting cases. @@ -1940,6 +2011,10 @@ Value *SCEVExpander::expandCodeForPredicate(const SCEVPredicate *Pred, return expandUnionPredicate(cast<SCEVUnionPredicate>(Pred), IP); case SCEVPredicate::P_Equal: return expandEqualPredicate(cast<SCEVEqualPredicate>(Pred), IP); + case SCEVPredicate::P_Wrap: { + auto *AddRecPred = cast<SCEVWrapPredicate>(Pred); + return expandWrapPredicate(AddRecPred, IP); + } } llvm_unreachable("Unknown SCEV predicate type"); } @@ -1954,6 +2029,116 @@ Value *SCEVExpander::expandEqualPredicate(const SCEVEqualPredicate *Pred, return I; } +Value *SCEVExpander::generateOverflowCheck(const SCEVAddRecExpr *AR, + Instruction *Loc, bool Signed) { + assert(AR->isAffine() && "Cannot generate RT check for " + "non-affine expression"); + + SCEVUnionPredicate Pred; + const SCEV *ExitCount = + SE.getPredicatedBackedgeTakenCount(AR->getLoop(), Pred); + + assert(ExitCount != SE.getCouldNotCompute() && "Invalid loop count"); + + const SCEV *Step = AR->getStepRecurrence(SE); + const SCEV *Start = AR->getStart(); + + unsigned SrcBits = SE.getTypeSizeInBits(ExitCount->getType()); + unsigned DstBits = SE.getTypeSizeInBits(AR->getType()); + + // The expression {Start,+,Step} has nusw/nssw if + // Step < 0, Start - |Step| * Backedge <= Start + // Step >= 0, Start + |Step| * Backedge > Start + // and |Step| * Backedge doesn't unsigned overflow. + + IntegerType *CountTy = IntegerType::get(Loc->getContext(), SrcBits); + Builder.SetInsertPoint(Loc); + Value *TripCountVal = expandCodeFor(ExitCount, CountTy, Loc); + + IntegerType *Ty = + IntegerType::get(Loc->getContext(), SE.getTypeSizeInBits(AR->getType())); + + Value *StepValue = expandCodeFor(Step, Ty, Loc); + Value *NegStepValue = expandCodeFor(SE.getNegativeSCEV(Step), Ty, Loc); + Value *StartValue = expandCodeFor(Start, Ty, Loc); + + ConstantInt *Zero = + ConstantInt::get(Loc->getContext(), APInt::getNullValue(DstBits)); + + Builder.SetInsertPoint(Loc); + // Compute |Step| + Value *StepCompare = Builder.CreateICmp(ICmpInst::ICMP_SLT, StepValue, Zero); + Value *AbsStep = Builder.CreateSelect(StepCompare, NegStepValue, StepValue); + + // Get the backedge taken count and truncate or extended to the AR type. + Value *TruncTripCount = Builder.CreateZExtOrTrunc(TripCountVal, Ty); + auto *MulF = Intrinsic::getDeclaration(Loc->getModule(), + Intrinsic::umul_with_overflow, Ty); + + // Compute |Step| * Backedge + CallInst *Mul = Builder.CreateCall(MulF, {AbsStep, TruncTripCount}, "mul"); + Value *MulV = Builder.CreateExtractValue(Mul, 0, "mul.result"); + Value *OfMul = Builder.CreateExtractValue(Mul, 1, "mul.overflow"); + + // Compute: + // Start + |Step| * Backedge < Start + // Start - |Step| * Backedge > Start + Value *Add = Builder.CreateAdd(StartValue, MulV); + Value *Sub = Builder.CreateSub(StartValue, MulV); + + Value *EndCompareGT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, Sub, StartValue); + + Value *EndCompareLT = Builder.CreateICmp( + Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, Add, StartValue); + + // Select the answer based on the sign of Step. + Value *EndCheck = + Builder.CreateSelect(StepCompare, EndCompareGT, EndCompareLT); + + // If the backedge taken count type is larger than the AR type, + // check that we don't drop any bits by truncating it. If we are + // droping bits, then we have overflow (unless the step is zero). + if (SE.getTypeSizeInBits(CountTy) > SE.getTypeSizeInBits(Ty)) { + auto MaxVal = APInt::getMaxValue(DstBits).zext(SrcBits); + auto *BackedgeCheck = + Builder.CreateICmp(ICmpInst::ICMP_UGT, TripCountVal, + ConstantInt::get(Loc->getContext(), MaxVal)); + BackedgeCheck = Builder.CreateAnd( + BackedgeCheck, Builder.CreateICmp(ICmpInst::ICMP_NE, StepValue, Zero)); + + EndCheck = Builder.CreateOr(EndCheck, BackedgeCheck); + } + + EndCheck = Builder.CreateOr(EndCheck, OfMul); + return EndCheck; +} + +Value *SCEVExpander::expandWrapPredicate(const SCEVWrapPredicate *Pred, + Instruction *IP) { + const auto *A = cast<SCEVAddRecExpr>(Pred->getExpr()); + Value *NSSWCheck = nullptr, *NUSWCheck = nullptr; + + // Add a check for NUSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNUSW) + NUSWCheck = generateOverflowCheck(A, IP, false); + + // Add a check for NSSW + if (Pred->getFlags() & SCEVWrapPredicate::IncrementNSSW) + NSSWCheck = generateOverflowCheck(A, IP, true); + + if (NUSWCheck && NSSWCheck) + return Builder.CreateOr(NUSWCheck, NSSWCheck); + + if (NUSWCheck) + return NUSWCheck; + + if (NSSWCheck) + return NSSWCheck; + + return ConstantInt::getFalse(IP->getContext()); +} + Value *SCEVExpander::expandUnionPredicate(const SCEVUnionPredicate *Union, Instruction *IP) { auto *BoolType = IntegerType::get(IP->getContext(), 1); diff --git a/lib/Analysis/ScalarEvolutionNormalization.cpp b/lib/Analysis/ScalarEvolutionNormalization.cpp index b7fd5d506175..c1f9503816ee 100644 --- a/lib/Analysis/ScalarEvolutionNormalization.cpp +++ b/lib/Analysis/ScalarEvolutionNormalization.cpp @@ -1,4 +1,4 @@ -//===- ScalarEvolutionNormalization.cpp - See below -------------*- C++ -*-===// +//===- ScalarEvolutionNormalization.cpp - See below -----------------------===// // // The LLVM Compiler Infrastructure // diff --git a/lib/Analysis/ScopedNoAliasAA.cpp b/lib/Analysis/ScopedNoAliasAA.cpp index 486f3a583284..82e65a1f2088 100644 --- a/lib/Analysis/ScopedNoAliasAA.cpp +++ b/lib/Analysis/ScopedNoAliasAA.cpp @@ -34,7 +34,6 @@ #include "llvm/Analysis/ScopedNoAliasAA.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Constants.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" @@ -51,9 +50,9 @@ static cl::opt<bool> EnableScopedNoAlias("enable-scoped-noalias", cl::init(true)); namespace { -/// AliasScopeNode - This is a simple wrapper around an MDNode which provides -/// a higher-level interface by hiding the details of how alias analysis -/// information is encoded in its operands. +/// This is a simple wrapper around an MDNode which provides a higher-level +/// interface by hiding the details of how alias analysis information is encoded +/// in its operands. class AliasScopeNode { const MDNode *Node; @@ -61,10 +60,10 @@ public: AliasScopeNode() : Node(nullptr) {} explicit AliasScopeNode(const MDNode *N) : Node(N) {} - /// getNode - Get the MDNode for this AliasScopeNode. + /// Get the MDNode for this AliasScopeNode. const MDNode *getNode() const { return Node; } - /// getDomain - Get the MDNode for this AliasScopeNode's domain. + /// Get the MDNode for this AliasScopeNode's domain. const MDNode *getDomain() const { if (Node->getNumOperands() < 2) return nullptr; @@ -131,8 +130,8 @@ ModRefInfo ScopedNoAliasAAResult::getModRefInfo(ImmutableCallSite CS1, void ScopedNoAliasAAResult::collectMDInDomain( const MDNode *List, const MDNode *Domain, SmallPtrSetImpl<const MDNode *> &Nodes) const { - for (unsigned i = 0, ie = List->getNumOperands(); i != ie; ++i) - if (const MDNode *MD = dyn_cast<MDNode>(List->getOperand(i))) + for (const MDOperand &MDOp : List->operands()) + if (const MDNode *MD = dyn_cast<MDNode>(MDOp)) if (AliasScopeNode(MD).getDomain() == Domain) Nodes.insert(MD); } @@ -144,8 +143,8 @@ bool ScopedNoAliasAAResult::mayAliasInScopes(const MDNode *Scopes, // Collect the set of scope domains relevant to the noalias scopes. SmallPtrSet<const MDNode *, 16> Domains; - for (unsigned i = 0, ie = NoAlias->getNumOperands(); i != ie; ++i) - if (const MDNode *NAMD = dyn_cast<MDNode>(NoAlias->getOperand(i))) + for (const MDOperand &MDOp : NoAlias->operands()) + if (const MDNode *NAMD = dyn_cast<MDNode>(MDOp)) if (const MDNode *Domain = AliasScopeNode(NAMD).getDomain()) Domains.insert(Domain); @@ -173,19 +172,16 @@ bool ScopedNoAliasAAResult::mayAliasInScopes(const MDNode *Scopes, return true; } +char ScopedNoAliasAA::PassID; + ScopedNoAliasAAResult ScopedNoAliasAA::run(Function &F, - AnalysisManager<Function> *AM) { - return ScopedNoAliasAAResult(AM->getResult<TargetLibraryAnalysis>(F)); + AnalysisManager<Function> &AM) { + return ScopedNoAliasAAResult(); } -char ScopedNoAliasAA::PassID; - char ScopedNoAliasAAWrapperPass::ID = 0; -INITIALIZE_PASS_BEGIN(ScopedNoAliasAAWrapperPass, "scoped-noalias", - "Scoped NoAlias Alias Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(ScopedNoAliasAAWrapperPass, "scoped-noalias", - "Scoped NoAlias Alias Analysis", false, true) +INITIALIZE_PASS(ScopedNoAliasAAWrapperPass, "scoped-noalias", + "Scoped NoAlias Alias Analysis", false, true) ImmutablePass *llvm::createScopedNoAliasAAWrapperPass() { return new ScopedNoAliasAAWrapperPass(); @@ -196,8 +192,7 @@ ScopedNoAliasAAWrapperPass::ScopedNoAliasAAWrapperPass() : ImmutablePass(ID) { } bool ScopedNoAliasAAWrapperPass::doInitialization(Module &M) { - Result.reset(new ScopedNoAliasAAResult( - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); + Result.reset(new ScopedNoAliasAAResult()); return false; } @@ -208,5 +203,4 @@ bool ScopedNoAliasAAWrapperPass::doFinalization(Module &M) { void ScopedNoAliasAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); } diff --git a/lib/Analysis/SparsePropagation.cpp b/lib/Analysis/SparsePropagation.cpp index f5a927b80525..79dc84e25533 100644 --- a/lib/Analysis/SparsePropagation.cpp +++ b/lib/Analysis/SparsePropagation.cpp @@ -320,8 +320,8 @@ void SparseSolver::Solve(Function &F) { // Notify all instructions in this basic block that they are newly // executable. - for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) - visitInst(*I); + for (Instruction &I : *BB) + visitInst(I); } } } diff --git a/lib/Analysis/StratifiedSets.h b/lib/Analysis/StratifiedSets.h index fd3fbc0d86ad..fd3a241d79c1 100644 --- a/lib/Analysis/StratifiedSets.h +++ b/lib/Analysis/StratifiedSets.h @@ -10,60 +10,49 @@ #ifndef LLVM_ADT_STRATIFIEDSETS_H #define LLVM_ADT_STRATIFIEDSETS_H +#include "AliasAnalysisSummary.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/Optional.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Compiler.h" #include <bitset> #include <cassert> #include <cmath> -#include <limits> #include <type_traits> #include <utility> #include <vector> namespace llvm { -// \brief An index into Stratified Sets. +namespace cflaa { +/// An index into Stratified Sets. typedef unsigned StratifiedIndex; -// NOTE: ^ This can't be a short -- bootstrapping clang has a case where -// ~1M sets exist. +/// NOTE: ^ This can't be a short -- bootstrapping clang has a case where +/// ~1M sets exist. // \brief Container of information related to a value in a StratifiedSet. struct StratifiedInfo { StratifiedIndex Index; - // For field sensitivity, etc. we can tack attributes on to this struct. + /// For field sensitivity, etc. we can tack fields on here. }; -// The number of attributes that StratifiedAttrs should contain. Attributes are -// described below, and 32 was an arbitrary choice because it fits nicely in 32 -// bits (because we use a bitset for StratifiedAttrs). -static const unsigned NumStratifiedAttrs = 32; - -// These are attributes that the users of StratifiedSets/StratifiedSetBuilders -// may use for various purposes. These also have the special property of that -// they are merged down. So, if set A is above set B, and one decides to set an -// attribute in set A, then the attribute will automatically be set in set B. -typedef std::bitset<NumStratifiedAttrs> StratifiedAttrs; - -// \brief A "link" between two StratifiedSets. +/// A "link" between two StratifiedSets. struct StratifiedLink { - // \brief This is a value used to signify "does not exist" where - // the StratifiedIndex type is used. This is used instead of - // Optional<StratifiedIndex> because Optional<StratifiedIndex> would - // eat up a considerable amount of extra memory, after struct - // padding/alignment is taken into account. + /// \brief This is a value used to signify "does not exist" where the + /// StratifiedIndex type is used. + /// + /// This is used instead of Optional<StratifiedIndex> because + /// Optional<StratifiedIndex> would eat up a considerable amount of extra + /// memory, after struct padding/alignment is taken into account. static const StratifiedIndex SetSentinel; - // \brief The index for the set "above" current + /// The index for the set "above" current StratifiedIndex Above; - // \brief The link for the set "below" current + /// The link for the set "below" current StratifiedIndex Below; - // \brief Attributes for these StratifiedSets. - StratifiedAttrs Attrs; + /// Attributes for these StratifiedSets. + AliasAttrs Attrs; StratifiedLink() : Above(SetSentinel), Below(SetSentinel) {} @@ -74,46 +63,48 @@ struct StratifiedLink { void clearAbove() { Above = SetSentinel; } }; -// \brief These are stratified sets, as described in "Fast algorithms for -// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M -// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets -// of Value*s. If two Value*s are in the same set, or if both sets have -// overlapping attributes, then the Value*s are said to alias. -// -// Sets may be related by position, meaning that one set may be considered as -// above or below another. In CFL Alias Analysis, this gives us an indication -// of how two variables are related; if the set of variable A is below a set -// containing variable B, then at some point, a variable that has interacted -// with B (or B itself) was either used in order to extract the variable A, or -// was used as storage of variable A. -// -// Sets may also have attributes (as noted above). These attributes are -// generally used for noting whether a variable in the set has interacted with -// a variable whose origins we don't quite know (i.e. globals/arguments), or if -// the variable may have had operations performed on it (modified in a function -// call). All attributes that exist in a set A must exist in all sets marked as -// below set A. +/// \brief These are stratified sets, as described in "Fast algorithms for +/// Dyck-CFL-reachability with applications to Alias Analysis" by Zhang Q, Lyu M +/// R, Yuan H, and Su Z. -- in short, this is meant to represent different sets +/// of Value*s. If two Value*s are in the same set, or if both sets have +/// overlapping attributes, then the Value*s are said to alias. +/// +/// Sets may be related by position, meaning that one set may be considered as +/// above or below another. In CFL Alias Analysis, this gives us an indication +/// of how two variables are related; if the set of variable A is below a set +/// containing variable B, then at some point, a variable that has interacted +/// with B (or B itself) was either used in order to extract the variable A, or +/// was used as storage of variable A. +/// +/// Sets may also have attributes (as noted above). These attributes are +/// generally used for noting whether a variable in the set has interacted with +/// a variable whose origins we don't quite know (i.e. globals/arguments), or if +/// the variable may have had operations performed on it (modified in a function +/// call). All attributes that exist in a set A must exist in all sets marked as +/// below set A. template <typename T> class StratifiedSets { public: - StratifiedSets() {} - - StratifiedSets(DenseMap<T, StratifiedInfo> Map, - std::vector<StratifiedLink> Links) - : Values(std::move(Map)), Links(std::move(Links)) {} + StratifiedSets() = default; - StratifiedSets(StratifiedSets<T> &&Other) { *this = std::move(Other); } + // TODO: Figure out how to make MSVC not call the copy ctor here, and delete + // it. - StratifiedSets &operator=(StratifiedSets<T> &&Other) { + // Can't default these due to compile errors in MSVC2013 + StratifiedSets(StratifiedSets &&Other) { *this = std::move(Other); } + StratifiedSets &operator=(StratifiedSets &&Other) { Values = std::move(Other.Values); Links = std::move(Other.Links); return *this; } + StratifiedSets(DenseMap<T, StratifiedInfo> Map, + std::vector<StratifiedLink> Links) + : Values(std::move(Map)), Links(std::move(Links)) {} + Optional<StratifiedInfo> find(const T &Elem) const { auto Iter = Values.find(Elem); - if (Iter == Values.end()) { - return NoneType(); - } + if (Iter == Values.end()) + return None; return Iter->second; } @@ -129,91 +120,70 @@ private: bool inbounds(StratifiedIndex Idx) const { return Idx < Links.size(); } }; -// \brief Generic Builder class that produces StratifiedSets instances. -// -// The goal of this builder is to efficiently produce correct StratifiedSets -// instances. To this end, we use a few tricks: -// > Set chains (A method for linking sets together) -// > Set remaps (A method for marking a set as an alias [irony?] of another) -// -// ==== Set chains ==== -// This builder has a notion of some value A being above, below, or with some -// other value B: -// > The `A above B` relationship implies that there is a reference edge going -// from A to B. Namely, it notes that A can store anything in B's set. -// > The `A below B` relationship is the opposite of `A above B`. It implies -// that there's a dereference edge going from A to B. -// > The `A with B` relationship states that there's an assignment edge going -// from A to B, and that A and B should be treated as equals. -// -// As an example, take the following code snippet: -// -// %a = alloca i32, align 4 -// %ap = alloca i32*, align 8 -// %app = alloca i32**, align 8 -// store %a, %ap -// store %ap, %app -// %aw = getelementptr %ap, 0 -// -// Given this, the follow relations exist: -// - %a below %ap & %ap above %a -// - %ap below %app & %app above %ap -// - %aw with %ap & %ap with %aw -// -// These relations produce the following sets: -// [{%a}, {%ap, %aw}, {%app}] -// -// ...Which states that the only MayAlias relationship in the above program is -// between %ap and %aw. -// -// Life gets more complicated when we actually have logic in our programs. So, -// we either must remove this logic from our programs, or make consessions for -// it in our AA algorithms. In this case, we have decided to select the latter -// option. -// -// First complication: Conditionals -// Motivation: -// %ad = alloca int, align 4 -// %a = alloca int*, align 8 -// %b = alloca int*, align 8 -// %bp = alloca int**, align 8 -// %c = call i1 @SomeFunc() -// %k = select %c, %ad, %bp -// store %ad, %a -// store %b, %bp -// -// %k has 'with' edges to both %a and %b, which ordinarily would not be linked -// together. So, we merge the set that contains %a with the set that contains -// %b. We then recursively merge the set above %a with the set above %b, and -// the set below %a with the set below %b, etc. Ultimately, the sets for this -// program would end up like: {%ad}, {%a, %b, %k}, {%bp}, where {%ad} is below -// {%a, %b, %c} is below {%ad}. -// -// Second complication: Arbitrary casts -// Motivation: -// %ip = alloca int*, align 8 -// %ipp = alloca int**, align 8 -// %i = bitcast ipp to int -// store %ip, %ipp -// store %i, %ip -// -// This is impossible to construct with any of the rules above, because a set -// containing both {%i, %ipp} is supposed to exist, the set with %i is supposed -// to be below the set with %ip, and the set with %ip is supposed to be below -// the set with %ipp. Because we don't allow circular relationships like this, -// we merge all concerned sets into one. So, the above code would generate a -// single StratifiedSet: {%ip, %ipp, %i}. -// -// ==== Set remaps ==== -// More of an implementation detail than anything -- when merging sets, we need -// to update the numbers of all of the elements mapped to those sets. Rather -// than doing this at each merge, we note in the BuilderLink structure that a -// remap has occurred, and use this information so we can defer renumbering set -// elements until build time. +/// Generic Builder class that produces StratifiedSets instances. +/// +/// The goal of this builder is to efficiently produce correct StratifiedSets +/// instances. To this end, we use a few tricks: +/// > Set chains (A method for linking sets together) +/// > Set remaps (A method for marking a set as an alias [irony?] of another) +/// +/// ==== Set chains ==== +/// This builder has a notion of some value A being above, below, or with some +/// other value B: +/// > The `A above B` relationship implies that there is a reference edge +/// going from A to B. Namely, it notes that A can store anything in B's set. +/// > The `A below B` relationship is the opposite of `A above B`. It implies +/// that there's a dereference edge going from A to B. +/// > The `A with B` relationship states that there's an assignment edge going +/// from A to B, and that A and B should be treated as equals. +/// +/// As an example, take the following code snippet: +/// +/// %a = alloca i32, align 4 +/// %ap = alloca i32*, align 8 +/// %app = alloca i32**, align 8 +/// store %a, %ap +/// store %ap, %app +/// %aw = getelementptr %ap, i32 0 +/// +/// Given this, the following relations exist: +/// - %a below %ap & %ap above %a +/// - %ap below %app & %app above %ap +/// - %aw with %ap & %ap with %aw +/// +/// These relations produce the following sets: +/// [{%a}, {%ap, %aw}, {%app}] +/// +/// ...Which state that the only MayAlias relationship in the above program is +/// between %ap and %aw. +/// +/// Because LLVM allows arbitrary casts, code like the following needs to be +/// supported: +/// %ip = alloca i64, align 8 +/// %ipp = alloca i64*, align 8 +/// %i = bitcast i64** ipp to i64 +/// store i64* %ip, i64** %ipp +/// store i64 %i, i64* %ip +/// +/// Which, because %ipp ends up *both* above and below %ip, is fun. +/// +/// This is solved by merging %i and %ipp into a single set (...which is the +/// only way to solve this, since their bit patterns are equivalent). Any sets +/// that ended up in between %i and %ipp at the time of merging (in this case, +/// the set containing %ip) also get conservatively merged into the set of %i +/// and %ipp. In short, the resulting StratifiedSet from the above code would be +/// {%ip, %ipp, %i}. +/// +/// ==== Set remaps ==== +/// More of an implementation detail than anything -- when merging sets, we need +/// to update the numbers of all of the elements mapped to those sets. Rather +/// than doing this at each merge, we note in the BuilderLink structure that a +/// remap has occurred, and use this information so we can defer renumbering set +/// elements until build time. template <typename T> class StratifiedSetsBuilder { - // \brief Represents a Stratified Set, with information about the Stratified - // Set above it, the set below it, and whether the current set has been - // remapped to another. + /// \brief Represents a Stratified Set, with information about the Stratified + /// Set above it, the set below it, and whether the current set has been + /// remapped to another. struct BuilderLink { const StratifiedIndex Number; @@ -263,25 +233,19 @@ template <typename T> class StratifiedSetsBuilder { return Link.Above; } - StratifiedAttrs &getAttrs() { + AliasAttrs getAttrs() { assert(!isRemapped()); return Link.Attrs; } - void setAttr(unsigned index) { + void setAttrs(AliasAttrs Other) { assert(!isRemapped()); - assert(index < NumStratifiedAttrs); - Link.Attrs.set(index); - } - - void setAttrs(const StratifiedAttrs &other) { - assert(!isRemapped()); - Link.Attrs |= other; + Link.Attrs |= Other; } bool isRemapped() const { return Remap != StratifiedLink::SetSentinel; } - // \brief For initial remapping to another set + /// For initial remapping to another set void remapTo(StratifiedIndex Other) { assert(!isRemapped()); Remap = Other; @@ -292,15 +256,15 @@ template <typename T> class StratifiedSetsBuilder { return Remap; } - // \brief Should only be called when we're already remapped. + /// Should only be called when we're already remapped. void updateRemap(StratifiedIndex Other) { assert(isRemapped()); Remap = Other; } - // \brief Prefer the above functions to calling things directly on what's - // returned from this -- they guard against unexpected calls when the - // current BuilderLink is remapped. + /// Prefer the above functions to calling things directly on what's returned + /// from this -- they guard against unexpected calls when the current + /// BuilderLink is remapped. const StratifiedLink &getLink() const { return Link; } private: @@ -308,15 +272,14 @@ template <typename T> class StratifiedSetsBuilder { StratifiedIndex Remap; }; - // \brief This function performs all of the set unioning/value renumbering - // that we've been putting off, and generates a vector<StratifiedLink> that - // may be placed in a StratifiedSets instance. + /// \brief This function performs all of the set unioning/value renumbering + /// that we've been putting off, and generates a vector<StratifiedLink> that + /// may be placed in a StratifiedSets instance. void finalizeSets(std::vector<StratifiedLink> &StratLinks) { DenseMap<StratifiedIndex, StratifiedIndex> Remaps; for (auto &Link : Links) { - if (Link.isRemapped()) { + if (Link.isRemapped()) continue; - } StratifiedIndex Number = StratLinks.size(); Remaps.insert(std::make_pair(Link.Number, Number)); @@ -348,8 +311,8 @@ template <typename T> class StratifiedSetsBuilder { } } - // \brief There's a guarantee in StratifiedLink where all bits set in a - // Link.externals will be set in all Link.externals "below" it. + /// \brief There's a guarantee in StratifiedLink where all bits set in a + /// Link.externals will be set in all Link.externals "below" it. static void propagateAttrs(std::vector<StratifiedLink> &Links) { const auto getHighestParentAbove = [&Links](StratifiedIndex Idx) { const auto *Link = &Links[Idx]; @@ -363,9 +326,8 @@ template <typename T> class StratifiedSetsBuilder { SmallSet<StratifiedIndex, 16> Visited; for (unsigned I = 0, E = Links.size(); I < E; ++I) { auto CurrentIndex = getHighestParentAbove(I); - if (!Visited.insert(CurrentIndex).second) { + if (!Visited.insert(CurrentIndex).second) continue; - } while (Links[CurrentIndex].hasBelow()) { auto &CurrentBits = Links[CurrentIndex].Attrs; @@ -378,8 +340,8 @@ template <typename T> class StratifiedSetsBuilder { } public: - // \brief Builds a StratifiedSet from the information we've been given since - // either construction or the prior build() call. + /// Builds a StratifiedSet from the information we've been given since either + /// construction or the prior build() call. StratifiedSets<T> build() { std::vector<StratifiedLink> StratLinks; finalizeSets(StratLinks); @@ -388,9 +350,6 @@ public: return StratifiedSets<T>(std::move(Values), std::move(StratLinks)); } - std::size_t size() const { return Values.size(); } - std::size_t numSets() const { return Links.size(); } - bool has(const T &Elem) const { return get(Elem).hasValue(); } bool add(const T &Main) { @@ -401,9 +360,9 @@ public: return addAtMerging(Main, NewIndex); } - // \brief Restructures the stratified sets as necessary to make "ToAdd" in a - // set above "Main". There are some cases where this is not possible (see - // above), so we merge them such that ToAdd and Main are in the same set. + /// \brief Restructures the stratified sets as necessary to make "ToAdd" in a + /// set above "Main". There are some cases where this is not possible (see + /// above), so we merge them such that ToAdd and Main are in the same set. bool addAbove(const T &Main, const T &ToAdd) { assert(has(Main)); auto Index = *indexOf(Main); @@ -414,9 +373,9 @@ public: return addAtMerging(ToAdd, Above); } - // \brief Restructures the stratified sets as necessary to make "ToAdd" in a - // set below "Main". There are some cases where this is not possible (see - // above), so we merge them such that ToAdd and Main are in the same set. + /// \brief Restructures the stratified sets as necessary to make "ToAdd" in a + /// set below "Main". There are some cases where this is not possible (see + /// above), so we merge them such that ToAdd and Main are in the same set. bool addBelow(const T &Main, const T &ToAdd) { assert(has(Main)); auto Index = *indexOf(Main); @@ -433,65 +392,18 @@ public: return addAtMerging(ToAdd, MainIndex); } - void noteAttribute(const T &Main, unsigned AttrNum) { - assert(has(Main)); - assert(AttrNum < StratifiedLink::SetSentinel); - auto *Info = *get(Main); - auto &Link = linksAt(Info->Index); - Link.setAttr(AttrNum); - } - - void noteAttributes(const T &Main, const StratifiedAttrs &NewAttrs) { + void noteAttributes(const T &Main, AliasAttrs NewAttrs) { assert(has(Main)); auto *Info = *get(Main); auto &Link = linksAt(Info->Index); Link.setAttrs(NewAttrs); } - StratifiedAttrs getAttributes(const T &Main) { - assert(has(Main)); - auto *Info = *get(Main); - auto *Link = &linksAt(Info->Index); - auto Attrs = Link->getAttrs(); - while (Link->hasAbove()) { - Link = &linksAt(Link->getAbove()); - Attrs |= Link->getAttrs(); - } - - return Attrs; - } - - bool getAttribute(const T &Main, unsigned AttrNum) { - assert(AttrNum < StratifiedLink::SetSentinel); - auto Attrs = getAttributes(Main); - return Attrs[AttrNum]; - } - - // \brief Gets the attributes that have been applied to the set that Main - // belongs to. It ignores attributes in any sets above the one that Main - // resides in. - StratifiedAttrs getRawAttributes(const T &Main) { - assert(has(Main)); - auto *Info = *get(Main); - auto &Link = linksAt(Info->Index); - return Link.getAttrs(); - } - - // \brief Gets an attribute from the attributes that have been applied to the - // set that Main belongs to. It ignores attributes in any sets above the one - // that Main resides in. - bool getRawAttribute(const T &Main, unsigned AttrNum) { - assert(AttrNum < StratifiedLink::SetSentinel); - auto Attrs = getRawAttributes(Main); - return Attrs[AttrNum]; - } - private: DenseMap<T, StratifiedInfo> Values; std::vector<BuilderLink> Links; - // \brief Adds the given element at the given index, merging sets if - // necessary. + /// Adds the given element at the given index, merging sets if necessary. bool addAtMerging(const T &ToAdd, StratifiedIndex Index) { StratifiedInfo Info = {Index}; auto Pair = Values.insert(std::make_pair(ToAdd, Info)); @@ -509,8 +421,8 @@ private: return false; } - // \brief Gets the BuilderLink at the given index, taking set remapping into - // account. + /// Gets the BuilderLink at the given index, taking set remapping into + /// account. BuilderLink &linksAt(StratifiedIndex Index) { auto *Start = &Links[Index]; if (!Start->isRemapped()) @@ -534,8 +446,8 @@ private: return *Current; } - // \brief Merges two sets into one another. Assumes that these sets are not - // already one in the same + /// \brief Merges two sets into one another. Assumes that these sets are not + /// already one in the same. void merge(StratifiedIndex Idx1, StratifiedIndex Idx2) { assert(inbounds(Idx1) && inbounds(Idx2)); assert(&linksAt(Idx1) != &linksAt(Idx2) && @@ -555,8 +467,8 @@ private: mergeDirect(Idx1, Idx2); } - // \brief Merges two sets assuming that the set at `Idx1` is unreachable from - // traversing above or below the set at `Idx2`. + /// \brief Merges two sets assuming that the set at `Idx1` is unreachable from + /// traversing above or below the set at `Idx2`. void mergeDirect(StratifiedIndex Idx1, StratifiedIndex Idx2) { assert(inbounds(Idx1) && inbounds(Idx2)); @@ -582,7 +494,7 @@ private: // match `LinksFrom.Below` // > If both have links above, deal with those next. while (LinksInto->hasBelow() && LinksFrom->hasBelow()) { - auto &FromAttrs = LinksFrom->getAttrs(); + auto FromAttrs = LinksFrom->getAttrs(); LinksInto->setAttrs(FromAttrs); // Remap needs to happen after getBelow(), but before @@ -599,12 +511,13 @@ private: NewBelow.setAbove(LinksInto->Number); } + LinksInto->setAttrs(LinksFrom->getAttrs()); LinksFrom->remapTo(LinksInto->Number); } - // \brief Checks to see if lowerIndex is at a level lower than upperIndex. - // If so, it will merge lowerIndex with upperIndex (and all of the sets - // between) and return true. Otherwise, it will return false. + /// Checks to see if lowerIndex is at a level lower than upperIndex. If so, it + /// will merge lowerIndex with upperIndex (and all of the sets between) and + /// return true. Otherwise, it will return false. bool tryMergeUpwards(StratifiedIndex LowerIndex, StratifiedIndex UpperIndex) { assert(inbounds(LowerIndex) && inbounds(UpperIndex)); auto *Lower = &linksAt(LowerIndex); @@ -644,21 +557,21 @@ private: Optional<const StratifiedInfo *> get(const T &Val) const { auto Result = Values.find(Val); if (Result == Values.end()) - return NoneType(); + return None; return &Result->second; } Optional<StratifiedInfo *> get(const T &Val) { auto Result = Values.find(Val); if (Result == Values.end()) - return NoneType(); + return None; return &Result->second; } Optional<StratifiedIndex> indexOf(const T &Val) { auto MaybeVal = get(Val); if (!MaybeVal.hasValue()) - return NoneType(); + return None; auto *Info = *MaybeVal; auto &Link = linksAt(Info->Index); return Link.Number; @@ -689,4 +602,5 @@ private: bool inbounds(StratifiedIndex N) const { return N < Links.size(); } }; } +} #endif // LLVM_ADT_STRATIFIEDSETS_H diff --git a/lib/Analysis/TargetLibraryInfo.cpp b/lib/Analysis/TargetLibraryInfo.cpp index ce3881925627..93d537ad3abb 100644 --- a/lib/Analysis/TargetLibraryInfo.cpp +++ b/lib/Analysis/TargetLibraryInfo.cpp @@ -65,14 +65,18 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc::ldexp); TLI.setUnavailable(LibFunc::ldexpf); TLI.setUnavailable(LibFunc::ldexpl); + TLI.setUnavailable(LibFunc::exp10); + TLI.setUnavailable(LibFunc::exp10f); + TLI.setUnavailable(LibFunc::exp10l); + TLI.setUnavailable(LibFunc::log10); + TLI.setUnavailable(LibFunc::log10f); + TLI.setUnavailable(LibFunc::log10l); } // There are no library implementations of mempcy and memset for AMD gpus and // these can be difficult to lower in the backend. if (T.getArch() == Triple::r600 || - T.getArch() == Triple::amdgcn || - T.getArch() == Triple::wasm32 || - T.getArch() == Triple::wasm64) { + T.getArch() == Triple::amdgcn) { TLI.setUnavailable(LibFunc::memcpy); TLI.setUnavailable(LibFunc::memset); TLI.setUnavailable(LibFunc::memset_pattern16); @@ -207,6 +211,8 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc::fmaxf); TLI.setUnavailable(LibFunc::fmodf); TLI.setUnavailable(LibFunc::logf); + TLI.setUnavailable(LibFunc::log10f); + TLI.setUnavailable(LibFunc::modff); TLI.setUnavailable(LibFunc::powf); TLI.setUnavailable(LibFunc::sinf); TLI.setUnavailable(LibFunc::sinhf); @@ -387,6 +393,24 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc::tmpfile64); } + // As currently implemented in clang, NVPTX code has no standard library to + // speak of. Headers provide a standard-ish library implementation, but many + // of the signatures are wrong -- for example, many libm functions are not + // extern "C". + // + // libdevice, an IR library provided by nvidia, is linked in by the front-end, + // but only used functions are provided to llvm. Moreover, most of the + // functions in libdevice don't map precisely to standard library functions. + // + // FIXME: Having no standard library prevents e.g. many fastmath + // optimizations, so this situation should be fixed. + if (T.isNVPTX()) { + TLI.disableAllFunctions(); + TLI.setAvailable(LibFunc::nvvm_reflect); + } else { + TLI.setUnavailable(LibFunc::nvvm_reflect); + } + TLI.addVectorizableFunctionsFromVecLib(ClVectorLibrary); } @@ -444,7 +468,7 @@ static StringRef sanitizeFunctionName(StringRef funcName) { } bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, - LibFunc::Func &F) const { + LibFunc::Func &F) const { const char *const *Start = &StandardNames[0]; const char *const *End = &StandardNames[LibFunc::NumLibFuncs]; @@ -463,6 +487,518 @@ bool TargetLibraryInfoImpl::getLibFunc(StringRef funcName, return false; } +bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, + LibFunc::Func F, + const DataLayout *DL) const { + LLVMContext &Ctx = FTy.getContext(); + Type *PCharTy = Type::getInt8PtrTy(Ctx); + Type *SizeTTy = DL ? DL->getIntPtrType(Ctx, /*AS=*/0) : nullptr; + auto IsSizeTTy = [SizeTTy](Type *Ty) { + return SizeTTy ? Ty == SizeTTy : Ty->isIntegerTy(); + }; + unsigned NumParams = FTy.getNumParams(); + + switch (F) { + case LibFunc::strlen: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc::strchr: + case LibFunc::strrchr: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1)->isIntegerTy()); + + case LibFunc::strtol: + case LibFunc::strtod: + case LibFunc::strtof: + case LibFunc::strtoul: + case LibFunc::strtoll: + case LibFunc::strtold: + case LibFunc::strtoull: + return ((NumParams == 2 || NumParams == 3) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::strcat: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1) == FTy.getReturnType()); + + case LibFunc::strncat: + return (NumParams == 3 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0) == FTy.getReturnType() && + FTy.getParamType(1) == FTy.getReturnType() && + FTy.getParamType(2)->isIntegerTy()); + + case LibFunc::strcpy_chk: + case LibFunc::stpcpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + // fallthrough + case LibFunc::strcpy: + case LibFunc::stpcpy: + return (NumParams == 2 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy); + + case LibFunc::strncpy_chk: + case LibFunc::stpncpy_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + // fallthrough + case LibFunc::strncpy: + case LibFunc::stpncpy: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getParamType(0) == PCharTy && + FTy.getParamType(2)->isIntegerTy()); + + case LibFunc::strxfrm: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc::strcmp: + return (NumParams == 2 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1)); + + case LibFunc::strncmp: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getParamType(2)->isIntegerTy()); + + case LibFunc::strspn: + case LibFunc::strcspn: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(0) == FTy.getParamType(1) && + FTy.getReturnType()->isIntegerTy()); + + case LibFunc::strcoll: + case LibFunc::strcasecmp: + case LibFunc::strncasecmp: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc::strstr: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc::strpbrk: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0) == FTy.getParamType(1)); + + case LibFunc::strtok: + case LibFunc::strtok_r: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::scanf: + case LibFunc::setbuf: + case LibFunc::setvbuf: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::strdup: + case LibFunc::strndup: + return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc::sscanf: + case LibFunc::stat: + case LibFunc::statvfs: + case LibFunc::sprintf: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::snprintf: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::setitimer: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::system: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::malloc: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy()); + case LibFunc::memcmp: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + + case LibFunc::memchr: + case LibFunc::memrchr: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy(32) && + FTy.getParamType(2)->isIntegerTy() && + FTy.getReturnType()->isPointerTy()); + case LibFunc::modf: + case LibFunc::modff: + case LibFunc::modfl: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + + case LibFunc::memcpy_chk: + case LibFunc::memmove_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + // fallthrough + case LibFunc::memcpy: + case LibFunc::memmove: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc::memset_chk: + --NumParams; + if (!IsSizeTTy(FTy.getParamType(NumParams))) + return false; + // fallthrough + case LibFunc::memset: + return (NumParams == 3 && FTy.getReturnType() == FTy.getParamType(0) && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy() && + IsSizeTTy(FTy.getParamType(2))); + + case LibFunc::memccpy: + return (NumParams >= 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::memalign: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc::realloc: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getReturnType()->isPointerTy()); + case LibFunc::read: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::rewind: + case LibFunc::rmdir: + case LibFunc::remove: + case LibFunc::realpath: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::rename: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::readlink: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::write: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::bcopy: + case LibFunc::bcmp: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::bzero: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::calloc: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy()); + + case LibFunc::atof: + case LibFunc::atoi: + case LibFunc::atol: + case LibFunc::atoll: + case LibFunc::ferror: + case LibFunc::getenv: + case LibFunc::getpwnam: + case LibFunc::pclose: + case LibFunc::perror: + case LibFunc::printf: + case LibFunc::puts: + case LibFunc::uname: + case LibFunc::under_IO_getc: + case LibFunc::unlink: + case LibFunc::unsetenv: + return (NumParams == 1 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc::chmod: + case LibFunc::chown: + case LibFunc::clearerr: + case LibFunc::closedir: + case LibFunc::ctermid: + case LibFunc::fclose: + case LibFunc::feof: + case LibFunc::fflush: + case LibFunc::fgetc: + case LibFunc::fileno: + case LibFunc::flockfile: + case LibFunc::free: + case LibFunc::fseek: + case LibFunc::fseeko64: + case LibFunc::fseeko: + case LibFunc::fsetpos: + case LibFunc::ftell: + case LibFunc::ftello64: + case LibFunc::ftello: + case LibFunc::ftrylockfile: + case LibFunc::funlockfile: + case LibFunc::getc: + case LibFunc::getc_unlocked: + case LibFunc::getlogin_r: + case LibFunc::mkdir: + case LibFunc::mktime: + case LibFunc::times: + return (NumParams != 0 && FTy.getParamType(0)->isPointerTy()); + + case LibFunc::access: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::fopen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::fdopen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::fputc: + case LibFunc::fstat: + case LibFunc::frexp: + case LibFunc::frexpf: + case LibFunc::frexpl: + case LibFunc::fstatvfs: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::fgets: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::fread: + return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(3)->isPointerTy()); + case LibFunc::fwrite: + return (NumParams == 4 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isIntegerTy() && + FTy.getParamType(2)->isIntegerTy() && + FTy.getParamType(3)->isPointerTy()); + case LibFunc::fputs: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::fscanf: + case LibFunc::fprintf: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::fgetpos: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::gets: + case LibFunc::getchar: + case LibFunc::getitimer: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::ungetc: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::utime: + case LibFunc::utimes: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::putc: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::pread: + case LibFunc::pwrite: + return (NumParams == 4 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::popen: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::vscanf: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::vsscanf: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::vfscanf: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::valloc: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc::vprintf: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::vfprintf: + case LibFunc::vsprintf: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::vsnprintf: + return (NumParams == 4 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + case LibFunc::open: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::opendir: + return (NumParams == 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc::tmpfile: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc::htonl: + case LibFunc::htons: + case LibFunc::ntohl: + case LibFunc::ntohs: + case LibFunc::lstat: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::lchown: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::qsort: + return (NumParams == 4 && FTy.getParamType(3)->isPointerTy()); + case LibFunc::dunder_strdup: + case LibFunc::dunder_strndup: + return (NumParams >= 1 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy()); + case LibFunc::dunder_strtok_r: + return (NumParams == 3 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::under_IO_putc: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::dunder_isoc99_scanf: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::stat64: + case LibFunc::lstat64: + case LibFunc::statvfs64: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::dunder_isoc99_sscanf: + return (NumParams >= 1 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::fopen64: + return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + case LibFunc::tmpfile64: + return (FTy.getReturnType()->isPointerTy()); + case LibFunc::fstat64: + case LibFunc::fstatvfs64: + return (NumParams == 2 && FTy.getParamType(1)->isPointerTy()); + case LibFunc::open64: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy()); + case LibFunc::gettimeofday: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy()); + + case LibFunc::Znwj: // new(unsigned int); + case LibFunc::Znwm: // new(unsigned long); + case LibFunc::Znaj: // new[](unsigned int); + case LibFunc::Znam: // new[](unsigned long); + case LibFunc::msvc_new_int: // new(unsigned int); + case LibFunc::msvc_new_longlong: // new(unsigned long long); + case LibFunc::msvc_new_array_int: // new[](unsigned int); + case LibFunc::msvc_new_array_longlong: // new[](unsigned long long); + return (NumParams == 1); + + case LibFunc::memset_pattern16: + return (!FTy.isVarArg() && NumParams == 3 && + isa<PointerType>(FTy.getParamType(0)) && + isa<PointerType>(FTy.getParamType(1)) && + isa<IntegerType>(FTy.getParamType(2))); + + // int __nvvm_reflect(const char *); + case LibFunc::nvvm_reflect: + return (NumParams == 1 && isa<PointerType>(FTy.getParamType(0))); + + case LibFunc::sin: + case LibFunc::sinf: + case LibFunc::sinl: + case LibFunc::cos: + case LibFunc::cosf: + case LibFunc::cosl: + case LibFunc::tan: + case LibFunc::tanf: + case LibFunc::tanl: + case LibFunc::exp: + case LibFunc::expf: + case LibFunc::expl: + case LibFunc::exp2: + case LibFunc::exp2f: + case LibFunc::exp2l: + case LibFunc::log: + case LibFunc::logf: + case LibFunc::logl: + case LibFunc::log10: + case LibFunc::log10f: + case LibFunc::log10l: + case LibFunc::log2: + case LibFunc::log2f: + case LibFunc::log2l: + case LibFunc::fabs: + case LibFunc::fabsf: + case LibFunc::fabsl: + case LibFunc::floor: + case LibFunc::floorf: + case LibFunc::floorl: + case LibFunc::ceil: + case LibFunc::ceilf: + case LibFunc::ceill: + case LibFunc::trunc: + case LibFunc::truncf: + case LibFunc::truncl: + case LibFunc::rint: + case LibFunc::rintf: + case LibFunc::rintl: + case LibFunc::nearbyint: + case LibFunc::nearbyintf: + case LibFunc::nearbyintl: + case LibFunc::round: + case LibFunc::roundf: + case LibFunc::roundl: + case LibFunc::sqrt: + case LibFunc::sqrtf: + case LibFunc::sqrtl: + return (NumParams == 1 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc::fmin: + case LibFunc::fminf: + case LibFunc::fminl: + case LibFunc::fmax: + case LibFunc::fmaxf: + case LibFunc::fmaxl: + case LibFunc::copysign: + case LibFunc::copysignf: + case LibFunc::copysignl: + case LibFunc::pow: + case LibFunc::powf: + case LibFunc::powl: + return (NumParams == 2 && FTy.getReturnType()->isFloatingPointTy() && + FTy.getReturnType() == FTy.getParamType(0) && + FTy.getReturnType() == FTy.getParamType(1)); + + case LibFunc::ffs: + case LibFunc::ffsl: + case LibFunc::ffsll: + case LibFunc::isdigit: + case LibFunc::isascii: + case LibFunc::toascii: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy(32) && + FTy.getParamType(0)->isIntegerTy()); + + case LibFunc::fls: + case LibFunc::flsl: + case LibFunc::flsll: + case LibFunc::abs: + case LibFunc::labs: + case LibFunc::llabs: + return (NumParams == 1 && FTy.getReturnType()->isIntegerTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc::cxa_atexit: + return (NumParams == 3 && FTy.getReturnType()->isIntegerTy() && + FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy()); + + case LibFunc::sinpi: + case LibFunc::cospi: + return (NumParams == 1 && FTy.getReturnType()->isDoubleTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + case LibFunc::sinpif: + case LibFunc::cospif: + return (NumParams == 1 && FTy.getReturnType()->isFloatTy() && + FTy.getReturnType() == FTy.getParamType(0)); + + default: + // Assume the other functions are correct. + // FIXME: It'd be really nice to cover them all. + return true; + } +} + +bool TargetLibraryInfoImpl::getLibFunc(const Function &FDecl, + LibFunc::Func &F) const { + const DataLayout *DL = + FDecl.getParent() ? &FDecl.getParent()->getDataLayout() : nullptr; + return getLibFunc(FDecl.getName(), F) && + isValidProtoForLibFunc(*FDecl.getFunctionType(), F, DL); +} + void TargetLibraryInfoImpl::disableAllFunctions() { memset(AvailableArray, 0, sizeof(AvailableArray)); } @@ -583,14 +1119,16 @@ StringRef TargetLibraryInfoImpl::getScalarizedFunction(StringRef F, return I->ScalarFnName; } -TargetLibraryInfo TargetLibraryAnalysis::run(Module &M) { +TargetLibraryInfo TargetLibraryAnalysis::run(Module &M, + ModuleAnalysisManager &) { if (PresetInfoImpl) return TargetLibraryInfo(*PresetInfoImpl); return TargetLibraryInfo(lookupInfoImpl(Triple(M.getTargetTriple()))); } -TargetLibraryInfo TargetLibraryAnalysis::run(Function &F) { +TargetLibraryInfo TargetLibraryAnalysis::run(Function &F, + FunctionAnalysisManager &) { if (PresetInfoImpl) return TargetLibraryInfo(*PresetInfoImpl); @@ -598,7 +1136,7 @@ TargetLibraryInfo TargetLibraryAnalysis::run(Function &F) { lookupInfoImpl(Triple(F.getParent()->getTargetTriple()))); } -TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(Triple T) { +TargetLibraryInfoImpl &TargetLibraryAnalysis::lookupInfoImpl(const Triple &T) { std::unique_ptr<TargetLibraryInfoImpl> &Impl = Impls[T.normalize()]; if (!Impl) diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 9c1d3fd4f582..52013f796c56 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/Support/ErrorHandling.h" +#include <utility> using namespace llvm; @@ -66,6 +67,15 @@ int TargetTransformInfo::getCallCost(const Function *F, return Cost; } +unsigned TargetTransformInfo::getInliningThresholdMultiplier() const { + return TTIImpl->getInliningThresholdMultiplier(); +} + +int TargetTransformInfo::getGEPCost(Type *PointeeType, const Value *Ptr, + ArrayRef<const Value *> Operands) const { + return TTIImpl->getGEPCost(PointeeType, Ptr, Operands); +} + int TargetTransformInfo::getIntrinsicCost( Intrinsic::ID IID, Type *RetTy, ArrayRef<const Value *> Arguments) const { int Cost = TTIImpl->getIntrinsicCost(IID, RetTy, Arguments); @@ -172,6 +182,18 @@ bool TargetTransformInfo::enableInterleavedAccessVectorization() const { return TTIImpl->enableInterleavedAccessVectorization(); } +bool TargetTransformInfo::isFPVectorizationPotentiallyUnsafe() const { + return TTIImpl->isFPVectorizationPotentiallyUnsafe(); +} + +bool TargetTransformInfo::allowsMisalignedMemoryAccesses(unsigned BitWidth, + unsigned AddressSpace, + unsigned Alignment, + bool *Fast) const { + return TTIImpl->allowsMisalignedMemoryAccesses(BitWidth, AddressSpace, + Alignment, Fast); +} + TargetTransformInfo::PopcntSupportKind TargetTransformInfo::getPopcntSupport(unsigned IntTyWidthInBit) const { return TTIImpl->getPopcntSupport(IntTyWidthInBit); @@ -187,6 +209,14 @@ int TargetTransformInfo::getFPOpCost(Type *Ty) const { return Cost; } +int TargetTransformInfo::getIntImmCodeSizeCost(unsigned Opcode, unsigned Idx, + const APInt &Imm, + Type *Ty) const { + int Cost = TTIImpl->getIntImmCodeSizeCost(Opcode, Idx, Imm, Ty); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + int TargetTransformInfo::getIntImmCost(const APInt &Imm, Type *Ty) const { int Cost = TTIImpl->getIntImmCost(Imm, Ty); assert(Cost >= 0 && "TTI should not produce negative costs!"); @@ -215,6 +245,26 @@ unsigned TargetTransformInfo::getRegisterBitWidth(bool Vector) const { return TTIImpl->getRegisterBitWidth(Vector); } +unsigned TargetTransformInfo::getLoadStoreVecRegBitWidth(unsigned AS) const { + return TTIImpl->getLoadStoreVecRegBitWidth(AS); +} + +unsigned TargetTransformInfo::getCacheLineSize() const { + return TTIImpl->getCacheLineSize(); +} + +unsigned TargetTransformInfo::getPrefetchDistance() const { + return TTIImpl->getPrefetchDistance(); +} + +unsigned TargetTransformInfo::getMinPrefetchStride() const { + return TTIImpl->getMinPrefetchStride(); +} + +unsigned TargetTransformInfo::getMaxPrefetchIterationsAhead() const { + return TTIImpl->getMaxPrefetchIterationsAhead(); +} + unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { return TTIImpl->getMaxInterleaveFactor(VF); } @@ -243,6 +293,14 @@ int TargetTransformInfo::getCastInstrCost(unsigned Opcode, Type *Dst, return Cost; } +int TargetTransformInfo::getExtractWithExtendCost(unsigned Opcode, Type *Dst, + VectorType *VecTy, + unsigned Index) const { + int Cost = TTIImpl->getExtractWithExtendCost(Opcode, Dst, VecTy, Index); + assert(Cost >= 0 && "TTI should not produce negative costs!"); + return Cost; +} + int TargetTransformInfo::getCFInstrCost(unsigned Opcode) const { int Cost = TTIImpl->getCFInstrCost(Opcode); assert(Cost >= 0 && "TTI should not produce negative costs!"); @@ -299,15 +357,17 @@ int TargetTransformInfo::getInterleavedMemoryOpCost( } int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef<Type *> Tys) const { - int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys); + ArrayRef<Type *> Tys, + FastMathFlags FMF) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Tys, FMF); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } int TargetTransformInfo::getIntrinsicInstrCost(Intrinsic::ID ID, Type *RetTy, - ArrayRef<Value *> Args) const { - int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args); + ArrayRef<Value *> Args, + FastMathFlags FMF) const { + int Cost = TTIImpl->getIntrinsicInstrCost(ID, RetTy, Args, FMF); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -363,9 +423,10 @@ TargetIRAnalysis::TargetIRAnalysis() : TTICallback(&getDefaultTTI) {} TargetIRAnalysis::TargetIRAnalysis( std::function<Result(const Function &)> TTICallback) - : TTICallback(TTICallback) {} + : TTICallback(std::move(TTICallback)) {} -TargetIRAnalysis::Result TargetIRAnalysis::run(const Function &F) { +TargetIRAnalysis::Result TargetIRAnalysis::run(const Function &F, + AnalysisManager<Function> &) { return TTICallback(F); } @@ -396,7 +457,8 @@ TargetTransformInfoWrapperPass::TargetTransformInfoWrapperPass( } TargetTransformInfo &TargetTransformInfoWrapperPass::getTTI(const Function &F) { - TTI = TIRA.run(F); + AnalysisManager<Function> DummyFAM; + TTI = TIRA.run(F, DummyFAM); return *TTI; } diff --git a/lib/Analysis/Trace.cpp b/lib/Analysis/Trace.cpp index 5a1acc00fb94..c7e2c0f3412a 100644 --- a/lib/Analysis/Trace.cpp +++ b/lib/Analysis/Trace.cpp @@ -46,7 +46,7 @@ void Trace::print(raw_ostream &O) const { /// dump - Debugger convenience method; writes trace to standard error /// output stream. /// -void Trace::dump() const { +LLVM_DUMP_METHOD void Trace::dump() const { print(dbgs()); } #endif diff --git a/lib/Analysis/TypeBasedAliasAnalysis.cpp b/lib/Analysis/TypeBasedAliasAnalysis.cpp index 9f923913ca27..20d162a03c30 100644 --- a/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -122,7 +122,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/TypeBasedAliasAnalysis.h" -#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/ADT/SetVector.h" #include "llvm/IR/Constants.h" #include "llvm/IR/LLVMContext.h" @@ -584,18 +583,15 @@ bool TypeBasedAAResult::PathAliases(const MDNode *A, const MDNode *B) const { return false; } -TypeBasedAAResult TypeBasedAA::run(Function &F, AnalysisManager<Function> *AM) { - return TypeBasedAAResult(AM->getResult<TargetLibraryAnalysis>(F)); -} - char TypeBasedAA::PassID; +TypeBasedAAResult TypeBasedAA::run(Function &F, AnalysisManager<Function> &AM) { + return TypeBasedAAResult(); +} + char TypeBasedAAWrapperPass::ID = 0; -INITIALIZE_PASS_BEGIN(TypeBasedAAWrapperPass, "tbaa", - "Type-Based Alias Analysis", false, true) -INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) -INITIALIZE_PASS_END(TypeBasedAAWrapperPass, "tbaa", "Type-Based Alias Analysis", - false, true) +INITIALIZE_PASS(TypeBasedAAWrapperPass, "tbaa", "Type-Based Alias Analysis", + false, true) ImmutablePass *llvm::createTypeBasedAAWrapperPass() { return new TypeBasedAAWrapperPass(); @@ -606,8 +602,7 @@ TypeBasedAAWrapperPass::TypeBasedAAWrapperPass() : ImmutablePass(ID) { } bool TypeBasedAAWrapperPass::doInitialization(Module &M) { - Result.reset(new TypeBasedAAResult( - getAnalysis<TargetLibraryInfoWrapperPass>().getTLI())); + Result.reset(new TypeBasedAAResult()); return false; } @@ -618,5 +613,4 @@ bool TypeBasedAAWrapperPass::doFinalization(Module &M) { void TypeBasedAAWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.setPreservesAll(); - AU.addRequired<TargetLibraryInfoWrapperPass>(); } diff --git a/lib/Analysis/TypeMetadataUtils.cpp b/lib/Analysis/TypeMetadataUtils.cpp new file mode 100644 index 000000000000..31e2b42075d6 --- /dev/null +++ b/lib/Analysis/TypeMetadataUtils.cpp @@ -0,0 +1,118 @@ +//===- TypeMetadataUtils.cpp - Utilities related to type metadata ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file contains functions that make it easier to manipulate type metadata +// for devirtualization. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +// Search for virtual calls that call FPtr and add them to DevirtCalls. +static void +findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, + bool *HasNonCallUses, Value *FPtr, uint64_t Offset) { + for (const Use &U : FPtr->uses()) { + Value *User = U.getUser(); + if (isa<BitCastInst>(User)) { + findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset); + } else if (auto CI = dyn_cast<CallInst>(User)) { + DevirtCalls.push_back({Offset, CI}); + } else if (auto II = dyn_cast<InvokeInst>(User)) { + DevirtCalls.push_back({Offset, II}); + } else if (HasNonCallUses) { + *HasNonCallUses = true; + } + } +} + +// Search for virtual calls that load from VPtr and add them to DevirtCalls. +static void +findLoadCallsAtConstantOffset(Module *M, + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + Value *VPtr, int64_t Offset) { + for (const Use &U : VPtr->uses()) { + Value *User = U.getUser(); + if (isa<BitCastInst>(User)) { + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset); + } else if (isa<LoadInst>(User)) { + findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset); + } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { + // Take into account the GEP offset. + if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { + SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end()); + int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType( + GEP->getSourceElementType(), Indices); + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset); + } + } + } +} + +void llvm::findDevirtualizableCallsForTypeTest( + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + SmallVectorImpl<CallInst *> &Assumes, CallInst *CI) { + assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); + + Module *M = CI->getParent()->getParent()->getParent(); + + // Find llvm.assume intrinsics for this llvm.type.test call. + for (const Use &CIU : CI->uses()) { + auto AssumeCI = dyn_cast<CallInst>(CIU.getUser()); + if (AssumeCI) { + Function *F = AssumeCI->getCalledFunction(); + if (F && F->getIntrinsicID() == Intrinsic::assume) + Assumes.push_back(AssumeCI); + } + } + + // If we found any, search for virtual calls based on %p and add them to + // DevirtCalls. + if (!Assumes.empty()) + findLoadCallsAtConstantOffset(M, DevirtCalls, + CI->getArgOperand(0)->stripPointerCasts(), 0); +} + +void llvm::findDevirtualizableCallsForTypeCheckedLoad( + SmallVectorImpl<DevirtCallSite> &DevirtCalls, + SmallVectorImpl<Instruction *> &LoadedPtrs, + SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, CallInst *CI) { + assert(CI->getCalledFunction()->getIntrinsicID() == + Intrinsic::type_checked_load); + + auto *Offset = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!Offset) { + HasNonCallUses = true; + return; + } + + for (Use &U : CI->uses()) { + auto CIU = U.getUser(); + if (auto EVI = dyn_cast<ExtractValueInst>(CIU)) { + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 0) { + LoadedPtrs.push_back(EVI); + continue; + } + if (EVI->getNumIndices() == 1 && EVI->getIndices()[0] == 1) { + Preds.push_back(EVI); + continue; + } + } + HasNonCallUses = true; + } + + for (Value *LoadedPtr : LoadedPtrs) + findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr, + Offset->getZExtValue()); +} diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index a83e207bd265..f2b40787443a 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -18,7 +18,9 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -36,40 +38,19 @@ #include "llvm/IR/Statepoint.h" #include "llvm/Support/Debug.h" #include "llvm/Support/MathExtras.h" +#include <algorithm> +#include <array> #include <cstring> using namespace llvm; using namespace llvm::PatternMatch; const unsigned MaxDepth = 6; -/// Enable an experimental feature to leverage information about dominating -/// conditions to compute known bits. The individual options below control how -/// hard we search. The defaults are chosen to be fairly aggressive. If you -/// run into compile time problems when testing, scale them back and report -/// your findings. -static cl::opt<bool> EnableDomConditions("value-tracking-dom-conditions", - cl::Hidden, cl::init(false)); - -// This is expensive, so we only do it for the top level query value. -// (TODO: evaluate cost vs profit, consider higher thresholds) -static cl::opt<unsigned> DomConditionsMaxDepth("dom-conditions-max-depth", - cl::Hidden, cl::init(1)); - -/// How many dominating blocks should be scanned looking for dominating -/// conditions? -static cl::opt<unsigned> DomConditionsMaxDomBlocks("dom-conditions-dom-blocks", - cl::Hidden, - cl::init(20)); - // Controls the number of uses of the value searched for possible // dominating comparisons. static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses", cl::Hidden, cl::init(20)); -// If true, don't consider only compares whose only use is a branch. -static cl::opt<bool> DomConditionsSingleCmpUse("dom-conditions-single-cmp-use", - cl::Hidden, cl::init(false)); - /// Returns the bitwidth of the given scalar or pointer type (if unknown returns /// 0). For vector types, returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { @@ -79,34 +60,45 @@ static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { return DL.getPointerTypeSizeInBits(Ty); } -// Many of these functions have internal versions that take an assumption -// exclusion set. This is because of the potential for mutual recursion to -// cause computeKnownBits to repeatedly visit the same assume intrinsic. The -// classic case of this is assume(x = y), which will attempt to determine -// bits in x from bits in y, which will attempt to determine bits in y from -// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call -// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and -// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so on. -typedef SmallPtrSet<const Value *, 8> ExclInvsSet; - namespace { // Simplifying using an assume can only be done in a particular control-flow // context (the context instruction provides that context). If an assume and // the context instruction are not in the same block then the DT helps in // figuring out if we can use it. struct Query { - ExclInvsSet ExclInvs; + const DataLayout &DL; AssumptionCache *AC; const Instruction *CxtI; const DominatorTree *DT; - Query(AssumptionCache *AC = nullptr, const Instruction *CxtI = nullptr, - const DominatorTree *DT = nullptr) - : AC(AC), CxtI(CxtI), DT(DT) {} + /// Set of assumptions that should be excluded from further queries. + /// This is because of the potential for mutual recursion to cause + /// computeKnownBits to repeatedly visit the same assume intrinsic. The + /// classic case of this is assume(x = y), which will attempt to determine + /// bits in x from bits in y, which will attempt to determine bits in y from + /// bits in x, etc. Regarding the mutual recursion, computeKnownBits can call + /// isKnownNonZero, which calls computeKnownBits and ComputeSignBit and + /// isKnownToBeAPowerOfTwo (all of which can call computeKnownBits), and so + /// on. + std::array<const Value*, MaxDepth> Excluded; + unsigned NumExcluded; + + Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), NumExcluded(0) {} Query(const Query &Q, const Value *NewExcl) - : ExclInvs(Q.ExclInvs), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT) { - ExclInvs.insert(NewExcl); + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), NumExcluded(Q.NumExcluded) { + Excluded = Q.Excluded; + Excluded[NumExcluded++] = NewExcl; + assert(NumExcluded <= Excluded.size()); + } + + bool isExcluded(const Value *Value) const { + if (NumExcluded == 0) + return false; + auto End = Excluded.begin() + NumExcluded; + return std::find(Excluded.begin(), End, Value) != End; } }; } // end anonymous namespace @@ -128,15 +120,14 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { } static void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - const DataLayout &DL, unsigned Depth, - const Query &Q); + unsigned Depth, const Query &Q); void llvm::computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - ::computeKnownBits(V, KnownZero, KnownOne, DL, Depth, - Query(AC, safeCxtI(V, CxtI), DT)); + ::computeKnownBits(V, KnownZero, KnownOne, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT)); } bool llvm::haveNoCommonBitsSet(Value *LHS, Value *RHS, const DataLayout &DL, @@ -155,35 +146,33 @@ bool llvm::haveNoCommonBitsSet(Value *LHS, Value *RHS, const DataLayout &DL, } static void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout &DL, unsigned Depth, - const Query &Q); + unsigned Depth, const Query &Q); void llvm::ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - ::ComputeSignBit(V, KnownZero, KnownOne, DL, Depth, - Query(AC, safeCxtI(V, CxtI), DT)); + ::ComputeSignBit(V, KnownZero, KnownOne, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT)); } static bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, - const Query &Q, const DataLayout &DL); + const Query &Q); bool llvm::isKnownToBeAPowerOfTwo(Value *V, const DataLayout &DL, bool OrZero, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, - Query(AC, safeCxtI(V, CxtI), DT), DL); + Query(DL, AC, safeCxtI(V, CxtI), DT)); } -static bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, - const Query &Q); +static bool isKnownNonZero(Value *V, unsigned Depth, const Query &Q); bool llvm::isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - return ::isKnownNonZero(V, DL, Depth, Query(AC, safeCxtI(V, CxtI), DT)); + return ::isKnownNonZero(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); } bool llvm::isKnownNonNegative(Value *V, const DataLayout &DL, unsigned Depth, @@ -194,42 +183,59 @@ bool llvm::isKnownNonNegative(Value *V, const DataLayout &DL, unsigned Depth, return NonNegative; } -static bool isKnownNonEqual(Value *V1, Value *V2, const DataLayout &DL, - const Query &Q); +bool llvm::isKnownPositive(Value *V, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + if (auto *CI = dyn_cast<ConstantInt>(V)) + return CI->getValue().isStrictlyPositive(); + + // TODO: We'd doing two recursive queries here. We should factor this such + // that only a single query is needed. + return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT) && + isKnownNonZero(V, DL, Depth, AC, CxtI, DT); +} + +bool llvm::isKnownNegative(Value *V, const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT) { + bool NonNegative, Negative; + ComputeSignBit(V, NonNegative, Negative, DL, Depth, AC, CxtI, DT); + return Negative; +} + +static bool isKnownNonEqual(Value *V1, Value *V2, const Query &Q); bool llvm::isKnownNonEqual(Value *V1, Value *V2, const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - return ::isKnownNonEqual(V1, V2, DL, Query(AC, - safeCxtI(V1, safeCxtI(V2, CxtI)), - DT)); + return ::isKnownNonEqual(V1, V2, Query(DL, AC, + safeCxtI(V1, safeCxtI(V2, CxtI)), + DT)); } -static bool MaskedValueIsZero(Value *V, const APInt &Mask, const DataLayout &DL, - unsigned Depth, const Query &Q); +static bool MaskedValueIsZero(Value *V, const APInt &Mask, unsigned Depth, + const Query &Q); bool llvm::MaskedValueIsZero(Value *V, const APInt &Mask, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - return ::MaskedValueIsZero(V, Mask, DL, Depth, - Query(AC, safeCxtI(V, CxtI), DT)); + return ::MaskedValueIsZero(V, Mask, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT)); } -static unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, - unsigned Depth, const Query &Q); +static unsigned ComputeNumSignBits(Value *V, unsigned Depth, const Query &Q); unsigned llvm::ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - return ::ComputeNumSignBits(V, DL, Depth, Query(AC, safeCxtI(V, CxtI), DT)); + return ::ComputeNumSignBits(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); } static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout &DL, unsigned Depth, - const Query &Q) { + unsigned Depth, const Query &Q) { if (!Add) { if (ConstantInt *CLHS = dyn_cast<ConstantInt>(Op0)) { // We know that the top bits of C-X are clear if X contains less bits @@ -240,7 +246,7 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, unsigned NLZ = (CLHS->getValue()+1).countLeadingZeros(); // NLZ can't be BitWidth with no sign bit APInt MaskV = APInt::getHighBitsSet(BitWidth, NLZ+1); - computeKnownBits(Op1, KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); // If all of the MaskV bits are known to be zero, then we know the // output top bits are zero, because we now know that the output is @@ -259,8 +265,8 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, // If an initial sequence of bits in the result is not needed, the // corresponding bits in the operands are not needed. APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, DL, Depth + 1, Q); - computeKnownBits(Op1, KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(Op0, LHSKnownZero, LHSKnownOne, Depth + 1, Q); + computeKnownBits(Op1, KnownZero2, KnownOne2, Depth + 1, Q); // Carry in a 1 for a subtract, rather than a 0. APInt CarryIn(BitWidth, 0); @@ -308,11 +314,10 @@ static void computeKnownBitsAddSub(bool Add, Value *Op0, Value *Op1, bool NSW, static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout &DL, unsigned Depth, - const Query &Q) { + unsigned Depth, const Query &Q) { unsigned BitWidth = KnownZero.getBitWidth(); - computeKnownBits(Op1, KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(Op0, KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(Op1, KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(Op0, KnownZero2, KnownOne2, Depth + 1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -333,9 +338,9 @@ static void computeKnownBitsMul(Value *Op0, Value *Op1, bool NSW, // negative or zero. if (!isKnownNonNegative) isKnownNegative = (isKnownNegativeOp1 && isKnownNonNegativeOp0 && - isKnownNonZero(Op0, DL, Depth, Q)) || + isKnownNonZero(Op0, Depth, Q)) || (isKnownNegativeOp0 && isKnownNonNegativeOp1 && - isKnownNonZero(Op1, DL, Depth, Q)); + isKnownNonZero(Op1, Depth, Q)); } } @@ -451,7 +456,8 @@ static bool isAssumeLikeIntrinsic(const Instruction *I) { return false; } -static bool isValidAssumeForContext(Value *V, const Query &Q) { +static bool isValidAssumeForContext(Value *V, const Instruction *CxtI, + const DominatorTree *DT) { Instruction *Inv = cast<Instruction>(V); // There are two restrictions on the use of an assume: @@ -462,43 +468,43 @@ static bool isValidAssumeForContext(Value *V, const Query &Q) { // feeding the assume is trivially true, thus causing the removal of // the assume). - if (Q.DT) { - if (Q.DT->dominates(Inv, Q.CxtI)) { + if (DT) { + if (DT->dominates(Inv, CxtI)) { return true; - } else if (Inv->getParent() == Q.CxtI->getParent()) { + } else if (Inv->getParent() == CxtI->getParent()) { // The context comes first, but they're both in the same block. Make sure // there is nothing in between that might interrupt the control flow. for (BasicBlock::const_iterator I = - std::next(BasicBlock::const_iterator(Q.CxtI)), + std::next(BasicBlock::const_iterator(CxtI)), IE(Inv); I != IE; ++I) if (!isSafeToSpeculativelyExecute(&*I) && !isAssumeLikeIntrinsic(&*I)) return false; - return !isEphemeralValueOf(Inv, Q.CxtI); + return !isEphemeralValueOf(Inv, CxtI); } return false; } // When we don't have a DT, we do a limited search... - if (Inv->getParent() == Q.CxtI->getParent()->getSinglePredecessor()) { + if (Inv->getParent() == CxtI->getParent()->getSinglePredecessor()) { return true; - } else if (Inv->getParent() == Q.CxtI->getParent()) { + } else if (Inv->getParent() == CxtI->getParent()) { // Search forward from the assume until we reach the context (or the end // of the block); the common case is that the assume will come first. for (BasicBlock::iterator I = std::next(BasicBlock::iterator(Inv)), IE = Inv->getParent()->end(); I != IE; ++I) - if (&*I == Q.CxtI) + if (&*I == CxtI) return true; // The context must come first... for (BasicBlock::const_iterator I = - std::next(BasicBlock::const_iterator(Q.CxtI)), + std::next(BasicBlock::const_iterator(CxtI)), IE(Inv); I != IE; ++I) if (!isSafeToSpeculativelyExecute(&*I) && !isAssumeLikeIntrinsic(&*I)) return false; - return !isEphemeralValueOf(Inv, Q.CxtI); + return !isEphemeralValueOf(Inv, CxtI); } return false; @@ -507,226 +513,12 @@ static bool isValidAssumeForContext(Value *V, const Query &Q) { bool llvm::isValidAssumeForContext(const Instruction *I, const Instruction *CxtI, const DominatorTree *DT) { - return ::isValidAssumeForContext(const_cast<Instruction *>(I), - Query(nullptr, CxtI, DT)); -} - -template<typename LHS, typename RHS> -inline match_combine_or<CmpClass_match<LHS, RHS, ICmpInst, ICmpInst::Predicate>, - CmpClass_match<RHS, LHS, ICmpInst, ICmpInst::Predicate>> -m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { - return m_CombineOr(m_ICmp(Pred, L, R), m_ICmp(Pred, R, L)); -} - -template<typename LHS, typename RHS> -inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::And>, - BinaryOp_match<RHS, LHS, Instruction::And>> -m_c_And(const LHS &L, const RHS &R) { - return m_CombineOr(m_And(L, R), m_And(R, L)); -} - -template<typename LHS, typename RHS> -inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Or>, - BinaryOp_match<RHS, LHS, Instruction::Or>> -m_c_Or(const LHS &L, const RHS &R) { - return m_CombineOr(m_Or(L, R), m_Or(R, L)); -} - -template<typename LHS, typename RHS> -inline match_combine_or<BinaryOp_match<LHS, RHS, Instruction::Xor>, - BinaryOp_match<RHS, LHS, Instruction::Xor>> -m_c_Xor(const LHS &L, const RHS &R) { - return m_CombineOr(m_Xor(L, R), m_Xor(R, L)); -} - -/// Compute known bits in 'V' under the assumption that the condition 'Cmp' is -/// true (at the context instruction.) This is mostly a utility function for -/// the prototype dominating conditions reasoning below. -static void computeKnownBitsFromTrueCondition(Value *V, ICmpInst *Cmp, - APInt &KnownZero, - APInt &KnownOne, - const DataLayout &DL, - unsigned Depth, const Query &Q) { - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - // TODO: We could potentially be more aggressive here. This would be worth - // evaluating. If we can, explore commoning this code with the assume - // handling logic. - if (LHS != V && RHS != V) - return; - - const unsigned BitWidth = KnownZero.getBitWidth(); - - switch (Cmp->getPredicate()) { - default: - // We know nothing from this condition - break; - // TODO: implement unsigned bound from below (known one bits) - // TODO: common condition check implementations with assumes - // TODO: implement other patterns from assume (e.g. V & B == A) - case ICmpInst::ICMP_SGT: - if (LHS == V) { - APInt KnownZeroTemp(BitWidth, 0), KnownOneTemp(BitWidth, 0); - computeKnownBits(RHS, KnownZeroTemp, KnownOneTemp, DL, Depth + 1, Q); - if (KnownOneTemp.isAllOnesValue() || KnownZeroTemp.isNegative()) { - // We know that the sign bit is zero. - KnownZero |= APInt::getSignBit(BitWidth); - } - } - break; - case ICmpInst::ICMP_EQ: - { - APInt KnownZeroTemp(BitWidth, 0), KnownOneTemp(BitWidth, 0); - if (LHS == V) - computeKnownBits(RHS, KnownZeroTemp, KnownOneTemp, DL, Depth + 1, Q); - else if (RHS == V) - computeKnownBits(LHS, KnownZeroTemp, KnownOneTemp, DL, Depth + 1, Q); - else - llvm_unreachable("missing use?"); - KnownZero |= KnownZeroTemp; - KnownOne |= KnownOneTemp; - } - break; - case ICmpInst::ICMP_ULE: - if (LHS == V) { - APInt KnownZeroTemp(BitWidth, 0), KnownOneTemp(BitWidth, 0); - computeKnownBits(RHS, KnownZeroTemp, KnownOneTemp, DL, Depth + 1, Q); - // The known zero bits carry over - unsigned SignBits = KnownZeroTemp.countLeadingOnes(); - KnownZero |= APInt::getHighBitsSet(BitWidth, SignBits); - } - break; - case ICmpInst::ICMP_ULT: - if (LHS == V) { - APInt KnownZeroTemp(BitWidth, 0), KnownOneTemp(BitWidth, 0); - computeKnownBits(RHS, KnownZeroTemp, KnownOneTemp, DL, Depth + 1, Q); - // Whatever high bits in rhs are zero are known to be zero (if rhs is a - // power of 2, then one more). - unsigned SignBits = KnownZeroTemp.countLeadingOnes(); - if (isKnownToBeAPowerOfTwo(RHS, false, Depth + 1, Query(Q, Cmp), DL)) - SignBits++; - KnownZero |= APInt::getHighBitsSet(BitWidth, SignBits); - } - break; - }; -} - -/// Compute known bits in 'V' from conditions which are known to be true along -/// all paths leading to the context instruction. In particular, look for -/// cases where one branch of an interesting condition dominates the context -/// instruction. This does not do general dataflow. -/// NOTE: This code is EXPERIMENTAL and currently off by default. -static void computeKnownBitsFromDominatingCondition(Value *V, APInt &KnownZero, - APInt &KnownOne, - const DataLayout &DL, - unsigned Depth, - const Query &Q) { - // Need both the dominator tree and the query location to do anything useful - if (!Q.DT || !Q.CxtI) - return; - Instruction *Cxt = const_cast<Instruction *>(Q.CxtI); - // The context instruction might be in a statically unreachable block. If - // so, asking dominator queries may yield suprising results. (e.g. the block - // may not have a dom tree node) - if (!Q.DT->isReachableFromEntry(Cxt->getParent())) - return; - - // Avoid useless work - if (auto VI = dyn_cast<Instruction>(V)) - if (VI->getParent() == Cxt->getParent()) - return; - - // Note: We currently implement two options. It's not clear which of these - // will survive long term, we need data for that. - // Option 1 - Try walking the dominator tree looking for conditions which - // might apply. This works well for local conditions (loop guards, etc..), - // but not as well for things far from the context instruction (presuming a - // low max blocks explored). If we can set an high enough limit, this would - // be all we need. - // Option 2 - We restrict out search to those conditions which are uses of - // the value we're interested in. This is independent of dom structure, - // but is slightly less powerful without looking through lots of use chains. - // It does handle conditions far from the context instruction (e.g. early - // function exits on entry) really well though. - - // Option 1 - Search the dom tree - unsigned NumBlocksExplored = 0; - BasicBlock *Current = Cxt->getParent(); - while (true) { - // Stop searching if we've gone too far up the chain - if (NumBlocksExplored >= DomConditionsMaxDomBlocks) - break; - NumBlocksExplored++; - - if (!Q.DT->getNode(Current)->getIDom()) - break; - Current = Q.DT->getNode(Current)->getIDom()->getBlock(); - if (!Current) - // found function entry - break; - - BranchInst *BI = dyn_cast<BranchInst>(Current->getTerminator()); - if (!BI || BI->isUnconditional()) - continue; - ICmpInst *Cmp = dyn_cast<ICmpInst>(BI->getCondition()); - if (!Cmp) - continue; - - // We're looking for conditions that are guaranteed to hold at the context - // instruction. Finding a condition where one path dominates the context - // isn't enough because both the true and false cases could merge before - // the context instruction we're actually interested in. Instead, we need - // to ensure that the taken *edge* dominates the context instruction. We - // know that the edge must be reachable since we started from a reachable - // block. - BasicBlock *BB0 = BI->getSuccessor(0); - BasicBlockEdge Edge(BI->getParent(), BB0); - if (!Edge.isSingleEdge() || !Q.DT->dominates(Edge, Q.CxtI->getParent())) - continue; - - computeKnownBitsFromTrueCondition(V, Cmp, KnownZero, KnownOne, DL, Depth, - Q); - } - - // Option 2 - Search the other uses of V - unsigned NumUsesExplored = 0; - for (auto U : V->users()) { - // Avoid massive lists - if (NumUsesExplored >= DomConditionsMaxUses) - break; - NumUsesExplored++; - // Consider only compare instructions uniquely controlling a branch - ICmpInst *Cmp = dyn_cast<ICmpInst>(U); - if (!Cmp) - continue; - - if (DomConditionsSingleCmpUse && !Cmp->hasOneUse()) - continue; - - for (auto *CmpU : Cmp->users()) { - BranchInst *BI = dyn_cast<BranchInst>(CmpU); - if (!BI || BI->isUnconditional()) - continue; - // We're looking for conditions that are guaranteed to hold at the - // context instruction. Finding a condition where one path dominates - // the context isn't enough because both the true and false cases could - // merge before the context instruction we're actually interested in. - // Instead, we need to ensure that the taken *edge* dominates the context - // instruction. - BasicBlock *BB0 = BI->getSuccessor(0); - BasicBlockEdge Edge(BI->getParent(), BB0); - if (!Edge.isSingleEdge() || !Q.DT->dominates(Edge, Q.CxtI->getParent())) - continue; - - computeKnownBitsFromTrueCondition(V, Cmp, KnownZero, KnownOne, DL, Depth, - Q); - } - } + return ::isValidAssumeForContext(const_cast<Instruction *>(I), CxtI, DT); } static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, - APInt &KnownOne, const DataLayout &DL, - unsigned Depth, const Query &Q) { + APInt &KnownOne, unsigned Depth, + const Query &Q) { // Use of assumptions is context-sensitive. If we don't have a context, we // cannot use them! if (!Q.AC || !Q.CxtI) @@ -740,7 +532,7 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, CallInst *I = cast<CallInst>(AssumeVH); assert(I->getParent()->getParent() == Q.CxtI->getParent()->getParent() && "Got assumption for the wrong function!"); - if (Q.ExclInvs.count(I)) + if (Q.isExcluded(I)) continue; // Warning: This loop can end up being somewhat performance sensetive. @@ -752,7 +544,7 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, Value *Arg = I->getArgOperand(0); - if (Arg == V && isValidAssumeForContext(I, Q)) { + if (Arg == V && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { assert(BitWidth == 1 && "assume operand is not i1?"); KnownZero.clearAllBits(); KnownOne.setAllBits(); @@ -772,19 +564,20 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, ConstantInt *C; // assume(v = a) if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); KnownZero |= RHSKnownZero; KnownOne |= RHSKnownOne; // assume(v & b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); - computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, Depth+1, Query(Q, I)); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. @@ -793,11 +586,12 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(~(v & b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt MaskKnownZero(BitWidth, 0), MaskKnownOne(BitWidth, 0); - computeKnownBits(B, MaskKnownZero, MaskKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, MaskKnownZero, MaskKnownOne, Depth+1, Query(Q, I)); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. @@ -806,11 +600,12 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(v | b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. @@ -819,11 +614,12 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(~(v | b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. @@ -832,11 +628,12 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(v ^ b = a) } else if (match(Arg, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, @@ -848,11 +645,12 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(~(v ^ b) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); APInt BKnownZero(BitWidth, 0), BKnownOne(BitWidth, 0); - computeKnownBits(B, BKnownZero, BKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(B, BKnownZero, BKnownOne, Depth+1, Query(Q, I)); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are @@ -864,9 +662,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(v << c = a) } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. KnownZero |= RHSKnownZero.lshr(C->getZExtValue()); @@ -874,9 +673,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, // assume(~(v << c) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. KnownZero |= RHSKnownOne.lshr(C->getZExtValue()); @@ -886,9 +686,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, m_c_ICmp(Pred, m_CombineOr(m_LShr(m_V, m_ConstantInt(C)), m_AShr(m_V, m_ConstantInt(C))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. KnownZero |= RHSKnownZero << C->getZExtValue(); @@ -898,18 +699,20 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, m_LShr(m_V, m_ConstantInt(C)), m_AShr(m_V, m_ConstantInt(C)))), m_Value(A))) && - Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_EQ && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. KnownZero |= RHSKnownOne << C->getZExtValue(); KnownOne |= RHSKnownZero << C->getZExtValue(); // assume(v >=_s c) where c is non-negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SGE && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_SGE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); if (RHSKnownZero.isNegative()) { // We know that the sign bit is zero. @@ -917,9 +720,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, } // assume(v >_s c) where c is at least -1. } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SGT && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_SGT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); if (RHSKnownOne.isAllOnesValue() || RHSKnownZero.isNegative()) { // We know that the sign bit is zero. @@ -927,9 +731,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, } // assume(v <=_s c) where c is negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SLE && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_SLE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); if (RHSKnownOne.isNegative()) { // We know that the sign bit is one. @@ -937,9 +742,10 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, } // assume(v <_s c) where c is non-positive } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_SLT && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_SLT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); if (RHSKnownZero.isAllOnesValue() || RHSKnownOne.isNegative()) { // We know that the sign bit is one. @@ -947,22 +753,24 @@ static void computeKnownBitsFromAssume(Value *V, APInt &KnownZero, } // assume(v <=_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_ULE && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_ULE && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // Whatever high bits in c are zero are known to be zero. KnownZero |= APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()); // assume(v <_u c) } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && - Pred == ICmpInst::ICMP_ULT && isValidAssumeForContext(I, Q)) { + Pred == ICmpInst::ICMP_ULT && + isValidAssumeForContext(I, Q.CxtI, Q.DT)) { APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - computeKnownBits(A, RHSKnownZero, RHSKnownOne, DL, Depth+1, Query(Q, I)); + computeKnownBits(A, RHSKnownZero, RHSKnownOne, Depth+1, Query(Q, I)); // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). - if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I), DL)) + if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) KnownZero |= APInt::getHighBitsSet(BitWidth, RHSKnownZero.countLeadingOnes()+1); else @@ -984,20 +792,19 @@ template <typename KZFunctor, typename KOFunctor> static void computeKnownBitsFromShiftOperator(Operator *I, APInt &KnownZero, APInt &KnownOne, APInt &KnownZero2, APInt &KnownOne2, - const DataLayout &DL, unsigned Depth, const Query &Q, - KZFunctor KZF, KOFunctor KOF) { + unsigned Depth, const Query &Q, KZFunctor KZF, KOFunctor KOF) { unsigned BitWidth = KnownZero.getBitWidth(); if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); KnownZero = KZF(KnownZero, ShiftAmt); KnownOne = KOF(KnownOne, ShiftAmt); return; } - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); // Note: We cannot use KnownZero.getLimitedValue() here, because if // BitWidth > 64 and any upper bits are known, we'll end up returning the @@ -1007,7 +814,8 @@ static void computeKnownBitsFromShiftOperator(Operator *I, // It would be more-clearly correct to use the two temporaries for this // calculation. Reusing the APInts here to prevent unnecessary allocations. - KnownZero.clearAllBits(), KnownOne.clearAllBits(); + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); // If we know the shifter operand is nonzero, we can sometimes infer more // known bits. However this is expensive to compute, so be lazy about it and @@ -1017,12 +825,12 @@ static void computeKnownBitsFromShiftOperator(Operator *I, // Early exit if we can't constrain any well-defined shift amount. if (!(ShiftAmtKZ & (BitWidth - 1)) && !(ShiftAmtKO & (BitWidth - 1))) { ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q); + isKnownNonZero(I->getOperand(1), Depth + 1, Q); if (!*ShifterOperandIsNonZero) return; } - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); KnownZero = KnownOne = APInt::getAllOnesValue(BitWidth); for (unsigned ShiftAmt = 0; ShiftAmt < BitWidth; ++ShiftAmt) { @@ -1038,7 +846,7 @@ static void computeKnownBitsFromShiftOperator(Operator *I, if (ShiftAmt == 0) { if (!ShifterOperandIsNonZero.hasValue()) ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), DL, Depth + 1, Q); + isKnownNonZero(I->getOperand(1), Depth + 1, Q); if (*ShifterOperandIsNonZero) continue; } @@ -1052,13 +860,15 @@ static void computeKnownBitsFromShiftOperator(Operator *I, // return anything we'd like, but we need to make sure the sets of known bits // stay disjoint (it should be better for some other code to actually // propagate the undef than to pick a value here using known bits). - if ((KnownZero & KnownOne) != 0) - KnownZero.clearAllBits(), KnownOne.clearAllBits(); + if ((KnownZero & KnownOne) != 0) { + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); + } } static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, - APInt &KnownOne, const DataLayout &DL, - unsigned Depth, const Query &Q) { + APInt &KnownOne, unsigned Depth, + const Query &Q) { unsigned BitWidth = KnownZero.getBitWidth(); APInt KnownZero2(KnownZero), KnownOne2(KnownOne); @@ -1070,8 +880,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); // Output known-1 bits are only known if set in both the LHS & RHS. KnownOne &= KnownOne2; @@ -1089,15 +899,15 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), m_Value(Y)))) { APInt KnownZero3(BitWidth, 0), KnownOne3(BitWidth, 0); - computeKnownBits(Y, KnownZero3, KnownOne3, DL, Depth + 1, Q); + computeKnownBits(Y, KnownZero3, KnownOne3, Depth + 1, Q); if (KnownOne3.countTrailingOnes() > 0) KnownZero |= APInt::getLowBitsSet(BitWidth, 1); } break; } case Instruction::Or: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); // Output known-0 bits are only known if clear in both the LHS & RHS. KnownZero &= KnownZero2; @@ -1106,8 +916,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; } case Instruction::Xor: { - computeKnownBits(I->getOperand(1), KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); // Output known-0 bits are known if clear or set in both the LHS & RHS. APInt KnownZeroOut = (KnownZero & KnownZero2) | (KnownOne & KnownOne2); @@ -1119,19 +929,19 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, case Instruction::Mul: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, KnownZero, - KnownOne, KnownZero2, KnownOne2, DL, Depth, Q); + KnownOne, KnownZero2, KnownOne2, Depth, Q); break; } case Instruction::UDiv: { // For the purposes of computing leading zeros we can conservatively // treat a udiv as a logical right shift by the power of 2 known to // be less than the denominator. - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); unsigned LeadZ = KnownZero2.countLeadingOnes(); KnownOne2.clearAllBits(); KnownZero2.clearAllBits(); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); unsigned RHSUnknownLeadingOnes = KnownOne2.countLeadingZeros(); if (RHSUnknownLeadingOnes != BitWidth) LeadZ = std::min(BitWidth, @@ -1141,8 +951,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; } case Instruction::Select: - computeKnownBits(I->getOperand(2), KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(2), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); // Only known if known in both the LHS and RHS. KnownOne &= KnownOne2; @@ -1166,12 +976,12 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, unsigned SrcBitWidth; // Note that we handle pointer operands here because of inttoptr/ptrtoint // which fall through here. - SrcBitWidth = DL.getTypeSizeInBits(SrcTy->getScalarType()); + SrcBitWidth = Q.DL.getTypeSizeInBits(SrcTy->getScalarType()); assert(SrcBitWidth && "SrcBitWidth can't be zero"); KnownZero = KnownZero.zextOrTrunc(SrcBitWidth); KnownOne = KnownOne.zextOrTrunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); KnownZero = KnownZero.zextOrTrunc(BitWidth); KnownOne = KnownOne.zextOrTrunc(BitWidth); // Any top bits are known to be zero. @@ -1181,12 +991,11 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, } case Instruction::BitCast: { Type *SrcTy = I->getOperand(0)->getType(); - if ((SrcTy->isIntegerTy() || SrcTy->isPointerTy() || - SrcTy->isFloatingPointTy()) && + if ((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); break; } break; @@ -1197,7 +1006,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); KnownZero = KnownZero.zext(BitWidth); KnownOne = KnownOne.zext(BitWidth); @@ -1221,8 +1030,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, }; computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, DL, Depth, Q, - KZF, KOF); + KnownZero2, KnownOne2, Depth, Q, KZF, + KOF); break; } case Instruction::LShr: { @@ -1238,8 +1047,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, }; computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, DL, Depth, Q, - KZF, KOF); + KnownZero2, KnownOne2, Depth, Q, KZF, + KOF); break; } case Instruction::AShr: { @@ -1253,22 +1062,22 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, }; computeKnownBitsFromShiftOperator(I, KnownZero, KnownOne, - KnownZero2, KnownOne2, DL, Depth, Q, - KZF, KOF); + KnownZero2, KnownOne2, Depth, Q, KZF, + KOF); break; } case Instruction::Sub: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, DL, - Depth, Q); + KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, + Q); break; } case Instruction::Add: { bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, - KnownZero, KnownOne, KnownZero2, KnownOne2, DL, - Depth, Q); + KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, + Q); break; } case Instruction::SRem: @@ -1276,7 +1085,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, APInt RA = Rem->getValue().abs(); if (RA.isPowerOf2()) { APInt LowBits = RA - 1; - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, Depth + 1, + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); // The low bits of the first operand are unchanged by the srem. @@ -1301,8 +1110,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, // remainder is zero. if (KnownZero.isNonNegative()) { APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, DL, - Depth + 1, Q); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, + Q); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) KnownZero.setBit(BitWidth - 1); @@ -1311,11 +1120,10 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; case Instruction::URem: { if (ConstantInt *Rem = dyn_cast<ConstantInt>(I->getOperand(1))) { - APInt RA = Rem->getValue(); + const APInt &RA = Rem->getValue(); if (RA.isPowerOf2()) { APInt LowBits = (RA - 1); - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, - Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); KnownZero |= ~LowBits; KnownOne &= LowBits; break; @@ -1324,8 +1132,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, // Since the result is less than or equal to either operand, any leading // zero bits in either operand must also exist in the result. - computeKnownBits(I->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, Q); - computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); + computeKnownBits(I->getOperand(1), KnownZero2, KnownOne2, Depth + 1, Q); unsigned Leaders = std::max(KnownZero.countLeadingOnes(), KnownZero2.countLeadingOnes()); @@ -1338,7 +1146,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, AllocaInst *AI = cast<AllocaInst>(I); unsigned Align = AI->getAlignment(); if (Align == 0) - Align = DL.getABITypeAlignment(AI->getType()->getElementType()); + Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); if (Align > 0) KnownZero = APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); @@ -1348,8 +1156,8 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, // Analyze all of the subscripts of this getelementptr instruction // to determine if we can prove known low zero bits. APInt LocalKnownZero(BitWidth, 0), LocalKnownOne(BitWidth, 0); - computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, DL, - Depth + 1, Q); + computeKnownBits(I->getOperand(0), LocalKnownZero, LocalKnownOne, Depth + 1, + Q); unsigned TrailZ = LocalKnownZero.countTrailingOnes(); gep_type_iterator GTI = gep_type_begin(I); @@ -1367,7 +1175,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, Index = CIndex->getSplatValue(); unsigned Idx = cast<ConstantInt>(Index)->getZExtValue(); - const StructLayout *SL = DL.getStructLayout(STy); + const StructLayout *SL = Q.DL.getStructLayout(STy); uint64_t Offset = SL->getElementOffset(Idx); TrailZ = std::min<unsigned>(TrailZ, countTrailingZeros(Offset)); @@ -1379,10 +1187,9 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; } unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); - uint64_t TypeSize = DL.getTypeAllocSize(IndexedTy); + uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); LocalKnownZero = LocalKnownOne = APInt(GEPOpiBits, 0); - computeKnownBits(Index, LocalKnownZero, LocalKnownOne, DL, Depth + 1, - Q); + computeKnownBits(Index, LocalKnownZero, LocalKnownOne, Depth + 1, Q); TrailZ = std::min(TrailZ, unsigned(countTrailingZeros(TypeSize) + LocalKnownZero.countTrailingOnes())); @@ -1424,11 +1231,11 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; // Ok, we have a PHI of the form L op= R. Check for low // zero bits. - computeKnownBits(R, KnownZero2, KnownOne2, DL, Depth + 1, Q); + computeKnownBits(R, KnownZero2, KnownOne2, Depth + 1, Q); // We need to take the minimum number of known bits APInt KnownZero3(KnownZero), KnownOne3(KnownOne); - computeKnownBits(L, KnownZero3, KnownOne3, DL, Depth + 1, Q); + computeKnownBits(L, KnownZero3, KnownOne3, Depth + 1, Q); KnownZero = APInt::getLowBitsSet(BitWidth, std::min(KnownZero2.countTrailingOnes(), @@ -1459,8 +1266,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, KnownOne2 = APInt(BitWidth, 0); // Recurse, but cap the recursion to one level, because we don't // want to waste time spinning around in loops. - computeKnownBits(IncValue, KnownZero2, KnownOne2, DL, - MaxDepth - 1, Q); + computeKnownBits(IncValue, KnownZero2, KnownOne2, MaxDepth - 1, Q); KnownZero &= KnownZero2; KnownOne &= KnownOne2; // If all bits have been ruled out, there's no need to check @@ -1473,17 +1279,21 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, } case Instruction::Call: case Instruction::Invoke: + // If range metadata is attached to this call, set known bits from that, + // and then intersect with known bits based on other properties of the + // function. if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, KnownZero, KnownOne); - // If a range metadata is attached to this IntrinsicInst, intersect the - // explicit range specified by the metadata and the implicit range of - // the intrinsic. + if (Value *RV = CallSite(I).getReturnedArgOperand()) { + computeKnownBits(RV, KnownZero2, KnownOne2, Depth + 1, Q); + KnownZero |= KnownZero2; + KnownOne |= KnownOne2; + } if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { default: break; case Intrinsic::bswap: - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, - Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); KnownZero |= KnownZero2.byteSwap(); KnownOne |= KnownOne2.byteSwap(); break; @@ -1497,8 +1307,7 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, break; } case Intrinsic::ctpop: { - computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, DL, - Depth + 1, Q); + computeKnownBits(I->getOperand(0), KnownZero2, KnownOne2, Depth + 1, Q); // We can bound the space the count needs. Also, bits known to be zero // can't contribute to the population. unsigned BitsPossiblySet = BitWidth - KnownZero2.countPopulation(); @@ -1511,12 +1320,6 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, // of bits which might be set provided by popcnt KnownOne2. break; } - case Intrinsic::fabs: { - Type *Ty = II->getType(); - APInt SignBit = APInt::getSignBit(Ty->getScalarSizeInBits()); - KnownZero |= APInt::getSplat(Ty->getPrimitiveSizeInBits(), SignBit); - break; - } case Intrinsic::x86_sse42_crc32_64_64: KnownZero |= APInt::getHighBitsSet(64, 32); break; @@ -1534,19 +1337,19 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, DL, Depth, Q); + KnownOne, KnownZero2, KnownOne2, Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), II->getArgOperand(1), false, KnownZero, - KnownOne, KnownZero2, KnownOne2, DL, Depth, Q); + KnownOne, KnownZero2, KnownOne2, Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, - KnownZero, KnownOne, KnownZero2, KnownOne2, DL, - Depth, Q); + KnownZero, KnownOne, KnownZero2, KnownOne2, Depth, + Q); break; } } @@ -1554,46 +1357,6 @@ static void computeKnownBitsFromOperator(Operator *I, APInt &KnownZero, } } -static unsigned getAlignment(const Value *V, const DataLayout &DL) { - unsigned Align = 0; - if (auto *GO = dyn_cast<GlobalObject>(V)) { - Align = GO->getAlignment(); - if (Align == 0) { - if (auto *GVar = dyn_cast<GlobalVariable>(GO)) { - Type *ObjectType = GVar->getType()->getElementType(); - if (ObjectType->isSized()) { - // If the object is defined in the current Module, we'll be giving - // it the preferred alignment. Otherwise, we have to assume that it - // may only have the minimum ABI alignment. - if (GVar->isStrongDefinitionForLinker()) - Align = DL.getPreferredAlignment(GVar); - else - Align = DL.getABITypeAlignment(ObjectType); - } - } - } - } else if (const Argument *A = dyn_cast<Argument>(V)) { - Align = A->getType()->isPointerTy() ? A->getParamAlignment() : 0; - - if (!Align && A->hasStructRetAttr()) { - // An sret parameter has at least the ABI alignment of the return type. - Type *EltTy = cast<PointerType>(A->getType())->getElementType(); - if (EltTy->isSized()) - Align = DL.getABITypeAlignment(EltTy); - } - } else if (const AllocaInst *AI = dyn_cast<AllocaInst>(V)) - Align = AI->getAlignment(); - else if (auto CS = ImmutableCallSite(V)) - Align = CS.getAttributes().getParamAlignment(AttributeSet::ReturnIndex); - else if (const LoadInst *LI = dyn_cast<LoadInst>(V)) - if (MDNode *MD = LI->getMetadata(LLVMContext::MD_align)) { - ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(0)); - Align = CI->getLimitedValue(); - } - - return Align; -} - /// Determine which bits of V are known to be either zero or one and return /// them in the KnownZero/KnownOne bit sets. /// @@ -1610,16 +1373,15 @@ static unsigned getAlignment(const Value *V, const DataLayout &DL) { /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, - const DataLayout &DL, unsigned Depth, const Query &Q) { + unsigned Depth, const Query &Q) { assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); unsigned BitWidth = KnownZero.getBitWidth(); assert((V->getType()->isIntOrIntVectorTy() || - V->getType()->isFPOrFPVectorTy() || V->getType()->getScalarType()->isPointerTy()) && - "Not integer, floating point, or pointer type!"); - assert((DL.getTypeSizeInBits(V->getType()->getScalarType()) == BitWidth) && + "Not integer or pointer type!"); + assert((Q.DL.getTypeSizeInBits(V->getType()->getScalarType()) == BitWidth) && (!V->getType()->isIntOrIntVectorTy() || V->getType()->getScalarSizeInBits() == BitWidth) && KnownZero.getBitWidth() == BitWidth && @@ -1633,15 +1395,13 @@ void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, return; } // Null and aggregate-zero are all-zeros. - if (isa<ConstantPointerNull>(V) || - isa<ConstantAggregateZero>(V)) { + if (isa<ConstantPointerNull>(V) || isa<ConstantAggregateZero>(V)) { KnownOne.clearAllBits(); KnownZero = APInt::getAllOnesValue(BitWidth); return; } // Handle a constant vector by taking the intersection of the known bits of - // each element. There is no real need to handle ConstantVector here, because - // we don't handle undef in any particularly useful way. + // each element. if (ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) { // We know that CDS must be a vector of integers. Take the intersection of // each element. @@ -1655,6 +1415,26 @@ void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, return; } + if (auto *CV = dyn_cast<ConstantVector>(V)) { + // We know that CV must be a vector of integers. Take the intersection of + // each element. + KnownZero.setAllBits(); KnownOne.setAllBits(); + APInt Elt(KnownZero.getBitWidth(), 0); + for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + Constant *Element = CV->getAggregateElement(i); + auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); + if (!ElementCI) { + KnownZero.clearAllBits(); + KnownOne.clearAllBits(); + return; + } + Elt = ElementCI->getValue(); + KnownZero &= ~Elt; + KnownOne &= Elt; + } + return; + } + // Start out not knowing anything. KnownZero.clearAllBits(); KnownOne.clearAllBits(); @@ -1666,33 +1446,26 @@ void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, // A weak GlobalAlias is totally unknown. A non-weak GlobalAlias has // the bits of its aliasee. if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (!GA->mayBeOverridden()) - computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, DL, Depth + 1, Q); + if (!GA->isInterposable()) + computeKnownBits(GA->getAliasee(), KnownZero, KnownOne, Depth + 1, Q); return; } if (Operator *I = dyn_cast<Operator>(V)) - computeKnownBitsFromOperator(I, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBitsFromOperator(I, KnownZero, KnownOne, Depth, Q); // Aligned pointers have trailing zeros - refine KnownZero set if (V->getType()->isPointerTy()) { - unsigned Align = getAlignment(V, DL); + unsigned Align = V->getPointerAlignment(Q.DL); if (Align) KnownZero |= APInt::getLowBitsSet(BitWidth, countTrailingZeros(Align)); } - // computeKnownBitsFromAssume and computeKnownBitsFromDominatingCondition - // strictly refines KnownZero and KnownOne. Therefore, we run them after - // computeKnownBitsFromOperator. + // computeKnownBitsFromAssume strictly refines KnownZero and + // KnownOne. Therefore, we run them after computeKnownBitsFromOperator. // Check whether a nearby assume intrinsic can determine some known bits. - computeKnownBitsFromAssume(V, KnownZero, KnownOne, DL, Depth, Q); - - // Check whether there's a dominating condition which implies something about - // this value at the given context. - if (EnableDomConditions && Depth <= DomConditionsMaxDepth) - computeKnownBitsFromDominatingCondition(V, KnownZero, KnownOne, DL, Depth, - Q); + computeKnownBitsFromAssume(V, KnownZero, KnownOne, Depth, Q); assert((KnownZero & KnownOne) == 0 && "Bits known to be one AND zero?"); } @@ -1700,8 +1473,8 @@ void computeKnownBits(Value *V, APInt &KnownZero, APInt &KnownOne, /// Determine whether the sign bit is known to be zero or one. /// Convenience wrapper around computeKnownBits. void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, - const DataLayout &DL, unsigned Depth, const Query &Q) { - unsigned BitWidth = getBitWidth(V->getType(), DL); + unsigned Depth, const Query &Q) { + unsigned BitWidth = getBitWidth(V->getType(), Q.DL); if (!BitWidth) { KnownZero = false; KnownOne = false; @@ -1709,7 +1482,7 @@ void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, } APInt ZeroBits(BitWidth, 0); APInt OneBits(BitWidth, 0); - computeKnownBits(V, ZeroBits, OneBits, DL, Depth, Q); + computeKnownBits(V, ZeroBits, OneBits, Depth, Q); KnownOne = OneBits[BitWidth - 1]; KnownZero = ZeroBits[BitWidth - 1]; } @@ -1719,13 +1492,14 @@ void ComputeSignBit(Value *V, bool &KnownZero, bool &KnownOne, /// be a power of two when defined. Supports values with integer or pointer /// types and vectors of integers. bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, - const Query &Q, const DataLayout &DL) { + const Query &Q) { if (Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return OrZero; - if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) - return CI->getValue().isPowerOf2(); - // TODO: Handle vector constants. + + const APInt *ConstIntOrConstSplatInt; + if (match(C, m_APInt(ConstIntOrConstSplatInt))) + return ConstIntOrConstSplatInt->isPowerOf2(); } // 1 << X is clearly a power of two if the one is not shifted off the end. If @@ -1747,19 +1521,19 @@ bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, // or zero. if (OrZero && (match(V, m_Shl(m_Value(X), m_Value())) || match(V, m_LShr(m_Value(X), m_Value())))) - return isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q, DL); + return isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q); if (ZExtInst *ZI = dyn_cast<ZExtInst>(V)) - return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q, DL); + return isKnownToBeAPowerOfTwo(ZI->getOperand(0), OrZero, Depth, Q); if (SelectInst *SI = dyn_cast<SelectInst>(V)) - return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q, DL) && - isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q, DL); + return isKnownToBeAPowerOfTwo(SI->getTrueValue(), OrZero, Depth, Q) && + isKnownToBeAPowerOfTwo(SI->getFalseValue(), OrZero, Depth, Q); if (OrZero && match(V, m_And(m_Value(X), m_Value(Y)))) { // A power of two and'd with anything is a power of two or zero. - if (isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q, DL) || - isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, Depth, Q, DL)) + if (isKnownToBeAPowerOfTwo(X, /*OrZero*/ true, Depth, Q) || + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, Depth, Q)) return true; // X & (-X) is always a power of two or zero. if (match(X, m_Neg(m_Specific(Y))) || match(Y, m_Neg(m_Specific(X)))) @@ -1774,19 +1548,19 @@ bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) - if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q, DL)) + if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) return true; if (match(Y, m_And(m_Specific(X), m_Value())) || match(Y, m_And(m_Value(), m_Specific(X)))) - if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q, DL)) + if (isKnownToBeAPowerOfTwo(X, OrZero, Depth, Q)) return true; unsigned BitWidth = V->getType()->getScalarSizeInBits(); APInt LHSZeroBits(BitWidth, 0), LHSOneBits(BitWidth, 0); - computeKnownBits(X, LHSZeroBits, LHSOneBits, DL, Depth, Q); + computeKnownBits(X, LHSZeroBits, LHSOneBits, Depth, Q); APInt RHSZeroBits(BitWidth, 0), RHSOneBits(BitWidth, 0); - computeKnownBits(Y, RHSZeroBits, RHSOneBits, DL, Depth, Q); + computeKnownBits(Y, RHSZeroBits, RHSOneBits, Depth, Q); // If i8 V is a power of two or zero: // ZeroBits: 1 1 1 0 1 1 1 1 // ~ZeroBits: 0 0 0 1 0 0 0 0 @@ -1804,7 +1578,7 @@ bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, if (match(V, m_Exact(m_LShr(m_Value(), m_Value()))) || match(V, m_Exact(m_UDiv(m_Value(), m_Value())))) { return isKnownToBeAPowerOfTwo(cast<Operator>(V)->getOperand(0), OrZero, - Depth, Q, DL); + Depth, Q); } return false; @@ -1816,8 +1590,8 @@ bool isKnownToBeAPowerOfTwo(Value *V, bool OrZero, unsigned Depth, /// to be non-null. /// /// Currently this routine does not support vector GEPs. -static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, - unsigned Depth, const Query &Q) { +static bool isGEPKnownNonNull(GEPOperator *GEP, unsigned Depth, + const Query &Q) { if (!GEP->isInBounds() || GEP->getPointerAddressSpace() != 0) return false; @@ -1826,7 +1600,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, // If the base pointer is non-null, we cannot walk to a null address with an // inbounds GEP in address space zero. - if (isKnownNonZero(GEP->getPointerOperand(), DL, Depth, Q)) + if (isKnownNonZero(GEP->getPointerOperand(), Depth, Q)) return true; // Walk the GEP operands and see if any operand introduces a non-zero offset. @@ -1838,7 +1612,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, if (StructType *STy = dyn_cast<StructType>(*GTI)) { ConstantInt *OpC = cast<ConstantInt>(GTI.getOperand()); unsigned ElementIdx = OpC->getZExtValue(); - const StructLayout *SL = DL.getStructLayout(STy); + const StructLayout *SL = Q.DL.getStructLayout(STy); uint64_t ElementOffset = SL->getElementOffset(ElementIdx); if (ElementOffset > 0) return true; @@ -1846,7 +1620,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, } // If we have a zero-sized type, the index doesn't matter. Keep looping. - if (DL.getTypeAllocSize(GTI.getIndexedType()) == 0) + if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0) continue; // Fast path the constant operand case both for efficiency and so we don't @@ -1865,7 +1639,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, if (Depth++ >= MaxDepth) continue; - if (isKnownNonZero(GTI.getOperand(), DL, Depth, Q)) + if (isKnownNonZero(GTI.getOperand(), Depth, Q)) return true; } @@ -1875,8 +1649,7 @@ static bool isGEPKnownNonNull(GEPOperator *GEP, const DataLayout &DL, /// Does the 'Range' metadata (which must be a valid MD_range operand list) /// ensure that the value it's attached to is never Value? 'RangeType' is /// is the type of the value described by the range. -static bool rangeMetadataExcludesValue(MDNode* Ranges, - const APInt& Value) { +static bool rangeMetadataExcludesValue(MDNode* Ranges, const APInt& Value) { const unsigned NumRanges = Ranges->getNumOperands() / 2; assert(NumRanges >= 1); for (unsigned i = 0; i < NumRanges; ++i) { @@ -1895,23 +1668,35 @@ static bool rangeMetadataExcludesValue(MDNode* Ranges, /// For vectors return true if every element is known to be non-zero when /// defined. Supports values with integer or pointer type and vectors of /// integers. -bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, - const Query &Q) { - if (Constant *C = dyn_cast<Constant>(V)) { +bool isKnownNonZero(Value *V, unsigned Depth, const Query &Q) { + if (auto *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return false; if (isa<ConstantInt>(C)) // Must be non-zero due to null test above. return true; - // TODO: Handle vectors + + // For constant vectors, check that all elements are undefined or known + // non-zero to determine that the whole vector is known non-zero. + if (auto *VecTy = dyn_cast<VectorType>(C->getType())) { + for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { + Constant *Elt = C->getAggregateElement(i); + if (!Elt || Elt->isNullValue()) + return false; + if (!isa<UndefValue>(Elt) && !isa<ConstantInt>(Elt)) + return false; + } + return true; + } + return false; } - if (Instruction* I = dyn_cast<Instruction>(V)) { + if (auto *I = dyn_cast<Instruction>(V)) { if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { // If the possible ranges don't contain zero, then the value is // definitely non-zero. - if (IntegerType* Ty = dyn_cast<IntegerType>(V->getType())) { + if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { const APInt ZeroValue(Ty->getBitWidth(), 0); if (rangeMetadataExcludesValue(Ranges, ZeroValue)) return true; @@ -1926,22 +1711,22 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, // Check for pointer simplifications. if (V->getType()->isPointerTy()) { if (isKnownNonNull(V)) - return true; + return true; if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) - if (isGEPKnownNonNull(GEP, DL, Depth, Q)) + if (isGEPKnownNonNull(GEP, Depth, Q)) return true; } - unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), DL); + unsigned BitWidth = getBitWidth(V->getType()->getScalarType(), Q.DL); // X | Y != 0 if X != 0 or Y != 0. Value *X = nullptr, *Y = nullptr; if (match(V, m_Or(m_Value(X), m_Value(Y)))) - return isKnownNonZero(X, DL, Depth, Q) || isKnownNonZero(Y, DL, Depth, Q); + return isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q); // ext X != 0 if X != 0. if (isa<SExtInst>(V) || isa<ZExtInst>(V)) - return isKnownNonZero(cast<Instruction>(V)->getOperand(0), DL, Depth, Q); + return isKnownNonZero(cast<Instruction>(V)->getOperand(0), Depth, Q); // shl X, Y != 0 if X is odd. Note that the value of the shift is undefined // if the lowest bit is shifted off the end. @@ -1949,11 +1734,11 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, // shl nuw can't remove any non-zero bits. OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); if (BO->hasNoUnsignedWrap()) - return isKnownNonZero(X, DL, Depth, Q); + return isKnownNonZero(X, Depth, Q); APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBits(X, KnownZero, KnownOne, Depth, Q); if (KnownOne[0]) return true; } @@ -1963,10 +1748,10 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, // shr exact can only shift out zero bits. PossiblyExactOperator *BO = cast<PossiblyExactOperator>(V); if (BO->isExact()) - return isKnownNonZero(X, DL, Depth, Q); + return isKnownNonZero(X, Depth, Q); bool XKnownNonNegative, XKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, DL, Depth, Q); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, Depth, Q); if (XKnownNegative) return true; @@ -1976,32 +1761,32 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, if (ConstantInt *Shift = dyn_cast<ConstantInt>(Y)) { APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(X, KnownZero, KnownOne, DL, Depth, Q); - + computeKnownBits(X, KnownZero, KnownOne, Depth, Q); + auto ShiftVal = Shift->getLimitedValue(BitWidth - 1); // Is there a known one in the portion not shifted out? if (KnownOne.countLeadingZeros() < BitWidth - ShiftVal) return true; // Are all the bits to be shifted out known zero? if (KnownZero.countTrailingOnes() >= ShiftVal) - return isKnownNonZero(X, DL, Depth, Q); + return isKnownNonZero(X, Depth, Q); } } // div exact can only produce a zero if the dividend is zero. else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { - return isKnownNonZero(X, DL, Depth, Q); + return isKnownNonZero(X, Depth, Q); } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { bool XKnownNonNegative, XKnownNegative; bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(X, XKnownNonNegative, XKnownNegative, DL, Depth, Q); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, DL, Depth, Q); + ComputeSignBit(X, XKnownNonNegative, XKnownNegative, Depth, Q); + ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. if (XKnownNonNegative && YKnownNonNegative) - if (isKnownNonZero(X, DL, Depth, Q) || isKnownNonZero(Y, DL, Depth, Q)) + if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not @@ -2012,22 +1797,22 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, APInt Mask = APInt::getSignedMaxValue(BitWidth); // The sign bit of X is set. If some other bit is set then X is not equal // to INT_MIN. - computeKnownBits(X, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBits(X, KnownZero, KnownOne, Depth, Q); if ((KnownOne & Mask) != 0) return true; // The sign bit of Y is set. If some other bit is set then Y is not equal // to INT_MIN. - computeKnownBits(Y, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBits(Y, KnownZero, KnownOne, Depth, Q); if ((KnownOne & Mask) != 0) return true; } // The sum of a non-negative number and a power of two is not zero. if (XKnownNonNegative && - isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q, DL)) + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ false, Depth, Q)) return true; if (YKnownNonNegative && - isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q, DL)) + isKnownToBeAPowerOfTwo(X, /*OrZero*/ false, Depth, Q)) return true; } // X * Y. @@ -2036,13 +1821,13 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && - isKnownNonZero(X, DL, Depth, Q) && isKnownNonZero(Y, DL, Depth, Q)) + isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) return true; } // (C ? X : Y) != 0 if X != 0 and Y != 0. else if (SelectInst *SI = dyn_cast<SelectInst>(V)) { - if (isKnownNonZero(SI->getTrueValue(), DL, Depth, Q) && - isKnownNonZero(SI->getFalseValue(), DL, Depth, Q)) + if (isKnownNonZero(SI->getTrueValue(), Depth, Q) && + isKnownNonZero(SI->getFalseValue(), Depth, Q)) return true; } // PHI @@ -2064,18 +1849,23 @@ bool isKnownNonZero(Value *V, const DataLayout &DL, unsigned Depth, } } } + // Check if all incoming values are non-zero constant. + bool AllNonZeroConstants = all_of(PN->operands(), [](Value *V) { + return isa<ConstantInt>(V) && !cast<ConstantInt>(V)->isZeroValue(); + }); + if (AllNonZeroConstants) + return true; } if (!BitWidth) return false; APInt KnownZero(BitWidth, 0); APInt KnownOne(BitWidth, 0); - computeKnownBits(V, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBits(V, KnownZero, KnownOne, Depth, Q); return KnownOne != 0; } /// Return true if V2 == V1 + X, where X is known non-zero. -static bool isAddOfNonZero(Value *V1, Value *V2, const DataLayout &DL, - const Query &Q) { +static bool isAddOfNonZero(Value *V1, Value *V2, const Query &Q) { BinaryOperator *BO = dyn_cast<BinaryOperator>(V1); if (!BO || BO->getOpcode() != Instruction::Add) return false; @@ -2086,18 +1876,17 @@ static bool isAddOfNonZero(Value *V1, Value *V2, const DataLayout &DL, Op = BO->getOperand(0); else return false; - return isKnownNonZero(Op, DL, 0, Q); + return isKnownNonZero(Op, 0, Q); } /// Return true if it is known that V1 != V2. -static bool isKnownNonEqual(Value *V1, Value *V2, const DataLayout &DL, - const Query &Q) { +static bool isKnownNonEqual(Value *V1, Value *V2, const Query &Q) { if (V1->getType()->isVectorTy() || V1 == V2) return false; if (V1->getType() != V2->getType()) // We can't look through casts yet. return false; - if (isAddOfNonZero(V1, V2, DL, Q) || isAddOfNonZero(V2, V1, DL, Q)) + if (isAddOfNonZero(V1, V2, Q) || isAddOfNonZero(V2, V1, Q)) return true; if (IntegerType *Ty = dyn_cast<IntegerType>(V1->getType())) { @@ -2106,10 +1895,10 @@ static bool isKnownNonEqual(Value *V1, Value *V2, const DataLayout &DL, auto BitWidth = Ty->getBitWidth(); APInt KnownZero1(BitWidth, 0); APInt KnownOne1(BitWidth, 0); - computeKnownBits(V1, KnownZero1, KnownOne1, DL, 0, Q); + computeKnownBits(V1, KnownZero1, KnownOne1, 0, Q); APInt KnownZero2(BitWidth, 0); APInt KnownOne2(BitWidth, 0); - computeKnownBits(V2, KnownZero2, KnownOne2, DL, 0, Q); + computeKnownBits(V2, KnownZero2, KnownOne2, 0, Q); auto OppositeBits = (KnownZero1 & KnownOne2) | (KnownZero2 & KnownOne1); if (OppositeBits.getBoolValue()) @@ -2127,26 +1916,48 @@ static bool isKnownNonEqual(Value *V1, Value *V2, const DataLayout &DL, /// where V is a vector, the mask, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true /// for all of the elements in the vector. -bool MaskedValueIsZero(Value *V, const APInt &Mask, const DataLayout &DL, - unsigned Depth, const Query &Q) { +bool MaskedValueIsZero(Value *V, const APInt &Mask, unsigned Depth, + const Query &Q) { APInt KnownZero(Mask.getBitWidth(), 0), KnownOne(Mask.getBitWidth(), 0); - computeKnownBits(V, KnownZero, KnownOne, DL, Depth, Q); + computeKnownBits(V, KnownZero, KnownOne, Depth, Q); return (KnownZero & Mask) == Mask; } +/// For vector constants, loop over the elements and find the constant with the +/// minimum number of sign bits. Return 0 if the value is not a vector constant +/// or if any element was not analyzed; otherwise, return the count for the +/// element with the minimum number of sign bits. +static unsigned computeNumSignBitsVectorConstant(Value *V, unsigned TyBits) { + auto *CV = dyn_cast<Constant>(V); + if (!CV || !CV->getType()->isVectorTy()) + return 0; + unsigned MinSignBits = TyBits; + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + // If we find a non-ConstantInt, bail out. + auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i)); + if (!Elt) + return 0; + + // If the sign bit is 1, flip the bits, so we always count leading zeros. + APInt EltVal = Elt->getValue(); + if (EltVal.isNegative()) + EltVal = ~EltVal; + MinSignBits = std::min(MinSignBits, EltVal.countLeadingZeros()); + } + + return MinSignBits; +} /// Return the number of times the sign bit of the register is replicated into /// the other bits. We know that at least 1 bit is always equal to the sign bit /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each -/// other, so we return 3. -/// -/// 'Op' must have a scalar integer type. -/// -unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, - const Query &Q) { - unsigned TyBits = DL.getTypeSizeInBits(V->getType()->getScalarType()); +/// other, so we return 3. For vectors, return the number of sign bits for the +/// vector element with the mininum number of known sign bits. +unsigned ComputeNumSignBits(Value *V, unsigned Depth, const Query &Q) { + unsigned TyBits = Q.DL.getTypeSizeInBits(V->getType()->getScalarType()); unsigned Tmp, Tmp2; unsigned FirstAnswer = 1; @@ -2161,7 +1972,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, default: break; case Instruction::SExt: Tmp = TyBits - U->getOperand(0)->getType()->getScalarSizeInBits(); - return ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q) + Tmp; + return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q) + Tmp; case Instruction::SDiv: { const APInt *Denominator; @@ -2173,7 +1984,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, break; // Calculate the incoming numerator bits. - unsigned NumBits = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + unsigned NumBits = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); // Add floor(log(C)) bits to the numerator bits. return std::min(TyBits, NumBits + Denominator->logBase2()); @@ -2195,7 +2006,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, // Calculate the incoming numerator bits. SRem by a positive constant // can't lower the number of sign bits. unsigned NumrBits = - ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); // Calculate the leading sign bit constraints by examining the // denominator. Given that the denominator is positive, there are two @@ -2217,7 +2028,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, } case Instruction::AShr: { - Tmp = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { @@ -2230,7 +2041,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. - Tmp = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); Tmp2 = ShAmt->getZExtValue(); if (Tmp2 >= TyBits || // Bad shift. Tmp2 >= Tmp) break; // Shifted all sign bits out. @@ -2242,9 +2053,9 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, case Instruction::Or: case Instruction::Xor: // NOT is handled here. // Logical binary ops preserve the number of sign bits at the worst. - Tmp = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); if (Tmp != 1) { - Tmp2 = ComputeNumSignBits(U->getOperand(1), DL, Depth + 1, Q); + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); FirstAnswer = std::min(Tmp, Tmp2); // We computed what we know about the sign bits as our first // answer. Now proceed to the generic code that uses @@ -2253,23 +2064,22 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, break; case Instruction::Select: - Tmp = ComputeNumSignBits(U->getOperand(1), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); if (Tmp == 1) return 1; // Early out. - Tmp2 = ComputeNumSignBits(U->getOperand(2), DL, Depth + 1, Q); + Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); return std::min(Tmp, Tmp2); case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); if (Tmp == 1) return 1; // Early out. // Special case decrementing a value (ADD X, -1): if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) if (CRHS->isAllOnesValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(0), KnownZero, KnownOne, DL, Depth + 1, - Q); + computeKnownBits(U->getOperand(0), KnownZero, KnownOne, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. @@ -2282,20 +2092,19 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, return Tmp; } - Tmp2 = ComputeNumSignBits(U->getOperand(1), DL, Depth + 1, Q); + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); if (Tmp2 == 1) return 1; return std::min(Tmp, Tmp2)-1; case Instruction::Sub: - Tmp2 = ComputeNumSignBits(U->getOperand(1), DL, Depth + 1, Q); + Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); if (Tmp2 == 1) return 1; // Handle NEG. if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) if (CLHS->isNullValue()) { APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - computeKnownBits(U->getOperand(1), KnownZero, KnownOne, DL, Depth + 1, - Q); + computeKnownBits(U->getOperand(1), KnownZero, KnownOne, Depth + 1, Q); // If the input is known to be 0 or 1, the output is 0/-1, which is all // sign bits set. if ((KnownZero | APInt(TyBits, 1)).isAllOnesValue()) @@ -2311,7 +2120,7 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. - Tmp = ComputeNumSignBits(U->getOperand(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; @@ -2325,11 +2134,11 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, // Take the minimum of all incoming values. This can't infinitely loop // because of our depth threshold. - Tmp = ComputeNumSignBits(PN->getIncomingValue(0), DL, Depth + 1, Q); + Tmp = ComputeNumSignBits(PN->getIncomingValue(0), Depth + 1, Q); for (unsigned i = 1, e = NumIncomingValues; i != e; ++i) { if (Tmp == 1) return Tmp; Tmp = std::min( - Tmp, ComputeNumSignBits(PN->getIncomingValue(i), DL, Depth + 1, Q)); + Tmp, ComputeNumSignBits(PN->getIncomingValue(i), Depth + 1, Q)); } return Tmp; } @@ -2342,26 +2151,25 @@ unsigned ComputeNumSignBits(Value *V, const DataLayout &DL, unsigned Depth, // Finally, if we can prove that the top bits of the result are 0's or 1's, // use this information. + + // If we can examine all elements of a vector constant successfully, we're + // done (we can't do any better than that). If not, keep trying. + if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits)) + return VecSignBits; + APInt KnownZero(TyBits, 0), KnownOne(TyBits, 0); - APInt Mask; - computeKnownBits(V, KnownZero, KnownOne, DL, Depth, Q); - - if (KnownZero.isNegative()) { // sign bit is 0 - Mask = KnownZero; - } else if (KnownOne.isNegative()) { // sign bit is 1; - Mask = KnownOne; - } else { - // Nothing known. - return FirstAnswer; - } + computeKnownBits(V, KnownZero, KnownOne, Depth, Q); + + // If we know that the sign bit is either zero or one, determine the number of + // identical bits in the top of the input value. + if (KnownZero.isNegative()) + return std::max(FirstAnswer, KnownZero.countLeadingOnes()); - // Okay, we know that the sign bit in Mask is set. Use CLZ to determine - // the number of identical bits in the top of the input value. - Mask = ~Mask; - Mask <<= Mask.getBitWidth()-TyBits; - // Return # leading zeros. We use 'min' here in case Val was zero before - // shifting. We don't want to return '64' as for an i32 "0". - return std::max(FirstAnswer, std::min(TyBits, Mask.countLeadingZeros())); + if (KnownOne.isNegative()) + return std::max(FirstAnswer, KnownOne.countLeadingOnes()); + + // computeKnownBits gave us no extra information about the top bits. + return FirstAnswer; } /// This function computes the integer multiple of Base that equals V. @@ -2484,13 +2292,124 @@ bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, return false; } +Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, + const TargetLibraryInfo *TLI) { + const Function *F = ICS.getCalledFunction(); + if (!F) + return Intrinsic::not_intrinsic; + + if (F->isIntrinsic()) + return F->getIntrinsicID(); + + if (!TLI) + return Intrinsic::not_intrinsic; + + LibFunc::Func Func; + // We're going to make assumptions on the semantics of the functions, check + // that the target knows that it's available in this environment and it does + // not have local linkage. + if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(*F, Func)) + return Intrinsic::not_intrinsic; + + if (!ICS.onlyReadsMemory()) + return Intrinsic::not_intrinsic; + + // Otherwise check if we have a call to a function that can be turned into a + // vector intrinsic. + switch (Func) { + default: + break; + case LibFunc::sin: + case LibFunc::sinf: + case LibFunc::sinl: + return Intrinsic::sin; + case LibFunc::cos: + case LibFunc::cosf: + case LibFunc::cosl: + return Intrinsic::cos; + case LibFunc::exp: + case LibFunc::expf: + case LibFunc::expl: + return Intrinsic::exp; + case LibFunc::exp2: + case LibFunc::exp2f: + case LibFunc::exp2l: + return Intrinsic::exp2; + case LibFunc::log: + case LibFunc::logf: + case LibFunc::logl: + return Intrinsic::log; + case LibFunc::log10: + case LibFunc::log10f: + case LibFunc::log10l: + return Intrinsic::log10; + case LibFunc::log2: + case LibFunc::log2f: + case LibFunc::log2l: + return Intrinsic::log2; + case LibFunc::fabs: + case LibFunc::fabsf: + case LibFunc::fabsl: + return Intrinsic::fabs; + case LibFunc::fmin: + case LibFunc::fminf: + case LibFunc::fminl: + return Intrinsic::minnum; + case LibFunc::fmax: + case LibFunc::fmaxf: + case LibFunc::fmaxl: + return Intrinsic::maxnum; + case LibFunc::copysign: + case LibFunc::copysignf: + case LibFunc::copysignl: + return Intrinsic::copysign; + case LibFunc::floor: + case LibFunc::floorf: + case LibFunc::floorl: + return Intrinsic::floor; + case LibFunc::ceil: + case LibFunc::ceilf: + case LibFunc::ceill: + return Intrinsic::ceil; + case LibFunc::trunc: + case LibFunc::truncf: + case LibFunc::truncl: + return Intrinsic::trunc; + case LibFunc::rint: + case LibFunc::rintf: + case LibFunc::rintl: + return Intrinsic::rint; + case LibFunc::nearbyint: + case LibFunc::nearbyintf: + case LibFunc::nearbyintl: + return Intrinsic::nearbyint; + case LibFunc::round: + case LibFunc::roundf: + case LibFunc::roundl: + return Intrinsic::round; + case LibFunc::pow: + case LibFunc::powf: + case LibFunc::powl: + return Intrinsic::pow; + case LibFunc::sqrt: + case LibFunc::sqrtf: + case LibFunc::sqrtl: + if (ICS->hasNoNaNs()) + return Intrinsic::sqrt; + return Intrinsic::not_intrinsic; + } + + return Intrinsic::not_intrinsic; +} + /// Return true if we can prove that the specified FP value is never equal to /// -0.0. /// /// NOTE: this function will need to be revisited when we support non-default /// rounding modes! /// -bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) { +bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, + unsigned Depth) { if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) return !CFP->getValueAPF().isNegZero(); @@ -2518,30 +2437,26 @@ bool llvm::CannotBeNegativeZero(const Value *V, unsigned Depth) { if (isa<SIToFPInst>(I) || isa<UIToFPInst>(I)) return true; - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) + if (const CallInst *CI = dyn_cast<CallInst>(I)) { + Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + switch (IID) { + default: + break; // sqrt(-0.0) = -0.0, no other negative results are possible. - if (II->getIntrinsicID() == Intrinsic::sqrt) - return CannotBeNegativeZero(II->getArgOperand(0), Depth+1); - - if (const CallInst *CI = dyn_cast<CallInst>(I)) - if (const Function *F = CI->getCalledFunction()) { - if (F->isDeclaration()) { - // abs(x) != -0.0 - if (F->getName() == "abs") return true; - // fabs[lf](x) != -0.0 - if (F->getName() == "fabs") return true; - if (F->getName() == "fabsf") return true; - if (F->getName() == "fabsl") return true; - if (F->getName() == "sqrt" || F->getName() == "sqrtf" || - F->getName() == "sqrtl") - return CannotBeNegativeZero(CI->getArgOperand(0), Depth+1); - } + case Intrinsic::sqrt: + return CannotBeNegativeZero(CI->getArgOperand(0), TLI, Depth + 1); + // fabs(x) != -0.0 + case Intrinsic::fabs: + return true; } + } return false; } -bool llvm::CannotBeOrderedLessThanZero(const Value *V, unsigned Depth) { +bool llvm::CannotBeOrderedLessThanZero(const Value *V, + const TargetLibraryInfo *TLI, + unsigned Depth) { if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) return !CFP->getValueAPF().isNegative() || CFP->getValueAPF().isZero(); @@ -2561,52 +2476,53 @@ bool llvm::CannotBeOrderedLessThanZero(const Value *V, unsigned Depth) { return true; case Instruction::FMul: // x*x is always non-negative or a NaN. - if (I->getOperand(0) == I->getOperand(1)) + if (I->getOperand(0) == I->getOperand(1)) return true; // Fall through case Instruction::FAdd: case Instruction::FDiv: case Instruction::FRem: - return CannotBeOrderedLessThanZero(I->getOperand(0), Depth+1) && - CannotBeOrderedLessThanZero(I->getOperand(1), Depth+1); + return CannotBeOrderedLessThanZero(I->getOperand(0), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(I->getOperand(1), TLI, Depth + 1); case Instruction::Select: - return CannotBeOrderedLessThanZero(I->getOperand(1), Depth+1) && - CannotBeOrderedLessThanZero(I->getOperand(2), Depth+1); + return CannotBeOrderedLessThanZero(I->getOperand(1), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(I->getOperand(2), TLI, Depth + 1); case Instruction::FPExt: case Instruction::FPTrunc: // Widening/narrowing never change sign. - return CannotBeOrderedLessThanZero(I->getOperand(0), Depth+1); - case Instruction::Call: - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) - switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::maxnum: - return CannotBeOrderedLessThanZero(I->getOperand(0), Depth+1) || - CannotBeOrderedLessThanZero(I->getOperand(1), Depth+1); - case Intrinsic::minnum: - return CannotBeOrderedLessThanZero(I->getOperand(0), Depth+1) && - CannotBeOrderedLessThanZero(I->getOperand(1), Depth+1); - case Intrinsic::exp: - case Intrinsic::exp2: - case Intrinsic::fabs: - case Intrinsic::sqrt: - return true; - case Intrinsic::powi: - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { - // powi(x,n) is non-negative if n is even. - if (CI->getBitWidth() <= 64 && CI->getSExtValue() % 2u == 0) - return true; - } - return CannotBeOrderedLessThanZero(I->getOperand(0), Depth+1); - case Intrinsic::fma: - case Intrinsic::fmuladd: - // x*x+y is non-negative if y is non-negative. - return I->getOperand(0) == I->getOperand(1) && - CannotBeOrderedLessThanZero(I->getOperand(2), Depth+1); + return CannotBeOrderedLessThanZero(I->getOperand(0), TLI, Depth + 1); + case Instruction::Call: + Intrinsic::ID IID = getIntrinsicForCallSite(cast<CallInst>(I), TLI); + switch (IID) { + default: + break; + case Intrinsic::maxnum: + return CannotBeOrderedLessThanZero(I->getOperand(0), TLI, Depth + 1) || + CannotBeOrderedLessThanZero(I->getOperand(1), TLI, Depth + 1); + case Intrinsic::minnum: + return CannotBeOrderedLessThanZero(I->getOperand(0), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(I->getOperand(1), TLI, Depth + 1); + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::fabs: + case Intrinsic::sqrt: + return true; + case Intrinsic::powi: + if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + // powi(x,n) is non-negative if n is even. + if (CI->getBitWidth() <= 64 && CI->getSExtValue() % 2u == 0) + return true; } + return CannotBeOrderedLessThanZero(I->getOperand(0), TLI, Depth + 1); + case Intrinsic::fma: + case Intrinsic::fmuladd: + // x*x+y is non-negative if y is non-negative. + return I->getOperand(0) == I->getOperand(1) && + CannotBeOrderedLessThanZero(I->getOperand(2), TLI, Depth + 1); + } break; } - return false; + return false; } /// If the specified value can be set by repeating the same byte in memory, @@ -2863,7 +2779,7 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, Operator::getOpcode(Ptr) == Instruction::AddrSpaceCast) { Ptr = cast<Operator>(Ptr)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) break; Ptr = GA->getAliasee(); } else { @@ -2874,6 +2790,24 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, return Ptr; } +bool llvm::isGEPBasedOnPointerToString(const GEPOperator *GEP) { + // Make sure the GEP has exactly three arguments. + if (GEP->getNumOperands() != 3) + return false; + + // Make sure the index-ee is a pointer to array of i8. + ArrayType *AT = dyn_cast<ArrayType>(GEP->getSourceElementType()); + if (!AT || !AT->getElementType()->isIntegerTy(8)) + return false; + + // Check to make sure that the first operand of the GEP is an integer and + // has value 0 so that we are sure we're indexing into the initializer. + const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1)); + if (!FirstIdx || !FirstIdx->isZero()) + return false; + + return true; +} /// This function computes the length of a null-terminated C string pointed to /// by V. If successful, it returns true and returns the string in Str. @@ -2888,20 +2822,9 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, // If the value is a GEP instruction or constant expression, treat it as an // offset. if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - // Make sure the GEP has exactly three arguments. - if (GEP->getNumOperands() != 3) - return false; - - // Make sure the index-ee is a pointer to array of i8. - PointerType *PT = cast<PointerType>(GEP->getOperand(0)->getType()); - ArrayType *AT = dyn_cast<ArrayType>(PT->getElementType()); - if (!AT || !AT->getElementType()->isIntegerTy(8)) - return false; - - // Check to make sure that the first operand of the GEP is an integer and - // has value 0 so that we are sure we're indexing into the initializer. - const ConstantInt *FirstIdx = dyn_cast<ConstantInt>(GEP->getOperand(1)); - if (!FirstIdx || !FirstIdx->isZero()) + // The GEP operator should be based on a pointer to string constant, and is + // indexing into the string constant. + if (!isGEPBasedOnPointerToString(GEP)) return false; // If the second index isn't a ConstantInt, then this is a variable index @@ -2923,7 +2846,7 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, if (!GV || !GV->isConstant() || !GV->hasDefinitiveInitializer()) return false; - // Handle the all-zeros case + // Handle the all-zeros case. if (GV->getInitializer()->isNullValue()) { // This is a degenerate case. The initializer is constant zero so the // length of the string must be zero. @@ -2931,13 +2854,12 @@ bool llvm::getConstantStringInfo(const Value *V, StringRef &Str, return true; } - // Must be a Constant Array - const ConstantDataArray *Array = - dyn_cast<ConstantDataArray>(GV->getInitializer()); + // This must be a ConstantDataArray. + const auto *Array = dyn_cast<ConstantDataArray>(GV->getInitializer()); if (!Array || !Array->isString()) return false; - // Get the number of elements in the array + // Get the number of elements in the array. uint64_t NumElts = Array->getType()->getArrayNumElements(); // Start out with the entire array in the StringRef. @@ -3060,10 +2982,16 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, Operator::getOpcode(V) == Instruction::AddrSpaceCast) { V = cast<Operator>(V)->getOperand(0); } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { - if (GA->mayBeOverridden()) + if (GA->isInterposable()) return V; V = GA->getAliasee(); } else { + if (auto CS = CallSite(V)) + if (Value *RV = CS.getReturnedArgOperand()) { + V = RV; + continue; + } + // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) // TODO: Acquire a DominatorTree and AssumptionCache and use them. @@ -3133,213 +3061,9 @@ bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { return true; } -static bool isDereferenceableFromAttribute(const Value *BV, APInt Offset, - Type *Ty, const DataLayout &DL, - const Instruction *CtxI, - const DominatorTree *DT, - const TargetLibraryInfo *TLI) { - assert(Offset.isNonNegative() && "offset can't be negative"); - assert(Ty->isSized() && "must be sized"); - - APInt DerefBytes(Offset.getBitWidth(), 0); - bool CheckForNonNull = false; - if (const Argument *A = dyn_cast<Argument>(BV)) { - DerefBytes = A->getDereferenceableBytes(); - if (!DerefBytes.getBoolValue()) { - DerefBytes = A->getDereferenceableOrNullBytes(); - CheckForNonNull = true; - } - } else if (auto CS = ImmutableCallSite(BV)) { - DerefBytes = CS.getDereferenceableBytes(0); - if (!DerefBytes.getBoolValue()) { - DerefBytes = CS.getDereferenceableOrNullBytes(0); - CheckForNonNull = true; - } - } else if (const LoadInst *LI = dyn_cast<LoadInst>(BV)) { - if (MDNode *MD = LI->getMetadata(LLVMContext::MD_dereferenceable)) { - ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(0)); - DerefBytes = CI->getLimitedValue(); - } - if (!DerefBytes.getBoolValue()) { - if (MDNode *MD = - LI->getMetadata(LLVMContext::MD_dereferenceable_or_null)) { - ConstantInt *CI = mdconst::extract<ConstantInt>(MD->getOperand(0)); - DerefBytes = CI->getLimitedValue(); - } - CheckForNonNull = true; - } - } - - if (DerefBytes.getBoolValue()) - if (DerefBytes.uge(Offset + DL.getTypeStoreSize(Ty))) - if (!CheckForNonNull || isKnownNonNullAt(BV, CtxI, DT, TLI)) - return true; - - return false; -} - -static bool isDereferenceableFromAttribute(const Value *V, const DataLayout &DL, - const Instruction *CtxI, - const DominatorTree *DT, - const TargetLibraryInfo *TLI) { - Type *VTy = V->getType(); - Type *Ty = VTy->getPointerElementType(); - if (!Ty->isSized()) - return false; - - APInt Offset(DL.getTypeStoreSizeInBits(VTy), 0); - return isDereferenceableFromAttribute(V, Offset, Ty, DL, CtxI, DT, TLI); -} - -static bool isAligned(const Value *Base, APInt Offset, unsigned Align, - const DataLayout &DL) { - APInt BaseAlign(Offset.getBitWidth(), getAlignment(Base, DL)); - - if (!BaseAlign) { - Type *Ty = Base->getType()->getPointerElementType(); - if (!Ty->isSized()) - return false; - BaseAlign = DL.getABITypeAlignment(Ty); - } - - APInt Alignment(Offset.getBitWidth(), Align); - - assert(Alignment.isPowerOf2() && "must be a power of 2!"); - return BaseAlign.uge(Alignment) && !(Offset & (Alignment-1)); -} - -static bool isAligned(const Value *Base, unsigned Align, const DataLayout &DL) { - Type *Ty = Base->getType(); - assert(Ty->isSized() && "must be sized"); - APInt Offset(DL.getTypeStoreSizeInBits(Ty), 0); - return isAligned(Base, Offset, Align, DL); -} - -/// Test if V is always a pointer to allocated and suitably aligned memory for -/// a simple load or store. -static bool isDereferenceableAndAlignedPointer( - const Value *V, unsigned Align, const DataLayout &DL, - const Instruction *CtxI, const DominatorTree *DT, - const TargetLibraryInfo *TLI, SmallPtrSetImpl<const Value *> &Visited) { - // Note that it is not safe to speculate into a malloc'd region because - // malloc may return null. - - // These are obviously ok if aligned. - if (isa<AllocaInst>(V)) - return isAligned(V, Align, DL); - - // It's not always safe to follow a bitcast, for example: - // bitcast i8* (alloca i8) to i32* - // would result in a 4-byte load from a 1-byte alloca. However, - // if we're casting from a pointer from a type of larger size - // to a type of smaller size (or the same size), and the alignment - // is at least as large as for the resulting pointer type, then - // we can look through the bitcast. - if (const BitCastOperator *BC = dyn_cast<BitCastOperator>(V)) { - Type *STy = BC->getSrcTy()->getPointerElementType(), - *DTy = BC->getDestTy()->getPointerElementType(); - if (STy->isSized() && DTy->isSized() && - (DL.getTypeStoreSize(STy) >= DL.getTypeStoreSize(DTy)) && - (DL.getABITypeAlignment(STy) >= DL.getABITypeAlignment(DTy))) - return isDereferenceableAndAlignedPointer(BC->getOperand(0), Align, DL, - CtxI, DT, TLI, Visited); - } - - // Global variables which can't collapse to null are ok. - if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) - if (!GV->hasExternalWeakLinkage()) - return isAligned(V, Align, DL); - - // byval arguments are okay. - if (const Argument *A = dyn_cast<Argument>(V)) - if (A->hasByValAttr()) - return isAligned(V, Align, DL); - - if (isDereferenceableFromAttribute(V, DL, CtxI, DT, TLI)) - return isAligned(V, Align, DL); - - // For GEPs, determine if the indexing lands within the allocated object. - if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - Type *VTy = GEP->getType(); - Type *Ty = VTy->getPointerElementType(); - const Value *Base = GEP->getPointerOperand(); - - // Conservatively require that the base pointer be fully dereferenceable - // and aligned. - if (!Visited.insert(Base).second) - return false; - if (!isDereferenceableAndAlignedPointer(Base, Align, DL, CtxI, DT, TLI, - Visited)) - return false; - - APInt Offset(DL.getPointerTypeSizeInBits(VTy), 0); - if (!GEP->accumulateConstantOffset(DL, Offset)) - return false; - - // Check if the load is within the bounds of the underlying object - // and offset is aligned. - uint64_t LoadSize = DL.getTypeStoreSize(Ty); - Type *BaseType = Base->getType()->getPointerElementType(); - assert(isPowerOf2_32(Align) && "must be a power of 2!"); - return (Offset + LoadSize).ule(DL.getTypeAllocSize(BaseType)) && - !(Offset & APInt(Offset.getBitWidth(), Align-1)); - } - - // For gc.relocate, look through relocations - if (const GCRelocateInst *RelocateInst = dyn_cast<GCRelocateInst>(V)) - return isDereferenceableAndAlignedPointer( - RelocateInst->getDerivedPtr(), Align, DL, CtxI, DT, TLI, Visited); - - if (const AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(V)) - return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Align, DL, - CtxI, DT, TLI, Visited); - - // If we don't know, assume the worst. - return false; -} - -bool llvm::isDereferenceableAndAlignedPointer(const Value *V, unsigned Align, - const DataLayout &DL, - const Instruction *CtxI, - const DominatorTree *DT, - const TargetLibraryInfo *TLI) { - // When dereferenceability information is provided by a dereferenceable - // attribute, we know exactly how many bytes are dereferenceable. If we can - // determine the exact offset to the attributed variable, we can use that - // information here. - Type *VTy = V->getType(); - Type *Ty = VTy->getPointerElementType(); - - // Require ABI alignment for loads without alignment specification - if (Align == 0) - Align = DL.getABITypeAlignment(Ty); - - if (Ty->isSized()) { - APInt Offset(DL.getTypeStoreSizeInBits(VTy), 0); - const Value *BV = V->stripAndAccumulateInBoundsConstantOffsets(DL, Offset); - - if (Offset.isNonNegative()) - if (isDereferenceableFromAttribute(BV, Offset, Ty, DL, CtxI, DT, TLI) && - isAligned(BV, Offset, Align, DL)) - return true; - } - - SmallPtrSet<const Value *, 32> Visited; - return ::isDereferenceableAndAlignedPointer(V, Align, DL, CtxI, DT, TLI, - Visited); -} - -bool llvm::isDereferenceablePointer(const Value *V, const DataLayout &DL, - const Instruction *CtxI, - const DominatorTree *DT, - const TargetLibraryInfo *TLI) { - return isDereferenceableAndAlignedPointer(V, 1, DL, CtxI, DT, TLI); -} - bool llvm::isSafeToSpeculativelyExecute(const Value *V, const Instruction *CtxI, - const DominatorTree *DT, - const TargetLibraryInfo *TLI) { + const DominatorTree *DT) { const Operator *Inst = dyn_cast<Operator>(V); if (!Inst) return false; @@ -3383,15 +3107,13 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, const LoadInst *LI = cast<LoadInst>(Inst); if (!LI->isUnordered() || // Speculative load may create a race that did not exist in the source. - LI->getParent()->getParent()->hasFnAttribute( - Attribute::SanitizeThread) || + LI->getFunction()->hasFnAttribute(Attribute::SanitizeThread) || // Speculative load may load data from dirty regions. - LI->getParent()->getParent()->hasFnAttribute( - Attribute::SanitizeAddress)) + LI->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) return false; const DataLayout &DL = LI->getModule()->getDataLayout(); - return isDereferenceableAndAlignedPointer( - LI->getPointerOperand(), LI->getAlignment(), DL, CtxI, DT, TLI); + return isDereferenceableAndAlignedPointer(LI->getPointerOperand(), + LI->getAlignment(), DL, CtxI, DT); } case Instruction::Call: { if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { @@ -3416,17 +3138,29 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, case Intrinsic::umul_with_overflow: case Intrinsic::usub_with_overflow: return true; - // Sqrt should be OK, since the llvm sqrt intrinsic isn't defined to set - // errno like libm sqrt would. + // These intrinsics are defined to have the same behavior as libm + // functions except for setting errno. case Intrinsic::sqrt: case Intrinsic::fma: case Intrinsic::fmuladd: + return true; + // These intrinsics are defined to have the same behavior as libm + // functions, and the corresponding libm functions never set errno. + case Intrinsic::trunc: + case Intrinsic::copysign: case Intrinsic::fabs: case Intrinsic::minnum: case Intrinsic::maxnum: return true; - // TODO: some fp intrinsics are marked as having the same error handling - // as libm. They're safe to speculate when they won't error. + // These intrinsics are defined to have the same behavior as libm + // functions, which never overflow when operating on the IEEE754 types + // that we support, and never set errno otherwise. + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::nearbyint: + case Intrinsic::rint: + case Intrinsic::round: + return true; // TODO: are convert_{from,to}_fp16 safe? // TODO: can we list target-specific intrinsics here? default: break; @@ -3464,7 +3198,7 @@ bool llvm::mayBeMemoryDependent(const Instruction &I) { } /// Return true if we know that the specified value is never null. -bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { +bool llvm::isKnownNonNull(const Value *V) { assert(V->getType()->isPointerTy() && "V must be pointer type"); // Alloca never returns null, malloc might. @@ -3481,7 +3215,7 @@ bool llvm::isKnownNonNull(const Value *V, const TargetLibraryInfo *TLI) { return !GV->hasExternalWeakLinkage() && GV->getType()->getAddressSpace() == 0; - // A Load tagged w/nonnull metadata is never null. + // A Load tagged with nonnull metadata is never null. if (const LoadInst *LI = dyn_cast<LoadInst>(V)) return LI->getMetadata(LLVMContext::MD_nonnull); @@ -3498,41 +3232,31 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, assert(V->getType()->isPointerTy() && "V must be pointer type"); unsigned NumUsesExplored = 0; - for (auto U : V->users()) { + for (auto *U : V->users()) { // Avoid massive lists if (NumUsesExplored >= DomConditionsMaxUses) break; NumUsesExplored++; // Consider only compare instructions uniquely controlling a branch - const ICmpInst *Cmp = dyn_cast<ICmpInst>(U); - if (!Cmp) - continue; - - if (DomConditionsSingleCmpUse && !Cmp->hasOneUse()) + CmpInst::Predicate Pred; + if (!match(const_cast<User *>(U), + m_c_ICmp(Pred, m_Specific(V), m_Zero())) || + (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) continue; - for (auto *CmpU : Cmp->users()) { - const BranchInst *BI = dyn_cast<BranchInst>(CmpU); - if (!BI) - continue; - - assert(BI->isConditional() && "uses a comparison!"); - - BasicBlock *NonNullSuccessor = nullptr; - CmpInst::Predicate Pred; - - if (match(const_cast<ICmpInst*>(Cmp), - m_c_ICmp(Pred, m_Specific(V), m_Zero()))) { - if (Pred == ICmpInst::ICMP_EQ) - NonNullSuccessor = BI->getSuccessor(1); - else if (Pred == ICmpInst::ICMP_NE) - NonNullSuccessor = BI->getSuccessor(0); - } + for (auto *CmpU : U->users()) { + if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { + assert(BI->isConditional() && "uses a comparison!"); - if (NonNullSuccessor) { + BasicBlock *NonNullSuccessor = + BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) return true; + } else if (Pred == ICmpInst::ICMP_NE && + match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && + DT->dominates(cast<Instruction>(CmpU), CtxI)) { + return true; } } } @@ -3541,8 +3265,8 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, } bool llvm::isKnownNonNullAt(const Value *V, const Instruction *CtxI, - const DominatorTree *DT, const TargetLibraryInfo *TLI) { - if (isKnownNonNull(V, TLI)) + const DominatorTree *DT) { + if (isKnownNonNull(V)) return true; return CtxI ? ::isKnownNonNullFromDominatingCondition(V, CtxI, DT) : false; @@ -3671,6 +3395,67 @@ static OverflowResult computeOverflowForSignedAdd( return OverflowResult::MayOverflow; } +bool llvm::isOverflowIntrinsicNoWrap(IntrinsicInst *II, DominatorTree &DT) { +#ifndef NDEBUG + auto IID = II->getIntrinsicID(); + assert((IID == Intrinsic::sadd_with_overflow || + IID == Intrinsic::uadd_with_overflow || + IID == Intrinsic::ssub_with_overflow || + IID == Intrinsic::usub_with_overflow || + IID == Intrinsic::smul_with_overflow || + IID == Intrinsic::umul_with_overflow) && + "Not an overflow intrinsic!"); +#endif + + SmallVector<BranchInst *, 2> GuardingBranches; + SmallVector<ExtractValueInst *, 2> Results; + + for (User *U : II->users()) { + if (auto *EVI = dyn_cast<ExtractValueInst>(U)) { + assert(EVI->getNumIndices() == 1 && "Obvious from CI's type"); + + if (EVI->getIndices()[0] == 0) + Results.push_back(EVI); + else { + assert(EVI->getIndices()[0] == 1 && "Obvious from CI's type"); + + for (auto *U : EVI->users()) + if (auto *B = dyn_cast<BranchInst>(U)) { + assert(B->isConditional() && "How else is it using an i1?"); + GuardingBranches.push_back(B); + } + } + } else { + // We are using the aggregate directly in a way we don't want to analyze + // here (storing it to a global, say). + return false; + } + } + + auto AllUsesGuardedByBranch = [&](BranchInst *BI) { + BasicBlockEdge NoWrapEdge(BI->getParent(), BI->getSuccessor(1)); + if (!NoWrapEdge.isSingleEdge()) + return false; + + // Check if all users of the add are provably no-wrap. + for (auto *Result : Results) { + // If the extractvalue itself is not executed on overflow, the we don't + // need to check each use separately, since domination is transitive. + if (DT.dominates(NoWrapEdge, Result->getParent())) + continue; + + for (auto &RU : Result->uses()) + if (!DT.dominates(NoWrapEdge, RU)) + return false; + } + + return true; + }; + + return any_of(GuardingBranches, AllUsesGuardedByBranch); +} + + OverflowResult llvm::computeOverflowForSignedAdd(AddOperator *Add, const DataLayout &DL, AssumptionCache *AC, @@ -3689,16 +3474,46 @@ OverflowResult llvm::computeOverflowForSignedAdd(Value *LHS, Value *RHS, } bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { - // FIXME: This conservative implementation can be relaxed. E.g. most - // atomic operations are guaranteed to terminate on most platforms - // and most functions terminate. - - return !I->isAtomic() && // atomics may never succeed on some platforms - !isa<CallInst>(I) && // could throw and might not terminate - !isa<InvokeInst>(I) && // might not terminate and could throw to - // non-successor (see bug 24185 for details). - !isa<ResumeInst>(I) && // has no successors - !isa<ReturnInst>(I); // has no successors + // A memory operation returns normally if it isn't volatile. A volatile + // operation is allowed to trap. + // + // An atomic operation isn't guaranteed to return in a reasonable amount of + // time because it's possible for another thread to interfere with it for an + // arbitrary length of time, but programs aren't allowed to rely on that. + if (const LoadInst *LI = dyn_cast<LoadInst>(I)) + return !LI->isVolatile(); + if (const StoreInst *SI = dyn_cast<StoreInst>(I)) + return !SI->isVolatile(); + if (const AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(I)) + return !CXI->isVolatile(); + if (const AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(I)) + return !RMWI->isVolatile(); + if (const MemIntrinsic *MII = dyn_cast<MemIntrinsic>(I)) + return !MII->isVolatile(); + + // If there is no successor, then execution can't transfer to it. + if (const auto *CRI = dyn_cast<CleanupReturnInst>(I)) + return !CRI->unwindsToCaller(); + if (const auto *CatchSwitch = dyn_cast<CatchSwitchInst>(I)) + return !CatchSwitch->unwindsToCaller(); + if (isa<ResumeInst>(I)) + return false; + if (isa<ReturnInst>(I)) + return false; + + // Calls can throw, or contain an infinite loop, or kill the process. + if (CallSite CS = CallSite(const_cast<Instruction*>(I))) { + // Calls which don't write to arbitrary memory are safe. + // FIXME: Ignoring infinite loops without any side-effects is too aggressive, + // but it's consistent with other passes. See http://llvm.org/PR965 . + // FIXME: This isn't aggressive enough; a call which only writes to a + // global is guaranteed to return. + return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory() || + match(I, m_Intrinsic<Intrinsic::assume>()); + } + + // Other instructions return normally. + return true; } bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, @@ -3775,6 +3590,11 @@ bool llvm::propagatesFullPoison(const Instruction *I) { return false; } + case Instruction::ICmp: + // Comparing poison with any value yields poison. This is why, for + // instance, x s< (x +nsw 1) can be folded to true. + return true; + case Instruction::GetElementPtr: // A GEP implicitly represents a sequence of additions, subtractions, // truncations, sign extensions and multiplications. The multiplications @@ -3827,26 +3647,44 @@ bool llvm::isKnownNotFullPoison(const Instruction *PoisonI) { // Set of instructions that we have proved will yield poison if PoisonI // does. SmallSet<const Value *, 16> YieldsPoison; + SmallSet<const BasicBlock *, 4> Visited; YieldsPoison.insert(PoisonI); + Visited.insert(PoisonI->getParent()); - for (BasicBlock::const_iterator I = PoisonI->getIterator(), E = BB->end(); - I != E; ++I) { - if (&*I != PoisonI) { - const Value *NotPoison = getGuaranteedNonFullPoisonOp(&*I); - if (NotPoison != nullptr && YieldsPoison.count(NotPoison)) return true; - if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) - return false; + BasicBlock::const_iterator Begin = PoisonI->getIterator(), End = BB->end(); + + unsigned Iter = 0; + while (Iter++ < MaxDepth) { + for (auto &I : make_range(Begin, End)) { + if (&I != PoisonI) { + const Value *NotPoison = getGuaranteedNonFullPoisonOp(&I); + if (NotPoison != nullptr && YieldsPoison.count(NotPoison)) + return true; + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + return false; + } + + // Mark poison that propagates from I through uses of I. + if (YieldsPoison.count(&I)) { + for (const User *User : I.users()) { + const Instruction *UserI = cast<Instruction>(User); + if (propagatesFullPoison(UserI)) + YieldsPoison.insert(User); + } + } } - // Mark poison that propagates from I through uses of I. - if (YieldsPoison.count(&*I)) { - for (const User *User : I->users()) { - const Instruction *UserI = cast<Instruction>(User); - if (UserI->getParent() == BB && propagatesFullPoison(UserI)) - YieldsPoison.insert(User); + if (auto *NextBB = BB->getSingleSuccessor()) { + if (Visited.insert(NextBB).second) { + BB = NextBB; + Begin = BB->getFirstNonPHI()->getIterator(); + End = BB->end(); + continue; } } - } + + break; + }; return false; } @@ -3979,10 +3817,11 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, return {(CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; } } - + // Y >s C ? ~Y : ~C == ~Y <s ~C ? ~Y : ~C = SMIN(~Y, ~C) if (const auto *C2 = dyn_cast<ConstantInt>(FalseVal)) { - if (C1->getType() == C2->getType() && ~C1->getValue() == C2->getValue() && + if (Pred == ICmpInst::ICMP_SGT && C1->getType() == C2->getType() && + ~C1->getValue() == C2->getValue() && (match(TrueVal, m_Not(m_Specific(CmpLHS))) || match(CmpLHS, m_Not(m_Specific(TrueVal))))) { LHS = TrueVal; @@ -4001,12 +3840,11 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, Instruction::CastOps *CastOp) { CastInst *CI = dyn_cast<CastInst>(V1); Constant *C = dyn_cast<Constant>(V2); - CastInst *CI2 = dyn_cast<CastInst>(V2); if (!CI) return nullptr; *CastOp = CI->getOpcode(); - if (CI2) { + if (auto *CI2 = dyn_cast<CastInst>(V2)) { // If V1 and V2 are both the same cast from the same type, we can look // through V1. if (CI2->getOpcode() == CI->getOpcode() && @@ -4017,43 +3855,48 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, return nullptr; } - if (isa<SExtInst>(CI) && CmpI->isSigned()) { - Constant *T = ConstantExpr::getTrunc(C, CI->getSrcTy()); - // This is only valid if the truncated value can be sign-extended - // back to the original value. - if (ConstantExpr::getSExt(T, C->getType()) == C) - return T; - return nullptr; - } + Constant *CastedTo = nullptr; + if (isa<ZExtInst>(CI) && CmpI->isUnsigned()) - return ConstantExpr::getTrunc(C, CI->getSrcTy()); + CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy()); + + if (isa<SExtInst>(CI) && CmpI->isSigned()) + CastedTo = ConstantExpr::getTrunc(C, CI->getSrcTy(), true); if (isa<TruncInst>(CI)) - return ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned()); + CastedTo = ConstantExpr::getIntegerCast(C, CI->getSrcTy(), CmpI->isSigned()); + + if (isa<FPTruncInst>(CI)) + CastedTo = ConstantExpr::getFPExtend(C, CI->getSrcTy(), true); + + if (isa<FPExtInst>(CI)) + CastedTo = ConstantExpr::getFPTrunc(C, CI->getSrcTy(), true); if (isa<FPToUIInst>(CI)) - return ConstantExpr::getUIToFP(C, CI->getSrcTy(), true); + CastedTo = ConstantExpr::getUIToFP(C, CI->getSrcTy(), true); if (isa<FPToSIInst>(CI)) - return ConstantExpr::getSIToFP(C, CI->getSrcTy(), true); + CastedTo = ConstantExpr::getSIToFP(C, CI->getSrcTy(), true); if (isa<UIToFPInst>(CI)) - return ConstantExpr::getFPToUI(C, CI->getSrcTy(), true); + CastedTo = ConstantExpr::getFPToUI(C, CI->getSrcTy(), true); if (isa<SIToFPInst>(CI)) - return ConstantExpr::getFPToSI(C, CI->getSrcTy(), true); + CastedTo = ConstantExpr::getFPToSI(C, CI->getSrcTy(), true); - if (isa<FPTruncInst>(CI)) - return ConstantExpr::getFPExtend(C, CI->getSrcTy(), true); + if (!CastedTo) + return nullptr; - if (isa<FPExtInst>(CI)) - return ConstantExpr::getFPTrunc(C, CI->getSrcTy(), true); + Constant *CastedBack = + ConstantExpr::getCast(CI->getOpcode(), CastedTo, C->getType(), true); + // Make sure the cast doesn't lose any information. + if (CastedBack != C) + return nullptr; - return nullptr; + return CastedTo; } -SelectPatternResult llvm::matchSelectPattern(Value *V, - Value *&LHS, Value *&RHS, +SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, Instruction::CastOps *CastOp) { SelectInst *SI = dyn_cast<SelectInst>(V); if (!SI) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4172,46 +4015,105 @@ static bool isTruePredicate(CmpInst::Predicate Pred, Value *LHS, Value *RHS, } /// Return true if "icmp Pred BLHS BRHS" is true whenever "icmp Pred -/// ALHS ARHS" is true. -static bool isImpliedCondOperands(CmpInst::Predicate Pred, Value *ALHS, - Value *ARHS, Value *BLHS, Value *BRHS, - const DataLayout &DL, unsigned Depth, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { +/// ALHS ARHS" is true. Otherwise, return None. +static Optional<bool> +isImpliedCondOperands(CmpInst::Predicate Pred, Value *ALHS, Value *ARHS, + Value *BLHS, Value *BRHS, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT) { switch (Pred) { default: - return false; + return None; case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - return isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth, AC, CxtI, - DT) && - isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth, AC, CxtI, - DT); + if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth, AC, CxtI, + DT) && + isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth, AC, CxtI, DT)) + return true; + return None; case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - return isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth, AC, CxtI, - DT) && - isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth, AC, CxtI, - DT); + if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth, AC, CxtI, + DT) && + isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth, AC, CxtI, DT)) + return true; + return None; } } -bool llvm::isImpliedCondition(Value *LHS, Value *RHS, const DataLayout &DL, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - assert(LHS->getType() == RHS->getType() && "mismatched type"); +/// Return true if the operands of the two compares match. IsSwappedOps is true +/// when the operands match, but are swapped. +static bool isMatchingOps(Value *ALHS, Value *ARHS, Value *BLHS, Value *BRHS, + bool &IsSwappedOps) { + + bool IsMatchingOps = (ALHS == BLHS && ARHS == BRHS); + IsSwappedOps = (ALHS == BRHS && ARHS == BLHS); + return IsMatchingOps || IsSwappedOps; +} + +/// Return true if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS BRHS" is +/// true. Return false if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS +/// BRHS" is false. Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, + Value *ALHS, Value *ARHS, + CmpInst::Predicate BPred, + Value *BLHS, Value *BRHS, + bool IsSwappedOps) { + // Canonicalize the operands so they're matching. + if (IsSwappedOps) { + std::swap(BLHS, BRHS); + BPred = ICmpInst::getSwappedPredicate(BPred); + } + if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred)) + return true; + if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred)) + return false; + + return None; +} + +/// Return true if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS C2" is +/// true. Return false if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS +/// C2" is false. Otherwise, return None if we can't infer anything. +static Optional<bool> +isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, Value *ALHS, + ConstantInt *C1, CmpInst::Predicate BPred, + Value *BLHS, ConstantInt *C2) { + assert(ALHS == BLHS && "LHS operands must match."); + ConstantRange DomCR = + ConstantRange::makeExactICmpRegion(APred, C1->getValue()); + ConstantRange CR = + ConstantRange::makeAllowedICmpRegion(BPred, C2->getValue()); + ConstantRange Intersection = DomCR.intersectWith(CR); + ConstantRange Difference = DomCR.difference(CR); + if (Intersection.isEmptySet()) + return false; + if (Difference.isEmptySet()) + return true; + return None; +} + +Optional<bool> llvm::isImpliedCondition(Value *LHS, Value *RHS, + const DataLayout &DL, bool InvertAPred, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // A mismatch occurs when we compare a scalar cmp to a vector cmp, for example. + if (LHS->getType() != RHS->getType()) + return None; + Type *OpTy = LHS->getType(); assert(OpTy->getScalarType()->isIntegerTy(1)); // LHS ==> RHS by definition - if (LHS == RHS) return true; + if (!InvertAPred && LHS == RHS) + return true; if (OpTy->isVectorTy()) // TODO: extending the code below to handle vectors - return false; + return None; assert(OpTy->isIntegerTy(1) && "implied by above"); ICmpInst::Predicate APred, BPred; @@ -4220,11 +4122,37 @@ bool llvm::isImpliedCondition(Value *LHS, Value *RHS, const DataLayout &DL, if (!match(LHS, m_ICmp(APred, m_Value(ALHS), m_Value(ARHS))) || !match(RHS, m_ICmp(BPred, m_Value(BLHS), m_Value(BRHS)))) - return false; + return None; + + if (InvertAPred) + APred = CmpInst::getInversePredicate(APred); + + // Can we infer anything when the two compares have matching operands? + bool IsSwappedOps; + if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, IsSwappedOps)) { + if (Optional<bool> Implication = isImpliedCondMatchingOperands( + APred, ALHS, ARHS, BPred, BLHS, BRHS, IsSwappedOps)) + return Implication; + // No amount of additional analysis will infer the second condition, so + // early exit. + return None; + } + + // Can we infer anything when the LHS operands match and the RHS operands are + // constants (not necessarily matching)? + if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) { + if (Optional<bool> Implication = isImpliedCondMatchingImmOperands( + APred, ALHS, cast<ConstantInt>(ARHS), BPred, BLHS, + cast<ConstantInt>(BRHS))) + return Implication; + // No amount of additional analysis will infer the second condition, so + // early exit. + return None; + } if (APred == BPred) return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth, AC, CxtI, DT); - return false; + return None; } diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index 4b244ec5e1f6..53e7153a350f 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/PatternMatch.h" @@ -51,6 +52,7 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::nearbyint: case Intrinsic::round: case Intrinsic::bswap: + case Intrinsic::bitreverse: case Intrinsic::ctpop: case Intrinsic::pow: case Intrinsic::fma: @@ -78,150 +80,18 @@ bool llvm::hasVectorInstrinsicScalarOpd(Intrinsic::ID ID, } } -/// \brief Check call has a unary float signature -/// It checks following: -/// a) call should have a single argument -/// b) argument type should be floating point type -/// c) call instruction type and argument type should be same -/// d) call should only reads memory. -/// If all these condition is met then return ValidIntrinsicID -/// else return not_intrinsic. -Intrinsic::ID -llvm::checkUnaryFloatSignature(const CallInst &I, - Intrinsic::ID ValidIntrinsicID) { - if (I.getNumArgOperands() != 1 || - !I.getArgOperand(0)->getType()->isFloatingPointTy() || - I.getType() != I.getArgOperand(0)->getType() || !I.onlyReadsMemory()) - return Intrinsic::not_intrinsic; - - return ValidIntrinsicID; -} - -/// \brief Check call has a binary float signature -/// It checks following: -/// a) call should have 2 arguments. -/// b) arguments type should be floating point type -/// c) call instruction type and arguments type should be same -/// d) call should only reads memory. -/// If all these condition is met then return ValidIntrinsicID -/// else return not_intrinsic. -Intrinsic::ID -llvm::checkBinaryFloatSignature(const CallInst &I, - Intrinsic::ID ValidIntrinsicID) { - if (I.getNumArgOperands() != 2 || - !I.getArgOperand(0)->getType()->isFloatingPointTy() || - !I.getArgOperand(1)->getType()->isFloatingPointTy() || - I.getType() != I.getArgOperand(0)->getType() || - I.getType() != I.getArgOperand(1)->getType() || !I.onlyReadsMemory()) - return Intrinsic::not_intrinsic; - - return ValidIntrinsicID; -} - /// \brief Returns intrinsic ID for call. /// For the input call instruction it finds mapping intrinsic and returns /// its ID, in case it does not found it return not_intrinsic. -Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, - const TargetLibraryInfo *TLI) { - // If we have an intrinsic call, check if it is trivially vectorizable. - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { - Intrinsic::ID ID = II->getIntrinsicID(); - if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || - ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) - return ID; +Intrinsic::ID llvm::getVectorIntrinsicIDForCall(const CallInst *CI, + const TargetLibraryInfo *TLI) { + Intrinsic::ID ID = getIntrinsicForCallSite(CI, TLI); + if (ID == Intrinsic::not_intrinsic) return Intrinsic::not_intrinsic; - } - - if (!TLI) - return Intrinsic::not_intrinsic; - - LibFunc::Func Func; - Function *F = CI->getCalledFunction(); - // We're going to make assumptions on the semantics of the functions, check - // that the target knows that it's available in this environment and it does - // not have local linkage. - if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(F->getName(), Func)) - return Intrinsic::not_intrinsic; - - // Otherwise check if we have a call to a function that can be turned into a - // vector intrinsic. - switch (Func) { - default: - break; - case LibFunc::sin: - case LibFunc::sinf: - case LibFunc::sinl: - return checkUnaryFloatSignature(*CI, Intrinsic::sin); - case LibFunc::cos: - case LibFunc::cosf: - case LibFunc::cosl: - return checkUnaryFloatSignature(*CI, Intrinsic::cos); - case LibFunc::exp: - case LibFunc::expf: - case LibFunc::expl: - return checkUnaryFloatSignature(*CI, Intrinsic::exp); - case LibFunc::exp2: - case LibFunc::exp2f: - case LibFunc::exp2l: - return checkUnaryFloatSignature(*CI, Intrinsic::exp2); - case LibFunc::log: - case LibFunc::logf: - case LibFunc::logl: - return checkUnaryFloatSignature(*CI, Intrinsic::log); - case LibFunc::log10: - case LibFunc::log10f: - case LibFunc::log10l: - return checkUnaryFloatSignature(*CI, Intrinsic::log10); - case LibFunc::log2: - case LibFunc::log2f: - case LibFunc::log2l: - return checkUnaryFloatSignature(*CI, Intrinsic::log2); - case LibFunc::fabs: - case LibFunc::fabsf: - case LibFunc::fabsl: - return checkUnaryFloatSignature(*CI, Intrinsic::fabs); - case LibFunc::fmin: - case LibFunc::fminf: - case LibFunc::fminl: - return checkBinaryFloatSignature(*CI, Intrinsic::minnum); - case LibFunc::fmax: - case LibFunc::fmaxf: - case LibFunc::fmaxl: - return checkBinaryFloatSignature(*CI, Intrinsic::maxnum); - case LibFunc::copysign: - case LibFunc::copysignf: - case LibFunc::copysignl: - return checkBinaryFloatSignature(*CI, Intrinsic::copysign); - case LibFunc::floor: - case LibFunc::floorf: - case LibFunc::floorl: - return checkUnaryFloatSignature(*CI, Intrinsic::floor); - case LibFunc::ceil: - case LibFunc::ceilf: - case LibFunc::ceill: - return checkUnaryFloatSignature(*CI, Intrinsic::ceil); - case LibFunc::trunc: - case LibFunc::truncf: - case LibFunc::truncl: - return checkUnaryFloatSignature(*CI, Intrinsic::trunc); - case LibFunc::rint: - case LibFunc::rintf: - case LibFunc::rintl: - return checkUnaryFloatSignature(*CI, Intrinsic::rint); - case LibFunc::nearbyint: - case LibFunc::nearbyintf: - case LibFunc::nearbyintl: - return checkUnaryFloatSignature(*CI, Intrinsic::nearbyint); - case LibFunc::round: - case LibFunc::roundf: - case LibFunc::roundl: - return checkUnaryFloatSignature(*CI, Intrinsic::round); - case LibFunc::pow: - case LibFunc::powf: - case LibFunc::powl: - return checkBinaryFloatSignature(*CI, Intrinsic::pow); - } + if (isTriviallyVectorizable(ID) || ID == Intrinsic::lifetime_start || + ID == Intrinsic::lifetime_end || ID == Intrinsic::assume) + return ID; return Intrinsic::not_intrinsic; } @@ -231,8 +101,7 @@ Intrinsic::ID llvm::getIntrinsicIDForCall(CallInst *CI, unsigned llvm::getGEPInductionOperand(const GetElementPtrInst *Gep) { const DataLayout &DL = Gep->getModule()->getDataLayout(); unsigned LastOperand = Gep->getNumOperands() - 1; - unsigned GEPAllocSize = DL.getTypeAllocSize( - cast<PointerType>(Gep->getType()->getScalarType())->getElementType()); + unsigned GEPAllocSize = DL.getTypeAllocSize(Gep->getResultElementType()); // Walk backwards and try to peel off zeros. while (LastOperand > 1 && match(Gep->getOperand(LastOperand), m_Zero())) { @@ -318,8 +187,6 @@ Value *llvm::getStrideFromPointer(Value *Ptr, ScalarEvolution *SE, Loop *Lp) { // Strip off the size of access multiplication if we are still analyzing the // pointer. if (OrigPtr == Ptr) { - const DataLayout &DL = Lp->getHeader()->getModule()->getDataLayout(); - DL.getTypeAllocSize(PtrTy->getElementType()); if (const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(V)) { if (M->getOperand(0)->getSCEVType() != scConstant) return nullptr; @@ -502,6 +369,7 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, uint64_t V = DB.getDemandedBits(I).getZExtValue(); DBits[Leader] |= V; + DBits[I] = V; // Casts, loads and instructions outside of our range terminate a chain // successfully. @@ -552,6 +420,20 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, // Round up to a power of 2 if (!isPowerOf2_64((uint64_t)MinBW)) MinBW = NextPowerOf2(MinBW); + + // We don't modify the types of PHIs. Reductions will already have been + // truncated if possible, and inductions' sizes will have been chosen by + // indvars. + // If we are required to shrink a PHI, abandon this entire equivalence class. + bool Abort = false; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) + if (isa<PHINode>(*MI) && MinBW < (*MI)->getType()->getScalarSizeInBits()) { + Abort = true; + break; + } + if (Abort) + continue; + for (auto MI = ECs.member_begin(I), ME = ECs.member_end(); MI != ME; ++MI) { if (!isa<Instruction>(*MI)) continue; @@ -565,3 +447,44 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, return MinBWs; } + +/// \returns \p I after propagating metadata from \p VL. +Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { + Instruction *I0 = cast<Instruction>(VL[0]); + SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; + I0->getAllMetadataOtherThanDebugLoc(Metadata); + + for (auto Kind : { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_fpmath, + LLVMContext::MD_nontemporal }) { + MDNode *MD = I0->getMetadata(Kind); + + for (int J = 1, E = VL.size(); MD && J != E; ++J) { + const Instruction *IJ = cast<Instruction>(VL[J]); + MDNode *IMD = IJ->getMetadata(Kind); + switch (Kind) { + case LLVMContext::MD_tbaa: + MD = MDNode::getMostGenericTBAA(MD, IMD); + break; + case LLVMContext::MD_alias_scope: + MD = MDNode::getMostGenericAliasScope(MD, IMD); + break; + case LLVMContext::MD_noalias: + MD = MDNode::intersect(MD, IMD); + break; + case LLVMContext::MD_fpmath: + MD = MDNode::getMostGenericFPMath(MD, IMD); + break; + case LLVMContext::MD_nontemporal: + MD = MDNode::intersect(MD, IMD); + break; + default: + llvm_unreachable("unhandled metadata"); + } + } + + Inst->setMetadata(Kind, MD); + } + + return Inst; +} |