diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-08-22 19:00:43 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-13 20:39:49 +0000 |
commit | fe6060f10f634930ff71b7c50291ddc610da2475 (patch) | |
tree | 1483580c790bd4d27b6500a7542b5ee00534d3cc /contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp | |
parent | b61bce17f346d79cecfd8f195a64b10f77be43b1 (diff) | |
parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp')
-rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp | 402 |
1 files changed, 271 insertions, 131 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index 9d3c8d0f3739..b9cccc2af309 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -38,6 +38,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/InitializePasses.h" @@ -63,6 +64,7 @@ #define DEBUG_TYPE "simple-loop-unswitch" using namespace llvm; +using namespace llvm::PatternMatch; STATISTIC(NumBranches, "Number of branches unswitched"); STATISTIC(NumSwitches, "Number of switches unswitched"); @@ -101,6 +103,11 @@ static cl::opt<bool> DropNonTrivialImplicitNullChecks( cl::init(false), cl::Hidden, cl::desc("If enabled, drop make.implicit metadata in unswitched implicit " "null checks to save time analyzing if we can keep it.")); +static cl::opt<unsigned> + MSSAThreshold("simple-loop-unswitch-memoryssa-threshold", + cl::desc("Max number of memory uses to explore during " + "partial unswitching analysis"), + cl::init(100), cl::Hidden); /// Collect all of the loop invariant input values transitively used by the /// homogeneous instruction graph from a given root. @@ -116,6 +123,9 @@ collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, "Only need to walk the graph if root itself is not invariant."); TinyPtrVector<Value *> Invariants; + bool IsRootAnd = match(&Root, m_LogicalAnd()); + bool IsRootOr = match(&Root, m_LogicalOr()); + // Build a worklist and recurse through operators collecting invariants. SmallVector<Instruction *, 4> Worklist; SmallPtrSet<Instruction *, 8> Visited; @@ -136,12 +146,13 @@ collectHomogenousInstGraphLoopInvariants(Loop &L, Instruction &Root, // If not an instruction with the same opcode, nothing we can do. Instruction *OpI = dyn_cast<Instruction>(OpV); - if (!OpI || OpI->getOpcode() != Root.getOpcode()) - continue; - // Visit this operand. - if (Visited.insert(OpI).second) - Worklist.push_back(OpI); + if (OpI && ((IsRootAnd && match(OpI, m_LogicalAnd())) || + (IsRootOr && match(OpI, m_LogicalOr())))) { + // Visit this operand. + if (Visited.insert(OpI).second) + Worklist.push_back(OpI); + } } } while (!Worklist.empty()); @@ -153,14 +164,13 @@ static void replaceLoopInvariantUses(Loop &L, Value *Invariant, assert(!isa<Constant>(Invariant) && "Why are we unswitching on a constant?"); // Replace uses of LIC in the loop with the given constant. - for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); UI != UE;) { - // Grab the use and walk past it so we can clobber it in the use list. - Use *U = &*UI++; - Instruction *UserI = dyn_cast<Instruction>(U->getUser()); + // We use make_early_inc_range as set invalidates the iterator. + for (Use &U : llvm::make_early_inc_range(Invariant->uses())) { + Instruction *UserI = dyn_cast<Instruction>(U.getUser()); // Replace this use within the loop body. if (UserI && L.contains(UserI)) - U->set(&Replacement); + U.set(&Replacement); } } @@ -182,8 +192,9 @@ static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, llvm_unreachable("Basic blocks should never be empty!"); } -/// Insert code to test a set of loop invariant values, and conditionally branch -/// on them. +/// Copy a set of loop invariant values \p ToDuplicate and insert them at the +/// end of \p BB and conditionally branch on the copied condition. We only +/// branch on a single value. static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, ArrayRef<Value *> Invariants, bool Direction, @@ -197,6 +208,49 @@ static void buildPartialUnswitchConditionalBranch(BasicBlock &BB, Direction ? &NormalSucc : &UnswitchedSucc); } +/// Copy a set of loop invariant values, and conditionally branch on them. +static void buildPartialInvariantUnswitchConditionalBranch( + BasicBlock &BB, ArrayRef<Value *> ToDuplicate, bool Direction, + BasicBlock &UnswitchedSucc, BasicBlock &NormalSucc, Loop &L, + MemorySSAUpdater *MSSAU) { + ValueToValueMapTy VMap; + for (auto *Val : reverse(ToDuplicate)) { + Instruction *Inst = cast<Instruction>(Val); + Instruction *NewInst = Inst->clone(); + BB.getInstList().insert(BB.end(), NewInst); + RemapInstruction(NewInst, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + VMap[Val] = NewInst; + + if (!MSSAU) + continue; + + MemorySSA *MSSA = MSSAU->getMemorySSA(); + if (auto *MemUse = + dyn_cast_or_null<MemoryUse>(MSSA->getMemoryAccess(Inst))) { + auto *DefiningAccess = MemUse->getDefiningAccess(); + // Get the first defining access before the loop. + while (L.contains(DefiningAccess->getBlock())) { + // If the defining access is a MemoryPhi, get the incoming + // value for the pre-header as defining access. + if (auto *MemPhi = dyn_cast<MemoryPhi>(DefiningAccess)) + DefiningAccess = + MemPhi->getIncomingValueForBlock(L.getLoopPreheader()); + else + DefiningAccess = cast<MemoryDef>(DefiningAccess)->getDefiningAccess(); + } + MSSAU->createMemoryAccessInBB(NewInst, DefiningAccess, + NewInst->getParent(), + MemorySSA::BeforeTerminator); + } + } + + IRBuilder<> IRB(&BB); + Value *Cond = VMap[ToDuplicate[0]]; + IRB.CreateCondBr(Cond, Direction ? &UnswitchedSucc : &NormalSucc, + Direction ? &NormalSucc : &UnswitchedSucc); +} + /// Rewrite the PHI nodes in an unswitched loop exit basic block. /// /// Requires that the loop exit and unswitched basic block are the same, and @@ -366,7 +420,7 @@ static Loop *getTopMostExitingLoop(BasicBlock *ExitBB, LoopInfo &LI) { /// hoists the branch above that split. Preserves loop simplified form /// (splitting the exit block as necessary). It simplifies the branch within /// the loop to an unconditional branch but doesn't remove it entirely. Further -/// cleanup can be done with some simplify-cfg like pass. +/// cleanup can be done with some simplifycfg like pass. /// /// If `SE` is not null, it will be updated based on the potential loop SCEVs /// invalidated by this. @@ -389,9 +443,10 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, } else { if (auto *CondInst = dyn_cast<Instruction>(BI.getCondition())) Invariants = collectHomogenousInstGraphLoopInvariants(L, *CondInst, LI); - if (Invariants.empty()) - // Couldn't find invariant inputs! + if (Invariants.empty()) { + LLVM_DEBUG(dbgs() << " Couldn't find invariant inputs!\n"); return false; + } } // Check that one of the branch's successors exits, and which one. @@ -402,13 +457,17 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, ExitDirection = false; LoopExitSuccIdx = 1; LoopExitBB = BI.getSuccessor(1); - if (L.contains(LoopExitBB)) + if (L.contains(LoopExitBB)) { + LLVM_DEBUG(dbgs() << " Branch doesn't exit the loop!\n"); return false; + } } auto *ContinueBB = BI.getSuccessor(1 - LoopExitSuccIdx); auto *ParentBB = BI.getParent(); - if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) + if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) { + LLVM_DEBUG(dbgs() << " Loop exit PHI's aren't loop-invariant!\n"); return false; + } // When unswitching only part of the branch's condition, we need the exit // block to be reached directly from the partially unswitched input. This can @@ -416,12 +475,11 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // is a graph of `or` operations, or the exit block is along the false edge // and the condition is a graph of `and` operations. if (!FullUnswitch) { - if (ExitDirection) { - if (cast<Instruction>(BI.getCondition())->getOpcode() != Instruction::Or) - return false; - } else { - if (cast<Instruction>(BI.getCondition())->getOpcode() != Instruction::And) - return false; + if (ExitDirection ? !match(BI.getCondition(), m_LogicalOr()) + : !match(BI.getCondition(), m_LogicalAnd())) { + LLVM_DEBUG(dbgs() << " Branch condition is in improper form for " + "non-full unswitch!\n"); + return false; } } @@ -498,13 +556,13 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // Only unswitching a subset of inputs to the condition, so we will need to // build a new branch that merges the invariant inputs. if (ExitDirection) - assert(cast<Instruction>(BI.getCondition())->getOpcode() == - Instruction::Or && - "Must have an `or` of `i1`s for the condition!"); + assert(match(BI.getCondition(), m_LogicalOr()) && + "Must have an `or` of `i1`s or `select i1 X, true, Y`s for the " + "condition!"); else - assert(cast<Instruction>(BI.getCondition())->getOpcode() == - Instruction::And && - "Must have an `and` of `i1`s for the condition!"); + assert(match(BI.getCondition(), m_LogicalAnd()) && + "Must have an `and` of `i1`s or `select i1 X, Y, false`s for the" + " condition!"); buildPartialUnswitchConditionalBranch(*OldPH, Invariants, ExitDirection, *UnswitchedBB, *NewPH); } @@ -590,7 +648,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, /// considered for unswitching so this is a stable transform and the same /// switch will not be revisited. If after unswitching there is only a single /// in-loop successor, the switch is further simplified to an unconditional -/// branch. Still more cleanup can be done with some simplify-cfg like pass. +/// branch. Still more cleanup can be done with some simplifycfg like pass. /// /// If `SE` is not null, it will be updated based on the potential loop SCEVs /// invalidated by this. @@ -925,7 +983,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, if (auto *SI = dyn_cast<SwitchInst>(CurrentTerm)) { // Don't bother trying to unswitch past a switch with a constant // condition. This should be removed prior to running this pass by - // simplify-cfg. + // simplifycfg. if (isa<Constant>(SI->getCondition())) return Changed; @@ -954,7 +1012,7 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, return Changed; // Don't bother trying to unswitch past an unconditional branch or a branch - // with a constant value. These should be removed by simplify-cfg prior to + // with a constant value. These should be removed by simplifycfg prior to // running this pass. if (!BI->isConditional() || isa<Constant>(BI->getCondition())) return Changed; @@ -1108,9 +1166,8 @@ static BasicBlock *buildClonedLoopBlocks( for (Instruction &I : *ClonedBB) { RemapInstruction(&I, VMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); - if (auto *II = dyn_cast<IntrinsicInst>(&I)) - if (II->getIntrinsicID() == Intrinsic::assume) - AC.registerAssumption(II); + if (auto *II = dyn_cast<AssumeInst>(&I)) + AC.registerAssumption(II); } // Update any PHI nodes in the cloned successors of the skipped blocks to not @@ -1959,18 +2016,22 @@ void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { static void unswitchNontrivialInvariants( Loop &L, Instruction &TI, ArrayRef<Value *> Invariants, - SmallVectorImpl<BasicBlock *> &ExitBlocks, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, + SmallVectorImpl<BasicBlock *> &ExitBlocks, IVConditionInfo &PartialIVInfo, + DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { auto *ParentBB = TI.getParent(); BranchInst *BI = dyn_cast<BranchInst>(&TI); SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI); // We can only unswitch switches, conditional branches with an invariant - // condition, or combining invariant conditions with an instruction. + // condition, or combining invariant conditions with an instruction or + // partially invariant instructions. assert((SI || (BI && BI->isConditional())) && "Can only unswitch switches and conditional branch!"); - bool FullUnswitch = SI || BI->getCondition() == Invariants[0]; + bool PartiallyInvariant = !PartialIVInfo.InstToDuplicate.empty(); + bool FullUnswitch = + SI || (BI->getCondition() == Invariants[0] && !PartiallyInvariant); if (FullUnswitch) assert(Invariants.size() == 1 && "Cannot have other invariants with full unswitching!"); @@ -1984,19 +2045,24 @@ static void unswitchNontrivialInvariants( // Constant and BBs tracking the cloned and continuing successor. When we are // unswitching the entire condition, this can just be trivially chosen to // unswitch towards `true`. However, when we are unswitching a set of - // invariants combined with `and` or `or`, the combining operation determines - // the best direction to unswitch: we want to unswitch the direction that will - // collapse the branch. + // invariants combined with `and` or `or` or partially invariant instructions, + // the combining operation determines the best direction to unswitch: we want + // to unswitch the direction that will collapse the branch. bool Direction = true; int ClonedSucc = 0; if (!FullUnswitch) { - if (cast<Instruction>(BI->getCondition())->getOpcode() != Instruction::Or) { - assert(cast<Instruction>(BI->getCondition())->getOpcode() == - Instruction::And && - "Only `or` and `and` instructions can combine invariants being " - "unswitched."); - Direction = false; - ClonedSucc = 1; + Value *Cond = BI->getCondition(); + (void)Cond; + assert(((match(Cond, m_LogicalAnd()) ^ match(Cond, m_LogicalOr())) || + PartiallyInvariant) && + "Only `or`, `and`, an `select`, partially invariant instructions " + "can combine invariants being unswitched."); + if (!match(BI->getCondition(), m_LogicalOr())) { + if (match(BI->getCondition(), m_LogicalAnd()) || + (PartiallyInvariant && !PartialIVInfo.KnownValue->isOneValue())) { + Direction = false; + ClonedSucc = 1; + } } } @@ -2214,8 +2280,12 @@ static void unswitchNontrivialInvariants( BasicBlock *ClonedPH = ClonedPHs.begin()->second; // When doing a partial unswitch, we have to do a bit more work to build up // the branch in the split block. - buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, - *ClonedPH, *LoopPH); + if (PartiallyInvariant) + buildPartialInvariantUnswitchConditionalBranch( + *SplitBB, Invariants, Direction, *ClonedPH, *LoopPH, L, MSSAU); + else + buildPartialUnswitchConditionalBranch(*SplitBB, Invariants, Direction, + *ClonedPH, *LoopPH); DTUpdates.push_back({DominatorTree::Insert, SplitBB, ClonedPH}); if (MSSAU) { @@ -2267,7 +2337,7 @@ static void unswitchNontrivialInvariants( // verification steps. assert(DT.verify(DominatorTree::VerificationLevel::Fast)); - if (BI) { + if (BI && !PartiallyInvariant) { // If we unswitched a branch which collapses the condition to a known // constant we want to replace all the uses of the invariants within both // the original and cloned blocks. We do this here so that we can use the @@ -2285,7 +2355,8 @@ static void unswitchNontrivialInvariants( // for each invariant operand. // So it happens that for multiple-partial case we dont replace // in the unswitched branch. - bool ReplaceUnswitched = FullUnswitch || (Invariants.size() == 1); + bool ReplaceUnswitched = + FullUnswitch || (Invariants.size() == 1) || PartiallyInvariant; ConstantInt *UnswitchedReplacement = Direction ? ConstantInt::getTrue(BI->getContext()) @@ -2294,21 +2365,19 @@ static void unswitchNontrivialInvariants( Direction ? ConstantInt::getFalse(BI->getContext()) : ConstantInt::getTrue(BI->getContext()); for (Value *Invariant : Invariants) - for (auto UI = Invariant->use_begin(), UE = Invariant->use_end(); - UI != UE;) { - // Grab the use and walk past it so we can clobber it in the use list. - Use *U = &*UI++; - Instruction *UserI = dyn_cast<Instruction>(U->getUser()); + // Use make_early_inc_range here as set invalidates the iterator. + for (Use &U : llvm::make_early_inc_range(Invariant->uses())) { + Instruction *UserI = dyn_cast<Instruction>(U.getUser()); if (!UserI) continue; // Replace it with the 'continue' side if in the main loop body, and the // unswitched if in the cloned blocks. if (DT.dominates(LoopPH, UserI->getParent())) - U->set(ContinueReplacement); + U.set(ContinueReplacement); else if (ReplaceUnswitched && DT.dominates(ClonedPH, UserI->getParent())) - U->set(UnswitchedReplacement); + U.set(UnswitchedReplacement); } } @@ -2382,7 +2451,7 @@ static void unswitchNontrivialInvariants( for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) if (UpdatedL->getParentLoop() == ParentL) SibLoops.push_back(UpdatedL); - UnswitchCB(IsStillLoop, SibLoops); + UnswitchCB(IsStillLoop, PartiallyInvariant, SibLoops); if (MSSAU && VerifyMemorySSA) MSSAU->getMemorySSA()->verifyMemorySSA(); @@ -2399,10 +2468,10 @@ static void unswitchNontrivialInvariants( /// The recursive computation is memozied into the provided DT-indexed cost map /// to allow querying it for most nodes in the domtree without it becoming /// quadratic. -static int -computeDomSubtreeCost(DomTreeNode &N, - const SmallDenseMap<BasicBlock *, int, 4> &BBCostMap, - SmallDenseMap<DomTreeNode *, int, 4> &DTCostMap) { +static InstructionCost computeDomSubtreeCost( + DomTreeNode &N, + const SmallDenseMap<BasicBlock *, InstructionCost, 4> &BBCostMap, + SmallDenseMap<DomTreeNode *, InstructionCost, 4> &DTCostMap) { // Don't accumulate cost (or recurse through) blocks not in our block cost // map and thus not part of the duplication cost being considered. auto BBCostIt = BBCostMap.find(N.getBlock()); @@ -2416,8 +2485,9 @@ computeDomSubtreeCost(DomTreeNode &N, // If not, we have to compute it. We can't use insert above and update // because computing the cost may insert more things into the map. - int Cost = std::accumulate( - N.begin(), N.end(), BBCostIt->second, [&](int Sum, DomTreeNode *ChildN) { + InstructionCost Cost = std::accumulate( + N.begin(), N.end(), BBCostIt->second, + [&](InstructionCost Sum, DomTreeNode *ChildN) -> InstructionCost { return Sum + computeDomSubtreeCost(*ChildN, BBCostMap, DTCostMap); }); bool Inserted = DTCostMap.insert({&N, Cost}).second; @@ -2596,11 +2666,11 @@ static int CalculateUnswitchCostMultiplier( return CostMultiplier; } -static bool -unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, TargetTransformInfo &TTI, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { +static bool unswitchBestCondition( + Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, + function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { // Collect all invariant conditions within this loop (as opposed to an inner // loop which would be handled when visiting that inner loop). SmallVector<std::pair<Instruction *, TinyPtrVector<Value *>>, 4> @@ -2615,6 +2685,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, CollectGuards = true; } + IVConditionInfo PartialIVInfo; for (auto *BB : L.blocks()) { if (LI.getLoopFor(BB) != &L) continue; @@ -2642,22 +2713,48 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, BI->getSuccessor(0) == BI->getSuccessor(1)) continue; + // If BI's condition is 'select _, true, false', simplify it to confuse + // matchers + Value *Cond = BI->getCondition(), *CondNext; + while (match(Cond, m_Select(m_Value(CondNext), m_One(), m_Zero()))) + Cond = CondNext; + BI->setCondition(Cond); + if (L.isLoopInvariant(BI->getCondition())) { UnswitchCandidates.push_back({BI, {BI->getCondition()}}); continue; } Instruction &CondI = *cast<Instruction>(BI->getCondition()); - if (CondI.getOpcode() != Instruction::And && - CondI.getOpcode() != Instruction::Or) - continue; + if (match(&CondI, m_CombineOr(m_LogicalAnd(), m_LogicalOr()))) { + TinyPtrVector<Value *> Invariants = + collectHomogenousInstGraphLoopInvariants(L, CondI, LI); + if (Invariants.empty()) + continue; - TinyPtrVector<Value *> Invariants = - collectHomogenousInstGraphLoopInvariants(L, CondI, LI); - if (Invariants.empty()) + UnswitchCandidates.push_back({BI, std::move(Invariants)}); continue; + } + } - UnswitchCandidates.push_back({BI, std::move(Invariants)}); + Instruction *PartialIVCondBranch = nullptr; + if (MSSAU && !findOptionMDForLoop(&L, "llvm.loop.unswitch.partial.disable") && + !any_of(UnswitchCandidates, [&L](auto &TerminatorAndInvariants) { + return TerminatorAndInvariants.first == L.getHeader()->getTerminator(); + })) { + MemorySSA *MSSA = MSSAU->getMemorySSA(); + if (auto Info = hasPartialIVCondition(L, MSSAThreshold, *MSSA, AA)) { + LLVM_DEBUG( + dbgs() << "simple-loop-unswitch: Found partially invariant condition " + << *Info->InstToDuplicate[0] << "\n"); + PartialIVInfo = *Info; + PartialIVCondBranch = L.getHeader()->getTerminator(); + TinyPtrVector<Value *> ValsToDuplicate; + for (auto *Inst : Info->InstToDuplicate) + ValsToDuplicate.push_back(Inst); + UnswitchCandidates.push_back( + {L.getHeader()->getTerminator(), std::move(ValsToDuplicate)}); + } } // If we didn't find any candidates, we're done. @@ -2678,15 +2775,18 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, SmallVector<BasicBlock *, 4> ExitBlocks; L.getUniqueExitBlocks(ExitBlocks); - // We cannot unswitch if exit blocks contain a cleanuppad instruction as we - // don't know how to split those exit blocks. + // We cannot unswitch if exit blocks contain a cleanuppad/catchswitch + // instruction as we don't know how to split those exit blocks. // FIXME: We should teach SplitBlock to handle this and remove this // restriction. - for (auto *ExitBB : ExitBlocks) - if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) { - dbgs() << "Cannot unswitch because of cleanuppad in exit block\n"; + for (auto *ExitBB : ExitBlocks) { + auto *I = ExitBB->getFirstNonPHI(); + if (isa<CleanupPadInst>(I) || isa<CatchSwitchInst>(I)) { + LLVM_DEBUG(dbgs() << "Cannot unswitch because of cleanuppad/catchswitch " + "in exit block\n"); return false; } + } LLVM_DEBUG( dbgs() << "Considering " << UnswitchCandidates.size() @@ -2699,7 +2799,7 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, // subsets of the loop for duplication during unswitching. SmallPtrSet<const Value *, 4> EphValues; CodeMetrics::collectEphemeralValues(&L, &AC, EphValues); - SmallDenseMap<BasicBlock *, int, 4> BBCostMap; + SmallDenseMap<BasicBlock *, InstructionCost, 4> BBCostMap; // Compute the cost of each block, as well as the total loop cost. Also, bail // out if we see instructions which are incompatible with loop unswitching @@ -2710,9 +2810,9 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, L.getHeader()->getParent()->hasMinSize() ? TargetTransformInfo::TCK_CodeSize : TargetTransformInfo::TCK_SizeAndLatency; - int LoopCost = 0; + InstructionCost LoopCost = 0; for (auto *BB : L.blocks()) { - int Cost = 0; + InstructionCost Cost = 0; for (auto &I : *BB) { if (EphValues.count(&I)) continue; @@ -2746,37 +2846,38 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, // This requires memoizing each dominator subtree to avoid redundant work. // // FIXME: Need to actually do the number of candidates part above. - SmallDenseMap<DomTreeNode *, int, 4> DTCostMap; + SmallDenseMap<DomTreeNode *, InstructionCost, 4> DTCostMap; // Given a terminator which might be unswitched, computes the non-duplicated // cost for that terminator. - auto ComputeUnswitchedCost = [&](Instruction &TI, bool FullUnswitch) { + auto ComputeUnswitchedCost = [&](Instruction &TI, + bool FullUnswitch) -> InstructionCost { BasicBlock &BB = *TI.getParent(); SmallPtrSet<BasicBlock *, 4> Visited; - int Cost = LoopCost; + InstructionCost Cost = 0; for (BasicBlock *SuccBB : successors(&BB)) { // Don't count successors more than once. if (!Visited.insert(SuccBB).second) continue; // If this is a partial unswitch candidate, then it must be a conditional - // branch with a condition of either `or` or `and`. In that case, one of + // branch with a condition of either `or`, `and`, their corresponding + // select forms or partially invariant instructions. In that case, one of // the successors is necessarily duplicated, so don't even try to remove // its cost. if (!FullUnswitch) { auto &BI = cast<BranchInst>(TI); - if (cast<Instruction>(BI.getCondition())->getOpcode() == - Instruction::And) { + if (match(BI.getCondition(), m_LogicalAnd())) { if (SuccBB == BI.getSuccessor(1)) continue; - } else { - assert(cast<Instruction>(BI.getCondition())->getOpcode() == - Instruction::Or && - "Only `and` and `or` conditions can result in a partial " - "unswitch!"); + } else if (match(BI.getCondition(), m_LogicalOr())) { if (SuccBB == BI.getSuccessor(0)) continue; - } + } else if ((PartialIVInfo.KnownValue->isOneValue() && + SuccBB == BI.getSuccessor(0)) || + (!PartialIVInfo.KnownValue->isOneValue() && + SuccBB == BI.getSuccessor(1))) + continue; } // This successor's domtree will not need to be duplicated after @@ -2787,8 +2888,8 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, llvm::all_of(predecessors(SuccBB), [&](BasicBlock *PredBB) { return PredBB == &BB || DT.dominates(SuccBB, PredBB); })) { - Cost -= computeDomSubtreeCost(*DT[SuccBB], BBCostMap, DTCostMap); - assert(Cost >= 0 && + Cost += computeDomSubtreeCost(*DT[SuccBB], BBCostMap, DTCostMap); + assert(Cost <= LoopCost && "Non-duplicated cost should never exceed total loop cost!"); } } @@ -2801,16 +2902,16 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, int SuccessorsCount = isGuard(&TI) ? 2 : Visited.size(); assert(SuccessorsCount > 1 && "Cannot unswitch a condition without multiple distinct successors!"); - return Cost * (SuccessorsCount - 1); + return (LoopCost - Cost) * (SuccessorsCount - 1); }; Instruction *BestUnswitchTI = nullptr; - int BestUnswitchCost = 0; + InstructionCost BestUnswitchCost = 0; ArrayRef<Value *> BestUnswitchInvariants; for (auto &TerminatorAndInvariants : UnswitchCandidates) { Instruction &TI = *TerminatorAndInvariants.first; ArrayRef<Value *> Invariants = TerminatorAndInvariants.second; BranchInst *BI = dyn_cast<BranchInst>(&TI); - int CandidateCost = ComputeUnswitchedCost( + InstructionCost CandidateCost = ComputeUnswitchedCost( TI, /*FullUnswitch*/ !BI || (Invariants.size() == 1 && Invariants[0] == BI->getCondition())); // Calculate cost multiplier which is a tool to limit potentially @@ -2844,6 +2945,9 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, return false; } + if (BestUnswitchTI != PartialIVCondBranch) + PartialIVInfo.InstToDuplicate.clear(); + // If the best candidate is a guard, turn it into a branch. if (isGuard(BestUnswitchTI)) BestUnswitchTI = turnGuardIntoBranch(cast<IntrinsicInst>(BestUnswitchTI), L, @@ -2853,7 +2957,8 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, << BestUnswitchCost << ") terminator: " << *BestUnswitchTI << "\n"); unswitchNontrivialInvariants(L, *BestUnswitchTI, BestUnswitchInvariants, - ExitBlocks, DT, LI, AC, UnswitchCB, SE, MSSAU); + ExitBlocks, PartialIVInfo, DT, LI, AC, + UnswitchCB, SE, MSSAU); return true; } @@ -2864,9 +2969,9 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. /// -/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also -/// updated based on the unswitch. -/// The `MSSA` analysis is also updated if valid (i.e. its use is enabled). +/// The `DT`, `LI`, `AC`, `AA`, `TTI` parameters are required analyses that are +/// also updated based on the unswitch. The `MSSA` analysis is also updated if +/// valid (i.e. its use is enabled). /// /// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is /// true, we will attempt to do non-trivial unswitching as well as trivial @@ -2878,11 +2983,12 @@ unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI, /// /// If `SE` is non-null, we will update that analysis based on the unswitching /// done. -static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC, TargetTransformInfo &TTI, - bool NonTrivial, - function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB, - ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { +static bool +unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + AAResults &AA, TargetTransformInfo &TTI, bool Trivial, + bool NonTrivial, + function_ref<void(bool, bool, ArrayRef<Loop *>)> UnswitchCB, + ScalarEvolution *SE, MemorySSAUpdater *MSSAU) { assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); @@ -2891,23 +2997,37 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, return false; // Try trivial unswitch first before loop over other basic blocks in the loop. - if (unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { + if (Trivial && unswitchAllTrivialConditions(L, DT, LI, SE, MSSAU)) { // If we unswitched successfully we will want to clean up the loop before // processing it further so just mark it as unswitched and return. - UnswitchCB(/*CurrentLoopValid*/ true, {}); + UnswitchCB(/*CurrentLoopValid*/ true, false, {}); return true; } - // If we're not doing non-trivial unswitching, we're done. We both accept - // a parameter but also check a local flag that can be used for testing - // a debugging. - if (!NonTrivial && !EnableNonTrivialUnswitch) + // Check whether we should continue with non-trivial conditions. + // EnableNonTrivialUnswitch: Global variable that forces non-trivial + // unswitching for testing and debugging. + // NonTrivial: Parameter that enables non-trivial unswitching for this + // invocation of the transform. But this should be allowed only + // for targets without branch divergence. + // + // FIXME: If divergence analysis becomes available to a loop + // transform, we should allow unswitching for non-trivial uniform + // branches even on targets that have divergence. + // https://bugs.llvm.org/show_bug.cgi?id=48819 + bool ContinueWithNonTrivial = + EnableNonTrivialUnswitch || (NonTrivial && !TTI.hasBranchDivergence()); + if (!ContinueWithNonTrivial) return false; // Skip non-trivial unswitching for optsize functions. if (L.getHeader()->getParent()->hasOptSize()) return false; + // Skip non-trivial unswitching for loops that cannot be cloned. + if (!L.isSafeToClone()) + return false; + // For non-trivial unswitching, because it often creates new loops, we rely on // the pass manager to iterate on the loops rather than trying to immediately // reach a fixed point. There is no substantial advantage to iterating @@ -2916,7 +3036,7 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, // Try to unswitch the best invariant condition. We prefer this full unswitch to // a partial unswitch when possible below the threshold. - if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE, MSSAU)) + if (unswitchBestCondition(L, DT, LI, AC, AA, TTI, UnswitchCB, SE, MSSAU)) return true; // No other opportunities to unswitch. @@ -2937,6 +3057,7 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, std::string LoopName = std::string(L.getName()); auto UnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, + bool PartiallyInvariant, ArrayRef<Loop *> NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. if (!NewLoops.empty()) @@ -2944,9 +3065,21 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, // If the current loop remains valid, we should revisit it to catch any // other unswitch opportunities. Otherwise, we need to mark it as deleted. - if (CurrentLoopValid) - U.revisitCurrentLoop(); - else + if (CurrentLoopValid) { + if (PartiallyInvariant) { + // Mark the new loop as partially unswitched, to avoid unswitching on + // the same condition again. + auto &Context = L.getHeader()->getContext(); + MDNode *DisableUnswitchMD = MDNode::get( + Context, + MDString::get(Context, "llvm.loop.unswitch.partial.disable")); + MDNode *NewLoopID = makePostTransformationMetadata( + Context, L.getLoopID(), {"llvm.loop.unswitch.partial"}, + {DisableUnswitchMD}); + L.setLoopID(NewLoopID); + } else + U.revisitCurrentLoop(); + } else U.markLoopAsDeleted(L, LoopName); }; @@ -2956,8 +3089,9 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, if (VerifyMemorySSA) AR.MSSA->verifyMemorySSA(); } - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB, - &AR.SE, MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.AA, AR.TTI, Trivial, NonTrivial, + UnswitchCB, &AR.SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr)) return PreservedAnalyses::all(); if (AR.MSSA && VerifyMemorySSA) @@ -3014,6 +3148,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); MemorySSA *MSSA = nullptr; Optional<MemorySSAUpdater> MSSAU; @@ -3025,7 +3160,7 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); auto *SE = SEWP ? &SEWP->getSE() : nullptr; - auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, + auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid, bool PartiallyInvariant, ArrayRef<Loop *> NewLoops) { // If we did a non-trivial unswitch, we have added new (cloned) loops. for (auto *NewL : NewLoops) @@ -3034,17 +3169,22 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { // If the current loop remains valid, re-add it to the queue. This is // a little wasteful as we'll finish processing the current loop as well, // but it is the best we can do in the old PM. - if (CurrentLoopValid) - LPM.addLoop(*L); - else + if (CurrentLoopValid) { + // If the current loop has been unswitched using a partially invariant + // condition, we should not re-add the current loop to avoid unswitching + // on the same condition again. + if (!PartiallyInvariant) + LPM.addLoop(*L); + } else LPM.markLoopAsDeleted(*L); }; if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); - bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE, - MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); + bool Changed = + unswitchLoop(*L, DT, LI, AC, AA, TTI, true, NonTrivial, UnswitchCB, SE, + MSSAU.hasValue() ? MSSAU.getPointer() : nullptr); if (MSSA && VerifyMemorySSA) MSSA->verifyMemorySSA(); |