summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopPredication.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopPredication.cpp160
1 files changed, 114 insertions, 46 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
index 4f97641e2027..aa7e79a589f2 100644
--- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp
@@ -183,6 +183,8 @@
#include "llvm/Analysis/GuardUtils.h"
#include "llvm/Analysis/LoopInfo.h"
#include "llvm/Analysis/LoopPass.h"
+#include "llvm/Analysis/MemorySSA.h"
+#include "llvm/Analysis/MemorySSAUpdater.h"
#include "llvm/Analysis/ScalarEvolution.h"
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/IR/Function.h"
@@ -254,7 +256,7 @@ class LoopPredication {
DominatorTree *DT;
ScalarEvolution *SE;
LoopInfo *LI;
- BranchProbabilityInfo *BPI;
+ MemorySSAUpdater *MSSAU;
Loop *L;
const DataLayout *DL;
@@ -302,16 +304,15 @@ class LoopPredication {
// If the loop always exits through another block in the loop, we should not
// predicate based on the latch check. For example, the latch check can be a
// very coarse grained check and there can be more fine grained exit checks
- // within the loop. We identify such unprofitable loops through BPI.
+ // within the loop.
bool isLoopProfitableToPredicate();
bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter);
public:
- LoopPredication(AliasAnalysis *AA, DominatorTree *DT,
- ScalarEvolution *SE, LoopInfo *LI,
- BranchProbabilityInfo *BPI)
- : AA(AA), DT(DT), SE(SE), LI(LI), BPI(BPI) {};
+ LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE,
+ LoopInfo *LI, MemorySSAUpdater *MSSAU)
+ : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){};
bool runOnLoop(Loop *L);
};
@@ -325,6 +326,7 @@ public:
void getAnalysisUsage(AnalysisUsage &AU) const override {
AU.addRequired<BranchProbabilityInfoWrapperPass>();
getLoopAnalysisUsage(AU);
+ AU.addPreserved<MemorySSAWrapperPass>();
}
bool runOnLoop(Loop *L, LPPassManager &LPM) override {
@@ -333,10 +335,12 @@ public:
auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
- BranchProbabilityInfo &BPI =
- getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI();
+ auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>();
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
+ if (MSSAWP)
+ MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA());
auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
- LoopPredication LP(AA, DT, SE, LI, &BPI);
+ LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr);
return LP.runOnLoop(L);
}
};
@@ -358,16 +362,18 @@ Pass *llvm::createLoopPredicationPass() {
PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM,
LoopStandardAnalysisResults &AR,
LPMUpdater &U) {
- Function *F = L.getHeader()->getParent();
- // For the new PM, we also can't use BranchProbabilityInfo as an analysis
- // pass. Function analyses need to be preserved across loop transformations
- // but BPI is not preserved, hence a newly built one is needed.
- BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI, &AR.DT, nullptr);
- LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI);
+ std::unique_ptr<MemorySSAUpdater> MSSAU;
+ if (AR.MSSA)
+ MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA);
+ LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI,
+ MSSAU ? MSSAU.get() : nullptr);
if (!LP.runOnLoop(&L))
return PreservedAnalyses::all();
- return getLoopPassPreservedAnalyses();
+ auto PA = getLoopPassPreservedAnalyses();
+ if (AR.MSSA)
+ PA.preserve<MemorySSAAnalysis>();
+ return PA;
}
Optional<LoopICmp>
@@ -809,7 +815,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard,
Value *AllChecks = Builder.CreateAnd(Checks);
auto *OldCond = Guard->getOperand(0);
Guard->setOperand(0, AllChecks);
- RecursivelyDeleteTriviallyDeadInstructions(OldCond);
+ RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n");
return true;
@@ -835,7 +841,7 @@ bool LoopPredication::widenWidenableBranchGuardConditions(
Value *AllChecks = Builder.CreateAnd(Checks);
auto *OldCond = BI->getCondition();
BI->setCondition(AllChecks);
- RecursivelyDeleteTriviallyDeadInstructions(OldCond);
+ RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU);
assert(isGuardAsWidenableBranch(BI) &&
"Stopped being a guard after transform?");
@@ -912,7 +918,7 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() {
bool LoopPredication::isLoopProfitableToPredicate() {
- if (SkipProfitabilityChecks || !BPI)
+ if (SkipProfitabilityChecks)
return true;
SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges;
@@ -934,8 +940,61 @@ bool LoopPredication::isLoopProfitableToPredicate() {
"expected to be an exiting block with 2 succs!");
unsigned LatchBrExitIdx =
LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0;
+ // We compute branch probabilities without BPI. We do not rely on BPI since
+ // Loop predication is usually run in an LPM and BPI is only preserved
+ // lossily within loop pass managers, while BPI has an inherent notion of
+ // being complete for an entire function.
+
+ // If the latch exits into a deoptimize or an unreachable block, do not
+ // predicate on that latch check.
+ auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx);
+ if (isa<UnreachableInst>(LatchTerm) ||
+ LatchExitBlock->getTerminatingDeoptimizeCall())
+ return false;
+
+ auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) {
+ if (!ProfileData || !ProfileData->getOperand(0))
+ return false;
+ if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0)))
+ if (!MDS->getString().equals("branch_weights"))
+ return false;
+ if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors())
+ return false;
+ return true;
+ };
+ MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof);
+ // Latch terminator has no valid profile data, so nothing to check
+ // profitability on.
+ if (!IsValidProfileData(LatchProfileData, LatchTerm))
+ return true;
+
+ auto ComputeBranchProbability =
+ [&](const BasicBlock *ExitingBlock,
+ const BasicBlock *ExitBlock) -> BranchProbability {
+ auto *Term = ExitingBlock->getTerminator();
+ MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof);
+ unsigned NumSucc = Term->getNumSuccessors();
+ if (IsValidProfileData(ProfileData, Term)) {
+ uint64_t Numerator = 0, Denominator = 0, ProfVal = 0;
+ for (unsigned i = 0; i < NumSucc; i++) {
+ ConstantInt *CI =
+ mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1));
+ ProfVal = CI->getValue().getZExtValue();
+ if (Term->getSuccessor(i) == ExitBlock)
+ Numerator += ProfVal;
+ Denominator += ProfVal;
+ }
+ return BranchProbability::getBranchProbability(Numerator, Denominator);
+ } else {
+ assert(LatchBlock != ExitingBlock &&
+ "Latch term should always have profile data!");
+ // No profile data, so we choose the weight as 1/num_of_succ(Src)
+ return BranchProbability::getBranchProbability(1, NumSucc);
+ }
+ };
+
BranchProbability LatchExitProbability =
- BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx);
+ ComputeBranchProbability(LatchBlock, LatchExitBlock);
// Protect against degenerate inputs provided by the user. Providing a value
// less than one, can invert the definition of profitable loop predication.
@@ -948,18 +1007,18 @@ bool LoopPredication::isLoopProfitableToPredicate() {
LLVM_DEBUG(dbgs() << "The value is set to 1.0\n");
ScaleFactor = 1.0;
}
- const auto LatchProbabilityThreshold =
- LatchExitProbability * ScaleFactor;
+ const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor;
for (const auto &ExitEdge : ExitEdges) {
BranchProbability ExitingBlockProbability =
- BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second);
+ ComputeBranchProbability(ExitEdge.first, ExitEdge.second);
// Some exiting edge has higher probability than the latch exiting edge.
// No longer profitable to predicate.
if (ExitingBlockProbability > LatchProbabilityThreshold)
return false;
}
- // Using BPI, we have concluded that the most probable way to exit from the
+
+ // We have concluded that the most probable way to exit from the
// loop is through the latch (or there's no profile information and all
// exits are equally likely).
return true;
@@ -1071,28 +1130,26 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
// widen so that we gain ability to analyze it's exit count and perform this
// transform. TODO: It'd be nice to know for sure the exit became
// analyzeable after dropping widenability.
- {
- bool Invalidate = false;
+ bool ChangedLoop = false;
- for (auto *ExitingBB : ExitingBlocks) {
- if (LI->getLoopFor(ExitingBB) != L)
- continue;
+ for (auto *ExitingBB : ExitingBlocks) {
+ if (LI->getLoopFor(ExitingBB) != L)
+ continue;
- auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
- if (!BI)
- continue;
+ auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator());
+ if (!BI)
+ continue;
- Use *Cond, *WC;
- BasicBlock *IfTrueBB, *IfFalseBB;
- if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
- L->contains(IfTrueBB)) {
- WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
- Invalidate = true;
- }
+ Use *Cond, *WC;
+ BasicBlock *IfTrueBB, *IfFalseBB;
+ if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) &&
+ L->contains(IfTrueBB)) {
+ WC->set(ConstantInt::getTrue(IfTrueBB->getContext()));
+ ChangedLoop = true;
}
- if (Invalidate)
- SE->forgetLoop(L);
}
+ if (ChangedLoop)
+ SE->forgetLoop(L);
// The use of umin(all analyzeable exits) instead of latch is subtle, but
// important for profitability. We may have a loop which hasn't been fully
@@ -1104,18 +1161,24 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() ||
!SE->isLoopInvariant(MinEC, L) ||
!isSafeToExpandAt(MinEC, WidenableBR, *SE))
- return false;
+ return ChangedLoop;
// Subtlety: We need to avoid inserting additional uses of the WC. We know
// that it can only have one transitive use at the moment, and thus moving
// that use to just before the branch and inserting code before it and then
// modifying the operand is legal.
auto *IP = cast<Instruction>(WidenableBR->getCondition());
+ // Here we unconditionally modify the IR, so after this point we should return
+ // only `true`!
IP->moveBefore(WidenableBR);
+ if (MSSAU)
+ if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP))
+ MSSAU->moveToPlace(MUD, WidenableBR->getParent(),
+ MemorySSA::BeforeTerminator);
Rewriter.setInsertPoint(IP);
IRBuilder<> B(IP);
- bool Changed = false;
+ bool InvalidateLoop = false;
Value *MinECV = nullptr; // lazily generated if needed
for (BasicBlock *ExitingBB : ExitingBlocks) {
// If our exiting block exits multiple loops, we can only rewrite the
@@ -1172,16 +1235,18 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) {
Value *OldCond = BI->getCondition();
BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue));
- Changed = true;
+ InvalidateLoop = true;
}
- if (Changed)
+ if (InvalidateLoop)
// We just mutated a bunch of loop exits changing there exit counts
// widely. We need to force recomputation of the exit counts given these
// changes. Note that all of the inserted exits are never taken, and
// should be removed next time the CFG is modified.
SE->forgetLoop(L);
- return Changed;
+
+ // Always return `true` since we have moved the WidenableBR's condition.
+ return true;
}
bool LoopPredication::runOnLoop(Loop *Loop) {
@@ -1242,5 +1307,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) {
for (auto *Guard : GuardsAsWidenableBranches)
Changed |= widenWidenableBranchGuardConditions(Guard, Expander);
Changed |= predicateLoopExits(L, Expander);
+
+ if (MSSAU && VerifyMemorySSA)
+ MSSAU->getMemorySSA()->verifyMemorySSA();
return Changed;
}