diff options
Diffstat (limited to 'lib/Analysis')
67 files changed, 7708 insertions, 1995 deletions
diff --git a/lib/Analysis/AliasAnalysis.cpp b/lib/Analysis/AliasAnalysis.cpp index a6585df949f8..3446aef39938 100644 --- a/lib/Analysis/AliasAnalysis.cpp +++ b/lib/Analysis/AliasAnalysis.cpp @@ -40,7 +40,6 @@ #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" @@ -118,11 +117,11 @@ bool AAResults::pointsToConstantMemory(const MemoryLocation &Loc, return false; } -ModRefInfo AAResults::getArgModRefInfo(ImmutableCallSite CS, unsigned ArgIdx) { +ModRefInfo AAResults::getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) { ModRefInfo Result = ModRefInfo::ModRef; for (const auto &AA : AAs) { - Result = intersectModRef(Result, AA->getArgModRefInfo(CS, ArgIdx)); + Result = intersectModRef(Result, AA->getArgModRefInfo(Call, ArgIdx)); // Early-exit the moment we reach the bottom of the lattice. if (isNoModRef(Result)) @@ -132,11 +131,11 @@ ModRefInfo AAResults::getArgModRefInfo(ImmutableCallSite CS, unsigned ArgIdx) { return Result; } -ModRefInfo AAResults::getModRefInfo(Instruction *I, ImmutableCallSite Call) { +ModRefInfo AAResults::getModRefInfo(Instruction *I, const CallBase *Call2) { // We may have two calls. - if (auto CS = ImmutableCallSite(I)) { + if (const auto *Call1 = dyn_cast<CallBase>(I)) { // Check if the two calls modify the same memory. - return getModRefInfo(CS, Call); + return getModRefInfo(Call1, Call2); } else if (I->isFenceLike()) { // If this is a fence, just return ModRef. return ModRefInfo::ModRef; @@ -146,19 +145,19 @@ ModRefInfo AAResults::getModRefInfo(Instruction *I, ImmutableCallSite Call) { // is that if the call references what this instruction // defines, it must be clobbered by this location. const MemoryLocation DefLoc = MemoryLocation::get(I); - ModRefInfo MR = getModRefInfo(Call, DefLoc); + ModRefInfo MR = getModRefInfo(Call2, DefLoc); if (isModOrRefSet(MR)) return setModAndRef(MR); } return ModRefInfo::NoModRef; } -ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, +ModRefInfo AAResults::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { ModRefInfo Result = ModRefInfo::ModRef; for (const auto &AA : AAs) { - Result = intersectModRef(Result, AA->getModRefInfo(CS, Loc)); + Result = intersectModRef(Result, AA->getModRefInfo(Call, Loc)); // Early-exit the moment we reach the bottom of the lattice. if (isNoModRef(Result)) @@ -167,7 +166,7 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, // Try to refine the mod-ref info further using other API entry points to the // aggregate set of AA results. - auto MRB = getModRefBehavior(CS); + auto MRB = getModRefBehavior(Call); if (MRB == FMRB_DoesNotAccessMemory || MRB == FMRB_OnlyAccessesInaccessibleMem) return ModRefInfo::NoModRef; @@ -178,20 +177,19 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, Result = clearRef(Result); if (onlyAccessesArgPointees(MRB) || onlyAccessesInaccessibleOrArgMem(MRB)) { - bool DoesAlias = false; bool IsMustAlias = true; ModRefInfo AllArgsMask = ModRefInfo::NoModRef; if (doesAccessArgPointees(MRB)) { - for (auto AI = CS.arg_begin(), AE = CS.arg_end(); AI != AE; ++AI) { + for (auto AI = Call->arg_begin(), AE = Call->arg_end(); AI != AE; ++AI) { const Value *Arg = *AI; if (!Arg->getType()->isPointerTy()) continue; - unsigned ArgIdx = std::distance(CS.arg_begin(), AI); - MemoryLocation ArgLoc = MemoryLocation::getForArgument(CS, ArgIdx, TLI); + unsigned ArgIdx = std::distance(Call->arg_begin(), AI); + MemoryLocation ArgLoc = + MemoryLocation::getForArgument(Call, ArgIdx, TLI); AliasResult ArgAlias = alias(ArgLoc, Loc); if (ArgAlias != NoAlias) { - ModRefInfo ArgMask = getArgModRefInfo(CS, ArgIdx); - DoesAlias = true; + ModRefInfo ArgMask = getArgModRefInfo(Call, ArgIdx); AllArgsMask = unionModRef(AllArgsMask, ArgMask); } // Conservatively clear IsMustAlias unless only MustAlias is found. @@ -199,7 +197,7 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, } } // Return NoModRef if no alias found with any argument. - if (!DoesAlias) + if (isNoModRef(AllArgsMask)) return ModRefInfo::NoModRef; // Logical & between other AA analyses and argument analysis. Result = intersectModRef(Result, AllArgsMask); @@ -215,12 +213,12 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS, return Result; } -ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) { +ModRefInfo AAResults::getModRefInfo(const CallBase *Call1, + const CallBase *Call2) { ModRefInfo Result = ModRefInfo::ModRef; for (const auto &AA : AAs) { - Result = intersectModRef(Result, AA->getModRefInfo(CS1, CS2)); + Result = intersectModRef(Result, AA->getModRefInfo(Call1, Call2)); // Early-exit the moment we reach the bottom of the lattice. if (isNoModRef(Result)) @@ -230,59 +228,61 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, // Try to refine the mod-ref info further using other API entry points to the // aggregate set of AA results. - // If CS1 or CS2 are readnone, they don't interact. - auto CS1B = getModRefBehavior(CS1); - if (CS1B == FMRB_DoesNotAccessMemory) + // If Call1 or Call2 are readnone, they don't interact. + auto Call1B = getModRefBehavior(Call1); + if (Call1B == FMRB_DoesNotAccessMemory) return ModRefInfo::NoModRef; - auto CS2B = getModRefBehavior(CS2); - if (CS2B == FMRB_DoesNotAccessMemory) + auto Call2B = getModRefBehavior(Call2); + if (Call2B == FMRB_DoesNotAccessMemory) return ModRefInfo::NoModRef; // If they both only read from memory, there is no dependence. - if (onlyReadsMemory(CS1B) && onlyReadsMemory(CS2B)) + if (onlyReadsMemory(Call1B) && onlyReadsMemory(Call2B)) return ModRefInfo::NoModRef; - // If CS1 only reads memory, the only dependence on CS2 can be - // from CS1 reading memory written by CS2. - if (onlyReadsMemory(CS1B)) + // If Call1 only reads memory, the only dependence on Call2 can be + // from Call1 reading memory written by Call2. + if (onlyReadsMemory(Call1B)) Result = clearMod(Result); - else if (doesNotReadMemory(CS1B)) + else if (doesNotReadMemory(Call1B)) Result = clearRef(Result); - // If CS2 only access memory through arguments, accumulate the mod/ref - // information from CS1's references to the memory referenced by - // CS2's arguments. - if (onlyAccessesArgPointees(CS2B)) { - if (!doesAccessArgPointees(CS2B)) + // If Call2 only access memory through arguments, accumulate the mod/ref + // information from Call1's references to the memory referenced by + // Call2's arguments. + if (onlyAccessesArgPointees(Call2B)) { + if (!doesAccessArgPointees(Call2B)) return ModRefInfo::NoModRef; ModRefInfo R = ModRefInfo::NoModRef; bool IsMustAlias = true; - for (auto I = CS2.arg_begin(), E = CS2.arg_end(); I != E; ++I) { + for (auto I = Call2->arg_begin(), E = Call2->arg_end(); I != E; ++I) { const Value *Arg = *I; if (!Arg->getType()->isPointerTy()) continue; - unsigned CS2ArgIdx = std::distance(CS2.arg_begin(), I); - auto CS2ArgLoc = MemoryLocation::getForArgument(CS2, CS2ArgIdx, TLI); - - // ArgModRefCS2 indicates what CS2 might do to CS2ArgLoc, and the - // dependence of CS1 on that location is the inverse: - // - If CS2 modifies location, dependence exists if CS1 reads or writes. - // - If CS2 only reads location, dependence exists if CS1 writes. - ModRefInfo ArgModRefCS2 = getArgModRefInfo(CS2, CS2ArgIdx); + unsigned Call2ArgIdx = std::distance(Call2->arg_begin(), I); + auto Call2ArgLoc = + MemoryLocation::getForArgument(Call2, Call2ArgIdx, TLI); + + // ArgModRefC2 indicates what Call2 might do to Call2ArgLoc, and the + // dependence of Call1 on that location is the inverse: + // - If Call2 modifies location, dependence exists if Call1 reads or + // writes. + // - If Call2 only reads location, dependence exists if Call1 writes. + ModRefInfo ArgModRefC2 = getArgModRefInfo(Call2, Call2ArgIdx); ModRefInfo ArgMask = ModRefInfo::NoModRef; - if (isModSet(ArgModRefCS2)) + if (isModSet(ArgModRefC2)) ArgMask = ModRefInfo::ModRef; - else if (isRefSet(ArgModRefCS2)) + else if (isRefSet(ArgModRefC2)) ArgMask = ModRefInfo::Mod; - // ModRefCS1 indicates what CS1 might do to CS2ArgLoc, and we use + // ModRefC1 indicates what Call1 might do to Call2ArgLoc, and we use // above ArgMask to update dependence info. - ModRefInfo ModRefCS1 = getModRefInfo(CS1, CS2ArgLoc); - ArgMask = intersectModRef(ArgMask, ModRefCS1); + ModRefInfo ModRefC1 = getModRefInfo(Call1, Call2ArgLoc); + ArgMask = intersectModRef(ArgMask, ModRefC1); // Conservatively clear IsMustAlias unless only MustAlias is found. - IsMustAlias &= isMustSet(ModRefCS1); + IsMustAlias &= isMustSet(ModRefC1); R = intersectModRef(unionModRef(R, ArgMask), Result); if (R == Result) { @@ -300,31 +300,32 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, return IsMustAlias ? setMust(R) : clearMust(R); } - // If CS1 only accesses memory through arguments, check if CS2 references - // any of the memory referenced by CS1's arguments. If not, return NoModRef. - if (onlyAccessesArgPointees(CS1B)) { - if (!doesAccessArgPointees(CS1B)) + // If Call1 only accesses memory through arguments, check if Call2 references + // any of the memory referenced by Call1's arguments. If not, return NoModRef. + if (onlyAccessesArgPointees(Call1B)) { + if (!doesAccessArgPointees(Call1B)) return ModRefInfo::NoModRef; ModRefInfo R = ModRefInfo::NoModRef; bool IsMustAlias = true; - for (auto I = CS1.arg_begin(), E = CS1.arg_end(); I != E; ++I) { + for (auto I = Call1->arg_begin(), E = Call1->arg_end(); I != E; ++I) { const Value *Arg = *I; if (!Arg->getType()->isPointerTy()) continue; - unsigned CS1ArgIdx = std::distance(CS1.arg_begin(), I); - auto CS1ArgLoc = MemoryLocation::getForArgument(CS1, CS1ArgIdx, TLI); - - // ArgModRefCS1 indicates what CS1 might do to CS1ArgLoc; if CS1 might - // Mod CS1ArgLoc, then we care about either a Mod or a Ref by CS2. If - // CS1 might Ref, then we care only about a Mod by CS2. - ModRefInfo ArgModRefCS1 = getArgModRefInfo(CS1, CS1ArgIdx); - ModRefInfo ModRefCS2 = getModRefInfo(CS2, CS1ArgLoc); - if ((isModSet(ArgModRefCS1) && isModOrRefSet(ModRefCS2)) || - (isRefSet(ArgModRefCS1) && isModSet(ModRefCS2))) - R = intersectModRef(unionModRef(R, ArgModRefCS1), Result); + unsigned Call1ArgIdx = std::distance(Call1->arg_begin(), I); + auto Call1ArgLoc = + MemoryLocation::getForArgument(Call1, Call1ArgIdx, TLI); + + // ArgModRefC1 indicates what Call1 might do to Call1ArgLoc; if Call1 + // might Mod Call1ArgLoc, then we care about either a Mod or a Ref by + // Call2. If Call1 might Ref, then we care only about a Mod by Call2. + ModRefInfo ArgModRefC1 = getArgModRefInfo(Call1, Call1ArgIdx); + ModRefInfo ModRefC2 = getModRefInfo(Call2, Call1ArgLoc); + if ((isModSet(ArgModRefC1) && isModOrRefSet(ModRefC2)) || + (isRefSet(ArgModRefC1) && isModSet(ModRefC2))) + R = intersectModRef(unionModRef(R, ArgModRefC1), Result); // Conservatively clear IsMustAlias unless only MustAlias is found. - IsMustAlias &= isMustSet(ModRefCS2); + IsMustAlias &= isMustSet(ModRefC2); if (R == Result) { // On early exit, not all args were checked, cannot set Must. @@ -344,11 +345,11 @@ ModRefInfo AAResults::getModRefInfo(ImmutableCallSite CS1, return Result; } -FunctionModRefBehavior AAResults::getModRefBehavior(ImmutableCallSite CS) { +FunctionModRefBehavior AAResults::getModRefBehavior(const CallBase *Call) { FunctionModRefBehavior Result = FMRB_UnknownModRefBehavior; for (const auto &AA : AAs) { - Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(CS)); + Result = FunctionModRefBehavior(Result & AA->getModRefBehavior(Call)); // Early-exit the moment we reach the bottom of the lattice. if (Result == FMRB_DoesNotAccessMemory) @@ -560,8 +561,8 @@ ModRefInfo AAResults::callCapturesBefore(const Instruction *I, isa<Constant>(Object)) return ModRefInfo::ModRef; - ImmutableCallSite CS(I); - if (!CS.getInstruction() || CS.getInstruction() == Object) + const auto *Call = dyn_cast<CallBase>(I); + if (!Call || Call == Object) return ModRefInfo::ModRef; if (PointerMayBeCapturedBefore(Object, /* ReturnCaptures */ true, @@ -574,14 +575,14 @@ ModRefInfo AAResults::callCapturesBefore(const Instruction *I, ModRefInfo R = ModRefInfo::NoModRef; bool IsMustAlias = true; // Set flag only if no May found and all operands processed. - for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); + for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end(); CI != CE; ++CI, ++ArgNo) { // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it // couldn't be no-capture. if (!(*CI)->getType()->isPointerTy() || - (!CS.doesNotCapture(ArgNo) && - ArgNo < CS.getNumArgOperands() && !CS.isByValArgument(ArgNo))) + (!Call->doesNotCapture(ArgNo) && ArgNo < Call->getNumArgOperands() && + !Call->isByValArgument(ArgNo))) continue; AliasResult AR = alias(MemoryLocation(*CI), MemoryLocation(Object)); @@ -593,9 +594,9 @@ ModRefInfo AAResults::callCapturesBefore(const Instruction *I, IsMustAlias = false; if (AR == NoAlias) continue; - if (CS.doesNotAccessMemory(ArgNo)) + if (Call->doesNotAccessMemory(ArgNo)) continue; - if (CS.onlyReadsMemory(ArgNo)) { + if (Call->onlyReadsMemory(ArgNo)) { R = ModRefInfo::Ref; continue; } @@ -642,28 +643,6 @@ AnalysisKey AAManager::Key; namespace { -/// A wrapper pass for external alias analyses. This just squirrels away the -/// callback used to run any analyses and register their results. -struct ExternalAAWrapperPass : ImmutablePass { - using CallbackT = std::function<void(Pass &, Function &, AAResults &)>; - - CallbackT CB; - - static char ID; - - ExternalAAWrapperPass() : ImmutablePass(ID) { - initializeExternalAAWrapperPassPass(*PassRegistry::getPassRegistry()); - } - - explicit ExternalAAWrapperPass(CallbackT CB) - : ImmutablePass(ID), CB(std::move(CB)) { - initializeExternalAAWrapperPassPass(*PassRegistry::getPassRegistry()); - } - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.setPreservesAll(); - } -}; } // end anonymous namespace @@ -799,8 +778,8 @@ AAResults llvm::createLegacyPMAAResults(Pass &P, Function &F, } bool llvm::isNoAliasCall(const Value *V) { - if (auto CS = ImmutableCallSite(V)) - return CS.hasRetAttr(Attribute::NoAlias); + if (const auto *Call = dyn_cast<CallBase>(V)) + return Call->hasRetAttr(Attribute::NoAlias); return false; } diff --git a/lib/Analysis/AliasAnalysisEvaluator.cpp b/lib/Analysis/AliasAnalysisEvaluator.cpp index 764ae9160350..85dd4fe95b33 100644 --- a/lib/Analysis/AliasAnalysisEvaluator.cpp +++ b/lib/Analysis/AliasAnalysisEvaluator.cpp @@ -66,11 +66,10 @@ static inline void PrintModRefResults(const char *Msg, bool P, Instruction *I, } } -static inline void PrintModRefResults(const char *Msg, bool P, CallSite CSA, - CallSite CSB, Module *M) { +static inline void PrintModRefResults(const char *Msg, bool P, CallBase *CallA, + CallBase *CallB, Module *M) { if (PrintAll || P) { - errs() << " " << Msg << ": " << *CSA.getInstruction() << " <-> " - << *CSB.getInstruction() << '\n'; + errs() << " " << Msg << ": " << *CallA << " <-> " << *CallB << '\n'; } } @@ -98,7 +97,7 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { ++FunctionCount; SetVector<Value *> Pointers; - SmallSetVector<CallSite, 16> CallSites; + SmallSetVector<CallBase *, 16> Calls; SetVector<Value *> Loads; SetVector<Value *> Stores; @@ -114,16 +113,16 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { if (EvalAAMD && isa<StoreInst>(&*I)) Stores.insert(&*I); Instruction &Inst = *I; - if (auto CS = CallSite(&Inst)) { - Value *Callee = CS.getCalledValue(); + if (auto *Call = dyn_cast<CallBase>(&Inst)) { + Value *Callee = Call->getCalledValue(); // Skip actual functions for direct function calls. if (!isa<Function>(Callee) && isInterestingPointer(Callee)) Pointers.insert(Callee); // Consider formals. - for (Use &DataOp : CS.data_ops()) + for (Use &DataOp : Call->data_ops()) if (isInterestingPointer(DataOp)) Pointers.insert(DataOp); - CallSites.insert(CS); + Calls.insert(Call); } else { // Consider all operands. for (Instruction::op_iterator OI = Inst.op_begin(), OE = Inst.op_end(); @@ -136,19 +135,21 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { if (PrintAll || PrintNoAlias || PrintMayAlias || PrintPartialAlias || PrintMustAlias || PrintNoModRef || PrintMod || PrintRef || PrintModRef) errs() << "Function: " << F.getName() << ": " << Pointers.size() - << " pointers, " << CallSites.size() << " call sites\n"; + << " pointers, " << Calls.size() << " call sites\n"; // iterate over the worklist, and run the full (n^2)/2 disambiguations for (SetVector<Value *>::iterator I1 = Pointers.begin(), E = Pointers.end(); I1 != E; ++I1) { - uint64_t I1Size = MemoryLocation::UnknownSize; + auto I1Size = LocationSize::unknown(); Type *I1ElTy = cast<PointerType>((*I1)->getType())->getElementType(); - if (I1ElTy->isSized()) I1Size = DL.getTypeStoreSize(I1ElTy); + if (I1ElTy->isSized()) + I1Size = LocationSize::precise(DL.getTypeStoreSize(I1ElTy)); for (SetVector<Value *>::iterator I2 = Pointers.begin(); I2 != I1; ++I2) { - uint64_t I2Size = MemoryLocation::UnknownSize; - Type *I2ElTy =cast<PointerType>((*I2)->getType())->getElementType(); - if (I2ElTy->isSized()) I2Size = DL.getTypeStoreSize(I2ElTy); + auto I2Size = LocationSize::unknown(); + Type *I2ElTy = cast<PointerType>((*I2)->getType())->getElementType(); + if (I2ElTy->isSized()) + I2Size = LocationSize::precise(DL.getTypeStoreSize(I2ElTy)); AliasResult AR = AA.alias(*I1, I1Size, *I2, I2Size); switch (AR) { @@ -228,49 +229,48 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { } // Mod/ref alias analysis: compare all pairs of calls and values - for (CallSite C : CallSites) { - Instruction *I = C.getInstruction(); - + for (CallBase *Call : Calls) { for (auto Pointer : Pointers) { - uint64_t Size = MemoryLocation::UnknownSize; + auto Size = LocationSize::unknown(); Type *ElTy = cast<PointerType>(Pointer->getType())->getElementType(); - if (ElTy->isSized()) Size = DL.getTypeStoreSize(ElTy); + if (ElTy->isSized()) + Size = LocationSize::precise(DL.getTypeStoreSize(ElTy)); - switch (AA.getModRefInfo(C, Pointer, Size)) { + switch (AA.getModRefInfo(Call, Pointer, Size)) { case ModRefInfo::NoModRef: - PrintModRefResults("NoModRef", PrintNoModRef, I, Pointer, + PrintModRefResults("NoModRef", PrintNoModRef, Call, Pointer, F.getParent()); ++NoModRefCount; break; case ModRefInfo::Mod: - PrintModRefResults("Just Mod", PrintMod, I, Pointer, F.getParent()); + PrintModRefResults("Just Mod", PrintMod, Call, Pointer, F.getParent()); ++ModCount; break; case ModRefInfo::Ref: - PrintModRefResults("Just Ref", PrintRef, I, Pointer, F.getParent()); + PrintModRefResults("Just Ref", PrintRef, Call, Pointer, F.getParent()); ++RefCount; break; case ModRefInfo::ModRef: - PrintModRefResults("Both ModRef", PrintModRef, I, Pointer, + PrintModRefResults("Both ModRef", PrintModRef, Call, Pointer, F.getParent()); ++ModRefCount; break; case ModRefInfo::Must: - PrintModRefResults("Must", PrintMust, I, Pointer, F.getParent()); + PrintModRefResults("Must", PrintMust, Call, Pointer, F.getParent()); ++MustCount; break; case ModRefInfo::MustMod: - PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, I, Pointer, + PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, Call, Pointer, F.getParent()); ++MustModCount; break; case ModRefInfo::MustRef: - PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, I, Pointer, + PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, Call, Pointer, F.getParent()); ++MustRefCount; break; case ModRefInfo::MustModRef: - PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, I, + PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, Call, Pointer, F.getParent()); ++MustModRefCount; break; @@ -279,44 +279,46 @@ void AAEvaluator::runInternal(Function &F, AAResults &AA) { } // Mod/ref alias analysis: compare all pairs of calls - for (auto C = CallSites.begin(), Ce = CallSites.end(); C != Ce; ++C) { - for (auto D = CallSites.begin(); D != Ce; ++D) { - if (D == C) + for (CallBase *CallA : Calls) { + for (CallBase *CallB : Calls) { + if (CallA == CallB) continue; - switch (AA.getModRefInfo(*C, *D)) { + switch (AA.getModRefInfo(CallA, CallB)) { case ModRefInfo::NoModRef: - PrintModRefResults("NoModRef", PrintNoModRef, *C, *D, F.getParent()); + PrintModRefResults("NoModRef", PrintNoModRef, CallA, CallB, + F.getParent()); ++NoModRefCount; break; case ModRefInfo::Mod: - PrintModRefResults("Just Mod", PrintMod, *C, *D, F.getParent()); + PrintModRefResults("Just Mod", PrintMod, CallA, CallB, F.getParent()); ++ModCount; break; case ModRefInfo::Ref: - PrintModRefResults("Just Ref", PrintRef, *C, *D, F.getParent()); + PrintModRefResults("Just Ref", PrintRef, CallA, CallB, F.getParent()); ++RefCount; break; case ModRefInfo::ModRef: - PrintModRefResults("Both ModRef", PrintModRef, *C, *D, F.getParent()); + PrintModRefResults("Both ModRef", PrintModRef, CallA, CallB, + F.getParent()); ++ModRefCount; break; case ModRefInfo::Must: - PrintModRefResults("Must", PrintMust, *C, *D, F.getParent()); + PrintModRefResults("Must", PrintMust, CallA, CallB, F.getParent()); ++MustCount; break; case ModRefInfo::MustMod: - PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, *C, *D, + PrintModRefResults("Just Mod (MustAlias)", PrintMustMod, CallA, CallB, F.getParent()); ++MustModCount; break; case ModRefInfo::MustRef: - PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, *C, *D, + PrintModRefResults("Just Ref (MustAlias)", PrintMustRef, CallA, CallB, F.getParent()); ++MustRefCount; break; case ModRefInfo::MustModRef: - PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, *C, *D, - F.getParent()); + PrintModRefResults("Both ModRef (MustAlias)", PrintMustModRef, CallA, + CallB, F.getParent()); ++MustModRefCount; break; } diff --git a/lib/Analysis/AliasSetTracker.cpp b/lib/Analysis/AliasSetTracker.cpp index 8f903fa4f1e8..f6ad704cc914 100644 --- a/lib/Analysis/AliasSetTracker.cpp +++ b/lib/Analysis/AliasSetTracker.cpp @@ -13,9 +13,9 @@ #include "llvm/Analysis/AliasSetTracker.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Config/llvm-config.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Function.h" @@ -24,6 +24,7 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" @@ -55,7 +56,6 @@ void AliasSet::mergeSetIn(AliasSet &AS, AliasSetTracker &AST) { // Update the alias and access types of this set... Access |= AS.Access; Alias |= AS.Alias; - Volatile |= AS.Volatile; if (Alias == SetMustAlias) { // Check that these two merged sets really are must aliases. Since both @@ -113,10 +113,9 @@ void AliasSetTracker::removeAliasSet(AliasSet *AS) { if (AliasSet *Fwd = AS->Forward) { Fwd->dropRef(*this); AS->Forward = nullptr; - } - - if (AS->Alias == AliasSet::SetMayAlias) - TotalMayAliasSetSize -= AS->size(); + } else // Update TotalMayAliasSetSize only if not forwarding. + if (AS->Alias == AliasSet::SetMayAlias) + TotalMayAliasSetSize -= AS->size(); AliasSets.erase(AS); } @@ -169,7 +168,12 @@ void AliasSet::addUnknownInst(Instruction *I, AliasAnalysis &AA) { addRef(); UnknownInsts.emplace_back(I); - if (!I->mayWriteToMemory()) { + // Guards are marked as modifying memory for control flow modelling purposes, + // but don't actually modify any specific memory location. + using namespace PatternMatch; + bool MayWriteMemory = I->mayWriteToMemory() && !isGuard(I) && + !(I->use_empty() && match(I, m_Intrinsic<Intrinsic::invariant_start>())); + if (!MayWriteMemory) { Alias = SetMayAlias; Access |= RefAccess; return; @@ -226,12 +230,13 @@ bool AliasSet::aliasesUnknownInst(const Instruction *Inst, if (AliasAny) return true; - if (!Inst->mayReadOrWriteMemory()) - return false; + assert(Inst->mayReadOrWriteMemory() && + "Instruction must either read or write memory."); for (unsigned i = 0, e = UnknownInsts.size(); i != e; ++i) { if (auto *UnknownInst = getUnknownInst(i)) { - ImmutableCallSite C1(UnknownInst), C2(Inst); + const auto *C1 = dyn_cast<CallBase>(UnknownInst); + const auto *C2 = dyn_cast<CallBase>(Inst); if (!C1 || !C2 || isModOrRefSet(AA.getModRefInfo(C1, C2)) || isModOrRefSet(AA.getModRefInfo(C2, C1))) return true; @@ -246,6 +251,31 @@ bool AliasSet::aliasesUnknownInst(const Instruction *Inst, return false; } +Instruction* AliasSet::getUniqueInstruction() { + if (AliasAny) + // May have collapses alias set + return nullptr; + if (begin() != end()) { + if (!UnknownInsts.empty()) + // Another instruction found + return nullptr; + if (std::next(begin()) != end()) + // Another instruction found + return nullptr; + Value *Addr = begin()->getValue(); + assert(!Addr->user_empty() && + "where's the instruction which added this pointer?"); + if (std::next(Addr->user_begin()) != Addr->user_end()) + // Another instruction found -- this is really restrictive + // TODO: generalize! + return nullptr; + return cast<Instruction>(*(Addr->user_begin())); + } + if (1 != UnknownInsts.size()) + return nullptr; + return cast<Instruction>(UnknownInsts[0]); +} + void AliasSetTracker::clear() { // Delete all the PointerRec entries. for (PointerMapType::iterator I = PointerMap.begin(), E = PointerMap.end(); @@ -280,13 +310,6 @@ AliasSet *AliasSetTracker::mergeAliasSetsForPointer(const Value *Ptr, return FoundSet; } -bool AliasSetTracker::containsUnknown(const Instruction *Inst) const { - for (const AliasSet &AS : *this) - if (!AS.Forward && AS.aliasesUnknownInst(Inst, AA)) - return true; - return false; -} - AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { AliasSet *FoundSet = nullptr; for (iterator I = begin(), E = end(); I != E;) { @@ -295,17 +318,18 @@ AliasSet *AliasSetTracker::findAliasSetForUnknownInst(Instruction *Inst) { continue; if (!FoundSet) // If this is the first alias set ptr can go into. FoundSet = &*Cur; // Remember it. - else if (!Cur->Forward) // Otherwise, we must merge the sets. + else // Otherwise, we must merge the sets. FoundSet->mergeSetIn(*Cur, *this); // Merge in contents. } return FoundSet; } -/// getAliasSetForPointer - Return the alias set that the specified pointer -/// lives in. -AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, - LocationSize Size, - const AAMDNodes &AAInfo) { +AliasSet &AliasSetTracker::getAliasSetFor(const MemoryLocation &MemLoc) { + + Value * const Pointer = const_cast<Value*>(MemLoc.Ptr); + const LocationSize Size = MemLoc.Size; + const AAMDNodes &AAInfo = MemLoc.AATags; + AliasSet::PointerRec &Entry = getEntryFor(Pointer); if (AliasAnyAS) { @@ -351,83 +375,32 @@ AliasSet &AliasSetTracker::getAliasSetForPointer(Value *Pointer, void AliasSetTracker::add(Value *Ptr, LocationSize Size, const AAMDNodes &AAInfo) { - addPointer(Ptr, Size, AAInfo, AliasSet::NoAccess); + addPointer(MemoryLocation(Ptr, Size, AAInfo), AliasSet::NoAccess); } void AliasSetTracker::add(LoadInst *LI) { - if (isStrongerThanMonotonic(LI->getOrdering())) return addUnknown(LI); - - AAMDNodes AAInfo; - LI->getAAMetadata(AAInfo); - - AliasSet::AccessLattice Access = AliasSet::RefAccess; - const DataLayout &DL = LI->getModule()->getDataLayout(); - AliasSet &AS = addPointer(LI->getOperand(0), - DL.getTypeStoreSize(LI->getType()), AAInfo, Access); - if (LI->isVolatile()) AS.setVolatile(); + if (isStrongerThanMonotonic(LI->getOrdering())) + return addUnknown(LI); + addPointer(MemoryLocation::get(LI), AliasSet::RefAccess); } void AliasSetTracker::add(StoreInst *SI) { - if (isStrongerThanMonotonic(SI->getOrdering())) return addUnknown(SI); - - AAMDNodes AAInfo; - SI->getAAMetadata(AAInfo); - - AliasSet::AccessLattice Access = AliasSet::ModAccess; - const DataLayout &DL = SI->getModule()->getDataLayout(); - Value *Val = SI->getOperand(0); - AliasSet &AS = addPointer( - SI->getOperand(1), DL.getTypeStoreSize(Val->getType()), AAInfo, Access); - if (SI->isVolatile()) AS.setVolatile(); + if (isStrongerThanMonotonic(SI->getOrdering())) + return addUnknown(SI); + addPointer(MemoryLocation::get(SI), AliasSet::ModAccess); } void AliasSetTracker::add(VAArgInst *VAAI) { - AAMDNodes AAInfo; - VAAI->getAAMetadata(AAInfo); - - addPointer(VAAI->getOperand(0), MemoryLocation::UnknownSize, AAInfo, - AliasSet::ModRefAccess); + addPointer(MemoryLocation::get(VAAI), AliasSet::ModRefAccess); } void AliasSetTracker::add(AnyMemSetInst *MSI) { - AAMDNodes AAInfo; - MSI->getAAMetadata(AAInfo); - - uint64_t Len; - - if (ConstantInt *C = dyn_cast<ConstantInt>(MSI->getLength())) - Len = C->getZExtValue(); - else - Len = MemoryLocation::UnknownSize; - - AliasSet &AS = - addPointer(MSI->getRawDest(), Len, AAInfo, AliasSet::ModAccess); - auto *MS = dyn_cast<MemSetInst>(MSI); - if (MS && MS->isVolatile()) - AS.setVolatile(); + addPointer(MemoryLocation::getForDest(MSI), AliasSet::ModAccess); } void AliasSetTracker::add(AnyMemTransferInst *MTI) { - AAMDNodes AAInfo; - MTI->getAAMetadata(AAInfo); - - uint64_t Len; - if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) - Len = C->getZExtValue(); - else - Len = MemoryLocation::UnknownSize; - - AliasSet &ASSrc = - addPointer(MTI->getRawSource(), Len, AAInfo, AliasSet::RefAccess); - - AliasSet &ASDst = - addPointer(MTI->getRawDest(), Len, AAInfo, AliasSet::ModAccess); - - auto* MT = dyn_cast<MemTransferInst>(MTI); - if (MT && MT->isVolatile()) { - ASSrc.setVolatile(); - ASDst.setVolatile(); - } + addPointer(MemoryLocation::getForDest(MTI), AliasSet::ModAccess); + addPointer(MemoryLocation::getForSource(MTI), AliasSet::RefAccess); } void AliasSetTracker::addUnknown(Instruction *Inst) { @@ -471,6 +444,46 @@ void AliasSetTracker::add(Instruction *I) { return add(MSI); if (AnyMemTransferInst *MTI = dyn_cast<AnyMemTransferInst>(I)) return add(MTI); + + // Handle all calls with known mod/ref sets genericall + if (auto *Call = dyn_cast<CallBase>(I)) + if (Call->onlyAccessesArgMemory()) { + auto getAccessFromModRef = [](ModRefInfo MRI) { + if (isRefSet(MRI) && isModSet(MRI)) + return AliasSet::ModRefAccess; + else if (isModSet(MRI)) + return AliasSet::ModAccess; + else if (isRefSet(MRI)) + return AliasSet::RefAccess; + else + return AliasSet::NoAccess; + }; + + ModRefInfo CallMask = createModRefInfo(AA.getModRefBehavior(Call)); + + // Some intrinsics are marked as modifying memory for control flow + // modelling purposes, but don't actually modify any specific memory + // location. + using namespace PatternMatch; + if (Call->use_empty() && + match(Call, m_Intrinsic<Intrinsic::invariant_start>())) + CallMask = clearMod(CallMask); + + for (auto IdxArgPair : enumerate(Call->args())) { + int ArgIdx = IdxArgPair.index(); + const Value *Arg = IdxArgPair.value(); + if (!Arg->getType()->isPointerTy()) + continue; + MemoryLocation ArgLoc = + MemoryLocation::getForArgument(Call, ArgIdx, nullptr); + ModRefInfo ArgMask = AA.getArgModRefInfo(Call, ArgIdx); + ArgMask = intersectModRef(CallMask, ArgMask); + if (!isNoModRef(ArgMask)) + addPointer(ArgLoc, getAccessFromModRef(ArgMask)); + } + return; + } + return addUnknown(I); } @@ -496,12 +509,10 @@ void AliasSetTracker::add(const AliasSetTracker &AST) { add(Inst); // Loop over all of the pointers in this alias set. - for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) { - AliasSet &NewAS = - addPointer(ASI.getPointer(), ASI.getSize(), ASI.getAAInfo(), - (AliasSet::AccessLattice)AS.Access); - if (AS.isVolatile()) NewAS.setVolatile(); - } + for (AliasSet::iterator ASI = AS.begin(), E = AS.end(); ASI != E; ++ASI) + addPointer( + MemoryLocation(ASI.getPointer(), ASI.getSize(), ASI.getAAInfo()), + (AliasSet::AccessLattice)AS.Access); } } @@ -594,10 +605,9 @@ AliasSet &AliasSetTracker::mergeAllAliasSets() { return *AliasAnyAS; } -AliasSet &AliasSetTracker::addPointer(Value *P, LocationSize Size, - const AAMDNodes &AAInfo, +AliasSet &AliasSetTracker::addPointer(MemoryLocation Loc, AliasSet::AccessLattice E) { - AliasSet &AS = getAliasSetForPointer(P, Size, AAInfo); + AliasSet &AS = getAliasSetFor(Loc); AS.Access |= E; if (!AliasAnyAS && (TotalMayAliasSetSize > SaturationThreshold)) { @@ -623,7 +633,6 @@ void AliasSet::print(raw_ostream &OS) const { case ModRefAccess: OS << "Mod/Ref "; break; default: llvm_unreachable("Bad value for Access!"); } - if (isVolatile()) OS << "[volatile] "; if (Forward) OS << " forwarding to " << (void*)Forward; @@ -632,7 +641,10 @@ void AliasSet::print(raw_ostream &OS) const { for (iterator I = begin(), E = end(); I != E; ++I) { if (I != begin()) OS << ", "; I.getPointer()->printAsOperand(OS << "("); - OS << ", " << I.getSize() << ")"; + if (I.getSize() == LocationSize::unknown()) + OS << ", unknown)"; + else + OS << ", " << I.getSize() << ")"; } } if (!UnknownInsts.empty()) { diff --git a/lib/Analysis/Analysis.cpp b/lib/Analysis/Analysis.cpp index 30576cf1ae10..bb8742123a0f 100644 --- a/lib/Analysis/Analysis.cpp +++ b/lib/Analysis/Analysis.cpp @@ -39,7 +39,6 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeDependenceAnalysisWrapperPassPass(Registry); initializeDelinearizationPass(Registry); initializeDemandedBitsWrapperPassPass(Registry); - initializeDivergenceAnalysisPass(Registry); initializeDominanceFrontierWrapperPassPass(Registry); initializeDomViewerPass(Registry); initializeDomPrinterPass(Registry); @@ -58,6 +57,7 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeLazyBlockFrequencyInfoPassPass(Registry); initializeLazyValueInfoWrapperPassPass(Registry); initializeLazyValueInfoPrinterPass(Registry); + initializeLegacyDivergenceAnalysisPass(Registry); initializeLintPass(Registry); initializeLoopInfoWrapperPassPass(Registry); initializeMemDepPrinterPass(Registry); @@ -77,6 +77,8 @@ void llvm::initializeAnalysis(PassRegistry &Registry) { initializeRegionOnlyPrinterPass(Registry); initializeSCEVAAWrapperPassPass(Registry); initializeScalarEvolutionWrapperPassPass(Registry); + initializeStackSafetyGlobalInfoWrapperPassPass(Registry); + initializeStackSafetyInfoWrapperPassPass(Registry); initializeTargetTransformInfoWrapperPassPass(Registry); initializeTypeBasedAAWrapperPassPass(Registry); initializeScopedNoAliasAAWrapperPassPass(Registry); diff --git a/lib/Analysis/BasicAliasAnalysis.cpp b/lib/Analysis/BasicAliasAnalysis.cpp index 1a24ae3dba15..332eeaa00e73 100644 --- a/lib/Analysis/BasicAliasAnalysis.cpp +++ b/lib/Analysis/BasicAliasAnalysis.cpp @@ -31,7 +31,6 @@ #include "llvm/Analysis/PhiValues.h" #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" @@ -68,6 +67,16 @@ using namespace llvm; /// Enable analysis of recursive PHI nodes. static cl::opt<bool> EnableRecPhiAnalysis("basicaa-recphi", cl::Hidden, cl::init(false)); + +/// By default, even on 32-bit architectures we use 64-bit integers for +/// calculations. This will allow us to more-aggressively decompose indexing +/// expressions calculated using i64 values (e.g., long long in C) which is +/// common enough to worry about. +static cl::opt<bool> ForceAtLeast64Bits("basicaa-force-at-least-64b", + cl::Hidden, cl::init(true)); +static cl::opt<bool> DoubleCalcBits("basicaa-double-calc-bits", + cl::Hidden, cl::init(false)); + /// SearchLimitReached / SearchTimes shows how often the limit of /// to decompose GEPs is reached. It will affect the precision /// of basic alias analysis. @@ -134,7 +143,7 @@ static bool isNonEscapingLocalObject(const Value *V) { /// Returns true if the pointer is one which would have been considered an /// escape by isNonEscapingLocalObject. static bool isEscapeSource(const Value *V) { - if (ImmutableCallSite(V)) + if (isa<CallBase>(V)) return true; if (isa<Argument>(V)) @@ -381,13 +390,22 @@ static bool isObjectSize(const Value *V, uint64_t Size, const DataLayout &DL, } /// To ensure a pointer offset fits in an integer of size PointerSize -/// (in bits) when that size is smaller than 64. This is an issue in -/// particular for 32b programs with negative indices that rely on two's -/// complement wrap-arounds for precise alias information. -static int64_t adjustToPointerSize(int64_t Offset, unsigned PointerSize) { - assert(PointerSize <= 64 && "Invalid PointerSize!"); - unsigned ShiftBits = 64 - PointerSize; - return (int64_t)((uint64_t)Offset << ShiftBits) >> ShiftBits; +/// (in bits) when that size is smaller than the maximum pointer size. This is +/// an issue, for example, in particular for 32b pointers with negative indices +/// that rely on two's complement wrap-arounds for precise alias information +/// where the maximum pointer size is 64b. +static APInt adjustToPointerSize(APInt Offset, unsigned PointerSize) { + assert(PointerSize <= Offset.getBitWidth() && "Invalid PointerSize!"); + unsigned ShiftBits = Offset.getBitWidth() - PointerSize; + return (Offset << ShiftBits).ashr(ShiftBits); +} + +static unsigned getMaxPointerSize(const DataLayout &DL) { + unsigned MaxPointerSize = DL.getMaxPointerSizeInBits(); + if (MaxPointerSize < 64 && ForceAtLeast64Bits) MaxPointerSize = 64; + if (DoubleCalcBits) MaxPointerSize *= 2; + + return MaxPointerSize; } /// If V is a symbolic pointer expression, decompose it into a base pointer @@ -410,8 +428,7 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, unsigned MaxLookup = MaxLookupSearchDepth; SearchTimes++; - Decomposed.StructOffset = 0; - Decomposed.OtherOffset = 0; + unsigned MaxPointerSize = getMaxPointerSize(DL); Decomposed.VarIndices.clear(); do { // See if this is a bitcast or GEP. @@ -436,7 +453,7 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, const GEPOperator *GEPOp = dyn_cast<GEPOperator>(Op); if (!GEPOp) { - if (auto CS = ImmutableCallSite(V)) { + if (const auto *Call = dyn_cast<CallBase>(V)) { // CaptureTracking can know about special capturing properties of some // intrinsics like launder.invariant.group, that can't be expressed with // the attributes, but have properties like returning aliasing pointer. @@ -446,7 +463,7 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, // because it should be in sync with CaptureTracking. Not using it may // cause weird miscompilations where 2 aliasing pointers are assumed to // noalias. - if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { + if (auto *RP = getArgumentAliasingToReturnedPointer(Call)) { V = RP; continue; } @@ -501,13 +518,15 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, if (CIdx->isZero()) continue; Decomposed.OtherOffset += - DL.getTypeAllocSize(GTI.getIndexedType()) * CIdx->getSExtValue(); + (DL.getTypeAllocSize(GTI.getIndexedType()) * + CIdx->getValue().sextOrSelf(MaxPointerSize)) + .sextOrTrunc(MaxPointerSize); continue; } GepHasConstantOffset = false; - uint64_t Scale = DL.getTypeAllocSize(GTI.getIndexedType()); + APInt Scale(MaxPointerSize, DL.getTypeAllocSize(GTI.getIndexedType())); unsigned ZExtBits = 0, SExtBits = 0; // If the integer type is smaller than the pointer size, it is implicitly @@ -519,20 +538,34 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, // Use GetLinearExpression to decompose the index into a C1*V+C2 form. APInt IndexScale(Width, 0), IndexOffset(Width, 0); bool NSW = true, NUW = true; + const Value *OrigIndex = Index; Index = GetLinearExpression(Index, IndexScale, IndexOffset, ZExtBits, SExtBits, DL, 0, AC, DT, NSW, NUW); - // All GEP math happens in the width of the pointer type, - // so we can truncate the value to 64-bits as we don't handle - // currently pointers larger than 64 bits and we would crash - // later. TODO: Make `Scale` an APInt to avoid this problem. - if (IndexScale.getBitWidth() > 64) - IndexScale = IndexScale.sextOrTrunc(64); - // The GEP index scale ("Scale") scales C1*V+C2, yielding (C1*V+C2)*Scale. // This gives us an aggregate computation of (C1*Scale)*V + C2*Scale. - Decomposed.OtherOffset += IndexOffset.getSExtValue() * Scale; - Scale *= IndexScale.getSExtValue(); + + // It can be the case that, even through C1*V+C2 does not overflow for + // relevant values of V, (C2*Scale) can overflow. In that case, we cannot + // decompose the expression in this way. + // + // FIXME: C1*Scale and the other operations in the decomposed + // (C1*Scale)*V+C2*Scale can also overflow. We should check for this + // possibility. + APInt WideScaledOffset = IndexOffset.sextOrTrunc(MaxPointerSize*2) * + Scale.sext(MaxPointerSize*2); + if (WideScaledOffset.getMinSignedBits() > MaxPointerSize) { + Index = OrigIndex; + IndexScale = 1; + IndexOffset = 0; + + ZExtBits = SExtBits = 0; + if (PointerSize > Width) + SExtBits += PointerSize - Width; + } else { + Decomposed.OtherOffset += IndexOffset.sextOrTrunc(MaxPointerSize) * Scale; + Scale *= IndexScale.sextOrTrunc(MaxPointerSize); + } // If we already had an occurrence of this index variable, merge this // scale into it. For example, we want to handle: @@ -552,9 +585,8 @@ bool BasicAAResult::DecomposeGEPExpression(const Value *V, // pointer size. Scale = adjustToPointerSize(Scale, PointerSize); - if (Scale) { - VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, - static_cast<int64_t>(Scale)}; + if (!!Scale) { + VariableGEPIndex Entry = {Index, ZExtBits, SExtBits, Scale}; Decomposed.VarIndices.push_back(Entry); } } @@ -640,8 +672,8 @@ bool BasicAAResult::pointsToConstantMemory(const MemoryLocation &Loc, } /// Returns the behavior when calling the given call site. -FunctionModRefBehavior BasicAAResult::getModRefBehavior(ImmutableCallSite CS) { - if (CS.doesNotAccessMemory()) +FunctionModRefBehavior BasicAAResult::getModRefBehavior(const CallBase *Call) { + if (Call->doesNotAccessMemory()) // Can't do better than this. return FMRB_DoesNotAccessMemory; @@ -649,23 +681,23 @@ FunctionModRefBehavior BasicAAResult::getModRefBehavior(ImmutableCallSite CS) { // If the callsite knows it only reads memory, don't return worse // than that. - if (CS.onlyReadsMemory()) + if (Call->onlyReadsMemory()) Min = FMRB_OnlyReadsMemory; - else if (CS.doesNotReadMemory()) + else if (Call->doesNotReadMemory()) Min = FMRB_DoesNotReadMemory; - if (CS.onlyAccessesArgMemory()) + if (Call->onlyAccessesArgMemory()) Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesArgumentPointees); - else if (CS.onlyAccessesInaccessibleMemory()) + else if (Call->onlyAccessesInaccessibleMemory()) Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleMem); - else if (CS.onlyAccessesInaccessibleMemOrArgMem()) + else if (Call->onlyAccessesInaccessibleMemOrArgMem()) Min = FunctionModRefBehavior(Min & FMRB_OnlyAccessesInaccessibleOrArgMem); - // If CS has operand bundles then aliasing attributes from the function it - // calls do not directly apply to the CallSite. This can be made more - // precise in the future. - if (!CS.hasOperandBundles()) - if (const Function *F = CS.getCalledFunction()) + // If the call has operand bundles then aliasing attributes from the function + // it calls do not directly apply to the call. This can be made more precise + // in the future. + if (!Call->hasOperandBundles()) + if (const Function *F = Call->getCalledFunction()) Min = FunctionModRefBehavior(Min & getBestAAResults().getModRefBehavior(F)); @@ -698,9 +730,9 @@ FunctionModRefBehavior BasicAAResult::getModRefBehavior(const Function *F) { } /// Returns true if this is a writeonly (i.e Mod only) parameter. -static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, +static bool isWriteOnlyParam(const CallBase *Call, unsigned ArgIdx, const TargetLibraryInfo &TLI) { - if (CS.paramHasAttr(ArgIdx, Attribute::WriteOnly)) + if (Call->paramHasAttr(ArgIdx, Attribute::WriteOnly)) return true; // We can bound the aliasing properties of memset_pattern16 just as we can @@ -710,7 +742,8 @@ static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, // FIXME Consider handling this in InferFunctionAttr.cpp together with other // attributes. LibFunc F; - if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && + if (Call->getCalledFunction() && + TLI.getLibFunc(*Call->getCalledFunction(), F) && F == LibFunc_memset_pattern16 && TLI.has(F)) if (ArgIdx == 0) return true; @@ -722,23 +755,23 @@ static bool isWriteOnlyParam(ImmutableCallSite CS, unsigned ArgIdx, return false; } -ModRefInfo BasicAAResult::getArgModRefInfo(ImmutableCallSite CS, +ModRefInfo BasicAAResult::getArgModRefInfo(const CallBase *Call, unsigned ArgIdx) { // Checking for known builtin intrinsics and target library functions. - if (isWriteOnlyParam(CS, ArgIdx, TLI)) + if (isWriteOnlyParam(Call, ArgIdx, TLI)) return ModRefInfo::Mod; - if (CS.paramHasAttr(ArgIdx, Attribute::ReadOnly)) + if (Call->paramHasAttr(ArgIdx, Attribute::ReadOnly)) return ModRefInfo::Ref; - if (CS.paramHasAttr(ArgIdx, Attribute::ReadNone)) + if (Call->paramHasAttr(ArgIdx, Attribute::ReadNone)) return ModRefInfo::NoModRef; - return AAResultBase::getArgModRefInfo(CS, ArgIdx); + return AAResultBase::getArgModRefInfo(Call, ArgIdx); } -static bool isIntrinsicCall(ImmutableCallSite CS, Intrinsic::ID IID) { - const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction()); +static bool isIntrinsicCall(const CallBase *Call, Intrinsic::ID IID) { + const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call); return II && II->getIntrinsicID() == IID; } @@ -794,27 +827,34 @@ AliasResult BasicAAResult::alias(const MemoryLocation &LocA, /// Since we only look at local properties of this function, we really can't /// say much about this query. We do, however, use simple "address taken" /// analysis on local objects. -ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, +ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { - assert(notDifferentParent(CS.getInstruction(), Loc.Ptr) && + assert(notDifferentParent(Call, Loc.Ptr) && "AliasAnalysis query involving multiple functions!"); const Value *Object = GetUnderlyingObject(Loc.Ptr, DL); - // If this is a tail call and Loc.Ptr points to a stack location, we know that - // the tail call cannot access or modify the local stack. - // We cannot exclude byval arguments here; these belong to the caller of - // the current function not to the current function, and a tail callee - // may reference them. + // Calls marked 'tail' cannot read or write allocas from the current frame + // because the current frame might be destroyed by the time they run. However, + // a tail call may use an alloca with byval. Calling with byval copies the + // contents of the alloca into argument registers or stack slots, so there is + // no lifetime issue. if (isa<AllocaInst>(Object)) - if (const CallInst *CI = dyn_cast<CallInst>(CS.getInstruction())) - if (CI->isTailCall()) + if (const CallInst *CI = dyn_cast<CallInst>(Call)) + if (CI->isTailCall() && + !CI->getAttributes().hasAttrSomewhere(Attribute::ByVal)) return ModRefInfo::NoModRef; + // Stack restore is able to modify unescaped dynamic allocas. Assume it may + // modify them even though the alloca is not escaped. + if (auto *AI = dyn_cast<AllocaInst>(Object)) + if (!AI->isStaticAlloca() && isIntrinsicCall(Call, Intrinsic::stackrestore)) + return ModRefInfo::Mod; + // If the pointer is to a locally allocated object that does not escape, // then the call can not mod/ref the pointer unless the call takes the pointer // as an argument, and itself doesn't capture it. - if (!isa<Constant>(Object) && CS.getInstruction() != Object && + if (!isa<Constant>(Object) && Call != Object && isNonEscapingLocalObject(Object)) { // Optimistically assume that call doesn't touch Object and check this @@ -823,19 +863,20 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, bool IsMustAlias = true; unsigned OperandNo = 0; - for (auto CI = CS.data_operands_begin(), CE = CS.data_operands_end(); + for (auto CI = Call->data_operands_begin(), CE = Call->data_operands_end(); CI != CE; ++CI, ++OperandNo) { // Only look at the no-capture or byval pointer arguments. If this // pointer were passed to arguments that were neither of these, then it // couldn't be no-capture. if (!(*CI)->getType()->isPointerTy() || - (!CS.doesNotCapture(OperandNo) && - OperandNo < CS.getNumArgOperands() && !CS.isByValArgument(OperandNo))) + (!Call->doesNotCapture(OperandNo) && + OperandNo < Call->getNumArgOperands() && + !Call->isByValArgument(OperandNo))) continue; // Call doesn't access memory through this operand, so we don't care // if it aliases with Object. - if (CS.doesNotAccessMemory(OperandNo)) + if (Call->doesNotAccessMemory(OperandNo)) continue; // If this is a no-capture pointer argument, see if we can tell that it @@ -849,12 +890,12 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, continue; // Operand aliases 'Object', but call doesn't modify it. Strengthen // initial assumption and keep looking in case if there are more aliases. - if (CS.onlyReadsMemory(OperandNo)) { + if (Call->onlyReadsMemory(OperandNo)) { Result = setRef(Result); continue; } // Operand aliases 'Object' but call only writes into it. - if (CS.doesNotReadMemory(OperandNo)) { + if (Call->doesNotReadMemory(OperandNo)) { Result = setMod(Result); continue; } @@ -878,17 +919,16 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, } } - // If the CallSite is to malloc or calloc, we can assume that it doesn't + // If the call is to malloc or calloc, we can assume that it doesn't // modify any IR visible value. This is only valid because we assume these // routines do not read values visible in the IR. TODO: Consider special // casing realloc and strdup routines which access only their arguments as // well. Or alternatively, replace all of this with inaccessiblememonly once // that's implemented fully. - auto *Inst = CS.getInstruction(); - if (isMallocOrCallocLikeFn(Inst, &TLI)) { + if (isMallocOrCallocLikeFn(Call, &TLI)) { // Be conservative if the accessed pointer may alias the allocation - // fallback to the generic handling below. - if (getBestAAResults().alias(MemoryLocation(Inst), Loc) == NoAlias) + if (getBestAAResults().alias(MemoryLocation(Call), Loc) == NoAlias) return ModRefInfo::NoModRef; } @@ -896,7 +936,7 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // operands, i.e., source and destination of any given memcpy must no-alias. // If Loc must-aliases either one of these two locations, then it necessarily // no-aliases the other. - if (auto *Inst = dyn_cast<AnyMemCpyInst>(CS.getInstruction())) { + if (auto *Inst = dyn_cast<AnyMemCpyInst>(Call)) { AliasResult SrcAA, DestAA; if ((SrcAA = getBestAAResults().alias(MemoryLocation::getForSource(Inst), @@ -920,7 +960,7 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // While the assume intrinsic is marked as arbitrarily writing so that // proper control dependencies will be maintained, it never aliases any // particular memory location. - if (isIntrinsicCall(CS, Intrinsic::assume)) + if (isIntrinsicCall(Call, Intrinsic::assume)) return ModRefInfo::NoModRef; // Like assumes, guard intrinsics are also marked as arbitrarily writing so @@ -930,7 +970,7 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // *Unlike* assumes, guard intrinsics are modeled as reading memory since the // heap state at the point the guard is issued needs to be consistent in case // the guard invokes the "deopt" continuation. - if (isIntrinsicCall(CS, Intrinsic::experimental_guard)) + if (isIntrinsicCall(Call, Intrinsic::experimental_guard)) return ModRefInfo::Ref; // Like assumes, invariant.start intrinsics were also marked as arbitrarily @@ -956,20 +996,20 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS, // The transformation will cause the second store to be ignored (based on // rules of invariant.start) and print 40, while the first program always // prints 50. - if (isIntrinsicCall(CS, Intrinsic::invariant_start)) + if (isIntrinsicCall(Call, Intrinsic::invariant_start)) return ModRefInfo::Ref; // The AAResultBase base class has some smarts, lets use them. - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); } -ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) { +ModRefInfo BasicAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2) { // While the assume intrinsic is marked as arbitrarily writing so that // proper control dependencies will be maintained, it never aliases any // particular memory location. - if (isIntrinsicCall(CS1, Intrinsic::assume) || - isIntrinsicCall(CS2, Intrinsic::assume)) + if (isIntrinsicCall(Call1, Intrinsic::assume) || + isIntrinsicCall(Call2, Intrinsic::assume)) return ModRefInfo::NoModRef; // Like assumes, guard intrinsics are also marked as arbitrarily writing so @@ -983,26 +1023,26 @@ ModRefInfo BasicAAResult::getModRefInfo(ImmutableCallSite CS1, // NB! This function is *not* commutative, so we specical case two // possibilities for guard intrinsics. - if (isIntrinsicCall(CS1, Intrinsic::experimental_guard)) - return isModSet(createModRefInfo(getModRefBehavior(CS2))) + if (isIntrinsicCall(Call1, Intrinsic::experimental_guard)) + return isModSet(createModRefInfo(getModRefBehavior(Call2))) ? ModRefInfo::Ref : ModRefInfo::NoModRef; - if (isIntrinsicCall(CS2, Intrinsic::experimental_guard)) - return isModSet(createModRefInfo(getModRefBehavior(CS1))) + if (isIntrinsicCall(Call2, Intrinsic::experimental_guard)) + return isModSet(createModRefInfo(getModRefBehavior(Call1))) ? ModRefInfo::Mod : ModRefInfo::NoModRef; // The AAResultBase base class has some smarts, lets use them. - return AAResultBase::getModRefInfo(CS1, CS2); + return AAResultBase::getModRefInfo(Call1, Call2); } /// Provide ad-hoc rules to disambiguate accesses through two GEP operators, /// both having the exact same pointer operand. static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, - LocationSize V1Size, + LocationSize MaybeV1Size, const GEPOperator *GEP2, - LocationSize V2Size, + LocationSize MaybeV2Size, const DataLayout &DL) { assert(GEP1->getPointerOperand()->stripPointerCastsAndInvariantGroups() == GEP2->getPointerOperand()->stripPointerCastsAndInvariantGroups() && @@ -1018,10 +1058,13 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, // If we don't know the size of the accesses through both GEPs, we can't // determine whether the struct fields accessed can't alias. - if (V1Size == MemoryLocation::UnknownSize || - V2Size == MemoryLocation::UnknownSize) + if (MaybeV1Size == LocationSize::unknown() || + MaybeV2Size == LocationSize::unknown()) return MayAlias; + const uint64_t V1Size = MaybeV1Size.getValue(); + const uint64_t V2Size = MaybeV2Size.getValue(); + ConstantInt *C1 = dyn_cast<ConstantInt>(GEP1->getOperand(GEP1->getNumOperands() - 1)); ConstantInt *C2 = @@ -1029,8 +1072,12 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, // If the last (struct) indices are constants and are equal, the other indices // might be also be dynamically equal, so the GEPs can alias. - if (C1 && C2 && C1->getSExtValue() == C2->getSExtValue()) - return MayAlias; + if (C1 && C2) { + unsigned BitWidth = std::max(C1->getBitWidth(), C2->getBitWidth()); + if (C1->getValue().sextOrSelf(BitWidth) == + C2->getValue().sextOrSelf(BitWidth)) + return MayAlias; + } // Find the last-indexed type of the GEP, i.e., the type you'd get if // you stripped the last index. @@ -1113,6 +1160,10 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, return MayAlias; } + if (C1->getValue().getActiveBits() > 64 || + C2->getValue().getActiveBits() > 64) + return MayAlias; + // We know that: // - both GEPs begin indexing from the exact same pointer; // - the last indices in both GEPs are constants, indexing into a struct; @@ -1178,11 +1229,13 @@ static AliasResult aliasSameBasePointerGEPs(const GEPOperator *GEP1, // than (%alloca - 1), and so is not inbounds, a contradiction. bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, const DecomposedGEP &DecompGEP, const DecomposedGEP &DecompObject, - LocationSize ObjectAccessSize) { + LocationSize MaybeObjectAccessSize) { // If the object access size is unknown, or the GEP isn't inbounds, bail. - if (ObjectAccessSize == MemoryLocation::UnknownSize || !GEPOp->isInBounds()) + if (MaybeObjectAccessSize == LocationSize::unknown() || !GEPOp->isInBounds()) return false; + const uint64_t ObjectAccessSize = MaybeObjectAccessSize.getValue(); + // We need the object to be an alloca or a globalvariable, and want to know // the offset of the pointer from the object precisely, so no variable // indices are allowed. @@ -1191,8 +1244,8 @@ bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, !DecompObject.VarIndices.empty()) return false; - int64_t ObjectBaseOffset = DecompObject.StructOffset + - DecompObject.OtherOffset; + APInt ObjectBaseOffset = DecompObject.StructOffset + + DecompObject.OtherOffset; // If the GEP has no variable indices, we know the precise offset // from the base, then use it. If the GEP has variable indices, @@ -1200,10 +1253,11 @@ bool BasicAAResult::isGEPBaseAtNegativeOffset(const GEPOperator *GEPOp, // false in that case. if (!DecompGEP.VarIndices.empty()) return false; - int64_t GEPBaseOffset = DecompGEP.StructOffset; + + APInt GEPBaseOffset = DecompGEP.StructOffset; GEPBaseOffset += DecompGEP.OtherOffset; - return (GEPBaseOffset >= ObjectBaseOffset + (int64_t)ObjectAccessSize); + return GEPBaseOffset.sge(ObjectBaseOffset + (int64_t)ObjectAccessSize); } /// Provides a bunch of ad-hoc rules to disambiguate a GEP instruction against @@ -1218,13 +1272,17 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, LocationSize V2Size, const AAMDNodes &V2AAInfo, const Value *UnderlyingV1, const Value *UnderlyingV2) { DecomposedGEP DecompGEP1, DecompGEP2; + unsigned MaxPointerSize = getMaxPointerSize(DL); + DecompGEP1.StructOffset = DecompGEP1.OtherOffset = APInt(MaxPointerSize, 0); + DecompGEP2.StructOffset = DecompGEP2.OtherOffset = APInt(MaxPointerSize, 0); + bool GEP1MaxLookupReached = DecomposeGEPExpression(GEP1, DecompGEP1, DL, &AC, DT); bool GEP2MaxLookupReached = DecomposeGEPExpression(V2, DecompGEP2, DL, &AC, DT); - int64_t GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; - int64_t GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; + APInt GEP1BaseOffset = DecompGEP1.StructOffset + DecompGEP1.OtherOffset; + APInt GEP2BaseOffset = DecompGEP2.StructOffset + DecompGEP2.OtherOffset; assert(DecompGEP1.Base == UnderlyingV1 && DecompGEP2.Base == UnderlyingV2 && "DecomposeGEPExpression returned a result different from " @@ -1247,8 +1305,8 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, return NoAlias; // Do the base pointers alias? AliasResult BaseAlias = - aliasCheck(UnderlyingV1, MemoryLocation::UnknownSize, AAMDNodes(), - UnderlyingV2, MemoryLocation::UnknownSize, AAMDNodes()); + aliasCheck(UnderlyingV1, LocationSize::unknown(), AAMDNodes(), + UnderlyingV2, LocationSize::unknown(), AAMDNodes()); // Check for geps of non-aliasing underlying pointers where the offsets are // identical. @@ -1307,13 +1365,12 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // pointer, we know they cannot alias. // If both accesses are unknown size, we can't do anything useful here. - if (V1Size == MemoryLocation::UnknownSize && - V2Size == MemoryLocation::UnknownSize) + if (V1Size == LocationSize::unknown() && V2Size == LocationSize::unknown()) return MayAlias; - AliasResult R = aliasCheck(UnderlyingV1, MemoryLocation::UnknownSize, - AAMDNodes(), V2, MemoryLocation::UnknownSize, - V2AAInfo, nullptr, UnderlyingV2); + AliasResult R = + aliasCheck(UnderlyingV1, LocationSize::unknown(), AAMDNodes(), V2, + LocationSize::unknown(), V2AAInfo, nullptr, UnderlyingV2); if (R != MustAlias) { // If V2 may alias GEP base pointer, conservatively returns MayAlias. // If V2 is known not to alias GEP base pointer, then the two values @@ -1343,9 +1400,9 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // that the objects are partially overlapping. If the difference is // greater, we know they do not overlap. if (GEP1BaseOffset != 0 && DecompGEP1.VarIndices.empty()) { - if (GEP1BaseOffset >= 0) { - if (V2Size != MemoryLocation::UnknownSize) { - if ((uint64_t)GEP1BaseOffset < V2Size) + if (GEP1BaseOffset.sge(0)) { + if (V2Size != LocationSize::unknown()) { + if (GEP1BaseOffset.ult(V2Size.getValue())) return PartialAlias; return NoAlias; } @@ -1358,9 +1415,9 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // GEP1 V2 // We need to know that V2Size is not unknown, otherwise we might have // stripped a gep with negative index ('gep <ptr>, -1, ...). - if (V1Size != MemoryLocation::UnknownSize && - V2Size != MemoryLocation::UnknownSize) { - if (-(uint64_t)GEP1BaseOffset < V1Size) + if (V1Size != LocationSize::unknown() && + V2Size != LocationSize::unknown()) { + if ((-GEP1BaseOffset).ult(V1Size.getValue())) return PartialAlias; return NoAlias; } @@ -1368,7 +1425,7 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, } if (!DecompGEP1.VarIndices.empty()) { - uint64_t Modulo = 0; + APInt Modulo(MaxPointerSize, 0); bool AllPositive = true; for (unsigned i = 0, e = DecompGEP1.VarIndices.size(); i != e; ++i) { @@ -1376,7 +1433,7 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // Grab the least significant bit set in any of the scales. We // don't need std::abs here (even if the scale's negative) as we'll // be ^'ing Modulo with itself later. - Modulo |= (uint64_t)DecompGEP1.VarIndices[i].Scale; + Modulo |= DecompGEP1.VarIndices[i].Scale; if (AllPositive) { // If the Value could change between cycles, then any reasoning about @@ -1397,9 +1454,9 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // If the variable begins with a zero then we know it's // positive, regardless of whether the value is signed or // unsigned. - int64_t Scale = DecompGEP1.VarIndices[i].Scale; + APInt Scale = DecompGEP1.VarIndices[i].Scale; AllPositive = - (SignKnownZero && Scale >= 0) || (SignKnownOne && Scale < 0); + (SignKnownZero && Scale.sge(0)) || (SignKnownOne && Scale.slt(0)); } } @@ -1408,16 +1465,18 @@ BasicAAResult::aliasGEP(const GEPOperator *GEP1, LocationSize V1Size, // We can compute the difference between the two addresses // mod Modulo. Check whether that difference guarantees that the // two locations do not alias. - uint64_t ModOffset = (uint64_t)GEP1BaseOffset & (Modulo - 1); - if (V1Size != MemoryLocation::UnknownSize && - V2Size != MemoryLocation::UnknownSize && ModOffset >= V2Size && - V1Size <= Modulo - ModOffset) + APInt ModOffset = GEP1BaseOffset & (Modulo - 1); + if (V1Size != LocationSize::unknown() && + V2Size != LocationSize::unknown() && ModOffset.uge(V2Size.getValue()) && + (Modulo - ModOffset).uge(V1Size.getValue())) return NoAlias; // If we know all the variables are positive, then GEP1 >= GEP1BasePtr. // If GEP1BasePtr > V2 (GEP1BaseOffset > 0) then we know the pointers // don't alias if V2Size can fit in the gap between V2 and GEP1BasePtr. - if (AllPositive && GEP1BaseOffset > 0 && V2Size <= (uint64_t)GEP1BaseOffset) + if (AllPositive && GEP1BaseOffset.sgt(0) && + V2Size != LocationSize::unknown() && + GEP1BaseOffset.uge(V2Size.getValue())) return NoAlias; if (constantOffsetHeuristic(DecompGEP1.VarIndices, V1Size, V2Size, @@ -1597,7 +1656,7 @@ AliasResult BasicAAResult::aliasPHI(const PHINode *PN, LocationSize PNSize, // unknown to represent all the possible values the GEP could advance the // pointer to. if (isRecursive) - PNSize = MemoryLocation::UnknownSize; + PNSize = LocationSize::unknown(); AliasResult Alias = aliasCheck(V2, V2Size, V2AAInfo, V1Srcs[0], @@ -1631,7 +1690,7 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, const Value *O1, const Value *O2) { // If either of the memory references is empty, it doesn't matter what the // pointer values are. - if (V1Size == 0 || V2Size == 0) + if (V1Size.isZero() || V2Size.isZero()) return NoAlias; // Strip off any casts if they exist. @@ -1705,10 +1764,10 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, // If the size of one access is larger than the entire object on the other // side, then we know such behavior is undefined and can assume no alias. bool NullIsValidLocation = NullPointerIsDefined(&F); - if ((V1Size != MemoryLocation::UnknownSize && - isObjectSmallerThan(O2, V1Size, DL, TLI, NullIsValidLocation)) || - (V2Size != MemoryLocation::UnknownSize && - isObjectSmallerThan(O1, V2Size, DL, TLI, NullIsValidLocation))) + if ((V1Size.isPrecise() && isObjectSmallerThan(O2, V1Size.getValue(), DL, TLI, + NullIsValidLocation)) || + (V2Size.isPrecise() && isObjectSmallerThan(O1, V2Size.getValue(), DL, TLI, + NullIsValidLocation))) return NoAlias; // Check the cache before climbing up use-def chains. This also terminates @@ -1766,10 +1825,9 @@ AliasResult BasicAAResult::aliasCheck(const Value *V1, LocationSize V1Size, // If both pointers are pointing into the same object and one of them // accesses the entire object, then the accesses must overlap in some way. if (O1 == O2) - if (V1Size != MemoryLocation::UnknownSize && - V2Size != MemoryLocation::UnknownSize && - (isObjectSize(O1, V1Size, DL, TLI, NullIsValidLocation) || - isObjectSize(O2, V2Size, DL, TLI, NullIsValidLocation))) + if (V1Size.isPrecise() && V2Size.isPrecise() && + (isObjectSize(O1, V1Size.getValue(), DL, TLI, NullIsValidLocation) || + isObjectSize(O2, V2Size.getValue(), DL, TLI, NullIsValidLocation))) return AliasCache[Locs] = PartialAlias; // Recurse back into the best AA results we have, potentially with refined @@ -1824,7 +1882,7 @@ void BasicAAResult::GetIndexDifference( for (unsigned i = 0, e = Src.size(); i != e; ++i) { const Value *V = Src[i].V; unsigned ZExtBits = Src[i].ZExtBits, SExtBits = Src[i].SExtBits; - int64_t Scale = Src[i].Scale; + APInt Scale = Src[i].Scale; // Find V in Dest. This is N^2, but pointer indices almost never have more // than a few variable indexes. @@ -1844,7 +1902,7 @@ void BasicAAResult::GetIndexDifference( } // If we didn't consume this entry, add it to the end of the Dest list. - if (Scale) { + if (!!Scale) { VariableGEPIndex Entry = {V, ZExtBits, SExtBits, -Scale}; Dest.push_back(Entry); } @@ -1852,13 +1910,16 @@ void BasicAAResult::GetIndexDifference( } bool BasicAAResult::constantOffsetHeuristic( - const SmallVectorImpl<VariableGEPIndex> &VarIndices, LocationSize V1Size, - LocationSize V2Size, int64_t BaseOffset, AssumptionCache *AC, - DominatorTree *DT) { - if (VarIndices.size() != 2 || V1Size == MemoryLocation::UnknownSize || - V2Size == MemoryLocation::UnknownSize) + const SmallVectorImpl<VariableGEPIndex> &VarIndices, + LocationSize MaybeV1Size, LocationSize MaybeV2Size, APInt BaseOffset, + AssumptionCache *AC, DominatorTree *DT) { + if (VarIndices.size() != 2 || MaybeV1Size == LocationSize::unknown() || + MaybeV2Size == LocationSize::unknown()) return false; + const uint64_t V1Size = MaybeV1Size.getValue(); + const uint64_t V2Size = MaybeV2Size.getValue(); + const VariableGEPIndex &Var0 = VarIndices[0], &Var1 = VarIndices[1]; if (Var0.ZExtBits != Var1.ZExtBits || Var0.SExtBits != Var1.SExtBits || @@ -1895,14 +1956,15 @@ bool BasicAAResult::constantOffsetHeuristic( // the minimum distance between %i and %i + 5 is 3. APInt MinDiff = V0Offset - V1Offset, Wrapped = -MinDiff; MinDiff = APIntOps::umin(MinDiff, Wrapped); - uint64_t MinDiffBytes = MinDiff.getZExtValue() * std::abs(Var0.Scale); + APInt MinDiffBytes = + MinDiff.zextOrTrunc(Var0.Scale.getBitWidth()) * Var0.Scale.abs(); // We can't definitely say whether GEP1 is before or after V2 due to wrapping // arithmetic (i.e. for some values of GEP1 and V2 GEP1 < V2, and for other // values GEP1 > V2). We'll therefore only declare NoAlias if both V1Size and // V2Size can fit in the MinDiffBytes gap. - return V1Size + std::abs(BaseOffset) <= MinDiffBytes && - V2Size + std::abs(BaseOffset) <= MinDiffBytes; + return MinDiffBytes.uge(V1Size + BaseOffset.abs()) && + MinDiffBytes.uge(V2Size + BaseOffset.abs()); } //===----------------------------------------------------------------------===// diff --git a/lib/Analysis/BlockFrequencyInfo.cpp b/lib/Analysis/BlockFrequencyInfo.cpp index 41c295895213..ef27c36517ea 100644 --- a/lib/Analysis/BlockFrequencyInfo.cpp +++ b/lib/Analysis/BlockFrequencyInfo.cpp @@ -252,8 +252,8 @@ void BlockFrequencyInfo::setBlockFreqAndScale( /// Pop up a ghostview window with the current block frequency propagation /// rendered using dot. -void BlockFrequencyInfo::view() const { - ViewGraph(const_cast<BlockFrequencyInfo *>(this), "BlockFrequencyDAGs"); +void BlockFrequencyInfo::view(StringRef title) const { + ViewGraph(const_cast<BlockFrequencyInfo *>(this), title); } const Function *BlockFrequencyInfo::getFunction() const { diff --git a/lib/Analysis/BlockFrequencyInfoImpl.cpp b/lib/Analysis/BlockFrequencyInfoImpl.cpp index 3d095068e7ff..08ebcc47a807 100644 --- a/lib/Analysis/BlockFrequencyInfoImpl.cpp +++ b/lib/Analysis/BlockFrequencyInfoImpl.cpp @@ -156,9 +156,9 @@ static void combineWeight(Weight &W, const Weight &OtherW) { static void combineWeightsBySorting(WeightList &Weights) { // Sort so edges to the same node are adjacent. - llvm::sort(Weights.begin(), Weights.end(), - [](const Weight &L, - const Weight &R) { return L.TargetNode < R.TargetNode; }); + llvm::sort(Weights, [](const Weight &L, const Weight &R) { + return L.TargetNode < R.TargetNode; + }); // Combine adjacent edges. WeightList::iterator O = Weights.begin(); @@ -573,7 +573,9 @@ BlockFrequencyInfoImplBase::getProfileCountFromFreq(const Function &F, APInt BlockFreq(128, Freq); APInt EntryFreq(128, getEntryFreq()); BlockCount *= BlockFreq; - BlockCount = BlockCount.udiv(EntryFreq); + // Rounded division of BlockCount by EntryFreq. Since EntryFreq is unsigned + // lshr by 1 gives EntryFreq/2. + BlockCount = (BlockCount + EntryFreq.lshr(1)).udiv(EntryFreq); return BlockCount.getLimitedValue(); } @@ -705,7 +707,7 @@ static void findIrreducibleHeaders( "Expected irreducible CFG; -loop-info is likely invalid"); if (Headers.size() == InSCC.size()) { // Every block is a header. - llvm::sort(Headers.begin(), Headers.end()); + llvm::sort(Headers); return; } @@ -740,8 +742,8 @@ static void findIrreducibleHeaders( Others.push_back(Irr.Node); LLVM_DEBUG(dbgs() << " => other = " << BFI.getBlockName(Irr.Node) << "\n"); } - llvm::sort(Headers.begin(), Headers.end()); - llvm::sort(Others.begin(), Others.end()); + llvm::sort(Headers); + llvm::sort(Others); } static void createIrreducibleLoop( diff --git a/lib/Analysis/BranchProbabilityInfo.cpp b/lib/Analysis/BranchProbabilityInfo.cpp index 54a657073f0f..7f544b27fe9d 100644 --- a/lib/Analysis/BranchProbabilityInfo.cpp +++ b/lib/Analysis/BranchProbabilityInfo.cpp @@ -135,7 +135,7 @@ static const uint32_t IH_NONTAKEN_WEIGHT = 1; /// Add \p BB to PostDominatedByUnreachable set if applicable. void BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) { - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 0) { if (isa<UnreachableInst>(TI) || // If this block is terminated by a call to @@ -167,7 +167,7 @@ BranchProbabilityInfo::updatePostDominatedByUnreachable(const BasicBlock *BB) { void BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) { assert(!PostDominatedByColdCall.count(BB)); - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); if (TI->getNumSuccessors() == 0) return; @@ -202,7 +202,7 @@ BranchProbabilityInfo::updatePostDominatedByColdCall(const BasicBlock *BB) { /// Predict that a successor which leads necessarily to an /// unreachable-terminated block as extremely unlikely. bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); (void) TI; assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); assert(!isa<InvokeInst>(TI) && @@ -246,7 +246,7 @@ bool BranchProbabilityInfo::calcUnreachableHeuristics(const BasicBlock *BB) { // heuristic. The probability of the edge coming to unreachable block is // set to min of metadata and unreachable heuristic. bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || isa<IndirectBrInst>(TI))) return false; @@ -348,7 +348,7 @@ bool BranchProbabilityInfo::calcMetadataWeights(const BasicBlock *BB) { /// Return true if we could compute the weights for cold edges. /// Return false, otherwise. bool BranchProbabilityInfo::calcColdCallHeuristics(const BasicBlock *BB) { - const TerminatorInst *TI = BB->getTerminator(); + const Instruction *TI = BB->getTerminator(); (void) TI; assert(TI->getNumSuccessors() > 1 && "expected more than one successor!"); assert(!isa<InvokeInst>(TI) && diff --git a/lib/Analysis/CFG.cpp b/lib/Analysis/CFG.cpp index a319be8092f9..aa880a62b754 100644 --- a/lib/Analysis/CFG.cpp +++ b/lib/Analysis/CFG.cpp @@ -71,7 +71,7 @@ void llvm::FindFunctionBackedges(const Function &F, /// successor. unsigned llvm::GetSuccessorNumber(const BasicBlock *BB, const BasicBlock *Succ) { - const TerminatorInst *Term = BB->getTerminator(); + const Instruction *Term = BB->getTerminator(); #ifndef NDEBUG unsigned e = Term->getNumSuccessors(); #endif @@ -85,8 +85,9 @@ unsigned llvm::GetSuccessorNumber(const BasicBlock *BB, /// isCriticalEdge - Return true if the specified edge is a critical edge. /// Critical edges are edges from a block with multiple successors to a block /// with multiple predecessors. -bool llvm::isCriticalEdge(const TerminatorInst *TI, unsigned SuccNum, +bool llvm::isCriticalEdge(const Instruction *TI, unsigned SuccNum, bool AllowIdenticalEdges) { + assert(TI->isTerminator() && "Must be a terminator to have successors!"); assert(SuccNum < TI->getNumSuccessors() && "Illegal edge specification!"); if (TI->getNumSuccessors() == 1) return false; diff --git a/lib/Analysis/CFGPrinter.cpp b/lib/Analysis/CFGPrinter.cpp index 5b170dfa7903..6d01e9d5d447 100644 --- a/lib/Analysis/CFGPrinter.cpp +++ b/lib/Analysis/CFGPrinter.cpp @@ -7,9 +7,10 @@ // //===----------------------------------------------------------------------===// // -// This file defines a '-dot-cfg' analysis pass, which emits the -// cfg.<fnname>.dot file for each function in the program, with a graph of the -// CFG for that function. +// This file defines a `-dot-cfg` analysis pass, which emits the +// `<prefix>.<fnname>.dot` file for each function in the program, with a graph +// of the CFG for that function. The default value for `<prefix>` is `cfg` but +// can be customized as needed. // // The other main feature of this file is that it implements the // Function::viewCFG method, which is useful for debugging passes which operate @@ -27,6 +28,10 @@ static cl::opt<std::string> CFGFuncName( cl::desc("The name of a function (or its substring)" " whose CFG is viewed/printed.")); +static cl::opt<std::string> CFGDotFilenamePrefix( + "cfg-dot-filename-prefix", cl::Hidden, + cl::desc("The prefix used for the CFG dot file names.")); + namespace { struct CFGViewerLegacyPass : public FunctionPass { static char ID; // Pass identifcation, replacement for typeid @@ -90,7 +95,8 @@ PreservedAnalyses CFGOnlyViewerPass::run(Function &F, static void writeCFGToDotFile(Function &F, bool CFGOnly = false) { if (!CFGFuncName.empty() && !F.getName().contains(CFGFuncName)) return; - std::string Filename = ("cfg." + F.getName() + ".dot").str(); + std::string Filename = + (CFGDotFilenamePrefix + "." + F.getName() + ".dot").str(); errs() << "Writing '" << Filename << "'..."; std::error_code EC; diff --git a/lib/Analysis/CFLAndersAliasAnalysis.cpp b/lib/Analysis/CFLAndersAliasAnalysis.cpp index 194983418b08..1c61dd369a05 100644 --- a/lib/Analysis/CFLAndersAliasAnalysis.cpp +++ b/lib/Analysis/CFLAndersAliasAnalysis.cpp @@ -395,7 +395,7 @@ populateAliasMap(DenseMap<const Value *, std::vector<OffsetValue>> &AliasMap, } // Sort AliasList for faster lookup - llvm::sort(AliasList.begin(), AliasList.end()); + llvm::sort(AliasList); } } @@ -479,7 +479,7 @@ static void populateExternalRelations( } // Remove duplicates in ExtRelations - llvm::sort(ExtRelations.begin(), ExtRelations.end()); + llvm::sort(ExtRelations); ExtRelations.erase(std::unique(ExtRelations.begin(), ExtRelations.end()), ExtRelations.end()); } @@ -515,10 +515,9 @@ CFLAndersAAResult::FunctionInfo::getAttrs(const Value *V) const { return None; } -bool CFLAndersAAResult::FunctionInfo::mayAlias(const Value *LHS, - LocationSize LHSSize, - const Value *RHS, - LocationSize RHSSize) const { +bool CFLAndersAAResult::FunctionInfo::mayAlias( + const Value *LHS, LocationSize MaybeLHSSize, const Value *RHS, + LocationSize MaybeRHSSize) const { assert(LHS && RHS); // Check if we've seen LHS and RHS before. Sometimes LHS or RHS can be created @@ -557,11 +556,14 @@ bool CFLAndersAAResult::FunctionInfo::mayAlias(const Value *LHS, OffsetValue{RHS, 0}, Comparator); if (RangePair.first != RangePair.second) { - // Be conservative about UnknownSize - if (LHSSize == MemoryLocation::UnknownSize || - RHSSize == MemoryLocation::UnknownSize) + // Be conservative about unknown sizes + if (MaybeLHSSize == LocationSize::unknown() || + MaybeRHSSize == LocationSize::unknown()) return true; + const uint64_t LHSSize = MaybeLHSSize.getValue(); + const uint64_t RHSSize = MaybeRHSSize.getValue(); + for (const auto &OVal : make_range(RangePair)) { // Be conservative about UnknownOffset if (OVal.Offset == UnknownOffset) diff --git a/lib/Analysis/CFLGraph.h b/lib/Analysis/CFLGraph.h index 86812009da7c..12121d717433 100644 --- a/lib/Analysis/CFLGraph.h +++ b/lib/Analysis/CFLGraph.h @@ -594,7 +594,7 @@ template <typename CFLAA> class CFLGraphBuilder { // Determines whether or not we an instruction is useless to us (e.g. // FenceInst) static bool hasUsefulEdges(Instruction *Inst) { - bool IsNonInvokeRetTerminator = isa<TerminatorInst>(Inst) && + bool IsNonInvokeRetTerminator = Inst->isTerminator() && !isa<InvokeInst>(Inst) && !isa<ReturnInst>(Inst); return !isa<CmpInst>(Inst) && !isa<FenceInst>(Inst) && diff --git a/lib/Analysis/CGSCCPassManager.cpp b/lib/Analysis/CGSCCPassManager.cpp index b325afb8e7c5..fd2292ced017 100644 --- a/lib/Analysis/CGSCCPassManager.cpp +++ b/lib/Analysis/CGSCCPassManager.cpp @@ -54,6 +54,11 @@ PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, CGSCCUpdateResult &>::run(LazyCallGraph::SCC &InitialC, CGSCCAnalysisManager &AM, LazyCallGraph &G, CGSCCUpdateResult &UR) { + // Request PassInstrumentation from analysis manager, will use it to run + // instrumenting callbacks for the passes later. + PassInstrumentation PI = + AM.getResult<PassInstrumentationAnalysis>(InitialC, G); + PreservedAnalyses PA = PreservedAnalyses::all(); if (DebugLogging) @@ -67,8 +72,18 @@ PassManager<LazyCallGraph::SCC, CGSCCAnalysisManager, LazyCallGraph &, if (DebugLogging) dbgs() << "Running pass: " << Pass->name() << " on " << *C << "\n"; + // Check the PassInstrumentation's BeforePass callbacks before running the + // pass, skip its execution completely if asked to (callback returns false). + if (!PI.runBeforePass(*Pass, *C)) + continue; + PreservedAnalyses PassPA = Pass->run(*C, AM, G, UR); + if (UR.InvalidatedSCCs.count(C)) + PI.runAfterPassInvalidated<LazyCallGraph::SCC>(*Pass); + else + PI.runAfterPass<LazyCallGraph::SCC>(*Pass, *C); + // Update the SCC if necessary. C = UR.UpdatedC ? UR.UpdatedC : C; diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 8e8535abd053..c57d8ef69d69 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -30,10 +30,13 @@ add_llvm_library(LLVMAnalysis DominanceFrontier.cpp EHPersonalities.cpp GlobalsModRef.cpp + GuardUtils.cpp + IVDescriptors.cpp IVUsers.cpp IndirectCallPromotionAnalysis.cpp InlineCost.cpp InstCount.cpp + InstructionPrecedenceTracking.cpp InstructionSimplify.cpp Interval.cpp IntervalPartition.cpp @@ -42,6 +45,7 @@ add_llvm_library(LLVMAnalysis LazyBlockFrequencyInfo.cpp LazyCallGraph.cpp LazyValueInfo.cpp + LegacyDivergenceAnalysis.cpp Lint.cpp Loads.cpp LoopAccessAnalysis.cpp @@ -64,6 +68,7 @@ add_llvm_library(LLVMAnalysis ObjCARCInstKind.cpp OptimizationRemarkEmitter.cpp OrderedBasicBlock.cpp + OrderedInstructions.cpp PHITransAddr.cpp PhiValues.cpp PostDominators.cpp @@ -76,6 +81,8 @@ add_llvm_library(LLVMAnalysis ScalarEvolutionAliasAnalysis.cpp ScalarEvolutionExpander.cpp ScalarEvolutionNormalization.cpp + StackSafetyAnalysis.cpp + SyncDependenceAnalysis.cpp SyntheticCountsUtils.cpp TargetLibraryInfo.cpp TargetTransformInfo.cpp diff --git a/lib/Analysis/CallGraph.cpp b/lib/Analysis/CallGraph.cpp index cbdf5f63c557..0da678e1611b 100644 --- a/lib/Analysis/CallGraph.cpp +++ b/lib/Analysis/CallGraph.cpp @@ -97,8 +97,7 @@ void CallGraph::print(raw_ostream &OS) const { for (const auto &I : *this) Nodes.push_back(I.second.get()); - llvm::sort(Nodes.begin(), Nodes.end(), - [](CallGraphNode *LHS, CallGraphNode *RHS) { + llvm::sort(Nodes, [](CallGraphNode *LHS, CallGraphNode *RHS) { if (Function *LF = LHS->getFunction()) if (Function *RF = RHS->getFunction()) return LF->getName() < RF->getName(); diff --git a/lib/Analysis/CallGraphSCCPass.cpp b/lib/Analysis/CallGraphSCCPass.cpp index 4c33c420b65d..0aed57a39387 100644 --- a/lib/Analysis/CallGraphSCCPass.cpp +++ b/lib/Analysis/CallGraphSCCPass.cpp @@ -22,11 +22,13 @@ #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManagers.h" #include "llvm/IR/Module.h" #include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassTimingInfo.h" #include "llvm/Pass.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -123,24 +125,34 @@ bool CGPassManager::RunPassOnSCC(Pass *P, CallGraphSCC &CurSCC, Module &M = CG.getModule(); if (!PM) { - CallGraphSCCPass *CGSP = (CallGraphSCCPass*)P; + CallGraphSCCPass *CGSP = (CallGraphSCCPass *)P; if (!CallGraphUpToDate) { DevirtualizedCall |= RefreshCallGraph(CurSCC, CG, false); CallGraphUpToDate = true; } { - unsigned InstrCount = 0; + unsigned InstrCount, SCCCount = 0; + StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount; bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); TimeRegion PassTimer(getPassTimer(CGSP)); if (EmitICRemark) - InstrCount = initSizeRemarkInfo(M); + InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount); Changed = CGSP->runOnSCC(CurSCC); - // If the pass modified the module, it may have modified the instruction - // count of the module. Try emitting a remark. - if (EmitICRemark) - emitInstrCountChangedRemark(P, M, InstrCount); + if (EmitICRemark) { + // FIXME: Add getInstructionCount to CallGraphSCC. + SCCCount = M.getInstructionCount(); + // Is there a difference in the number of instructions in the module? + if (SCCCount != InstrCount) { + // Yep. Emit a remark and update InstrCount. + int64_t Delta = + static_cast<int64_t>(SCCCount) - static_cast<int64_t>(InstrCount); + emitInstrCountChangedRemark(P, M, Delta, InstrCount, + FunctionToInstrCount); + InstrCount = SCCCount; + } + } } // After the CGSCCPass is done, when assertions are enabled, use @@ -621,23 +633,40 @@ namespace { bool runOnSCC(CallGraphSCC &SCC) override { bool BannerPrinted = false; - auto PrintBannerOnce = [&] () { + auto PrintBannerOnce = [&]() { if (BannerPrinted) return; OS << Banner; BannerPrinted = true; - }; + }; + + bool NeedModule = llvm::forcePrintModuleIR(); + if (isFunctionInPrintList("*") && NeedModule) { + PrintBannerOnce(); + OS << "\n"; + SCC.getCallGraph().getModule().print(OS, nullptr); + return false; + } + bool FoundFunction = false; for (CallGraphNode *CGN : SCC) { if (Function *F = CGN->getFunction()) { if (!F->isDeclaration() && isFunctionInPrintList(F->getName())) { - PrintBannerOnce(); - F->print(OS); + FoundFunction = true; + if (!NeedModule) { + PrintBannerOnce(); + F->print(OS); + } } } else if (isFunctionInPrintList("*")) { PrintBannerOnce(); OS << "\nPrinting <null> Function\n"; } } + if (NeedModule && FoundFunction) { + PrintBannerOnce(); + OS << "\n"; + SCC.getCallGraph().getModule().print(OS, nullptr); + } return false; } diff --git a/lib/Analysis/CaptureTracking.cpp b/lib/Analysis/CaptureTracking.cpp index d4f73bdb4361..669f4f2835fa 100644 --- a/lib/Analysis/CaptureTracking.cpp +++ b/lib/Analysis/CaptureTracking.cpp @@ -23,7 +23,6 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" @@ -158,7 +157,8 @@ namespace { /// storing the value (or part of it) into memory anywhere automatically /// counts as capturing it or not. bool llvm::PointerMayBeCaptured(const Value *V, - bool ReturnCaptures, bool StoreCaptures) { + bool ReturnCaptures, bool StoreCaptures, + unsigned MaxUsesToExplore) { assert(!isa<GlobalValue>(V) && "It doesn't make sense to ask whether a global is captured."); @@ -169,7 +169,7 @@ bool llvm::PointerMayBeCaptured(const Value *V, (void)StoreCaptures; SimpleCaptureTracker SCT(ReturnCaptures); - PointerMayBeCaptured(V, &SCT); + PointerMayBeCaptured(V, &SCT, MaxUsesToExplore); return SCT.Captured; } @@ -186,13 +186,15 @@ bool llvm::PointerMayBeCaptured(const Value *V, bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, bool StoreCaptures, const Instruction *I, const DominatorTree *DT, bool IncludeI, - OrderedBasicBlock *OBB) { + OrderedBasicBlock *OBB, + unsigned MaxUsesToExplore) { assert(!isa<GlobalValue>(V) && "It doesn't make sense to ask whether a global is captured."); bool UseNewOBB = OBB == nullptr; if (!DT) - return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures); + return PointerMayBeCaptured(V, ReturnCaptures, StoreCaptures, + MaxUsesToExplore); if (UseNewOBB) OBB = new OrderedBasicBlock(I->getParent()); @@ -200,29 +202,25 @@ bool llvm::PointerMayBeCapturedBefore(const Value *V, bool ReturnCaptures, // with StoreCaptures. CapturesBefore CB(ReturnCaptures, I, DT, IncludeI, OBB); - PointerMayBeCaptured(V, &CB); + PointerMayBeCaptured(V, &CB, MaxUsesToExplore); if (UseNewOBB) delete OBB; return CB.Captured; } -/// TODO: Write a new FunctionPass AliasAnalysis so that it can keep -/// a cache. Then we can move the code from BasicAliasAnalysis into -/// that path, and remove this threshold. -static int const Threshold = 20; - -void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { +void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker, + unsigned MaxUsesToExplore) { assert(V->getType()->isPointerTy() && "Capture is for pointers only!"); - SmallVector<const Use *, Threshold> Worklist; - SmallSet<const Use *, Threshold> Visited; + SmallVector<const Use *, DefaultMaxUsesToExplore> Worklist; + SmallSet<const Use *, DefaultMaxUsesToExplore> Visited; auto AddUses = [&](const Value *V) { - int Count = 0; + unsigned Count = 0; for (const Use &U : V->uses()) { // If there are lots of uses, conservatively say that the value // is captured to avoid taking too much compile time. - if (Count++ >= Threshold) + if (Count++ >= MaxUsesToExplore) return Tracker->tooManyUses(); if (!Visited.insert(&U).second) continue; @@ -241,11 +239,12 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { switch (I->getOpcode()) { case Instruction::Call: case Instruction::Invoke: { - CallSite CS(I); + auto *Call = cast<CallBase>(I); // Not captured if the callee is readonly, doesn't return a copy through // its return value and doesn't unwind (a readonly function can leak bits // by throwing an exception or not depending on the input value). - if (CS.onlyReadsMemory() && CS.doesNotThrow() && I->getType()->isVoidTy()) + if (Call->onlyReadsMemory() && Call->doesNotThrow() && + Call->getType()->isVoidTy()) break; // The pointer is not captured if returned pointer is not captured. @@ -253,14 +252,14 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { // marked with nocapture do not capture. This means that places like // GetUnderlyingObject in ValueTracking or DecomposeGEPExpression // in BasicAA also need to know about this property. - if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) { - AddUses(I); + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call)) { + AddUses(Call); break; } // Volatile operations effectively capture the memory location that they // load and store to. - if (auto *MI = dyn_cast<MemIntrinsic>(I)) + if (auto *MI = dyn_cast<MemIntrinsic>(Call)) if (MI->isVolatile()) if (Tracker->captured(U)) return; @@ -272,13 +271,14 @@ void llvm::PointerMayBeCaptured(const Value *V, CaptureTracker *Tracker) { // that loading a value from a pointer does not cause the pointer to be // captured, even though the loaded value might be the pointer itself // (think of self-referential objects). - CallSite::data_operand_iterator B = - CS.data_operands_begin(), E = CS.data_operands_end(); - for (CallSite::data_operand_iterator A = B; A != E; ++A) - if (A->get() == V && !CS.doesNotCapture(A - B)) + for (auto IdxOpPair : enumerate(Call->data_ops())) { + int Idx = IdxOpPair.index(); + Value *A = IdxOpPair.value(); + if (A == V && !Call->doesNotCapture(Idx)) // The parameter is not marked 'nocapture' - captured. if (Tracker->captured(U)) return; + } break; } case Instruction::Load: diff --git a/lib/Analysis/CmpInstAnalysis.cpp b/lib/Analysis/CmpInstAnalysis.cpp index 159c1a2d135a..27071babec5c 100644 --- a/lib/Analysis/CmpInstAnalysis.cpp +++ b/lib/Analysis/CmpInstAnalysis.cpp @@ -40,28 +40,28 @@ unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { } } -Value *llvm::getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, - CmpInst::Predicate &NewICmpPred) { +Constant *llvm::getPredForICmpCode(unsigned Code, bool Sign, Type *OpTy, + CmpInst::Predicate &Pred) { switch (Code) { default: llvm_unreachable("Illegal ICmp code!"); case 0: // False. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case 1: NewICmpPred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; - case 2: NewICmpPred = ICmpInst::ICMP_EQ; break; - case 3: NewICmpPred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; - case 4: NewICmpPred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; - case 5: NewICmpPred = ICmpInst::ICMP_NE; break; - case 6: NewICmpPred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 0); + case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; + case 2: Pred = ICmpInst::ICMP_EQ; break; + case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; + case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; + case 5: Pred = ICmpInst::ICMP_NE; break; + case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; case 7: // True. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); + return ConstantInt::get(CmpInst::makeCmpResultType(OpTy), 1); } return nullptr; } -bool llvm::PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { - return (CmpInst::isSigned(p1) == CmpInst::isSigned(p2)) || - (CmpInst::isSigned(p1) && ICmpInst::isEquality(p2)) || - (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1)); +bool llvm::predicatesFoldable(ICmpInst::Predicate P1, ICmpInst::Predicate P2) { + return (CmpInst::isSigned(P1) == CmpInst::isSigned(P2)) || + (CmpInst::isSigned(P1) && ICmpInst::isEquality(P2)) || + (CmpInst::isSigned(P2) && ICmpInst::isEquality(P1)); } bool llvm::decomposeBitTestICmp(Value *LHS, Value *RHS, diff --git a/lib/Analysis/ConstantFolding.cpp b/lib/Analysis/ConstantFolding.cpp index c5281c57bc19..5da29d6d2372 100644 --- a/lib/Analysis/ConstantFolding.cpp +++ b/lib/Analysis/ConstantFolding.cpp @@ -347,9 +347,20 @@ Constant *llvm::ConstantFoldLoadThroughBitcast(Constant *C, Type *DestTy, // We're simulating a load through a pointer that was bitcast to point to // a different type, so we can try to walk down through the initial - // elements of an aggregate to see if some part of th e aggregate is + // elements of an aggregate to see if some part of the aggregate is // castable to implement the "load" semantic model. - C = C->getAggregateElement(0u); + if (SrcTy->isStructTy()) { + // Struct types might have leading zero-length elements like [0 x i32], + // which are certainly not what we are looking for, so skip them. + unsigned Elem = 0; + Constant *ElemC; + do { + ElemC = C->getAggregateElement(Elem++); + } while (ElemC && DL.getTypeSizeInBits(ElemC->getType()) == 0); + C = ElemC; + } else { + C = C->getAggregateElement(0u); + } } while (C); return nullptr; @@ -960,10 +971,8 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, NewIdxs.size() > *LastIRIndex) { InRangeIndex = LastIRIndex; for (unsigned I = 0; I <= *LastIRIndex; ++I) - if (NewIdxs[I] != InnermostGEP->getOperand(I + 1)) { - InRangeIndex = None; - break; - } + if (NewIdxs[I] != InnermostGEP->getOperand(I + 1)) + return nullptr; } // Create a GEP. @@ -985,11 +994,6 @@ Constant *SymbolicallyEvaluateGEP(const GEPOperator *GEP, /// returned, if not, null is returned. Note that this function can fail when /// attempting to fold instructions like loads and stores, which have no /// constant expression form. -/// -/// TODO: This function neither utilizes nor preserves nsw/nuw/inbounds/inrange -/// etc information, due to only being passed an opcode and operands. Constant -/// folding using this function strips this information. -/// Constant *ConstantFoldInstOperandsImpl(const Value *InstOrCE, unsigned Opcode, ArrayRef<Constant *> Ops, const DataLayout &DL, @@ -1370,6 +1374,8 @@ bool llvm::canConstantFoldCallTo(ImmutableCallSite CS, const Function *F) { case Intrinsic::fabs: case Intrinsic::minnum: case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: case Intrinsic::log: case Intrinsic::log2: case Intrinsic::log10: @@ -1389,6 +1395,8 @@ bool llvm::canConstantFoldCallTo(ImmutableCallSite CS, const Function *F) { case Intrinsic::ctpop: case Intrinsic::ctlz: case Intrinsic::cttz: + case Intrinsic::fshl: + case Intrinsic::fshr: case Intrinsic::fma: case Intrinsic::fmuladd: case Intrinsic::copysign: @@ -1402,6 +1410,10 @@ bool llvm::canConstantFoldCallTo(ImmutableCallSite CS, const Function *F) { case Intrinsic::usub_with_overflow: case Intrinsic::smul_with_overflow: case Intrinsic::umul_with_overflow: + case Intrinsic::sadd_sat: + case Intrinsic::uadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::usub_sat: case Intrinsic::convert_from_fp16: case Intrinsic::convert_to_fp16: case Intrinsic::bitreverse: @@ -1413,6 +1425,23 @@ bool llvm::canConstantFoldCallTo(ImmutableCallSite CS, const Function *F) { case Intrinsic::x86_sse2_cvtsd2si64: case Intrinsic::x86_sse2_cvttsd2si: case Intrinsic::x86_sse2_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: + case Intrinsic::is_constant: return true; default: return false; @@ -1553,7 +1582,7 @@ Constant *ConstantFoldBinaryFP(double (*NativeFP)(double, double), double V, /// result. Returns null if the conversion cannot be performed, otherwise /// returns the Constant value resulting from the conversion. Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero, - Type *Ty) { + Type *Ty, bool IsSigned) { // All of these conversion intrinsics form an integer of at most 64bits. unsigned ResultWidth = Ty->getIntegerBitWidth(); assert(ResultWidth <= 64 && @@ -1565,11 +1594,11 @@ Constant *ConstantFoldSSEConvertToInt(const APFloat &Val, bool roundTowardZero, : APFloat::rmNearestTiesToEven; APFloat::opStatus status = Val.convertToInteger(makeMutableArrayRef(UIntVal), ResultWidth, - /*isSigned=*/true, mode, &isExact); + IsSigned, mode, &isExact); if (status != APFloat::opOK && (!roundTowardZero || status != APFloat::opInexact)) return nullptr; - return ConstantInt::get(Ty, UIntVal, /*isSigned=*/true); + return ConstantInt::get(Ty, UIntVal, IsSigned); } double getValueAsDouble(ConstantFP *Op) { @@ -1587,14 +1616,49 @@ double getValueAsDouble(ConstantFP *Op) { return APF.convertToDouble(); } +static bool isManifestConstant(const Constant *c) { + if (isa<ConstantData>(c)) { + return true; + } else if (isa<ConstantAggregate>(c) || isa<ConstantExpr>(c)) { + for (const Value *subc : c->operand_values()) { + if (!isManifestConstant(cast<Constant>(subc))) + return false; + } + return true; + } + return false; +} + +static bool getConstIntOrUndef(Value *Op, const APInt *&C) { + if (auto *CI = dyn_cast<ConstantInt>(Op)) { + C = &CI->getValue(); + return true; + } + if (isa<UndefValue>(Op)) { + C = nullptr; + return true; + } + return false; +} + Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, ArrayRef<Constant *> Operands, const TargetLibraryInfo *TLI, ImmutableCallSite CS) { if (Operands.size() == 1) { + if (IntrinsicID == Intrinsic::is_constant) { + // We know we have a "Constant" argument. But we want to only + // return true for manifest constants, not those that depend on + // constants with unknowable values, e.g. GlobalValue or BlockAddress. + if (isManifestConstant(Operands[0])) + return ConstantInt::getTrue(Ty->getContext()); + return nullptr; + } if (isa<UndefValue>(Operands[0])) { - // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN - if (IntrinsicID == Intrinsic::cos) + // cosine(arg) is between -1 and 1. cosine(invalid arg) is NaN. + // ctpop() is between 0 and bitwidth, pick 0 for undef. + if (IntrinsicID == Intrinsic::cos || + IntrinsicID == Intrinsic::ctpop) return Constant::getNullValue(Ty); if (IntrinsicID == Intrinsic::bswap || IntrinsicID == Intrinsic::bitreverse || @@ -1849,7 +1913,8 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, if (ConstantFP *FPOp = dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), - /*roundTowardZero=*/false, Ty); + /*roundTowardZero=*/false, Ty, + /*IsSigned*/true); break; case Intrinsic::x86_sse_cvttss2si: case Intrinsic::x86_sse_cvttss2si64: @@ -1858,7 +1923,8 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, if (ConstantFP *FPOp = dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), - /*roundTowardZero=*/true, Ty); + /*roundTowardZero=*/true, Ty, + /*IsSigned*/true); break; } } @@ -1899,6 +1965,18 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, return ConstantFP::get(Ty->getContext(), maxnum(C1, C2)); } + if (IntrinsicID == Intrinsic::minimum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), minimum(C1, C2)); + } + + if (IntrinsicID == Intrinsic::maximum) { + const APFloat &C1 = Op1->getValueAPF(); + const APFloat &C2 = Op2->getValueAPF(); + return ConstantFP::get(Ty->getContext(), maximum(C1, C2)); + } + if (!TLI) return nullptr; if ((Name == "pow" && TLI->has(LibFunc_pow)) || @@ -1931,58 +2009,149 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, return nullptr; } - if (auto *Op1 = dyn_cast<ConstantInt>(Operands[0])) { - if (auto *Op2 = dyn_cast<ConstantInt>(Operands[1])) { + if (Operands[0]->getType()->isIntegerTy() && + Operands[1]->getType()->isIntegerTy()) { + const APInt *C0, *C1; + if (!getConstIntOrUndef(Operands[0], C0) || + !getConstIntOrUndef(Operands[1], C1)) + return nullptr; + + switch (IntrinsicID) { + default: break; + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + // Even if both operands are undef, we cannot fold muls to undef + // in the general case. For example, on i2 there are no inputs + // that would produce { i2 -1, i1 true } as the result. + if (!C0 || !C1) + return Constant::getNullValue(Ty); + LLVM_FALLTHROUGH; + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: { + if (!C0 || !C1) + return UndefValue::get(Ty); + + APInt Res; + bool Overflow; switch (IntrinsicID) { - default: break; + default: llvm_unreachable("Invalid case"); case Intrinsic::sadd_with_overflow: + Res = C0->sadd_ov(*C1, Overflow); + break; case Intrinsic::uadd_with_overflow: + Res = C0->uadd_ov(*C1, Overflow); + break; case Intrinsic::ssub_with_overflow: + Res = C0->ssub_ov(*C1, Overflow); + break; case Intrinsic::usub_with_overflow: + Res = C0->usub_ov(*C1, Overflow); + break; case Intrinsic::smul_with_overflow: - case Intrinsic::umul_with_overflow: { - APInt Res; - bool Overflow; - switch (IntrinsicID) { - default: llvm_unreachable("Invalid case"); - case Intrinsic::sadd_with_overflow: - Res = Op1->getValue().sadd_ov(Op2->getValue(), Overflow); - break; - case Intrinsic::uadd_with_overflow: - Res = Op1->getValue().uadd_ov(Op2->getValue(), Overflow); - break; - case Intrinsic::ssub_with_overflow: - Res = Op1->getValue().ssub_ov(Op2->getValue(), Overflow); - break; - case Intrinsic::usub_with_overflow: - Res = Op1->getValue().usub_ov(Op2->getValue(), Overflow); - break; - case Intrinsic::smul_with_overflow: - Res = Op1->getValue().smul_ov(Op2->getValue(), Overflow); - break; - case Intrinsic::umul_with_overflow: - Res = Op1->getValue().umul_ov(Op2->getValue(), Overflow); - break; - } - Constant *Ops[] = { - ConstantInt::get(Ty->getContext(), Res), - ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow) - }; - return ConstantStruct::get(cast<StructType>(Ty), Ops); - } - case Intrinsic::cttz: - if (Op2->isOne() && Op1->isZero()) // cttz(0, 1) is undef. - return UndefValue::get(Ty); - return ConstantInt::get(Ty, Op1->getValue().countTrailingZeros()); - case Intrinsic::ctlz: - if (Op2->isOne() && Op1->isZero()) // ctlz(0, 1) is undef. - return UndefValue::get(Ty); - return ConstantInt::get(Ty, Op1->getValue().countLeadingZeros()); + Res = C0->smul_ov(*C1, Overflow); + break; + case Intrinsic::umul_with_overflow: + Res = C0->umul_ov(*C1, Overflow); + break; } + Constant *Ops[] = { + ConstantInt::get(Ty->getContext(), Res), + ConstantInt::get(Type::getInt1Ty(Ty->getContext()), Overflow) + }; + return ConstantStruct::get(cast<StructType>(Ty), Ops); + } + case Intrinsic::uadd_sat: + case Intrinsic::sadd_sat: + if (!C0 && !C1) + return UndefValue::get(Ty); + if (!C0 || !C1) + return Constant::getAllOnesValue(Ty); + if (IntrinsicID == Intrinsic::uadd_sat) + return ConstantInt::get(Ty, C0->uadd_sat(*C1)); + else + return ConstantInt::get(Ty, C0->sadd_sat(*C1)); + case Intrinsic::usub_sat: + case Intrinsic::ssub_sat: + if (!C0 && !C1) + return UndefValue::get(Ty); + if (!C0 || !C1) + return Constant::getNullValue(Ty); + if (IntrinsicID == Intrinsic::usub_sat) + return ConstantInt::get(Ty, C0->usub_sat(*C1)); + else + return ConstantInt::get(Ty, C0->ssub_sat(*C1)); + case Intrinsic::cttz: + case Intrinsic::ctlz: + assert(C1 && "Must be constant int"); + + // cttz(0, 1) and ctlz(0, 1) are undef. + if (C1->isOneValue() && (!C0 || C0->isNullValue())) + return UndefValue::get(Ty); + if (!C0) + return Constant::getNullValue(Ty); + if (IntrinsicID == Intrinsic::cttz) + return ConstantInt::get(Ty, C0->countTrailingZeros()); + else + return ConstantInt::get(Ty, C0->countLeadingZeros()); } return nullptr; } + + // Support ConstantVector in case we have an Undef in the top. + if ((isa<ConstantVector>(Operands[0]) || + isa<ConstantDataVector>(Operands[0])) && + // Check for default rounding mode. + // FIXME: Support other rounding modes? + isa<ConstantInt>(Operands[1]) && + cast<ConstantInt>(Operands[1])->getValue() == 4) { + auto *Op = cast<Constant>(Operands[0]); + switch (IntrinsicID) { + default: break; + case Intrinsic::x86_avx512_vcvtss2si32: + case Intrinsic::x86_avx512_vcvtss2si64: + case Intrinsic::x86_avx512_vcvtsd2si32: + case Intrinsic::x86_avx512_vcvtsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/false, Ty, + /*IsSigned*/true); + break; + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtsd2usi64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/false, Ty, + /*IsSigned*/false); + break; + case Intrinsic::x86_avx512_cvttss2si: + case Intrinsic::x86_avx512_cvttss2si64: + case Intrinsic::x86_avx512_cvttsd2si: + case Intrinsic::x86_avx512_cvttsd2si64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/true, Ty, + /*IsSigned*/true); + break; + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: + if (ConstantFP *FPOp = + dyn_cast_or_null<ConstantFP>(Op->getAggregateElement(0U))) + return ConstantFoldSSEConvertToInt(FPOp->getValueAPF(), + /*roundTowardZero=*/true, Ty, + /*IsSigned*/false); + break; + } + } return nullptr; } @@ -2010,6 +2179,36 @@ Constant *ConstantFoldScalarCall(StringRef Name, unsigned IntrinsicID, Type *Ty, } } + if (IntrinsicID == Intrinsic::fshl || IntrinsicID == Intrinsic::fshr) { + const APInt *C0, *C1, *C2; + if (!getConstIntOrUndef(Operands[0], C0) || + !getConstIntOrUndef(Operands[1], C1) || + !getConstIntOrUndef(Operands[2], C2)) + return nullptr; + + bool IsRight = IntrinsicID == Intrinsic::fshr; + if (!C2) + return Operands[IsRight ? 1 : 0]; + if (!C0 && !C1) + return UndefValue::get(Ty); + + // The shift amount is interpreted as modulo the bitwidth. If the shift + // amount is effectively 0, avoid UB due to oversized inverse shift below. + unsigned BitWidth = C2->getBitWidth(); + unsigned ShAmt = C2->urem(BitWidth); + if (!ShAmt) + return Operands[IsRight ? 1 : 0]; + + // (C0 << ShlAmt) | (C1 >> LshrAmt) + unsigned LshrAmt = IsRight ? ShAmt : BitWidth - ShAmt; + unsigned ShlAmt = !IsRight ? ShAmt : BitWidth - ShAmt; + if (!C0) + return ConstantInt::get(Ty, C1->lshr(LshrAmt)); + if (!C1) + return ConstantInt::get(Ty, C0->shl(ShlAmt)); + return ConstantInt::get(Ty, C0->shl(ShlAmt) | C1->lshr(LshrAmt)); + } + return nullptr; } diff --git a/lib/Analysis/DemandedBits.cpp b/lib/Analysis/DemandedBits.cpp index e7637cd88327..34f785fb02be 100644 --- a/lib/Analysis/DemandedBits.cpp +++ b/lib/Analysis/DemandedBits.cpp @@ -21,8 +21,7 @@ #include "llvm/Analysis/DemandedBits.h" #include "llvm/ADT/APInt.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ValueTracking.h" @@ -39,6 +38,7 @@ #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" #include "llvm/IR/Use.h" #include "llvm/Pass.h" @@ -50,6 +50,7 @@ #include <cstdint> using namespace llvm; +using namespace llvm::PatternMatch; #define DEBUG_TYPE "demanded-bits" @@ -78,13 +79,14 @@ void DemandedBitsWrapperPass::print(raw_ostream &OS, const Module *M) const { } static bool isAlwaysLive(Instruction *I) { - return isa<TerminatorInst>(I) || isa<DbgInfoIntrinsic>(I) || - I->isEHPad() || I->mayHaveSideEffects(); + return I->isTerminator() || isa<DbgInfoIntrinsic>(I) || I->isEHPad() || + I->mayHaveSideEffects(); } void DemandedBits::determineLiveOperandBits( - const Instruction *UserI, const Instruction *I, unsigned OperandNo, - const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2) { + const Instruction *UserI, const Value *Val, unsigned OperandNo, + const APInt &AOut, APInt &AB, KnownBits &Known, KnownBits &Known2, + bool &KnownBitsComputed) { unsigned BitWidth = AB.getBitWidth(); // We're called once per operand, but for some instructions, we need to @@ -95,7 +97,11 @@ void DemandedBits::determineLiveOperandBits( // provided here. auto ComputeKnownBits = [&](unsigned BitWidth, const Value *V1, const Value *V2) { - const DataLayout &DL = I->getModule()->getDataLayout(); + if (KnownBitsComputed) + return; + KnownBitsComputed = true; + + const DataLayout &DL = UserI->getModule()->getDataLayout(); Known = KnownBits(BitWidth); computeKnownBits(V1, Known, DL, 0, &AC, UserI, &DT); @@ -127,7 +133,7 @@ void DemandedBits::determineLiveOperandBits( // We need some output bits, so we need all bits of the // input to the left of, and including, the leftmost bit // known to be one. - ComputeKnownBits(BitWidth, I, nullptr); + ComputeKnownBits(BitWidth, Val, nullptr); AB = APInt::getHighBitsSet(BitWidth, std::min(BitWidth, Known.countMaxLeadingZeros()+1)); } @@ -137,11 +143,33 @@ void DemandedBits::determineLiveOperandBits( // We need some output bits, so we need all bits of the // input to the right of, and including, the rightmost bit // known to be one. - ComputeKnownBits(BitWidth, I, nullptr); + ComputeKnownBits(BitWidth, Val, nullptr); AB = APInt::getLowBitsSet(BitWidth, std::min(BitWidth, Known.countMaxTrailingZeros()+1)); } break; + case Intrinsic::fshl: + case Intrinsic::fshr: { + const APInt *SA; + if (OperandNo == 2) { + // Shift amount is modulo the bitwidth. For powers of two we have + // SA % BW == SA & (BW - 1). + if (isPowerOf2_32(BitWidth)) + AB = BitWidth - 1; + } else if (match(II->getOperand(2), m_APInt(SA))) { + // Normalize to funnel shift left. APInt shifts of BitWidth are well- + // defined, so no need to special-case zero shifts here. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + if (OperandNo == 0) + AB = AOut.lshr(ShiftAmt); + else if (OperandNo == 1) + AB = AOut.shl(BitWidth - ShiftAmt); + } + break; + } } break; case Instruction::Add: @@ -153,8 +181,9 @@ void DemandedBits::determineLiveOperandBits( AB = APInt::getLowBitsSet(BitWidth, AOut.getActiveBits()); break; case Instruction::Shl: - if (OperandNo == 0) - if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); AB = AOut.lshr(ShiftAmt); @@ -166,10 +195,12 @@ void DemandedBits::determineLiveOperandBits( else if (S->hasNoUnsignedWrap()) AB |= APInt::getHighBitsSet(BitWidth, ShiftAmt); } + } break; case Instruction::LShr: - if (OperandNo == 0) - if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); AB = AOut.shl(ShiftAmt); @@ -178,10 +209,12 @@ void DemandedBits::determineLiveOperandBits( if (cast<LShrOperator>(UserI)->isExact()) AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); } + } break; case Instruction::AShr: - if (OperandNo == 0) - if (auto *ShiftAmtC = dyn_cast<ConstantInt>(UserI->getOperand(1))) { + if (OperandNo == 0) { + const APInt *ShiftAmtC; + if (match(UserI->getOperand(1), m_APInt(ShiftAmtC))) { uint64_t ShiftAmt = ShiftAmtC->getLimitedValue(BitWidth - 1); AB = AOut.shl(ShiftAmt); // Because the high input bit is replicated into the @@ -196,6 +229,7 @@ void DemandedBits::determineLiveOperandBits( if (cast<AShrOperator>(UserI)->isExact()) AB |= APInt::getLowBitsSet(BitWidth, ShiftAmt); } + } break; case Instruction::And: AB = AOut; @@ -204,14 +238,11 @@ void DemandedBits::determineLiveOperandBits( // other operand are dead (unless they're both zero, in which // case they can't both be dead, so just mark the LHS bits as // dead). - if (OperandNo == 0) { - ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + if (OperandNo == 0) AB &= ~Known2.Zero; - } else { - if (!isa<Instruction>(UserI->getOperand(0))) - ComputeKnownBits(BitWidth, UserI->getOperand(0), I); + else AB &= ~(Known.Zero & ~Known2.Zero); - } break; case Instruction::Or: AB = AOut; @@ -220,14 +251,11 @@ void DemandedBits::determineLiveOperandBits( // other operand are dead (unless they're both one, in which // case they can't both be dead, so just mark the LHS bits as // dead). - if (OperandNo == 0) { - ComputeKnownBits(BitWidth, I, UserI->getOperand(1)); + ComputeKnownBits(BitWidth, UserI->getOperand(0), UserI->getOperand(1)); + if (OperandNo == 0) AB &= ~Known2.One; - } else { - if (!isa<Instruction>(UserI->getOperand(0))) - ComputeKnownBits(BitWidth, UserI->getOperand(0), I); + else AB &= ~(Known.One & ~Known2.One); - } break; case Instruction::Xor: case Instruction::PHI: @@ -253,6 +281,15 @@ void DemandedBits::determineLiveOperandBits( if (OperandNo != 0) AB = AOut; break; + case Instruction::ExtractElement: + if (OperandNo == 0) + AB = AOut; + break; + case Instruction::InsertElement: + case Instruction::ShuffleVector: + if (OperandNo == 0 || OperandNo == 1) + AB = AOut; + break; } } @@ -275,8 +312,9 @@ void DemandedBits::performAnalysis() { Visited.clear(); AliveBits.clear(); + DeadUses.clear(); - SmallVector<Instruction*, 128> Worklist; + SmallSetVector<Instruction*, 16> Worklist; // Collect the set of "root" instructions that are known live. for (Instruction &I : instructions(F)) { @@ -288,9 +326,10 @@ void DemandedBits::performAnalysis() { // bits and add the instruction to the work list. For other instructions // add their operands to the work list (for integer values operands, mark // all bits as live). - if (IntegerType *IT = dyn_cast<IntegerType>(I.getType())) { - if (AliveBits.try_emplace(&I, IT->getBitWidth(), 0).second) - Worklist.push_back(&I); + Type *T = I.getType(); + if (T->isIntOrIntVectorTy()) { + if (AliveBits.try_emplace(&I, T->getScalarSizeInBits(), 0).second) + Worklist.insert(&I); continue; } @@ -298,9 +337,10 @@ void DemandedBits::performAnalysis() { // Non-integer-typed instructions... for (Use &OI : I.operands()) { if (Instruction *J = dyn_cast<Instruction>(OI)) { - if (IntegerType *IT = dyn_cast<IntegerType>(J->getType())) - AliveBits[J] = APInt::getAllOnesValue(IT->getBitWidth()); - Worklist.push_back(J); + Type *T = J->getType(); + if (T->isIntOrIntVectorTy()) + AliveBits[J] = APInt::getAllOnesValue(T->getScalarSizeInBits()); + Worklist.insert(J); } } // To save memory, we don't add I to the Visited set here. Instead, we @@ -315,35 +355,51 @@ void DemandedBits::performAnalysis() { LLVM_DEBUG(dbgs() << "DemandedBits: Visiting: " << *UserI); APInt AOut; - if (UserI->getType()->isIntegerTy()) { + if (UserI->getType()->isIntOrIntVectorTy()) { AOut = AliveBits[UserI]; - LLVM_DEBUG(dbgs() << " Alive Out: " << AOut); + LLVM_DEBUG(dbgs() << " Alive Out: 0x" + << Twine::utohexstr(AOut.getLimitedValue())); } LLVM_DEBUG(dbgs() << "\n"); - if (!UserI->getType()->isIntegerTy()) + if (!UserI->getType()->isIntOrIntVectorTy()) Visited.insert(UserI); KnownBits Known, Known2; + bool KnownBitsComputed = false; // Compute the set of alive bits for each operand. These are anded into the // existing set, if any, and if that changes the set of alive bits, the // operand is added to the work-list. for (Use &OI : UserI->operands()) { - if (Instruction *I = dyn_cast<Instruction>(OI)) { - if (IntegerType *IT = dyn_cast<IntegerType>(I->getType())) { - unsigned BitWidth = IT->getBitWidth(); - APInt AB = APInt::getAllOnesValue(BitWidth); - if (UserI->getType()->isIntegerTy() && !AOut && - !isAlwaysLive(UserI)) { - AB = APInt(BitWidth, 0); - } else { - // If all bits of the output are dead, then all bits of the input - // Bits of each operand that are used to compute alive bits of the - // output are alive, all others are dead. - determineLiveOperandBits(UserI, I, OI.getOperandNo(), AOut, AB, - Known, Known2); - } + // We also want to detect dead uses of arguments, but will only store + // demanded bits for instructions. + Instruction *I = dyn_cast<Instruction>(OI); + if (!I && !isa<Argument>(OI)) + continue; + + Type *T = OI->getType(); + if (T->isIntOrIntVectorTy()) { + unsigned BitWidth = T->getScalarSizeInBits(); + APInt AB = APInt::getAllOnesValue(BitWidth); + if (UserI->getType()->isIntOrIntVectorTy() && !AOut && + !isAlwaysLive(UserI)) { + // If all bits of the output are dead, then all bits of the input + // are also dead. + AB = APInt(BitWidth, 0); + } else { + // Bits of each operand that are used to compute alive bits of the + // output are alive, all others are dead. + determineLiveOperandBits(UserI, OI, OI.getOperandNo(), AOut, AB, + Known, Known2, KnownBitsComputed); + + // Keep track of uses which have no demanded bits. + if (AB.isNullValue()) + DeadUses.insert(&OI); + else + DeadUses.erase(&OI); + } + if (I) { // If we've added to the set of alive bits (or the operand has not // been previously visited), then re-queue the operand to be visited // again. @@ -355,11 +411,11 @@ void DemandedBits::performAnalysis() { APInt ABNew = AB | ABPrev; if (ABNew != ABPrev || ABI == AliveBits.end()) { AliveBits[I] = std::move(ABNew); - Worklist.push_back(I); + Worklist.insert(I); } - } else if (!Visited.count(I)) { - Worklist.push_back(I); } + } else if (I && !Visited.count(I)) { + Worklist.insert(I); } } } @@ -368,11 +424,13 @@ void DemandedBits::performAnalysis() { APInt DemandedBits::getDemandedBits(Instruction *I) { performAnalysis(); - const DataLayout &DL = I->getModule()->getDataLayout(); auto Found = AliveBits.find(I); if (Found != AliveBits.end()) return Found->second; - return APInt::getAllOnesValue(DL.getTypeSizeInBits(I->getType())); + + const DataLayout &DL = I->getModule()->getDataLayout(); + return APInt::getAllOnesValue( + DL.getTypeSizeInBits(I->getType()->getScalarType())); } bool DemandedBits::isInstructionDead(Instruction *I) { @@ -382,6 +440,31 @@ bool DemandedBits::isInstructionDead(Instruction *I) { !isAlwaysLive(I); } +bool DemandedBits::isUseDead(Use *U) { + // We only track integer uses, everything else is assumed live. + if (!(*U)->getType()->isIntOrIntVectorTy()) + return false; + + // Uses by always-live instructions are never dead. + Instruction *UserI = cast<Instruction>(U->getUser()); + if (isAlwaysLive(UserI)) + return false; + + performAnalysis(); + if (DeadUses.count(U)) + return true; + + // If no output bits are demanded, no input bits are demanded and the use + // is dead. These uses might not be explicitly present in the DeadUses map. + if (UserI->getType()->isIntOrIntVectorTy()) { + auto Found = AliveBits.find(UserI); + if (Found != AliveBits.end() && Found->second.isNullValue()) + return true; + } + + return false; +} + void DemandedBits::print(raw_ostream &OS) { performAnalysis(); for (auto &KV : AliveBits) { diff --git a/lib/Analysis/DependenceAnalysis.cpp b/lib/Analysis/DependenceAnalysis.cpp index 79c2728d5620..3f4dfa52e1da 100644 --- a/lib/Analysis/DependenceAnalysis.cpp +++ b/lib/Analysis/DependenceAnalysis.cpp @@ -194,6 +194,13 @@ void DependenceAnalysisWrapperPass::print(raw_ostream &OS, dumpExampleDependence(OS, info.get()); } +PreservedAnalyses +DependenceAnalysisPrinterPass::run(Function &F, FunctionAnalysisManager &FAM) { + OS << "'Dependence Analysis' for function '" << F.getName() << "':\n"; + dumpExampleDependence(OS, &FAM.getResult<DependenceAnalysis>(F)); + return PreservedAnalyses::all(); +} + //===----------------------------------------------------------------------===// // Dependence methods @@ -633,8 +640,8 @@ static AliasResult underlyingObjectsAlias(AliasAnalysis *AA, const MemoryLocation &LocB) { // Check the original locations (minus size) for noalias, which can happen for // tbaa, incompatible underlying object locations, etc. - MemoryLocation LocAS(LocA.Ptr, MemoryLocation::UnknownSize, LocA.AATags); - MemoryLocation LocBS(LocB.Ptr, MemoryLocation::UnknownSize, LocB.AATags); + MemoryLocation LocAS(LocA.Ptr, LocationSize::unknown(), LocA.AATags); + MemoryLocation LocBS(LocB.Ptr, LocationSize::unknown(), LocB.AATags); if (AA->alias(LocAS, LocBS) == NoAlias) return NoAlias; diff --git a/lib/Analysis/DivergenceAnalysis.cpp b/lib/Analysis/DivergenceAnalysis.cpp index f5f1874c9303..7ba23854a3cc 100644 --- a/lib/Analysis/DivergenceAnalysis.cpp +++ b/lib/Analysis/DivergenceAnalysis.cpp @@ -7,8 +7,9 @@ // //===----------------------------------------------------------------------===// // -// This file implements divergence analysis which determines whether a branch -// in a GPU program is divergent.It can help branch optimizations such as jump +// This file implements a general divergence analysis for loop vectorization +// and GPU programs. It determines which branches and values in a loop or GPU +// program are divergent. It can help branch optimizations such as jump // threading and loop unswitching to make better decisions. // // GPU programs typically use the SIMD execution model, where multiple threads @@ -16,25 +17,29 @@ // code contains divergent branches (i.e., threads in a group do not agree on // which path of the branch to take), the group of threads has to execute all // the paths from that branch with different subsets of threads enabled until -// they converge at the immediately post-dominating BB of the paths. +// they re-converge. // // Due to this execution model, some optimizations such as jump -// threading and loop unswitching can be unfortunately harmful when performed on -// divergent branches. Therefore, an analysis that computes which branches in a -// GPU program are divergent can help the compiler to selectively run these -// optimizations. +// threading and loop unswitching can interfere with thread re-convergence. +// Therefore, an analysis that computes which branches in a GPU program are +// divergent can help the compiler to selectively run these optimizations. // -// This file defines divergence analysis which computes a conservative but -// non-trivial approximation of all divergent branches in a GPU program. It -// partially implements the approach described in +// This implementation is derived from the Vectorization Analysis of the +// Region Vectorizer (RV). That implementation in turn is based on the approach +// described in // -// Divergence Analysis -// Sampaio, Souza, Collange, Pereira -// TOPLAS '13 +// Improving Performance of OpenCL on CPUs +// Ralf Karrenberg and Sebastian Hack +// CC '12 // -// The divergence analysis identifies the sources of divergence (e.g., special -// variables that hold the thread ID), and recursively marks variables that are -// data or sync dependent on a source of divergence as divergent. +// This DivergenceAnalysis implementation is generic in the sense that it does +// not itself identify original sources of divergence. +// Instead specialized adapter classes, (LoopDivergenceAnalysis) for loops and +// (GPUDivergenceAnalysis) for GPU programs, identify the sources of divergence +// (e.g., special variables that hold the thread ID or the iteration variable). +// +// The generic implementation propagates divergence to variables that are data +// or sync dependent on a source of divergence. // // While data dependency is a well-known concept, the notion of sync dependency // is worth more explanation. Sync dependence characterizes the control flow @@ -54,287 +59,399 @@ // because the branch "br i1 %cond" depends on %tid and affects which value %a // is assigned to. // -// The current implementation has the following limitations: +// The sync dependence detection (which branch induces divergence in which join +// points) is implemented in the SyncDependenceAnalysis. +// +// The current DivergenceAnalysis implementation has the following limitations: // 1. intra-procedural. It conservatively considers the arguments of a // non-kernel-entry function and the return value of a function call as // divergent. // 2. memory as black box. It conservatively considers values loaded from // generic or local address as divergent. This can be improved by leveraging -// pointer analysis. +// pointer analysis and/or by modelling non-escaping memory objects in SSA +// as done in RV. // //===----------------------------------------------------------------------===// #include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/Passes.h" #include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include <vector> + using namespace llvm; -#define DEBUG_TYPE "divergence" - -namespace { - -class DivergencePropagator { -public: - DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, - PostDominatorTree &PDT, DenseSet<const Value *> &DV) - : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {} - void populateWithSourcesOfDivergence(); - void propagate(); - -private: - // A helper function that explores data dependents of V. - void exploreDataDependency(Value *V); - // A helper function that explores sync dependents of TI. - void exploreSyncDependency(TerminatorInst *TI); - // Computes the influence region from Start to End. This region includes all - // basic blocks on any simple path from Start to End. - void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion); - // Finds all users of I that are outside the influence region, and add these - // users to Worklist. - void findUsersOutsideInfluenceRegion( - Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion); - - Function &F; - TargetTransformInfo &TTI; - DominatorTree &DT; - PostDominatorTree &PDT; - std::vector<Value *> Worklist; // Stack for DFS. - DenseSet<const Value *> &DV; // Stores all divergent values. -}; - -void DivergencePropagator::populateWithSourcesOfDivergence() { - Worklist.clear(); - DV.clear(); - for (auto &I : instructions(F)) { - if (TTI.isSourceOfDivergence(&I)) { - Worklist.push_back(&I); - DV.insert(&I); - } +#define DEBUG_TYPE "divergence-analysis" + +// class DivergenceAnalysis +DivergenceAnalysis::DivergenceAnalysis( + const Function &F, const Loop *RegionLoop, const DominatorTree &DT, + const LoopInfo &LI, SyncDependenceAnalysis &SDA, bool IsLCSSAForm) + : F(F), RegionLoop(RegionLoop), DT(DT), LI(LI), SDA(SDA), + IsLCSSAForm(IsLCSSAForm) {} + +void DivergenceAnalysis::markDivergent(const Value &DivVal) { + assert(isa<Instruction>(DivVal) || isa<Argument>(DivVal)); + assert(!isAlwaysUniform(DivVal) && "cannot be a divergent"); + DivergentValues.insert(&DivVal); +} + +void DivergenceAnalysis::addUniformOverride(const Value &UniVal) { + UniformOverrides.insert(&UniVal); +} + +bool DivergenceAnalysis::updateTerminator(const Instruction &Term) const { + if (Term.getNumSuccessors() <= 1) + return false; + if (auto *BranchTerm = dyn_cast<BranchInst>(&Term)) { + assert(BranchTerm->isConditional()); + return isDivergent(*BranchTerm->getCondition()); } - for (auto &Arg : F.args()) { - if (TTI.isSourceOfDivergence(&Arg)) { - Worklist.push_back(&Arg); - DV.insert(&Arg); - } + if (auto *SwitchTerm = dyn_cast<SwitchInst>(&Term)) { + return isDivergent(*SwitchTerm->getCondition()); + } + if (isa<InvokeInst>(Term)) { + return false; // ignore abnormal executions through landingpad } + + llvm_unreachable("unexpected terminator"); } -void DivergencePropagator::exploreSyncDependency(TerminatorInst *TI) { - // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's - // immediate post dominator are divergent. This rule handles if-then-else - // patterns. For example, - // - // if (tid < 5) - // a1 = 1; - // else - // a2 = 2; - // a = phi(a1, a2); // sync dependent on (tid < 5) - BasicBlock *ThisBB = TI->getParent(); - - // Unreachable blocks may not be in the dominator tree. - if (!DT.isReachableFromEntry(ThisBB)) - return; +bool DivergenceAnalysis::updateNormalInstruction(const Instruction &I) const { + // TODO function calls with side effects, etc + for (const auto &Op : I.operands()) { + if (isDivergent(*Op)) + return true; + } + return false; +} - // If the function has no exit blocks or doesn't reach any exit blocks, the - // post dominator may be null. - DomTreeNode *ThisNode = PDT.getNode(ThisBB); - if (!ThisNode) - return; +bool DivergenceAnalysis::isTemporalDivergent(const BasicBlock &ObservingBlock, + const Value &Val) const { + const auto *Inst = dyn_cast<const Instruction>(&Val); + if (!Inst) + return false; + // check whether any divergent loop carrying Val terminates before control + // proceeds to ObservingBlock + for (const auto *Loop = LI.getLoopFor(Inst->getParent()); + Loop != RegionLoop && !Loop->contains(&ObservingBlock); + Loop = Loop->getParentLoop()) { + if (DivergentLoops.find(Loop) != DivergentLoops.end()) + return true; + } - BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); - if (IPostDom == nullptr) - return; + return false; +} - for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { - // A PHINode is uniform if it returns the same value no matter which path is - // taken. - if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) - Worklist.push_back(&*I); +bool DivergenceAnalysis::updatePHINode(const PHINode &Phi) const { + // joining divergent disjoint path in Phi parent block + if (!Phi.hasConstantOrUndefValue() && isJoinDivergent(*Phi.getParent())) { + return true; } - // Propagation rule 2: if a value defined in a loop is used outside, the user - // is sync dependent on the condition of the loop exits that dominate the - // user. For example, - // - // int i = 0; - // do { - // i++; - // if (foo(i)) ... // uniform - // } while (i < tid); - // if (bar(i)) ... // divergent + // An incoming value could be divergent by itself. + // Otherwise, an incoming value could be uniform within the loop + // that carries its definition but it may appear divergent + // from outside the loop. This happens when divergent loop exits + // drop definitions of that uniform value in different iterations. // - // A program may contain unstructured loops. Therefore, we cannot leverage - // LoopInfo, which only recognizes natural loops. - // - // The algorithm used here handles both natural and unstructured loops. Given - // a branch TI, we first compute its influence region, the union of all simple - // paths from TI to its immediate post dominator (IPostDom). Then, we search - // for all the values defined in the influence region but used outside. All - // these users are sync dependent on TI. - DenseSet<BasicBlock *> InfluenceRegion; - computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion); - // An insight that can speed up the search process is that all the in-region - // values that are used outside must dominate TI. Therefore, instead of - // searching every basic blocks in the influence region, we search all the - // dominators of TI until it is outside the influence region. - BasicBlock *InfluencedBB = ThisBB; - while (InfluenceRegion.count(InfluencedBB)) { - for (auto &I : *InfluencedBB) - findUsersOutsideInfluenceRegion(I, InfluenceRegion); - DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); - if (IDomNode == nullptr) - break; - InfluencedBB = IDomNode->getBlock(); + // for (int i = 0; i < n; ++i) { // 'i' is uniform inside the loop + // if (i % thread_id == 0) break; // divergent loop exit + // } + // int divI = i; // divI is divergent + for (size_t i = 0; i < Phi.getNumIncomingValues(); ++i) { + const auto *InVal = Phi.getIncomingValue(i); + if (isDivergent(*Phi.getIncomingValue(i)) || + isTemporalDivergent(*Phi.getParent(), *InVal)) { + return true; + } } + return false; } -void DivergencePropagator::findUsersOutsideInfluenceRegion( - Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { - for (User *U : I.users()) { - Instruction *UserInst = cast<Instruction>(U); - if (!InfluenceRegion.count(UserInst->getParent())) { - if (DV.insert(UserInst).second) - Worklist.push_back(UserInst); +bool DivergenceAnalysis::inRegion(const Instruction &I) const { + return I.getParent() && inRegion(*I.getParent()); +} + +bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { + return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); +} + +// marks all users of loop-carried values of the loop headed by LoopHeader as +// divergent +void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { + auto *DivLoop = LI.getLoopFor(&LoopHeader); + assert(DivLoop && "loopHeader is not actually part of a loop"); + + SmallVector<BasicBlock *, 8> TaintStack; + DivLoop->getExitBlocks(TaintStack); + + // Otherwise potential users of loop-carried values could be anywhere in the + // dominance region of DivLoop (including its fringes for phi nodes) + DenseSet<const BasicBlock *> Visited; + for (auto *Block : TaintStack) { + Visited.insert(Block); + } + Visited.insert(&LoopHeader); + + while (!TaintStack.empty()) { + auto *UserBlock = TaintStack.back(); + TaintStack.pop_back(); + + // don't spread divergence beyond the region + if (!inRegion(*UserBlock)) + continue; + + assert(!DivLoop->contains(UserBlock) && + "irreducible control flow detected"); + + // phi nodes at the fringes of the dominance region + if (!DT.dominates(&LoopHeader, UserBlock)) { + // all PHI nodes of UserBlock become divergent + for (auto &Phi : UserBlock->phis()) { + Worklist.push_back(&Phi); + } + continue; + } + + // taint outside users of values carried by DivLoop + for (auto &I : *UserBlock) { + if (isAlwaysUniform(I)) + continue; + if (isDivergent(I)) + continue; + + for (auto &Op : I.operands()) { + auto *OpInst = dyn_cast<Instruction>(&Op); + if (!OpInst) + continue; + if (DivLoop->contains(OpInst->getParent())) { + markDivergent(I); + pushUsers(I); + break; + } + } + } + + // visit all blocks in the dominance region + for (auto *SuccBlock : successors(UserBlock)) { + if (!Visited.insert(SuccBlock).second) { + continue; + } + TaintStack.push_back(SuccBlock); } } } -// A helper function for computeInfluenceRegion that adds successors of "ThisBB" -// to the influence region. -static void -addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion, - std::vector<BasicBlock *> &InfluenceStack) { - for (BasicBlock *Succ : successors(ThisBB)) { - if (Succ != End && InfluenceRegion.insert(Succ).second) - InfluenceStack.push_back(Succ); +void DivergenceAnalysis::pushPHINodes(const BasicBlock &Block) { + for (const auto &Phi : Block.phis()) { + if (isDivergent(Phi)) + continue; + Worklist.push_back(&Phi); } } -void DivergencePropagator::computeInfluenceRegion( - BasicBlock *Start, BasicBlock *End, - DenseSet<BasicBlock *> &InfluenceRegion) { - assert(PDT.properlyDominates(End, Start) && - "End does not properly dominate Start"); - - // The influence region starts from the end of "Start" to the beginning of - // "End". Therefore, "Start" should not be in the region unless "Start" is in - // a loop that doesn't contain "End". - std::vector<BasicBlock *> InfluenceStack; - addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack); - while (!InfluenceStack.empty()) { - BasicBlock *BB = InfluenceStack.back(); - InfluenceStack.pop_back(); - addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack); +void DivergenceAnalysis::pushUsers(const Value &V) { + for (const auto *User : V.users()) { + const auto *UserInst = dyn_cast<const Instruction>(User); + if (!UserInst) + continue; + + if (isDivergent(*UserInst)) + continue; + + // only compute divergent inside loop + if (!inRegion(*UserInst)) + continue; + Worklist.push_back(UserInst); } } -void DivergencePropagator::exploreDataDependency(Value *V) { - // Follow def-use chains of V. - for (User *U : V->users()) { - Instruction *UserInst = cast<Instruction>(U); - if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second) - Worklist.push_back(UserInst); +bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, + const Loop *BranchLoop) { + LLVM_DEBUG(dbgs() << "\tpropJoinDiv " << JoinBlock.getName() << "\n"); + + // ignore divergence outside the region + if (!inRegion(JoinBlock)) { + return false; + } + + // push non-divergent phi nodes in JoinBlock to the worklist + pushPHINodes(JoinBlock); + + // JoinBlock is a divergent loop exit + if (BranchLoop && !BranchLoop->contains(&JoinBlock)) { + return true; } + + // disjoint-paths divergent at JoinBlock + markBlockJoinDivergent(JoinBlock); + return false; } -void DivergencePropagator::propagate() { - // Traverse the dependency graph using DFS. - while (!Worklist.empty()) { - Value *V = Worklist.back(); - Worklist.pop_back(); - if (TerminatorInst *TI = dyn_cast<TerminatorInst>(V)) { - // Terminators with less than two successors won't introduce sync - // dependency. Ignore them. - if (TI->getNumSuccessors() > 1) - exploreSyncDependency(TI); +void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { + LLVM_DEBUG(dbgs() << "propBranchDiv " << Term.getParent()->getName() << "\n"); + + markDivergent(Term); + + const auto *BranchLoop = LI.getLoopFor(Term.getParent()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint from Term within the loop + // also iterates over loop exits that become divergent due to Term. + for (const auto *JoinBlock : SDA.join_blocks(Term)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } + + // Branch loop is a divergent loop due to the divergent branch in Term + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; } - exploreDataDependency(V); + propagateLoopDivergence(*BranchLoop); } } -} /// end namespace anonymous +void DivergenceAnalysis::propagateLoopDivergence(const Loop &ExitingLoop) { + LLVM_DEBUG(dbgs() << "propLoopDiv " << ExitingLoop.getName() << "\n"); + + // don't propagate beyond region + if (!inRegion(*ExitingLoop.getHeader())) + return; -// Register this pass. -char DivergenceAnalysis::ID = 0; -INITIALIZE_PASS_BEGIN(DivergenceAnalysis, "divergence", "Divergence Analysis", - false, true) -INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) -INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) -INITIALIZE_PASS_END(DivergenceAnalysis, "divergence", "Divergence Analysis", - false, true) + const auto *BranchLoop = ExitingLoop.getParentLoop(); -FunctionPass *llvm::createDivergenceAnalysisPass() { - return new DivergenceAnalysis(); -} + // Uses of loop-carried values could occur anywhere + // within the dominance region of the definition. All loop-carried + // definitions are dominated by the loop header (reducible control). + // Thus all users have to be in the dominance region of the loop header, + // except PHI nodes that can also live at the fringes of the dom region + // (incoming defining value). + if (!IsLCSSAForm) + taintLoopLiveOuts(*ExitingLoop.getHeader()); + + // whether there is a divergent loop exit from BranchLoop (if any) + bool IsBranchLoopDivergent = false; + + // iterate over all blocks reachable by disjoint paths from exits of + // ExitingLoop also iterates over loop exits (of BranchLoop) that in turn + // become divergent. + for (const auto *JoinBlock : SDA.join_blocks(ExitingLoop)) { + IsBranchLoopDivergent |= propagateJoinDivergence(*JoinBlock, BranchLoop); + } -void DivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { - AU.addRequired<DominatorTreeWrapperPass>(); - AU.addRequired<PostDominatorTreeWrapperPass>(); - AU.setPreservesAll(); + // Branch loop is a divergent due to divergent loop exit in ExitingLoop + if (IsBranchLoopDivergent) { + assert(BranchLoop); + if (!DivergentLoops.insert(BranchLoop).second) { + return; + } + propagateLoopDivergence(*BranchLoop); + } } -bool DivergenceAnalysis::runOnFunction(Function &F) { - auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); - if (TTIWP == nullptr) - return false; +void DivergenceAnalysis::compute() { + for (auto *DivVal : DivergentValues) { + pushUsers(*DivVal); + } - TargetTransformInfo &TTI = TTIWP->getTTI(F); - // Fast path: if the target does not have branch divergence, we do not mark - // any branch as divergent. - if (!TTI.hasBranchDivergence()) - return false; + // propagate divergence + while (!Worklist.empty()) { + const Instruction &I = *Worklist.back(); + Worklist.pop_back(); - DivergentValues.clear(); - auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - DivergencePropagator DP(F, TTI, - getAnalysis<DominatorTreeWrapperPass>().getDomTree(), - PDT, DivergentValues); - DP.populateWithSourcesOfDivergence(); - DP.propagate(); - LLVM_DEBUG( - dbgs() << "\nAfter divergence analysis on " << F.getName() << ":\n"; - print(dbgs(), F.getParent()) - ); - return false; + // maintain uniformity of overrides + if (isAlwaysUniform(I)) + continue; + + bool WasDivergent = isDivergent(I); + if (WasDivergent) + continue; + + // propagate divergence caused by terminator + if (I.isTerminator()) { + if (updateTerminator(I)) { + // propagate control divergence to affected instructions + propagateBranchDivergence(I); + continue; + } + } + + // update divergence of I due to divergent operands + bool DivergentUpd = false; + const auto *Phi = dyn_cast<const PHINode>(&I); + if (Phi) { + DivergentUpd = updatePHINode(*Phi); + } else { + DivergentUpd = updateNormalInstruction(I); + } + + // propagate value divergence to users + if (DivergentUpd) { + markDivergent(I); + pushUsers(I); + } + } +} + +bool DivergenceAnalysis::isAlwaysUniform(const Value &V) const { + return UniformOverrides.find(&V) != UniformOverrides.end(); +} + +bool DivergenceAnalysis::isDivergent(const Value &V) const { + return DivergentValues.find(&V) != DivergentValues.end(); } void DivergenceAnalysis::print(raw_ostream &OS, const Module *) const { if (DivergentValues.empty()) return; - const Value *FirstDivergentValue = *DivergentValues.begin(); - const Function *F; - if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) { - F = Arg->getParent(); - } else if (const Instruction *I = - dyn_cast<Instruction>(FirstDivergentValue)) { - F = I->getParent()->getParent(); - } else { - llvm_unreachable("Only arguments and instructions can be divergent"); + // iterate instructions using instructions() to ensure a deterministic order. + for (auto &I : instructions(F)) { + if (isDivergent(I)) + OS << "DIVERGENT:" << I << '\n'; } +} - // Dumps all divergent values in F, arguments and then instructions. - for (auto &Arg : F->args()) { - OS << (DivergentValues.count(&Arg) ? "DIVERGENT: " : " "); - OS << Arg << "\n"; +// class GPUDivergenceAnalysis +GPUDivergenceAnalysis::GPUDivergenceAnalysis(Function &F, + const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI, + const TargetTransformInfo &TTI) + : SDA(DT, PDT, LI), DA(F, nullptr, DT, LI, SDA, false) { + for (auto &I : instructions(F)) { + if (TTI.isSourceOfDivergence(&I)) { + DA.markDivergent(I); + } else if (TTI.isAlwaysUniform(&I)) { + DA.addUniformOverride(I); + } } - // Iterate instructions using instructions() to ensure a deterministic order. - for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) { - auto &BB = *BI; - OS << "\n " << BB.getName() << ":\n"; - for (auto &I : BB.instructionsWithoutDebug()) { - OS << (DivergentValues.count(&I) ? "DIVERGENT: " : " "); - OS << I << "\n"; + for (auto &Arg : F.args()) { + if (TTI.isSourceOfDivergence(&Arg)) { + DA.markDivergent(Arg); } } - OS << "\n"; + + DA.compute(); +} + +bool GPUDivergenceAnalysis::isDivergent(const Value &val) const { + return DA.isDivergent(val); +} + +void GPUDivergenceAnalysis::print(raw_ostream &OS, const Module *mod) const { + OS << "Divergence of kernel " << DA.getFunction().getName() << " {\n"; + DA.print(OS, mod); + OS << "}\n"; } diff --git a/lib/Analysis/EHPersonalities.cpp b/lib/Analysis/EHPersonalities.cpp index 2d35a3fa9118..0df73aeebbdc 100644 --- a/lib/Analysis/EHPersonalities.cpp +++ b/lib/Analysis/EHPersonalities.cpp @@ -120,7 +120,7 @@ DenseMap<BasicBlock *, ColorVector> llvm::colorEHFunclets(Function &F) { << "\'.\n"); BasicBlock *SuccColor = Color; - TerminatorInst *Terminator = Visiting->getTerminator(); + Instruction *Terminator = Visiting->getTerminator(); if (auto *CatchRet = dyn_cast<CatchReturnInst>(Terminator)) { Value *ParentPad = CatchRet->getCatchSwitchParentPad(); if (isa<ConstantTokenNone>(ParentPad)) diff --git a/lib/Analysis/GlobalsModRef.cpp b/lib/Analysis/GlobalsModRef.cpp index 2c503609d96b..b28abcadca4a 100644 --- a/lib/Analysis/GlobalsModRef.cpp +++ b/lib/Analysis/GlobalsModRef.cpp @@ -255,11 +255,11 @@ FunctionModRefBehavior GlobalsAAResult::getModRefBehavior(const Function *F) { } FunctionModRefBehavior -GlobalsAAResult::getModRefBehavior(ImmutableCallSite CS) { +GlobalsAAResult::getModRefBehavior(const CallBase *Call) { FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; - if (!CS.hasOperandBundles()) - if (const Function *F = CS.getCalledFunction()) + if (!Call->hasOperandBundles()) + if (const Function *F = Call->getCalledFunction()) if (FunctionInfo *FI = getFunctionInfo(F)) { if (!isModOrRefSet(FI->getModRefInfo())) Min = FMRB_DoesNotAccessMemory; @@ -267,7 +267,7 @@ GlobalsAAResult::getModRefBehavior(ImmutableCallSite CS) { Min = FMRB_OnlyReadsMemory; } - return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); + return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min); } /// Returns the function info for the function, or null if we don't have @@ -366,14 +366,14 @@ bool GlobalsAAResult::AnalyzeUsesOfPointer(Value *V, } else if (Operator::getOpcode(I) == Instruction::BitCast) { if (AnalyzeUsesOfPointer(I, Readers, Writers, OkayStoreDest)) return true; - } else if (auto CS = CallSite(I)) { + } else if (auto *Call = dyn_cast<CallBase>(I)) { // Make sure that this is just the function being called, not that it is // passing into the function. - if (CS.isDataOperand(&U)) { + if (Call->isDataOperand(&U)) { // Detect calls to free. - if (CS.isArgOperand(&U) && isFreeCall(I, &TLI)) { + if (Call->isArgOperand(&U) && isFreeCall(I, &TLI)) { if (Writers) - Writers->insert(CS->getParent()->getParent()); + Writers->insert(Call->getParent()->getParent()); } else { return true; // Argument of an unknown call. } @@ -576,15 +576,15 @@ void GlobalsAAResult::AnalyzeCallGraph(CallGraph &CG, Module &M) { // We handle calls specially because the graph-relevant aspects are // handled above. - if (auto CS = CallSite(&I)) { - if (isAllocationFn(&I, &TLI) || isFreeCall(&I, &TLI)) { + if (auto *Call = dyn_cast<CallBase>(&I)) { + if (isAllocationFn(Call, &TLI) || isFreeCall(Call, &TLI)) { // FIXME: It is completely unclear why this is necessary and not // handled by the above graph code. FI.addModRefInfo(ModRefInfo::ModRef); - } else if (Function *Callee = CS.getCalledFunction()) { + } else if (Function *Callee = Call->getCalledFunction()) { // The callgraph doesn't include intrinsic calls. if (Callee->isIntrinsic()) { - if (isa<DbgInfoIntrinsic>(I)) + if (isa<DbgInfoIntrinsic>(Call)) // Don't let dbg intrinsics affect alias info. continue; @@ -885,16 +885,16 @@ AliasResult GlobalsAAResult::alias(const MemoryLocation &LocA, return AAResultBase::alias(LocA, LocB); } -ModRefInfo GlobalsAAResult::getModRefInfoForArgument(ImmutableCallSite CS, +ModRefInfo GlobalsAAResult::getModRefInfoForArgument(const CallBase *Call, const GlobalValue *GV) { - if (CS.doesNotAccessMemory()) + if (Call->doesNotAccessMemory()) return ModRefInfo::NoModRef; ModRefInfo ConservativeResult = - CS.onlyReadsMemory() ? ModRefInfo::Ref : ModRefInfo::ModRef; + Call->onlyReadsMemory() ? ModRefInfo::Ref : ModRefInfo::ModRef; // Iterate through all the arguments to the called function. If any argument // is based on GV, return the conservative result. - for (auto &A : CS.args()) { + for (auto &A : Call->args()) { SmallVector<Value*, 4> Objects; GetUnderlyingObjects(A, Objects, DL); @@ -914,7 +914,7 @@ ModRefInfo GlobalsAAResult::getModRefInfoForArgument(ImmutableCallSite CS, return ModRefInfo::NoModRef; } -ModRefInfo GlobalsAAResult::getModRefInfo(ImmutableCallSite CS, +ModRefInfo GlobalsAAResult::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { ModRefInfo Known = ModRefInfo::ModRef; @@ -923,15 +923,15 @@ ModRefInfo GlobalsAAResult::getModRefInfo(ImmutableCallSite CS, if (const GlobalValue *GV = dyn_cast<GlobalValue>(GetUnderlyingObject(Loc.Ptr, DL))) if (GV->hasLocalLinkage()) - if (const Function *F = CS.getCalledFunction()) + if (const Function *F = Call->getCalledFunction()) if (NonAddressTakenGlobals.count(GV)) if (const FunctionInfo *FI = getFunctionInfo(F)) Known = unionModRef(FI->getModRefInfoForGlobal(*GV), - getModRefInfoForArgument(CS, GV)); + getModRefInfoForArgument(Call, GV)); if (!isModOrRefSet(Known)) return ModRefInfo::NoModRef; // No need to query other mod/ref analyses - return intersectModRef(Known, AAResultBase::getModRefInfo(CS, Loc)); + return intersectModRef(Known, AAResultBase::getModRefInfo(Call, Loc)); } GlobalsAAResult::GlobalsAAResult(const DataLayout &DL, diff --git a/lib/Analysis/GuardUtils.cpp b/lib/Analysis/GuardUtils.cpp new file mode 100644 index 000000000000..08fa6abeafb5 --- /dev/null +++ b/lib/Analysis/GuardUtils.cpp @@ -0,0 +1,21 @@ +//===-- GuardUtils.cpp - Utils for work with guards -------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Utils that are used to perform analyzes related to guards and their +// conditions. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/GuardUtils.h" +#include "llvm/IR/PatternMatch.h" + +using namespace llvm; + +bool llvm::isGuard(const User *U) { + using namespace llvm::PatternMatch; + return match(U, m_Intrinsic<Intrinsic::experimental_guard>()); +} diff --git a/lib/Analysis/IVDescriptors.cpp b/lib/Analysis/IVDescriptors.cpp new file mode 100644 index 000000000000..aaebc4a481ec --- /dev/null +++ b/lib/Analysis/IVDescriptors.cpp @@ -0,0 +1,1089 @@ +//===- llvm/Analysis/IVDescriptors.cpp - IndVar Descriptors -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file "describes" induction and recurrence variables. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/IVDescriptors.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MustExecute.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/DomTreeUpdater.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "iv-descriptors" + +bool RecurrenceDescriptor::areAllUsesIn(Instruction *I, + SmallPtrSetImpl<Instruction *> &Set) { + for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; ++Use) + if (!Set.count(dyn_cast<Instruction>(*Use))) + return false; + return true; +} + +bool RecurrenceDescriptor::isIntegerRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_IntegerOr: + case RK_IntegerAnd: + case RK_IntegerXor: + case RK_IntegerMinMax: + return true; + } + return false; +} + +bool RecurrenceDescriptor::isFloatingPointRecurrenceKind(RecurrenceKind Kind) { + return (Kind != RK_NoRecurrence) && !isIntegerRecurrenceKind(Kind); +} + +bool RecurrenceDescriptor::isArithmeticRecurrenceKind(RecurrenceKind Kind) { + switch (Kind) { + default: + break; + case RK_IntegerAdd: + case RK_IntegerMult: + case RK_FloatAdd: + case RK_FloatMult: + return true; + } + return false; +} + +/// Determines if Phi may have been type-promoted. If Phi has a single user +/// that ANDs the Phi with a type mask, return the user. RT is updated to +/// account for the narrower bit width represented by the mask, and the AND +/// instruction is added to CI. +static Instruction *lookThroughAnd(PHINode *Phi, Type *&RT, + SmallPtrSetImpl<Instruction *> &Visited, + SmallPtrSetImpl<Instruction *> &CI) { + if (!Phi->hasOneUse()) + return Phi; + + const APInt *M = nullptr; + Instruction *I, *J = cast<Instruction>(Phi->use_begin()->getUser()); + + // Matches either I & 2^x-1 or 2^x-1 & I. If we find a match, we update RT + // with a new integer type of the corresponding bit width. + if (match(J, m_c_And(m_Instruction(I), m_APInt(M)))) { + int32_t Bits = (*M + 1).exactLogBase2(); + if (Bits > 0) { + RT = IntegerType::get(Phi->getContext(), Bits); + Visited.insert(Phi); + CI.insert(J); + return J; + } + } + return Phi; +} + +/// Compute the minimal bit width needed to represent a reduction whose exit +/// instruction is given by Exit. +static std::pair<Type *, bool> computeRecurrenceType(Instruction *Exit, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + bool IsSigned = false; + const DataLayout &DL = Exit->getModule()->getDataLayout(); + uint64_t MaxBitWidth = DL.getTypeSizeInBits(Exit->getType()); + + if (DB) { + // Use the demanded bits analysis to determine the bits that are live out + // of the exit instruction, rounding up to the nearest power of two. If the + // use of demanded bits results in a smaller bit width, we know the value + // must be positive (i.e., IsSigned = false), because if this were not the + // case, the sign bit would have been demanded. + auto Mask = DB->getDemandedBits(Exit); + MaxBitWidth = Mask.getBitWidth() - Mask.countLeadingZeros(); + } + + if (MaxBitWidth == DL.getTypeSizeInBits(Exit->getType()) && AC && DT) { + // If demanded bits wasn't able to limit the bit width, we can try to use + // value tracking instead. This can be the case, for example, if the value + // may be negative. + auto NumSignBits = ComputeNumSignBits(Exit, DL, 0, AC, nullptr, DT); + auto NumTypeBits = DL.getTypeSizeInBits(Exit->getType()); + MaxBitWidth = NumTypeBits - NumSignBits; + KnownBits Bits = computeKnownBits(Exit, DL); + if (!Bits.isNonNegative()) { + // If the value is not known to be non-negative, we set IsSigned to true, + // meaning that we will use sext instructions instead of zext + // instructions to restore the original type. + IsSigned = true; + if (!Bits.isNegative()) + // If the value is not known to be negative, we don't known what the + // upper bit is, and therefore, we don't know what kind of extend we + // will need. In this case, just increase the bit width by one bit and + // use sext. + ++MaxBitWidth; + } + } + if (!isPowerOf2_64(MaxBitWidth)) + MaxBitWidth = NextPowerOf2(MaxBitWidth); + + return std::make_pair(Type::getIntNTy(Exit->getContext(), MaxBitWidth), + IsSigned); +} + +/// Collect cast instructions that can be ignored in the vectorizer's cost +/// model, given a reduction exit value and the minimal type in which the +/// reduction can be represented. +static void collectCastsToIgnore(Loop *TheLoop, Instruction *Exit, + Type *RecurrenceType, + SmallPtrSetImpl<Instruction *> &Casts) { + + SmallVector<Instruction *, 8> Worklist; + SmallPtrSet<Instruction *, 8> Visited; + Worklist.push_back(Exit); + + while (!Worklist.empty()) { + Instruction *Val = Worklist.pop_back_val(); + Visited.insert(Val); + if (auto *Cast = dyn_cast<CastInst>(Val)) + if (Cast->getSrcTy() == RecurrenceType) { + // If the source type of a cast instruction is equal to the recurrence + // type, it will be eliminated, and should be ignored in the vectorizer + // cost model. + Casts.insert(Cast); + continue; + } + + // Add all operands to the work list if they are loop-varying values that + // we haven't yet visited. + for (Value *O : cast<User>(Val)->operands()) + if (auto *I = dyn_cast<Instruction>(O)) + if (TheLoop->contains(I) && !Visited.count(I)) + Worklist.push_back(I); + } +} + +bool RecurrenceDescriptor::AddReductionVar(PHINode *Phi, RecurrenceKind Kind, + Loop *TheLoop, bool HasFunNoNaNAttr, + RecurrenceDescriptor &RedDes, + DemandedBits *DB, + AssumptionCache *AC, + DominatorTree *DT) { + if (Phi->getNumIncomingValues() != 2) + return false; + + // Reduction variables are only found in the loop header block. + if (Phi->getParent() != TheLoop->getHeader()) + return false; + + // Obtain the reduction start value from the value that comes from the loop + // preheader. + Value *RdxStart = Phi->getIncomingValueForBlock(TheLoop->getLoopPreheader()); + + // ExitInstruction is the single value which is used outside the loop. + // We only allow for a single reduction value to be used outside the loop. + // This includes users of the reduction, variables (which form a cycle + // which ends in the phi node). + Instruction *ExitInstruction = nullptr; + // Indicates that we found a reduction operation in our scan. + bool FoundReduxOp = false; + + // We start with the PHI node and scan for all of the users of this + // instruction. All users must be instructions that can be used as reduction + // variables (such as ADD). We must have a single out-of-block user. The cycle + // must include the original PHI. + bool FoundStartPHI = false; + + // To recognize min/max patterns formed by a icmp select sequence, we store + // the number of instruction we saw from the recognized min/max pattern, + // to make sure we only see exactly the two instructions. + unsigned NumCmpSelectPatternInst = 0; + InstDesc ReduxDesc(false, nullptr); + + // Data used for determining if the recurrence has been type-promoted. + Type *RecurrenceType = Phi->getType(); + SmallPtrSet<Instruction *, 4> CastInsts; + Instruction *Start = Phi; + bool IsSigned = false; + + SmallPtrSet<Instruction *, 8> VisitedInsts; + SmallVector<Instruction *, 8> Worklist; + + // Return early if the recurrence kind does not match the type of Phi. If the + // recurrence kind is arithmetic, we attempt to look through AND operations + // resulting from the type promotion performed by InstCombine. Vector + // operations are not limited to the legal integer widths, so we may be able + // to evaluate the reduction in the narrower width. + if (RecurrenceType->isFloatingPointTy()) { + if (!isFloatingPointRecurrenceKind(Kind)) + return false; + } else { + if (!isIntegerRecurrenceKind(Kind)) + return false; + if (isArithmeticRecurrenceKind(Kind)) + Start = lookThroughAnd(Phi, RecurrenceType, VisitedInsts, CastInsts); + } + + Worklist.push_back(Start); + VisitedInsts.insert(Start); + + // A value in the reduction can be used: + // - By the reduction: + // - Reduction operation: + // - One use of reduction value (safe). + // - Multiple use of reduction value (not safe). + // - PHI: + // - All uses of the PHI must be the reduction (safe). + // - Otherwise, not safe. + // - By instructions outside of the loop (safe). + // * One value may have several outside users, but all outside + // uses must be of the same value. + // - By an instruction that is not part of the reduction (not safe). + // This is either: + // * An instruction type other than PHI or the reduction operation. + // * A PHI in the header other than the initial PHI. + while (!Worklist.empty()) { + Instruction *Cur = Worklist.back(); + Worklist.pop_back(); + + // No Users. + // If the instruction has no users then this is a broken chain and can't be + // a reduction variable. + if (Cur->use_empty()) + return false; + + bool IsAPhi = isa<PHINode>(Cur); + + // A header PHI use other than the original PHI. + if (Cur != Phi && IsAPhi && Cur->getParent() == Phi->getParent()) + return false; + + // Reductions of instructions such as Div, and Sub is only possible if the + // LHS is the reduction variable. + if (!Cur->isCommutative() && !IsAPhi && !isa<SelectInst>(Cur) && + !isa<ICmpInst>(Cur) && !isa<FCmpInst>(Cur) && + !VisitedInsts.count(dyn_cast<Instruction>(Cur->getOperand(0)))) + return false; + + // Any reduction instruction must be of one of the allowed kinds. We ignore + // the starting value (the Phi or an AND instruction if the Phi has been + // type-promoted). + if (Cur != Start) { + ReduxDesc = isRecurrenceInstr(Cur, Kind, ReduxDesc, HasFunNoNaNAttr); + if (!ReduxDesc.isRecurrence()) + return false; + } + + bool IsASelect = isa<SelectInst>(Cur); + + // A conditional reduction operation must only have 2 or less uses in + // VisitedInsts. + if (IsASelect && (Kind == RK_FloatAdd || Kind == RK_FloatMult) && + hasMultipleUsesOf(Cur, VisitedInsts, 2)) + return false; + + // A reduction operation must only have one use of the reduction value. + if (!IsAPhi && !IsASelect && Kind != RK_IntegerMinMax && + Kind != RK_FloatMinMax && hasMultipleUsesOf(Cur, VisitedInsts, 1)) + return false; + + // All inputs to a PHI node must be a reduction value. + if (IsAPhi && Cur != Phi && !areAllUsesIn(Cur, VisitedInsts)) + return false; + + if (Kind == RK_IntegerMinMax && + (isa<ICmpInst>(Cur) || isa<SelectInst>(Cur))) + ++NumCmpSelectPatternInst; + if (Kind == RK_FloatMinMax && (isa<FCmpInst>(Cur) || isa<SelectInst>(Cur))) + ++NumCmpSelectPatternInst; + + // Check whether we found a reduction operator. + FoundReduxOp |= !IsAPhi && Cur != Start; + + // Process users of current instruction. Push non-PHI nodes after PHI nodes + // onto the stack. This way we are going to have seen all inputs to PHI + // nodes once we get to them. + SmallVector<Instruction *, 8> NonPHIs; + SmallVector<Instruction *, 8> PHIs; + for (User *U : Cur->users()) { + Instruction *UI = cast<Instruction>(U); + + // Check if we found the exit user. + BasicBlock *Parent = UI->getParent(); + if (!TheLoop->contains(Parent)) { + // If we already know this instruction is used externally, move on to + // the next user. + if (ExitInstruction == Cur) + continue; + + // Exit if you find multiple values used outside or if the header phi + // node is being used. In this case the user uses the value of the + // previous iteration, in which case we would loose "VF-1" iterations of + // the reduction operation if we vectorize. + if (ExitInstruction != nullptr || Cur == Phi) + return false; + + // The instruction used by an outside user must be the last instruction + // before we feed back to the reduction phi. Otherwise, we loose VF-1 + // operations on the value. + if (!is_contained(Phi->operands(), Cur)) + return false; + + ExitInstruction = Cur; + continue; + } + + // Process instructions only once (termination). Each reduction cycle + // value must only be used once, except by phi nodes and min/max + // reductions which are represented as a cmp followed by a select. + InstDesc IgnoredVal(false, nullptr); + if (VisitedInsts.insert(UI).second) { + if (isa<PHINode>(UI)) + PHIs.push_back(UI); + else + NonPHIs.push_back(UI); + } else if (!isa<PHINode>(UI) && + ((!isa<FCmpInst>(UI) && !isa<ICmpInst>(UI) && + !isa<SelectInst>(UI)) || + (!isConditionalRdxPattern(Kind, UI).isRecurrence() && + !isMinMaxSelectCmpPattern(UI, IgnoredVal).isRecurrence()))) + return false; + + // Remember that we completed the cycle. + if (UI == Phi) + FoundStartPHI = true; + } + Worklist.append(PHIs.begin(), PHIs.end()); + Worklist.append(NonPHIs.begin(), NonPHIs.end()); + } + + // This means we have seen one but not the other instruction of the + // pattern or more than just a select and cmp. + if ((Kind == RK_IntegerMinMax || Kind == RK_FloatMinMax) && + NumCmpSelectPatternInst != 2) + return false; + + if (!FoundStartPHI || !FoundReduxOp || !ExitInstruction) + return false; + + if (Start != Phi) { + // If the starting value is not the same as the phi node, we speculatively + // looked through an 'and' instruction when evaluating a potential + // arithmetic reduction to determine if it may have been type-promoted. + // + // We now compute the minimal bit width that is required to represent the + // reduction. If this is the same width that was indicated by the 'and', we + // can represent the reduction in the smaller type. The 'and' instruction + // will be eliminated since it will essentially be a cast instruction that + // can be ignore in the cost model. If we compute a different type than we + // did when evaluating the 'and', the 'and' will not be eliminated, and we + // will end up with different kinds of operations in the recurrence + // expression (e.g., RK_IntegerAND, RK_IntegerADD). We give up if this is + // the case. + // + // The vectorizer relies on InstCombine to perform the actual + // type-shrinking. It does this by inserting instructions to truncate the + // exit value of the reduction to the width indicated by RecurrenceType and + // then extend this value back to the original width. If IsSigned is false, + // a 'zext' instruction will be generated; otherwise, a 'sext' will be + // used. + // + // TODO: We should not rely on InstCombine to rewrite the reduction in the + // smaller type. We should just generate a correctly typed expression + // to begin with. + Type *ComputedType; + std::tie(ComputedType, IsSigned) = + computeRecurrenceType(ExitInstruction, DB, AC, DT); + if (ComputedType != RecurrenceType) + return false; + + // The recurrence expression will be represented in a narrower type. If + // there are any cast instructions that will be unnecessary, collect them + // in CastInsts. Note that the 'and' instruction was already included in + // this list. + // + // TODO: A better way to represent this may be to tag in some way all the + // instructions that are a part of the reduction. The vectorizer cost + // model could then apply the recurrence type to these instructions, + // without needing a white list of instructions to ignore. + collectCastsToIgnore(TheLoop, ExitInstruction, RecurrenceType, CastInsts); + } + + // We found a reduction var if we have reached the original phi node and we + // only have a single instruction with out-of-loop users. + + // The ExitInstruction(Instruction which is allowed to have out-of-loop users) + // is saved as part of the RecurrenceDescriptor. + + // Save the description of this reduction variable. + RecurrenceDescriptor RD( + RdxStart, ExitInstruction, Kind, ReduxDesc.getMinMaxKind(), + ReduxDesc.getUnsafeAlgebraInst(), RecurrenceType, IsSigned, CastInsts); + RedDes = RD; + + return true; +} + +/// Returns true if the instruction is a Select(ICmp(X, Y), X, Y) instruction +/// pattern corresponding to a min(X, Y) or max(X, Y). +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isMinMaxSelectCmpPattern(Instruction *I, InstDesc &Prev) { + + assert((isa<ICmpInst>(I) || isa<FCmpInst>(I) || isa<SelectInst>(I)) && + "Expect a select instruction"); + Instruction *Cmp = nullptr; + SelectInst *Select = nullptr; + + // We must handle the select(cmp()) as a single instruction. Advance to the + // select. + if ((Cmp = dyn_cast<ICmpInst>(I)) || (Cmp = dyn_cast<FCmpInst>(I))) { + if (!Cmp->hasOneUse() || !(Select = dyn_cast<SelectInst>(*I->user_begin()))) + return InstDesc(false, I); + return InstDesc(Select, Prev.getMinMaxKind()); + } + + // Only handle single use cases for now. + if (!(Select = dyn_cast<SelectInst>(I))) + return InstDesc(false, I); + if (!(Cmp = dyn_cast<ICmpInst>(I->getOperand(0))) && + !(Cmp = dyn_cast<FCmpInst>(I->getOperand(0)))) + return InstDesc(false, I); + if (!Cmp->hasOneUse()) + return InstDesc(false, I); + + Value *CmpLeft; + Value *CmpRight; + + // Look for a min/max pattern. + if (m_UMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_UIntMin); + else if (m_UMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_UIntMax); + else if (m_SMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_SIntMax); + else if (m_SMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_SIntMin); + else if (m_OrdFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMin); + else if (m_OrdFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMax); + else if (m_UnordFMin(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMin); + else if (m_UnordFMax(m_Value(CmpLeft), m_Value(CmpRight)).match(Select)) + return InstDesc(Select, MRK_FloatMax); + + return InstDesc(false, I); +} + +/// Returns true if the select instruction has users in the compare-and-add +/// reduction pattern below. The select instruction argument is the last one +/// in the sequence. +/// +/// %sum.1 = phi ... +/// ... +/// %cmp = fcmp pred %0, %CFP +/// %add = fadd %0, %sum.1 +/// %sum.2 = select %cmp, %add, %sum.1 +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isConditionalRdxPattern( + RecurrenceKind Kind, Instruction *I) { + SelectInst *SI = dyn_cast<SelectInst>(I); + if (!SI) + return InstDesc(false, I); + + CmpInst *CI = dyn_cast<CmpInst>(SI->getCondition()); + // Only handle single use cases for now. + if (!CI || !CI->hasOneUse()) + return InstDesc(false, I); + + Value *TrueVal = SI->getTrueValue(); + Value *FalseVal = SI->getFalseValue(); + // Handle only when either of operands of select instruction is a PHI + // node for now. + if ((isa<PHINode>(*TrueVal) && isa<PHINode>(*FalseVal)) || + (!isa<PHINode>(*TrueVal) && !isa<PHINode>(*FalseVal))) + return InstDesc(false, I); + + Instruction *I1 = + isa<PHINode>(*TrueVal) ? dyn_cast<Instruction>(FalseVal) + : dyn_cast<Instruction>(TrueVal); + if (!I1 || !I1->isBinaryOp()) + return InstDesc(false, I); + + Value *Op1, *Op2; + if ((m_FAdd(m_Value(Op1), m_Value(Op2)).match(I1) || + m_FSub(m_Value(Op1), m_Value(Op2)).match(I1)) && + I1->isFast()) + return InstDesc(Kind == RK_FloatAdd, SI); + + if (m_FMul(m_Value(Op1), m_Value(Op2)).match(I1) && (I1->isFast())) + return InstDesc(Kind == RK_FloatMult, SI); + + return InstDesc(false, I); +} + +RecurrenceDescriptor::InstDesc +RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, + InstDesc &Prev, bool HasFunNoNaNAttr) { + bool FP = I->getType()->isFloatingPointTy(); + Instruction *UAI = Prev.getUnsafeAlgebraInst(); + if (!UAI && FP && !I->isFast()) + UAI = I; // Found an unsafe (unvectorizable) algebra instruction. + + switch (I->getOpcode()) { + default: + return InstDesc(false, I); + case Instruction::PHI: + return InstDesc(I, Prev.getMinMaxKind(), Prev.getUnsafeAlgebraInst()); + case Instruction::Sub: + case Instruction::Add: + return InstDesc(Kind == RK_IntegerAdd, I); + case Instruction::Mul: + return InstDesc(Kind == RK_IntegerMult, I); + case Instruction::And: + return InstDesc(Kind == RK_IntegerAnd, I); + case Instruction::Or: + return InstDesc(Kind == RK_IntegerOr, I); + case Instruction::Xor: + return InstDesc(Kind == RK_IntegerXor, I); + case Instruction::FMul: + return InstDesc(Kind == RK_FloatMult, I, UAI); + case Instruction::FSub: + case Instruction::FAdd: + return InstDesc(Kind == RK_FloatAdd, I, UAI); + case Instruction::Select: + if (Kind == RK_FloatAdd || Kind == RK_FloatMult) + return isConditionalRdxPattern(Kind, I); + LLVM_FALLTHROUGH; + case Instruction::FCmp: + case Instruction::ICmp: + if (Kind != RK_IntegerMinMax && + (!HasFunNoNaNAttr || Kind != RK_FloatMinMax)) + return InstDesc(false, I); + return isMinMaxSelectCmpPattern(I, Prev); + } +} + +bool RecurrenceDescriptor::hasMultipleUsesOf( + Instruction *I, SmallPtrSetImpl<Instruction *> &Insts, + unsigned MaxNumUses) { + unsigned NumUses = 0; + for (User::op_iterator Use = I->op_begin(), E = I->op_end(); Use != E; + ++Use) { + if (Insts.count(dyn_cast<Instruction>(*Use))) + ++NumUses; + if (NumUses > MaxNumUses) + return true; + } + + return false; +} +bool RecurrenceDescriptor::isReductionPHI(PHINode *Phi, Loop *TheLoop, + RecurrenceDescriptor &RedDes, + DemandedBits *DB, AssumptionCache *AC, + DominatorTree *DT) { + + BasicBlock *Header = TheLoop->getHeader(); + Function &F = *Header->getParent(); + bool HasFunNoNaNAttr = + F.getFnAttribute("no-nans-fp-math").getValueAsString() == "true"; + + if (AddReductionVar(Phi, RK_IntegerAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an ADD reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MUL reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerOr, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an OR reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerAnd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an AND reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerXor, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a XOR reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_IntegerMinMax, TheLoop, HasFunNoNaNAttr, RedDes, + DB, AC, DT)) { + LLVM_DEBUG(dbgs() << "Found a MINMAX reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatMult, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FMult reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatAdd, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an FAdd reduction PHI." << *Phi << "\n"); + return true; + } + if (AddReductionVar(Phi, RK_FloatMinMax, TheLoop, HasFunNoNaNAttr, RedDes, DB, + AC, DT)) { + LLVM_DEBUG(dbgs() << "Found an float MINMAX reduction PHI." << *Phi + << "\n"); + return true; + } + // Not a reduction of known type. + return false; +} + +bool RecurrenceDescriptor::isFirstOrderRecurrence( + PHINode *Phi, Loop *TheLoop, + DenseMap<Instruction *, Instruction *> &SinkAfter, DominatorTree *DT) { + + // Ensure the phi node is in the loop header and has two incoming values. + if (Phi->getParent() != TheLoop->getHeader() || + Phi->getNumIncomingValues() != 2) + return false; + + // Ensure the loop has a preheader and a single latch block. The loop + // vectorizer will need the latch to set up the next iteration of the loop. + auto *Preheader = TheLoop->getLoopPreheader(); + auto *Latch = TheLoop->getLoopLatch(); + if (!Preheader || !Latch) + return false; + + // Ensure the phi node's incoming blocks are the loop preheader and latch. + if (Phi->getBasicBlockIndex(Preheader) < 0 || + Phi->getBasicBlockIndex(Latch) < 0) + return false; + + // Get the previous value. The previous value comes from the latch edge while + // the initial value comes form the preheader edge. + auto *Previous = dyn_cast<Instruction>(Phi->getIncomingValueForBlock(Latch)); + if (!Previous || !TheLoop->contains(Previous) || isa<PHINode>(Previous) || + SinkAfter.count(Previous)) // Cannot rely on dominance due to motion. + return false; + + // Ensure every user of the phi node is dominated by the previous value. + // The dominance requirement ensures the loop vectorizer will not need to + // vectorize the initial value prior to the first iteration of the loop. + // TODO: Consider extending this sinking to handle other kinds of instructions + // and expressions, beyond sinking a single cast past Previous. + if (Phi->hasOneUse()) { + auto *I = Phi->user_back(); + if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && + DT->dominates(Previous, I->user_back())) { + if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking. + SinkAfter[I] = Previous; + return true; + } + } + + for (User *U : Phi->users()) + if (auto *I = dyn_cast<Instruction>(U)) { + if (!DT->dominates(Previous, I)) + return false; + } + + return true; +} + +/// This function returns the identity element (or neutral element) for +/// the operation K. +Constant *RecurrenceDescriptor::getRecurrenceIdentity(RecurrenceKind K, + Type *Tp) { + switch (K) { + case RK_IntegerXor: + case RK_IntegerAdd: + case RK_IntegerOr: + // Adding, Xoring, Oring zero to a number does not change it. + return ConstantInt::get(Tp, 0); + case RK_IntegerMult: + // Multiplying a number by 1 does not change it. + return ConstantInt::get(Tp, 1); + case RK_IntegerAnd: + // AND-ing a number with an all-1 value does not change it. + return ConstantInt::get(Tp, -1, true); + case RK_FloatMult: + // Multiplying a number by 1 does not change it. + return ConstantFP::get(Tp, 1.0L); + case RK_FloatAdd: + // Adding zero to a number does not change it. + return ConstantFP::get(Tp, 0.0L); + default: + llvm_unreachable("Unknown recurrence kind"); + } +} + +/// This function translates the recurrence kind to an LLVM binary operator. +unsigned RecurrenceDescriptor::getRecurrenceBinOp(RecurrenceKind Kind) { + switch (Kind) { + case RK_IntegerAdd: + return Instruction::Add; + case RK_IntegerMult: + return Instruction::Mul; + case RK_IntegerOr: + return Instruction::Or; + case RK_IntegerAnd: + return Instruction::And; + case RK_IntegerXor: + return Instruction::Xor; + case RK_FloatMult: + return Instruction::FMul; + case RK_FloatAdd: + return Instruction::FAdd; + case RK_IntegerMinMax: + return Instruction::ICmp; + case RK_FloatMinMax: + return Instruction::FCmp; + default: + llvm_unreachable("Unknown recurrence operation"); + } +} + +InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, + const SCEV *Step, BinaryOperator *BOp, + SmallVectorImpl<Instruction *> *Casts) + : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { + assert(IK != IK_NoInduction && "Not an induction"); + + // Start value type should match the induction kind and the value + // itself should not be null. + assert(StartValue && "StartValue is null"); + assert((IK != IK_PtrInduction || StartValue->getType()->isPointerTy()) && + "StartValue is not a pointer for pointer induction"); + assert((IK != IK_IntInduction || StartValue->getType()->isIntegerTy()) && + "StartValue is not an integer for integer induction"); + + // Check the Step Value. It should be non-zero integer value. + assert((!getConstIntStepValue() || !getConstIntStepValue()->isZero()) && + "Step value is zero"); + + assert((IK != IK_PtrInduction || getConstIntStepValue()) && + "Step value should be constant for pointer induction"); + assert((IK == IK_FpInduction || Step->getType()->isIntegerTy()) && + "StepValue is not an integer"); + + assert((IK != IK_FpInduction || Step->getType()->isFloatingPointTy()) && + "StepValue is not FP for FpInduction"); + assert((IK != IK_FpInduction || + (InductionBinOp && + (InductionBinOp->getOpcode() == Instruction::FAdd || + InductionBinOp->getOpcode() == Instruction::FSub))) && + "Binary opcode should be specified for FP induction"); + + if (Casts) { + for (auto &Inst : *Casts) { + RedundantCasts.push_back(Inst); + } + } +} + +int InductionDescriptor::getConsecutiveDirection() const { + ConstantInt *ConstStep = getConstIntStepValue(); + if (ConstStep && (ConstStep->isOne() || ConstStep->isMinusOne())) + return ConstStep->getSExtValue(); + return 0; +} + +ConstantInt *InductionDescriptor::getConstIntStepValue() const { + if (isa<SCEVConstant>(Step)) + return dyn_cast<ConstantInt>(cast<SCEVConstant>(Step)->getValue()); + return nullptr; +} + +bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, + ScalarEvolution *SE, + InductionDescriptor &D) { + + // Here we only handle FP induction variables. + assert(Phi->getType()->isFloatingPointTy() && "Unexpected Phi type"); + + if (TheLoop->getHeader() != Phi->getParent()) + return false; + + // The loop may have multiple entrances or multiple exits; we can analyze + // this phi if it has a unique entry value and a unique backedge value. + if (Phi->getNumIncomingValues() != 2) + return false; + Value *BEValue = nullptr, *StartValue = nullptr; + if (TheLoop->contains(Phi->getIncomingBlock(0))) { + BEValue = Phi->getIncomingValue(0); + StartValue = Phi->getIncomingValue(1); + } else { + assert(TheLoop->contains(Phi->getIncomingBlock(1)) && + "Unexpected Phi node in the loop"); + BEValue = Phi->getIncomingValue(1); + StartValue = Phi->getIncomingValue(0); + } + + BinaryOperator *BOp = dyn_cast<BinaryOperator>(BEValue); + if (!BOp) + return false; + + Value *Addend = nullptr; + if (BOp->getOpcode() == Instruction::FAdd) { + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + else if (BOp->getOperand(1) == Phi) + Addend = BOp->getOperand(0); + } else if (BOp->getOpcode() == Instruction::FSub) + if (BOp->getOperand(0) == Phi) + Addend = BOp->getOperand(1); + + if (!Addend) + return false; + + // The addend should be loop invariant + if (auto *I = dyn_cast<Instruction>(Addend)) + if (TheLoop->contains(I)) + return false; + + // FP Step has unknown SCEV + const SCEV *Step = SE->getUnknown(Addend); + D = InductionDescriptor(StartValue, IK_FpInduction, Step, BOp); + return true; +} + +/// This function is called when we suspect that the update-chain of a phi node +/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, +/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime +/// predicate P under which the SCEV expression for the phi can be the +/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the +/// cast instructions that are involved in the update-chain of this induction. +/// A caller that adds the required runtime predicate can be free to drop these +/// cast instructions, and compute the phi using \p AR (instead of some scev +/// expression with casts). +/// +/// For example, without a predicate the scev expression can take the following +/// form: +/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy) +/// +/// It corresponds to the following IR sequence: +/// %for.body: +/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ] +/// %casted_phi = "ExtTrunc i64 %x" +/// %add = add i64 %casted_phi, %step +/// +/// where %x is given in \p PN, +/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate, +/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of +/// several forms, for example, such as: +/// ExtTrunc1: %casted_phi = and %x, 2^n-1 +/// or: +/// ExtTrunc2: %t = shl %x, m +/// %casted_phi = ashr %t, m +/// +/// If we are able to find such sequence, we return the instructions +/// we found, namely %casted_phi and the instructions on its use-def chain up +/// to the phi (not including the phi). +static bool getCastsForInductionPHI(PredicatedScalarEvolution &PSE, + const SCEVUnknown *PhiScev, + const SCEVAddRecExpr *AR, + SmallVectorImpl<Instruction *> &CastInsts) { + + assert(CastInsts.empty() && "CastInsts is expected to be empty."); + auto *PN = cast<PHINode>(PhiScev->getValue()); + assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); + const Loop *L = AR->getLoop(); + + // Find any cast instructions that participate in the def-use chain of + // PhiScev in the loop. + // FORNOW/TODO: We currently expect the def-use chain to include only + // two-operand instructions, where one of the operands is an invariant. + // createAddRecFromPHIWithCasts() currently does not support anything more + // involved than that, so we keep the search simple. This can be + // extended/generalized as needed. + + auto getDef = [&](const Value *Val) -> Value * { + const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val); + if (!BinOp) + return nullptr; + Value *Op0 = BinOp->getOperand(0); + Value *Op1 = BinOp->getOperand(1); + Value *Def = nullptr; + if (L->isLoopInvariant(Op0)) + Def = Op1; + else if (L->isLoopInvariant(Op1)) + Def = Op0; + return Def; + }; + + // Look for the instruction that defines the induction via the + // loop backedge. + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + Value *Val = PN->getIncomingValueForBlock(Latch); + if (!Val) + return false; + + // Follow the def-use chain until the induction phi is reached. + // If on the way we encounter a Value that has the same SCEV Expr as the + // phi node, we can consider the instructions we visit from that point + // as part of the cast-sequence that can be ignored. + bool InCastSequence = false; + auto *Inst = dyn_cast<Instruction>(Val); + while (Val != PN) { + // If we encountered a phi node other than PN, or if we left the loop, + // we bail out. + if (!Inst || !L->contains(Inst)) { + return false; + } + auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val)); + if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR)) + InCastSequence = true; + if (InCastSequence) { + // Only the last instruction in the cast sequence is expected to have + // uses outside the induction def-use chain. + if (!CastInsts.empty()) + if (!Inst->hasOneUse()) + return false; + CastInsts.push_back(Inst); + } + Val = getDef(Val); + if (!Val) + return false; + Inst = dyn_cast<Instruction>(Val); + } + + return InCastSequence; +} + +bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, + PredicatedScalarEvolution &PSE, + InductionDescriptor &D, bool Assume) { + Type *PhiTy = Phi->getType(); + + // Handle integer and pointer inductions variables. + // Now we handle also FP induction but not trying to make a + // recurrent expression from the PHI node in-place. + + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy() && !PhiTy->isFloatTy() && + !PhiTy->isDoubleTy() && !PhiTy->isHalfTy()) + return false; + + if (PhiTy->isFloatingPointTy()) + return isFPInductionPHI(Phi, TheLoop, PSE.getSE(), D); + + const SCEV *PhiScev = PSE.getSCEV(Phi); + const auto *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + + // We need this expression to be an AddRecExpr. + if (Assume && !AR) + AR = PSE.getAsAddRec(Phi); + + if (!AR) { + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + return false; + } + + // Record any Cast instructions that participate in the induction update + const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev); + // If we started from an UnknownSCEV, and managed to build an addRecurrence + // only after enabling Assume with PSCEV, this means we may have encountered + // cast instructions that required adding a runtime check in order to + // guarantee the correctness of the AddRecurence respresentation of the + // induction. + if (PhiScev != AR && SymbolicPhi) { + SmallVector<Instruction *, 2> Casts; + if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts)) + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts); + } + + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); +} + +bool InductionDescriptor::isInductionPHI( + PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, + InductionDescriptor &D, const SCEV *Expr, + SmallVectorImpl<Instruction *> *CastsToIgnore) { + Type *PhiTy = Phi->getType(); + // We only handle integer and pointer inductions variables. + if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) + return false; + + // Check that the PHI is consecutive. + const SCEV *PhiScev = Expr ? Expr : SE->getSCEV(Phi); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(PhiScev); + + if (!AR) { + LLVM_DEBUG(dbgs() << "LV: PHI is not a poly recurrence.\n"); + return false; + } + + if (AR->getLoop() != TheLoop) { + // FIXME: We should treat this as a uniform. Unfortunately, we + // don't currently know how to handled uniform PHIs. + LLVM_DEBUG( + dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); + return false; + } + + Value *StartValue = + Phi->getIncomingValueForBlock(AR->getLoop()->getLoopPreheader()); + const SCEV *Step = AR->getStepRecurrence(*SE); + // Calculate the pointer stride and check if it is consecutive. + // The stride may be a constant or a loop invariant integer value. + const SCEVConstant *ConstStep = dyn_cast<SCEVConstant>(Step); + if (!ConstStep && !SE->isLoopInvariant(Step, TheLoop)) + return false; + + if (PhiTy->isIntegerTy()) { + D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/nullptr, + CastsToIgnore); + return true; + } + + assert(PhiTy->isPointerTy() && "The PHI must be a pointer"); + // Pointer induction should be a constant. + if (!ConstStep) + return false; + + ConstantInt *CV = ConstStep->getValue(); + Type *PointerElementType = PhiTy->getPointerElementType(); + // The pointer stride cannot be determined if the pointer element type is not + // sized. + if (!PointerElementType->isSized()) + return false; + + const DataLayout &DL = Phi->getModule()->getDataLayout(); + int64_t Size = static_cast<int64_t>(DL.getTypeAllocSize(PointerElementType)); + if (!Size) + return false; + + int64_t CVSize = CV->getSExtValue(); + if (CVSize % Size) + return false; + auto *StepValue = + SE->getConstant(CV->getType(), CVSize / Size, true /* signed */); + D = InductionDescriptor(StartValue, IK_PtrInduction, StepValue); + return true; +} diff --git a/lib/Analysis/IndirectCallPromotionAnalysis.cpp b/lib/Analysis/IndirectCallPromotionAnalysis.cpp index 4659c0a00629..d6e6e76af03c 100644 --- a/lib/Analysis/IndirectCallPromotionAnalysis.cpp +++ b/lib/Analysis/IndirectCallPromotionAnalysis.cpp @@ -15,7 +15,7 @@ #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/ADT/STLExtras.h" -#include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/IndirectCallVisitor.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/InstVisitor.h" diff --git a/lib/Analysis/InlineCost.cpp b/lib/Analysis/InlineCost.cpp index a6cccc3b5910..6ddb3cbc01a3 100644 --- a/lib/Analysis/InlineCost.cpp +++ b/lib/Analysis/InlineCost.cpp @@ -23,6 +23,7 @@ #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -30,6 +31,7 @@ #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/InstVisitor.h" @@ -137,7 +139,7 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { bool HasReturn; bool HasIndirectBr; bool HasUninlineableIntrinsic; - bool UsesVarArgs; + bool InitsVargArgs; /// Number of bytes allocated statically by the callee. uint64_t AllocatedSize; @@ -227,7 +229,8 @@ class CallAnalyzer : public InstVisitor<CallAnalyzer, bool> { BlockFrequencyInfo *CallerBFI); // Custom analysis routines. - bool analyzeBlock(BasicBlock *BB, SmallPtrSetImpl<const Value *> &EphValues); + InlineResult analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues); // Disable several entry points to the visitor so we don't accidentally use // them by declaring but not defining them here. @@ -282,7 +285,7 @@ public: IsCallerRecursive(false), IsRecursiveCall(false), ExposesReturnsTwice(false), HasDynamicAlloca(false), ContainsNoDuplicateCall(false), HasReturn(false), HasIndirectBr(false), - HasUninlineableIntrinsic(false), UsesVarArgs(false), AllocatedSize(0), + HasUninlineableIntrinsic(false), InitsVargArgs(false), AllocatedSize(0), NumInstructions(0), NumVectorInstructions(0), VectorBonus(0), SingleBBBonus(0), EnableLoadElimination(true), LoadEliminationCost(0), NumConstantArgs(0), NumConstantOffsetPtrArgs(0), NumAllocaArgs(0), @@ -290,7 +293,7 @@ public: NumInstructionsSimplified(0), SROACostSavings(0), SROACostSavingsLost(0) {} - bool analyzeCall(CallSite CS); + InlineResult analyzeCall(CallSite CS); int getThreshold() { return Threshold; } int getCost() { return Cost; } @@ -719,6 +722,7 @@ bool CallAnalyzer::visitCastInst(CastInst &I) { case Instruction::FPToSI: if (TTI.getFPOpCost(I.getType()) == TargetTransformInfo::TCC_Expensive) Cost += InlineConstants::CallPenalty; + break; default: break; } @@ -1238,8 +1242,7 @@ bool CallAnalyzer::visitCallSite(CallSite CS) { HasUninlineableIntrinsic = true; return false; case Intrinsic::vastart: - case Intrinsic::vaend: - UsesVarArgs = true; + InitsVargArgs = true; return false; } } @@ -1541,8 +1544,9 @@ bool CallAnalyzer::visitInstruction(Instruction &I) { /// aborts early if the threshold has been exceeded or an impossible to inline /// construct has been detected. It returns false if inlining is no longer /// viable, and true if inlining remains viable. -bool CallAnalyzer::analyzeBlock(BasicBlock *BB, - SmallPtrSetImpl<const Value *> &EphValues) { +InlineResult +CallAnalyzer::analyzeBlock(BasicBlock *BB, + SmallPtrSetImpl<const Value *> &EphValues) { for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ++I) { // FIXME: Currently, the number of instructions in a function regardless of // our ability to simplify them during inline to constants or dead code, @@ -1574,16 +1578,29 @@ bool CallAnalyzer::analyzeBlock(BasicBlock *BB, using namespace ore; // If the visit this instruction detected an uninlinable pattern, abort. - if (IsRecursiveCall || ExposesReturnsTwice || HasDynamicAlloca || - HasIndirectBr || HasUninlineableIntrinsic || UsesVarArgs) { + InlineResult IR; + if (IsRecursiveCall) + IR = "recursive"; + else if (ExposesReturnsTwice) + IR = "exposes returns twice"; + else if (HasDynamicAlloca) + IR = "dynamic alloca"; + else if (HasIndirectBr) + IR = "indirect branch"; + else if (HasUninlineableIntrinsic) + IR = "uninlinable intrinsic"; + else if (InitsVargArgs) + IR = "varargs"; + if (!IR) { if (ORE) ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", CandidateCS.getInstruction()) - << NV("Callee", &F) - << " has uninlinable pattern and cost is not fully computed"; + << NV("Callee", &F) << " has uninlinable pattern (" + << NV("InlineResult", IR.message) + << ") and cost is not fully computed"; }); - return false; + return IR; } // If the caller is a recursive function then we don't want to inline @@ -1591,15 +1608,15 @@ bool CallAnalyzer::analyzeBlock(BasicBlock *BB, // the caller stack usage dramatically. if (IsCallerRecursive && AllocatedSize > InlineConstants::TotalAllocaSizeRecursiveCaller) { + InlineResult IR = "recursive and allocates too much stack space"; if (ORE) ORE->emit([&]() { return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", CandidateCS.getInstruction()) - << NV("Callee", &F) - << " is recursive and allocates too much stack space. Cost is " - "not fully computed"; + << NV("Callee", &F) << " is " << NV("InlineResult", IR.message) + << ". Cost is not fully computed"; }); - return false; + return IR; } // Check if we've past the maximum possible threshold so we don't spin in @@ -1695,7 +1712,7 @@ void CallAnalyzer::findDeadBlocks(BasicBlock *CurrBB, BasicBlock *NextBB) { /// factors and heuristics. If this method returns false but the computed cost /// is below the computed threshold, then inlining was forcibly disabled by /// some artifact of the routine. -bool CallAnalyzer::analyzeCall(CallSite CS) { +InlineResult CallAnalyzer::analyzeCall(CallSite CS) { ++NumCallsAnalyzed; // Perform some tweaks to the cost and threshold based on the direct @@ -1714,6 +1731,13 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // Update the threshold based on callsite properties updateThreshold(CS, F); + // While Threshold depends on commandline options that can take negative + // values, we want to enforce the invariant that the computed threshold and + // bonuses are non-negative. + assert(Threshold >= 0); + assert(SingleBBBonus >= 0); + assert(VectorBonus >= 0); + // Speculatively apply all possible bonuses to Threshold. If cost exceeds // this Threshold any time, and cost cannot decrease, we can stop processing // the rest of the function body. @@ -1730,7 +1754,7 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // Check if we're done. This can happen due to bonuses and penalties. if (Cost >= Threshold && !ComputeFullInlineCost) - return false; + return "high cost"; if (F.empty()) return true; @@ -1809,14 +1833,15 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // site. If the blockaddress escapes the function, e.g., via a global // variable, inlining may lead to an invalid cross-function reference. if (BB->hasAddressTaken()) - return false; + return "blockaddress"; // Analyze the cost of this block. If we blow through the threshold, this // returns false, and we can bail on out. - if (!analyzeBlock(BB, EphValues)) - return false; + InlineResult IR = analyzeBlock(BB, EphValues); + if (!IR) + return IR; - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); // Add in the live successors by first checking whether we have terminator // that may be simplified based on the values simplified by this call. @@ -1867,7 +1892,25 @@ bool CallAnalyzer::analyzeCall(CallSite CS) { // inlining this would cause the removal of the caller (so the instruction // is not actually duplicated, just moved). if (!OnlyOneCallAndLocalLinkage && ContainsNoDuplicateCall) - return false; + return "noduplicate"; + + // Loops generally act a lot like calls in that they act like barriers to + // movement, require a certain amount of setup, etc. So when optimising for + // size, we penalise any call sites that perform loops. We do this after all + // other costs here, so will likely only be dealing with relatively small + // functions (and hence DT and LI will hopefully be cheap). + if (Caller->optForMinSize()) { + DominatorTree DT(F); + LoopInfo LI(DT); + int NumLoops = 0; + for (Loop *L : LI) { + // Ignore loops that will not be executed + if (DeadBlocks.count(L->getHeader())) + continue; + NumLoops++; + } + Cost += NumLoops * InlineConstants::CallPenalty; + } // We applied the maximum possible vector bonus at the beginning. Now, // subtract the excess bonus, if any, from the Threshold before @@ -1961,7 +2004,7 @@ InlineCost llvm::getInlineCost( // Cannot inline indirect calls. if (!Callee) - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getNever("indirect call"); // Never inline calls with byval arguments that does not have the alloca // address space. Since byval arguments can be replaced with a copy to an @@ -1973,54 +2016,59 @@ InlineCost llvm::getInlineCost( if (CS.isByValArgument(I)) { PointerType *PTy = cast<PointerType>(CS.getArgument(I)->getType()); if (PTy->getAddressSpace() != AllocaAS) - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getNever("byval arguments without alloca" + " address space"); } // Calls to functions with always-inline attributes should be inlined // whenever possible. if (CS.hasFnAttr(Attribute::AlwaysInline)) { if (isInlineViable(*Callee)) - return llvm::InlineCost::getAlways(); - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getAlways("always inline attribute"); + return llvm::InlineCost::getNever("inapplicable always inline attribute"); } // Never inline functions with conflicting attributes (unless callee has // always-inline attribute). Function *Caller = CS.getCaller(); if (!functionsHaveCompatibleAttributes(Caller, Callee, CalleeTTI)) - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getNever("conflicting attributes"); // Don't inline this call if the caller has the optnone attribute. if (Caller->hasFnAttribute(Attribute::OptimizeNone)) - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getNever("optnone attribute"); // Don't inline a function that treats null pointer as valid into a caller // that does not have this attribute. if (!Caller->nullPointerIsDefined() && Callee->nullPointerIsDefined()) - return llvm::InlineCost::getNever(); + return llvm::InlineCost::getNever("nullptr definitions incompatible"); + + // Don't inline functions which can be interposed at link-time. + if (Callee->isInterposable()) + return llvm::InlineCost::getNever("interposable"); + + // Don't inline functions marked noinline. + if (Callee->hasFnAttribute(Attribute::NoInline)) + return llvm::InlineCost::getNever("noinline function attribute"); - // Don't inline functions which can be interposed at link-time. Don't inline - // functions marked noinline or call sites marked noinline. - // Note: inlining non-exact non-interposable functions is fine, since we know - // we have *a* correct implementation of the source level function. - if (Callee->isInterposable() || Callee->hasFnAttribute(Attribute::NoInline) || - CS.isNoInline()) - return llvm::InlineCost::getNever(); + // Don't inline call sites marked noinline. + if (CS.isNoInline()) + return llvm::InlineCost::getNever("noinline call site attribute"); LLVM_DEBUG(llvm::dbgs() << " Analyzing call of " << Callee->getName() << "... (caller:" << Caller->getName() << ")\n"); CallAnalyzer CA(CalleeTTI, GetAssumptionCache, GetBFI, PSI, ORE, *Callee, CS, Params); - bool ShouldInline = CA.analyzeCall(CS); + InlineResult ShouldInline = CA.analyzeCall(CS); LLVM_DEBUG(CA.dump()); // Check if there was a reason to force inlining or no inlining. if (!ShouldInline && CA.getCost() < CA.getThreshold()) - return InlineCost::getNever(); + return InlineCost::getNever(ShouldInline.message); if (ShouldInline && CA.getCost() >= CA.getThreshold()) - return InlineCost::getAlways(); + return InlineCost::getAlways("empty function"); return llvm::InlineCost::get(CA.getCost(), CA.getThreshold()); } @@ -2058,9 +2106,8 @@ bool llvm::isInlineViable(Function &F) { // Disallow inlining functions that call @llvm.localescape. Doing this // correctly would require major changes to the inliner. case llvm::Intrinsic::localescape: - // Disallow inlining of functions that access VarArgs. + // Disallow inlining of functions that initialize VarArgs with va_start. case llvm::Intrinsic::vastart: - case llvm::Intrinsic::vaend: return false; } } diff --git a/lib/Analysis/InstructionPrecedenceTracking.cpp b/lib/Analysis/InstructionPrecedenceTracking.cpp new file mode 100644 index 000000000000..816126f407ca --- /dev/null +++ b/lib/Analysis/InstructionPrecedenceTracking.cpp @@ -0,0 +1,157 @@ +//===-- InstructionPrecedenceTracking.cpp -----------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// Implements a class that is able to define some instructions as "special" +// (e.g. as having implicit control flow, or writing memory, or having another +// interesting property) and then efficiently answers queries of the types: +// 1. Are there any special instructions in the block of interest? +// 2. Return first of the special instructions in the given block; +// 3. Check if the given instruction is preceeded by the first special +// instruction in the same block. +// The class provides caching that allows to answer these queries quickly. The +// user must make sure that the cached data is invalidated properly whenever +// a content of some tracked block is changed. +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InstructionPrecedenceTracking.h" +#include "llvm/Analysis/ValueTracking.h" + +using namespace llvm; + +#ifndef NDEBUG +static cl::opt<bool> ExpensiveAsserts( + "ipt-expensive-asserts", + cl::desc("Perform expensive assert validation on every query to Instruction" + " Precedence Tracking"), + cl::init(false), cl::Hidden); +#endif + +const Instruction *InstructionPrecedenceTracking::getFirstSpecialInstruction( + const BasicBlock *BB) { +#ifndef NDEBUG + // If there is a bug connected to invalid cache, turn on ExpensiveAsserts to + // catch this situation as early as possible. + if (ExpensiveAsserts) + validateAll(); + else + validate(BB); +#endif + + if (FirstSpecialInsts.find(BB) == FirstSpecialInsts.end()) { + fill(BB); + assert(FirstSpecialInsts.find(BB) != FirstSpecialInsts.end() && "Must be!"); + } + return FirstSpecialInsts[BB]; +} + +bool InstructionPrecedenceTracking::hasSpecialInstructions( + const BasicBlock *BB) { + return getFirstSpecialInstruction(BB) != nullptr; +} + +bool InstructionPrecedenceTracking::isPreceededBySpecialInstruction( + const Instruction *Insn) { + const Instruction *MaybeFirstSpecial = + getFirstSpecialInstruction(Insn->getParent()); + return MaybeFirstSpecial && OI.dominates(MaybeFirstSpecial, Insn); +} + +void InstructionPrecedenceTracking::fill(const BasicBlock *BB) { + FirstSpecialInsts.erase(BB); + for (auto &I : *BB) + if (isSpecialInstruction(&I)) { + FirstSpecialInsts[BB] = &I; + return; + } + + // Mark this block as having no special instructions. + FirstSpecialInsts[BB] = nullptr; +} + +#ifndef NDEBUG +void InstructionPrecedenceTracking::validate(const BasicBlock *BB) const { + auto It = FirstSpecialInsts.find(BB); + // Bail if we don't have anything cached for this block. + if (It == FirstSpecialInsts.end()) + return; + + for (const Instruction &Insn : *BB) + if (isSpecialInstruction(&Insn)) { + assert(It->second == &Insn && + "Cached first special instruction is wrong!"); + return; + } + + assert(It->second == nullptr && + "Block is marked as having special instructions but in fact it has " + "none!"); +} + +void InstructionPrecedenceTracking::validateAll() const { + // Check that for every known block the cached value is correct. + for (auto &It : FirstSpecialInsts) + validate(It.first); +} +#endif + +void InstructionPrecedenceTracking::insertInstructionTo(const Instruction *Inst, + const BasicBlock *BB) { + if (isSpecialInstruction(Inst)) + FirstSpecialInsts.erase(BB); + OI.invalidateBlock(BB); +} + +void InstructionPrecedenceTracking::removeInstruction(const Instruction *Inst) { + if (isSpecialInstruction(Inst)) + FirstSpecialInsts.erase(Inst->getParent()); + OI.invalidateBlock(Inst->getParent()); +} + +void InstructionPrecedenceTracking::clear() { + for (auto It : FirstSpecialInsts) + OI.invalidateBlock(It.first); + FirstSpecialInsts.clear(); +#ifndef NDEBUG + // The map should be valid after clearing (at least empty). + validateAll(); +#endif +} + +bool ImplicitControlFlowTracking::isSpecialInstruction( + const Instruction *Insn) const { + // If a block's instruction doesn't always pass the control to its successor + // instruction, mark the block as having implicit control flow. We use them + // to avoid wrong assumptions of sort "if A is executed and B post-dominates + // A, then B is also executed". This is not true is there is an implicit + // control flow instruction (e.g. a guard) between them. + // + // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false + // for volatile stores and loads because they can trap. The discussion on + // whether or not it is correct is still ongoing. We might want to get rid + // of this logic in the future. Anyways, trapping instructions shouldn't + // introduce implicit control flow, so we explicitly allow them here. This + // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed. + if (isGuaranteedToTransferExecutionToSuccessor(Insn)) + return false; + if (isa<LoadInst>(Insn)) { + assert(cast<LoadInst>(Insn)->isVolatile() && + "Non-volatile load should transfer execution to successor!"); + return false; + } + if (isa<StoreInst>(Insn)) { + assert(cast<StoreInst>(Insn)->isVolatile() && + "Non-volatile store should transfer execution to successor!"); + return false; + } + return true; +} + +bool MemoryWriteTracking::isSpecialInstruction( + const Instruction *Insn) const { + return Insn->mayWriteToMemory(); +} diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 7fc7c15a0c25..ccf907c144f0 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -861,8 +861,10 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // (X / Y) * Y -> X if the division is exact. Value *X = nullptr; - if (match(Op0, m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y - match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0))))) // Y * (X / Y) + if (Q.IIQ.UseInstrInfo && + (match(Op0, + m_Exact(m_IDiv(m_Value(X), m_Specific(Op1)))) || // (X / Y) * Y + match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))) // Y * (X / Y) return X; // i1 mul -> and. @@ -1035,8 +1037,8 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { auto *Mul = cast<OverflowingBinaryOperator>(Op0); // If the Mul does not overflow, then we are good to go. - if ((IsSigned && Mul->hasNoSignedWrap()) || - (!IsSigned && Mul->hasNoUnsignedWrap())) + if ((IsSigned && Q.IIQ.hasNoSignedWrap(Mul)) || + (!IsSigned && Q.IIQ.hasNoUnsignedWrap(Mul))) return X; // If X has the form X = A / Y, then X * Y cannot overflow. if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || @@ -1094,10 +1096,11 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, return Op0; // (X << Y) % X -> 0 - if ((Opcode == Instruction::SRem && - match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || - (Opcode == Instruction::URem && - match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))) + if (Q.IIQ.UseInstrInfo && + ((Opcode == Instruction::SRem && + match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || + (Opcode == Instruction::URem && + match(Op0, m_NUWShl(m_Specific(Op1), m_Value()))))) return Constant::getNullValue(Op0->getType()); // If the operation is with the result of a select instruction, check whether @@ -1295,7 +1298,8 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // (X >> A) << A -> X Value *X; - if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) + if (Q.IIQ.UseInstrInfo && + match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1))))) return X; // shl nuw i8 C, %x -> C iff C has sign bit set. @@ -1338,7 +1342,7 @@ static Value *SimplifyLShrInst(Value *Op0, Value *Op1, bool isExact, const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); const unsigned Width = Op0->getType()->getScalarSizeInBits(); const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); - if (EffWidthY <= ShRAmt->getZExtValue()) + if (ShRAmt->uge(EffWidthY)) return X; } @@ -1365,7 +1369,7 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact, // (X << A) >> A -> X Value *X; - if (match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) + if (Q.IIQ.UseInstrInfo && match(Op0, m_NSWShl(m_Value(X), m_Specific(Op1)))) return X; // Arithmetic shifting an all-sign-bit value is a no-op. @@ -1552,7 +1556,8 @@ static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; } -static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { +static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { // (icmp (add V, C0), C1) & (icmp V, C0) ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; @@ -1563,13 +1568,13 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { if (!match(Op1, m_ICmp(Pred1, m_Specific(V), m_Value()))) return nullptr; - auto *AddInst = cast<BinaryOperator>(Op0->getOperand(0)); + auto *AddInst = cast<OverflowingBinaryOperator>(Op0->getOperand(0)); if (AddInst->getOperand(1) != Op1->getOperand(1)) return nullptr; Type *ITy = Op0->getType(); - bool isNSW = AddInst->hasNoSignedWrap(); - bool isNUW = AddInst->hasNoUnsignedWrap(); + bool isNSW = IIQ.hasNoSignedWrap(AddInst); + bool isNUW = IIQ.hasNoUnsignedWrap(AddInst); const APInt Delta = *C1 - *C0; if (C0->isStrictlyPositive()) { @@ -1598,7 +1603,8 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; } -static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/true)) return X; if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true)) @@ -1615,15 +1621,16 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) return X; - if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1)) + if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1, IIQ)) return X; - if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) + if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0, IIQ)) return X; return nullptr; } -static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { +static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { // (icmp (add V, C0), C1) | (icmp V, C0) ICmpInst::Predicate Pred0, Pred1; const APInt *C0, *C1; @@ -1639,8 +1646,8 @@ static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; Type *ITy = Op0->getType(); - bool isNSW = AddInst->hasNoSignedWrap(); - bool isNUW = AddInst->hasNoUnsignedWrap(); + bool isNSW = IIQ.hasNoSignedWrap(AddInst); + bool isNUW = IIQ.hasNoUnsignedWrap(AddInst); const APInt Delta = *C1 - *C0; if (C0->isStrictlyPositive()) { @@ -1669,7 +1676,8 @@ static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) { return nullptr; } -static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { +static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, + const InstrInfoQuery &IIQ) { if (Value *X = simplifyUnsignedRangeCheck(Op0, Op1, /*IsAnd=*/false)) return X; if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false)) @@ -1686,15 +1694,16 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) { if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) return X; - if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1)) + if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1, IIQ)) return X; - if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) + if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0, IIQ)) return X; return nullptr; } -static Value *simplifyAndOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { +static Value *simplifyAndOrOfFCmps(const TargetLibraryInfo *TLI, + FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); if (LHS0->getType() != RHS0->getType()) @@ -1711,8 +1720,8 @@ static Value *simplifyAndOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { // (fcmp uno NNAN, X) | (fcmp uno Y, X) --> fcmp uno Y, X // (fcmp uno X, NNAN) | (fcmp uno X, Y) --> fcmp uno X, Y // (fcmp uno X, NNAN) | (fcmp uno Y, X) --> fcmp uno Y, X - if ((isKnownNeverNaN(LHS0) && (LHS1 == RHS0 || LHS1 == RHS1)) || - (isKnownNeverNaN(LHS1) && (LHS0 == RHS0 || LHS0 == RHS1))) + if ((isKnownNeverNaN(LHS0, TLI) && (LHS1 == RHS0 || LHS1 == RHS1)) || + (isKnownNeverNaN(LHS1, TLI) && (LHS0 == RHS0 || LHS0 == RHS1))) return RHS; // (fcmp ord X, Y) & (fcmp ord NNAN, X) --> fcmp ord X, Y @@ -1723,15 +1732,16 @@ static Value *simplifyAndOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { // (fcmp uno Y, X) | (fcmp uno NNAN, X) --> fcmp uno Y, X // (fcmp uno X, Y) | (fcmp uno X, NNAN) --> fcmp uno X, Y // (fcmp uno Y, X) | (fcmp uno X, NNAN) --> fcmp uno Y, X - if ((isKnownNeverNaN(RHS0) && (RHS1 == LHS0 || RHS1 == LHS1)) || - (isKnownNeverNaN(RHS1) && (RHS0 == LHS0 || RHS0 == LHS1))) + if ((isKnownNeverNaN(RHS0, TLI) && (RHS1 == LHS0 || RHS1 == LHS1)) || + (isKnownNeverNaN(RHS1, TLI) && (RHS0 == LHS0 || RHS0 == LHS1))) return LHS; } return nullptr; } -static Value *simplifyAndOrOfCmps(Value *Op0, Value *Op1, bool IsAnd) { +static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, + Value *Op0, Value *Op1, bool IsAnd) { // Look through casts of the 'and' operands to find compares. auto *Cast0 = dyn_cast<CastInst>(Op0); auto *Cast1 = dyn_cast<CastInst>(Op1); @@ -1745,13 +1755,13 @@ static Value *simplifyAndOrOfCmps(Value *Op0, Value *Op1, bool IsAnd) { auto *ICmp0 = dyn_cast<ICmpInst>(Op0); auto *ICmp1 = dyn_cast<ICmpInst>(Op1); if (ICmp0 && ICmp1) - V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1) : - simplifyOrOfICmps(ICmp0, ICmp1); + V = IsAnd ? simplifyAndOfICmps(ICmp0, ICmp1, Q.IIQ) + : simplifyOrOfICmps(ICmp0, ICmp1, Q.IIQ); auto *FCmp0 = dyn_cast<FCmpInst>(Op0); auto *FCmp1 = dyn_cast<FCmpInst>(Op1); if (FCmp0 && FCmp1) - V = simplifyAndOrOfFCmps(FCmp0, FCmp1, IsAnd); + V = simplifyAndOrOfFCmps(Q.TLI, FCmp0, FCmp1, IsAnd); if (!V) return nullptr; @@ -1831,7 +1841,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op1; } - if (Value *V = simplifyAndOrOfCmps(Op0, Op1, true)) + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true)) return V; // Try some generic simplifications for associative operations. @@ -1863,6 +1873,40 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, MaxRecurse)) return V; + // Assuming the effective width of Y is not larger than A, i.e. all bits + // from X and Y are disjoint in (X << A) | Y, + // if the mask of this AND op covers all bits of X or Y, while it covers + // no bits from the other, we can bypass this AND op. E.g., + // ((X << A) | Y) & Mask -> Y, + // if Mask = ((1 << effective_width_of(Y)) - 1) + // ((X << A) | Y) & Mask -> X << A, + // if Mask = ((1 << effective_width_of(X)) - 1) << A + // SimplifyDemandedBits in InstCombine can optimize the general case. + // This pattern aims to help other passes for a common case. + Value *Y, *XShifted; + if (match(Op1, m_APInt(Mask)) && + match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)), + m_Value(XShifted)), + m_Value(Y)))) { + const unsigned Width = Op0->getType()->getScalarSizeInBits(); + const unsigned ShftCnt = ShAmt->getLimitedValue(Width); + const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const unsigned EffWidthY = Width - YKnown.countMinLeadingZeros(); + if (EffWidthY <= ShftCnt) { + const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, + Q.DT); + const unsigned EffWidthX = Width - XKnown.countMinLeadingZeros(); + const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); + const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; + // If the mask is extracting all bits from X or Y as is, we can skip + // this AND op. + if (EffBitsY.isSubsetOf(*Mask) && !EffBitsX.intersects(*Mask)) + return Y; + if (EffBitsX.isSubsetOf(*Mask) && !EffBitsY.intersects(*Mask)) + return XShifted; + } + } + return nullptr; } @@ -1947,7 +1991,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) return Op0; - if (Value *V = simplifyAndOrOfCmps(Op0, Op1, false)) + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) return V; // Try some generic simplifications for associative operations. @@ -2108,13 +2152,15 @@ static Constant * computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, const DominatorTree *DT, CmpInst::Predicate Pred, AssumptionCache *AC, const Instruction *CxtI, - Value *LHS, Value *RHS) { + const InstrInfoQuery &IIQ, Value *LHS, Value *RHS) { // First, skip past any trivial no-ops. LHS = LHS->stripPointerCasts(); RHS = RHS->stripPointerCasts(); // A non-null pointer is not equal to a null pointer. - if (llvm::isKnownNonZero(LHS, DL) && isa<ConstantPointerNull>(RHS) && + if (llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, + IIQ.UseInstrInfo) && + isa<ConstantPointerNull>(RHS) && (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_NE)) return ConstantInt::get(GetCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); @@ -2379,12 +2425,12 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return getTrue(ITy); case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULE: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) return getFalse(ITy); break; case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: - if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) return getTrue(ITy); break; case ICmpInst::ICMP_SLT: { @@ -2429,17 +2475,18 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, /// Many binary operators with a constant operand have an easy-to-compute /// range of outputs. This can be used to fold a comparison to always true or /// always false. -static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { +static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper, + const InstrInfoQuery &IIQ) { unsigned Width = Lower.getBitWidth(); const APInt *C; switch (BO.getOpcode()) { case Instruction::Add: if (match(BO.getOperand(1), m_APInt(C)) && !C->isNullValue()) { // FIXME: If we have both nuw and nsw, we should reduce the range further. - if (BO.hasNoUnsignedWrap()) { + if (IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(&BO))) { // 'add nuw x, C' produces [C, UINT_MAX]. Lower = *C; - } else if (BO.hasNoSignedWrap()) { + } else if (IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(&BO))) { if (C->isNegative()) { // 'add nsw x, -C' produces [SINT_MIN, SINT_MAX - C]. Lower = APInt::getSignedMinValue(Width); @@ -2472,7 +2519,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { Upper = APInt::getSignedMaxValue(Width).ashr(*C) + 1; } else if (match(BO.getOperand(0), m_APInt(C))) { unsigned ShiftAmount = Width - 1; - if (!C->isNullValue() && BO.isExact()) + if (!C->isNullValue() && IIQ.isExact(&BO)) ShiftAmount = C->countTrailingZeros(); if (C->isNegative()) { // 'ashr C, x' produces [C, C >> (Width-1)] @@ -2493,7 +2540,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { } else if (match(BO.getOperand(0), m_APInt(C))) { // 'lshr C, x' produces [C >> (Width-1), C]. unsigned ShiftAmount = Width - 1; - if (!C->isNullValue() && BO.isExact()) + if (!C->isNullValue() && IIQ.isExact(&BO)) ShiftAmount = C->countTrailingZeros(); Lower = C->lshr(ShiftAmount); Upper = *C + 1; @@ -2502,7 +2549,7 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { case Instruction::Shl: if (match(BO.getOperand(0), m_APInt(C))) { - if (BO.hasNoUnsignedWrap()) { + if (IIQ.hasNoUnsignedWrap(&BO)) { // 'shl nuw C, x' produces [C, C << CLZ(C)] Lower = *C; Upper = Lower.shl(Lower.countLeadingZeros()) + 1; @@ -2583,8 +2630,72 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) { } } +/// Some intrinsics with a constant operand have an easy-to-compute range of +/// outputs. This can be used to fold a comparison to always true or always +/// false. +static void setLimitsForIntrinsic(IntrinsicInst &II, APInt &Lower, + APInt &Upper) { + unsigned Width = Lower.getBitWidth(); + const APInt *C; + switch (II.getIntrinsicID()) { + case Intrinsic::uadd_sat: + // uadd.sat(x, C) produces [C, UINT_MAX]. + if (match(II.getOperand(0), m_APInt(C)) || + match(II.getOperand(1), m_APInt(C))) + Lower = *C; + break; + case Intrinsic::sadd_sat: + if (match(II.getOperand(0), m_APInt(C)) || + match(II.getOperand(1), m_APInt(C))) { + if (C->isNegative()) { + // sadd.sat(x, -C) produces [SINT_MIN, SINT_MAX + (-C)]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) + *C + 1; + } else { + // sadd.sat(x, +C) produces [SINT_MIN + C, SINT_MAX]. + Lower = APInt::getSignedMinValue(Width) + *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } + break; + case Intrinsic::usub_sat: + // usub.sat(C, x) produces [0, C]. + if (match(II.getOperand(0), m_APInt(C))) + Upper = *C + 1; + // usub.sat(x, C) produces [0, UINT_MAX - C]. + else if (match(II.getOperand(1), m_APInt(C))) + Upper = APInt::getMaxValue(Width) - *C + 1; + break; + case Intrinsic::ssub_sat: + if (match(II.getOperand(0), m_APInt(C))) { + if (C->isNegative()) { + // ssub.sat(-C, x) produces [SINT_MIN, -SINT_MIN + (-C)]. + Lower = APInt::getSignedMinValue(Width); + Upper = *C - APInt::getSignedMinValue(Width) + 1; + } else { + // ssub.sat(+C, x) produces [-SINT_MAX + C, SINT_MAX]. + Lower = *C - APInt::getSignedMaxValue(Width); + Upper = APInt::getSignedMaxValue(Width) + 1; + } + } else if (match(II.getOperand(1), m_APInt(C))) { + if (C->isNegative()) { + // ssub.sat(x, -C) produces [SINT_MIN - (-C), SINT_MAX]: + Lower = APInt::getSignedMinValue(Width) - *C; + Upper = APInt::getSignedMaxValue(Width) + 1; + } else { + // ssub.sat(x, +C) produces [SINT_MIN, SINT_MAX - C]. + Lower = APInt::getSignedMinValue(Width); + Upper = APInt::getSignedMaxValue(Width) - *C + 1; + } + } + break; + default: + break; + } +} + static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, - Value *RHS) { + Value *RHS, const InstrInfoQuery &IIQ) { Type *ITy = GetCompareTy(RHS); // The return type. Value *X; @@ -2615,13 +2726,15 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, APInt Lower = APInt(Width, 0); APInt Upper = APInt(Width, 0); if (auto *BO = dyn_cast<BinaryOperator>(LHS)) - setLimitsForBinOp(*BO, Lower, Upper); + setLimitsForBinOp(*BO, Lower, Upper, IIQ); + else if (auto *II = dyn_cast<IntrinsicInst>(LHS)) + setLimitsForIntrinsic(*II, Lower, Upper); ConstantRange LHS_CR = Lower != Upper ? ConstantRange(Lower, Upper) : ConstantRange(Width, true); if (auto *I = dyn_cast<Instruction>(LHS)) - if (auto *Ranges = I->getMetadata(LLVMContext::MD_range)) + if (auto *Ranges = IIQ.getMetadata(I, LLVMContext::MD_range)) LHS_CR = LHS_CR.intersectWith(getConstantRangeFromMetadata(*Ranges)); if (!LHS_CR.isFullSet()) { @@ -2654,16 +2767,20 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, B = LBO->getOperand(1); NoLHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && LBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && LBO->hasNoSignedWrap()); + (CmpInst::isUnsigned(Pred) && + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO))) || + (CmpInst::isSigned(Pred) && + Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO))); } if (RBO && RBO->getOpcode() == Instruction::Add) { C = RBO->getOperand(0); D = RBO->getOperand(1); NoRHSWrapProblem = ICmpInst::isEquality(Pred) || - (CmpInst::isUnsigned(Pred) && RBO->hasNoUnsignedWrap()) || - (CmpInst::isSigned(Pred) && RBO->hasNoSignedWrap()); + (CmpInst::isUnsigned(Pred) && + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(RBO))) || + (CmpInst::isSigned(Pred) && + Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(RBO))); } // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. @@ -2881,7 +2998,8 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // - The shift is nuw, we can't shift out the one bit. // - CI2 is one // - CI isn't zero - if (LBO->hasNoSignedWrap() || LBO->hasNoUnsignedWrap() || + if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || CI2Val->isOneValue() || !CI->isZero()) { if (Pred == ICmpInst::ICMP_EQ) return ConstantInt::getFalse(RHS->getContext()); @@ -2905,29 +3023,31 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, break; case Instruction::UDiv: case Instruction::LShr: - if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact()) + if (ICmpInst::isSigned(Pred) || !Q.IIQ.isExact(LBO) || + !Q.IIQ.isExact(RBO)) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::SDiv: - if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact()) + if (!ICmpInst::isEquality(Pred) || !Q.IIQ.isExact(LBO) || + !Q.IIQ.isExact(RBO)) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::AShr: - if (!LBO->isExact() || !RBO->isExact()) + if (!Q.IIQ.isExact(LBO) || !Q.IIQ.isExact(RBO)) break; if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), RBO->getOperand(0), Q, MaxRecurse - 1)) return V; break; case Instruction::Shl: { - bool NUW = LBO->hasNoUnsignedWrap() && RBO->hasNoUnsignedWrap(); - bool NSW = LBO->hasNoSignedWrap() && RBO->hasNoSignedWrap(); + bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO); + bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO); if (!NUW && !NSW) break; if (!NSW && ICmpInst::isSigned(Pred)) @@ -2942,6 +3062,44 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +static Value *simplifyICmpWithAbsNabs(CmpInst::Predicate Pred, Value *Op0, + Value *Op1) { + // We need a comparison with a constant. + const APInt *C; + if (!match(Op1, m_APInt(C))) + return nullptr; + + // matchSelectPattern returns the negation part of an abs pattern in SP1. + // If the negate has an NSW flag, abs(INT_MIN) is undefined. Without that + // constraint, we can't make a contiguous range for the result of abs. + ICmpInst::Predicate AbsPred = ICmpInst::BAD_ICMP_PREDICATE; + Value *SP0, *SP1; + SelectPatternFlavor SPF = matchSelectPattern(Op0, SP0, SP1).Flavor; + if (SPF == SelectPatternFlavor::SPF_ABS && + cast<Instruction>(SP1)->hasNoSignedWrap()) + // The result of abs(X) is >= 0 (with nsw). + AbsPred = ICmpInst::ICMP_SGE; + if (SPF == SelectPatternFlavor::SPF_NABS) + // The result of -abs(X) is <= 0. + AbsPred = ICmpInst::ICMP_SLE; + + if (AbsPred == ICmpInst::BAD_ICMP_PREDICATE) + return nullptr; + + // If there is no intersection between abs/nabs and the range of this icmp, + // the icmp must be false. If the abs/nabs range is a subset of the icmp + // range, the icmp must be true. + APInt Zero = APInt::getNullValue(C->getBitWidth()); + ConstantRange AbsRange = ConstantRange::makeExactICmpRegion(AbsPred, Zero); + ConstantRange CmpRange = ConstantRange::makeExactICmpRegion(Pred, *C); + if (AbsRange.intersectWith(CmpRange).isEmptySet()) + return getFalse(GetCompareTy(Op0)); + if (CmpRange.contains(AbsRange)) + return getTrue(GetCompareTy(Op0)); + + return nullptr; +} + /// Simplify integer comparisons where at least one operand of the compare /// matches an integer min/max idiom. static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, @@ -3175,7 +3333,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Value *V = simplifyICmpWithZero(Pred, LHS, RHS, Q)) return V; - if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS)) + if (Value *V = simplifyICmpWithConstant(Pred, LHS, RHS, Q.IIQ)) return V; // If both operands have range metadata, use the metadata @@ -3184,8 +3342,8 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, auto RHS_Instr = cast<Instruction>(RHS); auto LHS_Instr = cast<Instruction>(LHS); - if (RHS_Instr->getMetadata(LLVMContext::MD_range) && - LHS_Instr->getMetadata(LLVMContext::MD_range)) { + if (Q.IIQ.getMetadata(RHS_Instr, LLVMContext::MD_range) && + Q.IIQ.getMetadata(LHS_Instr, LLVMContext::MD_range)) { auto RHS_CR = getConstantRangeFromMetadata( *RHS_Instr->getMetadata(LLVMContext::MD_range)); auto LHS_CR = getConstantRangeFromMetadata( @@ -3363,7 +3521,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // icmp eq|ne X, Y -> false|true if X != Y if (ICmpInst::isEquality(Pred) && - isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT)) { + isKnownNonEqual(LHS, RHS, Q.DL, Q.AC, Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo)) { return Pred == ICmpInst::ICMP_NE ? getTrue(ITy) : getFalse(ITy); } @@ -3373,11 +3531,14 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Value *V = simplifyICmpWithMinMax(Pred, LHS, RHS, Q, MaxRecurse)) return V; + if (Value *V = simplifyICmpWithAbsNabs(Pred, LHS, RHS)) + return V; + // Simplify comparisons of related pointers using a powerful, recursive // GEP-walk when we have target data available.. if (LHS->getType()->isPointerTy()) - if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, LHS, - RHS)) + if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, + Q.IIQ, LHS, RHS)) return C; if (auto *CLHS = dyn_cast<PtrToIntOperator>(LHS)) if (auto *CRHS = dyn_cast<PtrToIntOperator>(RHS)) @@ -3386,7 +3547,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, Q.DL.getTypeSizeInBits(CRHS->getPointerOperandType()) == Q.DL.getTypeSizeInBits(CRHS->getType())) if (auto *C = computePointerICmp(Q.DL, Q.TLI, Q.DT, Pred, Q.AC, Q.CxtI, - CLHS->getPointerOperand(), + Q.IIQ, CLHS->getPointerOperand(), CRHS->getPointerOperand())) return C; @@ -3457,13 +3618,11 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Pred == FCmpInst::FCMP_TRUE) return getTrue(RetTy); - // UNO/ORD predicates can be trivially folded if NaNs are ignored. - if (FMF.noNaNs()) { - if (Pred == FCmpInst::FCMP_UNO) - return getFalse(RetTy); - if (Pred == FCmpInst::FCMP_ORD) - return getTrue(RetTy); - } + // Fold (un)ordered comparison if we can determine there are no NaNs. + if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD) + if (FMF.noNaNs() || + (isKnownNeverNaN(LHS, Q.TLI) && isKnownNeverNaN(RHS, Q.TLI))) + return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD); // NaN is unordered; NaN is not ordered. assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && @@ -3518,12 +3677,19 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } if (C->isZero()) { switch (Pred) { + case FCmpInst::FCMP_OGE: + if (FMF.noNaNs() && CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getTrue(RetTy); + break; case FCmpInst::FCMP_UGE: if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) return getTrue(RetTy); break; + case FCmpInst::FCMP_ULT: + if (FMF.noNaNs() && CannotBeOrderedLessThanZero(LHS, Q.TLI)) + return getFalse(RetTy); + break; case FCmpInst::FCMP_OLT: - // X < 0 if (CannotBeOrderedLessThanZero(LHS, Q.TLI)) return getFalse(RetTy); break; @@ -3600,11 +3766,10 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, // // We can't replace %sel with %add unless we strip away the flags. if (isa<OverflowingBinaryOperator>(B)) - if (B->hasNoSignedWrap() || B->hasNoUnsignedWrap()) - return nullptr; - if (isa<PossiblyExactOperator>(B)) - if (B->isExact()) + if (Q.IIQ.hasNoSignedWrap(B) || Q.IIQ.hasNoUnsignedWrap(B)) return nullptr; + if (isa<PossiblyExactOperator>(B) && Q.IIQ.isExact(B)) + return nullptr; if (MaxRecurse) { if (B->getOperand(0) == Op) @@ -3738,6 +3903,28 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, Pred == ICmpInst::ICMP_EQ)) return V; + + // Test for zero-shift-guard-ops around funnel shifts. These are used to + // avoid UB from oversized shifts in raw IR rotate patterns, but the + // intrinsics do not have that problem. + Value *ShAmt; + auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), + m_Value(ShAmt)), + m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), + m_Value(ShAmt))); + // (ShAmt != 0) ? fshl(X, *, ShAmt) : X --> fshl(X, *, ShAmt) + // (ShAmt != 0) ? fshr(*, X, ShAmt) : X --> fshr(*, X, ShAmt) + // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X + // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X + if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt) + return Pred == ICmpInst::ICMP_NE ? TrueVal : X; + + // (ShAmt == 0) ? X : fshl(X, *, ShAmt) --> fshl(X, *, ShAmt) + // (ShAmt == 0) ? X : fshr(*, X, ShAmt) --> fshr(*, X, ShAmt) + // (ShAmt != 0) ? X : fshl(X, *, ShAmt) --> X + // (ShAmt != 0) ? X : fshr(*, X, ShAmt) --> X + if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt) + return Pred == ICmpInst::ICMP_EQ ? FalseVal : X; } // Check for other compares that behave like bit test. @@ -3775,6 +3962,34 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, return nullptr; } +/// Try to simplify a select instruction when its condition operand is a +/// floating-point comparison. +static Value *simplifySelectWithFCmp(Value *Cond, Value *T, Value *F) { + FCmpInst::Predicate Pred; + if (!match(Cond, m_FCmp(Pred, m_Specific(T), m_Specific(F))) && + !match(Cond, m_FCmp(Pred, m_Specific(F), m_Specific(T)))) + return nullptr; + + // TODO: The transform may not be valid with -0.0. An incomplete way of + // testing for that possibility is to check if at least one operand is a + // non-zero constant. + const APFloat *C; + if ((match(T, m_APFloat(C)) && C->isNonZero()) || + (match(F, m_APFloat(C)) && C->isNonZero())) { + // (T == F) ? T : F --> F + // (F == T) ? T : F --> F + if (Pred == FCmpInst::FCMP_OEQ) + return F; + + // (T != F) ? T : F --> T + // (F != T) ? T : F --> T + if (Pred == FCmpInst::FCMP_UNE) + return T; + } + + return nullptr; +} + /// Given operands for a SelectInst, see if we can fold the result. /// If not, this returns null. static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, @@ -3811,9 +4026,16 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse)) return V; + if (Value *V = simplifySelectWithFCmp(Cond, TrueVal, FalseVal)) + return V; + if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal)) return V; + Optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL); + if (Imp) + return *Imp ? TrueVal : FalseVal; + return nullptr; } @@ -4325,6 +4547,14 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) return ConstantFP::getNullValue(Op0->getType()); + // (X - Y) + Y --> X + // Y + (X - Y) --> X + Value *X; + if (FMF.noSignedZeros() && FMF.allowReassoc() && + (match(Op0, m_FSub(m_Value(X), m_Specific(Op1))) || + match(Op1, m_FSub(m_Value(X), m_Specific(Op0))))) + return X; + return nullptr; } @@ -4362,6 +4592,13 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (FMF.noNaNs() && Op0 == Op1) return Constant::getNullValue(Op0->getType()); + // Y - (Y - X) --> X + // (X + Y) - Y --> X + if (FMF.noSignedZeros() && FMF.allowReassoc() && + (match(Op1, m_FSub(m_Specific(Op0), m_Value(X))) || + match(Op0, m_c_FAdd(m_Specific(Op1), m_Value(X))))) + return X; + return nullptr; } @@ -4442,10 +4679,8 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, // -X / X -> -1.0 and // X / -X -> -1.0 are legal when NaNs are ignored. // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. - if ((BinaryOperator::isFNeg(Op0, /*IgnoreZeroSign=*/true) && - BinaryOperator::getFNegArgument(Op0) == Op1) || - (BinaryOperator::isFNeg(Op1, /*IgnoreZeroSign=*/true) && - BinaryOperator::getFNegArgument(Op1) == Op0)) + if (match(Op0, m_FNegNSZ(m_Specific(Op1))) || + match(Op1, m_FNegNSZ(m_Specific(Op0)))) return ConstantFP::get(Op0->getType(), -1.0); } @@ -4747,6 +4982,40 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, if (match(Op0, m_Undef()) || match(Op1, m_Undef())) return Constant::getNullValue(ReturnType); break; + case Intrinsic::uadd_sat: + // sat(MAX + X) -> MAX + // sat(X + MAX) -> MAX + if (match(Op0, m_AllOnes()) || match(Op1, m_AllOnes())) + return Constant::getAllOnesValue(ReturnType); + LLVM_FALLTHROUGH; + case Intrinsic::sadd_sat: + // sat(X + undef) -> -1 + // sat(undef + X) -> -1 + // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1). + // For signed: Assume undef is ~X, in which case X + ~X = -1. + if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getAllOnesValue(ReturnType); + + // X + 0 -> X + if (match(Op1, m_Zero())) + return Op0; + // 0 + X -> X + if (match(Op0, m_Zero())) + return Op1; + break; + case Intrinsic::usub_sat: + // sat(0 - X) -> 0, sat(X - MAX) -> 0 + if (match(Op0, m_Zero()) || match(Op1, m_AllOnes())) + return Constant::getNullValue(ReturnType); + LLVM_FALLTHROUGH; + case Intrinsic::ssub_sat: + // X - X -> 0, X - undef -> 0, undef - X -> 0 + if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef())) + return Constant::getNullValue(ReturnType); + // X - 0 -> X + if (match(Op1, m_Zero())) + return Op0; + break; case Intrinsic::load_relative: if (auto *C0 = dyn_cast<Constant>(Op0)) if (auto *C1 = dyn_cast<Constant>(Op1)) @@ -4764,10 +5033,51 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, break; case Intrinsic::maxnum: case Intrinsic::minnum: - // If one argument is NaN, return the other argument. - if (match(Op0, m_NaN())) return Op1; - if (match(Op1, m_NaN())) return Op0; + case Intrinsic::maximum: + case Intrinsic::minimum: { + // If the arguments are the same, this is a no-op. + if (Op0 == Op1) return Op0; + + // If one argument is undef, return the other argument. + if (match(Op0, m_Undef())) + return Op1; + if (match(Op1, m_Undef())) + return Op0; + + // If one argument is NaN, return other or NaN appropriately. + bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; + if (match(Op0, m_NaN())) + return PropagateNaN ? Op0 : Op1; + if (match(Op1, m_NaN())) + return PropagateNaN ? Op1 : Op0; + + // Min/max of the same operation with common operand: + // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) + if (auto *M0 = dyn_cast<IntrinsicInst>(Op0)) + if (M0->getIntrinsicID() == IID && + (M0->getOperand(0) == Op1 || M0->getOperand(1) == Op1)) + return Op0; + if (auto *M1 = dyn_cast<IntrinsicInst>(Op1)) + if (M1->getIntrinsicID() == IID && + (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0)) + return Op1; + + // min(X, -Inf) --> -Inf (and commuted variant) + // max(X, +Inf) --> +Inf (and commuted variant) + bool UseNegInf = IID == Intrinsic::minnum || IID == Intrinsic::minimum; + const APFloat *C; + if ((match(Op0, m_APFloat(C)) && C->isInfinity() && + C->isNegative() == UseNegInf) || + (match(Op1, m_APFloat(C)) && C->isInfinity() && + C->isNegative() == UseNegInf)) + return ConstantFP::getInfinity(ReturnType, UseNegInf); + + // TODO: minnum(nnan x, inf) -> x + // TODO: minnum(nnan ninf x, flt_max) -> x + // TODO: maxnum(nnan x, -inf) -> x + // TODO: maxnum(nnan ninf x, -flt_max) -> x break; + } default: break; } @@ -4802,7 +5112,16 @@ static Value *simplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd, } case Intrinsic::fshl: case Intrinsic::fshr: { - Value *ShAmtArg = ArgBegin[2]; + Value *Op0 = ArgBegin[0], *Op1 = ArgBegin[1], *ShAmtArg = ArgBegin[2]; + + // If both operands are undef, the result is undef. + if (match(Op0, m_Undef()) && match(Op1, m_Undef())) + return UndefValue::get(F->getReturnType()); + + // If shift amount is undef, assume it is zero. + if (match(ShAmtArg, m_Undef())) + return ArgBegin[IID == Intrinsic::fshl ? 0 : 1]; + const APInt *ShAmtC; if (match(ShAmtArg, m_APInt(ShAmtC))) { // If there's effectively no shift, return the 1st arg or 2nd arg. @@ -4889,18 +5208,20 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, I->getFastMathFlags(), Q); break; case Instruction::Add: - Result = SimplifyAddInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); + Result = + SimplifyAddInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FSub: Result = SimplifyFSubInst(I->getOperand(0), I->getOperand(1), I->getFastMathFlags(), Q); break; case Instruction::Sub: - Result = SimplifySubInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); + Result = + SimplifySubInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::FMul: Result = SimplifyFMulInst(I->getOperand(0), I->getOperand(1), @@ -4930,17 +5251,18 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, I->getFastMathFlags(), Q); break; case Instruction::Shl: - Result = SimplifyShlInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->hasNoSignedWrap(), - cast<BinaryOperator>(I)->hasNoUnsignedWrap(), Q); + Result = + SimplifyShlInst(I->getOperand(0), I->getOperand(1), + Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)), + Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q); break; case Instruction::LShr: Result = SimplifyLShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::AShr: Result = SimplifyAShrInst(I->getOperand(0), I->getOperand(1), - cast<BinaryOperator>(I)->isExact(), Q); + Q.IIQ.isExact(cast<BinaryOperator>(I)), Q); break; case Instruction::And: Result = SimplifyAndInst(I->getOperand(0), I->getOperand(1), Q); @@ -5066,7 +5388,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, // Gracefully handle edge cases where the instruction is not wired into any // parent block. - if (I->getParent() && !I->isEHPad() && !isa<TerminatorInst>(I) && + if (I->getParent() && !I->isEHPad() && !I->isTerminator() && !I->mayHaveSideEffects()) I->eraseFromParent(); } else { @@ -5095,7 +5417,7 @@ static bool replaceAndRecursivelySimplifyImpl(Instruction *I, Value *SimpleV, // Gracefully handle edge cases where the instruction is not wired into any // parent block. - if (I->getParent() && !I->isEHPad() && !isa<TerminatorInst>(I) && + if (I->getParent() && !I->isEHPad() && !I->isTerminator() && !I->mayHaveSideEffects()) I->eraseFromParent(); } diff --git a/lib/Analysis/IteratedDominanceFrontier.cpp b/lib/Analysis/IteratedDominanceFrontier.cpp index e7751d32aab3..000fe5ddad54 100644 --- a/lib/Analysis/IteratedDominanceFrontier.cpp +++ b/lib/Analysis/IteratedDominanceFrontier.cpp @@ -17,6 +17,7 @@ #include <queue> namespace llvm { + template <class NodeTy, bool IsPostDom> void IDFCalculator<NodeTy, IsPostDom>::calculate( SmallVectorImpl<BasicBlock *> &PHIBlocks) { @@ -61,29 +62,39 @@ void IDFCalculator<NodeTy, IsPostDom>::calculate( BasicBlock *BB = Node->getBlock(); // Succ is the successor in the direction we are calculating IDF, so it is // successor for IDF, and predecessor for Reverse IDF. - for (auto *Succ : children<NodeTy>(BB)) { + auto DoWork = [&](BasicBlock *Succ) { DomTreeNode *SuccNode = DT.getNode(Succ); // Quickly skip all CFG edges that are also dominator tree edges instead // of catching them below. if (SuccNode->getIDom() == Node) - continue; + return; const unsigned SuccLevel = SuccNode->getLevel(); if (SuccLevel > RootLevel) - continue; + return; if (!VisitedPQ.insert(SuccNode).second) - continue; + return; BasicBlock *SuccBB = SuccNode->getBlock(); if (useLiveIn && !LiveInBlocks->count(SuccBB)) - continue; + return; PHIBlocks.emplace_back(SuccBB); if (!DefBlocks->count(SuccBB)) PQ.push(std::make_pair( SuccNode, std::make_pair(SuccLevel, SuccNode->getDFSNumIn()))); + }; + + if (GD) { + for (auto Pair : children< + std::pair<const GraphDiff<BasicBlock *, IsPostDom> *, NodeTy>>( + {GD, BB})) + DoWork(Pair.second); + } else { + for (auto *Succ : children<NodeTy>(BB)) + DoWork(Succ); } for (auto DomChild : *Node) { diff --git a/lib/Analysis/LazyCallGraph.cpp b/lib/Analysis/LazyCallGraph.cpp index b1d585bfc683..3f22ada803c9 100644 --- a/lib/Analysis/LazyCallGraph.cpp +++ b/lib/Analysis/LazyCallGraph.cpp @@ -619,7 +619,7 @@ LazyCallGraph::RefSCC::switchInternalEdgeToCall( // If the merge range is empty, then adding the edge didn't actually form any // new cycles. We're done. - if (MergeRange.begin() == MergeRange.end()) { + if (empty(MergeRange)) { // Now that the SCC structure is finalized, flip the kind to call. SourceN->setEdgeKind(TargetN, Edge::Call); return false; // No new cycle. diff --git a/lib/Analysis/LazyValueInfo.cpp b/lib/Analysis/LazyValueInfo.cpp index ee0148e0d795..110c085d3f35 100644 --- a/lib/Analysis/LazyValueInfo.cpp +++ b/lib/Analysis/LazyValueInfo.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" @@ -420,6 +421,8 @@ namespace { BasicBlock *BB); bool solveBlockValueSelect(ValueLatticeElement &BBLV, SelectInst *S, BasicBlock *BB); + Optional<ConstantRange> getRangeForOperand(unsigned Op, Instruction *I, + BasicBlock *BB); bool solveBlockValueBinaryOp(ValueLatticeElement &BBLV, BinaryOperator *BBI, BasicBlock *BB); bool solveBlockValueCast(ValueLatticeElement &BBLV, CastInst *CI, @@ -634,8 +637,7 @@ bool LazyValueInfoImpl::solveBlockValueImpl(ValueLatticeElement &Res, if (auto *CI = dyn_cast<CastInst>(BBI)) return solveBlockValueCast(Res, CI, BB); - BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI); - if (BO && isa<ConstantInt>(BO->getOperand(1))) + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(BBI)) return solveBlockValueBinaryOp(Res, BO, BB); } @@ -951,6 +953,25 @@ bool LazyValueInfoImpl::solveBlockValueSelect(ValueLatticeElement &BBLV, return true; } +Optional<ConstantRange> LazyValueInfoImpl::getRangeForOperand(unsigned Op, + Instruction *I, + BasicBlock *BB) { + if (!hasBlockValue(I->getOperand(Op), BB)) + if (pushBlockValue(std::make_pair(BB, I->getOperand(Op)))) + return None; + + const unsigned OperandBitWidth = + DL.getTypeSizeInBits(I->getOperand(Op)->getType()); + ConstantRange Range = ConstantRange(OperandBitWidth); + if (hasBlockValue(I->getOperand(Op), BB)) { + ValueLatticeElement Val = getBlockValue(I->getOperand(Op), BB); + intersectAssumeOrGuardBlockValueConstantRange(I->getOperand(Op), Val, I); + if (Val.isConstantRange()) + Range = Val.getConstantRange(); + } + return Range; +} + bool LazyValueInfoImpl::solveBlockValueCast(ValueLatticeElement &BBLV, CastInst *CI, BasicBlock *BB) { @@ -981,21 +1002,11 @@ bool LazyValueInfoImpl::solveBlockValueCast(ValueLatticeElement &BBLV, // Figure out the range of the LHS. If that fails, we still apply the // transfer rule on the full set since we may be able to locally infer // interesting facts. - if (!hasBlockValue(CI->getOperand(0), BB)) - if (pushBlockValue(std::make_pair(BB, CI->getOperand(0)))) - // More work to do before applying this transfer rule. - return false; - - const unsigned OperandBitWidth = - DL.getTypeSizeInBits(CI->getOperand(0)->getType()); - ConstantRange LHSRange = ConstantRange(OperandBitWidth); - if (hasBlockValue(CI->getOperand(0), BB)) { - ValueLatticeElement LHSVal = getBlockValue(CI->getOperand(0), BB); - intersectAssumeOrGuardBlockValueConstantRange(CI->getOperand(0), LHSVal, - CI); - if (LHSVal.isConstantRange()) - LHSRange = LHSVal.getConstantRange(); - } + Optional<ConstantRange> LHSRes = getRangeForOperand(0, CI, BB); + if (!LHSRes.hasValue()) + // More work to do before applying this transfer rule. + return false; + ConstantRange LHSRange = LHSRes.getValue(); const unsigned ResultBitWidth = CI->getType()->getIntegerBitWidth(); @@ -1037,27 +1048,19 @@ bool LazyValueInfoImpl::solveBlockValueBinaryOp(ValueLatticeElement &BBLV, return true; }; - // Figure out the range of the LHS. If that fails, use a conservative range, - // but apply the transfer rule anyways. This lets us pick up facts from - // expressions like "and i32 (call i32 @foo()), 32" - if (!hasBlockValue(BO->getOperand(0), BB)) - if (pushBlockValue(std::make_pair(BB, BO->getOperand(0)))) - // More work to do before applying this transfer rule. - return false; + // Figure out the ranges of the operands. If that fails, use a + // conservative range, but apply the transfer rule anyways. This + // lets us pick up facts from expressions like "and i32 (call i32 + // @foo()), 32" + Optional<ConstantRange> LHSRes = getRangeForOperand(0, BO, BB); + Optional<ConstantRange> RHSRes = getRangeForOperand(1, BO, BB); - const unsigned OperandBitWidth = - DL.getTypeSizeInBits(BO->getOperand(0)->getType()); - ConstantRange LHSRange = ConstantRange(OperandBitWidth); - if (hasBlockValue(BO->getOperand(0), BB)) { - ValueLatticeElement LHSVal = getBlockValue(BO->getOperand(0), BB); - intersectAssumeOrGuardBlockValueConstantRange(BO->getOperand(0), LHSVal, - BO); - if (LHSVal.isConstantRange()) - LHSRange = LHSVal.getConstantRange(); - } + if (!LHSRes.hasValue() || !RHSRes.hasValue()) + // More work to do before applying this transfer rule. + return false; - ConstantInt *RHS = cast<ConstantInt>(BO->getOperand(1)); - ConstantRange RHSRange = ConstantRange(RHS->getValue()); + ConstantRange LHSRange = LHSRes.getValue(); + ConstantRange RHSRange = RHSRes.getValue(); // NOTE: We're currently limited by the set of operations that ConstantRange // can evaluate symbolically. Enhancing that set will allows us to analyze diff --git a/lib/Analysis/LegacyDivergenceAnalysis.cpp b/lib/Analysis/LegacyDivergenceAnalysis.cpp new file mode 100644 index 000000000000..5540859ebdda --- /dev/null +++ b/lib/Analysis/LegacyDivergenceAnalysis.cpp @@ -0,0 +1,391 @@ +//===- LegacyDivergenceAnalysis.cpp --------- Legacy Divergence Analysis +//Implementation -==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements divergence analysis which determines whether a branch +// in a GPU program is divergent.It can help branch optimizations such as jump +// threading and loop unswitching to make better decisions. +// +// GPU programs typically use the SIMD execution model, where multiple threads +// in the same execution group have to execute in lock-step. Therefore, if the +// code contains divergent branches (i.e., threads in a group do not agree on +// which path of the branch to take), the group of threads has to execute all +// the paths from that branch with different subsets of threads enabled until +// they converge at the immediately post-dominating BB of the paths. +// +// Due to this execution model, some optimizations such as jump +// threading and loop unswitching can be unfortunately harmful when performed on +// divergent branches. Therefore, an analysis that computes which branches in a +// GPU program are divergent can help the compiler to selectively run these +// optimizations. +// +// This file defines divergence analysis which computes a conservative but +// non-trivial approximation of all divergent branches in a GPU program. It +// partially implements the approach described in +// +// Divergence Analysis +// Sampaio, Souza, Collange, Pereira +// TOPLAS '13 +// +// The divergence analysis identifies the sources of divergence (e.g., special +// variables that hold the thread ID), and recursively marks variables that are +// data or sync dependent on a source of divergence as divergent. +// +// While data dependency is a well-known concept, the notion of sync dependency +// is worth more explanation. Sync dependence characterizes the control flow +// aspect of the propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// The current implementation has the following limitations: +// 1. intra-procedural. It conservatively considers the arguments of a +// non-kernel-entry function and the return value of a function call as +// divergent. +// 2. memory as black box. It conservatively considers values loaded from +// generic or local address as divergent. This can be improved by leveraging +// pointer analysis. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LegacyDivergenceAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include <vector> +using namespace llvm; + +#define DEBUG_TYPE "divergence" + +// transparently use the GPUDivergenceAnalysis +static cl::opt<bool> UseGPUDA("use-gpu-divergence-analysis", cl::init(false), + cl::Hidden, + cl::desc("turn the LegacyDivergenceAnalysis into " + "a wrapper for GPUDivergenceAnalysis")); + +namespace { + +class DivergencePropagator { +public: + DivergencePropagator(Function &F, TargetTransformInfo &TTI, DominatorTree &DT, + PostDominatorTree &PDT, DenseSet<const Value *> &DV) + : F(F), TTI(TTI), DT(DT), PDT(PDT), DV(DV) {} + void populateWithSourcesOfDivergence(); + void propagate(); + +private: + // A helper function that explores data dependents of V. + void exploreDataDependency(Value *V); + // A helper function that explores sync dependents of TI. + void exploreSyncDependency(Instruction *TI); + // Computes the influence region from Start to End. This region includes all + // basic blocks on any simple path from Start to End. + void computeInfluenceRegion(BasicBlock *Start, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion); + // Finds all users of I that are outside the influence region, and add these + // users to Worklist. + void findUsersOutsideInfluenceRegion( + Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion); + + Function &F; + TargetTransformInfo &TTI; + DominatorTree &DT; + PostDominatorTree &PDT; + std::vector<Value *> Worklist; // Stack for DFS. + DenseSet<const Value *> &DV; // Stores all divergent values. +}; + +void DivergencePropagator::populateWithSourcesOfDivergence() { + Worklist.clear(); + DV.clear(); + for (auto &I : instructions(F)) { + if (TTI.isSourceOfDivergence(&I)) { + Worklist.push_back(&I); + DV.insert(&I); + } + } + for (auto &Arg : F.args()) { + if (TTI.isSourceOfDivergence(&Arg)) { + Worklist.push_back(&Arg); + DV.insert(&Arg); + } + } +} + +void DivergencePropagator::exploreSyncDependency(Instruction *TI) { + // Propagation rule 1: if branch TI is divergent, all PHINodes in TI's + // immediate post dominator are divergent. This rule handles if-then-else + // patterns. For example, + // + // if (tid < 5) + // a1 = 1; + // else + // a2 = 2; + // a = phi(a1, a2); // sync dependent on (tid < 5) + BasicBlock *ThisBB = TI->getParent(); + + // Unreachable blocks may not be in the dominator tree. + if (!DT.isReachableFromEntry(ThisBB)) + return; + + // If the function has no exit blocks or doesn't reach any exit blocks, the + // post dominator may be null. + DomTreeNode *ThisNode = PDT.getNode(ThisBB); + if (!ThisNode) + return; + + BasicBlock *IPostDom = ThisNode->getIDom()->getBlock(); + if (IPostDom == nullptr) + return; + + for (auto I = IPostDom->begin(); isa<PHINode>(I); ++I) { + // A PHINode is uniform if it returns the same value no matter which path is + // taken. + if (!cast<PHINode>(I)->hasConstantOrUndefValue() && DV.insert(&*I).second) + Worklist.push_back(&*I); + } + + // Propagation rule 2: if a value defined in a loop is used outside, the user + // is sync dependent on the condition of the loop exits that dominate the + // user. For example, + // + // int i = 0; + // do { + // i++; + // if (foo(i)) ... // uniform + // } while (i < tid); + // if (bar(i)) ... // divergent + // + // A program may contain unstructured loops. Therefore, we cannot leverage + // LoopInfo, which only recognizes natural loops. + // + // The algorithm used here handles both natural and unstructured loops. Given + // a branch TI, we first compute its influence region, the union of all simple + // paths from TI to its immediate post dominator (IPostDom). Then, we search + // for all the values defined in the influence region but used outside. All + // these users are sync dependent on TI. + DenseSet<BasicBlock *> InfluenceRegion; + computeInfluenceRegion(ThisBB, IPostDom, InfluenceRegion); + // An insight that can speed up the search process is that all the in-region + // values that are used outside must dominate TI. Therefore, instead of + // searching every basic blocks in the influence region, we search all the + // dominators of TI until it is outside the influence region. + BasicBlock *InfluencedBB = ThisBB; + while (InfluenceRegion.count(InfluencedBB)) { + for (auto &I : *InfluencedBB) + findUsersOutsideInfluenceRegion(I, InfluenceRegion); + DomTreeNode *IDomNode = DT.getNode(InfluencedBB)->getIDom(); + if (IDomNode == nullptr) + break; + InfluencedBB = IDomNode->getBlock(); + } +} + +void DivergencePropagator::findUsersOutsideInfluenceRegion( + Instruction &I, const DenseSet<BasicBlock *> &InfluenceRegion) { + for (User *U : I.users()) { + Instruction *UserInst = cast<Instruction>(U); + if (!InfluenceRegion.count(UserInst->getParent())) { + if (DV.insert(UserInst).second) + Worklist.push_back(UserInst); + } + } +} + +// A helper function for computeInfluenceRegion that adds successors of "ThisBB" +// to the influence region. +static void +addSuccessorsToInfluenceRegion(BasicBlock *ThisBB, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion, + std::vector<BasicBlock *> &InfluenceStack) { + for (BasicBlock *Succ : successors(ThisBB)) { + if (Succ != End && InfluenceRegion.insert(Succ).second) + InfluenceStack.push_back(Succ); + } +} + +void DivergencePropagator::computeInfluenceRegion( + BasicBlock *Start, BasicBlock *End, + DenseSet<BasicBlock *> &InfluenceRegion) { + assert(PDT.properlyDominates(End, Start) && + "End does not properly dominate Start"); + + // The influence region starts from the end of "Start" to the beginning of + // "End". Therefore, "Start" should not be in the region unless "Start" is in + // a loop that doesn't contain "End". + std::vector<BasicBlock *> InfluenceStack; + addSuccessorsToInfluenceRegion(Start, End, InfluenceRegion, InfluenceStack); + while (!InfluenceStack.empty()) { + BasicBlock *BB = InfluenceStack.back(); + InfluenceStack.pop_back(); + addSuccessorsToInfluenceRegion(BB, End, InfluenceRegion, InfluenceStack); + } +} + +void DivergencePropagator::exploreDataDependency(Value *V) { + // Follow def-use chains of V. + for (User *U : V->users()) { + Instruction *UserInst = cast<Instruction>(U); + if (!TTI.isAlwaysUniform(U) && DV.insert(UserInst).second) + Worklist.push_back(UserInst); + } +} + +void DivergencePropagator::propagate() { + // Traverse the dependency graph using DFS. + while (!Worklist.empty()) { + Value *V = Worklist.back(); + Worklist.pop_back(); + if (Instruction *I = dyn_cast<Instruction>(V)) { + // Terminators with less than two successors won't introduce sync + // dependency. Ignore them. + if (I->isTerminator() && I->getNumSuccessors() > 1) + exploreSyncDependency(I); + } + exploreDataDependency(V); + } +} + +} // namespace + +// Register this pass. +char LegacyDivergenceAnalysis::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyDivergenceAnalysis, "divergence", + "Legacy Divergence Analysis", false, true) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(LegacyDivergenceAnalysis, "divergence", + "Legacy Divergence Analysis", false, true) + +FunctionPass *llvm::createLegacyDivergenceAnalysisPass() { + return new LegacyDivergenceAnalysis(); +} + +void LegacyDivergenceAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + if (UseGPUDA) + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesAll(); +} + +bool LegacyDivergenceAnalysis::shouldUseGPUDivergenceAnalysis( + const Function &F) const { + if (!UseGPUDA) + return false; + + // GPUDivergenceAnalysis requires a reducible CFG. + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + using RPOTraversal = ReversePostOrderTraversal<const Function *>; + RPOTraversal FuncRPOT(&F); + return !containsIrreducibleCFG<const BasicBlock *, const RPOTraversal, + const LoopInfo>(FuncRPOT, LI); +} + +bool LegacyDivergenceAnalysis::runOnFunction(Function &F) { + auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + if (TTIWP == nullptr) + return false; + + TargetTransformInfo &TTI = TTIWP->getTTI(F); + // Fast path: if the target does not have branch divergence, we do not mark + // any branch as divergent. + if (!TTI.hasBranchDivergence()) + return false; + + DivergentValues.clear(); + gpuDA = nullptr; + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + + if (shouldUseGPUDivergenceAnalysis(F)) { + // run the new GPU divergence analysis + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + gpuDA = llvm::make_unique<GPUDivergenceAnalysis>(F, DT, PDT, LI, TTI); + + } else { + // run LLVM's existing DivergenceAnalysis + DivergencePropagator DP(F, TTI, DT, PDT, DivergentValues); + DP.populateWithSourcesOfDivergence(); + DP.propagate(); + } + + LLVM_DEBUG(dbgs() << "\nAfter divergence analysis on " << F.getName() + << ":\n"; + print(dbgs(), F.getParent())); + + return false; +} + +bool LegacyDivergenceAnalysis::isDivergent(const Value *V) const { + if (gpuDA) { + return gpuDA->isDivergent(*V); + } + return DivergentValues.count(V); +} + +void LegacyDivergenceAnalysis::print(raw_ostream &OS, const Module *) const { + if ((!gpuDA || !gpuDA->hasDivergence()) && DivergentValues.empty()) + return; + + const Function *F = nullptr; + if (!DivergentValues.empty()) { + const Value *FirstDivergentValue = *DivergentValues.begin(); + if (const Argument *Arg = dyn_cast<Argument>(FirstDivergentValue)) { + F = Arg->getParent(); + } else if (const Instruction *I = + dyn_cast<Instruction>(FirstDivergentValue)) { + F = I->getParent()->getParent(); + } else { + llvm_unreachable("Only arguments and instructions can be divergent"); + } + } else if (gpuDA) { + F = &gpuDA->getFunction(); + } + if (!F) + return; + + // Dumps all divergent values in F, arguments and then instructions. + for (auto &Arg : F->args()) { + OS << (isDivergent(&Arg) ? "DIVERGENT: " : " "); + OS << Arg << "\n"; + } + // Iterate instructions using instructions() to ensure a deterministic order. + for (auto BI = F->begin(), BE = F->end(); BI != BE; ++BI) { + auto &BB = *BI; + OS << "\n " << BB.getName() << ":\n"; + for (auto &I : BB.instructionsWithoutDebug()) { + OS << (isDivergent(&I) ? "DIVERGENT: " : " "); + OS << I << "\n"; + } + } + OS << "\n"; +} diff --git a/lib/Analysis/Lint.cpp b/lib/Analysis/Lint.cpp index db919bd233bf..5d0a627f8426 100644 --- a/lib/Analysis/Lint.cpp +++ b/lib/Analysis/Lint.cpp @@ -330,12 +330,12 @@ void Lint::visitCallSite(CallSite CS) { // Check that the memcpy arguments don't overlap. The AliasAnalysis API // isn't expressive enough for what we really want to do. Known partial // overlap is not distinguished from the case where nothing is known. - uint64_t Size = 0; + auto Size = LocationSize::unknown(); if (const ConstantInt *Len = dyn_cast<ConstantInt>(findValue(MCI->getLength(), /*OffsetOk=*/false))) if (Len->getValue().isIntN(32)) - Size = Len->getValue().getZExtValue(); + Size = LocationSize::precise(Len->getValue().getZExtValue()); Assert(AA->alias(MCI->getSource(), Size, MCI->getDest(), Size) != MustAlias, "Undefined behavior: memcpy source and destination overlap", &I); diff --git a/lib/Analysis/Loads.cpp b/lib/Analysis/Loads.cpp index d319d4c249d3..8129795bc0c1 100644 --- a/lib/Analysis/Loads.cpp +++ b/lib/Analysis/Loads.cpp @@ -107,8 +107,8 @@ static bool isDereferenceableAndAlignedPointer( return isDereferenceableAndAlignedPointer(ASC->getOperand(0), Align, Size, DL, CtxI, DT, Visited); - if (auto CS = ImmutableCallSite(V)) - if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) + if (const auto *Call = dyn_cast<CallBase>(V)) + if (auto *RP = getArgumentAliasingToReturnedPointer(Call)) return isDereferenceableAndAlignedPointer(RP, Align, Size, DL, CtxI, DT, Visited); @@ -345,7 +345,7 @@ Value *llvm::FindAvailablePtrLoadStore(Value *Ptr, Type *AccessTy, const DataLayout &DL = ScanBB->getModule()->getDataLayout(); // Try to get the store size for the type. - uint64_t AccessSize = DL.getTypeStoreSize(AccessTy); + auto AccessSize = LocationSize::precise(DL.getTypeStoreSize(AccessTy)); Value *StrippedPtr = Ptr->stripPointerCasts(); diff --git a/lib/Analysis/LoopAccessAnalysis.cpp b/lib/Analysis/LoopAccessAnalysis.cpp index a24d66011b8d..7f3480f512ab 100644 --- a/lib/Analysis/LoopAccessAnalysis.cpp +++ b/lib/Analysis/LoopAccessAnalysis.cpp @@ -342,7 +342,7 @@ void RuntimePointerChecking::groupChecks( // // The above case requires that we have an UnknownDependence between // accesses to the same underlying object. This cannot happen unless - // ShouldRetryWithRuntimeCheck is set, and therefore UseDependencies + // FoundNonConstantDistanceDependence is set, and therefore UseDependencies // is also false. In this case we will use the fallback path and create // separate checking groups for all pointers. @@ -420,7 +420,7 @@ void RuntimePointerChecking::groupChecks( // We've computed the grouped checks for this partition. // Save the results and continue with the next one. - std::copy(Groups.begin(), Groups.end(), std::back_inserter(CheckingGroups)); + llvm::copy(Groups, std::back_inserter(CheckingGroups)); } } @@ -509,7 +509,7 @@ public: /// Register a load and whether it is only read from. void addLoad(MemoryLocation &Loc, bool IsReadOnly) { Value *Ptr = const_cast<Value*>(Loc.Ptr); - AST.add(Ptr, MemoryLocation::UnknownSize, Loc.AATags); + AST.add(Ptr, LocationSize::unknown(), Loc.AATags); Accesses.insert(MemAccessInfo(Ptr, false)); if (IsReadOnly) ReadOnlyPtr.insert(Ptr); @@ -518,7 +518,7 @@ public: /// Register a store. void addStore(MemoryLocation &Loc) { Value *Ptr = const_cast<Value*>(Loc.Ptr); - AST.add(Ptr, MemoryLocation::UnknownSize, Loc.AATags); + AST.add(Ptr, LocationSize::unknown(), Loc.AATags); Accesses.insert(MemAccessInfo(Ptr, true)); } @@ -556,7 +556,7 @@ public: /// perform dependency checking. /// /// Note that this can later be cleared if we retry memcheck analysis without - /// dependency checking (i.e. ShouldRetryWithRuntimeCheck). + /// dependency checking (i.e. FoundNonConstantDistanceDependence). bool isDependencyCheckNeeded() { return !CheckDeps.empty(); } /// We decided that no dependence analysis would be used. Reset the state. @@ -604,8 +604,8 @@ private: /// /// Note that, this is different from isDependencyCheckNeeded. When we retry /// memcheck analysis without dependency checking - /// (i.e. ShouldRetryWithRuntimeCheck), isDependencyCheckNeeded is cleared - /// while this remains set if we have potentially dependent accesses. + /// (i.e. FoundNonConstantDistanceDependence), isDependencyCheckNeeded is + /// cleared while this remains set if we have potentially dependent accesses. bool IsRTCheckAnalysisNeeded; /// The SCEV predicate containing all the SCEV-related assumptions. @@ -1221,18 +1221,20 @@ bool llvm::isConsecutiveAccess(Value *A, Value *B, const DataLayout &DL, return X == PtrSCEVB; } -bool MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { +MemoryDepChecker::VectorizationSafetyStatus +MemoryDepChecker::Dependence::isSafeForVectorization(DepType Type) { switch (Type) { case NoDep: case Forward: case BackwardVectorizable: - return true; + return VectorizationSafetyStatus::Safe; case Unknown: + return VectorizationSafetyStatus::PossiblySafeWithRtChecks; case ForwardButPreventsForwarding: case Backward: case BackwardVectorizableButPreventsForwarding: - return false; + return VectorizationSafetyStatus::Unsafe; } llvm_unreachable("unexpected DepType!"); } @@ -1317,6 +1319,11 @@ bool MemoryDepChecker::couldPreventStoreLoadForward(uint64_t Distance, return false; } +void MemoryDepChecker::mergeInStatus(VectorizationSafetyStatus S) { + if (Status < S) + Status = S; +} + /// Given a non-constant (unknown) dependence-distance \p Dist between two /// memory accesses, that have the same stride whose absolute value is given /// in \p Stride, and that have the same type size \p TypeByteSize, @@ -1485,7 +1492,7 @@ MemoryDepChecker::isDependent(const MemAccessInfo &A, unsigned AIdx, return Dependence::NoDep; LLVM_DEBUG(dbgs() << "LAA: Dependence because of non-constant distance\n"); - ShouldRetryWithRuntimeCheck = true; + FoundNonConstantDistanceDependence = true; return Dependence::Unknown; } @@ -1652,7 +1659,7 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, Dependence::DepType Type = isDependent(*A.first, A.second, *B.first, B.second, Strides); - SafeForVectorization &= Dependence::isSafeForVectorization(Type); + mergeInStatus(Dependence::isSafeForVectorization(Type)); // Gather dependences unless we accumulated MaxDependences // dependences. In that case return as soon as we find the first @@ -1669,7 +1676,7 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, << "Too many dependences, stopped recording\n"); } } - if (!RecordDependences && !SafeForVectorization) + if (!RecordDependences && !isSafeForVectorization()) return false; } ++OI; @@ -1679,7 +1686,7 @@ bool MemoryDepChecker::areDepsSafe(DepCandidates &AccessSets, } LLVM_DEBUG(dbgs() << "Total Dependences: " << Dependences.size() << "\n"); - return SafeForVectorization; + return isSafeForVectorization(); } SmallVector<Instruction *, 4> @@ -1862,10 +1869,17 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, // writes and between reads and writes, but not between reads and reads. ValueSet Seen; + // Record uniform store addresses to identify if we have multiple stores + // to the same address. + ValueSet UniformStores; + for (StoreInst *ST : Stores) { Value *Ptr = ST->getPointerOperand(); - // Check for store to loop invariant address. - StoreToLoopInvariantAddress |= isUniform(Ptr); + + if (isUniform(Ptr)) + HasDependenceInvolvingLoopInvariantAddress |= + !UniformStores.insert(Ptr).second; + // If we did *not* see this pointer before, insert it to the read-write // list. At this phase it is only a 'write' list. if (Seen.insert(Ptr).second) { @@ -1907,6 +1921,14 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, IsReadOnlyPtr = true; } + // See if there is an unsafe dependency between a load to a uniform address and + // store to the same uniform address. + if (UniformStores.count(Ptr)) { + LLVM_DEBUG(dbgs() << "LAA: Found an unsafe dependency between a uniform " + "load and uniform store to the same address!\n"); + HasDependenceInvolvingLoopInvariantAddress = true; + } + MemoryLocation Loc = MemoryLocation::get(LD); // The TBAA metadata could have a control dependency on the predication // condition, so we cannot rely on it when determining whether or not we @@ -2265,7 +2287,7 @@ LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, PtrRtChecking(llvm::make_unique<RuntimePointerChecking>(SE)), DepChecker(llvm::make_unique<MemoryDepChecker>(*PSE, L)), TheLoop(L), NumLoads(0), NumStores(0), MaxSafeDepDistBytes(-1), CanVecMem(false), - StoreToLoopInvariantAddress(false) { + HasDependenceInvolvingLoopInvariantAddress(false) { if (canAnalyzeLoop()) analyzeLoop(AA, LI, TLI, DT); } @@ -2297,8 +2319,8 @@ void LoopAccessInfo::print(raw_ostream &OS, unsigned Depth) const { PtrRtChecking->print(OS, Depth); OS << "\n"; - OS.indent(Depth) << "Store to invariant address was " - << (StoreToLoopInvariantAddress ? "" : "not ") + OS.indent(Depth) << "Non vectorizable stores to invariant address were " + << (HasDependenceInvolvingLoopInvariantAddress ? "" : "not ") << "found in loop.\n"; OS.indent(Depth) << "SCEV assumptions:\n"; diff --git a/lib/Analysis/LoopAnalysisManager.cpp b/lib/Analysis/LoopAnalysisManager.cpp index 074023a7e1e2..2a3b29d7fbca 100644 --- a/lib/Analysis/LoopAnalysisManager.cpp +++ b/lib/Analysis/LoopAnalysisManager.cpp @@ -147,8 +147,8 @@ PreservedAnalyses llvm::getLoopPassPreservedAnalyses() { PA.preserve<LoopAnalysis>(); PA.preserve<LoopAnalysisManagerFunctionProxy>(); PA.preserve<ScalarEvolutionAnalysis>(); - // FIXME: Uncomment this when all loop passes preserve MemorySSA - // PA.preserve<MemorySSAAnalysis>(); + if (EnableMSSALoopDependency) + PA.preserve<MemorySSAAnalysis>(); // FIXME: What we really want to do here is preserve an AA category, but that // concept doesn't exist yet. PA.preserve<AAManager>(); diff --git a/lib/Analysis/LoopInfo.cpp b/lib/Analysis/LoopInfo.cpp index 3f78456b3586..ef2b1257015c 100644 --- a/lib/Analysis/LoopInfo.cpp +++ b/lib/Analysis/LoopInfo.cpp @@ -26,6 +26,7 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" @@ -213,33 +214,21 @@ bool Loop::isSafeToClone() const { MDNode *Loop::getLoopID() const { MDNode *LoopID = nullptr; - if (BasicBlock *Latch = getLoopLatch()) { - LoopID = Latch->getTerminator()->getMetadata(LLVMContext::MD_loop); - } else { - assert(!getLoopLatch() && - "The loop should have no single latch at this point"); - // Go through each predecessor of the loop header and check the - // terminator for the metadata. - BasicBlock *H = getHeader(); - for (BasicBlock *BB : this->blocks()) { - TerminatorInst *TI = BB->getTerminator(); - MDNode *MD = nullptr; - - // Check if this terminator branches to the loop header. - for (BasicBlock *Successor : TI->successors()) { - if (Successor == H) { - MD = TI->getMetadata(LLVMContext::MD_loop); - break; - } - } - if (!MD) - return nullptr; - if (!LoopID) - LoopID = MD; - else if (MD != LoopID) - return nullptr; - } + // Go through the latch blocks and check the terminator for the metadata. + SmallVector<BasicBlock *, 4> LatchesBlocks; + getLoopLatches(LatchesBlocks); + for (BasicBlock *BB : LatchesBlocks) { + Instruction *TI = BB->getTerminator(); + MDNode *MD = TI->getMetadata(LLVMContext::MD_loop); + + if (!MD) + return nullptr; + + if (!LoopID) + LoopID = MD; + else if (MD != LoopID) + return nullptr; } if (!LoopID || LoopID->getNumOperands() == 0 || LoopID->getOperand(0) != LoopID) @@ -248,23 +237,19 @@ MDNode *Loop::getLoopID() const { } void Loop::setLoopID(MDNode *LoopID) const { - assert(LoopID && "Loop ID should not be null"); - assert(LoopID->getNumOperands() > 0 && "Loop ID needs at least one operand"); - assert(LoopID->getOperand(0) == LoopID && "Loop ID should refer to itself"); + assert((!LoopID || LoopID->getNumOperands() > 0) && + "Loop ID needs at least one operand"); + assert((!LoopID || LoopID->getOperand(0) == LoopID) && + "Loop ID should refer to itself"); - if (BasicBlock *Latch = getLoopLatch()) { - Latch->getTerminator()->setMetadata(LLVMContext::MD_loop, LoopID); - return; - } - - assert(!getLoopLatch() && - "The loop should have no single latch at this point"); BasicBlock *H = getHeader(); for (BasicBlock *BB : this->blocks()) { - TerminatorInst *TI = BB->getTerminator(); - for (BasicBlock *Successor : TI->successors()) { - if (Successor == H) + Instruction *TI = BB->getTerminator(); + for (BasicBlock *Successor : successors(TI)) { + if (Successor == H) { TI->setMetadata(LLVMContext::MD_loop, LoopID); + break; + } } } } @@ -308,16 +293,50 @@ bool Loop::isAnnotatedParallel() const { if (!DesiredLoopIdMetadata) return false; + MDNode *ParallelAccesses = + findOptionMDForLoop(this, "llvm.loop.parallel_accesses"); + SmallPtrSet<MDNode *, 4> + ParallelAccessGroups; // For scalable 'contains' check. + if (ParallelAccesses) { + for (const MDOperand &MD : drop_begin(ParallelAccesses->operands(), 1)) { + MDNode *AccGroup = cast<MDNode>(MD.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + ParallelAccessGroups.insert(AccGroup); + } + } + // The loop branch contains the parallel loop metadata. In order to ensure // that any parallel-loop-unaware optimization pass hasn't added loop-carried // dependencies (thus converted the loop back to a sequential loop), check - // that all the memory instructions in the loop contain parallelism metadata - // that point to the same unique "loop id metadata" the loop branch does. + // that all the memory instructions in the loop belong to an access group that + // is parallel to this loop. for (BasicBlock *BB : this->blocks()) { for (Instruction &I : *BB) { if (!I.mayReadOrWriteMemory()) continue; + if (MDNode *AccessGroup = I.getMetadata(LLVMContext::MD_access_group)) { + auto ContainsAccessGroup = [&ParallelAccessGroups](MDNode *AG) -> bool { + if (AG->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AG) && "Item must be an access group"); + return ParallelAccessGroups.count(AG); + } + + for (const MDOperand &AccessListItem : AG->operands()) { + MDNode *AccGroup = cast<MDNode>(AccessListItem.get()); + assert(isValidAsAccessGroup(AccGroup) && + "List item must be an access group"); + if (ParallelAccessGroups.count(AccGroup)) + return true; + } + return false; + }; + + if (ContainsAccessGroup(AccessGroup)) + continue; + } + // The memory instruction can refer to the loop identifier metadata // directly or indirectly through another list metadata (in case of // nested parallel loops). The loop identifier metadata refers to @@ -708,6 +727,40 @@ void llvm::printLoop(Loop &L, raw_ostream &OS, const std::string &Banner) { } } +MDNode *llvm::findOptionMDForLoopID(MDNode *LoopID, StringRef Name) { + // No loop metadata node, no loop properties. + if (!LoopID) + return nullptr; + + // First operand should refer to the metadata node itself, for legacy reasons. + assert(LoopID->getNumOperands() > 0 && "requires at least one operand"); + assert(LoopID->getOperand(0) == LoopID && "invalid loop id"); + + // Iterate over the metdata node operands and look for MDString metadata. + for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) { + MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); + if (!MD || MD->getNumOperands() < 1) + continue; + MDString *S = dyn_cast<MDString>(MD->getOperand(0)); + if (!S) + continue; + // Return the operand node if MDString holds expected metadata. + if (Name.equals(S->getString())) + return MD; + } + + // Loop property not found. + return nullptr; +} + +MDNode *llvm::findOptionMDForLoop(const Loop *TheLoop, StringRef Name) { + return findOptionMDForLoopID(TheLoop->getLoopID(), Name); +} + +bool llvm::isValidAsAccessGroup(MDNode *Node) { + return Node->getNumOperands() == 0 && Node->isDistinct(); +} + //===----------------------------------------------------------------------===// // LoopInfo implementation // diff --git a/lib/Analysis/LoopPass.cpp b/lib/Analysis/LoopPass.cpp index 07a151ce0fce..a68f114b83a0 100644 --- a/lib/Analysis/LoopPass.cpp +++ b/lib/Analysis/LoopPass.cpp @@ -20,6 +20,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/OptBisect.h" #include "llvm/IR/PassManager.h" +#include "llvm/IR/PassTimingInfo.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" @@ -193,8 +194,14 @@ bool LPPassManager::runOnFunction(Function &F) { } // Walk Loops - unsigned InstrCount = 0; + unsigned InstrCount, FunctionSize = 0; + StringMap<std::pair<unsigned, unsigned>> FunctionToInstrCount; bool EmitICRemark = M.shouldEmitInstrCountChangedRemark(); + // Collect the initial size of the module and the function we're looking at. + if (EmitICRemark) { + InstrCount = initSizeRemarkInfo(M, FunctionToInstrCount); + FunctionSize = F.getInstructionCount(); + } while (!LQ.empty()) { CurrentLoopDeleted = false; CurrentLoop = LQ.back(); @@ -209,17 +216,28 @@ bool LPPassManager::runOnFunction(Function &F) { initializeAnalysisImpl(P); + bool LocalChanged = false; { PassManagerPrettyStackEntry X(P, *CurrentLoop->getHeader()); TimeRegion PassTimer(getPassTimer(P)); - if (EmitICRemark) - InstrCount = initSizeRemarkInfo(M); - Changed |= P->runOnLoop(CurrentLoop, *this); - if (EmitICRemark) - emitInstrCountChangedRemark(P, M, InstrCount); + LocalChanged = P->runOnLoop(CurrentLoop, *this); + Changed |= LocalChanged; + if (EmitICRemark) { + unsigned NewSize = F.getInstructionCount(); + // Update the size of the function, emit a remark, and update the + // size of the module. + if (NewSize != FunctionSize) { + int64_t Delta = static_cast<int64_t>(NewSize) - + static_cast<int64_t>(FunctionSize); + emitInstrCountChangedRemark(P, M, Delta, InstrCount, + FunctionToInstrCount, &F); + InstrCount = static_cast<int64_t>(InstrCount) + Delta; + FunctionSize = NewSize; + } + } } - if (Changed) + if (LocalChanged) dumpPassInfo(P, MODIFICATION_MSG, ON_LOOP_MSG, CurrentLoopDeleted ? "<deleted loop>" : CurrentLoop->getName()); diff --git a/lib/Analysis/MemDepPrinter.cpp b/lib/Analysis/MemDepPrinter.cpp index 5a6bbd7b2ac6..907b321b231a 100644 --- a/lib/Analysis/MemDepPrinter.cpp +++ b/lib/Analysis/MemDepPrinter.cpp @@ -13,7 +13,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/Passes.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" @@ -106,9 +105,9 @@ bool MemDepPrinter::runOnFunction(Function &F) { if (!Res.isNonLocal()) { Deps[Inst].insert(std::make_pair(getInstTypePair(Res), static_cast<BasicBlock *>(nullptr))); - } else if (auto CS = CallSite(Inst)) { + } else if (auto *Call = dyn_cast<CallBase>(Inst)) { const MemoryDependenceResults::NonLocalDepInfo &NLDI = - MDA.getNonLocalCallDependency(CS); + MDA.getNonLocalCallDependency(Call); DepSet &InstDeps = Deps[Inst]; for (const NonLocalDepEntry &I : NLDI) { diff --git a/lib/Analysis/MemoryDependenceAnalysis.cpp b/lib/Analysis/MemoryDependenceAnalysis.cpp index feae53c54ecb..e22182b99e11 100644 --- a/lib/Analysis/MemoryDependenceAnalysis.cpp +++ b/lib/Analysis/MemoryDependenceAnalysis.cpp @@ -31,7 +31,6 @@ #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -182,8 +181,8 @@ static ModRefInfo GetLocation(const Instruction *Inst, MemoryLocation &Loc, } /// Private helper for finding the local dependencies of a call site. -MemDepResult MemoryDependenceResults::getCallSiteDependencyFrom( - CallSite CS, bool isReadOnlyCall, BasicBlock::iterator ScanIt, +MemDepResult MemoryDependenceResults::getCallDependencyFrom( + CallBase *Call, bool isReadOnlyCall, BasicBlock::iterator ScanIt, BasicBlock *BB) { unsigned Limit = BlockScanLimit; @@ -205,21 +204,21 @@ MemDepResult MemoryDependenceResults::getCallSiteDependencyFrom( ModRefInfo MR = GetLocation(Inst, Loc, TLI); if (Loc.Ptr) { // A simple instruction. - if (isModOrRefSet(AA.getModRefInfo(CS, Loc))) + if (isModOrRefSet(AA.getModRefInfo(Call, Loc))) return MemDepResult::getClobber(Inst); continue; } - if (auto InstCS = CallSite(Inst)) { + if (auto *CallB = dyn_cast<CallBase>(Inst)) { // If these two calls do not interfere, look past it. - if (isNoModRef(AA.getModRefInfo(CS, InstCS))) { - // If the two calls are the same, return InstCS as a Def, so that - // CS can be found redundant and eliminated. + if (isNoModRef(AA.getModRefInfo(Call, CallB))) { + // If the two calls are the same, return Inst as a Def, so that + // Call can be found redundant and eliminated. if (isReadOnlyCall && !isModSet(MR) && - CS.getInstruction()->isIdenticalToWhenDefined(Inst)) + Call->isIdenticalToWhenDefined(CallB)) return MemDepResult::getDef(Inst); - // Otherwise if the two calls don't interact (e.g. InstCS is readnone) + // Otherwise if the two calls don't interact (e.g. CallB is readnone) // keep scanning. continue; } else @@ -750,11 +749,10 @@ MemDepResult MemoryDependenceResults::getDependency(Instruction *QueryInst) { LocalCache = getPointerDependencyFrom( MemLoc, isLoad, ScanPos->getIterator(), QueryParent, QueryInst); - } else if (isa<CallInst>(QueryInst) || isa<InvokeInst>(QueryInst)) { - CallSite QueryCS(QueryInst); - bool isReadOnly = AA.onlyReadsMemory(QueryCS); - LocalCache = getCallSiteDependencyFrom( - QueryCS, isReadOnly, ScanPos->getIterator(), QueryParent); + } else if (auto *QueryCall = dyn_cast<CallBase>(QueryInst)) { + bool isReadOnly = AA.onlyReadsMemory(QueryCall); + LocalCache = getCallDependencyFrom(QueryCall, isReadOnly, + ScanPos->getIterator(), QueryParent); } else // Non-memory instruction. LocalCache = MemDepResult::getUnknown(); @@ -780,11 +778,11 @@ static void AssertSorted(MemoryDependenceResults::NonLocalDepInfo &Cache, #endif const MemoryDependenceResults::NonLocalDepInfo & -MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { - assert(getDependency(QueryCS.getInstruction()).isNonLocal() && +MemoryDependenceResults::getNonLocalCallDependency(CallBase *QueryCall) { + assert(getDependency(QueryCall).isNonLocal() && "getNonLocalCallDependency should only be used on calls with " "non-local deps!"); - PerInstNLInfo &CacheP = NonLocalDeps[QueryCS.getInstruction()]; + PerInstNLInfo &CacheP = NonLocalDeps[QueryCall]; NonLocalDepInfo &Cache = CacheP.first; // This is the set of blocks that need to be recomputed. In the cached case, @@ -807,21 +805,21 @@ MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { DirtyBlocks.push_back(Entry.getBB()); // Sort the cache so that we can do fast binary search lookups below. - llvm::sort(Cache.begin(), Cache.end()); + llvm::sort(Cache); ++NumCacheDirtyNonLocal; // cerr << "CACHED CASE: " << DirtyBlocks.size() << " dirty: " // << Cache.size() << " cached: " << *QueryInst; } else { // Seed DirtyBlocks with each of the preds of QueryInst's block. - BasicBlock *QueryBB = QueryCS.getInstruction()->getParent(); + BasicBlock *QueryBB = QueryCall->getParent(); for (BasicBlock *Pred : PredCache.get(QueryBB)) DirtyBlocks.push_back(Pred); ++NumUncacheNonLocal; } // isReadonlyCall - If this is a read-only call, we can be more aggressive. - bool isReadonlyCall = AA.onlyReadsMemory(QueryCS); + bool isReadonlyCall = AA.onlyReadsMemory(QueryCall); SmallPtrSet<BasicBlock *, 32> Visited; @@ -865,8 +863,8 @@ MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { if (Instruction *Inst = ExistingResult->getResult().getInst()) { ScanPos = Inst->getIterator(); // We're removing QueryInst's use of Inst. - RemoveFromReverseMap(ReverseNonLocalDeps, Inst, - QueryCS.getInstruction()); + RemoveFromReverseMap<Instruction *>(ReverseNonLocalDeps, Inst, + QueryCall); } } @@ -874,8 +872,7 @@ MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { MemDepResult Dep; if (ScanPos != DirtyBB->begin()) { - Dep = - getCallSiteDependencyFrom(QueryCS, isReadonlyCall, ScanPos, DirtyBB); + Dep = getCallDependencyFrom(QueryCall, isReadonlyCall, ScanPos, DirtyBB); } else if (DirtyBB != &DirtyBB->getParent()->getEntryBlock()) { // No dependence found. If this is the entry block of the function, it is // a clobber, otherwise it is unknown. @@ -897,7 +894,7 @@ MemoryDependenceResults::getNonLocalCallDependency(CallSite QueryCS) { // Keep the ReverseNonLocalDeps map up to date so we can efficiently // update this when we remove instructions. if (Instruction *Inst = Dep.getInst()) - ReverseNonLocalDeps[Inst].insert(QueryCS.getInstruction()); + ReverseNonLocalDeps[Inst].insert(QueryCall); } else { // If the block *is* completely transparent to the load, we need to check @@ -1070,7 +1067,7 @@ SortNonLocalDepInfoCache(MemoryDependenceResults::NonLocalDepInfo &Cache, break; default: // Added many values, do a full scale sort. - llvm::sort(Cache.begin(), Cache.end()); + llvm::sort(Cache); break; } } @@ -1113,21 +1110,36 @@ bool MemoryDependenceResults::getNonLocalPointerDepFromBB( // If we already have a cache entry for this CacheKey, we may need to do some // work to reconcile the cache entry and the current query. if (!Pair.second) { - if (CacheInfo->Size < Loc.Size) { - // The query's Size is greater than the cached one. Throw out the - // cached data and proceed with the query at the greater size. - CacheInfo->Pair = BBSkipFirstBlockPair(); - CacheInfo->Size = Loc.Size; - for (auto &Entry : CacheInfo->NonLocalDeps) - if (Instruction *Inst = Entry.getResult().getInst()) - RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); - CacheInfo->NonLocalDeps.clear(); - } else if (CacheInfo->Size > Loc.Size) { - // This query's Size is less than the cached one. Conservatively restart - // the query using the greater size. - return getNonLocalPointerDepFromBB( - QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad, - StartBB, Result, Visited, SkipFirstBlock); + if (CacheInfo->Size != Loc.Size) { + bool ThrowOutEverything; + if (CacheInfo->Size.hasValue() && Loc.Size.hasValue()) { + // FIXME: We may be able to do better in the face of results with mixed + // precision. We don't appear to get them in practice, though, so just + // be conservative. + ThrowOutEverything = + CacheInfo->Size.isPrecise() != Loc.Size.isPrecise() || + CacheInfo->Size.getValue() < Loc.Size.getValue(); + } else { + // For our purposes, unknown size > all others. + ThrowOutEverything = !Loc.Size.hasValue(); + } + + if (ThrowOutEverything) { + // The query's Size is greater than the cached one. Throw out the + // cached data and proceed with the query at the greater size. + CacheInfo->Pair = BBSkipFirstBlockPair(); + CacheInfo->Size = Loc.Size; + for (auto &Entry : CacheInfo->NonLocalDeps) + if (Instruction *Inst = Entry.getResult().getInst()) + RemoveFromReverseMap(ReverseNonLocalPtrDeps, Inst, CacheKey); + CacheInfo->NonLocalDeps.clear(); + } else { + // This query's Size is less than the cached one. Conservatively restart + // the query using the greater size. + return getNonLocalPointerDepFromBB( + QueryInst, Pointer, Loc.getWithNewSize(CacheInfo->Size), isLoad, + StartBB, Result, Visited, SkipFirstBlock); + } } // If the query's AATags are inconsistent with the cached one, @@ -1572,7 +1584,7 @@ void MemoryDependenceResults::removeInstruction(Instruction *RemInst) { ReverseDepMapType::iterator ReverseDepIt = ReverseLocalDeps.find(RemInst); if (ReverseDepIt != ReverseLocalDeps.end()) { // RemInst can't be the terminator if it has local stuff depending on it. - assert(!ReverseDepIt->second.empty() && !isa<TerminatorInst>(RemInst) && + assert(!ReverseDepIt->second.empty() && !RemInst->isTerminator() && "Nothing can locally depend on a terminator"); for (Instruction *InstDependingOnRemInst : ReverseDepIt->second) { @@ -1662,7 +1674,7 @@ void MemoryDependenceResults::removeInstruction(Instruction *RemInst) { // Re-sort the NonLocalDepInfo. Changing the dirty entry to its // subsequent value may invalidate the sortedness. - llvm::sort(NLPDI.begin(), NLPDI.end()); + llvm::sort(NLPDI); } ReverseNonLocalPtrDeps.erase(ReversePtrDepIt); diff --git a/lib/Analysis/MemoryLocation.cpp b/lib/Analysis/MemoryLocation.cpp index 55924db284ec..27e8d72b8e89 100644 --- a/lib/Analysis/MemoryLocation.cpp +++ b/lib/Analysis/MemoryLocation.cpp @@ -18,13 +18,28 @@ #include "llvm/IR/Type.h" using namespace llvm; +void LocationSize::print(raw_ostream &OS) const { + OS << "LocationSize::"; + if (*this == unknown()) + OS << "unknown"; + else if (*this == mapEmpty()) + OS << "mapEmpty"; + else if (*this == mapTombstone()) + OS << "mapTombstone"; + else if (isPrecise()) + OS << "precise(" << getValue() << ')'; + else + OS << "upperBound(" << getValue() << ')'; +} + MemoryLocation MemoryLocation::get(const LoadInst *LI) { AAMDNodes AATags; LI->getAAMetadata(AATags); const auto &DL = LI->getModule()->getDataLayout(); - return MemoryLocation(LI->getPointerOperand(), - DL.getTypeStoreSize(LI->getType()), AATags); + return MemoryLocation( + LI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize(LI->getType())), AATags); } MemoryLocation MemoryLocation::get(const StoreInst *SI) { @@ -33,7 +48,8 @@ MemoryLocation MemoryLocation::get(const StoreInst *SI) { const auto &DL = SI->getModule()->getDataLayout(); return MemoryLocation(SI->getPointerOperand(), - DL.getTypeStoreSize(SI->getValueOperand()->getType()), + LocationSize::precise(DL.getTypeStoreSize( + SI->getValueOperand()->getType())), AATags); } @@ -41,7 +57,8 @@ MemoryLocation MemoryLocation::get(const VAArgInst *VI) { AAMDNodes AATags; VI->getAAMetadata(AATags); - return MemoryLocation(VI->getPointerOperand(), UnknownSize, AATags); + return MemoryLocation(VI->getPointerOperand(), LocationSize::unknown(), + AATags); } MemoryLocation MemoryLocation::get(const AtomicCmpXchgInst *CXI) { @@ -49,9 +66,10 @@ MemoryLocation MemoryLocation::get(const AtomicCmpXchgInst *CXI) { CXI->getAAMetadata(AATags); const auto &DL = CXI->getModule()->getDataLayout(); - return MemoryLocation( - CXI->getPointerOperand(), - DL.getTypeStoreSize(CXI->getCompareOperand()->getType()), AATags); + return MemoryLocation(CXI->getPointerOperand(), + LocationSize::precise(DL.getTypeStoreSize( + CXI->getCompareOperand()->getType())), + AATags); } MemoryLocation MemoryLocation::get(const AtomicRMWInst *RMWI) { @@ -60,7 +78,8 @@ MemoryLocation MemoryLocation::get(const AtomicRMWInst *RMWI) { const auto &DL = RMWI->getModule()->getDataLayout(); return MemoryLocation(RMWI->getPointerOperand(), - DL.getTypeStoreSize(RMWI->getValOperand()->getType()), + LocationSize::precise(DL.getTypeStoreSize( + RMWI->getValOperand()->getType())), AATags); } @@ -73,9 +92,9 @@ MemoryLocation MemoryLocation::getForSource(const AtomicMemTransferInst *MTI) { } MemoryLocation MemoryLocation::getForSource(const AnyMemTransferInst *MTI) { - uint64_t Size = UnknownSize; + auto Size = LocationSize::unknown(); if (ConstantInt *C = dyn_cast<ConstantInt>(MTI->getLength())) - Size = C->getValue().getZExtValue(); + Size = LocationSize::precise(C->getValue().getZExtValue()); // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. @@ -94,9 +113,9 @@ MemoryLocation MemoryLocation::getForDest(const AtomicMemIntrinsic *MI) { } MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) { - uint64_t Size = UnknownSize; + auto Size = LocationSize::unknown(); if (ConstantInt *C = dyn_cast<ConstantInt>(MI->getLength())) - Size = C->getValue().getZExtValue(); + Size = LocationSize::precise(C->getValue().getZExtValue()); // memcpy/memmove can have AA tags. For memcpy, they apply // to both the source and the destination. @@ -106,15 +125,15 @@ MemoryLocation MemoryLocation::getForDest(const AnyMemIntrinsic *MI) { return MemoryLocation(MI->getRawDest(), Size, AATags); } -MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, +MemoryLocation MemoryLocation::getForArgument(const CallBase *Call, unsigned ArgIdx, - const TargetLibraryInfo &TLI) { + const TargetLibraryInfo *TLI) { AAMDNodes AATags; - CS->getAAMetadata(AATags); - const Value *Arg = CS.getArgument(ArgIdx); + Call->getAAMetadata(AATags); + const Value *Arg = Call->getArgOperand(ArgIdx); // We may be able to produce an exact size for known intrinsics. - if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { + if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Call)) { const DataLayout &DL = II->getModule()->getDataLayout(); switch (II->getIntrinsicID()) { @@ -126,7 +145,8 @@ MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, assert((ArgIdx == 0 || ArgIdx == 1) && "Invalid argument index for memory intrinsic"); if (ConstantInt *LenCI = dyn_cast<ConstantInt>(II->getArgOperand(2))) - return MemoryLocation(Arg, LenCI->getZExtValue(), AATags); + return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()), + AATags); break; case Intrinsic::lifetime_start: @@ -134,23 +154,37 @@ MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, case Intrinsic::invariant_start: assert(ArgIdx == 1 && "Invalid argument index"); return MemoryLocation( - Arg, cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(), AATags); + Arg, + LocationSize::precise( + cast<ConstantInt>(II->getArgOperand(0))->getZExtValue()), + AATags); case Intrinsic::invariant_end: + // The first argument to an invariant.end is a "descriptor" type (e.g. a + // pointer to a empty struct) which is never actually dereferenced. + if (ArgIdx == 0) + return MemoryLocation(Arg, LocationSize::precise(0), AATags); assert(ArgIdx == 2 && "Invalid argument index"); return MemoryLocation( - Arg, cast<ConstantInt>(II->getArgOperand(1))->getZExtValue(), AATags); + Arg, + LocationSize::precise( + cast<ConstantInt>(II->getArgOperand(1))->getZExtValue()), + AATags); case Intrinsic::arm_neon_vld1: assert(ArgIdx == 0 && "Invalid argument index"); // LLVM's vld1 and vst1 intrinsics currently only support a single // vector register. - return MemoryLocation(Arg, DL.getTypeStoreSize(II->getType()), AATags); + return MemoryLocation( + Arg, LocationSize::precise(DL.getTypeStoreSize(II->getType())), + AATags); case Intrinsic::arm_neon_vst1: assert(ArgIdx == 0 && "Invalid argument index"); - return MemoryLocation( - Arg, DL.getTypeStoreSize(II->getArgOperand(1)->getType()), AATags); + return MemoryLocation(Arg, + LocationSize::precise(DL.getTypeStoreSize( + II->getArgOperand(1)->getType())), + AATags); } } @@ -159,16 +193,20 @@ MemoryLocation MemoryLocation::getForArgument(ImmutableCallSite CS, // LoopIdiomRecognizer likes to turn loops into calls to memset_pattern16 // whenever possible. LibFunc F; - if (CS.getCalledFunction() && TLI.getLibFunc(*CS.getCalledFunction(), F) && - F == LibFunc_memset_pattern16 && TLI.has(F)) { + if (TLI && Call->getCalledFunction() && + TLI->getLibFunc(*Call->getCalledFunction(), F) && + F == LibFunc_memset_pattern16 && TLI->has(F)) { assert((ArgIdx == 0 || ArgIdx == 1) && "Invalid argument index for memset_pattern16"); if (ArgIdx == 1) - return MemoryLocation(Arg, 16, AATags); - if (const ConstantInt *LenCI = dyn_cast<ConstantInt>(CS.getArgument(2))) - return MemoryLocation(Arg, LenCI->getZExtValue(), AATags); + return MemoryLocation(Arg, LocationSize::precise(16), AATags); + if (const ConstantInt *LenCI = + dyn_cast<ConstantInt>(Call->getArgOperand(2))) + return MemoryLocation(Arg, LocationSize::precise(LenCI->getZExtValue()), + AATags); } // FIXME: Handle memset_pattern4 and memset_pattern8 also. - return MemoryLocation(CS.getArgument(ArgIdx), UnknownSize, AATags); + return MemoryLocation(Call->getArgOperand(ArgIdx), LocationSize::unknown(), + AATags); } diff --git a/lib/Analysis/MemorySSA.cpp b/lib/Analysis/MemorySSA.cpp index f57d490ce96e..6a5567ed765b 100644 --- a/lib/Analysis/MemorySSA.cpp +++ b/lib/Analysis/MemorySSA.cpp @@ -30,7 +30,6 @@ #include "llvm/Config/llvm-config.h" #include "llvm/IR/AssemblyAnnotationWriter.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instruction.h" @@ -77,9 +76,15 @@ static cl::opt<unsigned> MaxCheckLimit( cl::desc("The maximum number of stores/phis MemorySSA" "will consider trying to walk past (default = 100)")); -static cl::opt<bool> - VerifyMemorySSA("verify-memoryssa", cl::init(false), cl::Hidden, - cl::desc("Verify MemorySSA in legacy printer pass.")); +// Always verify MemorySSA if expensive checking is enabled. +#ifdef EXPENSIVE_CHECKS +bool llvm::VerifyMemorySSA = true; +#else +bool llvm::VerifyMemorySSA = false; +#endif +static cl::opt<bool, true> + VerifyMemorySSAX("verify-memoryssa", cl::location(VerifyMemorySSA), + cl::Hidden, cl::desc("Enable verification of MemorySSA.")); namespace llvm { @@ -119,16 +124,15 @@ class MemoryLocOrCall { public: bool IsCall = false; - MemoryLocOrCall() = default; MemoryLocOrCall(MemoryUseOrDef *MUD) : MemoryLocOrCall(MUD->getMemoryInst()) {} MemoryLocOrCall(const MemoryUseOrDef *MUD) : MemoryLocOrCall(MUD->getMemoryInst()) {} MemoryLocOrCall(Instruction *Inst) { - if (ImmutableCallSite(Inst)) { + if (auto *C = dyn_cast<CallBase>(Inst)) { IsCall = true; - CS = ImmutableCallSite(Inst); + Call = C; } else { IsCall = false; // There is no such thing as a memorylocation for a fence inst, and it is @@ -140,9 +144,9 @@ public: explicit MemoryLocOrCall(const MemoryLocation &Loc) : Loc(Loc) {} - ImmutableCallSite getCS() const { + const CallBase *getCall() const { assert(IsCall); - return CS; + return Call; } MemoryLocation getLoc() const { @@ -157,16 +161,17 @@ public: if (!IsCall) return Loc == Other.Loc; - if (CS.getCalledValue() != Other.CS.getCalledValue()) + if (Call->getCalledValue() != Other.Call->getCalledValue()) return false; - return CS.arg_size() == Other.CS.arg_size() && - std::equal(CS.arg_begin(), CS.arg_end(), Other.CS.arg_begin()); + return Call->arg_size() == Other.Call->arg_size() && + std::equal(Call->arg_begin(), Call->arg_end(), + Other.Call->arg_begin()); } private: union { - ImmutableCallSite CS; + const CallBase *Call; MemoryLocation Loc; }; }; @@ -192,9 +197,9 @@ template <> struct DenseMapInfo<MemoryLocOrCall> { hash_code hash = hash_combine(MLOC.IsCall, DenseMapInfo<const Value *>::getHashValue( - MLOC.getCS().getCalledValue())); + MLOC.getCall()->getCalledValue())); - for (const Value *Arg : MLOC.getCS().args()) + for (const Value *Arg : MLOC.getCall()->args()) hash = hash_combine(hash, DenseMapInfo<const Value *>::getHashValue(Arg)); return hash; } @@ -247,24 +252,29 @@ struct ClobberAlias { // Return a pair of {IsClobber (bool), AR (AliasResult)}. It relies on AR being // ignored if IsClobber = false. -static ClobberAlias instructionClobbersQuery(MemoryDef *MD, +static ClobberAlias instructionClobbersQuery(const MemoryDef *MD, const MemoryLocation &UseLoc, const Instruction *UseInst, AliasAnalysis &AA) { Instruction *DefInst = MD->getMemoryInst(); assert(DefInst && "Defining instruction not actually an instruction"); - ImmutableCallSite UseCS(UseInst); + const auto *UseCall = dyn_cast<CallBase>(UseInst); Optional<AliasResult> AR; if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(DefInst)) { // These intrinsics will show up as affecting memory, but they are just - // markers. + // markers, mostly. + // + // FIXME: We probably don't actually want MemorySSA to model these at all + // (including creating MemoryAccesses for them): we just end up inventing + // clobbers where they don't really exist at all. Please see D43269 for + // context. switch (II->getIntrinsicID()) { case Intrinsic::lifetime_start: - if (UseCS) + if (UseCall) return {false, NoAlias}; AR = AA.alias(MemoryLocation(II->getArgOperand(1)), UseLoc); - return {AR == MustAlias, AR}; + return {AR != NoAlias, AR}; case Intrinsic::lifetime_end: case Intrinsic::invariant_start: case Intrinsic::invariant_end: @@ -275,8 +285,8 @@ static ClobberAlias instructionClobbersQuery(MemoryDef *MD, } } - if (UseCS) { - ModRefInfo I = AA.getModRefInfo(DefInst, UseCS); + if (UseCall) { + ModRefInfo I = AA.getModRefInfo(DefInst, UseCall); AR = isMustSet(I) ? MustAlias : MayAlias; return {isModOrRefSet(I), AR}; } @@ -322,11 +332,12 @@ struct UpwardsMemoryQuery { // The MemoryAccess we actually got called with, used to test local domination const MemoryAccess *OriginalAccess = nullptr; Optional<AliasResult> AR = MayAlias; + bool SkipSelfAccess = false; UpwardsMemoryQuery() = default; UpwardsMemoryQuery(const Instruction *Inst, const MemoryAccess *Access) - : IsCall(ImmutableCallSite(Inst)), Inst(Inst), OriginalAccess(Access) { + : IsCall(isa<CallBase>(Inst)), Inst(Inst), OriginalAccess(Access) { if (!IsCall) StartingLoc = MemoryLocation::get(Inst); } @@ -366,13 +377,15 @@ static bool isUseTriviallyOptimizableToLiveOnEntry(AliasAnalysis &AA, /// \param Start The MemoryAccess that we want to walk from. /// \param ClobberAt A clobber for Start. /// \param StartLoc The MemoryLocation for Start. -/// \param MSSA The MemorySSA isntance that Start and ClobberAt belong to. +/// \param MSSA The MemorySSA instance that Start and ClobberAt belong to. /// \param Query The UpwardsMemoryQuery we used for our search. /// \param AA The AliasAnalysis we used for our search. -static void LLVM_ATTRIBUTE_UNUSED -checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, +/// \param AllowImpreciseClobber Always false, unless we do relaxed verify. +static void +checkClobberSanity(const MemoryAccess *Start, MemoryAccess *ClobberAt, const MemoryLocation &StartLoc, const MemorySSA &MSSA, - const UpwardsMemoryQuery &Query, AliasAnalysis &AA) { + const UpwardsMemoryQuery &Query, AliasAnalysis &AA, + bool AllowImpreciseClobber = false) { assert(MSSA.dominates(ClobberAt, Start) && "Clobber doesn't dominate start?"); if (MSSA.isLiveOnEntryDef(Start)) { @@ -382,21 +395,21 @@ checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, } bool FoundClobber = false; - DenseSet<MemoryAccessPair> VisitedPhis; - SmallVector<MemoryAccessPair, 8> Worklist; + DenseSet<ConstMemoryAccessPair> VisitedPhis; + SmallVector<ConstMemoryAccessPair, 8> Worklist; Worklist.emplace_back(Start, StartLoc); // Walk all paths from Start to ClobberAt, while looking for clobbers. If one // is found, complain. while (!Worklist.empty()) { - MemoryAccessPair MAP = Worklist.pop_back_val(); + auto MAP = Worklist.pop_back_val(); // All we care about is that nothing from Start to ClobberAt clobbers Start. // We learn nothing from revisiting nodes. if (!VisitedPhis.insert(MAP).second) continue; - for (MemoryAccess *MA : def_chain(MAP.first)) { + for (const auto *MA : def_chain(MAP.first)) { if (MA == ClobberAt) { - if (auto *MD = dyn_cast<MemoryDef>(MA)) { + if (const auto *MD = dyn_cast<MemoryDef>(MA)) { // instructionClobbersQuery isn't essentially free, so don't use `|=`, // since it won't let us short-circuit. // @@ -418,19 +431,39 @@ checkClobberSanity(MemoryAccess *Start, MemoryAccess *ClobberAt, // We should never hit liveOnEntry, unless it's the clobber. assert(!MSSA.isLiveOnEntryDef(MA) && "Hit liveOnEntry before clobber?"); - if (auto *MD = dyn_cast<MemoryDef>(MA)) { - (void)MD; + if (const auto *MD = dyn_cast<MemoryDef>(MA)) { + // If Start is a Def, skip self. + if (MD == Start) + continue; + assert(!instructionClobbersQuery(MD, MAP.second, Query.Inst, AA) .IsClobber && "Found clobber before reaching ClobberAt!"); continue; } + if (const auto *MU = dyn_cast<MemoryUse>(MA)) { + (void)MU; + assert (MU == Start && + "Can only find use in def chain if Start is a use"); + continue; + } + assert(isa<MemoryPhi>(MA)); - Worklist.append(upward_defs_begin({MA, MAP.second}), upward_defs_end()); + Worklist.append( + upward_defs_begin({const_cast<MemoryAccess *>(MA), MAP.second}), + upward_defs_end()); } } + // If the verify is done following an optimization, it's possible that + // ClobberAt was a conservative clobbering, that we can now infer is not a + // true clobbering access. Don't fail the verify if that's the case. + // We do have accesses that claim they're optimized, but could be optimized + // further. Updating all these can be expensive, so allow it for now (FIXME). + if (AllowImpreciseClobber) + return; + // If ClobberAt is a MemoryPhi, we can assume something above it acted as a // clobber. Otherwise, `ClobberAt` should've acted as a clobber at some point. assert((isa<MemoryPhi>(ClobberAt) || FoundClobber) && @@ -503,13 +536,13 @@ class ClobberWalker { /// /// This does not test for whether StopAt is a clobber UpwardsWalkResult - walkToPhiOrClobber(DefPath &Desc, - const MemoryAccess *StopAt = nullptr) const { + walkToPhiOrClobber(DefPath &Desc, const MemoryAccess *StopAt = nullptr, + const MemoryAccess *SkipStopAt = nullptr) const { assert(!isa<MemoryUse>(Desc.Last) && "Uses don't exist in my world"); for (MemoryAccess *Current : def_chain(Desc.Last)) { Desc.Last = Current; - if (Current == StopAt) + if (Current == StopAt || Current == SkipStopAt) return {Current, false, MayAlias}; if (auto *MD = dyn_cast<MemoryDef>(Current)) { @@ -587,9 +620,16 @@ class ClobberWalker { if (!VisitedPhis.insert({Node.Last, Node.Loc}).second) continue; - UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere); + const MemoryAccess *SkipStopWhere = nullptr; + if (Query->SkipSelfAccess && Node.Loc == Query->StartingLoc) { + assert(isa<MemoryDef>(Query->OriginalAccess)); + SkipStopWhere = Query->OriginalAccess; + } + + UpwardsWalkResult Res = walkToPhiOrClobber(Node, /*StopAt=*/StopWhere, + /*SkipStopAt=*/SkipStopWhere); if (Res.IsKnownClobber) { - assert(Res.Result != StopWhere); + assert(Res.Result != StopWhere && Res.Result != SkipStopWhere); // If this wasn't a cache hit, we hit a clobber when walking. That's a // failure. TerminatedPath Term{Res.Result, PathIndex}; @@ -601,10 +641,13 @@ class ClobberWalker { continue; } - if (Res.Result == StopWhere) { + if (Res.Result == StopWhere || Res.Result == SkipStopWhere) { // We've hit our target. Save this path off for if we want to continue - // walking. - NewPaused.push_back(PathIndex); + // walking. If we are in the mode of skipping the OriginalAccess, and + // we've reached back to the OriginalAccess, do not save path, we've + // just looped back to self. + if (Res.Result != SkipStopWhere) + NewPaused.push_back(PathIndex); continue; } @@ -875,7 +918,8 @@ public: } #ifdef EXPENSIVE_CHECKS - checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); + if (!Q.SkipSelfAccess) + checkClobberSanity(Current, Result, Q.StartingLoc, MSSA, Q, AA); #endif return Result; } @@ -903,28 +947,76 @@ struct RenamePassData { namespace llvm { +class MemorySSA::ClobberWalkerBase { + ClobberWalker Walker; + MemorySSA *MSSA; + +public: + ClobberWalkerBase(MemorySSA *M, AliasAnalysis *A, DominatorTree *D) + : Walker(*M, *A, *D), MSSA(M) {} + + MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *, + const MemoryLocation &); + // Second argument (bool), defines whether the clobber search should skip the + // original queried access. If true, there will be a follow-up query searching + // for a clobber access past "self". Note that the Optimized access is not + // updated if a new clobber is found by this SkipSelf search. If this + // additional query becomes heavily used we may decide to cache the result. + // Walker instantiations will decide how to set the SkipSelf bool. + MemoryAccess *getClobberingMemoryAccessBase(MemoryAccess *, bool); + void verify(const MemorySSA *MSSA) { Walker.verify(MSSA); } +}; + /// A MemorySSAWalker that does AA walks to disambiguate accesses. It no /// longer does caching on its own, but the name has been retained for the /// moment. class MemorySSA::CachingWalker final : public MemorySSAWalker { - ClobberWalker Walker; - - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, UpwardsMemoryQuery &); + ClobberWalkerBase *Walker; public: - CachingWalker(MemorySSA *, AliasAnalysis *, DominatorTree *); + CachingWalker(MemorySSA *M, ClobberWalkerBase *W) + : MemorySSAWalker(M), Walker(W) {} ~CachingWalker() override = default; using MemorySSAWalker::getClobberingMemoryAccess; - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *) override; - MemoryAccess *getClobberingMemoryAccess(MemoryAccess *, - const MemoryLocation &) override; - void invalidateInfo(MemoryAccess *) override; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) override; + + void invalidateInfo(MemoryAccess *MA) override { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->resetOptimized(); + } void verify(const MemorySSA *MSSA) override { MemorySSAWalker::verify(MSSA); - Walker.verify(MSSA); + Walker->verify(MSSA); + } +}; + +class MemorySSA::SkipSelfWalker final : public MemorySSAWalker { + ClobberWalkerBase *Walker; + +public: + SkipSelfWalker(MemorySSA *M, ClobberWalkerBase *W) + : MemorySSAWalker(M), Walker(W) {} + ~SkipSelfWalker() override = default; + + using MemorySSAWalker::getClobberingMemoryAccess; + + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA) override; + MemoryAccess *getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) override; + + void invalidateInfo(MemoryAccess *MA) override { + if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) + MUD->resetOptimized(); + } + + void verify(const MemorySSA *MSSA) override { + MemorySSAWalker::verify(MSSA); + Walker->verify(MSSA); } }; @@ -1063,7 +1155,7 @@ void MemorySSA::markUnreachableAsLiveOnEntry(BasicBlock *BB) { MemorySSA::MemorySSA(Function &Func, AliasAnalysis *AA, DominatorTree *DT) : AA(AA), DT(DT), F(Func), LiveOnEntryDef(nullptr), Walker(nullptr), - NextID(0) { + SkipWalker(nullptr), NextID(0) { buildMemorySSA(); } @@ -1394,10 +1486,25 @@ MemorySSA::CachingWalker *MemorySSA::getWalkerImpl() { if (Walker) return Walker.get(); - Walker = llvm::make_unique<CachingWalker>(this, AA, DT); + if (!WalkerBase) + WalkerBase = llvm::make_unique<ClobberWalkerBase>(this, AA, DT); + + Walker = llvm::make_unique<CachingWalker>(this, WalkerBase.get()); return Walker.get(); } +MemorySSAWalker *MemorySSA::getSkipSelfWalker() { + if (SkipWalker) + return SkipWalker.get(); + + if (!WalkerBase) + WalkerBase = llvm::make_unique<ClobberWalkerBase>(this, AA, DT); + + SkipWalker = llvm::make_unique<SkipSelfWalker>(this, WalkerBase.get()); + return SkipWalker.get(); + } + + // This is a helper function used by the creation routines. It places NewAccess // into the access and defs lists for a given basic block, at the given // insertion point. @@ -1461,15 +1568,25 @@ void MemorySSA::insertIntoListsBefore(MemoryAccess *What, const BasicBlock *BB, BlockNumberingValid.erase(BB); } +void MemorySSA::prepareForMoveTo(MemoryAccess *What, BasicBlock *BB) { + // Keep it in the lookup tables, remove from the lists + removeFromLists(What, false); + + // Note that moving should implicitly invalidate the optimized state of a + // MemoryUse (and Phis can't be optimized). However, it doesn't do so for a + // MemoryDef. + if (auto *MD = dyn_cast<MemoryDef>(What)) + MD->resetOptimized(); + What->setBlock(BB); +} + // Move What before Where in the IR. The end result is that What will belong to // the right lists and have the right Block set, but will not otherwise be // correct. It will not have the right defining access, and if it is a def, // things below it will not properly be updated. void MemorySSA::moveTo(MemoryUseOrDef *What, BasicBlock *BB, AccessList::iterator Where) { - // Keep it in the lookup tables, remove from the lists - removeFromLists(What, false); - What->setBlock(BB); + prepareForMoveTo(What, BB); insertIntoListsBefore(What, BB, Where); } @@ -1485,8 +1602,7 @@ void MemorySSA::moveTo(MemoryAccess *What, BasicBlock *BB, assert(Inserted && "Cannot move a Phi to a block that already has one"); } - removeFromLists(What, false); - What->setBlock(BB); + prepareForMoveTo(What, BB); insertIntoListsForBlock(What, BB, Point); } @@ -1500,9 +1616,10 @@ MemoryPhi *MemorySSA::createMemoryPhi(BasicBlock *BB) { } MemoryUseOrDef *MemorySSA::createDefinedAccess(Instruction *I, - MemoryAccess *Definition) { + MemoryAccess *Definition, + const MemoryUseOrDef *Template) { assert(!isa<PHINode>(I) && "Cannot create a defined access for a PHI"); - MemoryUseOrDef *NewAccess = createNewAccess(I); + MemoryUseOrDef *NewAccess = createNewAccess(I, Template); assert( NewAccess != nullptr && "Tried to create a memory access for a non-memory touching instruction"); @@ -1525,7 +1642,8 @@ static inline bool isOrdered(const Instruction *I) { } /// Helper function to create new memory accesses -MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { +MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I, + const MemoryUseOrDef *Template) { // The assume intrinsic has a control dependency which we model by claiming // that it writes arbitrarily. Ignore that fake memory dependency here. // FIXME: Replace this special casing with a more accurate modelling of @@ -1534,18 +1652,31 @@ MemoryUseOrDef *MemorySSA::createNewAccess(Instruction *I) { if (II->getIntrinsicID() == Intrinsic::assume) return nullptr; - // Find out what affect this instruction has on memory. - ModRefInfo ModRef = AA->getModRefInfo(I, None); - // The isOrdered check is used to ensure that volatiles end up as defs - // (atomics end up as ModRef right now anyway). Until we separate the - // ordering chain from the memory chain, this enables people to see at least - // some relative ordering to volatiles. Note that getClobberingMemoryAccess - // will still give an answer that bypasses other volatile loads. TODO: - // Separate memory aliasing and ordering into two different chains so that we - // can precisely represent both "what memory will this read/write/is clobbered - // by" and "what instructions can I move this past". - bool Def = isModSet(ModRef) || isOrdered(I); - bool Use = isRefSet(ModRef); + bool Def, Use; + if (Template) { + Def = dyn_cast_or_null<MemoryDef>(Template) != nullptr; + Use = dyn_cast_or_null<MemoryUse>(Template) != nullptr; +#if !defined(NDEBUG) + ModRefInfo ModRef = AA->getModRefInfo(I, None); + bool DefCheck, UseCheck; + DefCheck = isModSet(ModRef) || isOrdered(I); + UseCheck = isRefSet(ModRef); + assert(Def == DefCheck && (Def || Use == UseCheck) && "Invalid template"); +#endif + } else { + // Find out what affect this instruction has on memory. + ModRefInfo ModRef = AA->getModRefInfo(I, None); + // The isOrdered check is used to ensure that volatiles end up as defs + // (atomics end up as ModRef right now anyway). Until we separate the + // ordering chain from the memory chain, this enables people to see at least + // some relative ordering to volatiles. Note that getClobberingMemoryAccess + // will still give an answer that bypasses other volatile loads. TODO: + // Separate memory aliasing and ordering into two different chains so that + // we can precisely represent both "what memory will this read/write/is + // clobbered by" and "what instructions can I move this past". + Def = isModSet(ModRef) || isOrdered(I); + Use = isRefSet(ModRef); + } // It's possible for an instruction to not modify memory at all. During // construction, we ignore them. @@ -1648,6 +1779,34 @@ void MemorySSA::verifyMemorySSA() const { verifyOrdering(F); verifyDominationNumbers(F); Walker->verify(this); + verifyClobberSanity(F); +} + +/// Check sanity of the clobbering instruction for access MA. +void MemorySSA::checkClobberSanityAccess(const MemoryAccess *MA) const { + if (const auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) { + if (!MUD->isOptimized()) + return; + auto *I = MUD->getMemoryInst(); + auto Loc = MemoryLocation::getOrNone(I); + if (Loc == None) + return; + auto *Clobber = MUD->getOptimized(); + UpwardsMemoryQuery Q(I, MUD); + checkClobberSanity(MUD, Clobber, *Loc, *this, Q, *AA, true); + } +} + +void MemorySSA::verifyClobberSanity(const Function &F) const { +#if !defined(NDEBUG) && defined(EXPENSIVE_CHECKS) + for (const BasicBlock &BB : F) { + const AccessList *Accesses = getBlockAccesses(&BB); + if (!Accesses) + continue; + for (const MemoryAccess &MA : *Accesses) + checkClobberSanityAccess(&MA); + } +#endif } /// Verify that all of the blocks we believe to have valid domination numbers @@ -1691,6 +1850,7 @@ void MemorySSA::verifyDominationNumbers(const Function &F) const { /// Verify that the order and existence of MemoryAccesses matches the /// order and existence of memory affecting instructions. void MemorySSA::verifyOrdering(Function &F) const { +#ifndef NDEBUG // Walk all the blocks, comparing what the lookups think and what the access // lists think, as well as the order in the blocks vs the order in the access // lists. @@ -1749,6 +1909,7 @@ void MemorySSA::verifyOrdering(Function &F) const { } ActualDefs.clear(); } +#endif } /// Verify the domination properties of MemorySSA by checking that each @@ -1791,6 +1952,7 @@ void MemorySSA::verifyUseInDefs(MemoryAccess *Def, MemoryAccess *Use) const { /// accesses and verifying that, for each use, it appears in the /// appropriate def's use list void MemorySSA::verifyDefUses(Function &F) const { +#ifndef NDEBUG for (BasicBlock &B : F) { // Phi nodes are attached to basic blocks if (MemoryPhi *Phi = getMemoryAccess(&B)) { @@ -1811,14 +1973,7 @@ void MemorySSA::verifyDefUses(Function &F) const { } } } -} - -MemoryUseOrDef *MemorySSA::getMemoryAccess(const Instruction *I) const { - return cast_or_null<MemoryUseOrDef>(ValueToMemoryAccess.lookup(I)); -} - -MemoryPhi *MemorySSA::getMemoryAccess(const BasicBlock *BB) const { - return cast_or_null<MemoryPhi>(ValueToMemoryAccess.lookup(cast<Value>(BB))); +#endif } /// Perform a local numbering on blocks so that instruction ordering can be @@ -2051,25 +2206,11 @@ void MemorySSAWrapperPass::print(raw_ostream &OS, const Module *M) const { MemorySSAWalker::MemorySSAWalker(MemorySSA *M) : MSSA(M) {} -MemorySSA::CachingWalker::CachingWalker(MemorySSA *M, AliasAnalysis *A, - DominatorTree *D) - : MemorySSAWalker(M), Walker(*M, *A, *D) {} - -void MemorySSA::CachingWalker::invalidateInfo(MemoryAccess *MA) { - if (auto *MUD = dyn_cast<MemoryUseOrDef>(MA)) - MUD->resetOptimized(); -} - -/// Walk the use-def chains starting at \p MA and find +/// Walk the use-def chains starting at \p StartingAccess and find /// the MemoryAccess that actually clobbers Loc. /// /// \returns our clobbering memory access -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( - MemoryAccess *StartingAccess, UpwardsMemoryQuery &Q) { - return Walker.findClobber(StartingAccess, Q); -} - -MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( +MemoryAccess *MemorySSA::ClobberWalkerBase::getClobberingMemoryAccessBase( MemoryAccess *StartingAccess, const MemoryLocation &Loc) { if (isa<MemoryPhi>(StartingAccess)) return StartingAccess; @@ -2082,7 +2223,7 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( // Conservatively, fences are always clobbers, so don't perform the walk if we // hit a fence. - if (!ImmutableCallSite(I) && I->isFenceLike()) + if (!isa<CallBase>(I) && I->isFenceLike()) return StartingUseOrDef; UpwardsMemoryQuery Q; @@ -2093,11 +2234,12 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( // Unlike the other function, do not walk to the def of a def, because we are // handed something we already believe is the clobbering access. + // We never set SkipSelf to true in Q in this method. MemoryAccess *DefiningAccess = isa<MemoryUse>(StartingUseOrDef) ? StartingUseOrDef->getDefiningAccess() : StartingUseOrDef; - MemoryAccess *Clobber = getClobberingMemoryAccess(DefiningAccess, Q); + MemoryAccess *Clobber = Walker.findClobber(DefiningAccess, Q); LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); LLVM_DEBUG(dbgs() << *StartingUseOrDef << "\n"); LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); @@ -2106,26 +2248,33 @@ MemoryAccess *MemorySSA::CachingWalker::getClobberingMemoryAccess( } MemoryAccess * -MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { +MemorySSA::ClobberWalkerBase::getClobberingMemoryAccessBase(MemoryAccess *MA, + bool SkipSelf) { auto *StartingAccess = dyn_cast<MemoryUseOrDef>(MA); // If this is a MemoryPhi, we can't do anything. if (!StartingAccess) return MA; + bool IsOptimized = false; + // If this is an already optimized use or def, return the optimized result. // Note: Currently, we store the optimized def result in a separate field, // since we can't use the defining access. - if (StartingAccess->isOptimized()) - return StartingAccess->getOptimized(); + if (StartingAccess->isOptimized()) { + if (!SkipSelf || !isa<MemoryDef>(StartingAccess)) + return StartingAccess->getOptimized(); + IsOptimized = true; + } const Instruction *I = StartingAccess->getMemoryInst(); - UpwardsMemoryQuery Q(I, StartingAccess); // We can't sanely do anything with a fence, since they conservatively clobber // all memory, and have no locations to get pointers from to try to // disambiguate. - if (!Q.IsCall && I->isFenceLike()) + if (!isa<CallBase>(I) && I->isFenceLike()) return StartingAccess; + UpwardsMemoryQuery Q(I, StartingAccess); + if (isUseTriviallyOptimizableToLiveOnEntry(*MSSA->AA, I)) { MemoryAccess *LiveOnEntry = MSSA->getLiveOnEntryDef(); StartingAccess->setOptimized(LiveOnEntry); @@ -2133,33 +2282,71 @@ MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { return LiveOnEntry; } - // Start with the thing we already think clobbers this location - MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + MemoryAccess *OptimizedAccess; + if (!IsOptimized) { + // Start with the thing we already think clobbers this location + MemoryAccess *DefiningAccess = StartingAccess->getDefiningAccess(); + + // At this point, DefiningAccess may be the live on entry def. + // If it is, we will not get a better result. + if (MSSA->isLiveOnEntryDef(DefiningAccess)) { + StartingAccess->setOptimized(DefiningAccess); + StartingAccess->setOptimizedAccessType(None); + return DefiningAccess; + } - // At this point, DefiningAccess may be the live on entry def. - // If it is, we will not get a better result. - if (MSSA->isLiveOnEntryDef(DefiningAccess)) { - StartingAccess->setOptimized(DefiningAccess); - StartingAccess->setOptimizedAccessType(None); - return DefiningAccess; - } + OptimizedAccess = Walker.findClobber(DefiningAccess, Q); + StartingAccess->setOptimized(OptimizedAccess); + if (MSSA->isLiveOnEntryDef(OptimizedAccess)) + StartingAccess->setOptimizedAccessType(None); + else if (Q.AR == MustAlias) + StartingAccess->setOptimizedAccessType(MustAlias); + } else + OptimizedAccess = StartingAccess->getOptimized(); - MemoryAccess *Result = getClobberingMemoryAccess(DefiningAccess, Q); LLVM_DEBUG(dbgs() << "Starting Memory SSA clobber for " << *I << " is "); - LLVM_DEBUG(dbgs() << *DefiningAccess << "\n"); - LLVM_DEBUG(dbgs() << "Final Memory SSA clobber for " << *I << " is "); - LLVM_DEBUG(dbgs() << *Result << "\n"); - - StartingAccess->setOptimized(Result); - if (MSSA->isLiveOnEntryDef(Result)) - StartingAccess->setOptimizedAccessType(None); - else if (Q.AR == MustAlias) - StartingAccess->setOptimizedAccessType(MustAlias); + LLVM_DEBUG(dbgs() << *StartingAccess << "\n"); + LLVM_DEBUG(dbgs() << "Optimized Memory SSA clobber for " << *I << " is "); + LLVM_DEBUG(dbgs() << *OptimizedAccess << "\n"); + + MemoryAccess *Result; + if (SkipSelf && isa<MemoryPhi>(OptimizedAccess) && + isa<MemoryDef>(StartingAccess)) { + assert(isa<MemoryDef>(Q.OriginalAccess)); + Q.SkipSelfAccess = true; + Result = Walker.findClobber(OptimizedAccess, Q); + } else + Result = OptimizedAccess; + + LLVM_DEBUG(dbgs() << "Result Memory SSA clobber [SkipSelf = " << SkipSelf); + LLVM_DEBUG(dbgs() << "] for " << *I << " is " << *Result << "\n"); return Result; } MemoryAccess * +MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + return Walker->getClobberingMemoryAccessBase(MA, false); +} + +MemoryAccess * +MemorySSA::CachingWalker::getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) { + return Walker->getClobberingMemoryAccessBase(MA, Loc); +} + +MemoryAccess * +MemorySSA::SkipSelfWalker::getClobberingMemoryAccess(MemoryAccess *MA) { + return Walker->getClobberingMemoryAccessBase(MA, true); +} + +MemoryAccess * +MemorySSA::SkipSelfWalker::getClobberingMemoryAccess(MemoryAccess *MA, + const MemoryLocation &Loc) { + return Walker->getClobberingMemoryAccessBase(MA, Loc); +} + +MemoryAccess * DoNothingMemorySSAWalker::getClobberingMemoryAccess(MemoryAccess *MA) { if (auto *Use = dyn_cast<MemoryUseOrDef>(MA)) return Use->getDefiningAccess(); diff --git a/lib/Analysis/MemorySSAUpdater.cpp b/lib/Analysis/MemorySSAUpdater.cpp index abe2b3c25a58..6c817d203684 100644 --- a/lib/Analysis/MemorySSAUpdater.cpp +++ b/lib/Analysis/MemorySSAUpdater.cpp @@ -12,7 +12,9 @@ //===----------------------------------------------------------------===// #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" @@ -91,7 +93,7 @@ MemoryAccess *MemorySSAUpdater::getPreviousDefRecursive( // FIXME: Figure out whether this is dead code and if so remove it. if (!std::equal(Phi->op_begin(), Phi->op_end(), PhiOps.begin())) { // These will have been filled in by the recursive read we did above. - std::copy(PhiOps.begin(), PhiOps.end(), Phi->op_begin()); + llvm::copy(PhiOps, Phi->op_begin()); std::copy(pred_begin(BB), pred_end(BB), Phi->block_begin()); } } else { @@ -264,16 +266,15 @@ void MemorySSAUpdater::insertDef(MemoryDef *MD, bool RenameUses) { for (auto UI = DefBefore->use_begin(), UE = DefBefore->use_end(); UI != UE;) { Use &U = *UI++; - // Leave the uses alone - if (isa<MemoryUse>(U.getUser())) + // Leave the MemoryUses alone. + // Also make sure we skip ourselves to avoid self references. + if (isa<MemoryUse>(U.getUser()) || U.getUser() == MD) continue; U.set(MD); } } // and that def is now our defining access. - // We change them in this order otherwise we will appear in the use list - // above and reset ourselves. MD->setDefiningAccess(DefBefore); SmallVector<WeakVH, 8> FixupList(InsertedPHIs.begin(), InsertedPHIs.end()); @@ -392,6 +393,522 @@ void MemorySSAUpdater::fixupDefs(const SmallVectorImpl<WeakVH> &Vars) { } } +void MemorySSAUpdater::removeEdge(BasicBlock *From, BasicBlock *To) { + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) { + MPhi->unorderedDeleteIncomingBlock(From); + if (MPhi->getNumIncomingValues() == 1) + removeMemoryAccess(MPhi); + } +} + +void MemorySSAUpdater::removeDuplicatePhiEdgesBetween(BasicBlock *From, + BasicBlock *To) { + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(To)) { + bool Found = false; + MPhi->unorderedDeleteIncomingIf([&](const MemoryAccess *, BasicBlock *B) { + if (From != B) + return false; + if (Found) + return true; + Found = true; + return false; + }); + if (MPhi->getNumIncomingValues() == 1) + removeMemoryAccess(MPhi); + } +} + +void MemorySSAUpdater::cloneUsesAndDefs(BasicBlock *BB, BasicBlock *NewBB, + const ValueToValueMapTy &VMap, + PhiToDefMap &MPhiMap) { + auto GetNewDefiningAccess = [&](MemoryAccess *MA) -> MemoryAccess * { + MemoryAccess *InsnDefining = MA; + if (MemoryUseOrDef *DefMUD = dyn_cast<MemoryUseOrDef>(InsnDefining)) { + if (!MSSA->isLiveOnEntryDef(DefMUD)) { + Instruction *DefMUDI = DefMUD->getMemoryInst(); + assert(DefMUDI && "Found MemoryUseOrDef with no Instruction."); + if (Instruction *NewDefMUDI = + cast_or_null<Instruction>(VMap.lookup(DefMUDI))) + InsnDefining = MSSA->getMemoryAccess(NewDefMUDI); + } + } else { + MemoryPhi *DefPhi = cast<MemoryPhi>(InsnDefining); + if (MemoryAccess *NewDefPhi = MPhiMap.lookup(DefPhi)) + InsnDefining = NewDefPhi; + } + assert(InsnDefining && "Defining instruction cannot be nullptr."); + return InsnDefining; + }; + + const MemorySSA::AccessList *Acc = MSSA->getBlockAccesses(BB); + if (!Acc) + return; + for (const MemoryAccess &MA : *Acc) { + if (const MemoryUseOrDef *MUD = dyn_cast<MemoryUseOrDef>(&MA)) { + Instruction *Insn = MUD->getMemoryInst(); + // Entry does not exist if the clone of the block did not clone all + // instructions. This occurs in LoopRotate when cloning instructions + // from the old header to the old preheader. The cloned instruction may + // also be a simplified Value, not an Instruction (see LoopRotate). + if (Instruction *NewInsn = + dyn_cast_or_null<Instruction>(VMap.lookup(Insn))) { + MemoryAccess *NewUseOrDef = MSSA->createDefinedAccess( + NewInsn, GetNewDefiningAccess(MUD->getDefiningAccess()), MUD); + MSSA->insertIntoListsForBlock(NewUseOrDef, NewBB, MemorySSA::End); + } + } + } +} + +void MemorySSAUpdater::updateForClonedLoop(const LoopBlocksRPO &LoopBlocks, + ArrayRef<BasicBlock *> ExitBlocks, + const ValueToValueMapTy &VMap, + bool IgnoreIncomingWithNoClones) { + PhiToDefMap MPhiMap; + + auto FixPhiIncomingValues = [&](MemoryPhi *Phi, MemoryPhi *NewPhi) { + assert(Phi && NewPhi && "Invalid Phi nodes."); + BasicBlock *NewPhiBB = NewPhi->getBlock(); + SmallPtrSet<BasicBlock *, 4> NewPhiBBPreds(pred_begin(NewPhiBB), + pred_end(NewPhiBB)); + for (unsigned It = 0, E = Phi->getNumIncomingValues(); It < E; ++It) { + MemoryAccess *IncomingAccess = Phi->getIncomingValue(It); + BasicBlock *IncBB = Phi->getIncomingBlock(It); + + if (BasicBlock *NewIncBB = cast_or_null<BasicBlock>(VMap.lookup(IncBB))) + IncBB = NewIncBB; + else if (IgnoreIncomingWithNoClones) + continue; + + // Now we have IncBB, and will need to add incoming from it to NewPhi. + + // If IncBB is not a predecessor of NewPhiBB, then do not add it. + // NewPhiBB was cloned without that edge. + if (!NewPhiBBPreds.count(IncBB)) + continue; + + // Determine incoming value and add it as incoming from IncBB. + if (MemoryUseOrDef *IncMUD = dyn_cast<MemoryUseOrDef>(IncomingAccess)) { + if (!MSSA->isLiveOnEntryDef(IncMUD)) { + Instruction *IncI = IncMUD->getMemoryInst(); + assert(IncI && "Found MemoryUseOrDef with no Instruction."); + if (Instruction *NewIncI = + cast_or_null<Instruction>(VMap.lookup(IncI))) { + IncMUD = MSSA->getMemoryAccess(NewIncI); + assert(IncMUD && + "MemoryUseOrDef cannot be null, all preds processed."); + } + } + NewPhi->addIncoming(IncMUD, IncBB); + } else { + MemoryPhi *IncPhi = cast<MemoryPhi>(IncomingAccess); + if (MemoryAccess *NewDefPhi = MPhiMap.lookup(IncPhi)) + NewPhi->addIncoming(NewDefPhi, IncBB); + else + NewPhi->addIncoming(IncPhi, IncBB); + } + } + }; + + auto ProcessBlock = [&](BasicBlock *BB) { + BasicBlock *NewBlock = cast_or_null<BasicBlock>(VMap.lookup(BB)); + if (!NewBlock) + return; + + assert(!MSSA->getWritableBlockAccesses(NewBlock) && + "Cloned block should have no accesses"); + + // Add MemoryPhi. + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) { + MemoryPhi *NewPhi = MSSA->createMemoryPhi(NewBlock); + MPhiMap[MPhi] = NewPhi; + } + // Update Uses and Defs. + cloneUsesAndDefs(BB, NewBlock, VMap, MPhiMap); + }; + + for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks)) + ProcessBlock(BB); + + for (auto BB : llvm::concat<BasicBlock *const>(LoopBlocks, ExitBlocks)) + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) + if (MemoryAccess *NewPhi = MPhiMap.lookup(MPhi)) + FixPhiIncomingValues(MPhi, cast<MemoryPhi>(NewPhi)); +} + +void MemorySSAUpdater::updateForClonedBlockIntoPred( + BasicBlock *BB, BasicBlock *P1, const ValueToValueMapTy &VM) { + // All defs/phis from outside BB that are used in BB, are valid uses in P1. + // Since those defs/phis must have dominated BB, and also dominate P1. + // Defs from BB being used in BB will be replaced with the cloned defs from + // VM. The uses of BB's Phi (if it exists) in BB will be replaced by the + // incoming def into the Phi from P1. + PhiToDefMap MPhiMap; + if (MemoryPhi *MPhi = MSSA->getMemoryAccess(BB)) + MPhiMap[MPhi] = MPhi->getIncomingValueForBlock(P1); + cloneUsesAndDefs(BB, P1, VM, MPhiMap); +} + +template <typename Iter> +void MemorySSAUpdater::privateUpdateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, Iter ValuesBegin, Iter ValuesEnd, + DominatorTree &DT) { + SmallVector<CFGUpdate, 4> Updates; + // Update/insert phis in all successors of exit blocks. + for (auto *Exit : ExitBlocks) + for (const ValueToValueMapTy *VMap : make_range(ValuesBegin, ValuesEnd)) + if (BasicBlock *NewExit = cast_or_null<BasicBlock>(VMap->lookup(Exit))) { + BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); + Updates.push_back({DT.Insert, NewExit, ExitSucc}); + } + applyInsertUpdates(Updates, DT); +} + +void MemorySSAUpdater::updateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, const ValueToValueMapTy &VMap, + DominatorTree &DT) { + const ValueToValueMapTy *const Arr[] = {&VMap}; + privateUpdateExitBlocksForClonedLoop(ExitBlocks, std::begin(Arr), + std::end(Arr), DT); +} + +void MemorySSAUpdater::updateExitBlocksForClonedLoop( + ArrayRef<BasicBlock *> ExitBlocks, + ArrayRef<std::unique_ptr<ValueToValueMapTy>> VMaps, DominatorTree &DT) { + auto GetPtr = [&](const std::unique_ptr<ValueToValueMapTy> &I) { + return I.get(); + }; + using MappedIteratorType = + mapped_iterator<const std::unique_ptr<ValueToValueMapTy> *, + decltype(GetPtr)>; + auto MapBegin = MappedIteratorType(VMaps.begin(), GetPtr); + auto MapEnd = MappedIteratorType(VMaps.end(), GetPtr); + privateUpdateExitBlocksForClonedLoop(ExitBlocks, MapBegin, MapEnd, DT); +} + +void MemorySSAUpdater::applyUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT) { + SmallVector<CFGUpdate, 4> RevDeleteUpdates; + SmallVector<CFGUpdate, 4> InsertUpdates; + for (auto &Update : Updates) { + if (Update.getKind() == DT.Insert) + InsertUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()}); + else + RevDeleteUpdates.push_back({DT.Insert, Update.getFrom(), Update.getTo()}); + } + + if (!RevDeleteUpdates.empty()) { + // Update for inserted edges: use newDT and snapshot CFG as if deletes had + // not occured. + // FIXME: This creates a new DT, so it's more expensive to do mix + // delete/inserts vs just inserts. We can do an incremental update on the DT + // to revert deletes, than re-delete the edges. Teaching DT to do this, is + // part of a pending cleanup. + DominatorTree NewDT(DT, RevDeleteUpdates); + GraphDiff<BasicBlock *> GD(RevDeleteUpdates); + applyInsertUpdates(InsertUpdates, NewDT, &GD); + } else { + GraphDiff<BasicBlock *> GD; + applyInsertUpdates(InsertUpdates, DT, &GD); + } + + // Update for deleted edges + for (auto &Update : RevDeleteUpdates) + removeEdge(Update.getFrom(), Update.getTo()); +} + +void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT) { + GraphDiff<BasicBlock *> GD; + applyInsertUpdates(Updates, DT, &GD); +} + +void MemorySSAUpdater::applyInsertUpdates(ArrayRef<CFGUpdate> Updates, + DominatorTree &DT, + const GraphDiff<BasicBlock *> *GD) { + // Get recursive last Def, assuming well formed MSSA and updated DT. + auto GetLastDef = [&](BasicBlock *BB) -> MemoryAccess * { + while (true) { + MemorySSA::DefsList *Defs = MSSA->getWritableBlockDefs(BB); + // Return last Def or Phi in BB, if it exists. + if (Defs) + return &*(--Defs->end()); + + // Check number of predecessors, we only care if there's more than one. + unsigned Count = 0; + BasicBlock *Pred = nullptr; + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) { + Pred = Pair.second; + Count++; + if (Count == 2) + break; + } + + // If BB has multiple predecessors, get last definition from IDom. + if (Count != 1) { + // [SimpleLoopUnswitch] If BB is a dead block, about to be deleted, its + // DT is invalidated. Return LoE as its last def. This will be added to + // MemoryPhi node, and later deleted when the block is deleted. + if (!DT.getNode(BB)) + return MSSA->getLiveOnEntryDef(); + if (auto *IDom = DT.getNode(BB)->getIDom()) + if (IDom->getBlock() != BB) { + BB = IDom->getBlock(); + continue; + } + return MSSA->getLiveOnEntryDef(); + } else { + // Single predecessor, BB cannot be dead. GetLastDef of Pred. + assert(Count == 1 && Pred && "Single predecessor expected."); + BB = Pred; + } + }; + llvm_unreachable("Unable to get last definition."); + }; + + // Get nearest IDom given a set of blocks. + // TODO: this can be optimized by starting the search at the node with the + // lowest level (highest in the tree). + auto FindNearestCommonDominator = + [&](const SmallSetVector<BasicBlock *, 2> &BBSet) -> BasicBlock * { + BasicBlock *PrevIDom = *BBSet.begin(); + for (auto *BB : BBSet) + PrevIDom = DT.findNearestCommonDominator(PrevIDom, BB); + return PrevIDom; + }; + + // Get all blocks that dominate PrevIDom, stop when reaching CurrIDom. Do not + // include CurrIDom. + auto GetNoLongerDomBlocks = + [&](BasicBlock *PrevIDom, BasicBlock *CurrIDom, + SmallVectorImpl<BasicBlock *> &BlocksPrevDom) { + if (PrevIDom == CurrIDom) + return; + BlocksPrevDom.push_back(PrevIDom); + BasicBlock *NextIDom = PrevIDom; + while (BasicBlock *UpIDom = + DT.getNode(NextIDom)->getIDom()->getBlock()) { + if (UpIDom == CurrIDom) + break; + BlocksPrevDom.push_back(UpIDom); + NextIDom = UpIDom; + } + }; + + // Map a BB to its predecessors: added + previously existing. To get a + // deterministic order, store predecessors as SetVectors. The order in each + // will be defined by teh order in Updates (fixed) and the order given by + // children<> (also fixed). Since we further iterate over these ordered sets, + // we lose the information of multiple edges possibly existing between two + // blocks, so we'll keep and EdgeCount map for that. + // An alternate implementation could keep unordered set for the predecessors, + // traverse either Updates or children<> each time to get the deterministic + // order, and drop the usage of EdgeCount. This alternate approach would still + // require querying the maps for each predecessor, and children<> call has + // additional computation inside for creating the snapshot-graph predecessors. + // As such, we favor using a little additional storage and less compute time. + // This decision can be revisited if we find the alternative more favorable. + + struct PredInfo { + SmallSetVector<BasicBlock *, 2> Added; + SmallSetVector<BasicBlock *, 2> Prev; + }; + SmallDenseMap<BasicBlock *, PredInfo> PredMap; + + for (auto &Edge : Updates) { + BasicBlock *BB = Edge.getTo(); + auto &AddedBlockSet = PredMap[BB].Added; + AddedBlockSet.insert(Edge.getFrom()); + } + + // Store all existing predecessor for each BB, at least one must exist. + SmallDenseMap<std::pair<BasicBlock *, BasicBlock *>, int> EdgeCountMap; + SmallPtrSet<BasicBlock *, 2> NewBlocks; + for (auto &BBPredPair : PredMap) { + auto *BB = BBPredPair.first; + const auto &AddedBlockSet = BBPredPair.second.Added; + auto &PrevBlockSet = BBPredPair.second.Prev; + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BB})) { + BasicBlock *Pi = Pair.second; + if (!AddedBlockSet.count(Pi)) + PrevBlockSet.insert(Pi); + EdgeCountMap[{Pi, BB}]++; + } + + if (PrevBlockSet.empty()) { + assert(pred_size(BB) == AddedBlockSet.size() && "Duplicate edges added."); + LLVM_DEBUG( + dbgs() + << "Adding a predecessor to a block with no predecessors. " + "This must be an edge added to a new, likely cloned, block. " + "Its memory accesses must be already correct, assuming completed " + "via the updateExitBlocksForClonedLoop API. " + "Assert a single such edge is added so no phi addition or " + "additional processing is required.\n"); + assert(AddedBlockSet.size() == 1 && + "Can only handle adding one predecessor to a new block."); + // Need to remove new blocks from PredMap. Remove below to not invalidate + // iterator here. + NewBlocks.insert(BB); + } + } + // Nothing to process for new/cloned blocks. + for (auto *BB : NewBlocks) + PredMap.erase(BB); + + SmallVector<BasicBlock *, 8> BlocksToProcess; + SmallVector<BasicBlock *, 16> BlocksWithDefsToReplace; + + // First create MemoryPhis in all blocks that don't have one. Create in the + // order found in Updates, not in PredMap, to get deterministic numbering. + for (auto &Edge : Updates) { + BasicBlock *BB = Edge.getTo(); + if (PredMap.count(BB) && !MSSA->getMemoryAccess(BB)) + MSSA->createMemoryPhi(BB); + } + + // Now we'll fill in the MemoryPhis with the right incoming values. + for (auto &BBPredPair : PredMap) { + auto *BB = BBPredPair.first; + const auto &PrevBlockSet = BBPredPair.second.Prev; + const auto &AddedBlockSet = BBPredPair.second.Added; + assert(!PrevBlockSet.empty() && + "At least one previous predecessor must exist."); + + // TODO: if this becomes a bottleneck, we can save on GetLastDef calls by + // keeping this map before the loop. We can reuse already populated entries + // if an edge is added from the same predecessor to two different blocks, + // and this does happen in rotate. Note that the map needs to be updated + // when deleting non-necessary phis below, if the phi is in the map by + // replacing the value with DefP1. + SmallDenseMap<BasicBlock *, MemoryAccess *> LastDefAddedPred; + for (auto *AddedPred : AddedBlockSet) { + auto *DefPn = GetLastDef(AddedPred); + assert(DefPn != nullptr && "Unable to find last definition."); + LastDefAddedPred[AddedPred] = DefPn; + } + + MemoryPhi *NewPhi = MSSA->getMemoryAccess(BB); + // If Phi is not empty, add an incoming edge from each added pred. Must + // still compute blocks with defs to replace for this block below. + if (NewPhi->getNumOperands()) { + for (auto *Pred : AddedBlockSet) { + auto *LastDefForPred = LastDefAddedPred[Pred]; + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(LastDefForPred, Pred); + } + } else { + // Pick any existing predecessor and get its definition. All other + // existing predecessors should have the same one, since no phi existed. + auto *P1 = *PrevBlockSet.begin(); + MemoryAccess *DefP1 = GetLastDef(P1); + + // Check DefP1 against all Defs in LastDefPredPair. If all the same, + // nothing to add. + bool InsertPhi = false; + for (auto LastDefPredPair : LastDefAddedPred) + if (DefP1 != LastDefPredPair.second) { + InsertPhi = true; + break; + } + if (!InsertPhi) { + // Since NewPhi may be used in other newly added Phis, replace all uses + // of NewPhi with the definition coming from all predecessors (DefP1), + // before deleting it. + NewPhi->replaceAllUsesWith(DefP1); + removeMemoryAccess(NewPhi); + continue; + } + + // Update Phi with new values for new predecessors and old value for all + // other predecessors. Since AddedBlockSet and PrevBlockSet are ordered + // sets, the order of entries in NewPhi is deterministic. + for (auto *Pred : AddedBlockSet) { + auto *LastDefForPred = LastDefAddedPred[Pred]; + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(LastDefForPred, Pred); + } + for (auto *Pred : PrevBlockSet) + for (int I = 0, E = EdgeCountMap[{Pred, BB}]; I < E; ++I) + NewPhi->addIncoming(DefP1, Pred); + + // Insert BB in the set of blocks that now have definition. We'll use this + // to compute IDF and add Phis there next. + BlocksToProcess.push_back(BB); + } + + // Get all blocks that used to dominate BB and no longer do after adding + // AddedBlockSet, where PrevBlockSet are the previously known predecessors. + assert(DT.getNode(BB)->getIDom() && "BB does not have valid idom"); + BasicBlock *PrevIDom = FindNearestCommonDominator(PrevBlockSet); + assert(PrevIDom && "Previous IDom should exists"); + BasicBlock *NewIDom = DT.getNode(BB)->getIDom()->getBlock(); + assert(NewIDom && "BB should have a new valid idom"); + assert(DT.dominates(NewIDom, PrevIDom) && + "New idom should dominate old idom"); + GetNoLongerDomBlocks(PrevIDom, NewIDom, BlocksWithDefsToReplace); + } + + // Compute IDF and add Phis in all IDF blocks that do not have one. + SmallVector<BasicBlock *, 32> IDFBlocks; + if (!BlocksToProcess.empty()) { + ForwardIDFCalculator IDFs(DT); + SmallPtrSet<BasicBlock *, 16> DefiningBlocks(BlocksToProcess.begin(), + BlocksToProcess.end()); + IDFs.setDefiningBlocks(DefiningBlocks); + IDFs.calculate(IDFBlocks); + for (auto *BBIDF : IDFBlocks) { + if (auto *IDFPhi = MSSA->getMemoryAccess(BBIDF)) { + // Update existing Phi. + // FIXME: some updates may be redundant, try to optimize and skip some. + for (unsigned I = 0, E = IDFPhi->getNumIncomingValues(); I < E; ++I) + IDFPhi->setIncomingValue(I, GetLastDef(IDFPhi->getIncomingBlock(I))); + } else { + IDFPhi = MSSA->createMemoryPhi(BBIDF); + for (auto &Pair : children<GraphDiffInvBBPair>({GD, BBIDF})) { + BasicBlock *Pi = Pair.second; + IDFPhi->addIncoming(GetLastDef(Pi), Pi); + } + } + } + } + + // Now for all defs in BlocksWithDefsToReplace, if there are uses they no + // longer dominate, replace those with the closest dominating def. + // This will also update optimized accesses, as they're also uses. + for (auto *BlockWithDefsToReplace : BlocksWithDefsToReplace) { + if (auto DefsList = MSSA->getWritableBlockDefs(BlockWithDefsToReplace)) { + for (auto &DefToReplaceUses : *DefsList) { + BasicBlock *DominatingBlock = DefToReplaceUses.getBlock(); + Value::use_iterator UI = DefToReplaceUses.use_begin(), + E = DefToReplaceUses.use_end(); + for (; UI != E;) { + Use &U = *UI; + ++UI; + MemoryAccess *Usr = dyn_cast<MemoryAccess>(U.getUser()); + if (MemoryPhi *UsrPhi = dyn_cast<MemoryPhi>(Usr)) { + BasicBlock *DominatedBlock = UsrPhi->getIncomingBlock(U); + if (!DT.dominates(DominatingBlock, DominatedBlock)) + U.set(GetLastDef(DominatedBlock)); + } else { + BasicBlock *DominatedBlock = Usr->getBlock(); + if (!DT.dominates(DominatingBlock, DominatedBlock)) { + if (auto *DomBlPhi = MSSA->getMemoryAccess(DominatedBlock)) + U.set(DomBlPhi); + else { + auto *IDom = DT.getNode(DominatedBlock)->getIDom(); + assert(IDom && "Block must have a valid IDom."); + U.set(GetLastDef(IDom->getBlock())); + } + cast<MemoryUseOrDef>(Usr)->resetOptimized(); + } + } + } + } + } + } +} + // Move What before Where in the MemorySSA IR. template <class WhereType> void MemorySSAUpdater::moveTo(MemoryUseOrDef *What, BasicBlock *BB, @@ -498,13 +1015,14 @@ static MemoryAccess *onlySingleValue(MemoryPhi *MP) { } void MemorySSAUpdater::wireOldPredecessorsToNewImmediatePredecessor( - BasicBlock *Old, BasicBlock *New, ArrayRef<BasicBlock *> Preds) { + BasicBlock *Old, BasicBlock *New, ArrayRef<BasicBlock *> Preds, + bool IdenticalEdgesWereMerged) { assert(!MSSA->getWritableBlockAccesses(New) && "Access list should be null for a new block."); MemoryPhi *Phi = MSSA->getMemoryAccess(Old); if (!Phi) return; - if (pred_size(Old) == 1) { + if (Old->hasNPredecessors(1)) { assert(pred_size(New) == Preds.size() && "Should have moved all predecessors."); MSSA->moveTo(Phi, New, MemorySSA::Beginning); @@ -513,9 +1031,17 @@ void MemorySSAUpdater::wireOldPredecessorsToNewImmediatePredecessor( "new immediate predecessor."); MemoryPhi *NewPhi = MSSA->createMemoryPhi(New); SmallPtrSet<BasicBlock *, 16> PredsSet(Preds.begin(), Preds.end()); + // Currently only support the case of removing a single incoming edge when + // identical edges were not merged. + if (!IdenticalEdgesWereMerged) + assert(PredsSet.size() == Preds.size() && + "If identical edges were not merged, we cannot have duplicate " + "blocks in the predecessors"); Phi->unorderedDeleteIncomingIf([&](MemoryAccess *MA, BasicBlock *B) { if (PredsSet.count(B)) { NewPhi->addIncoming(MA, B); + if (!IdenticalEdgesWereMerged) + PredsSet.erase(B); return true; } return false; @@ -578,9 +1104,9 @@ void MemorySSAUpdater::removeBlocks( const SmallPtrSetImpl<BasicBlock *> &DeadBlocks) { // First delete all uses of BB in MemoryPhis. for (BasicBlock *BB : DeadBlocks) { - TerminatorInst *TI = BB->getTerminator(); + Instruction *TI = BB->getTerminator(); assert(TI && "Basic block expected to have a terminator instruction"); - for (BasicBlock *Succ : TI->successors()) + for (BasicBlock *Succ : successors(TI)) if (!DeadBlocks.count(Succ)) if (MemoryPhi *MP = MSSA->getMemoryAccess(Succ)) { MP->unorderedDeleteIncomingBlock(BB); diff --git a/lib/Analysis/ModuleSummaryAnalysis.cpp b/lib/Analysis/ModuleSummaryAnalysis.cpp index 17dae20ce3a1..87f76d43bb1e 100644 --- a/lib/Analysis/ModuleSummaryAnalysis.cpp +++ b/lib/Analysis/ModuleSummaryAnalysis.cpp @@ -74,9 +74,17 @@ cl::opt<FunctionSummary::ForceSummaryHotnessType, true> FSEC( // Walk through the operands of a given User via worklist iteration and populate // the set of GlobalValue references encountered. Invoked either on an // Instruction or a GlobalVariable (which walks its initializer). -static void findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, +// Return true if any of the operands contains blockaddress. This is important +// to know when computing summary for global var, because if global variable +// references basic block address we can't import it separately from function +// containing that basic block. For simplicity we currently don't import such +// global vars at all. When importing function we aren't interested if any +// instruction in it takes an address of any basic block, because instruction +// can only take an address of basic block located in the same function. +static bool findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, SetVector<ValueInfo> &RefEdges, SmallPtrSet<const User *, 8> &Visited) { + bool HasBlockAddress = false; SmallVector<const User *, 32> Worklist; Worklist.push_back(CurUser); @@ -92,8 +100,10 @@ static void findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, const User *Operand = dyn_cast<User>(OI); if (!Operand) continue; - if (isa<BlockAddress>(Operand)) + if (isa<BlockAddress>(Operand)) { + HasBlockAddress = true; continue; + } if (auto *GV = dyn_cast<GlobalValue>(Operand)) { // We have a reference to a global value. This should be added to // the reference set unless it is a callee. Callees are handled @@ -105,6 +115,7 @@ static void findRefEdges(ModuleSummaryIndex &Index, const User *CurUser, Worklist.push_back(Operand); } } + return HasBlockAddress; } static CalleeInfo::HotnessType getHotness(uint64_t ProfileCount, @@ -147,7 +158,8 @@ static void addIntrinsicToSummary( SetVector<FunctionSummary::VFuncId> &TypeTestAssumeVCalls, SetVector<FunctionSummary::VFuncId> &TypeCheckedLoadVCalls, SetVector<FunctionSummary::ConstVCall> &TypeTestAssumeConstVCalls, - SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls) { + SetVector<FunctionSummary::ConstVCall> &TypeCheckedLoadConstVCalls, + DominatorTree &DT) { switch (CI->getCalledFunction()->getIntrinsicID()) { case Intrinsic::type_test: { auto *TypeMDVal = cast<MetadataAsValue>(CI->getArgOperand(1)); @@ -172,7 +184,7 @@ static void addIntrinsicToSummary( SmallVector<DevirtCallSite, 4> DevirtCalls; SmallVector<CallInst *, 4> Assumes; - findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI); + findDevirtualizableCallsForTypeTest(DevirtCalls, Assumes, CI, DT); for (auto &Call : DevirtCalls) addVCallToSet(Call, Guid, TypeTestAssumeVCalls, TypeTestAssumeConstVCalls); @@ -192,7 +204,7 @@ static void addIntrinsicToSummary( SmallVector<Instruction *, 4> Preds; bool HasNonCallUses = false; findDevirtualizableCallsForTypeCheckedLoad(DevirtCalls, LoadedPtrs, Preds, - HasNonCallUses, CI); + HasNonCallUses, CI, DT); // Any non-call uses of the result of llvm.type.checked.load will // prevent us from optimizing away the llvm.type.test. if (HasNonCallUses) @@ -208,11 +220,19 @@ static void addIntrinsicToSummary( } } -static void -computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, - const Function &F, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, bool HasLocalsInUsedOrAsm, - DenseSet<GlobalValue::GUID> &CantBePromoted) { +static bool isNonVolatileLoad(const Instruction *I) { + if (const auto *LI = dyn_cast<LoadInst>(I)) + return !LI->isVolatile(); + + return false; +} + +static void computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, + const Function &F, BlockFrequencyInfo *BFI, + ProfileSummaryInfo *PSI, DominatorTree &DT, + bool HasLocalsInUsedOrAsm, + DenseSet<GlobalValue::GUID> &CantBePromoted, + bool IsThinLTO) { // Summary not currently supported for anonymous functions, they should // have been named. assert(F.hasName()); @@ -233,6 +253,7 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, // Add personality function, prefix data and prologue data to function's ref // list. findRefEdges(Index, &F, RefEdges, Visited); + std::vector<const Instruction *> NonVolatileLoads; bool HasInlineAsmMaybeReferencingInternal = false; for (const BasicBlock &BB : F) @@ -240,6 +261,13 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, if (isa<DbgInfoIntrinsic>(I)) continue; ++NumInsts; + if (isNonVolatileLoad(&I)) { + // Postpone processing of non-volatile load instructions + // See comments below + Visited.insert(&I); + NonVolatileLoads.push_back(&I); + continue; + } findRefEdges(Index, &I, RefEdges, Visited); auto CS = ImmutableCallSite(&I); if (!CS) @@ -273,7 +301,7 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, if (CI && CalledFunction->isIntrinsic()) { addIntrinsicToSummary( CI, TypeTests, TypeTestAssumeVCalls, TypeCheckedLoadVCalls, - TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls); + TypeTestAssumeConstVCalls, TypeCheckedLoadConstVCalls, DT); continue; } // We should have named any anonymous globals @@ -329,6 +357,24 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, } } + // By now we processed all instructions in a function, except + // non-volatile loads. All new refs we add in a loop below + // are obviously constant. All constant refs are grouped in the + // end of RefEdges vector, so we can use a single integer value + // to identify them. + unsigned RefCnt = RefEdges.size(); + for (const Instruction *I : NonVolatileLoads) { + Visited.erase(I); + findRefEdges(Index, I, RefEdges, Visited); + } + std::vector<ValueInfo> Refs = RefEdges.takeVector(); + // Regular LTO module doesn't participate in ThinLTO import, + // so no reference from it can be readonly, since this would + // require importing variable as local copy + if (IsThinLTO) + for (; RefCnt < Refs.size(); ++RefCnt) + Refs[RefCnt].setReadOnly(); + // Explicit add hot edges to enforce importing for designated GUIDs for // sample PGO, to enable the same inlines as the profiled optimized binary. for (auto &I : F.getImportGUIDs()) @@ -339,22 +385,18 @@ computeFunctionSummary(ModuleSummaryIndex &Index, const Module &M, bool NonRenamableLocal = isNonRenamableLocal(F); bool NotEligibleForImport = - NonRenamableLocal || HasInlineAsmMaybeReferencingInternal || - // Inliner doesn't handle variadic functions. - // FIXME: refactor this to use the same code that inliner is using. - F.isVarArg() || - // Don't try to import functions with noinline attribute. - F.getAttributes().hasFnAttribute(Attribute::NoInline); + NonRenamableLocal || HasInlineAsmMaybeReferencingInternal; GlobalValueSummary::GVFlags Flags(F.getLinkage(), NotEligibleForImport, /* Live = */ false, F.isDSOLocal()); FunctionSummary::FFlags FunFlags{ F.hasFnAttribute(Attribute::ReadNone), F.hasFnAttribute(Attribute::ReadOnly), - F.hasFnAttribute(Attribute::NoRecurse), - F.returnDoesNotAlias(), - }; + F.hasFnAttribute(Attribute::NoRecurse), F.returnDoesNotAlias(), + // FIXME: refactor this to use the same code that inliner is using. + // Don't try to import functions with noinline attribute. + F.getAttributes().hasFnAttribute(Attribute::NoInline)}; auto FuncSummary = llvm::make_unique<FunctionSummary>( - Flags, NumInsts, FunFlags, RefEdges.takeVector(), + Flags, NumInsts, FunFlags, /*EntryCount=*/0, std::move(Refs), CallGraphEdges.takeVector(), TypeTests.takeVector(), TypeTestAssumeVCalls.takeVector(), TypeCheckedLoadVCalls.takeVector(), TypeTestAssumeConstVCalls.takeVector(), @@ -369,14 +411,21 @@ computeVariableSummary(ModuleSummaryIndex &Index, const GlobalVariable &V, DenseSet<GlobalValue::GUID> &CantBePromoted) { SetVector<ValueInfo> RefEdges; SmallPtrSet<const User *, 8> Visited; - findRefEdges(Index, &V, RefEdges, Visited); + bool HasBlockAddress = findRefEdges(Index, &V, RefEdges, Visited); bool NonRenamableLocal = isNonRenamableLocal(V); GlobalValueSummary::GVFlags Flags(V.getLinkage(), NonRenamableLocal, /* Live = */ false, V.isDSOLocal()); - auto GVarSummary = - llvm::make_unique<GlobalVarSummary>(Flags, RefEdges.takeVector()); + + // Don't mark variables we won't be able to internalize as read-only. + GlobalVarSummary::GVarFlags VarFlags( + !V.hasComdat() && !V.hasAppendingLinkage() && !V.isInterposable() && + !V.hasAvailableExternallyLinkage() && !V.hasDLLExportStorageClass()); + auto GVarSummary = llvm::make_unique<GlobalVarSummary>(Flags, VarFlags, + RefEdges.takeVector()); if (NonRenamableLocal) CantBePromoted.insert(V.getGUID()); + if (HasBlockAddress) + GVarSummary->setNotEligibleToImport(); Index.addGlobalValueSummary(V, std::move(GVarSummary)); } @@ -408,7 +457,11 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( std::function<BlockFrequencyInfo *(const Function &F)> GetBFICallback, ProfileSummaryInfo *PSI) { assert(PSI); - ModuleSummaryIndex Index(/*HaveGVs=*/true); + bool EnableSplitLTOUnit = false; + if (auto *MD = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag("EnableSplitLTOUnit"))) + EnableSplitLTOUnit = MD->getZExtValue(); + ModuleSummaryIndex Index(/*HaveGVs=*/true, EnableSplitLTOUnit); // Identify the local values in the llvm.used and llvm.compiler.used sets, // which should not be exported as they would then require renaming and @@ -460,13 +513,15 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( if (Function *F = dyn_cast<Function>(GV)) { std::unique_ptr<FunctionSummary> Summary = llvm::make_unique<FunctionSummary>( - GVFlags, 0, + GVFlags, /*InstCount=*/0, FunctionSummary::FFlags{ F->hasFnAttribute(Attribute::ReadNone), F->hasFnAttribute(Attribute::ReadOnly), F->hasFnAttribute(Attribute::NoRecurse), - F->returnDoesNotAlias()}, - ArrayRef<ValueInfo>{}, ArrayRef<FunctionSummary::EdgeTy>{}, + F->returnDoesNotAlias(), + /* NoInline = */ false}, + /*EntryCount=*/0, ArrayRef<ValueInfo>{}, + ArrayRef<FunctionSummary::EdgeTy>{}, ArrayRef<GlobalValue::GUID>{}, ArrayRef<FunctionSummary::VFuncId>{}, ArrayRef<FunctionSummary::VFuncId>{}, @@ -475,33 +530,40 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( Index.addGlobalValueSummary(*GV, std::move(Summary)); } else { std::unique_ptr<GlobalVarSummary> Summary = - llvm::make_unique<GlobalVarSummary>(GVFlags, - ArrayRef<ValueInfo>{}); + llvm::make_unique<GlobalVarSummary>( + GVFlags, GlobalVarSummary::GVarFlags(), + ArrayRef<ValueInfo>{}); Index.addGlobalValueSummary(*GV, std::move(Summary)); } }); } + bool IsThinLTO = true; + if (auto *MD = + mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("ThinLTO"))) + IsThinLTO = MD->getZExtValue(); + // Compute summaries for all functions defined in module, and save in the // index. for (auto &F : M) { if (F.isDeclaration()) continue; + DominatorTree DT(const_cast<Function &>(F)); BlockFrequencyInfo *BFI = nullptr; std::unique_ptr<BlockFrequencyInfo> BFIPtr; if (GetBFICallback) BFI = GetBFICallback(F); else if (F.hasProfileData()) { - LoopInfo LI{DominatorTree(const_cast<Function &>(F))}; + LoopInfo LI{DT}; BranchProbabilityInfo BPI{F, LI}; BFIPtr = llvm::make_unique<BlockFrequencyInfo>(F, BPI, LI); BFI = BFIPtr.get(); } - computeFunctionSummary(Index, M, F, BFI, PSI, + computeFunctionSummary(Index, M, F, BFI, PSI, DT, !LocalsUsed.empty() || HasLocalInlineAsmSymbol, - CantBePromoted); + CantBePromoted, IsThinLTO); } // Compute summaries for all variables defined in module, and save in the @@ -532,11 +594,6 @@ ModuleSummaryIndex llvm::buildModuleSummaryIndex( setLiveRoot(Index, "llvm.global_dtors"); setLiveRoot(Index, "llvm.global.annotations"); - bool IsThinLTO = true; - if (auto *MD = - mdconst::extract_or_null<ConstantInt>(M.getModuleFlag("ThinLTO"))) - IsThinLTO = MD->getZExtValue(); - for (auto &GlobalList : Index) { // Ignore entries for references that are undefined in the current module. if (GlobalList.second.SummaryList.empty()) @@ -606,7 +663,7 @@ ModuleSummaryIndexWrapperPass::ModuleSummaryIndexWrapperPass() } bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) { - auto &PSI = *getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); Index.emplace(buildModuleSummaryIndex( M, [this](const Function &F) { @@ -614,7 +671,7 @@ bool ModuleSummaryIndexWrapperPass::runOnModule(Module &M) { *const_cast<Function *>(&F)) .getBFI()); }, - &PSI)); + PSI)); return false; } diff --git a/lib/Analysis/MustExecute.cpp b/lib/Analysis/MustExecute.cpp index 8e85366b4618..180c38ddacc2 100644 --- a/lib/Analysis/MustExecute.cpp +++ b/lib/Analysis/MustExecute.cpp @@ -22,20 +22,32 @@ #include "llvm/Support/raw_ostream.h" using namespace llvm; -/// Computes loop safety information, checks loop body & header -/// for the possibility of may throw exception. -/// -void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { +const DenseMap<BasicBlock *, ColorVector> & +LoopSafetyInfo::getBlockColors() const { + return BlockColors; +} + +void LoopSafetyInfo::copyColors(BasicBlock *New, BasicBlock *Old) { + ColorVector &ColorsForNewBlock = BlockColors[New]; + ColorVector &ColorsForOldBlock = BlockColors[Old]; + ColorsForNewBlock = ColorsForOldBlock; +} + +bool SimpleLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const { + (void)BB; + return anyBlockMayThrow(); +} + +bool SimpleLoopSafetyInfo::anyBlockMayThrow() const { + return MayThrow; +} + +void SimpleLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) { assert(CurLoop != nullptr && "CurLoop can't be null"); BasicBlock *Header = CurLoop->getHeader(); - // Setting default safety values. - SafetyInfo->MayThrow = false; - SafetyInfo->HeaderMayThrow = false; // Iterate over header and compute safety info. - SafetyInfo->HeaderMayThrow = - !isGuaranteedToTransferExecutionToSuccessor(Header); - - SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; + HeaderMayThrow = !isGuaranteedToTransferExecutionToSuccessor(Header); + MayThrow = HeaderMayThrow; // Iterate over loop instructions and compute safety info. // Skip header as it has been computed and stored in HeaderMayThrow. // The first block in loopinfo.Blocks is guaranteed to be the header. @@ -43,23 +55,59 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { "First block must be header"); for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), BBE = CurLoop->block_end(); - (BB != BBE) && !SafetyInfo->MayThrow; ++BB) - SafetyInfo->MayThrow |= - !isGuaranteedToTransferExecutionToSuccessor(*BB); + (BB != BBE) && !MayThrow; ++BB) + MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(*BB); + + computeBlockColors(CurLoop); +} +bool ICFLoopSafetyInfo::blockMayThrow(const BasicBlock *BB) const { + return ICF.hasICF(BB); +} + +bool ICFLoopSafetyInfo::anyBlockMayThrow() const { + return MayThrow; +} + +void ICFLoopSafetyInfo::computeLoopSafetyInfo(const Loop *CurLoop) { + assert(CurLoop != nullptr && "CurLoop can't be null"); + ICF.clear(); + MW.clear(); + MayThrow = false; + // Figure out the fact that at least one block may throw. + for (auto &BB : CurLoop->blocks()) + if (ICF.hasICF(&*BB)) { + MayThrow = true; + break; + } + computeBlockColors(CurLoop); +} + +void ICFLoopSafetyInfo::insertInstructionTo(const Instruction *Inst, + const BasicBlock *BB) { + ICF.insertInstructionTo(Inst, BB); + MW.insertInstructionTo(Inst, BB); +} + +void ICFLoopSafetyInfo::removeInstruction(const Instruction *Inst) { + ICF.removeInstruction(Inst); + MW.removeInstruction(Inst); +} + +void LoopSafetyInfo::computeBlockColors(const Loop *CurLoop) { // Compute funclet colors if we might sink/hoist in a function with a funclet // personality routine. Function *Fn = CurLoop->getHeader()->getParent(); if (Fn->hasPersonalityFn()) if (Constant *PersonalityFn = Fn->getPersonalityFn()) if (isScopedEHPersonality(classifyEHPersonality(PersonalityFn))) - SafetyInfo->BlockColors = colorEHFunclets(*Fn); + BlockColors = colorEHFunclets(*Fn); } /// Return true if we can prove that the given ExitBlock is not reached on the /// first iteration of the given loop. That is, the backedge of the loop must /// be executed before the ExitBlock is executed in any dynamic execution trace. -static bool CanProveNotTakenFirstIteration(BasicBlock *ExitBlock, +static bool CanProveNotTakenFirstIteration(const BasicBlock *ExitBlock, const DominatorTree *DT, const Loop *CurLoop) { auto *CondExitBlock = ExitBlock->getSinglePredecessor(); @@ -99,15 +147,94 @@ static bool CanProveNotTakenFirstIteration(BasicBlock *ExitBlock, return SimpleCst->isAllOnesValue(); } +/// Collect all blocks from \p CurLoop which lie on all possible paths from +/// the header of \p CurLoop (inclusive) to BB (exclusive) into the set +/// \p Predecessors. If \p BB is the header, \p Predecessors will be empty. +static void collectTransitivePredecessors( + const Loop *CurLoop, const BasicBlock *BB, + SmallPtrSetImpl<const BasicBlock *> &Predecessors) { + assert(Predecessors.empty() && "Garbage in predecessors set?"); + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + if (BB == CurLoop->getHeader()) + return; + SmallVector<const BasicBlock *, 4> WorkList; + for (auto *Pred : predecessors(BB)) { + Predecessors.insert(Pred); + WorkList.push_back(Pred); + } + while (!WorkList.empty()) { + auto *Pred = WorkList.pop_back_val(); + assert(CurLoop->contains(Pred) && "Should only reach loop blocks!"); + // We are not interested in backedges and we don't want to leave loop. + if (Pred == CurLoop->getHeader()) + continue; + // TODO: If BB lies in an inner loop of CurLoop, this will traverse over all + // blocks of this inner loop, even those that are always executed AFTER the + // BB. It may make our analysis more conservative than it could be, see test + // @nested and @nested_no_throw in test/Analysis/MustExecute/loop-header.ll. + // We can ignore backedge of all loops containing BB to get a sligtly more + // optimistic result. + for (auto *PredPred : predecessors(Pred)) + if (Predecessors.insert(PredPred).second) + WorkList.push_back(PredPred); + } +} + +bool LoopSafetyInfo::allLoopPathsLeadToBlock(const Loop *CurLoop, + const BasicBlock *BB, + const DominatorTree *DT) const { + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + + // Fast path: header is always reached once the loop is entered. + if (BB == CurLoop->getHeader()) + return true; + + // Collect all transitive predecessors of BB in the same loop. This set will + // be a subset of the blocks within the loop. + SmallPtrSet<const BasicBlock *, 4> Predecessors; + collectTransitivePredecessors(CurLoop, BB, Predecessors); + + // Make sure that all successors of all predecessors of BB are either: + // 1) BB, + // 2) Also predecessors of BB, + // 3) Exit blocks which are not taken on 1st iteration. + // Memoize blocks we've already checked. + SmallPtrSet<const BasicBlock *, 4> CheckedSuccessors; + for (auto *Pred : Predecessors) { + // Predecessor block may throw, so it has a side exit. + if (blockMayThrow(Pred)) + return false; + for (auto *Succ : successors(Pred)) + if (CheckedSuccessors.insert(Succ).second && + Succ != BB && !Predecessors.count(Succ)) + // By discharging conditions that are not executed on the 1st iteration, + // we guarantee that *at least* on the first iteration all paths from + // header that *may* execute will lead us to the block of interest. So + // that if we had virtually peeled one iteration away, in this peeled + // iteration the set of predecessors would contain only paths from + // header to BB without any exiting edges that may execute. + // + // TODO: We only do it for exiting edges currently. We could use the + // same function to skip some of the edges within the loop if we know + // that they will not be taken on the 1st iteration. + // + // TODO: If we somehow know the number of iterations in loop, the same + // check may be done for any arbitrary N-th iteration as long as N is + // not greater than minimum number of iterations in this loop. + if (CurLoop->contains(Succ) || + !CanProveNotTakenFirstIteration(Succ, DT, CurLoop)) + return false; + } + + // All predecessors can only lead us to BB. + return true; +} + /// Returns true if the instruction in a loop is guaranteed to execute at least /// once. -bool llvm::isGuaranteedToExecute(const Instruction &Inst, - const DominatorTree *DT, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo) { - // We have to check to make sure that the instruction dominates all - // of the exit blocks. If it doesn't, then there is a path out of the loop - // which does not execute this instruction, so we can't hoist it. - +bool SimpleLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop) const { // If the instruction is in the header block for the loop (which is very // common), it is always guaranteed to dominate the exit blocks. Since this // is a common case, and can save some work, check it now. @@ -116,52 +243,48 @@ bool llvm::isGuaranteedToExecute(const Instruction &Inst, // Inst unless we can prove that Inst comes before the potential implicit // exit. At the moment, we use a (cheap) hack for the common case where // the instruction of interest is the first one in the block. - return !SafetyInfo->HeaderMayThrow || - Inst.getParent()->getFirstNonPHIOrDbg() == &Inst; + return !HeaderMayThrow || + Inst.getParent()->getFirstNonPHIOrDbg() == &Inst; - // Somewhere in this loop there is an instruction which may throw and make us - // exit the loop. - if (SafetyInfo->MayThrow) - return false; + // If there is a path from header to exit or latch that doesn't lead to our + // instruction's block, return false. + return allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT); +} - // Note: There are two styles of reasoning intermixed below for - // implementation efficiency reasons. They are: - // 1) If we can prove that the instruction dominates all exit blocks, then we - // know the instruction must have executed on *some* iteration before we - // exit. We do not prove *which* iteration the instruction must execute on. - // 2) If we can prove that the instruction dominates the latch and all exits - // which might be taken on the first iteration, we know the instruction must - // execute on the first iteration. This second style allows a conditional - // exit before the instruction of interest which is provably not taken on the - // first iteration. This is a quite common case for range check like - // patterns. TODO: support loops with multiple latches. - - const bool InstDominatesLatch = - CurLoop->getLoopLatch() != nullptr && - DT->dominates(Inst.getParent(), CurLoop->getLoopLatch()); - - // Get the exit blocks for the current loop. - SmallVector<BasicBlock *, 8> ExitBlocks; - CurLoop->getExitBlocks(ExitBlocks); - - // Verify that the block dominates each of the exit blocks of the loop. - for (BasicBlock *ExitBlock : ExitBlocks) - if (!DT->dominates(Inst.getParent(), ExitBlock)) - if (!InstDominatesLatch || - !CanProveNotTakenFirstIteration(ExitBlock, DT, CurLoop)) - return false; - - // As a degenerate case, if the loop is statically infinite then we haven't - // proven anything since there are no exit blocks. - if (ExitBlocks.empty()) - return false; +bool ICFLoopSafetyInfo::isGuaranteedToExecute(const Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop) const { + return !ICF.isDominatedByICFIFromSameBlock(&Inst) && + allLoopPathsLeadToBlock(CurLoop, Inst.getParent(), DT); +} + +bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const BasicBlock *BB, + const Loop *CurLoop) const { + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); - // FIXME: In general, we have to prove that the loop isn't an infinite loop. - // See http::llvm.org/PR24078 . (The "ExitBlocks.empty()" check above is - // just a special case of this.) + // Fast path: there are no instructions before header. + if (BB == CurLoop->getHeader()) + return true; + + // Collect all transitive predecessors of BB in the same loop. This set will + // be a subset of the blocks within the loop. + SmallPtrSet<const BasicBlock *, 4> Predecessors; + collectTransitivePredecessors(CurLoop, BB, Predecessors); + // Find if there any instruction in either predecessor that could write + // to memory. + for (auto *Pred : Predecessors) + if (MW.mayWriteToMemory(Pred)) + return false; return true; } +bool ICFLoopSafetyInfo::doesNotWriteMemoryBefore(const Instruction &I, + const Loop *CurLoop) const { + auto *BB = I.getParent(); + assert(CurLoop->contains(BB) && "Should only be called for loop blocks!"); + return !MW.isDominatedByMemoryWriteFromSameBlock(&I) && + doesNotWriteMemoryBefore(BB, CurLoop); +} namespace { struct MustExecutePrinter : public FunctionPass { @@ -195,9 +318,9 @@ static bool isMustExecuteIn(const Instruction &I, Loop *L, DominatorTree *DT) { // TODO: merge these two routines. For the moment, we display the best // result obtained by *either* implementation. This is a bit unfair since no // caller actually gets the full power at the moment. - LoopSafetyInfo LSI; - computeLoopSafetyInfo(&LSI, L); - return isGuaranteedToExecute(I, DT, L, &LSI) || + SimpleLoopSafetyInfo LSI; + LSI.computeLoopSafetyInfo(L); + return LSI.isGuaranteedToExecute(I, DT, L) || isGuaranteedToExecuteForEveryIteration(&I, L); } diff --git a/lib/Analysis/ObjCARCAliasAnalysis.cpp b/lib/Analysis/ObjCARCAliasAnalysis.cpp index 096ea661ecb6..95ae1a6e744f 100644 --- a/lib/Analysis/ObjCARCAliasAnalysis.cpp +++ b/lib/Analysis/ObjCARCAliasAnalysis.cpp @@ -106,12 +106,12 @@ FunctionModRefBehavior ObjCARCAAResult::getModRefBehavior(const Function *F) { return AAResultBase::getModRefBehavior(F); } -ModRefInfo ObjCARCAAResult::getModRefInfo(ImmutableCallSite CS, +ModRefInfo ObjCARCAAResult::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { if (!EnableARCOpts) - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); - switch (GetBasicARCInstKind(CS.getInstruction())) { + switch (GetBasicARCInstKind(Call)) { case ARCInstKind::Retain: case ARCInstKind::RetainRV: case ARCInstKind::Autorelease: @@ -128,7 +128,7 @@ ModRefInfo ObjCARCAAResult::getModRefInfo(ImmutableCallSite CS, break; } - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); } ObjCARCAAResult ObjCARCAA::run(Function &F, FunctionAnalysisManager &AM) { diff --git a/lib/Analysis/ObjCARCInstKind.cpp b/lib/Analysis/ObjCARCInstKind.cpp index f268e2a9abdd..31c432711834 100644 --- a/lib/Analysis/ObjCARCInstKind.cpp +++ b/lib/Analysis/ObjCARCInstKind.cpp @@ -85,97 +85,73 @@ raw_ostream &llvm::objcarc::operator<<(raw_ostream &OS, } ARCInstKind llvm::objcarc::GetFunctionClass(const Function *F) { - Function::const_arg_iterator AI = F->arg_begin(), AE = F->arg_end(); - // No (mandatory) arguments. - if (AI == AE) - return StringSwitch<ARCInstKind>(F->getName()) - .Case("objc_autoreleasePoolPush", ARCInstKind::AutoreleasepoolPush) - .Case("clang.arc.use", ARCInstKind::IntrinsicUser) - .Default(ARCInstKind::CallOrUser); - - // One argument. - const Argument *A0 = &*AI++; - if (AI == AE) { - // Argument is a pointer. - PointerType *PTy = dyn_cast<PointerType>(A0->getType()); - if (!PTy) - return ARCInstKind::CallOrUser; - - Type *ETy = PTy->getElementType(); - // Argument is i8*. - if (ETy->isIntegerTy(8)) - return StringSwitch<ARCInstKind>(F->getName()) - .Case("objc_retain", ARCInstKind::Retain) - .Case("objc_retainAutoreleasedReturnValue", ARCInstKind::RetainRV) - .Case("objc_unsafeClaimAutoreleasedReturnValue", ARCInstKind::ClaimRV) - .Case("objc_retainBlock", ARCInstKind::RetainBlock) - .Case("objc_release", ARCInstKind::Release) - .Case("objc_autorelease", ARCInstKind::Autorelease) - .Case("objc_autoreleaseReturnValue", ARCInstKind::AutoreleaseRV) - .Case("objc_autoreleasePoolPop", ARCInstKind::AutoreleasepoolPop) - .Case("objc_retainedObject", ARCInstKind::NoopCast) - .Case("objc_unretainedObject", ARCInstKind::NoopCast) - .Case("objc_unretainedPointer", ARCInstKind::NoopCast) - .Case("objc_retain_autorelease", ARCInstKind::FusedRetainAutorelease) - .Case("objc_retainAutorelease", ARCInstKind::FusedRetainAutorelease) - .Case("objc_retainAutoreleaseReturnValue", - ARCInstKind::FusedRetainAutoreleaseRV) - .Case("objc_sync_enter", ARCInstKind::User) - .Case("objc_sync_exit", ARCInstKind::User) - .Default(ARCInstKind::CallOrUser); - - // Argument is i8** - if (PointerType *Pte = dyn_cast<PointerType>(ETy)) - if (Pte->getElementType()->isIntegerTy(8)) - return StringSwitch<ARCInstKind>(F->getName()) - .Case("objc_loadWeakRetained", ARCInstKind::LoadWeakRetained) - .Case("objc_loadWeak", ARCInstKind::LoadWeak) - .Case("objc_destroyWeak", ARCInstKind::DestroyWeak) - .Default(ARCInstKind::CallOrUser); - - // Anything else with one argument. + Intrinsic::ID ID = F->getIntrinsicID(); + switch (ID) { + default: return ARCInstKind::CallOrUser; + case Intrinsic::objc_autorelease: + return ARCInstKind::Autorelease; + case Intrinsic::objc_autoreleasePoolPop: + return ARCInstKind::AutoreleasepoolPop; + case Intrinsic::objc_autoreleasePoolPush: + return ARCInstKind::AutoreleasepoolPush; + case Intrinsic::objc_autoreleaseReturnValue: + return ARCInstKind::AutoreleaseRV; + case Intrinsic::objc_copyWeak: + return ARCInstKind::CopyWeak; + case Intrinsic::objc_destroyWeak: + return ARCInstKind::DestroyWeak; + case Intrinsic::objc_initWeak: + return ARCInstKind::InitWeak; + case Intrinsic::objc_loadWeak: + return ARCInstKind::LoadWeak; + case Intrinsic::objc_loadWeakRetained: + return ARCInstKind::LoadWeakRetained; + case Intrinsic::objc_moveWeak: + return ARCInstKind::MoveWeak; + case Intrinsic::objc_release: + return ARCInstKind::Release; + case Intrinsic::objc_retain: + return ARCInstKind::Retain; + case Intrinsic::objc_retainAutorelease: + return ARCInstKind::FusedRetainAutorelease; + case Intrinsic::objc_retainAutoreleaseReturnValue: + return ARCInstKind::FusedRetainAutoreleaseRV; + case Intrinsic::objc_retainAutoreleasedReturnValue: + return ARCInstKind::RetainRV; + case Intrinsic::objc_retainBlock: + return ARCInstKind::RetainBlock; + case Intrinsic::objc_storeStrong: + return ARCInstKind::StoreStrong; + case Intrinsic::objc_storeWeak: + return ARCInstKind::StoreWeak; + case Intrinsic::objc_clang_arc_use: + return ARCInstKind::IntrinsicUser; + case Intrinsic::objc_unsafeClaimAutoreleasedReturnValue: + return ARCInstKind::ClaimRV; + case Intrinsic::objc_retainedObject: + return ARCInstKind::NoopCast; + case Intrinsic::objc_unretainedObject: + return ARCInstKind::NoopCast; + case Intrinsic::objc_unretainedPointer: + return ARCInstKind::NoopCast; + case Intrinsic::objc_retain_autorelease: + return ARCInstKind::FusedRetainAutorelease; + case Intrinsic::objc_sync_enter: + return ARCInstKind::User; + case Intrinsic::objc_sync_exit: + return ARCInstKind::User; + case Intrinsic::objc_arc_annotation_topdown_bbstart: + case Intrinsic::objc_arc_annotation_topdown_bbend: + case Intrinsic::objc_arc_annotation_bottomup_bbstart: + case Intrinsic::objc_arc_annotation_bottomup_bbend: + // Ignore annotation calls. This is important to stop the + // optimizer from treating annotations as uses which would + // make the state of the pointers they are attempting to + // elucidate to be incorrect. + return ARCInstKind::None; } - - // Two arguments, first is i8**. - const Argument *A1 = &*AI++; - if (AI == AE) - if (PointerType *PTy = dyn_cast<PointerType>(A0->getType())) - if (PointerType *Pte = dyn_cast<PointerType>(PTy->getElementType())) - if (Pte->getElementType()->isIntegerTy(8)) - if (PointerType *PTy1 = dyn_cast<PointerType>(A1->getType())) { - Type *ETy1 = PTy1->getElementType(); - // Second argument is i8* - if (ETy1->isIntegerTy(8)) - return StringSwitch<ARCInstKind>(F->getName()) - .Case("objc_storeWeak", ARCInstKind::StoreWeak) - .Case("objc_initWeak", ARCInstKind::InitWeak) - .Case("objc_storeStrong", ARCInstKind::StoreStrong) - .Default(ARCInstKind::CallOrUser); - // Second argument is i8**. - if (PointerType *Pte1 = dyn_cast<PointerType>(ETy1)) - if (Pte1->getElementType()->isIntegerTy(8)) - return StringSwitch<ARCInstKind>(F->getName()) - .Case("objc_moveWeak", ARCInstKind::MoveWeak) - .Case("objc_copyWeak", ARCInstKind::CopyWeak) - // Ignore annotation calls. This is important to stop the - // optimizer from treating annotations as uses which would - // make the state of the pointers they are attempting to - // elucidate to be incorrect. - .Case("llvm.arc.annotation.topdown.bbstart", - ARCInstKind::None) - .Case("llvm.arc.annotation.topdown.bbend", - ARCInstKind::None) - .Case("llvm.arc.annotation.bottomup.bbstart", - ARCInstKind::None) - .Case("llvm.arc.annotation.bottomup.bbend", - ARCInstKind::None) - .Default(ARCInstKind::CallOrUser); - } - - // Anything else. - return ARCInstKind::CallOrUser; } // A whitelist of intrinsics that we know do not use objc pointers or decrement diff --git a/lib/Analysis/OrderedBasicBlock.cpp b/lib/Analysis/OrderedBasicBlock.cpp index 6c47651eae9e..5f4fe0f7dda2 100644 --- a/lib/Analysis/OrderedBasicBlock.cpp +++ b/lib/Analysis/OrderedBasicBlock.cpp @@ -37,6 +37,8 @@ bool OrderedBasicBlock::comesBefore(const Instruction *A, const Instruction *Inst = nullptr; assert(!(LastInstFound == BB->end() && NextInstPos != 0) && "Instruction supposed to be in NumberedInsts"); + assert(A->getParent() == BB && "Instruction supposed to be in the block!"); + assert(B->getParent() == BB && "Instruction supposed to be in the block!"); // Start the search with the instruction found in the last lookup round. auto II = BB->begin(); @@ -65,6 +67,7 @@ bool OrderedBasicBlock::comesBefore(const Instruction *A, bool OrderedBasicBlock::dominates(const Instruction *A, const Instruction *B) { assert(A->getParent() == B->getParent() && "Instructions must be in the same basic block!"); + assert(A->getParent() == BB && "Instructions must be in the tracked block!"); // First we lookup the instructions. If they don't exist, lookup will give us // back ::end(). If they both exist, we compare the numbers. Otherwise, if NA diff --git a/lib/Analysis/OrderedInstructions.cpp b/lib/Analysis/OrderedInstructions.cpp new file mode 100644 index 000000000000..7b155208c02e --- /dev/null +++ b/lib/Analysis/OrderedInstructions.cpp @@ -0,0 +1,51 @@ +//===-- OrderedInstructions.cpp - Instruction dominance function ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file defines utility to check dominance relation of 2 instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/OrderedInstructions.h" +using namespace llvm; + +bool OrderedInstructions::localDominates(const Instruction *InstA, + const Instruction *InstB) const { + assert(InstA->getParent() == InstB->getParent() && + "Instructions must be in the same basic block"); + + const BasicBlock *IBB = InstA->getParent(); + auto OBB = OBBMap.find(IBB); + if (OBB == OBBMap.end()) + OBB = OBBMap.insert({IBB, make_unique<OrderedBasicBlock>(IBB)}).first; + return OBB->second->dominates(InstA, InstB); +} + +/// Given 2 instructions, use OrderedBasicBlock to check for dominance relation +/// if the instructions are in the same basic block, Otherwise, use dominator +/// tree. +bool OrderedInstructions::dominates(const Instruction *InstA, + const Instruction *InstB) const { + // Use ordered basic block to do dominance check in case the 2 instructions + // are in the same basic block. + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); + return DT->dominates(InstA->getParent(), InstB->getParent()); +} + +bool OrderedInstructions::dfsBefore(const Instruction *InstA, + const Instruction *InstB) const { + // Use ordered basic block in case the 2 instructions are in the same basic + // block. + if (InstA->getParent() == InstB->getParent()) + return localDominates(InstA, InstB); + + DomTreeNode *DA = DT->getNode(InstA->getParent()); + DomTreeNode *DB = DT->getNode(InstB->getParent()); + return DA->getDFSNumIn() < DB->getDFSNumIn(); +} diff --git a/lib/Analysis/PhiValues.cpp b/lib/Analysis/PhiValues.cpp index ef121815d2cf..729227c86697 100644 --- a/lib/Analysis/PhiValues.cpp +++ b/lib/Analysis/PhiValues.cpp @@ -14,6 +14,16 @@ using namespace llvm; +void PhiValues::PhiValuesCallbackVH::deleted() { + PV->invalidateValue(getValPtr()); +} + +void PhiValues::PhiValuesCallbackVH::allUsesReplacedWith(Value *) { + // We could potentially update the cached values we have with the new value, + // but it's simpler to just treat the old value as invalidated. + PV->invalidateValue(getValPtr()); +} + bool PhiValues::invalidate(Function &, const PreservedAnalyses &PA, FunctionAnalysisManager::Invalidator &) { // PhiValues is invalidated if it isn't preserved. @@ -46,6 +56,7 @@ void PhiValues::processPhi(const PHINode *Phi, DepthMap[Phi] = DepthNumber; // Recursively process the incoming phis of this phi. + TrackedValues.insert(PhiValuesCallbackVH(const_cast<PHINode *>(Phi), this)); for (Value *PhiOp : Phi->incoming_values()) { if (PHINode *PhiPhiOp = dyn_cast<PHINode>(PhiOp)) { // Recurse if the phi has not yet been visited. @@ -56,6 +67,8 @@ void PhiValues::processPhi(const PHINode *Phi, // phi are part of the same component, so adjust the depth number. if (!ReachableMap.count(DepthMap[PhiPhiOp])) DepthMap[Phi] = std::min(DepthMap[Phi], DepthMap[PhiPhiOp]); + } else { + TrackedValues.insert(PhiValuesCallbackVH(PhiOp, this)); } } @@ -122,6 +135,10 @@ void PhiValues::invalidateValue(const Value *V) { NonPhiReachableMap.erase(N); ReachableMap.erase(N); } + // This value is no longer tracked + auto It = TrackedValues.find_as(V); + if (It != TrackedValues.end()) + TrackedValues.erase(It); } void PhiValues::releaseMemory() { diff --git a/lib/Analysis/ProfileSummaryInfo.cpp b/lib/Analysis/ProfileSummaryInfo.cpp index fb591f5d6a69..1d70c75f2e1c 100644 --- a/lib/Analysis/ProfileSummaryInfo.cpp +++ b/lib/Analysis/ProfileSummaryInfo.cpp @@ -39,11 +39,6 @@ static cl::opt<int> ProfileSummaryCutoffCold( cl::desc("A count is cold if it is below the minimum count" " to reach this percentile of total counts.")); -static cl::opt<bool> ProfileSampleAccurate( - "profile-sample-accurate", cl::Hidden, cl::init(false), - cl::desc("If the sample profile is accurate, we will mark all un-sampled " - "callsite as cold. Otherwise, treat un-sampled callsites as if " - "we have no profile.")); static cl::opt<unsigned> ProfileSummaryHugeWorkingSetSizeThreshold( "profile-summary-huge-working-set-size-threshold", cl::Hidden, cl::init(15000), cl::ZeroOrMore, @@ -51,6 +46,18 @@ static cl::opt<unsigned> ProfileSummaryHugeWorkingSetSizeThreshold( " blocks required to reach the -profile-summary-cutoff-hot" " percentile exceeds this count.")); +// The next two options override the counts derived from summary computation and +// are useful for debugging purposes. +static cl::opt<int> ProfileSummaryHotCount( + "profile-summary-hot-count", cl::ReallyHidden, cl::ZeroOrMore, + cl::desc("A fixed hot count that overrides the count derived from" + " profile-summary-cutoff-hot")); + +static cl::opt<int> ProfileSummaryColdCount( + "profile-summary-cold-count", cl::ReallyHidden, cl::ZeroOrMore, + cl::desc("A fixed cold count that overrides the count derived from" + " profile-summary-cutoff-cold")); + // Find the summary entry for a desired percentile of counts. static const ProfileSummaryEntry &getEntryForPercentile(SummaryEntryVector &DS, uint64_t Percentile) { @@ -139,7 +146,7 @@ bool ProfileSummaryInfo::isFunctionHotInCallGraph(const Function *F, return true; } for (const auto &BB : *F) - if (isHotBB(&BB, &BFI)) + if (isHotBlock(&BB, &BFI)) return true; return false; } @@ -168,7 +175,7 @@ bool ProfileSummaryInfo::isFunctionColdInCallGraph(const Function *F, return false; } for (const auto &BB : *F) - if (!isColdBB(&BB, &BFI)) + if (!isColdBlock(&BB, &BFI)) return false; return true; } @@ -198,9 +205,15 @@ void ProfileSummaryInfo::computeThresholds() { auto &HotEntry = getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffHot); HotCountThreshold = HotEntry.MinCount; + if (ProfileSummaryHotCount.getNumOccurrences() > 0) + HotCountThreshold = ProfileSummaryHotCount; auto &ColdEntry = getEntryForPercentile(DetailedSummary, ProfileSummaryCutoffCold); ColdCountThreshold = ColdEntry.MinCount; + if (ProfileSummaryColdCount.getNumOccurrences() > 0) + ColdCountThreshold = ProfileSummaryColdCount; + assert(ColdCountThreshold <= HotCountThreshold && + "Cold count threshold cannot exceed hot count threshold!"); HasHugeWorkingSetSize = HotEntry.NumCounts > ProfileSummaryHugeWorkingSetSizeThreshold; } @@ -226,23 +239,23 @@ bool ProfileSummaryInfo::isColdCount(uint64_t C) { uint64_t ProfileSummaryInfo::getOrCompHotCountThreshold() { if (!HotCountThreshold) computeThresholds(); - return HotCountThreshold && HotCountThreshold.getValue(); + return HotCountThreshold ? HotCountThreshold.getValue() : UINT64_MAX; } uint64_t ProfileSummaryInfo::getOrCompColdCountThreshold() { if (!ColdCountThreshold) computeThresholds(); - return ColdCountThreshold && ColdCountThreshold.getValue(); + return ColdCountThreshold ? ColdCountThreshold.getValue() : 0; } -bool ProfileSummaryInfo::isHotBB(const BasicBlock *B, BlockFrequencyInfo *BFI) { - auto Count = BFI->getBlockProfileCount(B); +bool ProfileSummaryInfo::isHotBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) { + auto Count = BFI->getBlockProfileCount(BB); return Count && isHotCount(*Count); } -bool ProfileSummaryInfo::isColdBB(const BasicBlock *B, +bool ProfileSummaryInfo::isColdBlock(const BasicBlock *BB, BlockFrequencyInfo *BFI) { - auto Count = BFI->getBlockProfileCount(B); + auto Count = BFI->getBlockProfileCount(BB); return Count && isColdCount(*Count); } @@ -260,11 +273,7 @@ bool ProfileSummaryInfo::isColdCallSite(const CallSite &CS, // In SamplePGO, if the caller has been sampled, and there is no profile // annotated on the callsite, we consider the callsite as cold. - // If there is no profile for the caller, and we know the profile is - // accurate, we consider the callsite as cold. - return (hasSampleProfile() && - (CS.getCaller()->hasProfileData() || ProfileSampleAccurate || - CS.getCaller()->hasFnAttribute("profile-sample-accurate"))); + return hasSampleProfile() && CS.getCaller()->hasProfileData(); } INITIALIZE_PASS(ProfileSummaryInfoWrapperPass, "profile-summary-info", diff --git a/lib/Analysis/RegionPass.cpp b/lib/Analysis/RegionPass.cpp index ed17df2e7e93..a101ff109199 100644 --- a/lib/Analysis/RegionPass.cpp +++ b/lib/Analysis/RegionPass.cpp @@ -15,6 +15,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/RegionPass.h" #include "llvm/IR/OptBisect.h" +#include "llvm/IR/PassTimingInfo.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Timer.h" #include "llvm/Support/raw_ostream.h" diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 0e715b8814ff..e5134f2eeda9 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -112,6 +112,7 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -162,6 +163,11 @@ static cl::opt<bool> cl::desc("Verify no dangling value in ScalarEvolution's " "ExprValueMap (slow)")); +static cl::opt<bool> VerifyIR( + "scev-verify-ir", cl::Hidden, + cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), + cl::init(false)); + static cl::opt<unsigned> MulOpsInlineThreshold( "scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), @@ -204,7 +210,7 @@ static cl::opt<unsigned> static cl::opt<unsigned> MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), - cl::init(16)); + cl::init(8)); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -692,10 +698,6 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; - // Compare NoWrap flags. - if (LA->getNoWrapFlags() != RA->getNoWrapFlags()) - return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags(); - // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, @@ -720,10 +722,6 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; - // Compare NoWrap flags. - if (LC->getNoWrapFlags() != RC->getNoWrapFlags()) - return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags(); - for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(i), RC->getOperand(i), DT, @@ -2767,6 +2765,29 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, } const SCEV * +ScalarEvolution::getOrCreateAddRecExpr(SmallVectorImpl<const SCEV *> &Ops, + const Loop *L, SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + ID.AddPointer(L); + void *IP = nullptr; + SCEVAddRecExpr *S = + static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) + SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + } + S->setNoWrapFlags(Flags); + return S; +} + +const SCEV * ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; @@ -3045,7 +3066,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV*, 7> AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getZero(Ty); + SmallVector <const SCEV *, 7> SumOps; for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), @@ -3060,12 +3081,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEV *CoeffTerm = getConstant(Ty, Coeff); const SCEV *Term1 = AddRec->getOperand(y-z); const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2, - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, + SCEV::FlagAnyWrap, Depth + 1)); } } - AddRecOps.push_back(Term); + if (SumOps.empty()) + SumOps.push_back(getZero(Ty)); + AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); } if (!Overflow) { const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), @@ -3416,24 +3438,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, // Okay, it looks like we really DO need an addrec expr. Check to see if we // already have one, otherwise create a new one. - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - for (unsigned i = 0, e = Operands.size(); i != e; ++i) - ID.AddPointer(Operands[i]); - ID.AddPointer(L); - void *IP = nullptr; - SCEVAddRecExpr *S = - static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size()); - std::uninitialized_copy(Operands.begin(), Operands.end(), O); - S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), - O, Operands.size(), L); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - } - S->setNoWrapFlags(Flags); - return S; + return getOrCreateAddRecExpr(Operands, L, Flags); } const SCEV * @@ -7080,7 +7085,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, return getCouldNotCompute(); bool IsOnlyExit = (L->getExitingBlock() != nullptr); - TerminatorInst *Term = ExitingBlock->getTerminator(); + Instruction *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); @@ -8344,69 +8349,273 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } -/// Find the roots of the quadratic equation for the given quadratic chrec -/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or -/// two SCEVCouldNotCompute objects. -static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> -SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { +/// For a given quadratic addrec, generate coefficients of the corresponding +/// quadratic equation, multiplied by a common value to ensure that they are +/// integers. +/// The returned value is a tuple { A, B, C, M, BitWidth }, where +/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C +/// were multiplied by, and BitWidth is the bit width of the original addrec +/// coefficients. +/// This function returns None if the addrec coefficients are not compile- +/// time constants. +static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>> +GetQuadraticEquation(const SCEVAddRecExpr *AddRec) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); + LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: " + << *AddRec << '\n'); // We currently can only solve this if the coefficients are constants. - if (!LC || !MC || !NC) + if (!LC || !MC || !NC) { + LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n"); return None; + } - uint32_t BitWidth = LC->getAPInt().getBitWidth(); - const APInt &L = LC->getAPInt(); - const APInt &M = MC->getAPInt(); - const APInt &N = NC->getAPInt(); - APInt Two(BitWidth, 2); - - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + APInt L = LC->getAPInt(); + APInt M = MC->getAPInt(); + APInt N = NC->getAPInt(); + assert(!N.isNullValue() && "This is not a quadratic addrec"); + + unsigned BitWidth = LC->getAPInt().getBitWidth(); + unsigned NewWidth = BitWidth + 1; + LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: " + << BitWidth << '\n'); + // The sign-extension (as opposed to a zero-extension) here matches the + // extension used in SolveQuadraticEquationWrap (with the same motivation). + N = N.sext(NewWidth); + M = M.sext(NewWidth); + L = L.sext(NewWidth); + + // The increments are M, M+N, M+2N, ..., so the accumulated values are + // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is, + // L+M, L+2M+N, L+3M+3N, ... + // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N. + // + // The equation Acc = 0 is then + // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0. + // In a quadratic form it becomes: + // N n^2 + (2M-N) n + 2L = 0. + + APInt A = N; + APInt B = 2 * M - A; + APInt C = 2 * L; + APInt T = APInt(NewWidth, 2); + LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B + << "x + " << C << ", coeff bw: " << NewWidth + << ", multiplied by " << T << '\n'); + return std::make_tuple(A, B, C, T, BitWidth); +} + +/// Helper function to compare optional APInts: +/// (a) if X and Y both exist, return min(X, Y), +/// (b) if neither X nor Y exist, return None, +/// (c) if exactly one of X and Y exists, return that value. +static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) { + if (X.hasValue() && Y.hasValue()) { + unsigned W = std::max(X->getBitWidth(), Y->getBitWidth()); + APInt XW = X->sextOrSelf(W); + APInt YW = Y->sextOrSelf(W); + return XW.slt(YW) ? *X : *Y; + } + if (!X.hasValue() && !Y.hasValue()) + return None; + return X.hasValue() ? *X : *Y; +} - // The A coefficient is N/2 - APInt A = N.sdiv(Two); +/// Helper function to truncate an optional APInt to a given BitWidth. +/// When solving addrec-related equations, it is preferable to return a value +/// that has the same bit width as the original addrec's coefficients. If the +/// solution fits in the original bit width, truncate it (except for i1). +/// Returning a value of a different bit width may inhibit some optimizations. +/// +/// In general, a solution to a quadratic equation generated from an addrec +/// may require BW+1 bits, where BW is the bit width of the addrec's +/// coefficients. The reason is that the coefficients of the quadratic +/// equation are BW+1 bits wide (to avoid truncation when converting from +/// the addrec to the equation). +static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) { + if (!X.hasValue()) + return None; + unsigned W = X->getBitWidth(); + if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth)) + return X->trunc(BitWidth); + return X; +} + +/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n +/// iterations. The values L, M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW, +/// where BW is the bit width of the addrec's coefficients. +/// If the calculated value is a BW-bit integer (for BW > 1), it will be +/// returned as such, otherwise the bit width of the returned value may +/// be greater than BW. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases +/// like x^2 = 5, no integer solutions exist, in other cases an integer +/// solution may exist, but SolveQuadraticEquationWrap may fail to find it. +static Optional<APInt> +SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) + return None; - // The B coefficient is M-N/2 - APInt B = M; - B -= A; // A is the same as N/2. + std::tie(A, B, C, M, BitWidth) = *T; + LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n"); + Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1); + if (!X.hasValue()) + return None; - // The C coefficient is L. - const APInt& C = L; + ConstantInt *CX = ConstantInt::get(SE.getContext(), *X); + ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE); + if (!V->isZero()) + return None; - // Compute the B^2-4ac term. - APInt SqrtTerm = B; - SqrtTerm *= B; - SqrtTerm -= 4 * (A * C); + return TruncIfPossible(X, BitWidth); +} - if (SqrtTerm.isNegative()) { - // The loop is provably infinite. +/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n +/// iterations. The values M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n such that c(n) does not belong to the given range, +/// while c(n-1) does. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution for the +/// bounds of the range. +static Optional<APInt> +SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, + const ConstantRange &Range, ScalarEvolution &SE) { + assert(AddRec->getOperand(0)->isZero() && + "Starting value of addrec should be 0"); + LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range " + << Range << ", addrec " << *AddRec << '\n'); + // This case is handled in getNumIterationsInRange. Here we can assume that + // we start in the range. + assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && + "Addrec's initial value should be in range"); + + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) return None; - } - // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest - // integer value or else APInt::sqrt() will assert. - APInt SqrtVal = SqrtTerm.sqrt(); + // Be careful about the return value: there can be two reasons for not + // returning an actual number. First, if no solutions to the equations + // were found, and second, if the solutions don't leave the given range. + // The first case means that the actual solution is "unknown", the second + // means that it's known, but not valid. If the solution is unknown, we + // cannot make any conclusions. + // Return a pair: the optional solution and a flag indicating if the + // solution was found. + auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> { + // Solve for signed overflow and unsigned overflow, pick the lower + // solution. + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary " + << Bound << " (before multiplying by " << M << ")\n"); + Bound *= M; // The quadratic equation multiplier. + + Optional<APInt> SO = None; + if (BitWidth > 1) { + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "signed overflow\n"); + SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth); + } + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "unsigned overflow\n"); + Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, + BitWidth+1); + + auto LeavesRange = [&] (const APInt &X) { + ConstantInt *C0 = ConstantInt::get(SE.getContext(), X); + ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE); + if (Range.contains(V0->getValue())) + return false; + // X should be at least 1, so X-1 is non-negative. + ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1); + ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE); + if (Range.contains(V1->getValue())) + return true; + return false; + }; - // Compute the two solutions for the quadratic formula. - // The divisions must be performed as signed divisions. - APInt NegB = -std::move(B); - APInt TwoA = std::move(A); - TwoA <<= 1; - if (TwoA.isNullValue()) - return None; + // If SolveQuadraticEquationWrap returns None, it means that there can + // be a solution, but the function failed to find it. We cannot treat it + // as "no solution". + if (!SO.hasValue() || !UO.hasValue()) + return { None, false }; + + // Check the smaller value first to see if it leaves the range. + // At this point, both SO and UO must have values. + Optional<APInt> Min = MinOptional(SO, UO); + if (LeavesRange(*Min)) + return { Min, true }; + Optional<APInt> Max = Min == SO ? UO : SO; + if (LeavesRange(*Max)) + return { Max, true }; + + // Solutions were found, but were eliminated, hence the "true". + return { None, true }; + }; - LLVMContext &Context = SE.getContext(); + std::tie(A, B, C, M, BitWidth) = *T; + // Lower bound is inclusive, subtract 1 to represent the exiting value. + APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1; + APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth()); + auto SL = SolveForBoundary(Lower); + auto SU = SolveForBoundary(Upper); + // If any of the solutions was unknown, no meaninigful conclusions can + // be made. + if (!SL.second || !SU.second) + return None; - ConstantInt *Solution1 = - ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = - ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + // Claim: The correct solution is not some value between Min and Max. + // + // Justification: Assuming that Min and Max are different values, one of + // them is when the first signed overflow happens, the other is when the + // first unsigned overflow happens. Crossing the range boundary is only + // possible via an overflow (treating 0 as a special case of it, modeling + // an overflow as crossing k*2^W for some k). + // + // The interesting case here is when Min was eliminated as an invalid + // solution, but Max was not. The argument is that if there was another + // overflow between Min and Max, it would also have been eliminated if + // it was considered. + // + // For a given boundary, it is possible to have two overflows of the same + // type (signed/unsigned) without having the other type in between: this + // can happen when the vertex of the parabola is between the iterations + // corresponding to the overflows. This is only possible when the two + // overflows cross k*2^W for the same k. In such case, if the second one + // left the range (and was the first one to do so), the first overflow + // would have to enter the range, which would mean that either we had left + // the range before or that we started outside of it. Both of these cases + // are contradictions. + // + // Claim: In the case where SolveForBoundary returns None, the correct + // solution is not some value between the Max for this boundary and the + // Min of the other boundary. + // + // Justification: Assume that we had such Max_A and Min_B corresponding + // to range boundaries A and B and such that Max_A < Min_B. If there was + // a solution between Max_A and Min_B, it would have to be caused by an + // overflow corresponding to either A or B. It cannot correspond to B, + // since Min_B is the first occurrence of such an overflow. If it + // corresponded to A, it would have to be either a signed or an unsigned + // overflow that is larger than both eliminated overflows for A. But + // between the eliminated overflows and this overflow, the values would + // cover the entire value space, thus crossing the other boundary, which + // is a contradiction. - return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), - cast<SCEVConstant>(SE.getConstant(Solution2))); + return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } ScalarEvolution::ExitLimit @@ -8441,23 +8650,12 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( - CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - - // We can only use this value if the chrec ends up with an exact zero - // value at this index. When solving for "X*X != 5", for example, we - // should not accept a root of 2. - const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); - if (Val->isZero()) - // We found a quadratic root! - return ExitLimit(R1, R1, false, Predicates); - } + // We can only use this value if the chrec ends up with an exact zero + // value at this index. When solving for "X*X != 5", for example, we + // should not accept a root of 2. + if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { + const auto *R = cast<SCEVConstant>(getConstant(S.getValue())); + return ExitLimit(R, R, false, Predicates); } return getCouldNotCompute(); } @@ -8617,7 +8815,13 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { bool Changed = false; - + // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or + // '0 != 0'. + auto TrivialCase = [&](bool TriviallyTrue) { + LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); + Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return true; + }; // If we hit the max recursion limit bail out. if (Depth >= 3) return false; @@ -8629,9 +8833,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, if (ConstantExpr::getICmp(Pred, LHSC->getValue(), RHSC->getValue())->isNullValue()) - goto trivially_false; + return TrivialCase(false); else - goto trivially_true; + return TrivialCase(true); } // Otherwise swap the operands to put the constant on the right. std::swap(LHS, RHS); @@ -8661,9 +8865,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, if (!ICmpInst::isEquality(Pred)) { ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); if (ExactCR.isFullSet()) - goto trivially_true; + return TrivialCase(true); else if (ExactCR.isEmptySet()) - goto trivially_false; + return TrivialCase(false); APInt NewRHS; CmpInst::Predicate NewPred; @@ -8699,7 +8903,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // The "Should have been caught earlier!" messages refer to the fact // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above // should have fired on the corresponding cases, and canonicalized the - // check to trivially_true or trivially_false. + // check to trivial case. case ICmpInst::ICMP_UGE: assert(!RA.isMinValue() && "Should have been caught earlier!"); @@ -8732,9 +8936,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // Check for obvious equality. if (HasSameValue(LHS, RHS)) { if (ICmpInst::isTrueWhenEqual(Pred)) - goto trivially_true; + return TrivialCase(true); if (ICmpInst::isFalseWhenEqual(Pred)) - goto trivially_false; + return TrivialCase(false); } // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by @@ -8802,18 +9006,6 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); return Changed; - -trivially_true: - // Return 0 == 0. - LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); - Pred = ICmpInst::ICMP_EQ; - return true; - -trivially_false: - // Return 0 != 0. - LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); - Pred = ICmpInst::ICMP_NE; - return true; } bool ScalarEvolution::isKnownNegative(const SCEV *S) { @@ -9184,6 +9376,11 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) return true; @@ -9289,6 +9486,10 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + // Both LHS and RHS must be available at loop entry. assert(isAvailableAtLoopEntry(LHS, L) && "LHS is not available at Loop Entry"); @@ -10565,52 +10766,11 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); - } else if (isQuadratic()) { - // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the - // quadratic equation to solve it. To do this, we must frame our problem in - // terms of figuring out when zero is crossed, instead of when - // Range.getUpper() is crossed. - SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); - NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); - - // Next, solve the constructed addrec - if (auto Roots = - SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( - ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - - // Make sure the root is not off by one. The returned iteration should - // not be in the range, but the previous one should be. When solving - // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = - EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); - if (Range.contains(R1Val->getValue())) { - // The next iteration must be out of the range... - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); - - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (!Range.contains(R1Val->getValue())) - return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened - } + } - // If R1 was not in the range, then it is a good return value. Make - // sure that R1-1 WAS in the range though, just in case. - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (Range.contains(R1Val->getValue())) - return R1; - return SE.getCouldNotCompute(); // Something strange happened - } - } + if (isQuadratic()) { + if (auto S = SolveQuadraticAddRecRange(this, Range, SE)) + return SE.getConstant(S.getValue()); } return SE.getCouldNotCompute(); @@ -10920,7 +11080,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. - llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { + llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) { return numberOfTerms(LHS) > numberOfTerms(RHS); }); diff --git a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp index 7bea994121c8..289d4f8ae49a 100644 --- a/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp +++ b/lib/Analysis/ScalarEvolutionAliasAnalysis.cpp @@ -27,7 +27,7 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, // If either of the memory references is empty, it doesn't matter what the // pointer values are. This allows the code below to ignore this special // case. - if (LocA.Size == 0 || LocB.Size == 0) + if (LocA.Size.isZero() || LocB.Size.isZero()) return NoAlias; // This is SCEVAAResult. Get the SCEVs! @@ -43,8 +43,12 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, if (SE.getEffectiveSCEVType(AS->getType()) == SE.getEffectiveSCEVType(BS->getType())) { unsigned BitWidth = SE.getTypeSizeInBits(AS->getType()); - APInt ASizeInt(BitWidth, LocA.Size); - APInt BSizeInt(BitWidth, LocB.Size); + APInt ASizeInt(BitWidth, LocA.Size.hasValue() + ? LocA.Size.getValue() + : MemoryLocation::UnknownSize); + APInt BSizeInt(BitWidth, LocB.Size.hasValue() + ? LocB.Size.getValue() + : MemoryLocation::UnknownSize); // Compute the difference between the two pointers. const SCEV *BA = SE.getMinusSCEV(BS, AS); @@ -78,10 +82,10 @@ AliasResult SCEVAAResult::alias(const MemoryLocation &LocA, Value *BO = GetBaseValue(BS); if ((AO && AO != LocA.Ptr) || (BO && BO != LocB.Ptr)) if (alias(MemoryLocation(AO ? AO : LocA.Ptr, - AO ? +MemoryLocation::UnknownSize : LocA.Size, + AO ? LocationSize::unknown() : LocA.Size, AO ? AAMDNodes() : LocA.AATags), MemoryLocation(BO ? BO : LocB.Ptr, - BO ? +MemoryLocation::UnknownSize : LocB.Size, + BO ? LocationSize::unknown() : LocB.Size, BO ? AAMDNodes() : LocB.AATags)) == NoAlias) return NoAlias; diff --git a/lib/Analysis/ScalarEvolutionExpander.cpp b/lib/Analysis/ScalarEvolutionExpander.cpp index 8f89389c4b5d..ca5cf1663b83 100644 --- a/lib/Analysis/ScalarEvolutionExpander.cpp +++ b/lib/Analysis/ScalarEvolutionExpander.cpp @@ -1867,7 +1867,7 @@ SCEVExpander::replaceCongruentIVs(Loop *L, const DominatorTree *DT, Phis.push_back(&PN); if (TTI) - llvm::sort(Phis.begin(), Phis.end(), [](Value *LHS, Value *RHS) { + llvm::sort(Phis, [](Value *LHS, Value *RHS) { // Put pointers at the back and make sure pointer < pointer = false. if (!LHS->getType()->isIntegerTy() || !RHS->getType()->isIntegerTy()) return RHS->getType()->isIntegerTy() && !LHS->getType()->isIntegerTy(); diff --git a/lib/Analysis/ScopedNoAliasAA.cpp b/lib/Analysis/ScopedNoAliasAA.cpp index f12275aff387..9a581fe46afc 100644 --- a/lib/Analysis/ScopedNoAliasAA.cpp +++ b/lib/Analysis/ScopedNoAliasAA.cpp @@ -95,39 +95,36 @@ AliasResult ScopedNoAliasAAResult::alias(const MemoryLocation &LocA, return AAResultBase::alias(LocA, LocB); } -ModRefInfo ScopedNoAliasAAResult::getModRefInfo(ImmutableCallSite CS, +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { if (!EnableScopedNoAlias) - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); - if (!mayAliasInScopes(Loc.AATags.Scope, CS.getInstruction()->getMetadata( - LLVMContext::MD_noalias))) + if (!mayAliasInScopes(Loc.AATags.Scope, + Call->getMetadata(LLVMContext::MD_noalias))) return ModRefInfo::NoModRef; - if (!mayAliasInScopes( - CS.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), - Loc.AATags.NoAlias)) + if (!mayAliasInScopes(Call->getMetadata(LLVMContext::MD_alias_scope), + Loc.AATags.NoAlias)) return ModRefInfo::NoModRef; - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); } -ModRefInfo ScopedNoAliasAAResult::getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) { +ModRefInfo ScopedNoAliasAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2) { if (!EnableScopedNoAlias) - return AAResultBase::getModRefInfo(CS1, CS2); + return AAResultBase::getModRefInfo(Call1, Call2); - if (!mayAliasInScopes( - CS1.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), - CS2.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + if (!mayAliasInScopes(Call1->getMetadata(LLVMContext::MD_alias_scope), + Call2->getMetadata(LLVMContext::MD_noalias))) return ModRefInfo::NoModRef; - if (!mayAliasInScopes( - CS2.getInstruction()->getMetadata(LLVMContext::MD_alias_scope), - CS1.getInstruction()->getMetadata(LLVMContext::MD_noalias))) + if (!mayAliasInScopes(Call2->getMetadata(LLVMContext::MD_alias_scope), + Call1->getMetadata(LLVMContext::MD_noalias))) return ModRefInfo::NoModRef; - return AAResultBase::getModRefInfo(CS1, CS2); + return AAResultBase::getModRefInfo(Call1, Call2); } static void collectMDInDomain(const MDNode *List, const MDNode *Domain, diff --git a/lib/Analysis/StackSafetyAnalysis.cpp b/lib/Analysis/StackSafetyAnalysis.cpp new file mode 100644 index 000000000000..66b03845864f --- /dev/null +++ b/lib/Analysis/StackSafetyAnalysis.cpp @@ -0,0 +1,673 @@ +//===- StackSafetyAnalysis.cpp - Stack memory safety analysis -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/StackSafetyAnalysis.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +#define DEBUG_TYPE "stack-safety" + +static cl::opt<int> StackSafetyMaxIterations("stack-safety-max-iterations", + cl::init(20), cl::Hidden); + +namespace { + +/// Rewrite an SCEV expression for a memory access address to an expression that +/// represents offset from the given alloca. +class AllocaOffsetRewriter : public SCEVRewriteVisitor<AllocaOffsetRewriter> { + const Value *AllocaPtr; + +public: + AllocaOffsetRewriter(ScalarEvolution &SE, const Value *AllocaPtr) + : SCEVRewriteVisitor(SE), AllocaPtr(AllocaPtr) {} + + const SCEV *visit(const SCEV *Expr) { + // Only re-write the expression if the alloca is used in an addition + // expression (it can be used in other types of expressions if it's cast to + // an int and passed as an argument.) + if (!isa<SCEVAddRecExpr>(Expr) && !isa<SCEVAddExpr>(Expr) && + !isa<SCEVUnknown>(Expr)) + return Expr; + return SCEVRewriteVisitor<AllocaOffsetRewriter>::visit(Expr); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + // FIXME: look through one or several levels of definitions? + // This can be inttoptr(AllocaPtr) and SCEV would not unwrap + // it for us. + if (Expr->getValue() == AllocaPtr) + return SE.getZero(Expr->getType()); + return Expr; + } +}; + +/// Describes use of address in as a function call argument. +struct PassAsArgInfo { + /// Function being called. + const GlobalValue *Callee = nullptr; + /// Index of argument which pass address. + size_t ParamNo = 0; + // Offset range of address from base address (alloca or calling function + // argument). + // Range should never set to empty-set, that is an invalid access range + // that can cause empty-set to be propagated with ConstantRange::add + ConstantRange Offset; + PassAsArgInfo(const GlobalValue *Callee, size_t ParamNo, ConstantRange Offset) + : Callee(Callee), ParamNo(ParamNo), Offset(Offset) {} + + StringRef getName() const { return Callee->getName(); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const PassAsArgInfo &P) { + return OS << "@" << P.getName() << "(arg" << P.ParamNo << ", " << P.Offset + << ")"; +} + +/// Describe uses of address (alloca or parameter) inside of the function. +struct UseInfo { + // Access range if the address (alloca or parameters). + // It is allowed to be empty-set when there are no known accesses. + ConstantRange Range; + + // List of calls which pass address as an argument. + SmallVector<PassAsArgInfo, 4> Calls; + + explicit UseInfo(unsigned PointerSize) : Range{PointerSize, false} {} + + void updateRange(ConstantRange R) { Range = Range.unionWith(R); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const UseInfo &U) { + OS << U.Range; + for (auto &Call : U.Calls) + OS << ", " << Call; + return OS; +} + +struct AllocaInfo { + const AllocaInst *AI = nullptr; + uint64_t Size = 0; + UseInfo Use; + + AllocaInfo(unsigned PointerSize, const AllocaInst *AI, uint64_t Size) + : AI(AI), Size(Size), Use(PointerSize) {} + + StringRef getName() const { return AI->getName(); } +}; + +raw_ostream &operator<<(raw_ostream &OS, const AllocaInfo &A) { + return OS << A.getName() << "[" << A.Size << "]: " << A.Use; +} + +struct ParamInfo { + const Argument *Arg = nullptr; + UseInfo Use; + + explicit ParamInfo(unsigned PointerSize, const Argument *Arg) + : Arg(Arg), Use(PointerSize) {} + + StringRef getName() const { return Arg ? Arg->getName() : "<N/A>"; } +}; + +raw_ostream &operator<<(raw_ostream &OS, const ParamInfo &P) { + return OS << P.getName() << "[]: " << P.Use; +} + +/// Calculate the allocation size of a given alloca. Returns 0 if the +/// size can not be statically determined. +uint64_t getStaticAllocaAllocationSize(const AllocaInst *AI) { + const DataLayout &DL = AI->getModule()->getDataLayout(); + uint64_t Size = DL.getTypeAllocSize(AI->getAllocatedType()); + if (AI->isArrayAllocation()) { + auto C = dyn_cast<ConstantInt>(AI->getArraySize()); + if (!C) + return 0; + Size *= C->getZExtValue(); + } + return Size; +} + +} // end anonymous namespace + +/// Describes uses of allocas and parameters inside of a single function. +struct StackSafetyInfo::FunctionInfo { + // May be a Function or a GlobalAlias + const GlobalValue *GV = nullptr; + // Informations about allocas uses. + SmallVector<AllocaInfo, 4> Allocas; + // Informations about parameters uses. + SmallVector<ParamInfo, 4> Params; + // TODO: describe return value as depending on one or more of its arguments. + + // StackSafetyDataFlowAnalysis counter stored here for faster access. + int UpdateCount = 0; + + FunctionInfo(const StackSafetyInfo &SSI) : FunctionInfo(*SSI.Info) {} + + explicit FunctionInfo(const Function *F) : GV(F){}; + // Creates FunctionInfo that forwards all the parameters to the aliasee. + explicit FunctionInfo(const GlobalAlias *A); + + FunctionInfo(FunctionInfo &&) = default; + + bool IsDSOLocal() const { return GV->isDSOLocal(); }; + + bool IsInterposable() const { return GV->isInterposable(); }; + + StringRef getName() const { return GV->getName(); } + + void print(raw_ostream &O) const { + // TODO: Consider different printout format after + // StackSafetyDataFlowAnalysis. Calls and parameters are irrelevant then. + O << " @" << getName() << (IsDSOLocal() ? "" : " dso_preemptable") + << (IsInterposable() ? " interposable" : "") << "\n"; + O << " args uses:\n"; + for (auto &P : Params) + O << " " << P << "\n"; + O << " allocas uses:\n"; + for (auto &AS : Allocas) + O << " " << AS << "\n"; + } + +private: + FunctionInfo(const FunctionInfo &) = default; +}; + +StackSafetyInfo::FunctionInfo::FunctionInfo(const GlobalAlias *A) : GV(A) { + unsigned PointerSize = A->getParent()->getDataLayout().getPointerSizeInBits(); + const GlobalObject *Aliasee = A->getBaseObject(); + const FunctionType *Type = cast<FunctionType>(Aliasee->getValueType()); + // 'Forward' all parameters to this alias to the aliasee + for (unsigned ArgNo = 0; ArgNo < Type->getNumParams(); ArgNo++) { + Params.emplace_back(PointerSize, nullptr); + UseInfo &US = Params.back().Use; + US.Calls.emplace_back(Aliasee, ArgNo, ConstantRange(APInt(PointerSize, 0))); + } +} + +namespace { + +class StackSafetyLocalAnalysis { + const Function &F; + const DataLayout &DL; + ScalarEvolution &SE; + unsigned PointerSize = 0; + + const ConstantRange UnknownRange; + + ConstantRange offsetFromAlloca(Value *Addr, const Value *AllocaPtr); + ConstantRange getAccessRange(Value *Addr, const Value *AllocaPtr, + uint64_t AccessSize); + ConstantRange getMemIntrinsicAccessRange(const MemIntrinsic *MI, const Use &U, + const Value *AllocaPtr); + + bool analyzeAllUses(const Value *Ptr, UseInfo &AS); + + ConstantRange getRange(uint64_t Lower, uint64_t Upper) const { + return ConstantRange(APInt(PointerSize, Lower), APInt(PointerSize, Upper)); + } + +public: + StackSafetyLocalAnalysis(const Function &F, ScalarEvolution &SE) + : F(F), DL(F.getParent()->getDataLayout()), SE(SE), + PointerSize(DL.getPointerSizeInBits()), + UnknownRange(PointerSize, true) {} + + // Run the transformation on the associated function. + StackSafetyInfo run(); +}; + +ConstantRange +StackSafetyLocalAnalysis::offsetFromAlloca(Value *Addr, + const Value *AllocaPtr) { + if (!SE.isSCEVable(Addr->getType())) + return UnknownRange; + + AllocaOffsetRewriter Rewriter(SE, AllocaPtr); + const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr)); + ConstantRange Offset = SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize); + assert(!Offset.isEmptySet()); + return Offset; +} + +ConstantRange StackSafetyLocalAnalysis::getAccessRange(Value *Addr, + const Value *AllocaPtr, + uint64_t AccessSize) { + if (!SE.isSCEVable(Addr->getType())) + return UnknownRange; + + AllocaOffsetRewriter Rewriter(SE, AllocaPtr); + const SCEV *Expr = Rewriter.visit(SE.getSCEV(Addr)); + + ConstantRange AccessStartRange = + SE.getUnsignedRange(Expr).zextOrTrunc(PointerSize); + ConstantRange SizeRange = getRange(0, AccessSize); + ConstantRange AccessRange = AccessStartRange.add(SizeRange); + assert(!AccessRange.isEmptySet()); + return AccessRange; +} + +ConstantRange StackSafetyLocalAnalysis::getMemIntrinsicAccessRange( + const MemIntrinsic *MI, const Use &U, const Value *AllocaPtr) { + if (auto MTI = dyn_cast<MemTransferInst>(MI)) { + if (MTI->getRawSource() != U && MTI->getRawDest() != U) + return getRange(0, 1); + } else { + if (MI->getRawDest() != U) + return getRange(0, 1); + } + const auto *Len = dyn_cast<ConstantInt>(MI->getLength()); + // Non-constant size => unsafe. FIXME: try SCEV getRange. + if (!Len) + return UnknownRange; + ConstantRange AccessRange = getAccessRange(U, AllocaPtr, Len->getZExtValue()); + return AccessRange; +} + +/// The function analyzes all local uses of Ptr (alloca or argument) and +/// calculates local access range and all function calls where it was used. +bool StackSafetyLocalAnalysis::analyzeAllUses(const Value *Ptr, UseInfo &US) { + SmallPtrSet<const Value *, 16> Visited; + SmallVector<const Value *, 8> WorkList; + WorkList.push_back(Ptr); + + // A DFS search through all uses of the alloca in bitcasts/PHI/GEPs/etc. + while (!WorkList.empty()) { + const Value *V = WorkList.pop_back_val(); + for (const Use &UI : V->uses()) { + auto I = cast<const Instruction>(UI.getUser()); + assert(V == UI.get()); + + switch (I->getOpcode()) { + case Instruction::Load: { + US.updateRange( + getAccessRange(UI, Ptr, DL.getTypeStoreSize(I->getType()))); + break; + } + + case Instruction::VAArg: + // "va-arg" from a pointer is safe. + break; + case Instruction::Store: { + if (V == I->getOperand(0)) { + // Stored the pointer - conservatively assume it may be unsafe. + US.updateRange(UnknownRange); + return false; + } + US.updateRange(getAccessRange( + UI, Ptr, DL.getTypeStoreSize(I->getOperand(0)->getType()))); + break; + } + + case Instruction::Ret: + // Information leak. + // FIXME: Process parameters correctly. This is a leak only if we return + // alloca. + US.updateRange(UnknownRange); + return false; + + case Instruction::Call: + case Instruction::Invoke: { + ImmutableCallSite CS(I); + + if (I->isLifetimeStartOrEnd()) + break; + + if (const MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) { + US.updateRange(getMemIntrinsicAccessRange(MI, UI, Ptr)); + break; + } + + // FIXME: consult devirt? + // Do not follow aliases, otherwise we could inadvertently follow + // dso_preemptable aliases or aliases with interposable linkage. + const GlobalValue *Callee = dyn_cast<GlobalValue>( + CS.getCalledValue()->stripPointerCastsNoFollowAliases()); + if (!Callee) { + US.updateRange(UnknownRange); + return false; + } + + assert(isa<Function>(Callee) || isa<GlobalAlias>(Callee)); + + ImmutableCallSite::arg_iterator B = CS.arg_begin(), E = CS.arg_end(); + for (ImmutableCallSite::arg_iterator A = B; A != E; ++A) { + if (A->get() == V) { + ConstantRange Offset = offsetFromAlloca(UI, Ptr); + US.Calls.emplace_back(Callee, A - B, Offset); + } + } + + break; + } + + default: + if (Visited.insert(I).second) + WorkList.push_back(cast<const Instruction>(I)); + } + } + } + + return true; +} + +StackSafetyInfo StackSafetyLocalAnalysis::run() { + StackSafetyInfo::FunctionInfo Info(&F); + assert(!F.isDeclaration() && + "Can't run StackSafety on a function declaration"); + + LLVM_DEBUG(dbgs() << "[StackSafety] " << F.getName() << "\n"); + + for (auto &I : instructions(F)) { + if (auto AI = dyn_cast<AllocaInst>(&I)) { + Info.Allocas.emplace_back(PointerSize, AI, + getStaticAllocaAllocationSize(AI)); + AllocaInfo &AS = Info.Allocas.back(); + analyzeAllUses(AI, AS.Use); + } + } + + for (const Argument &A : make_range(F.arg_begin(), F.arg_end())) { + Info.Params.emplace_back(PointerSize, &A); + ParamInfo &PS = Info.Params.back(); + analyzeAllUses(&A, PS.Use); + } + + LLVM_DEBUG(dbgs() << "[StackSafety] done\n"); + LLVM_DEBUG(Info.print(dbgs())); + return StackSafetyInfo(std::move(Info)); +} + +class StackSafetyDataFlowAnalysis { + using FunctionMap = + std::map<const GlobalValue *, StackSafetyInfo::FunctionInfo>; + + FunctionMap Functions; + // Callee-to-Caller multimap. + DenseMap<const GlobalValue *, SmallVector<const GlobalValue *, 4>> Callers; + SetVector<const GlobalValue *> WorkList; + + unsigned PointerSize = 0; + const ConstantRange UnknownRange; + + ConstantRange getArgumentAccessRange(const GlobalValue *Callee, + unsigned ParamNo) const; + bool updateOneUse(UseInfo &US, bool UpdateToFullSet); + void updateOneNode(const GlobalValue *Callee, + StackSafetyInfo::FunctionInfo &FS); + void updateOneNode(const GlobalValue *Callee) { + updateOneNode(Callee, Functions.find(Callee)->second); + } + void updateAllNodes() { + for (auto &F : Functions) + updateOneNode(F.first, F.second); + } + void runDataFlow(); + void verifyFixedPoint(); + +public: + StackSafetyDataFlowAnalysis( + Module &M, std::function<const StackSafetyInfo &(Function &)> FI); + StackSafetyGlobalInfo run(); +}; + +StackSafetyDataFlowAnalysis::StackSafetyDataFlowAnalysis( + Module &M, std::function<const StackSafetyInfo &(Function &)> FI) + : PointerSize(M.getDataLayout().getPointerSizeInBits()), + UnknownRange(PointerSize, true) { + // Without ThinLTO, run the local analysis for every function in the TU and + // then run the DFA. + for (auto &F : M.functions()) + if (!F.isDeclaration()) + Functions.emplace(&F, FI(F)); + for (auto &A : M.aliases()) + if (isa<Function>(A.getBaseObject())) + Functions.emplace(&A, StackSafetyInfo::FunctionInfo(&A)); +} + +ConstantRange +StackSafetyDataFlowAnalysis::getArgumentAccessRange(const GlobalValue *Callee, + unsigned ParamNo) const { + auto IT = Functions.find(Callee); + // Unknown callee (outside of LTO domain or an indirect call). + if (IT == Functions.end()) + return UnknownRange; + const StackSafetyInfo::FunctionInfo &FS = IT->second; + // The definition of this symbol may not be the definition in this linkage + // unit. + if (!FS.IsDSOLocal() || FS.IsInterposable()) + return UnknownRange; + if (ParamNo >= FS.Params.size()) // possibly vararg + return UnknownRange; + return FS.Params[ParamNo].Use.Range; +} + +bool StackSafetyDataFlowAnalysis::updateOneUse(UseInfo &US, + bool UpdateToFullSet) { + bool Changed = false; + for (auto &CS : US.Calls) { + assert(!CS.Offset.isEmptySet() && + "Param range can't be empty-set, invalid offset range"); + + ConstantRange CalleeRange = getArgumentAccessRange(CS.Callee, CS.ParamNo); + CalleeRange = CalleeRange.add(CS.Offset); + if (!US.Range.contains(CalleeRange)) { + Changed = true; + if (UpdateToFullSet) + US.Range = UnknownRange; + else + US.Range = US.Range.unionWith(CalleeRange); + } + } + return Changed; +} + +void StackSafetyDataFlowAnalysis::updateOneNode( + const GlobalValue *Callee, StackSafetyInfo::FunctionInfo &FS) { + bool UpdateToFullSet = FS.UpdateCount > StackSafetyMaxIterations; + bool Changed = false; + for (auto &AS : FS.Allocas) + Changed |= updateOneUse(AS.Use, UpdateToFullSet); + for (auto &PS : FS.Params) + Changed |= updateOneUse(PS.Use, UpdateToFullSet); + + if (Changed) { + LLVM_DEBUG(dbgs() << "=== update [" << FS.UpdateCount + << (UpdateToFullSet ? ", full-set" : "") << "] " + << FS.getName() << "\n"); + // Callers of this function may need updating. + for (auto &CallerID : Callers[Callee]) + WorkList.insert(CallerID); + + ++FS.UpdateCount; + } +} + +void StackSafetyDataFlowAnalysis::runDataFlow() { + Callers.clear(); + WorkList.clear(); + + SmallVector<const GlobalValue *, 16> Callees; + for (auto &F : Functions) { + Callees.clear(); + StackSafetyInfo::FunctionInfo &FS = F.second; + for (auto &AS : FS.Allocas) + for (auto &CS : AS.Use.Calls) + Callees.push_back(CS.Callee); + for (auto &PS : FS.Params) + for (auto &CS : PS.Use.Calls) + Callees.push_back(CS.Callee); + + llvm::sort(Callees); + Callees.erase(std::unique(Callees.begin(), Callees.end()), Callees.end()); + + for (auto &Callee : Callees) + Callers[Callee].push_back(F.first); + } + + updateAllNodes(); + + while (!WorkList.empty()) { + const GlobalValue *Callee = WorkList.back(); + WorkList.pop_back(); + updateOneNode(Callee); + } +} + +void StackSafetyDataFlowAnalysis::verifyFixedPoint() { + WorkList.clear(); + updateAllNodes(); + assert(WorkList.empty()); +} + +StackSafetyGlobalInfo StackSafetyDataFlowAnalysis::run() { + runDataFlow(); + LLVM_DEBUG(verifyFixedPoint()); + + StackSafetyGlobalInfo SSI; + for (auto &F : Functions) + SSI.emplace(F.first, std::move(F.second)); + return SSI; +} + +void print(const StackSafetyGlobalInfo &SSI, raw_ostream &O, const Module &M) { + size_t Count = 0; + for (auto &F : M.functions()) + if (!F.isDeclaration()) { + SSI.find(&F)->second.print(O); + O << "\n"; + ++Count; + } + for (auto &A : M.aliases()) { + SSI.find(&A)->second.print(O); + O << "\n"; + ++Count; + } + assert(Count == SSI.size() && "Unexpected functions in the result"); +} + +} // end anonymous namespace + +StackSafetyInfo::StackSafetyInfo() = default; +StackSafetyInfo::StackSafetyInfo(StackSafetyInfo &&) = default; +StackSafetyInfo &StackSafetyInfo::operator=(StackSafetyInfo &&) = default; + +StackSafetyInfo::StackSafetyInfo(FunctionInfo &&Info) + : Info(new FunctionInfo(std::move(Info))) {} + +StackSafetyInfo::~StackSafetyInfo() = default; + +void StackSafetyInfo::print(raw_ostream &O) const { Info->print(O); } + +AnalysisKey StackSafetyAnalysis::Key; + +StackSafetyInfo StackSafetyAnalysis::run(Function &F, + FunctionAnalysisManager &AM) { + StackSafetyLocalAnalysis SSLA(F, AM.getResult<ScalarEvolutionAnalysis>(F)); + return SSLA.run(); +} + +PreservedAnalyses StackSafetyPrinterPass::run(Function &F, + FunctionAnalysisManager &AM) { + OS << "'Stack Safety Local Analysis' for function '" << F.getName() << "'\n"; + AM.getResult<StackSafetyAnalysis>(F).print(OS); + return PreservedAnalyses::all(); +} + +char StackSafetyInfoWrapperPass::ID = 0; + +StackSafetyInfoWrapperPass::StackSafetyInfoWrapperPass() : FunctionPass(ID) { + initializeStackSafetyInfoWrapperPassPass(*PassRegistry::getPassRegistry()); +} + +void StackSafetyInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.setPreservesAll(); +} + +void StackSafetyInfoWrapperPass::print(raw_ostream &O, const Module *M) const { + SSI.print(O); +} + +bool StackSafetyInfoWrapperPass::runOnFunction(Function &F) { + StackSafetyLocalAnalysis SSLA( + F, getAnalysis<ScalarEvolutionWrapperPass>().getSE()); + SSI = StackSafetyInfo(SSLA.run()); + return false; +} + +AnalysisKey StackSafetyGlobalAnalysis::Key; + +StackSafetyGlobalInfo +StackSafetyGlobalAnalysis::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + StackSafetyDataFlowAnalysis SSDFA( + M, [&FAM](Function &F) -> const StackSafetyInfo & { + return FAM.getResult<StackSafetyAnalysis>(F); + }); + return SSDFA.run(); +} + +PreservedAnalyses StackSafetyGlobalPrinterPass::run(Module &M, + ModuleAnalysisManager &AM) { + OS << "'Stack Safety Analysis' for module '" << M.getName() << "'\n"; + print(AM.getResult<StackSafetyGlobalAnalysis>(M), OS, M); + return PreservedAnalyses::all(); +} + +char StackSafetyGlobalInfoWrapperPass::ID = 0; + +StackSafetyGlobalInfoWrapperPass::StackSafetyGlobalInfoWrapperPass() + : ModulePass(ID) { + initializeStackSafetyGlobalInfoWrapperPassPass( + *PassRegistry::getPassRegistry()); +} + +void StackSafetyGlobalInfoWrapperPass::print(raw_ostream &O, + const Module *M) const { + ::print(SSI, O, *M); +} + +void StackSafetyGlobalInfoWrapperPass::getAnalysisUsage( + AnalysisUsage &AU) const { + AU.addRequired<StackSafetyInfoWrapperPass>(); +} + +bool StackSafetyGlobalInfoWrapperPass::runOnModule(Module &M) { + StackSafetyDataFlowAnalysis SSDFA( + M, [this](Function &F) -> const StackSafetyInfo & { + return getAnalysis<StackSafetyInfoWrapperPass>(F).getResult(); + }); + SSI = SSDFA.run(); + return false; +} + +static const char LocalPassArg[] = "stack-safety-local"; +static const char LocalPassName[] = "Stack Safety Local Analysis"; +INITIALIZE_PASS_BEGIN(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName, + false, true) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(StackSafetyInfoWrapperPass, LocalPassArg, LocalPassName, + false, true) + +static const char GlobalPassName[] = "Stack Safety Analysis"; +INITIALIZE_PASS_BEGIN(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE, + GlobalPassName, false, false) +INITIALIZE_PASS_DEPENDENCY(StackSafetyInfoWrapperPass) +INITIALIZE_PASS_END(StackSafetyGlobalInfoWrapperPass, DEBUG_TYPE, + GlobalPassName, false, false) diff --git a/lib/Analysis/SyncDependenceAnalysis.cpp b/lib/Analysis/SyncDependenceAnalysis.cpp new file mode 100644 index 000000000000..e1a7e4476d12 --- /dev/null +++ b/lib/Analysis/SyncDependenceAnalysis.cpp @@ -0,0 +1,380 @@ +//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation +//--===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements an algorithm that returns for a divergent branch +// the set of basic blocks whose phi nodes become divergent due to divergent +// control. These are the blocks that are reachable by two disjoint paths from +// the branch or loop exits that have a reaching path that is disjoint from a +// path to the loop latch. +// +// The SyncDependenceAnalysis is used in the DivergenceAnalysis to model +// control-induced divergence in phi nodes. +// +// -- Summary -- +// The SyncDependenceAnalysis lazily computes sync dependences [3]. +// The analysis evaluates the disjoint path criterion [2] by a reduction +// to SSA construction. The SSA construction algorithm is implemented as +// a simple data-flow analysis [1]. +// +// [1] "A Simple, Fast Dominance Algorithm", SPI '01, Cooper, Harvey and Kennedy +// [2] "Efficiently Computing Static Single Assignment Form +// and the Control Dependence Graph", TOPLAS '91, +// Cytron, Ferrante, Rosen, Wegman and Zadeck +// [3] "Improving Performance of OpenCL on CPUs", CC '12, Karrenberg and Hack +// [4] "Divergence Analysis", TOPLAS '13, Sampaio, Souza, Collange and Pereira +// +// -- Sync dependence -- +// Sync dependence [4] characterizes the control flow aspect of the +// propagation of branch divergence. For example, +// +// %cond = icmp slt i32 %tid, 10 +// br i1 %cond, label %then, label %else +// then: +// br label %merge +// else: +// br label %merge +// merge: +// %a = phi i32 [ 0, %then ], [ 1, %else ] +// +// Suppose %tid holds the thread ID. Although %a is not data dependent on %tid +// because %tid is not on its use-def chains, %a is sync dependent on %tid +// because the branch "br i1 %cond" depends on %tid and affects which value %a +// is assigned to. +// +// -- Reduction to SSA construction -- +// There are two disjoint paths from A to X, if a certain variant of SSA +// construction places a phi node in X under the following set-up scheme [2]. +// +// This variant of SSA construction ignores incoming undef values. +// That is paths from the entry without a definition do not result in +// phi nodes. +// +// entry +// / \ +// A \ +// / \ Y +// B C / +// \ / \ / +// D E +// \ / +// F +// Assume that A contains a divergent branch. We are interested +// in the set of all blocks where each block is reachable from A +// via two disjoint paths. This would be the set {D, F} in this +// case. +// To generally reduce this query to SSA construction we introduce +// a virtual variable x and assign to x different values in each +// successor block of A. +// entry +// / \ +// A \ +// / \ Y +// x = 0 x = 1 / +// \ / \ / +// D E +// \ / +// F +// Our flavor of SSA construction for x will construct the following +// entry +// / \ +// A \ +// / \ Y +// x0 = 0 x1 = 1 / +// \ / \ / +// x2=phi E +// \ / +// x3=phi +// The blocks D and F contain phi nodes and are thus each reachable +// by two disjoins paths from A. +// +// -- Remarks -- +// In case of loop exits we need to check the disjoint path criterion for loops +// [2]. To this end, we check whether the definition of x differs between the +// loop exit and the loop header (_after_ SSA construction). +// +//===----------------------------------------------------------------------===// +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/SyncDependenceAnalysis.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" + +#include <stack> +#include <unordered_set> + +#define DEBUG_TYPE "sync-dependence" + +namespace llvm { + +ConstBlockSet SyncDependenceAnalysis::EmptyBlockSet; + +SyncDependenceAnalysis::SyncDependenceAnalysis(const DominatorTree &DT, + const PostDominatorTree &PDT, + const LoopInfo &LI) + : FuncRPOT(DT.getRoot()->getParent()), DT(DT), PDT(PDT), LI(LI) {} + +SyncDependenceAnalysis::~SyncDependenceAnalysis() {} + +using FunctionRPOT = ReversePostOrderTraversal<const Function *>; + +// divergence propagator for reducible CFGs +struct DivergencePropagator { + const FunctionRPOT &FuncRPOT; + const DominatorTree &DT; + const PostDominatorTree &PDT; + const LoopInfo &LI; + + // identified join points + std::unique_ptr<ConstBlockSet> JoinBlocks; + + // reached loop exits (by a path disjoint to a path to the loop header) + SmallPtrSet<const BasicBlock *, 4> ReachedLoopExits; + + // if DefMap[B] == C then C is the dominating definition at block B + // if DefMap[B] ~ undef then we haven't seen B yet + // if DefMap[B] == B then B is a join point of disjoint paths from X or B is + // an immediate successor of X (initial value). + using DefiningBlockMap = std::map<const BasicBlock *, const BasicBlock *>; + DefiningBlockMap DefMap; + + // all blocks with pending visits + std::unordered_set<const BasicBlock *> PendingUpdates; + + DivergencePropagator(const FunctionRPOT &FuncRPOT, const DominatorTree &DT, + const PostDominatorTree &PDT, const LoopInfo &LI) + : FuncRPOT(FuncRPOT), DT(DT), PDT(PDT), LI(LI), + JoinBlocks(new ConstBlockSet) {} + + // set the definition at @block and mark @block as pending for a visit + void addPending(const BasicBlock &Block, const BasicBlock &DefBlock) { + bool WasAdded = DefMap.emplace(&Block, &DefBlock).second; + if (WasAdded) + PendingUpdates.insert(&Block); + } + + void printDefs(raw_ostream &Out) { + Out << "Propagator::DefMap {\n"; + for (const auto *Block : FuncRPOT) { + auto It = DefMap.find(Block); + Out << Block->getName() << " : "; + if (It == DefMap.end()) { + Out << "\n"; + } else { + const auto *DefBlock = It->second; + Out << (DefBlock ? DefBlock->getName() : "<null>") << "\n"; + } + } + Out << "}\n"; + } + + // process @succBlock with reaching definition @defBlock + // the original divergent branch was in @parentLoop (if any) + void visitSuccessor(const BasicBlock &SuccBlock, const Loop *ParentLoop, + const BasicBlock &DefBlock) { + + // @succBlock is a loop exit + if (ParentLoop && !ParentLoop->contains(&SuccBlock)) { + DefMap.emplace(&SuccBlock, &DefBlock); + ReachedLoopExits.insert(&SuccBlock); + return; + } + + // first reaching def? + auto ItLastDef = DefMap.find(&SuccBlock); + if (ItLastDef == DefMap.end()) { + addPending(SuccBlock, DefBlock); + return; + } + + // a join of at least two definitions + if (ItLastDef->second != &DefBlock) { + // do we know this join already? + if (!JoinBlocks->insert(&SuccBlock).second) + return; + + // update the definition + addPending(SuccBlock, SuccBlock); + } + } + + // find all blocks reachable by two disjoint paths from @rootTerm. + // This method works for both divergent terminators and loops with + // divergent exits. + // @rootBlock is either the block containing the branch or the header of the + // divergent loop. + // @nodeSuccessors is the set of successors of the node (Loop or Terminator) + // headed by @rootBlock. + // @parentLoop is the parent loop of the Loop or the loop that contains the + // Terminator. + template <typename SuccessorIterable> + std::unique_ptr<ConstBlockSet> + computeJoinPoints(const BasicBlock &RootBlock, + SuccessorIterable NodeSuccessors, const Loop *ParentLoop) { + assert(JoinBlocks); + + // immediate post dominator (no join block beyond that block) + const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(&RootBlock)); + const auto *IpdNode = PdNode->getIDom(); + const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr; + + // bootstrap with branch targets + for (const auto *SuccBlock : NodeSuccessors) { + DefMap.emplace(SuccBlock, SuccBlock); + + if (ParentLoop && !ParentLoop->contains(SuccBlock)) { + // immediate loop exit from node. + ReachedLoopExits.insert(SuccBlock); + continue; + } else { + // regular successor + PendingUpdates.insert(SuccBlock); + } + } + + auto ItBeginRPO = FuncRPOT.begin(); + + // skip until term (TODO RPOT won't let us start at @term directly) + for (; *ItBeginRPO != &RootBlock; ++ItBeginRPO) {} + + auto ItEndRPO = FuncRPOT.end(); + assert(ItBeginRPO != ItEndRPO); + + // propagate definitions at the immediate successors of the node in RPO + auto ItBlockRPO = ItBeginRPO; + while (++ItBlockRPO != ItEndRPO && *ItBlockRPO != PdBoundBlock) { + const auto *Block = *ItBlockRPO; + + // skip @block if not pending update + auto ItPending = PendingUpdates.find(Block); + if (ItPending == PendingUpdates.end()) + continue; + PendingUpdates.erase(ItPending); + + // propagate definition at @block to its successors + auto ItDef = DefMap.find(Block); + const auto *DefBlock = ItDef->second; + assert(DefBlock); + + auto *BlockLoop = LI.getLoopFor(Block); + if (ParentLoop && + (ParentLoop != BlockLoop && ParentLoop->contains(BlockLoop))) { + // if the successor is the header of a nested loop pretend its a + // single node with the loop's exits as successors + SmallVector<BasicBlock *, 4> BlockLoopExits; + BlockLoop->getExitBlocks(BlockLoopExits); + for (const auto *BlockLoopExit : BlockLoopExits) { + visitSuccessor(*BlockLoopExit, ParentLoop, *DefBlock); + } + + } else { + // the successors are either on the same loop level or loop exits + for (const auto *SuccBlock : successors(Block)) { + visitSuccessor(*SuccBlock, ParentLoop, *DefBlock); + } + } + } + + // We need to know the definition at the parent loop header to decide + // whether the definition at the header is different from the definition at + // the loop exits, which would indicate a divergent loop exits. + // + // A // loop header + // | + // B // nested loop header + // | + // C -> X (exit from B loop) -..-> (A latch) + // | + // D -> back to B (B latch) + // | + // proper exit from both loops + // + // D post-dominates B as it is the only proper exit from the "A loop". + // If C has a divergent branch, propagation will therefore stop at D. + // That implies that B will never receive a definition. + // But that definition can only be the same as at D (D itself in thise case) + // because all paths to anywhere have to pass through D. + // + const BasicBlock *ParentLoopHeader = + ParentLoop ? ParentLoop->getHeader() : nullptr; + if (ParentLoop && ParentLoop->contains(PdBoundBlock)) { + DefMap[ParentLoopHeader] = DefMap[PdBoundBlock]; + } + + // analyze reached loop exits + if (!ReachedLoopExits.empty()) { + assert(ParentLoop); + const auto *HeaderDefBlock = DefMap[ParentLoopHeader]; + LLVM_DEBUG(printDefs(dbgs())); + assert(HeaderDefBlock && "no definition in header of carrying loop"); + + for (const auto *ExitBlock : ReachedLoopExits) { + auto ItExitDef = DefMap.find(ExitBlock); + assert((ItExitDef != DefMap.end()) && + "no reaching def at reachable loop exit"); + if (ItExitDef->second != HeaderDefBlock) { + JoinBlocks->insert(ExitBlock); + } + } + } + + return std::move(JoinBlocks); + } +}; + +const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) { + using LoopExitVec = SmallVector<BasicBlock *, 4>; + LoopExitVec LoopExits; + Loop.getExitBlocks(LoopExits); + if (LoopExits.size() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedLoopExitJoins.find(&Loop); + if (ItCached != CachedLoopExitJoins.end()) + return *ItCached->second; + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>( + *Loop.getHeader(), LoopExits, Loop.getParentLoop()); + + auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +const ConstBlockSet & +SyncDependenceAnalysis::join_blocks(const Instruction &Term) { + // trivial case + if (Term.getNumSuccessors() < 1) { + return EmptyBlockSet; + } + + // already available in cache? + auto ItCached = CachedBranchJoins.find(&Term); + if (ItCached != CachedBranchJoins.end()) + return *ItCached->second; + + // compute all join points + DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI}; + const auto &TermBlock = *Term.getParent(); + auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>( + TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock)); + + auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks)); + assert(ItInserted.second); + return *ItInserted.first->second; +} + +} // namespace llvm diff --git a/lib/Analysis/SyntheticCountsUtils.cpp b/lib/Analysis/SyntheticCountsUtils.cpp index b085fa274d7f..c2d7bb11a4cf 100644 --- a/lib/Analysis/SyntheticCountsUtils.cpp +++ b/lib/Analysis/SyntheticCountsUtils.cpp @@ -14,22 +14,21 @@ #include "llvm/Analysis/SyntheticCountsUtils.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SCCIterator.h" -#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/ModuleSummaryIndex.h" using namespace llvm; // Given an SCC, propagate entry counts along the edge of the SCC nodes. template <typename CallGraphType> void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( - const SccTy &SCC, GetRelBBFreqTy GetRelBBFreq, GetCountTy GetCount, - AddCountTy AddCount) { + const SccTy &SCC, GetProfCountTy GetProfCount, AddCountTy AddCount) { - SmallPtrSet<NodeRef, 8> SCCNodes; + DenseSet<NodeRef> SCCNodes; SmallVector<std::pair<NodeRef, EdgeRef>, 8> SCCEdges, NonSCCEdges; for (auto &Node : SCC) @@ -54,17 +53,13 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( // This ensures that the order of // traversal of nodes within the SCC doesn't affect the final result. - DenseMap<NodeRef, uint64_t> AdditionalCounts; + DenseMap<NodeRef, Scaled64> AdditionalCounts; for (auto &E : SCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - uint64_t AdditionalCount = RelFreq.toInt<uint64_t>(); - AdditionalCounts[Callee] += AdditionalCount; + AdditionalCounts[Callee] += OptProfCount.getValue(); } // Update the counts for the nodes in the SCC. @@ -73,14 +68,11 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( // Now update the counts for nodes outside the SCC. for (auto &E : NonSCCEdges) { - auto OptRelFreq = GetRelBBFreq(E.second); - if (!OptRelFreq) + auto OptProfCount = GetProfCount(E.first, E.second); + if (!OptProfCount) continue; - Scaled64 RelFreq = OptRelFreq.getValue(); - auto Caller = E.first; auto Callee = CGT::edge_dest(E.second); - RelFreq *= Scaled64(GetCount(Caller), 0); - AddCount(Callee, RelFreq.toInt<uint64_t>()); + AddCount(Callee, OptProfCount.getValue()); } } @@ -94,8 +86,7 @@ void SyntheticCountsUtils<CallGraphType>::propagateFromSCC( template <typename CallGraphType> void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, - GetRelBBFreqTy GetRelBBFreq, - GetCountTy GetCount, + GetProfCountTy GetProfCount, AddCountTy AddCount) { std::vector<SccTy> SCCs; @@ -107,7 +98,8 @@ void SyntheticCountsUtils<CallGraphType>::propagate(const CallGraphType &CG, // The scc iterator returns the scc in bottom-up order, so reverse the SCCs // and call propagateFromSCC. for (auto &SCC : reverse(SCCs)) - propagateFromSCC(SCC, GetRelBBFreq, GetCount, AddCount); + propagateFromSCC(SCC, GetProfCount, AddCount); } template class llvm::SyntheticCountsUtils<const CallGraph *>; +template class llvm::SyntheticCountsUtils<ModuleSummaryIndex *>; diff --git a/lib/Analysis/TargetLibraryInfo.cpp b/lib/Analysis/TargetLibraryInfo.cpp index 102135fbf313..4643f75da42d 100644 --- a/lib/Analysis/TargetLibraryInfo.cpp +++ b/lib/Analysis/TargetLibraryInfo.cpp @@ -413,17 +413,17 @@ static void initialize(TargetLibraryInfoImpl &TLI, const Triple &T, TLI.setUnavailable(LibFunc_flsll); } - // The following functions are available on Linux, - // but Android uses bionic instead of glibc. - if (!T.isOSLinux() || T.isAndroid()) { + // The following functions are only available on GNU/Linux (using glibc). + // Linux variants without glibc (eg: bionic, musl) may have some subset. + if (!T.isOSLinux() || !T.isGNUEnvironment()) { TLI.setUnavailable(LibFunc_dunder_strdup); TLI.setUnavailable(LibFunc_dunder_strtok_r); TLI.setUnavailable(LibFunc_dunder_isoc99_scanf); TLI.setUnavailable(LibFunc_dunder_isoc99_sscanf); TLI.setUnavailable(LibFunc_under_IO_getc); TLI.setUnavailable(LibFunc_under_IO_putc); - // But, Android has memalign. - if (!T.isAndroid()) + // But, Android and musl have memalign. + if (!T.isAndroid() && !T.isMusl()) TLI.setUnavailable(LibFunc_memalign); TLI.setUnavailable(LibFunc_fopen64); TLI.setUnavailable(LibFunc_fseeko64); @@ -613,6 +613,24 @@ bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, unsigned NumParams = FTy.getNumParams(); switch (F) { + case LibFunc_execl: + case LibFunc_execlp: + case LibFunc_execle: + return (NumParams >= 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_execv: + case LibFunc_execvp: + return (NumParams == 2 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); + case LibFunc_execvP: + case LibFunc_execvpe: + case LibFunc_execve: + return (NumParams == 3 && FTy.getParamType(0)->isPointerTy() && + FTy.getParamType(1)->isPointerTy() && + FTy.getParamType(2)->isPointerTy() && + FTy.getReturnType()->isIntegerTy(32)); case LibFunc_strlen: return (NumParams == 1 && FTy.getParamType(0)->isPointerTy() && FTy.getReturnType()->isIntegerTy()); @@ -863,6 +881,8 @@ bool TargetLibraryInfoImpl::isValidProtoForLibFunc(const FunctionType &FTy, return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(0)->isPointerTy() && FTy.getParamType(1)->isPointerTy()); + case LibFunc_fork: + return (NumParams == 0 && FTy.getReturnType()->isIntegerTy(32)); case LibFunc_fdopen: return (NumParams == 2 && FTy.getReturnType()->isPointerTy() && FTy.getParamType(1)->isPointerTy()); @@ -1399,10 +1419,10 @@ static bool compareWithVectorFnName(const VecDesc &LHS, StringRef S) { void TargetLibraryInfoImpl::addVectorizableFunctions(ArrayRef<VecDesc> Fns) { VectorDescs.insert(VectorDescs.end(), Fns.begin(), Fns.end()); - llvm::sort(VectorDescs.begin(), VectorDescs.end(), compareByScalarFnName); + llvm::sort(VectorDescs, compareByScalarFnName); ScalarDescs.insert(ScalarDescs.end(), Fns.begin(), Fns.end()); - llvm::sort(ScalarDescs.begin(), ScalarDescs.end(), compareByVectorFnName); + llvm::sort(ScalarDescs, compareByVectorFnName); } void TargetLibraryInfoImpl::addVectorizableFunctionsFromVecLib( diff --git a/lib/Analysis/TargetTransformInfo.cpp b/lib/Analysis/TargetTransformInfo.cpp index 7233a86e5daf..9151d46c6cce 100644 --- a/lib/Analysis/TargetTransformInfo.cpp +++ b/lib/Analysis/TargetTransformInfo.cpp @@ -268,6 +268,10 @@ bool TargetTransformInfo::enableInterleavedAccessVectorization() const { return TTIImpl->enableInterleavedAccessVectorization(); } +bool TargetTransformInfo::enableMaskedInterleavedAccessVectorization() const { + return TTIImpl->enableMaskedInterleavedAccessVectorization(); +} + bool TargetTransformInfo::isFPVectorizationPotentiallyUnsafe() const { return TTIImpl->isFPVectorizationPotentiallyUnsafe(); } @@ -384,6 +388,55 @@ unsigned TargetTransformInfo::getMaxInterleaveFactor(unsigned VF) const { return TTIImpl->getMaxInterleaveFactor(VF); } +TargetTransformInfo::OperandValueKind +TargetTransformInfo::getOperandInfo(Value *V, OperandValueProperties &OpProps) { + OperandValueKind OpInfo = OK_AnyValue; + OpProps = OP_None; + + if (auto *CI = dyn_cast<ConstantInt>(V)) { + if (CI->getValue().isPowerOf2()) + OpProps = OP_PowerOf2; + return OK_UniformConstantValue; + } + + // A broadcast shuffle creates a uniform value. + // TODO: Add support for non-zero index broadcasts. + // TODO: Add support for different source vector width. + if (auto *ShuffleInst = dyn_cast<ShuffleVectorInst>(V)) + if (ShuffleInst->isZeroEltSplat()) + OpInfo = OK_UniformValue; + + const Value *Splat = getSplatValue(V); + + // Check for a splat of a constant or for a non uniform vector of constants + // and check if the constant(s) are all powers of two. + if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { + OpInfo = OK_NonUniformConstantValue; + if (Splat) { + OpInfo = OK_UniformConstantValue; + if (auto *CI = dyn_cast<ConstantInt>(Splat)) + if (CI->getValue().isPowerOf2()) + OpProps = OP_PowerOf2; + } else if (auto *CDS = dyn_cast<ConstantDataSequential>(V)) { + OpProps = OP_PowerOf2; + for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { + if (auto *CI = dyn_cast<ConstantInt>(CDS->getElementAsConstant(I))) + if (CI->getValue().isPowerOf2()) + continue; + OpProps = OP_None; + break; + } + } + } + + // Check for a splat of a uniform value. This is not loop aware, so return + // true only for the obviously uniform cases (argument, globalvalue) + if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat))) + OpInfo = OK_UniformValue; + + return OpInfo; +} + int TargetTransformInfo::getArithmeticInstrCost( unsigned Opcode, Type *Ty, OperandValueKind Opd1Info, OperandValueKind Opd2Info, OperandValueProperties Opd1PropInfo, @@ -472,9 +525,12 @@ int TargetTransformInfo::getGatherScatterOpCost(unsigned Opcode, Type *DataTy, int TargetTransformInfo::getInterleavedMemoryOpCost( unsigned Opcode, Type *VecTy, unsigned Factor, ArrayRef<unsigned> Indices, - unsigned Alignment, unsigned AddressSpace) const { + unsigned Alignment, unsigned AddressSpace, bool UseMaskForCond, + bool UseMaskForGaps) const { int Cost = TTIImpl->getInterleavedMemoryOpCost(Opcode, VecTy, Factor, Indices, - Alignment, AddressSpace); + Alignment, AddressSpace, + UseMaskForCond, + UseMaskForGaps); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } @@ -569,6 +625,12 @@ bool TargetTransformInfo::areInlineCompatible(const Function *Caller, return TTIImpl->areInlineCompatible(Caller, Callee); } +bool TargetTransformInfo::areFunctionArgsABICompatible( + const Function *Caller, const Function *Callee, + SmallPtrSetImpl<Argument *> &Args) const { + return TTIImpl->areFunctionArgsABICompatible(Caller, Callee, Args); +} + bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const { return TTIImpl->isIndexedLoadLegal(Mode, Ty); @@ -630,49 +692,6 @@ int TargetTransformInfo::getInstructionLatency(const Instruction *I) const { return TTIImpl->getInstructionLatency(I); } -static TargetTransformInfo::OperandValueKind -getOperandInfo(Value *V, TargetTransformInfo::OperandValueProperties &OpProps) { - TargetTransformInfo::OperandValueKind OpInfo = - TargetTransformInfo::OK_AnyValue; - OpProps = TargetTransformInfo::OP_None; - - if (auto *CI = dyn_cast<ConstantInt>(V)) { - if (CI->getValue().isPowerOf2()) - OpProps = TargetTransformInfo::OP_PowerOf2; - return TargetTransformInfo::OK_UniformConstantValue; - } - - const Value *Splat = getSplatValue(V); - - // Check for a splat of a constant or for a non uniform vector of constants - // and check if the constant(s) are all powers of two. - if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { - OpInfo = TargetTransformInfo::OK_NonUniformConstantValue; - if (Splat) { - OpInfo = TargetTransformInfo::OK_UniformConstantValue; - if (auto *CI = dyn_cast<ConstantInt>(Splat)) - if (CI->getValue().isPowerOf2()) - OpProps = TargetTransformInfo::OP_PowerOf2; - } else if (auto *CDS = dyn_cast<ConstantDataSequential>(V)) { - OpProps = TargetTransformInfo::OP_PowerOf2; - for (unsigned I = 0, E = CDS->getNumElements(); I != E; ++I) { - if (auto *CI = dyn_cast<ConstantInt>(CDS->getElementAsConstant(I))) - if (CI->getValue().isPowerOf2()) - continue; - OpProps = TargetTransformInfo::OP_None; - break; - } - } - } - - // Check for a splat of a uniform value. This is not loop aware, so return - // true only for the obviously uniform cases (argument, globalvalue) - if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat))) - OpInfo = TargetTransformInfo::OK_UniformValue; - - return OpInfo; -} - static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, unsigned Level) { // We don't need a shuffle if we just want to have element 0 in position 0 of @@ -1101,14 +1120,20 @@ int TargetTransformInfo::getInstructionThroughput(const Instruction *I) const { } case Instruction::ShuffleVector: { const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); - // TODO: Identify and add costs for insert/extract subvector, etc. + Type *Ty = Shuffle->getType(); + Type *SrcTy = Shuffle->getOperand(0)->getType(); + + // TODO: Identify and add costs for insert subvector, etc. + int SubIndex; + if (Shuffle->isExtractSubvectorMask(SubIndex)) + return TTIImpl->getShuffleCost(SK_ExtractSubvector, SrcTy, SubIndex, Ty); + if (Shuffle->changesLength()) return -1; if (Shuffle->isIdentity()) return 0; - Type *Ty = Shuffle->getType(); if (Shuffle->isReverse()) return TTIImpl->getShuffleCost(SK_Reverse, Ty, 0, nullptr); diff --git a/lib/Analysis/TypeBasedAliasAnalysis.cpp b/lib/Analysis/TypeBasedAliasAnalysis.cpp index 25a154edf4ac..83974da30a54 100644 --- a/lib/Analysis/TypeBasedAliasAnalysis.cpp +++ b/lib/Analysis/TypeBasedAliasAnalysis.cpp @@ -399,20 +399,20 @@ bool TypeBasedAAResult::pointsToConstantMemory(const MemoryLocation &Loc, } FunctionModRefBehavior -TypeBasedAAResult::getModRefBehavior(ImmutableCallSite CS) { +TypeBasedAAResult::getModRefBehavior(const CallBase *Call) { if (!EnableTBAA) - return AAResultBase::getModRefBehavior(CS); + return AAResultBase::getModRefBehavior(Call); FunctionModRefBehavior Min = FMRB_UnknownModRefBehavior; // If this is an "immutable" type, we can assume the call doesn't write // to memory. - if (const MDNode *M = CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa)) if ((!isStructPathTBAA(M) && TBAANode(M).isTypeImmutable()) || (isStructPathTBAA(M) && TBAAStructTagNode(M).isTypeImmutable())) Min = FMRB_OnlyReadsMemory; - return FunctionModRefBehavior(AAResultBase::getModRefBehavior(CS) & Min); + return FunctionModRefBehavior(AAResultBase::getModRefBehavior(Call) & Min); } FunctionModRefBehavior TypeBasedAAResult::getModRefBehavior(const Function *F) { @@ -420,33 +420,30 @@ FunctionModRefBehavior TypeBasedAAResult::getModRefBehavior(const Function *F) { return AAResultBase::getModRefBehavior(F); } -ModRefInfo TypeBasedAAResult::getModRefInfo(ImmutableCallSite CS, +ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call, const MemoryLocation &Loc) { if (!EnableTBAA) - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); if (const MDNode *L = Loc.AATags.TBAA) - if (const MDNode *M = - CS.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + if (const MDNode *M = Call->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(L, M)) return ModRefInfo::NoModRef; - return AAResultBase::getModRefInfo(CS, Loc); + return AAResultBase::getModRefInfo(Call, Loc); } -ModRefInfo TypeBasedAAResult::getModRefInfo(ImmutableCallSite CS1, - ImmutableCallSite CS2) { +ModRefInfo TypeBasedAAResult::getModRefInfo(const CallBase *Call1, + const CallBase *Call2) { if (!EnableTBAA) - return AAResultBase::getModRefInfo(CS1, CS2); + return AAResultBase::getModRefInfo(Call1, Call2); - if (const MDNode *M1 = - CS1.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) - if (const MDNode *M2 = - CS2.getInstruction()->getMetadata(LLVMContext::MD_tbaa)) + if (const MDNode *M1 = Call1->getMetadata(LLVMContext::MD_tbaa)) + if (const MDNode *M2 = Call2->getMetadata(LLVMContext::MD_tbaa)) if (!Aliases(M1, M2)) return ModRefInfo::NoModRef; - return AAResultBase::getModRefInfo(CS1, CS2); + return AAResultBase::getModRefInfo(Call1, Call2); } bool MDNode::isTBAAVtableAccess() const { diff --git a/lib/Analysis/TypeMetadataUtils.cpp b/lib/Analysis/TypeMetadataUtils.cpp index 6871e4887c9e..bd13a43b8d46 100644 --- a/lib/Analysis/TypeMetadataUtils.cpp +++ b/lib/Analysis/TypeMetadataUtils.cpp @@ -14,6 +14,7 @@ #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" @@ -22,11 +23,21 @@ using namespace llvm; // Search for virtual calls that call FPtr and add them to DevirtCalls. static void findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, - bool *HasNonCallUses, Value *FPtr, uint64_t Offset) { + bool *HasNonCallUses, Value *FPtr, uint64_t Offset, + const CallInst *CI, DominatorTree &DT) { for (const Use &U : FPtr->uses()) { - Value *User = U.getUser(); + Instruction *User = cast<Instruction>(U.getUser()); + // Ignore this instruction if it is not dominated by the type intrinsic + // being analyzed. Otherwise we may transform a call sharing the same + // vtable pointer incorrectly. Specifically, this situation can arise + // after indirect call promotion and inlining, where we may have uses + // of the vtable pointer guarded by a function pointer check, and a fallback + // indirect call. + if (!DT.dominates(CI, User)) + continue; if (isa<BitCastInst>(User)) { - findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset); + findCallsAtConstantOffset(DevirtCalls, HasNonCallUses, User, Offset, CI, + DT); } else if (auto CI = dyn_cast<CallInst>(User)) { DevirtCalls.push_back({Offset, CI}); } else if (auto II = dyn_cast<InvokeInst>(User)) { @@ -38,23 +49,23 @@ findCallsAtConstantOffset(SmallVectorImpl<DevirtCallSite> &DevirtCalls, } // Search for virtual calls that load from VPtr and add them to DevirtCalls. -static void -findLoadCallsAtConstantOffset(const Module *M, - SmallVectorImpl<DevirtCallSite> &DevirtCalls, - Value *VPtr, int64_t Offset) { +static void findLoadCallsAtConstantOffset( + const Module *M, SmallVectorImpl<DevirtCallSite> &DevirtCalls, Value *VPtr, + int64_t Offset, const CallInst *CI, DominatorTree &DT) { for (const Use &U : VPtr->uses()) { Value *User = U.getUser(); if (isa<BitCastInst>(User)) { - findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset); + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset, CI, DT); } else if (isa<LoadInst>(User)) { - findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset); + findCallsAtConstantOffset(DevirtCalls, nullptr, User, Offset, CI, DT); } else if (auto GEP = dyn_cast<GetElementPtrInst>(User)) { // Take into account the GEP offset. if (VPtr == GEP->getPointerOperand() && GEP->hasAllConstantIndices()) { SmallVector<Value *, 8> Indices(GEP->op_begin() + 1, GEP->op_end()); int64_t GEPOffset = M->getDataLayout().getIndexedOffsetInType( GEP->getSourceElementType(), Indices); - findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset); + findLoadCallsAtConstantOffset(M, DevirtCalls, User, Offset + GEPOffset, + CI, DT); } } } @@ -62,7 +73,8 @@ findLoadCallsAtConstantOffset(const Module *M, void llvm::findDevirtualizableCallsForTypeTest( SmallVectorImpl<DevirtCallSite> &DevirtCalls, - SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI) { + SmallVectorImpl<CallInst *> &Assumes, const CallInst *CI, + DominatorTree &DT) { assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_test); const Module *M = CI->getParent()->getParent()->getParent(); @@ -79,15 +91,15 @@ void llvm::findDevirtualizableCallsForTypeTest( // If we found any, search for virtual calls based on %p and add them to // DevirtCalls. if (!Assumes.empty()) - findLoadCallsAtConstantOffset(M, DevirtCalls, - CI->getArgOperand(0)->stripPointerCasts(), 0); + findLoadCallsAtConstantOffset( + M, DevirtCalls, CI->getArgOperand(0)->stripPointerCasts(), 0, CI, DT); } void llvm::findDevirtualizableCallsForTypeCheckedLoad( SmallVectorImpl<DevirtCallSite> &DevirtCalls, SmallVectorImpl<Instruction *> &LoadedPtrs, SmallVectorImpl<Instruction *> &Preds, bool &HasNonCallUses, - const CallInst *CI) { + const CallInst *CI, DominatorTree &DT) { assert(CI->getCalledFunction()->getIntrinsicID() == Intrinsic::type_checked_load); @@ -114,5 +126,5 @@ void llvm::findDevirtualizableCallsForTypeCheckedLoad( for (Value *LoadedPtr : LoadedPtrs) findCallsAtConstantOffset(DevirtCalls, &HasNonCallUses, LoadedPtr, - Offset->getZExtValue()); + Offset->getZExtValue(), CI, DT); } diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index 0ef39163bda3..0446426c0e66 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" @@ -118,14 +119,18 @@ struct Query { /// (all of which can call computeKnownBits), and so on. std::array<const Value *, MaxDepth> Excluded; + /// If true, it is safe to use metadata during simplification. + InstrInfoQuery IIQ; + unsigned NumExcluded = 0; Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE) {} + const DominatorTree *DT, bool UseInstrInfo, + OptimizationRemarkEmitter *ORE = nullptr) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ), NumExcluded(Q.NumExcluded) { Excluded = Q.Excluded; Excluded[NumExcluded++] = NewExcl; @@ -165,9 +170,9 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE) { + OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { ::computeKnownBits(V, Known, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } static KnownBits computeKnownBits(const Value *V, unsigned Depth, @@ -177,15 +182,16 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE) { - return ::computeKnownBits(V, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); + OptimizationRemarkEmitter *ORE, + bool UseInstrInfo) { + return ::computeKnownBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { assert(LHS->getType() == RHS->getType() && "LHS and RHS should have the same type"); assert(LHS->getType()->isIntOrIntVectorTy() && @@ -201,8 +207,8 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); - computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT); - computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT); + computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); } @@ -222,69 +228,71 @@ static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, const Query &Q); bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, - bool OrZero, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + bool OrZero, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownToBeAPowerOfTwo( + V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownNonZero(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownNonZero(V, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, - unsigned Depth, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); return Known.isNonNegative(); } bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { + const DominatorTree *DT, bool UseInstrInfo) { if (auto *CI = dyn_cast<ConstantInt>(V)) return CI->getValue().isStrictlyPositive(); // TODO: We'd doing two recursive queries here. We should factor this such // that only a single query is needed. - return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT) && - isKnownNonZero(V, DL, Depth, AC, CxtI, DT); + return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT, UseInstrInfo) && + isKnownNonZero(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); } bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + const DominatorTree *DT, bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); return Known.isNegative(); } static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q); bool llvm::isKnownNonEqual(const Value *V1, const Value *V2, - const DataLayout &DL, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownNonEqual(V1, V2, Query(DL, AC, - safeCxtI(V1, safeCxtI(V2, CxtI)), - DT)); + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + return ::isKnownNonEqual(V1, V2, + Query(DL, AC, safeCxtI(V1, safeCxtI(V2, CxtI)), DT, + UseInstrInfo, /*ORE=*/nullptr)); } static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, const Query &Q); bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, - const DataLayout &DL, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, const DominatorTree *DT) { - return ::MaskedValueIsZero(V, Mask, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::MaskedValueIsZero( + V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, @@ -293,8 +301,9 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::ComputeNumSignBits(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::ComputeNumSignBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, @@ -965,7 +974,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, switch (I->getOpcode()) { default: break; case Instruction::Load: - if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range)) + if (MDNode *MD = + Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); break; case Instruction::And: { @@ -1014,7 +1024,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } case Instruction::Mul: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; @@ -1082,7 +1092,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // RHS from matchSelectPattern returns the negation part of abs pattern. // If the negate has an NSW flag we can assume the sign bit of the result // will be 0 because that makes abs(INT_MIN) undefined. - if (cast<Instruction>(RHS)->hasNoSignedWrap()) + if (Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS))) MaxHighZeros = 1; } @@ -1151,7 +1161,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, } case Instruction::Shl: { // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { APInt KZResult = KnownZero << ShiftAmt; KZResult.setLowBits(ShiftAmt); // Low bits known 0. @@ -1202,13 +1212,13 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } case Instruction::Sub: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; } case Instruction::Add: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; @@ -1369,7 +1379,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, Known3.countMinTrailingZeros())); auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU); - if (OverflowOp && OverflowOp->hasNoSignedWrap()) { + if (OverflowOp && Q.IIQ.hasNoSignedWrap(OverflowOp)) { // If initial value of recurrence is nonnegative, and we are adding // a nonnegative number with nsw, the result can only be nonnegative // or poison value regardless of the number of times we execute the @@ -1442,7 +1452,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // If range metadata is attached to this call, set known bits from that, // and then intersect with known bits based on other properties of the // function. - if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range)) + if (MDNode *MD = + Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { computeKnownBits(RV, Known2, Depth + 1, Q); @@ -1495,6 +1506,27 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // of bits which might be set provided by popcnt KnownOne2. break; } + case Intrinsic::fshr: + case Intrinsic::fshl: { + const APInt *SA; + if (!match(I->getOperand(2), m_APInt(SA))) + break; + + // Normalize to funnel shift left. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + KnownBits Known3(Known); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q); + + Known.Zero = + Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt); + Known.One = + Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt); + break; + } case Intrinsic::x86_sse42_crc32_64_64: Known.Zero.setBitsFrom(32); break; @@ -1722,7 +1754,8 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, // either the original power-of-two, a larger power-of-two or zero. if (match(V, m_Add(m_Value(X), m_Value(Y)))) { const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V); - if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { + if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) || + Q.IIQ.hasNoSignedWrap(VOBO)) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) @@ -1860,19 +1893,41 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) continue; + SmallVector<const User *, 4> WorkList; + SmallPtrSet<const User *, 4> Visited; for (auto *CmpU : U->users()) { - if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { - assert(BI->isConditional() && "uses a comparison!"); + assert(WorkList.empty() && "Should be!"); + if (Visited.insert(CmpU).second) + WorkList.push_back(CmpU); + + while (!WorkList.empty()) { + auto *Curr = WorkList.pop_back_val(); + + // If a user is an AND, add all its users to the work list. We only + // propagate "pred != null" condition through AND because it is only + // correct to assume that all conditions of AND are met in true branch. + // TODO: Support similar logic of OR and EQ predicate? + if (Pred == ICmpInst::ICMP_NE) + if (auto *BO = dyn_cast<BinaryOperator>(Curr)) + if (BO->getOpcode() == Instruction::And) { + for (auto *BOU : BO->users()) + if (Visited.insert(BOU).second) + WorkList.push_back(BOU); + continue; + } - BasicBlock *NonNullSuccessor = - BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); - BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); - if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) { + assert(BI->isConditional() && "uses a comparison!"); + + BasicBlock *NonNullSuccessor = + BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); + BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); + if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + return true; + } else if (Pred == ICmpInst::ICMP_NE && isGuard(Curr) && + DT->dominates(cast<Instruction>(Curr), CtxI)) { return true; - } else if (Pred == ICmpInst::ICMP_NE && - match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && - DT->dominates(cast<Instruction>(CmpU), CtxI)) { - return true; + } } } } @@ -1937,7 +1992,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } if (auto *I = dyn_cast<Instruction>(V)) { - if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { + if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) { // If the possible ranges don't contain zero, then the value is // definitely non-zero. if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { @@ -1965,13 +2020,13 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // A Load tagged with nonnull metadata is never null. if (const LoadInst *LI = dyn_cast<LoadInst>(V)) - if (LI->getMetadata(LLVMContext::MD_nonnull)) + if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull)) return true; - if (auto CS = ImmutableCallSite(V)) { - if (CS.isReturnNonNull()) + if (const auto *Call = dyn_cast<CallBase>(V)) { + if (Call->isReturnNonNull()) return true; - if (const auto *RP = getArgumentAliasingToReturnedPointer(CS)) + if (const auto *RP = getArgumentAliasingToReturnedPointer(Call)) return isKnownNonZero(RP, Depth, Q); } } @@ -2003,7 +2058,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (match(V, m_Shl(m_Value(X), m_Value(Y)))) { // shl nuw can't remove any non-zero bits. const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); - if (BO->hasNoUnsignedWrap()) + if (Q.IIQ.hasNoUnsignedWrap(BO)) return isKnownNonZero(X, Depth, Q); KnownBits Known(BitWidth); @@ -2078,7 +2133,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. - if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && + if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) return true; } @@ -2100,7 +2155,8 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) { if (!C->isZero() && !C->isNegative()) { ConstantInt *X; - if ((match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || + if (Q.IIQ.UseInstrInfo && + (match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) && !X->isNegative()) return true; @@ -2174,6 +2230,36 @@ bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, return Mask.isSubsetOf(Known.Zero); } +// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow). +// Returns the input and lower/upper bounds. +static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, + const APInt *&CLow, const APInt *&CHigh) { + assert(isa<Operator>(Select) && + cast<Operator>(Select)->getOpcode() == Instruction::Select && + "Input should be a Select!"); + + const Value *LHS, *RHS, *LHS2, *RHS2; + SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor; + if (SPF != SPF_SMAX && SPF != SPF_SMIN) + return false; + + if (!match(RHS, m_APInt(CLow))) + return false; + + SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor; + if (getInverseMinMaxFlavor(SPF) != SPF2) + return false; + + if (!match(RHS2, m_APInt(CHigh))) + return false; + + if (SPF == SPF_SMIN) + std::swap(CLow, CHigh); + + In = LHS2; + return CLow->sle(*CHigh); +} + /// For vector constants, loop over the elements and find the constant with the /// minimum number of sign bits. Return 0 if the value is not a vector constant /// or if any element was not analyzed; otherwise, return the count for the @@ -2335,11 +2421,19 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, } break; - case Instruction::Select: + case Instruction::Select: { + // If we have a clamp pattern, we know that the number of sign bits will be + // the minimum of the clamp min/max range. + const Value *X; + const APInt *CLow, *CHigh; + if (isSignedMinMaxClamp(U, X, CLow, CHigh)) + return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits()); + Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); if (Tmp == 1) break; Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); return std::min(Tmp, Tmp2); + } case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output @@ -2437,6 +2531,44 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // valid for all elements of the vector (for example if vector is sign // extended, shifted, etc). return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + + case Instruction::ShuffleVector: { + // TODO: This is copied almost directly from the SelectionDAG version of + // ComputeNumSignBits. It would be better if we could share common + // code. If not, make sure that changes are translated to the DAG. + + // Collect the minimum number of sign bits that are shared by every vector + // element referenced by the shuffle. + auto *Shuf = cast<ShuffleVectorInst>(U); + int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); + int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); + APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); + for (int i = 0; i != NumMaskElts; ++i) { + int M = Shuf->getMaskValue(i); + assert(M < NumElts * 2 && "Invalid shuffle mask constant"); + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M == -1) + return 1; + if (M < NumElts) + DemandedLHS.setBit(M % NumElts); + else + DemandedRHS.setBit(M % NumElts); + } + Tmp = std::numeric_limits<unsigned>::max(); + if (!!DemandedLHS) + Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q); + if (!!DemandedRHS) { + Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q); + Tmp = std::min(Tmp, Tmp2); + } + // If we don't know anything, early out and try computeKnownBits fall-back. + if (Tmp == 1) + break; + assert(Tmp <= V->getType()->getScalarSizeInBits() && + "Failed to determine minimum sign bits"); + return Tmp; + } } // Finally, if we can prove that the top bits of the result are 0's or 1's, @@ -2722,6 +2854,7 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, break; // sqrt(-0.0) = -0.0, no other negative results are possible. case Intrinsic::sqrt: + case Intrinsic::canonicalize: return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); // fabs(x) != -0.0 case Intrinsic::fabs: @@ -2817,11 +2950,20 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, default: break; case Intrinsic::maxnum: + return (isKnownNeverNaN(I->getOperand(0), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, + SignBitOnly, Depth + 1)) || + (isKnownNeverNaN(I->getOperand(1), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, + SignBitOnly, Depth + 1)); + + case Intrinsic::maximum: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1) || cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, Depth + 1); case Intrinsic::minnum: + case Intrinsic::minimum: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1) && cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, @@ -2882,7 +3024,8 @@ bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); } -bool llvm::isKnownNeverNaN(const Value *V) { +bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, + unsigned Depth) { assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); // If we're told that NaNs won't happen, assume they won't. @@ -2890,13 +3033,60 @@ bool llvm::isKnownNeverNaN(const Value *V) { if (FPMathOp->hasNoNaNs()) return true; - // TODO: Handle instructions and potentially recurse like other 'isKnown' - // functions. For example, the result of sitofp is never NaN. - // Handle scalar constants. if (auto *CFP = dyn_cast<ConstantFP>(V)) return !CFP->isNaN(); + if (Depth == MaxDepth) + return false; + + if (auto *Inst = dyn_cast<Instruction>(V)) { + switch (Inst->getOpcode()) { + case Instruction::FAdd: + case Instruction::FMul: + case Instruction::FSub: + case Instruction::FDiv: + case Instruction::FRem: { + // TODO: Need isKnownNeverInfinity + return false; + } + case Instruction::Select: { + return isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && + isKnownNeverNaN(Inst->getOperand(2), TLI, Depth + 1); + } + case Instruction::SIToFP: + case Instruction::UIToFP: + return true; + case Instruction::FPTrunc: + case Instruction::FPExt: + return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1); + default: + break; + } + } + + if (const auto *II = dyn_cast<IntrinsicInst>(V)) { + switch (II->getIntrinsicID()) { + case Intrinsic::canonicalize: + case Intrinsic::fabs: + case Intrinsic::copysign: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1); + case Intrinsic::sqrt: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(II->getArgOperand(0), TLI); + default: + return false; + } + } + // Bail out for constant expressions, but try to handle vector constants. if (!V->getType()->isVectorTy() || !isa<Constant>(V)) return false; @@ -2917,62 +3107,92 @@ bool llvm::isKnownNeverNaN(const Value *V) { return true; } -/// If the specified value can be set by repeating the same byte in memory, -/// return the i8 value that it is represented with. This is -/// true for all i8 values obviously, but is also true for i32 0, i32 -1, -/// i16 0xF0F0, double 0.0 etc. If the value can't be handled with a repeated -/// byte store (e.g. i16 0x1234), return null. Value *llvm::isBytewiseValue(Value *V) { + // All byte-wide stores are splatable, even of arbitrary variables. - if (V->getType()->isIntegerTy(8)) return V; + if (V->getType()->isIntegerTy(8)) + return V; + + LLVMContext &Ctx = V->getContext(); + + // Undef don't care. + auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx)); + if (isa<UndefValue>(V)) + return UndefInt8; + + Constant *C = dyn_cast<Constant>(V); + if (!C) { + // Conceptually, we could handle things like: + // %a = zext i8 %X to i16 + // %b = shl i16 %a, 8 + // %c = or i16 %a, %b + // but until there is an example that actually needs this, it doesn't seem + // worth worrying about. + return nullptr; + } // Handle 'null' ConstantArrayZero etc. - if (Constant *C = dyn_cast<Constant>(V)) - if (C->isNullValue()) - return Constant::getNullValue(Type::getInt8Ty(V->getContext())); + if (C->isNullValue()) + return Constant::getNullValue(Type::getInt8Ty(Ctx)); - // Constant float and double values can be handled as integer values if the + // Constant floating-point values can be handled as integer values if the // corresponding integer value is "byteable". An important case is 0.0. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { - if (CFP->getType()->isFloatTy()) - V = ConstantExpr::getBitCast(CFP, Type::getInt32Ty(V->getContext())); - if (CFP->getType()->isDoubleTy()) - V = ConstantExpr::getBitCast(CFP, Type::getInt64Ty(V->getContext())); + if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { + Type *Ty = nullptr; + if (CFP->getType()->isHalfTy()) + Ty = Type::getInt16Ty(Ctx); + else if (CFP->getType()->isFloatTy()) + Ty = Type::getInt32Ty(Ctx); + else if (CFP->getType()->isDoubleTy()) + Ty = Type::getInt64Ty(Ctx); // Don't handle long double formats, which have strange constraints. + return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty)) : nullptr; } // We can handle constant integers that are multiple of 8 bits. - if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) { if (CI->getBitWidth() % 8 == 0) { assert(CI->getBitWidth() > 8 && "8 bits should be handled above!"); - if (!CI->getValue().isSplat(8)) return nullptr; - return ConstantInt::get(V->getContext(), CI->getValue().trunc(8)); + return ConstantInt::get(Ctx, CI->getValue().trunc(8)); } } - // A ConstantDataArray/Vector is splatable if all its members are equal and - // also splatable. - if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(V)) { - Value *Elt = CA->getElementAsConstant(0); - Value *Val = isBytewiseValue(Elt); - if (!Val) + auto Merge = [&](Value *LHS, Value *RHS) -> Value * { + if (LHS == RHS) + return LHS; + if (!LHS || !RHS) return nullptr; + if (LHS == UndefInt8) + return RHS; + if (RHS == UndefInt8) + return LHS; + return nullptr; + }; - for (unsigned I = 1, E = CA->getNumElements(); I != E; ++I) - if (CA->getElementAsConstant(I) != Elt) + if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I))))) return nullptr; + return Val; + } + if (isa<ConstantVector>(C)) { + Constant *Splat = cast<ConstantVector>(C)->getSplatValue(); + return Splat ? isBytewiseValue(Splat) : nullptr; + } + + if (isa<ConstantArray>(C) || isa<ConstantStruct>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I))))) + return nullptr; return Val; } - // Conceptually, we could handle things like: - // %a = zext i8 %X to i16 - // %b = shl i16 %a, 8 - // %c = or i16 %a, %b - // but until there is an example that actually needs this, it doesn't seem - // worth worrying about. + // Don't try to handle the handful of other constants. return nullptr; } @@ -3169,7 +3389,14 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, if (!GEP->accumulateConstantOffset(DL, GEPOffset)) break; - ByteOffset += GEPOffset.getSExtValue(); + APInt OrigByteOffset(ByteOffset); + ByteOffset += GEPOffset.sextOrTrunc(ByteOffset.getBitWidth()); + if (ByteOffset.getMinSignedBits() > 64) { + // Stop traversal if the pointer offset wouldn't fit into int64_t + // (this should be removed if Offset is updated to an APInt) + ByteOffset = OrigByteOffset; + break; + } Ptr = GEP->getPointerOperand(); } else if (Operator::getOpcode(Ptr) == Instruction::BitCast || @@ -3397,21 +3624,21 @@ uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { return Len == ~0ULL ? 1 : Len; } -const Value *llvm::getArgumentAliasingToReturnedPointer(ImmutableCallSite CS) { - assert(CS && - "getArgumentAliasingToReturnedPointer only works on nonnull CallSite"); - if (const Value *RV = CS.getReturnedArgOperand()) +const Value *llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call) { + assert(Call && + "getArgumentAliasingToReturnedPointer only works on nonnull calls"); + if (const Value *RV = Call->getReturnedArgOperand()) return RV; // This can be used only as a aliasing property. - if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) - return CS.getArgOperand(0); + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call)) + return Call->getArgOperand(0); return nullptr; } bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( - ImmutableCallSite CS) { - return CS.getIntrinsicID() == Intrinsic::launder_invariant_group || - CS.getIntrinsicID() == Intrinsic::strip_invariant_group; + const CallBase *Call) { + return Call->getIntrinsicID() == Intrinsic::launder_invariant_group || + Call->getIntrinsicID() == Intrinsic::strip_invariant_group; } /// \p PN defines a loop-variant pointer to an object. Check if the @@ -3459,7 +3686,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // An alloca can't be further simplified. return V; } else { - if (auto CS = CallSite(V)) { + if (auto *Call = dyn_cast<CallBase>(V)) { // CaptureTracking can know about special capturing properties of some // intrinsics like launder.invariant.group, that can't be expressed with // the attributes, but have properties like returning aliasing pointer. @@ -3469,7 +3696,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // because it should be in sync with CaptureTracking. Not using it may // cause weird miscompilations where 2 aliasing pointers are assumed to // noalias. - if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { + if (auto *RP = getArgumentAliasingToReturnedPointer(Call)) { V = RP; continue; } @@ -3602,8 +3829,7 @@ bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); if (!II) return false; - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) + if (!II->isLifetimeStartOrEnd()) return false; } return true; @@ -3700,12 +3926,10 @@ bool llvm::mayBeMemoryDependent(const Instruction &I) { return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); } -OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { +OverflowResult llvm::computeOverflowForUnsignedMul( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -3715,8 +3939,10 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); KnownBits LHSKnown(BitWidth); KnownBits RHSKnown(BitWidth); - computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); - computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); + computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr, + UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr, + UseInstrInfo); // Note that underestimating the number of zero bits gives a more // conservative answer. unsigned ZeroBits = LHSKnown.countMinLeadingZeros() + @@ -3747,12 +3973,11 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, return OverflowResult::MayOverflow; } -OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { +OverflowResult +llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS, + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -3781,33 +4006,33 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, // product is exactly the minimum negative number. // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 // For simplicity we just check if at least one side is not negative. - KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) return OverflowResult::NeverOverflows; } return OverflowResult::MayOverflow; } -OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +OverflowResult llvm::computeOverflowForUnsignedAdd( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNegative() && RHSKnown.isNegative()) { // The sign bit is set in both cases: this MUST overflow. - // Create a simple add instruction, and insert it into the struct. return OverflowResult::AlwaysOverflows; } if (LHSKnown.isNonNegative() && RHSKnown.isNonNegative()) { // The sign bit is clear in both cases: this CANNOT overflow. - // Create a simple add instruction, and insert it into the struct. return OverflowResult::NeverOverflows; } } @@ -3924,11 +4149,18 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - // If the LHS is negative and the RHS is non-negative, no unsigned wrap. KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); - if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) - return OverflowResult::NeverOverflows; + if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. + if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) + return OverflowResult::NeverOverflows; + + // If the LHS is non-negative and the RHS negative, we always wrap. + if (LHSKnown.isNonNegative() && RHSKnown.isNegative()) + return OverflowResult::AlwaysOverflows; + } return OverflowResult::MayOverflow; } @@ -4241,12 +4473,34 @@ static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) { if (auto *C = dyn_cast<ConstantFP>(V)) return !C->isNaN(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isNaN()) + return false; + } + return true; + } + return false; } static bool isKnownNonZero(const Value *V) { if (auto *C = dyn_cast<ConstantFP>(V)) return !C->isZero(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isZero()) + return false; + } + return true; + } + return false; } @@ -4538,6 +4792,27 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, unsigned Depth) { + if (CmpInst::isFPPredicate(Pred)) { + // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one + // 0.0 operand, set the compare's 0.0 operands to that same value for the + // purpose of identifying min/max. Disregard vector constants with undefined + // elements because those can not be back-propagated for analysis. + Value *OutputZeroVal = nullptr; + if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && + !cast<Constant>(TrueVal)->containsUndefElement()) + OutputZeroVal = TrueVal; + else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && + !cast<Constant>(FalseVal)->containsUndefElement()) + OutputZeroVal = FalseVal; + + if (OutputZeroVal) { + if (match(CmpLHS, m_AnyZeroFP())) + CmpLHS = OutputZeroVal; + if (match(CmpRHS, m_AnyZeroFP())) + CmpRHS = OutputZeroVal; + } + } + LHS = CmpLHS; RHS = CmpRHS; @@ -4967,21 +5242,16 @@ static bool isMatchingOps(const Value *ALHS, const Value *ARHS, return IsMatchingOps || IsSwappedOps; } -/// Return true if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS BRHS" is -/// true. Return false if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS -/// BRHS" is false. Otherwise, return None if we can't infer anything. +/// Return true if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is true. +/// Return false if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is false. +/// Otherwise, return None if we can't infer anything. static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, - const Value *ALHS, - const Value *ARHS, CmpInst::Predicate BPred, - const Value *BLHS, - const Value *BRHS, - bool IsSwappedOps) { - // Canonicalize the operands so they're matching. - if (IsSwappedOps) { - std::swap(BLHS, BRHS); + bool AreSwappedOps) { + // Canonicalize the predicate as if the operands were not commuted. + if (AreSwappedOps) BPred = ICmpInst::getSwappedPredicate(BPred); - } + if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred)) return true; if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred)) @@ -4990,15 +5260,14 @@ static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, return None; } -/// Return true if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS C2" is -/// true. Return false if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS -/// C2" is false. Otherwise, return None if we can't infer anything. +/// Return true if "icmp APred X, C1" implies "icmp BPred X, C2" is true. +/// Return false if "icmp APred X, C1" implies "icmp BPred X, C2" is false. +/// Otherwise, return None if we can't infer anything. static Optional<bool> -isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const Value *ALHS, +isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const ConstantInt *C1, CmpInst::Predicate BPred, - const Value *BLHS, const ConstantInt *C2) { - assert(ALHS == BLHS && "LHS operands must match."); + const ConstantInt *C2) { ConstantRange DomCR = ConstantRange::makeExactICmpRegion(APred, C1->getValue()); ConstantRange CR = @@ -5030,10 +5299,10 @@ static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, ICmpInst::Predicate BPred = RHS->getPredicate(); // Can we infer anything when the two compares have matching operands? - bool IsSwappedOps; - if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, IsSwappedOps)) { + bool AreSwappedOps; + if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, AreSwappedOps)) { if (Optional<bool> Implication = isImpliedCondMatchingOperands( - APred, ALHS, ARHS, BPred, BLHS, BRHS, IsSwappedOps)) + APred, BPred, AreSwappedOps)) return Implication; // No amount of additional analysis will infer the second condition, so // early exit. @@ -5044,8 +5313,7 @@ static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, // constants (not necessarily matching)? if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) { if (Optional<bool> Implication = isImpliedCondMatchingImmOperands( - APred, ALHS, cast<ConstantInt>(ARHS), BPred, BLHS, - cast<ConstantInt>(BRHS))) + APred, cast<ConstantInt>(ARHS), BPred, cast<ConstantInt>(BRHS))) return Implication; // No amount of additional analysis will infer the second condition, so // early exit. @@ -5130,3 +5398,35 @@ Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, } return None; } + +Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond, + const Instruction *ContextI, + const DataLayout &DL) { + assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool"); + if (!ContextI || !ContextI->getParent()) + return None; + + // TODO: This is a poor/cheap way to determine dominance. Should we use a + // dominator tree (eg, from a SimplifyQuery) instead? + const BasicBlock *ContextBB = ContextI->getParent(); + const BasicBlock *PredBB = ContextBB->getSinglePredecessor(); + if (!PredBB) + return None; + + // We need a conditional branch in the predecessor. + Value *PredCond; + BasicBlock *TrueBB, *FalseBB; + if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB))) + return None; + + // The branch should get simplified. Don't bother simplifying this condition. + if (TrueBB == FalseBB) + return None; + + assert((TrueBB == ContextBB || FalseBB == ContextBB) && + "Predecessor block does not point to successor?"); + + // Is this condition implied by the predecessor condition? + bool CondIsTrue = TrueBB == ContextBB; + return isImpliedCondition(PredCond, Cond, DL, CondIsTrue); +} diff --git a/lib/Analysis/VectorUtils.cpp b/lib/Analysis/VectorUtils.cpp index d73d24736439..5656a19d7e0d 100644 --- a/lib/Analysis/VectorUtils.cpp +++ b/lib/Analysis/VectorUtils.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/EquivalenceClasses.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -25,16 +26,30 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Value.h" +#define DEBUG_TYPE "vectorutils" + using namespace llvm; using namespace llvm::PatternMatch; -/// Identify if the intrinsic is trivially vectorizable. -/// This method returns true if the intrinsic's argument types are all -/// scalars for the scalar form of the intrinsic and all vectors for -/// the vector form of the intrinsic. +/// Maximum factor for an interleaved memory access. +static cl::opt<unsigned> MaxInterleaveGroupFactor( + "max-interleave-group-factor", cl::Hidden, + cl::desc("Maximum factor for an interleaved access group (default = 8)"), + cl::init(8)); + +/// Return true if all of the intrinsic's arguments and return type are scalars +/// for the scalar form of the intrinsic and vectors for the vector form of the +/// intrinsic. bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { switch (ID) { - case Intrinsic::sqrt: + case Intrinsic::bswap: // Begin integer bit-manipulation. + case Intrinsic::bitreverse: + case Intrinsic::ctpop: + case Intrinsic::ctlz: + case Intrinsic::cttz: + case Intrinsic::fshl: + case Intrinsic::fshr: + case Intrinsic::sqrt: // Begin floating-point. case Intrinsic::sin: case Intrinsic::cos: case Intrinsic::exp: @@ -45,6 +60,8 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::fabs: case Intrinsic::minnum: case Intrinsic::maxnum: + case Intrinsic::minimum: + case Intrinsic::maximum: case Intrinsic::copysign: case Intrinsic::floor: case Intrinsic::ceil: @@ -52,15 +69,15 @@ bool llvm::isTriviallyVectorizable(Intrinsic::ID ID) { case Intrinsic::rint: case Intrinsic::nearbyint: case Intrinsic::round: - case Intrinsic::bswap: - case Intrinsic::bitreverse: - case Intrinsic::ctpop: case Intrinsic::pow: case Intrinsic::fma: case Intrinsic::fmuladd: - case Intrinsic::ctlz: - case Intrinsic::cttz: case Intrinsic::powi: + case Intrinsic::canonicalize: + case Intrinsic::sadd_sat: + case Intrinsic::ssub_sat: + case Intrinsic::uadd_sat: + case Intrinsic::usub_sat: return true; default: return false; @@ -270,9 +287,10 @@ Value *llvm::findScalarElement(Value *V, unsigned EltNo) { } // Extract a value from a vector add operation with a constant zero. - Value *Val = nullptr; Constant *Con = nullptr; - if (match(V, m_Add(m_Value(Val), m_Constant(Con)))) - if (Constant *Elt = Con->getAggregateElement(EltNo)) + // TODO: Use getBinOpIdentity() to generalize this. + Value *Val; Constant *C; + if (match(V, m_Add(m_Value(Val), m_Constant(C)))) + if (Constant *Elt = C->getAggregateElement(EltNo)) if (Elt->isNullValue()) return findScalarElement(Val, EltNo); @@ -450,16 +468,100 @@ llvm::computeMinimumValueSizes(ArrayRef<BasicBlock *> Blocks, DemandedBits &DB, return MinBWs; } +/// Add all access groups in @p AccGroups to @p List. +template <typename ListT> +static void addToAccessGroupList(ListT &List, MDNode *AccGroups) { + // Interpret an access group as a list containing itself. + if (AccGroups->getNumOperands() == 0) { + assert(isValidAsAccessGroup(AccGroups) && "Node must be an access group"); + List.insert(AccGroups); + return; + } + + for (auto &AccGroupListOp : AccGroups->operands()) { + auto *Item = cast<MDNode>(AccGroupListOp.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + List.insert(Item); + } +} + +MDNode *llvm::uniteAccessGroups(MDNode *AccGroups1, MDNode *AccGroups2) { + if (!AccGroups1) + return AccGroups2; + if (!AccGroups2) + return AccGroups1; + if (AccGroups1 == AccGroups2) + return AccGroups1; + + SmallSetVector<Metadata *, 4> Union; + addToAccessGroupList(Union, AccGroups1); + addToAccessGroupList(Union, AccGroups2); + + if (Union.size() == 0) + return nullptr; + if (Union.size() == 1) + return cast<MDNode>(Union.front()); + + LLVMContext &Ctx = AccGroups1->getContext(); + return MDNode::get(Ctx, Union.getArrayRef()); +} + +MDNode *llvm::intersectAccessGroups(const Instruction *Inst1, + const Instruction *Inst2) { + bool MayAccessMem1 = Inst1->mayReadOrWriteMemory(); + bool MayAccessMem2 = Inst2->mayReadOrWriteMemory(); + + if (!MayAccessMem1 && !MayAccessMem2) + return nullptr; + if (!MayAccessMem1) + return Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MayAccessMem2) + return Inst1->getMetadata(LLVMContext::MD_access_group); + + MDNode *MD1 = Inst1->getMetadata(LLVMContext::MD_access_group); + MDNode *MD2 = Inst2->getMetadata(LLVMContext::MD_access_group); + if (!MD1 || !MD2) + return nullptr; + if (MD1 == MD2) + return MD1; + + // Use set for scalable 'contains' check. + SmallPtrSet<Metadata *, 4> AccGroupSet2; + addToAccessGroupList(AccGroupSet2, MD2); + + SmallVector<Metadata *, 4> Intersection; + if (MD1->getNumOperands() == 0) { + assert(isValidAsAccessGroup(MD1) && "Node must be an access group"); + if (AccGroupSet2.count(MD1)) + Intersection.push_back(MD1); + } else { + for (const MDOperand &Node : MD1->operands()) { + auto *Item = cast<MDNode>(Node.get()); + assert(isValidAsAccessGroup(Item) && "List item must be an access group"); + if (AccGroupSet2.count(Item)) + Intersection.push_back(Item); + } + } + + if (Intersection.size() == 0) + return nullptr; + if (Intersection.size() == 1) + return cast<MDNode>(Intersection.front()); + + LLVMContext &Ctx = Inst1->getContext(); + return MDNode::get(Ctx, Intersection); +} + /// \returns \p I after propagating metadata from \p VL. Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { Instruction *I0 = cast<Instruction>(VL[0]); SmallVector<std::pair<unsigned, MDNode *>, 4> Metadata; I0->getAllMetadataOtherThanDebugLoc(Metadata); - for (auto Kind : - {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, - LLVMContext::MD_noalias, LLVMContext::MD_fpmath, - LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load}) { + for (auto Kind : {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_fpmath, + LLVMContext::MD_nontemporal, LLVMContext::MD_invariant_load, + LLVMContext::MD_access_group}) { MDNode *MD = I0->getMetadata(Kind); for (int J = 1, E = VL.size(); MD && J != E; ++J) { @@ -480,6 +582,9 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { case LLVMContext::MD_invariant_load: MD = MDNode::intersect(MD, IMD); break; + case LLVMContext::MD_access_group: + MD = intersectAccessGroups(Inst, IJ); + break; default: llvm_unreachable("unhandled metadata"); } @@ -491,6 +596,36 @@ Instruction *llvm::propagateMetadata(Instruction *Inst, ArrayRef<Value *> VL) { return Inst; } +Constant * +llvm::createBitMaskForGaps(IRBuilder<> &Builder, unsigned VF, + const InterleaveGroup<Instruction> &Group) { + // All 1's means mask is not needed. + if (Group.getNumMembers() == Group.getFactor()) + return nullptr; + + // TODO: support reversed access. + assert(!Group.isReverse() && "Reversed group not supported."); + + SmallVector<Constant *, 16> Mask; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < Group.getFactor(); ++j) { + unsigned HasMember = Group.getMember(j) ? 1 : 0; + Mask.push_back(Builder.getInt1(HasMember)); + } + + return ConstantVector::get(Mask); +} + +Constant *llvm::createReplicatedMask(IRBuilder<> &Builder, + unsigned ReplicationFactor, unsigned VF) { + SmallVector<Constant *, 16> MaskVec; + for (unsigned i = 0; i < VF; i++) + for (unsigned j = 0; j < ReplicationFactor; j++) + MaskVec.push_back(Builder.getInt32(i)); + + return ConstantVector::get(MaskVec); +} + Constant *llvm::createInterleaveMask(IRBuilder<> &Builder, unsigned VF, unsigned NumVecs) { SmallVector<Constant *, 16> Mask; @@ -575,3 +710,364 @@ Value *llvm::concatenateVectors(IRBuilder<> &Builder, ArrayRef<Value *> Vecs) { return ResList[0]; } + +bool InterleavedAccessInfo::isStrided(int Stride) { + unsigned Factor = std::abs(Stride); + return Factor >= 2 && Factor <= MaxInterleaveGroupFactor; +} + +void InterleavedAccessInfo::collectConstStrideAccesses( + MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, + const ValueToValueMap &Strides) { + auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + + // Since it's desired that the load/store instructions be maintained in + // "program order" for the interleaved access analysis, we have to visit the + // blocks in the loop in reverse postorder (i.e., in a topological order). + // Such an ordering will ensure that any load/store that may be executed + // before a second load/store will precede the second load/store in + // AccessStrideInfo. + LoopBlocksDFS DFS(TheLoop); + DFS.perform(LI); + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) + for (auto &I : *BB) { + auto *LI = dyn_cast<LoadInst>(&I); + auto *SI = dyn_cast<StoreInst>(&I); + if (!LI && !SI) + continue; + + Value *Ptr = getLoadStorePointerOperand(&I); + // We don't check wrapping here because we don't know yet if Ptr will be + // part of a full group or a group with gaps. Checking wrapping for all + // pointers (even those that end up in groups with no gaps) will be overly + // conservative. For full groups, wrapping should be ok since if we would + // wrap around the address space we would do a memory access at nullptr + // even without the transformation. The wrapping checks are therefore + // deferred until after we've formed the interleaved groups. + int64_t Stride = getPtrStride(PSE, Ptr, TheLoop, Strides, + /*Assume=*/true, /*ShouldCheckWrap=*/false); + + const SCEV *Scev = replaceSymbolicStrideSCEV(PSE, Strides, Ptr); + PointerType *PtrTy = dyn_cast<PointerType>(Ptr->getType()); + uint64_t Size = DL.getTypeAllocSize(PtrTy->getElementType()); + + // An alignment of 0 means target ABI alignment. + unsigned Align = getLoadStoreAlignment(&I); + if (!Align) + Align = DL.getABITypeAlignment(PtrTy->getElementType()); + + AccessStrideInfo[&I] = StrideDescriptor(Stride, Scev, Size, Align); + } +} + +// Analyze interleaved accesses and collect them into interleaved load and +// store groups. +// +// When generating code for an interleaved load group, we effectively hoist all +// loads in the group to the location of the first load in program order. When +// generating code for an interleaved store group, we sink all stores to the +// location of the last store. This code motion can change the order of load +// and store instructions and may break dependences. +// +// The code generation strategy mentioned above ensures that we won't violate +// any write-after-read (WAR) dependences. +// +// E.g., for the WAR dependence: a = A[i]; // (1) +// A[i] = b; // (2) +// +// The store group of (2) is always inserted at or below (2), and the load +// group of (1) is always inserted at or above (1). Thus, the instructions will +// never be reordered. All other dependences are checked to ensure the +// correctness of the instruction reordering. +// +// The algorithm visits all memory accesses in the loop in bottom-up program +// order. Program order is established by traversing the blocks in the loop in +// reverse postorder when collecting the accesses. +// +// We visit the memory accesses in bottom-up order because it can simplify the +// construction of store groups in the presence of write-after-write (WAW) +// dependences. +// +// E.g., for the WAW dependence: A[i] = a; // (1) +// A[i] = b; // (2) +// A[i + 1] = c; // (3) +// +// We will first create a store group with (3) and (2). (1) can't be added to +// this group because it and (2) are dependent. However, (1) can be grouped +// with other accesses that may precede it in program order. Note that a +// bottom-up order does not imply that WAW dependences should not be checked. +void InterleavedAccessInfo::analyzeInterleaving( + bool EnablePredicatedInterleavedMemAccesses) { + LLVM_DEBUG(dbgs() << "LV: Analyzing interleaved accesses...\n"); + const ValueToValueMap &Strides = LAI->getSymbolicStrides(); + + // Holds all accesses with a constant stride. + MapVector<Instruction *, StrideDescriptor> AccessStrideInfo; + collectConstStrideAccesses(AccessStrideInfo, Strides); + + if (AccessStrideInfo.empty()) + return; + + // Collect the dependences in the loop. + collectDependences(); + + // Holds all interleaved store groups temporarily. + SmallSetVector<InterleaveGroup<Instruction> *, 4> StoreGroups; + // Holds all interleaved load groups temporarily. + SmallSetVector<InterleaveGroup<Instruction> *, 4> LoadGroups; + + // Search in bottom-up program order for pairs of accesses (A and B) that can + // form interleaved load or store groups. In the algorithm below, access A + // precedes access B in program order. We initialize a group for B in the + // outer loop of the algorithm, and then in the inner loop, we attempt to + // insert each A into B's group if: + // + // 1. A and B have the same stride, + // 2. A and B have the same memory object size, and + // 3. A belongs in B's group according to its distance from B. + // + // Special care is taken to ensure group formation will not break any + // dependences. + for (auto BI = AccessStrideInfo.rbegin(), E = AccessStrideInfo.rend(); + BI != E; ++BI) { + Instruction *B = BI->first; + StrideDescriptor DesB = BI->second; + + // Initialize a group for B if it has an allowable stride. Even if we don't + // create a group for B, we continue with the bottom-up algorithm to ensure + // we don't break any of B's dependences. + InterleaveGroup<Instruction> *Group = nullptr; + if (isStrided(DesB.Stride) && + (!isPredicated(B->getParent()) || EnablePredicatedInterleavedMemAccesses)) { + Group = getInterleaveGroup(B); + if (!Group) { + LLVM_DEBUG(dbgs() << "LV: Creating an interleave group with:" << *B + << '\n'); + Group = createInterleaveGroup(B, DesB.Stride, DesB.Align); + } + if (B->mayWriteToMemory()) + StoreGroups.insert(Group); + else + LoadGroups.insert(Group); + } + + for (auto AI = std::next(BI); AI != E; ++AI) { + Instruction *A = AI->first; + StrideDescriptor DesA = AI->second; + + // Our code motion strategy implies that we can't have dependences + // between accesses in an interleaved group and other accesses located + // between the first and last member of the group. Note that this also + // means that a group can't have more than one member at a given offset. + // The accesses in a group can have dependences with other accesses, but + // we must ensure we don't extend the boundaries of the group such that + // we encompass those dependent accesses. + // + // For example, assume we have the sequence of accesses shown below in a + // stride-2 loop: + // + // (1, 2) is a group | A[i] = a; // (1) + // | A[i-1] = b; // (2) | + // A[i-3] = c; // (3) + // A[i] = d; // (4) | (2, 4) is not a group + // + // Because accesses (2) and (3) are dependent, we can group (2) with (1) + // but not with (4). If we did, the dependent access (3) would be within + // the boundaries of the (2, 4) group. + if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { + // If a dependence exists and A is already in a group, we know that A + // must be a store since A precedes B and WAR dependences are allowed. + // Thus, A would be sunk below B. We release A's group to prevent this + // illegal code motion. A will then be free to form another group with + // instructions that precede it. + if (isInterleaved(A)) { + InterleaveGroup<Instruction> *StoreGroup = getInterleaveGroup(A); + StoreGroups.remove(StoreGroup); + releaseGroup(StoreGroup); + } + + // If a dependence exists and A is not already in a group (or it was + // and we just released it), B might be hoisted above A (if B is a + // load) or another store might be sunk below A (if B is a store). In + // either case, we can't add additional instructions to B's group. B + // will only form a group with instructions that it precedes. + break; + } + + // At this point, we've checked for illegal code motion. If either A or B + // isn't strided, there's nothing left to do. + if (!isStrided(DesA.Stride) || !isStrided(DesB.Stride)) + continue; + + // Ignore A if it's already in a group or isn't the same kind of memory + // operation as B. + // Note that mayReadFromMemory() isn't mutually exclusive to + // mayWriteToMemory in the case of atomic loads. We shouldn't see those + // here, canVectorizeMemory() should have returned false - except for the + // case we asked for optimization remarks. + if (isInterleaved(A) || + (A->mayReadFromMemory() != B->mayReadFromMemory()) || + (A->mayWriteToMemory() != B->mayWriteToMemory())) + continue; + + // Check rules 1 and 2. Ignore A if its stride or size is different from + // that of B. + if (DesA.Stride != DesB.Stride || DesA.Size != DesB.Size) + continue; + + // Ignore A if the memory object of A and B don't belong to the same + // address space + if (getLoadStoreAddressSpace(A) != getLoadStoreAddressSpace(B)) + continue; + + // Calculate the distance from A to B. + const SCEVConstant *DistToB = dyn_cast<SCEVConstant>( + PSE.getSE()->getMinusSCEV(DesA.Scev, DesB.Scev)); + if (!DistToB) + continue; + int64_t DistanceToB = DistToB->getAPInt().getSExtValue(); + + // Check rule 3. Ignore A if its distance to B is not a multiple of the + // size. + if (DistanceToB % static_cast<int64_t>(DesB.Size)) + continue; + + // All members of a predicated interleave-group must have the same predicate, + // and currently must reside in the same BB. + BasicBlock *BlockA = A->getParent(); + BasicBlock *BlockB = B->getParent(); + if ((isPredicated(BlockA) || isPredicated(BlockB)) && + (!EnablePredicatedInterleavedMemAccesses || BlockA != BlockB)) + continue; + + // The index of A is the index of B plus A's distance to B in multiples + // of the size. + int IndexA = + Group->getIndex(B) + DistanceToB / static_cast<int64_t>(DesB.Size); + + // Try to insert A into B's group. + if (Group->insertMember(A, IndexA, DesA.Align)) { + LLVM_DEBUG(dbgs() << "LV: Inserted:" << *A << '\n' + << " into the interleave group with" << *B + << '\n'); + InterleaveGroupMap[A] = Group; + + // Set the first load in program order as the insert position. + if (A->mayReadFromMemory()) + Group->setInsertPos(A); + } + } // Iteration over A accesses. + } // Iteration over B accesses. + + // Remove interleaved store groups with gaps. + for (auto *Group : StoreGroups) + if (Group->getNumMembers() != Group->getFactor()) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved store group due " + "to gaps.\n"); + releaseGroup(Group); + } + // Remove interleaved groups with gaps (currently only loads) whose memory + // accesses may wrap around. We have to revisit the getPtrStride analysis, + // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does + // not check wrapping (see documentation there). + // FORNOW we use Assume=false; + // TODO: Change to Assume=true but making sure we don't exceed the threshold + // of runtime SCEV assumptions checks (thereby potentially failing to + // vectorize altogether). + // Additional optional optimizations: + // TODO: If we are peeling the loop and we know that the first pointer doesn't + // wrap then we can deduce that all pointers in the group don't wrap. + // This means that we can forcefully peel the loop in order to only have to + // check the first pointer for no-wrap. When we'll change to use Assume=true + // we'll only need at most one runtime check per interleaved group. + for (auto *Group : LoadGroups) { + // Case 1: A full group. Can Skip the checks; For full groups, if the wide + // load would wrap around the address space we would do a memory access at + // nullptr even without the transformation. + if (Group->getNumMembers() == Group->getFactor()) + continue; + + // Case 2: If first and last members of the group don't wrap this implies + // that all the pointers in the group don't wrap. + // So we check only group member 0 (which is always guaranteed to exist), + // and group member Factor - 1; If the latter doesn't exist we rely on + // peeling (if it is a non-reveresed accsess -- see Case 3). + Value *FirstMemberPtr = getLoadStorePointerOperand(Group->getMember(0)); + if (!getPtrStride(PSE, FirstMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "first group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + continue; + } + Instruction *LastMember = Group->getMember(Group->getFactor() - 1); + if (LastMember) { + Value *LastMemberPtr = getLoadStorePointerOperand(LastMember); + if (!getPtrStride(PSE, LastMemberPtr, TheLoop, Strides, /*Assume=*/false, + /*ShouldCheckWrap=*/true)) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "last group member potentially pointer-wrapping.\n"); + releaseGroup(Group); + } + } else { + // Case 3: A non-reversed interleaved load group with gaps: We need + // to execute at least one scalar epilogue iteration. This will ensure + // we don't speculatively access memory out-of-bounds. We only need + // to look for a member at index factor - 1, since every group must have + // a member at index zero. + if (Group->isReverse()) { + LLVM_DEBUG( + dbgs() << "LV: Invalidate candidate interleaved group due to " + "a reverse access with gaps.\n"); + releaseGroup(Group); + continue; + } + LLVM_DEBUG( + dbgs() << "LV: Interleaved group requires epilogue iteration.\n"); + RequiresScalarEpilogue = true; + } + } +} + +void InterleavedAccessInfo::invalidateGroupsRequiringScalarEpilogue() { + // If no group had triggered the requirement to create an epilogue loop, + // there is nothing to do. + if (!requiresScalarEpilogue()) + return; + + // Avoid releasing a Group twice. + SmallPtrSet<InterleaveGroup<Instruction> *, 4> DelSet; + for (auto &I : InterleaveGroupMap) { + InterleaveGroup<Instruction> *Group = I.second; + if (Group->requiresScalarEpilogue()) + DelSet.insert(Group); + } + for (auto *Ptr : DelSet) { + LLVM_DEBUG( + dbgs() + << "LV: Invalidate candidate interleaved group due to gaps that " + "require a scalar epilogue (not allowed under optsize) and cannot " + "be masked (not enabled). \n"); + releaseGroup(Ptr); + } + + RequiresScalarEpilogue = false; +} + +template <typename InstT> +void InterleaveGroup<InstT>::addMetadata(InstT *NewInst) const { + llvm_unreachable("addMetadata can only be used for Instruction"); +} + +namespace llvm { +template <> +void InterleaveGroup<Instruction>::addMetadata(Instruction *NewInst) const { + SmallVector<Value *, 4> VL; + std::transform(Members.begin(), Members.end(), std::back_inserter(VL), + [](std::pair<int, Instruction *> p) { return p.second; }); + propagateMetadata(NewInst, VL); +} +} |