diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopPredication.cpp')
| -rw-r--r-- | llvm/lib/Transforms/Scalar/LoopPredication.cpp | 160 |
1 files changed, 114 insertions, 46 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/llvm/lib/Transforms/Scalar/LoopPredication.cpp index 4f97641e2027..aa7e79a589f2 100644 --- a/llvm/lib/Transforms/Scalar/LoopPredication.cpp +++ b/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -183,6 +183,8 @@ #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/IR/Function.h" @@ -254,7 +256,7 @@ class LoopPredication { DominatorTree *DT; ScalarEvolution *SE; LoopInfo *LI; - BranchProbabilityInfo *BPI; + MemorySSAUpdater *MSSAU; Loop *L; const DataLayout *DL; @@ -302,16 +304,15 @@ class LoopPredication { // If the loop always exits through another block in the loop, we should not // predicate based on the latch check. For example, the latch check can be a // very coarse grained check and there can be more fine grained exit checks - // within the loop. We identify such unprofitable loops through BPI. + // within the loop. bool isLoopProfitableToPredicate(); bool predicateLoopExits(Loop *L, SCEVExpander &Rewriter); public: - LoopPredication(AliasAnalysis *AA, DominatorTree *DT, - ScalarEvolution *SE, LoopInfo *LI, - BranchProbabilityInfo *BPI) - : AA(AA), DT(DT), SE(SE), LI(LI), BPI(BPI) {}; + LoopPredication(AliasAnalysis *AA, DominatorTree *DT, ScalarEvolution *SE, + LoopInfo *LI, MemorySSAUpdater *MSSAU) + : AA(AA), DT(DT), SE(SE), LI(LI), MSSAU(MSSAU){}; bool runOnLoop(Loop *L); }; @@ -325,6 +326,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BranchProbabilityInfoWrapperPass>(); getLoopAnalysisUsage(AU); + AU.addPreserved<MemorySSAWrapperPass>(); } bool runOnLoop(Loop *L, LPPassManager &LPM) override { @@ -333,10 +335,12 @@ public: auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); - BranchProbabilityInfo &BPI = - getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + auto *MSSAWP = getAnalysisIfAvailable<MemorySSAWrapperPass>(); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSAWP) + MSSAU = std::make_unique<MemorySSAUpdater>(&MSSAWP->getMSSA()); auto *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - LoopPredication LP(AA, DT, SE, LI, &BPI); + LoopPredication LP(AA, DT, SE, LI, MSSAU ? MSSAU.get() : nullptr); return LP.runOnLoop(L); } }; @@ -358,16 +362,18 @@ Pass *llvm::createLoopPredicationPass() { PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &U) { - Function *F = L.getHeader()->getParent(); - // For the new PM, we also can't use BranchProbabilityInfo as an analysis - // pass. Function analyses need to be preserved across loop transformations - // but BPI is not preserved, hence a newly built one is needed. - BranchProbabilityInfo BPI(*F, AR.LI, &AR.TLI, &AR.DT, nullptr); - LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, &BPI); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (AR.MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(AR.MSSA); + LoopPredication LP(&AR.AA, &AR.DT, &AR.SE, &AR.LI, + MSSAU ? MSSAU.get() : nullptr); if (!LP.runOnLoop(&L)) return PreservedAnalyses::all(); - return getLoopPassPreservedAnalyses(); + auto PA = getLoopPassPreservedAnalyses(); + if (AR.MSSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; } Optional<LoopICmp> @@ -809,7 +815,7 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = Guard->getOperand(0); Guard->setOperand(0, AllChecks); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); + RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); LLVM_DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); return true; @@ -835,7 +841,7 @@ bool LoopPredication::widenWidenableBranchGuardConditions( Value *AllChecks = Builder.CreateAnd(Checks); auto *OldCond = BI->getCondition(); BI->setCondition(AllChecks); - RecursivelyDeleteTriviallyDeadInstructions(OldCond); + RecursivelyDeleteTriviallyDeadInstructions(OldCond, nullptr /* TLI */, MSSAU); assert(isGuardAsWidenableBranch(BI) && "Stopped being a guard after transform?"); @@ -912,7 +918,7 @@ Optional<LoopICmp> LoopPredication::parseLoopLatchICmp() { bool LoopPredication::isLoopProfitableToPredicate() { - if (SkipProfitabilityChecks || !BPI) + if (SkipProfitabilityChecks) return true; SmallVector<std::pair<BasicBlock *, BasicBlock *>, 8> ExitEdges; @@ -934,8 +940,61 @@ bool LoopPredication::isLoopProfitableToPredicate() { "expected to be an exiting block with 2 succs!"); unsigned LatchBrExitIdx = LatchTerm->getSuccessor(0) == L->getHeader() ? 1 : 0; + // We compute branch probabilities without BPI. We do not rely on BPI since + // Loop predication is usually run in an LPM and BPI is only preserved + // lossily within loop pass managers, while BPI has an inherent notion of + // being complete for an entire function. + + // If the latch exits into a deoptimize or an unreachable block, do not + // predicate on that latch check. + auto *LatchExitBlock = LatchTerm->getSuccessor(LatchBrExitIdx); + if (isa<UnreachableInst>(LatchTerm) || + LatchExitBlock->getTerminatingDeoptimizeCall()) + return false; + + auto IsValidProfileData = [](MDNode *ProfileData, const Instruction *Term) { + if (!ProfileData || !ProfileData->getOperand(0)) + return false; + if (MDString *MDS = dyn_cast<MDString>(ProfileData->getOperand(0))) + if (!MDS->getString().equals("branch_weights")) + return false; + if (ProfileData->getNumOperands() != 1 + Term->getNumSuccessors()) + return false; + return true; + }; + MDNode *LatchProfileData = LatchTerm->getMetadata(LLVMContext::MD_prof); + // Latch terminator has no valid profile data, so nothing to check + // profitability on. + if (!IsValidProfileData(LatchProfileData, LatchTerm)) + return true; + + auto ComputeBranchProbability = + [&](const BasicBlock *ExitingBlock, + const BasicBlock *ExitBlock) -> BranchProbability { + auto *Term = ExitingBlock->getTerminator(); + MDNode *ProfileData = Term->getMetadata(LLVMContext::MD_prof); + unsigned NumSucc = Term->getNumSuccessors(); + if (IsValidProfileData(ProfileData, Term)) { + uint64_t Numerator = 0, Denominator = 0, ProfVal = 0; + for (unsigned i = 0; i < NumSucc; i++) { + ConstantInt *CI = + mdconst::extract<ConstantInt>(ProfileData->getOperand(i + 1)); + ProfVal = CI->getValue().getZExtValue(); + if (Term->getSuccessor(i) == ExitBlock) + Numerator += ProfVal; + Denominator += ProfVal; + } + return BranchProbability::getBranchProbability(Numerator, Denominator); + } else { + assert(LatchBlock != ExitingBlock && + "Latch term should always have profile data!"); + // No profile data, so we choose the weight as 1/num_of_succ(Src) + return BranchProbability::getBranchProbability(1, NumSucc); + } + }; + BranchProbability LatchExitProbability = - BPI->getEdgeProbability(LatchBlock, LatchBrExitIdx); + ComputeBranchProbability(LatchBlock, LatchExitBlock); // Protect against degenerate inputs provided by the user. Providing a value // less than one, can invert the definition of profitable loop predication. @@ -948,18 +1007,18 @@ bool LoopPredication::isLoopProfitableToPredicate() { LLVM_DEBUG(dbgs() << "The value is set to 1.0\n"); ScaleFactor = 1.0; } - const auto LatchProbabilityThreshold = - LatchExitProbability * ScaleFactor; + const auto LatchProbabilityThreshold = LatchExitProbability * ScaleFactor; for (const auto &ExitEdge : ExitEdges) { BranchProbability ExitingBlockProbability = - BPI->getEdgeProbability(ExitEdge.first, ExitEdge.second); + ComputeBranchProbability(ExitEdge.first, ExitEdge.second); // Some exiting edge has higher probability than the latch exiting edge. // No longer profitable to predicate. if (ExitingBlockProbability > LatchProbabilityThreshold) return false; } - // Using BPI, we have concluded that the most probable way to exit from the + + // We have concluded that the most probable way to exit from the // loop is through the latch (or there's no profile information and all // exits are equally likely). return true; @@ -1071,28 +1130,26 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { // widen so that we gain ability to analyze it's exit count and perform this // transform. TODO: It'd be nice to know for sure the exit became // analyzeable after dropping widenability. - { - bool Invalidate = false; + bool ChangedLoop = false; - for (auto *ExitingBB : ExitingBlocks) { - if (LI->getLoopFor(ExitingBB) != L) - continue; + for (auto *ExitingBB : ExitingBlocks) { + if (LI->getLoopFor(ExitingBB) != L) + continue; - auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); - if (!BI) - continue; + auto *BI = dyn_cast<BranchInst>(ExitingBB->getTerminator()); + if (!BI) + continue; - Use *Cond, *WC; - BasicBlock *IfTrueBB, *IfFalseBB; - if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) && - L->contains(IfTrueBB)) { - WC->set(ConstantInt::getTrue(IfTrueBB->getContext())); - Invalidate = true; - } + Use *Cond, *WC; + BasicBlock *IfTrueBB, *IfFalseBB; + if (parseWidenableBranch(BI, Cond, WC, IfTrueBB, IfFalseBB) && + L->contains(IfTrueBB)) { + WC->set(ConstantInt::getTrue(IfTrueBB->getContext())); + ChangedLoop = true; } - if (Invalidate) - SE->forgetLoop(L); } + if (ChangedLoop) + SE->forgetLoop(L); // The use of umin(all analyzeable exits) instead of latch is subtle, but // important for profitability. We may have a loop which hasn't been fully @@ -1104,18 +1161,24 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { if (isa<SCEVCouldNotCompute>(MinEC) || MinEC->getType()->isPointerTy() || !SE->isLoopInvariant(MinEC, L) || !isSafeToExpandAt(MinEC, WidenableBR, *SE)) - return false; + return ChangedLoop; // Subtlety: We need to avoid inserting additional uses of the WC. We know // that it can only have one transitive use at the moment, and thus moving // that use to just before the branch and inserting code before it and then // modifying the operand is legal. auto *IP = cast<Instruction>(WidenableBR->getCondition()); + // Here we unconditionally modify the IR, so after this point we should return + // only `true`! IP->moveBefore(WidenableBR); + if (MSSAU) + if (auto *MUD = MSSAU->getMemorySSA()->getMemoryAccess(IP)) + MSSAU->moveToPlace(MUD, WidenableBR->getParent(), + MemorySSA::BeforeTerminator); Rewriter.setInsertPoint(IP); IRBuilder<> B(IP); - bool Changed = false; + bool InvalidateLoop = false; Value *MinECV = nullptr; // lazily generated if needed for (BasicBlock *ExitingBB : ExitingBlocks) { // If our exiting block exits multiple loops, we can only rewrite the @@ -1172,16 +1235,18 @@ bool LoopPredication::predicateLoopExits(Loop *L, SCEVExpander &Rewriter) { Value *OldCond = BI->getCondition(); BI->setCondition(ConstantInt::get(OldCond->getType(), !ExitIfTrue)); - Changed = true; + InvalidateLoop = true; } - if (Changed) + if (InvalidateLoop) // We just mutated a bunch of loop exits changing there exit counts // widely. We need to force recomputation of the exit counts given these // changes. Note that all of the inserted exits are never taken, and // should be removed next time the CFG is modified. SE->forgetLoop(L); - return Changed; + + // Always return `true` since we have moved the WidenableBR's condition. + return true; } bool LoopPredication::runOnLoop(Loop *Loop) { @@ -1242,5 +1307,8 @@ bool LoopPredication::runOnLoop(Loop *Loop) { for (auto *Guard : GuardsAsWidenableBranches) Changed |= widenWidenableBranchGuardConditions(Guard, Expander); Changed |= predicateLoopExits(L, Expander); + + if (MSSAU && VerifyMemorySSA) + MSSAU->getMemorySSA()->verifyMemorySSA(); return Changed; } |
