diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp | 440 |
1 files changed, 218 insertions, 222 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp index 6749d3db743c..a92cb6a313d3 100644 --- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -22,6 +22,7 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" @@ -35,6 +36,7 @@ #include "llvm/Transforms/Utils.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" #include "llvm/Transforms/Utils/UnrollLoop.h" @@ -167,8 +169,11 @@ static void ConnectProlog(Loop *L, Value *BECount, unsigned Count, // Add the branch to the exit block (around the unrolled loop) B.CreateCondBr(BrLoopExit, OriginalLoopLatchExit, NewPreHeader); InsertPt->eraseFromParent(); - if (DT) - DT->changeImmediateDominator(OriginalLoopLatchExit, PrologExit); + if (DT) { + auto *NewDom = DT->findNearestCommonDominator(OriginalLoopLatchExit, + PrologExit); + DT->changeImmediateDominator(OriginalLoopLatchExit, NewDom); + } } /// Connect the unrolling epilog code to the original loop. @@ -215,7 +220,10 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // PN = PHI [I, Latch] // ... // Exit: - // EpilogPN = PHI [PN, EpilogPreHeader] + // EpilogPN = PHI [PN, EpilogPreHeader], [X, Exit2], [Y, Exit2.epil] + // + // Exits from non-latch blocks point to the original exit block and the + // epilogue edges have already been added. // // There is EpilogPreHeader incoming block instead of NewExit as // NewExit was spilt 1 more time to get EpilogPreHeader. @@ -282,8 +290,10 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, // Add the branch to the exit block (around the unrolling loop) B.CreateCondBr(BrLoopExit, EpilogPreHeader, Exit); InsertPt->eraseFromParent(); - if (DT) - DT->changeImmediateDominator(Exit, NewExit); + if (DT) { + auto *NewDom = DT->findNearestCommonDominator(Exit, NewExit); + DT->changeImmediateDominator(Exit, NewDom); + } // Split the main loop exit to maintain canonicalization guarantees. SmallVector<BasicBlock*, 4> NewExitPreds{Latch}; @@ -291,17 +301,15 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, PreserveLCSSA); } -/// Create a clone of the blocks in a loop and connect them together. -/// If CreateRemainderLoop is false, loop structure will not be cloned, -/// otherwise a new loop will be created including all cloned blocks, and the -/// iterator of it switches to count NewIter down to 0. +/// Create a clone of the blocks in a loop and connect them together. A new +/// loop will be created including all cloned blocks, and the iterator of the +/// new loop switched to count NewIter down to 0. /// The cloned blocks should be inserted between InsertTop and InsertBot. -/// If loop structure is cloned InsertTop should be new preheader, InsertBot -/// new loop exit. -/// Return the new cloned loop that is created when CreateRemainderLoop is true. +/// InsertTop should be new preheader, InsertBot new loop exit. +/// Returns the new cloned loop that is created. static Loop * -CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, - const bool UseEpilogRemainder, const bool UnrollRemainder, +CloneLoopBlocks(Loop *L, Value *NewIter, const bool UseEpilogRemainder, + const bool UnrollRemainder, BasicBlock *InsertTop, BasicBlock *InsertBot, BasicBlock *Preheader, std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, @@ -315,8 +323,6 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, Loop *ParentLoop = L->getParentLoop(); NewLoopsMap NewLoops; NewLoops[ParentLoop] = ParentLoop; - if (!CreateRemainderLoop) - NewLoops[L] = ParentLoop; // For each block in the original loop, create a new copy, // and update the value map with the newly created values. @@ -324,11 +330,7 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, BasicBlock *NewBB = CloneBasicBlock(*BB, VMap, "." + suffix, F); NewBlocks.push_back(NewBB); - // If we're unrolling the outermost loop, there's no remainder loop, - // and this block isn't in a nested loop, then the new block is not - // in any loop. Otherwise, add it to loopinfo. - if (CreateRemainderLoop || LI->getLoopFor(*BB) != L || ParentLoop) - addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops); + addClonedBlockToLoopInfo(*BB, NewBB, LI, NewLoops); VMap[*BB] = NewBB; if (Header == *BB) { @@ -349,27 +351,24 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, } if (Latch == *BB) { - // For the last block, if CreateRemainderLoop is false, create a direct - // jump to InsertBot. If not, create a loop back to cloned head. + // For the last block, create a loop back to cloned head. VMap.erase((*BB)->getTerminator()); + // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. + // Subtle: NewIter can be 0 if we wrapped when computing the trip count, + // thus we must compare the post-increment (wrapping) value. BasicBlock *FirstLoopBB = cast<BasicBlock>(VMap[Header]); BranchInst *LatchBR = cast<BranchInst>(NewBB->getTerminator()); IRBuilder<> Builder(LatchBR); - if (!CreateRemainderLoop) { - Builder.CreateBr(InsertBot); - } else { - PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, - suffix + ".iter", - FirstLoopBB->getFirstNonPHI()); - Value *IdxSub = - Builder.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), - NewIdx->getName() + ".sub"); - Value *IdxCmp = - Builder.CreateIsNotNull(IdxSub, NewIdx->getName() + ".cmp"); - Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot); - NewIdx->addIncoming(NewIter, InsertTop); - NewIdx->addIncoming(IdxSub, NewBB); - } + PHINode *NewIdx = PHINode::Create(NewIter->getType(), 2, + suffix + ".iter", + FirstLoopBB->getFirstNonPHI()); + auto *Zero = ConstantInt::get(NewIdx->getType(), 0); + auto *One = ConstantInt::get(NewIdx->getType(), 1); + Value *IdxNext = Builder.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); + Value *IdxCmp = Builder.CreateICmpNE(IdxNext, NewIter, NewIdx->getName() + ".cmp"); + Builder.CreateCondBr(IdxCmp, FirstLoopBB, InsertBot); + NewIdx->addIncoming(Zero, InsertTop); + NewIdx->addIncoming(IdxNext, NewBB); LatchBR->eraseFromParent(); } } @@ -378,99 +377,45 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, // cloned loop. for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) { PHINode *NewPHI = cast<PHINode>(VMap[&*I]); - if (!CreateRemainderLoop) { - if (UseEpilogRemainder) { - unsigned idx = NewPHI->getBasicBlockIndex(Preheader); - NewPHI->setIncomingBlock(idx, InsertTop); - NewPHI->removeIncomingValue(Latch, false); - } else { - VMap[&*I] = NewPHI->getIncomingValueForBlock(Preheader); - cast<BasicBlock>(VMap[Header])->getInstList().erase(NewPHI); - } - } else { - unsigned idx = NewPHI->getBasicBlockIndex(Preheader); - NewPHI->setIncomingBlock(idx, InsertTop); - BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); - idx = NewPHI->getBasicBlockIndex(Latch); - Value *InVal = NewPHI->getIncomingValue(idx); - NewPHI->setIncomingBlock(idx, NewLatch); - if (Value *V = VMap.lookup(InVal)) - NewPHI->setIncomingValue(idx, V); - } - } - if (CreateRemainderLoop) { - Loop *NewLoop = NewLoops[L]; - assert(NewLoop && "L should have been cloned"); - MDNode *LoopID = NewLoop->getLoopID(); - - // Only add loop metadata if the loop is not going to be completely - // unrolled. - if (UnrollRemainder) - return NewLoop; - - Optional<MDNode *> NewLoopID = makeFollowupLoopID( - LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); - if (NewLoopID.hasValue()) { - NewLoop->setLoopID(NewLoopID.getValue()); - - // Do not setLoopAlreadyUnrolled if loop attributes have been defined - // explicitly. - return NewLoop; - } - - // Add unroll disable metadata to disable future unrolling for this loop. - NewLoop->setLoopAlreadyUnrolled(); - return NewLoop; + unsigned idx = NewPHI->getBasicBlockIndex(Preheader); + NewPHI->setIncomingBlock(idx, InsertTop); + BasicBlock *NewLatch = cast<BasicBlock>(VMap[Latch]); + idx = NewPHI->getBasicBlockIndex(Latch); + Value *InVal = NewPHI->getIncomingValue(idx); + NewPHI->setIncomingBlock(idx, NewLatch); + if (Value *V = VMap.lookup(InVal)) + NewPHI->setIncomingValue(idx, V); } - else - return nullptr; -} -/// Returns true if we can safely unroll a multi-exit/exiting loop. OtherExits -/// is populated with all the loop exit blocks other than the LatchExit block. -static bool canSafelyUnrollMultiExitLoop(Loop *L, BasicBlock *LatchExit, - bool PreserveLCSSA, - bool UseEpilogRemainder) { + Loop *NewLoop = NewLoops[L]; + assert(NewLoop && "L should have been cloned"); + MDNode *LoopID = NewLoop->getLoopID(); - // We currently have some correctness constrains in unrolling a multi-exit - // loop. Check for these below. + // Only add loop metadata if the loop is not going to be completely + // unrolled. + if (UnrollRemainder) + return NewLoop; - // We rely on LCSSA form being preserved when the exit blocks are transformed. - if (!PreserveLCSSA) - return false; + Optional<MDNode *> NewLoopID = makeFollowupLoopID( + LoopID, {LLVMLoopUnrollFollowupAll, LLVMLoopUnrollFollowupRemainder}); + if (NewLoopID.hasValue()) { + NewLoop->setLoopID(NewLoopID.getValue()); - // TODO: Support multiple exiting blocks jumping to the `LatchExit` when - // UnrollRuntimeMultiExit is true. This will need updating the logic in - // connectEpilog/connectProlog. - if (!LatchExit->getSinglePredecessor()) { - LLVM_DEBUG( - dbgs() << "Bailout for multi-exit handling when latch exit has >1 " - "predecessor.\n"); - return false; + // Do not setLoopAlreadyUnrolled if loop attributes have been defined + // explicitly. + return NewLoop; } - // FIXME: We bail out of multi-exit unrolling when epilog loop is generated - // and L is an inner loop. This is because in presence of multiple exits, the - // outer loop is incorrect: we do not add the EpilogPreheader and exit to the - // outer loop. This is automatically handled in the prolog case, so we do not - // have that bug in prolog generation. - if (UseEpilogRemainder && L->getParentLoop()) - return false; - // All constraints have been satisfied. - return true; + // Add unroll disable metadata to disable future unrolling for this loop. + NewLoop->setLoopAlreadyUnrolled(); + return NewLoop; } /// Returns true if we can profitably unroll the multi-exit loop L. Currently, /// we return true only if UnrollRuntimeMultiExit is set to true. static bool canProfitablyUnrollMultiExitLoop( Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, BasicBlock *LatchExit, - bool PreserveLCSSA, bool UseEpilogRemainder) { - -#if !defined(NDEBUG) - assert(canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, - UseEpilogRemainder) && - "Should be safe to unroll before checking profitability!"); -#endif + bool UseEpilogRemainder) { // Priority goes to UnrollRuntimeMultiExit if it's supplied. if (UnrollRuntimeMultiExit.getNumOccurrences()) @@ -523,24 +468,56 @@ static void updateLatchBranchWeightsForRemainderLoop(Loop *OrigLoop, uint64_t TrueWeight, FalseWeight; BranchInst *LatchBR = cast<BranchInst>(OrigLoop->getLoopLatch()->getTerminator()); - if (LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) { - uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() - ? FalseWeight - : TrueWeight; - assert(UnrollFactor > 1); - uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; - BasicBlock *Header = RemainderLoop->getHeader(); - BasicBlock *Latch = RemainderLoop->getLoopLatch(); - auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); - unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); - MDBuilder MDB(RemainderLatchBR->getContext()); - MDNode *WeightNode = - HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) - : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); - RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); - } + if (!LatchBR->extractProfMetadata(TrueWeight, FalseWeight)) + return; + uint64_t ExitWeight = LatchBR->getSuccessor(0) == OrigLoop->getHeader() + ? FalseWeight + : TrueWeight; + assert(UnrollFactor > 1); + uint64_t BackEdgeWeight = (UnrollFactor - 1) * ExitWeight; + BasicBlock *Header = RemainderLoop->getHeader(); + BasicBlock *Latch = RemainderLoop->getLoopLatch(); + auto *RemainderLatchBR = cast<BranchInst>(Latch->getTerminator()); + unsigned HeaderIdx = (RemainderLatchBR->getSuccessor(0) == Header ? 0 : 1); + MDBuilder MDB(RemainderLatchBR->getContext()); + MDNode *WeightNode = + HeaderIdx ? MDB.createBranchWeights(ExitWeight, BackEdgeWeight) + : MDB.createBranchWeights(BackEdgeWeight, ExitWeight); + RemainderLatchBR->setMetadata(LLVMContext::MD_prof, WeightNode); } +/// Calculate ModVal = (BECount + 1) % Count on the abstract integer domain +/// accounting for the possibility of unsigned overflow in the 2s complement +/// domain. Preconditions: +/// 1) TripCount = BECount + 1 (allowing overflow) +/// 2) Log2(Count) <= BitWidth(BECount) +static Value *CreateTripRemainder(IRBuilder<> &B, Value *BECount, + Value *TripCount, unsigned Count) { + // Note that TripCount is BECount + 1. + if (isPowerOf2_32(Count)) + // If the expression is zero, then either: + // 1. There are no iterations to be run in the prolog/epilog loop. + // OR + // 2. The addition computing TripCount overflowed. + // + // If (2) is true, we know that TripCount really is (1 << BEWidth) and so + // the number of iterations that remain to be run in the original loop is a + // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (a + // precondition of this method). + return B.CreateAnd(TripCount, Count - 1, "xtraiter"); + + // As (BECount + 1) can potentially unsigned overflow we count + // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count. + Constant *CountC = ConstantInt::get(BECount->getType(), Count); + Value *ModValTmp = B.CreateURem(BECount, CountC); + Value *ModValAdd = B.CreateAdd(ModValTmp, + ConstantInt::get(ModValTmp->getType(), 1)); + // At that point (BECount % Count) + 1 could be equal to Count. + // To handle this case we need to take mod by Count one more time. + return B.CreateURem(ModValAdd, CountC, "xtraiter"); +} + + /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. /// @@ -624,19 +601,22 @@ bool llvm::UnrollRuntimeLoopRemainder( // These are exit blocks other than the target of the latch exiting block. SmallVector<BasicBlock *, 4> OtherExits; L->getUniqueNonLatchExitBlocks(OtherExits); - bool isMultiExitUnrollingEnabled = - canSafelyUnrollMultiExitLoop(L, LatchExit, PreserveLCSSA, - UseEpilogRemainder) && - canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, - UseEpilogRemainder); - // Support only single exit and exiting block unless multi-exit loop unrolling is enabled. - if (!isMultiExitUnrollingEnabled && - (!L->getExitingBlock() || OtherExits.size())) { - LLVM_DEBUG( - dbgs() - << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " - "enabled!\n"); - return false; + // Support only single exit and exiting block unless multi-exit loop + // unrolling is enabled. + if (!L->getExitingBlock() || OtherExits.size()) { + // We rely on LCSSA form being preserved when the exit blocks are transformed. + // (Note that only an off-by-default mode of the old PM disables PreserveLCCA.) + if (!PreserveLCSSA) + return false; + + if (!canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, + UseEpilogRemainder)) { + LLVM_DEBUG( + dbgs() + << "Multiple exit/exiting blocks in loop and multi-exit unrolling not " + "enabled!\n"); + return false; + } } // Use Scalar Evolution to compute the trip count. This allows more loops to // be unrolled than relying on induction var simplification. @@ -659,6 +639,7 @@ bool llvm::UnrollRuntimeLoopRemainder( unsigned BEWidth = cast<IntegerType>(BECountSC->getType())->getBitWidth(); // Add 1 since the backedge count doesn't include the first loop iteration. + // (Note that overflow can occur, this is handled explicitly below) const SCEV *TripCountSC = SE->getAddExpr(BECountSC, SE->getConstant(BECountSC->getType(), 1)); if (isa<SCEVCouldNotCompute>(TripCountSC)) { @@ -706,8 +687,7 @@ bool llvm::UnrollRuntimeLoopRemainder( NewPreHeader = SplitBlock(PreHeader, PreHeader->getTerminator(), DT, LI); NewPreHeader->setName(PreHeader->getName() + ".new"); // Split LatchExit to create phi nodes from branch above. - SmallVector<BasicBlock*, 4> Preds(predecessors(LatchExit)); - NewExit = SplitBlockPredecessors(LatchExit, Preds, ".unr-lcssa", DT, LI, + NewExit = SplitBlockPredecessors(LatchExit, {Latch}, ".unr-lcssa", DT, LI, nullptr, PreserveLCSSA); // NewExit gets its DebugLoc from LatchExit, which is not part of the // original Loop. @@ -717,6 +697,21 @@ bool llvm::UnrollRuntimeLoopRemainder( // Split NewExit to insert epilog remainder loop. EpilogPreHeader = SplitBlock(NewExit, NewExitTerminator, DT, LI); EpilogPreHeader->setName(Header->getName() + ".epil.preheader"); + + // If the latch exits from multiple level of nested loops, then + // by assumption there must be another loop exit which branches to the + // outer loop and we must adjust the loop for the newly inserted blocks + // to account for the fact that our epilogue is still in the same outer + // loop. Note that this leaves loopinfo temporarily out of sync with the + // CFG until the actual epilogue loop is inserted. + if (auto *ParentL = L->getParentLoop()) + if (LI->getLoopFor(LatchExit) != ParentL) { + LI->removeBlock(NewExit); + ParentL->addBasicBlockToLoop(NewExit, *LI); + LI->removeBlock(EpilogPreHeader); + ParentL->addBasicBlockToLoop(EpilogPreHeader, *LI); + } + } else { // If prolog remainder // Split the original preheader twice to insert prolog remainder loop @@ -751,35 +746,8 @@ bool llvm::UnrollRuntimeLoopRemainder( Value *BECount = Expander.expandCodeFor(BECountSC, BECountSC->getType(), PreHeaderBR); IRBuilder<> B(PreHeaderBR); - Value *ModVal; - // Calculate ModVal = (BECount + 1) % Count. - // Note that TripCount is BECount + 1. - if (isPowerOf2_32(Count)) { - // When Count is power of 2 we don't BECount for epilog case, however we'll - // need it for a branch around unrolling loop for prolog case. - ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter"); - // 1. There are no iterations to be run in the prolog/epilog loop. - // OR - // 2. The addition computing TripCount overflowed. - // - // If (2) is true, we know that TripCount really is (1 << BEWidth) and so - // the number of iterations that remain to be run in the original loop is a - // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we - // explicitly check this above). - } else { - // As (BECount + 1) can potentially unsigned overflow we count - // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count. - Value *ModValTmp = B.CreateURem(BECount, - ConstantInt::get(BECount->getType(), - Count)); - Value *ModValAdd = B.CreateAdd(ModValTmp, - ConstantInt::get(ModValTmp->getType(), 1)); - // At that point (BECount % Count) + 1 could be equal to Count. - // To handle this case we need to take mod by Count one more time. - ModVal = B.CreateURem(ModValAdd, - ConstantInt::get(BECount->getType(), Count), - "xtraiter"); - } + Value * const ModVal = CreateTripRemainder(B, BECount, TripCount, Count); + Value *BranchVal = UseEpilogRemainder ? B.CreateICmpULT(BECount, ConstantInt::get(BECount->getType(), @@ -810,18 +778,13 @@ bool llvm::UnrollRuntimeLoopRemainder( std::vector<BasicBlock *> NewBlocks; ValueToValueMapTy VMap; - // For unroll factor 2 remainder loop will have 1 iterations. - // Do not create 1 iteration loop. - bool CreateRemainderLoop = (Count != 2); - // Clone all the basic blocks in the loop. If Count is 2, we don't clone // the loop, otherwise we create a cloned loop to execute the extra // iterations. This function adds the appropriate CFG connections. BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit; BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; Loop *remainderLoop = CloneLoopBlocks( - L, ModVal, CreateRemainderLoop, UseEpilogRemainder, UnrollRemainder, - InsertTop, InsertBot, + L, ModVal, UseEpilogRemainder, UnrollRemainder, InsertTop, InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); // Assign the maximum possible trip count as the back edge weight for the @@ -840,36 +803,33 @@ bool llvm::UnrollRuntimeLoopRemainder( // work is to update the phi nodes in the original loop, and take in the // values from the cloned region. for (auto *BB : OtherExits) { - for (auto &II : *BB) { - - // Given we preserve LCSSA form, we know that the values used outside the - // loop will be used through these phi nodes at the exit blocks that are - // transformed below. - if (!isa<PHINode>(II)) - break; - PHINode *Phi = cast<PHINode>(&II); - unsigned oldNumOperands = Phi->getNumIncomingValues(); + // Given we preserve LCSSA form, we know that the values used outside the + // loop will be used through these phi nodes at the exit blocks that are + // transformed below. + for (PHINode &PN : BB->phis()) { + unsigned oldNumOperands = PN.getNumIncomingValues(); // Add the incoming values from the remainder code to the end of the phi // node. - for (unsigned i =0; i < oldNumOperands; i++){ - Value *newVal = VMap.lookup(Phi->getIncomingValue(i)); - // newVal can be a constant or derived from values outside the loop, and - // hence need not have a VMap value. Also, since lookup already generated - // a default "null" VMap entry for this value, we need to populate that - // VMap entry correctly, with the mapped entry being itself. - if (!newVal) { - newVal = Phi->getIncomingValue(i); - VMap[Phi->getIncomingValue(i)] = Phi->getIncomingValue(i); - } - Phi->addIncoming(newVal, - cast<BasicBlock>(VMap[Phi->getIncomingBlock(i)])); + for (unsigned i = 0; i < oldNumOperands; i++){ + auto *PredBB =PN.getIncomingBlock(i); + if (PredBB == Latch) + // The latch exit is handled seperately, see connectX + continue; + if (!L->contains(PredBB)) + // Even if we had dedicated exits, the code above inserted an + // extra branch which can reach the latch exit. + continue; + + auto *V = PN.getIncomingValue(i); + if (Instruction *I = dyn_cast<Instruction>(V)) + if (L->contains(I)) + V = VMap.lookup(I); + PN.addIncoming(V, cast<BasicBlock>(VMap[PredBB])); } } #if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) for (BasicBlock *SuccBB : successors(BB)) { - assert(!(any_of(OtherExits, - [SuccBB](BasicBlock *EB) { return EB == SuccBB; }) || - SuccBB == LatchExit) && + assert(!(llvm::is_contained(OtherExits, SuccBB) || SuccBB == LatchExit) && "Breaks the definition of dedicated exits!"); } #endif @@ -931,23 +891,22 @@ bool llvm::UnrollRuntimeLoopRemainder( PreserveLCSSA); // Update counter in loop for unrolling. - // I should be multiply of Count. + // Use an incrementing IV. Pre-incr/post-incr is backedge/trip count. + // Subtle: TestVal can be 0 if we wrapped when computing the trip count, + // thus we must compare the post-increment (wrapping) value. IRBuilder<> B2(NewPreHeader->getTerminator()); Value *TestVal = B2.CreateSub(TripCount, ModVal, "unroll_iter"); BranchInst *LatchBR = cast<BranchInst>(Latch->getTerminator()); - B2.SetInsertPoint(LatchBR); PHINode *NewIdx = PHINode::Create(TestVal->getType(), 2, "niter", Header->getFirstNonPHI()); - Value *IdxSub = - B2.CreateSub(NewIdx, ConstantInt::get(NewIdx->getType(), 1), - NewIdx->getName() + ".nsub"); - Value *IdxCmp; - if (LatchBR->getSuccessor(0) == Header) - IdxCmp = B2.CreateIsNotNull(IdxSub, NewIdx->getName() + ".ncmp"); - else - IdxCmp = B2.CreateIsNull(IdxSub, NewIdx->getName() + ".ncmp"); - NewIdx->addIncoming(TestVal, NewPreHeader); - NewIdx->addIncoming(IdxSub, Latch); + B2.SetInsertPoint(LatchBR); + auto *Zero = ConstantInt::get(NewIdx->getType(), 0); + auto *One = ConstantInt::get(NewIdx->getType(), 1); + Value *IdxNext = B2.CreateAdd(NewIdx, One, NewIdx->getName() + ".next"); + auto Pred = LatchBR->getSuccessor(0) == Header ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ; + Value *IdxCmp = B2.CreateICmp(Pred, IdxNext, TestVal, NewIdx->getName() + ".ncmp"); + NewIdx->addIncoming(Zero, NewPreHeader); + NewIdx->addIncoming(IdxNext, Latch); LatchBR->setCondition(IdxCmp); } else { // Connect the prolog code to the original loop and update the @@ -960,12 +919,49 @@ bool llvm::UnrollRuntimeLoopRemainder( // of its parent loops, so the Scalar Evolution pass needs to be run again. SE->forgetTopmostLoop(L); - // Verify that the Dom Tree is correct. + // Verify that the Dom Tree and Loop Info are correct. #if defined(EXPENSIVE_CHECKS) && !defined(NDEBUG) - if (DT) + if (DT) { assert(DT->verify(DominatorTree::VerificationLevel::Full)); + LI->verify(*DT); + } #endif + // For unroll factor 2 remainder loop will have 1 iteration. + if (Count == 2 && DT && LI && SE) { + // TODO: This code could probably be pulled out into a helper function + // (e.g. breakLoopBackedgeAndSimplify) and reused in loop-deletion. + BasicBlock *RemainderLatch = remainderLoop->getLoopLatch(); + assert(RemainderLatch); + SmallVector<BasicBlock*> RemainderBlocks(remainderLoop->getBlocks().begin(), + remainderLoop->getBlocks().end()); + breakLoopBackedge(remainderLoop, *DT, *SE, *LI, nullptr); + remainderLoop = nullptr; + + // Simplify loop values after breaking the backedge + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SmallVector<WeakTrackingVH, 16> DeadInsts; + for (BasicBlock *BB : RemainderBlocks) { + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { + if (Value *V = SimplifyInstruction(&Inst, {DL, nullptr, DT, AC})) + if (LI->replacementPreservesLCSSAForm(&Inst, V)) + Inst.replaceAllUsesWith(V); + if (isInstructionTriviallyDead(&Inst)) + DeadInsts.emplace_back(&Inst); + } + // We can't do recursive deletion until we're done iterating, as we might + // have a phi which (potentially indirectly) uses instructions later in + // the block we're iterating through. + RecursivelyDeleteTriviallyDeadInstructions(DeadInsts); + } + + // Merge latch into exit block. + auto *ExitBB = RemainderLatch->getSingleSuccessor(); + assert(ExitBB && "required after breaking cond br backedge"); + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + MergeBlockIntoPredecessor(ExitBB, &DTU, LI); + } + // Canonicalize to LoopSimplifyForm both original and remainder loops. We // cannot rely on the LoopUnrollPass to do this because it only does // canonicalization for parent/subloops and not the sibling loops. |