summaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/DivergenceAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/DivergenceAnalysis.cpp')
-rw-r--r--llvm/lib/Analysis/DivergenceAnalysis.cpp40
1 files changed, 25 insertions, 15 deletions
diff --git a/llvm/lib/Analysis/DivergenceAnalysis.cpp b/llvm/lib/Analysis/DivergenceAnalysis.cpp
index 3d1be1e1cce09..343406c9bba16 100644
--- a/llvm/lib/Analysis/DivergenceAnalysis.cpp
+++ b/llvm/lib/Analysis/DivergenceAnalysis.cpp
@@ -184,6 +184,17 @@ bool DivergenceAnalysis::inRegion(const BasicBlock &BB) const {
return (!RegionLoop && BB.getParent() == &F) || RegionLoop->contains(&BB);
}
+static bool usesLiveOut(const Instruction &I, const Loop *DivLoop) {
+ for (auto &Op : I.operands()) {
+ auto *OpInst = dyn_cast<Instruction>(&Op);
+ if (!OpInst)
+ continue;
+ if (DivLoop->contains(OpInst->getParent()))
+ return true;
+ }
+ return false;
+}
+
// marks all users of loop-carried values of the loop headed by LoopHeader as
// divergent
void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
@@ -227,16 +238,14 @@ void DivergenceAnalysis::taintLoopLiveOuts(const BasicBlock &LoopHeader) {
continue;
if (isDivergent(I))
continue;
+ if (!usesLiveOut(I, DivLoop))
+ continue;
- for (auto &Op : I.operands()) {
- auto *OpInst = dyn_cast<Instruction>(&Op);
- if (!OpInst)
- continue;
- if (DivLoop->contains(OpInst->getParent())) {
- markDivergent(I);
- pushUsers(I);
- break;
- }
+ markDivergent(I);
+ if (I.isTerminator()) {
+ propagateBranchDivergence(I);
+ } else {
+ pushUsers(I);
}
}
@@ -286,14 +295,11 @@ bool DivergenceAnalysis::propagateJoinDivergence(const BasicBlock &JoinBlock,
// push non-divergent phi nodes in JoinBlock to the worklist
pushPHINodes(JoinBlock);
- // JoinBlock is a divergent loop exit
- if (BranchLoop && !BranchLoop->contains(&JoinBlock)) {
- return true;
- }
-
// disjoint-paths divergent at JoinBlock
markBlockJoinDivergent(JoinBlock);
- return false;
+
+ // JoinBlock is a divergent loop exit
+ return BranchLoop && !BranchLoop->contains(&JoinBlock);
}
void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
@@ -301,6 +307,10 @@ void DivergenceAnalysis::propagateBranchDivergence(const Instruction &Term) {
markDivergent(Term);
+ // Don't propagate divergence from unreachable blocks.
+ if (!DT.isReachableFromEntry(Term.getParent()))
+ return;
+
const auto *BranchLoop = LI.getLoopFor(Term.getParent());
// whether there is a divergent loop exit from BranchLoop (if any)