diff options
Diffstat (limited to 'llvm/lib/Analysis/DivergenceAnalysis.cpp')
-rw-r--r-- | llvm/lib/Analysis/DivergenceAnalysis.cpp | 40 |
1 files changed, 25 insertions, 15 deletions
diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp index 3d1be1e1cce09..343406c9bba16 100644 --- a/llvm/lib/Analysis/DivergenceAnalysis.cpp +++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp @@ -184,6 +184,17 @@ bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const { return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB); } +static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) { + for (auto &Op : I.operands()) { + auto *OpInst = dyn_cast<Instruction>(&Op); + if (!OpInst) + continue; + if (DivLoop->contains(OpInst->getParent())) + return true; + } + return false; +} + // marks all users of loop-carried values of the loop headed by LoopHeader as // divergent void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { @@ -227,16 +238,14 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) { continue; if (isDivergent(I)) continue; + if (!usesLiveOut(I, DivLoop)) + continue; - for (auto &Op : I.operands()) { - auto *OpInst = dyn_cast<Instruction>(&Op); - if (!OpInst) - continue; - if (DivLoop->contains(OpInst->getParent())) { - markDivergent(I); - pushUsers(I); - break; - } + markDivergent(I); + if (I.isTerminator()) { + propagateBranchDivergence(I); + } else { + pushUsers(I); } } @@ -286,14 +295,11 @@ bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock, // push non-divergent phi nodes in JoinBlock to the worklist pushPHINodes(JoinBlock); - // JoinBlock is a divergent loop exit - if (BranchLoop && !BranchLoop->contains(&JoinBlock)) { - return true; - } - // disjoint-paths divergent at JoinBlock markBlockJoinDivergent(JoinBlock); - return false; + + // JoinBlock is a divergent loop exit + return BranchLoop && !BranchLoop->contains(&JoinBlock); } void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { @@ -301,6 +307,10 @@ void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) { markDivergent(Term); + // Don't propagate divergence from unreachable blocks. + if (!DT.isReachableFromEntry(Term.getParent())) + return; + const auto *BranchLoop = LI.getLoopFor(Term.getParent()); // whether there is a divergent loop exit from BranchLoop (if any) |