diff options
Diffstat (limited to 'llvm/lib/Analysis/LoopAccessAnalysis.cpp')
-rw-r--r-- | llvm/lib/Analysis/LoopAccessAnalysis.cpp | 254 |
1 files changed, 60 insertions, 194 deletions
diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp index 26fa5112c29a7..ae282a7a10952 100644 --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -30,7 +30,6 @@ #include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" @@ -43,7 +42,6 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" @@ -174,6 +172,13 @@ const SCEV *llvm::replaceSymbolicStrideSCEV(PredicatedScalarEvolution &PSE, return OrigSCEV; } +RuntimeCheckingPtrGroup::RuntimeCheckingPtrGroup( + unsigned Index, RuntimePointerChecking &RtCheck) + : RtCheck(RtCheck), High(RtCheck.Pointers[Index].End), + Low(RtCheck.Pointers[Index].Start) { + Members.push_back(Index); +} + /// Calculate Start and End points of memory access. /// Let's assume A is the first access and B is a memory access on N-th loop /// iteration. Then B is calculated as: @@ -231,14 +236,14 @@ void RuntimePointerChecking::insert(Loop *Lp, Value *Ptr, bool WritePtr, Pointers.emplace_back(Ptr, ScStart, ScEnd, WritePtr, DepSetId, ASId, Sc); } -SmallVector<RuntimePointerChecking::PointerCheck, 4> +SmallVector<RuntimePointerCheck, 4> RuntimePointerChecking::generateChecks() const { - SmallVector<PointerCheck, 4> Checks; + SmallVector<RuntimePointerCheck, 4> Checks; for (unsigned I = 0; I < CheckingGroups.size(); ++I) { for (unsigned J = I + 1; J < CheckingGroups.size(); ++J) { - const RuntimePointerChecking::CheckingPtrGroup &CGI = CheckingGroups[I]; - const RuntimePointerChecking::CheckingPtrGroup &CGJ = CheckingGroups[J]; + const RuntimeCheckingPtrGroup &CGI = CheckingGroups[I]; + const RuntimeCheckingPtrGroup &CGJ = CheckingGroups[J]; if (needsChecking(CGI, CGJ)) Checks.push_back(std::make_pair(&CGI, &CGJ)); @@ -254,8 +259,8 @@ void RuntimePointerChecking::generateChecks( Checks = generateChecks(); } -bool RuntimePointerChecking::needsChecking(const CheckingPtrGroup &M, - const CheckingPtrGroup &N) const { +bool RuntimePointerChecking::needsChecking( + const RuntimeCheckingPtrGroup &M, const RuntimeCheckingPtrGroup &N) const { for (unsigned I = 0, EI = M.Members.size(); EI != I; ++I) for (unsigned J = 0, EJ = N.Members.size(); EJ != J; ++J) if (needsChecking(M.Members[I], N.Members[J])) @@ -277,7 +282,7 @@ static const SCEV *getMinFromExprs(const SCEV *I, const SCEV *J, return I; } -bool RuntimePointerChecking::CheckingPtrGroup::addPointer(unsigned Index) { +bool RuntimeCheckingPtrGroup::addPointer(unsigned Index) { const SCEV *Start = RtCheck.Pointers[Index].Start; const SCEV *End = RtCheck.Pointers[Index].End; @@ -352,7 +357,7 @@ void RuntimePointerChecking::groupChecks( // pointers to the same underlying object. if (!UseDependencies) { for (unsigned I = 0; I < Pointers.size(); ++I) - CheckingGroups.push_back(CheckingPtrGroup(I, *this)); + CheckingGroups.push_back(RuntimeCheckingPtrGroup(I, *this)); return; } @@ -378,7 +383,7 @@ void RuntimePointerChecking::groupChecks( MemoryDepChecker::MemAccessInfo Access(Pointers[I].PointerValue, Pointers[I].IsWritePtr); - SmallVector<CheckingPtrGroup, 2> Groups; + SmallVector<RuntimeCheckingPtrGroup, 2> Groups; auto LeaderI = DepCands.findValue(DepCands.getLeaderValue(Access)); // Because DepCands is constructed by visiting accesses in the order in @@ -395,7 +400,7 @@ void RuntimePointerChecking::groupChecks( // Go through all the existing sets and see if we can find one // which can include this pointer. - for (CheckingPtrGroup &Group : Groups) { + for (RuntimeCheckingPtrGroup &Group : Groups) { // Don't perform more than a certain amount of comparisons. // This should limit the cost of grouping the pointers to something // reasonable. If we do end up hitting this threshold, the algorithm @@ -415,7 +420,7 @@ void RuntimePointerChecking::groupChecks( // We couldn't add this pointer to any existing set or the threshold // for the number of comparisons has been reached. Create a new group // to hold the current pointer. - Groups.push_back(CheckingPtrGroup(Pointer, *this)); + Groups.push_back(RuntimeCheckingPtrGroup(Pointer, *this)); } // We've computed the grouped checks for this partition. @@ -451,7 +456,7 @@ bool RuntimePointerChecking::needsChecking(unsigned I, unsigned J) const { } void RuntimePointerChecking::printChecks( - raw_ostream &OS, const SmallVectorImpl<PointerCheck> &Checks, + raw_ostream &OS, const SmallVectorImpl<RuntimePointerCheck> &Checks, unsigned Depth) const { unsigned N = 0; for (const auto &Check : Checks) { @@ -500,7 +505,7 @@ public: typedef PointerIntPair<Value *, 1, bool> MemAccessInfo; typedef SmallVector<MemAccessInfo, 8> MemAccessInfoList; - AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AliasAnalysis *AA, + AccessAnalysis(const DataLayout &Dl, Loop *TheLoop, AAResults *AA, LoopInfo *LI, MemoryDepChecker::DepCandidates &DA, PredicatedScalarEvolution &PSE) : DL(Dl), TheLoop(TheLoop), AST(*AA), LI(LI), DepCands(DA), @@ -700,18 +705,19 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, // to place a runtime bound check. bool CanDoRT = true; - bool NeedRTCheck = false; + bool MayNeedRTCheck = false; if (!IsRTCheckAnalysisNeeded) return true; bool IsDepCheckNeeded = isDependencyCheckNeeded(); // We assign a consecutive id to access from different alias sets. // Accesses between different groups doesn't need to be checked. - unsigned ASId = 1; + unsigned ASId = 0; for (auto &AS : AST) { int NumReadPtrChecks = 0; int NumWritePtrChecks = 0; bool CanDoAliasSetRT = true; + ++ASId; // We assign consecutive id to access from different dependence sets. // Accesses within the same set don't need a runtime check. @@ -742,14 +748,30 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, // check them. But there is no need to checks if there is only one // dependence set for this alias set. // - // Note that this function computes CanDoRT and NeedRTCheck independently. - // For example CanDoRT=false, NeedRTCheck=false means that we have a pointer - // for which we couldn't find the bounds but we don't actually need to emit - // any checks so it does not matter. + // Note that this function computes CanDoRT and MayNeedRTCheck + // independently. For example CanDoRT=false, MayNeedRTCheck=false means that + // we have a pointer for which we couldn't find the bounds but we don't + // actually need to emit any checks so it does not matter. bool NeedsAliasSetRTCheck = false; - if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2)) + if (!(IsDepCheckNeeded && CanDoAliasSetRT && RunningDepId == 2)) { NeedsAliasSetRTCheck = (NumWritePtrChecks >= 2 || (NumReadPtrChecks >= 1 && NumWritePtrChecks >= 1)); + // For alias sets without at least 2 writes or 1 write and 1 read, there + // is no need to generate RT checks and CanDoAliasSetRT for this alias set + // does not impact whether runtime checks can be generated. + if (!NeedsAliasSetRTCheck) { + assert((AS.size() <= 1 || + all_of(AS, + [this](auto AC) { + MemAccessInfo AccessWrite(AC.getValue(), true); + return DepCands.findValue(AccessWrite) == + DepCands.end(); + })) && + "Can only skip updating CanDoRT below, if all entries in AS " + "are reads or there is at most 1 entry"); + continue; + } + } // We need to perform run-time alias checks, but some pointers had bounds // that couldn't be checked. @@ -768,7 +790,7 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, } CanDoRT &= CanDoAliasSetRT; - NeedRTCheck |= NeedsAliasSetRTCheck; + MayNeedRTCheck |= NeedsAliasSetRTCheck; ++ASId; } @@ -802,15 +824,18 @@ bool AccessAnalysis::canCheckPtrAtRT(RuntimePointerChecking &RtCheck, } } - if (NeedRTCheck && CanDoRT) + if (MayNeedRTCheck && CanDoRT) RtCheck.generateChecks(DepCands, IsDepCheckNeeded); LLVM_DEBUG(dbgs() << "LAA: We need to do " << RtCheck.getNumberOfChecks() << " pointer comparisons.\n"); - RtCheck.Need = NeedRTCheck; + // If we can do run-time checks, but there are no checks, no runtime checks + // are needed. This can happen when all pointers point to the same underlying + // object for example. + RtCheck.Need = CanDoRT ? RtCheck.getNumberOfChecks() != 0 : MayNeedRTCheck; - bool CanDoRTIfNeeded = !NeedRTCheck || CanDoRT; + bool CanDoRTIfNeeded = !RtCheck.Need || CanDoRT; if (!CanDoRTIfNeeded) RtCheck.reset(); return CanDoRTIfNeeded; @@ -1787,7 +1812,7 @@ bool LoopAccessInfo::canAnalyzeLoop() { return true; } -void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, +void LoopAccessInfo::analyzeLoop(AAResults *AA, LoopInfo *LI, const TargetLibraryInfo *TLI, DominatorTree *DT) { typedef SmallPtrSet<Value*, 16> ValueSet; @@ -1810,6 +1835,10 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, const bool IsAnnotatedParallel = TheLoop->isAnnotatedParallel(); + const bool EnableMemAccessVersioningOfLoop = + EnableMemAccessVersioning && + !TheLoop->getHeader()->getParent()->hasOptSize(); + // For each block. for (BasicBlock *BB : TheLoop->blocks()) { // Scan the BB and collect legal loads and stores. Also detect any @@ -1845,7 +1874,7 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, // If the function has an explicit vectorized counterpart, we can safely // assume that it can be vectorized. if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() && - TLI->isFunctionVectorizable(Call->getCalledFunction()->getName())) + !VFDatabase::getMappings(*Call).empty()) continue; auto *Ld = dyn_cast<LoadInst>(&I); @@ -1865,7 +1894,7 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, NumLoads++; Loads.push_back(Ld); DepChecker->addAccess(Ld); - if (EnableMemAccessVersioning) + if (EnableMemAccessVersioningOfLoop) collectStridedAccess(Ld); continue; } @@ -1889,7 +1918,7 @@ void LoopAccessInfo::analyzeLoop(AliasAnalysis *AA, LoopInfo *LI, NumStores++; Stores.push_back(St); DepChecker->addAccess(St); - if (EnableMemAccessVersioning) + if (EnableMemAccessVersioningOfLoop) collectStridedAccess(St); } } // Next instr. @@ -2116,169 +2145,6 @@ bool LoopAccessInfo::isUniform(Value *V) const { return (SE->isLoopInvariant(SE->getSCEV(V), TheLoop)); } -// FIXME: this function is currently a duplicate of the one in -// LoopVectorize.cpp. -static Instruction *getFirstInst(Instruction *FirstInst, Value *V, - Instruction *Loc) { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast<Instruction>(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; -} - -namespace { - -/// IR Values for the lower and upper bounds of a pointer evolution. We -/// need to use value-handles because SCEV expansion can invalidate previously -/// expanded values. Thus expansion of a pointer can invalidate the bounds for -/// a previous one. -struct PointerBounds { - TrackingVH<Value> Start; - TrackingVH<Value> End; -}; - -} // end anonymous namespace - -/// Expand code for the lower and upper bound of the pointer group \p CG -/// in \p TheLoop. \return the values for the bounds. -static PointerBounds -expandBounds(const RuntimePointerChecking::CheckingPtrGroup *CG, Loop *TheLoop, - Instruction *Loc, SCEVExpander &Exp, ScalarEvolution *SE, - const RuntimePointerChecking &PtrRtChecking) { - Value *Ptr = PtrRtChecking.Pointers[CG->Members[0]].PointerValue; - const SCEV *Sc = SE->getSCEV(Ptr); - - unsigned AS = Ptr->getType()->getPointerAddressSpace(); - LLVMContext &Ctx = Loc->getContext(); - - // Use this type for pointer arithmetic. - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); - - if (SE->isLoopInvariant(Sc, TheLoop)) { - LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" - << *Ptr << "\n"); - // Ptr could be in the loop body. If so, expand a new one at the correct - // location. - Instruction *Inst = dyn_cast<Instruction>(Ptr); - Value *NewPtr = (Inst && TheLoop->contains(Inst)) - ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) - : Ptr; - // We must return a half-open range, which means incrementing Sc. - const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); - Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); - return {NewPtr, NewPtrPlusOne}; - } else { - Value *Start = nullptr, *End = nullptr; - LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); - Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); - End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); - LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High - << "\n"); - return {Start, End}; - } -} - -/// Turns a collection of checks into a collection of expanded upper and -/// lower bounds for both pointers in the check. -static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds( - const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks, - Loop *L, Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp, - const RuntimePointerChecking &PtrRtChecking) { - SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; - - // Here we're relying on the SCEV Expander's cache to only emit code for the - // same bounds once. - transform( - PointerChecks, std::back_inserter(ChecksWithBounds), - [&](const RuntimePointerChecking::PointerCheck &Check) { - PointerBounds - First = expandBounds(Check.first, L, Loc, Exp, SE, PtrRtChecking), - Second = expandBounds(Check.second, L, Loc, Exp, SE, PtrRtChecking); - return std::make_pair(First, Second); - }); - - return ChecksWithBounds; -} - -std::pair<Instruction *, Instruction *> LoopAccessInfo::addRuntimeChecks( - Instruction *Loc, - const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &PointerChecks) - const { - const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); - auto *SE = PSE->getSE(); - SCEVExpander Exp(*SE, DL, "induction"); - auto ExpandedChecks = - expandBounds(PointerChecks, TheLoop, Loc, SE, Exp, *PtrRtChecking); - - LLVMContext &Ctx = Loc->getContext(); - Instruction *FirstInst = nullptr; - IRBuilder<> ChkBuilder(Loc); - // Our instructions might fold to a constant. - Value *MemoryRuntimeCheck = nullptr; - - for (const auto &Check : ExpandedChecks) { - const PointerBounds &A = Check.first, &B = Check.second; - // Check if two pointers (A and B) conflict where conflict is computed as: - // start(A) <= end(B) && start(B) <= end(A) - unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); - unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); - - assert((AS0 == B.End->getType()->getPointerAddressSpace()) && - (AS1 == A.End->getType()->getPointerAddressSpace()) && - "Trying to bounds check pointers with different address spaces"); - - Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); - Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); - - Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); - Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); - Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); - Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); - - // [A|B].Start points to the first accessed byte under base [A|B]. - // [A|B].End points to the last accessed byte, plus one. - // There is no conflict when the intervals are disjoint: - // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) - // - // bound0 = (B.Start < A.End) - // bound1 = (A.Start < B.End) - // IsConflict = bound0 & bound1 - Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - FirstInst = getFirstInst(FirstInst, Cmp0, Loc); - Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); - FirstInst = getFirstInst(FirstInst, Cmp1, Loc); - Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - if (MemoryRuntimeCheck) { - IsConflict = - ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); - FirstInst = getFirstInst(FirstInst, IsConflict, Loc); - } - MemoryRuntimeCheck = IsConflict; - } - - if (!MemoryRuntimeCheck) - return std::make_pair(nullptr, nullptr); - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - Instruction *Check = BinaryOperator::CreateAnd(MemoryRuntimeCheck, - ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(Check, "memcheck.conflict"); - FirstInst = getFirstInst(FirstInst, Check, Loc); - return std::make_pair(FirstInst, Check); -} - -std::pair<Instruction *, Instruction *> -LoopAccessInfo::addRuntimeChecks(Instruction *Loc) const { - if (!PtrRtChecking->Need) - return std::make_pair(nullptr, nullptr); - - return addRuntimeChecks(Loc, PtrRtChecking->getChecks()); -} - void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { Value *Ptr = nullptr; if (LoadInst *LI = dyn_cast<LoadInst>(MemAccess)) @@ -2343,7 +2209,7 @@ void LoopAccessInfo::collectStridedAccess(Value *MemAccess) { } LoopAccessInfo::LoopAccessInfo(Loop *L, ScalarEvolution *SE, - const TargetLibraryInfo *TLI, AliasAnalysis *AA, + const TargetLibraryInfo *TLI, AAResults *AA, DominatorTree *DT, LoopInfo *LI) : PSE(std::make_unique<PredicatedScalarEvolution>(*SE, *L)), PtrRtChecking(std::make_unique<RuntimePointerChecking>(SE)), |