aboutsummaryrefslogtreecommitdiff
path: root/lib/Analysis/SyncDependenceAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Analysis/SyncDependenceAnalysis.cpp')
-rw-r--r--lib/Analysis/SyncDependenceAnalysis.cpp35
1 files changed, 22 insertions, 13 deletions
diff --git a/lib/Analysis/SyncDependenceAnalysis.cpp b/lib/Analysis/SyncDependenceAnalysis.cpp
index e1a7e4476d12..3cf248a31142 100644
--- a/lib/Analysis/SyncDependenceAnalysis.cpp
+++ b/lib/Analysis/SyncDependenceAnalysis.cpp
@@ -1,10 +1,9 @@
//===- SyncDependenceAnalysis.cpp - Divergent Branch Dependence Calculation
//--===//
//
-// The LLVM Compiler Infrastructure
-//
-// This file is distributed under the University of Illinois Open Source
-// License. See LICENSE.TXT for details.
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
@@ -219,14 +218,9 @@ struct DivergencePropagator {
template <typename SuccessorIterable>
std::unique_ptr<ConstBlockSet>
computeJoinPoints(const BasicBlock &RootBlock,
- SuccessorIterable NodeSuccessors, const Loop *ParentLoop) {
+ SuccessorIterable NodeSuccessors, const Loop *ParentLoop, const BasicBlock * PdBoundBlock) {
assert(JoinBlocks);
- // immediate post dominator (no join block beyond that block)
- const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(&RootBlock));
- const auto *IpdNode = PdNode->getIDom();
- const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
-
// bootstrap with branch targets
for (const auto *SuccBlock : NodeSuccessors) {
DefMap.emplace(SuccBlock, SuccBlock);
@@ -341,13 +335,23 @@ const ConstBlockSet &SyncDependenceAnalysis::join_blocks(const Loop &Loop) {
// already available in cache?
auto ItCached = CachedLoopExitJoins.find(&Loop);
- if (ItCached != CachedLoopExitJoins.end())
+ if (ItCached != CachedLoopExitJoins.end()) {
return *ItCached->second;
+ }
+
+ // dont propagte beyond the immediate post dom of the loop
+ const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Loop.getHeader()));
+ const auto *IpdNode = PdNode->getIDom();
+ const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
+ while (PdBoundBlock && Loop.contains(PdBoundBlock)) {
+ IpdNode = IpdNode->getIDom();
+ PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
+ }
// compute all join points
DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
auto JoinBlocks = Propagator.computeJoinPoints<const LoopExitVec &>(
- *Loop.getHeader(), LoopExits, Loop.getParentLoop());
+ *Loop.getHeader(), LoopExits, Loop.getParentLoop(), PdBoundBlock);
auto ItInserted = CachedLoopExitJoins.emplace(&Loop, std::move(JoinBlocks));
assert(ItInserted.second);
@@ -366,11 +370,16 @@ SyncDependenceAnalysis::join_blocks(const Instruction &Term) {
if (ItCached != CachedBranchJoins.end())
return *ItCached->second;
+ // dont propagate beyond the immediate post dominator of the branch
+ const auto *PdNode = PDT.getNode(const_cast<BasicBlock *>(Term.getParent()));
+ const auto *IpdNode = PdNode->getIDom();
+ const auto *PdBoundBlock = IpdNode ? IpdNode->getBlock() : nullptr;
+
// compute all join points
DivergencePropagator Propagator{FuncRPOT, DT, PDT, LI};
const auto &TermBlock = *Term.getParent();
auto JoinBlocks = Propagator.computeJoinPoints<succ_const_range>(
- TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock));
+ TermBlock, successors(Term.getParent()), LI.getLoopFor(&TermBlock), PdBoundBlock);
auto ItInserted = CachedBranchJoins.emplace(&Term, std::move(JoinBlocks));
assert(ItInserted.second);