diff options
Diffstat (limited to 'lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | lib/Transforms/Utils/LoopUtils.cpp | 371 |
1 files changed, 312 insertions, 59 deletions
diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index 3c522786641a..c3fa05a11a24 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -432,7 +432,7 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, InstDesc &Prev, bool HasFunNoNaNAttr) { bool FP = I->getType()->isFloatingPointTy(); Instruction *UAI = Prev.getUnsafeAlgebraInst(); - if (!UAI && FP && !I->hasUnsafeAlgebra()) + if (!UAI && FP && !I->isFast()) UAI = I; // Found an unsafe (unvectorizable) algebra instruction. switch (I->getOpcode()) { @@ -565,7 +565,8 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence( auto *I = Phi->user_back(); if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && DT->dominates(Previous, I->user_back())) { - SinkAfter[I] = Previous; + if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking. + SinkAfter[I] = Previous; return true; } } @@ -659,11 +660,11 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, break; } - // We only match FP sequences with unsafe algebra, so we can unconditionally + // We only match FP sequences that are 'fast', so we can unconditionally // set it on any generated instructions. IRBuilder<>::FastMathFlagGuard FMFG(Builder); FastMathFlags FMF; - FMF.setUnsafeAlgebra(); + FMF.setFast(); Builder.setFastMathFlags(FMF); Value *Cmp; @@ -677,7 +678,8 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, } InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, - const SCEV *Step, BinaryOperator *BOp) + const SCEV *Step, BinaryOperator *BOp, + SmallVectorImpl<Instruction *> *Casts) : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { assert(IK != IK_NoInduction && "Not an induction"); @@ -704,6 +706,12 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, (InductionBinOp->getOpcode() == Instruction::FAdd || InductionBinOp->getOpcode() == Instruction::FSub))) && "Binary opcode should be specified for FP induction"); + + if (Casts) { + for (auto &Inst : *Casts) { + RedundantCasts.push_back(Inst); + } + } } int InductionDescriptor::getConsecutiveDirection() const { @@ -767,7 +775,7 @@ Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, // Floating point operations had to be 'fast' to enable the induction. FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); Value *MulExp = B.CreateFMul(StepValue, Index); if (isa<Instruction>(MulExp)) @@ -807,7 +815,7 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, StartValue = Phi->getIncomingValue(1); } else { assert(TheLoop->contains(Phi->getIncomingBlock(1)) && - "Unexpected Phi node in the loop"); + "Unexpected Phi node in the loop"); BEValue = Phi->getIncomingValue(1); StartValue = Phi->getIncomingValue(0); } @@ -840,6 +848,110 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, return true; } +/// This function is called when we suspect that the update-chain of a phi node +/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, +/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime +/// predicate P under which the SCEV expression for the phi can be the +/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the +/// cast instructions that are involved in the update-chain of this induction. +/// A caller that adds the required runtime predicate can be free to drop these +/// cast instructions, and compute the phi using \p AR (instead of some scev +/// expression with casts). +/// +/// For example, without a predicate the scev expression can take the following +/// form: +/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy) +/// +/// It corresponds to the following IR sequence: +/// %for.body: +/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ] +/// %casted_phi = "ExtTrunc i64 %x" +/// %add = add i64 %casted_phi, %step +/// +/// where %x is given in \p PN, +/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate, +/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of +/// several forms, for example, such as: +/// ExtTrunc1: %casted_phi = and %x, 2^n-1 +/// or: +/// ExtTrunc2: %t = shl %x, m +/// %casted_phi = ashr %t, m +/// +/// If we are able to find such sequence, we return the instructions +/// we found, namely %casted_phi and the instructions on its use-def chain up +/// to the phi (not including the phi). +bool getCastsForInductionPHI( + PredicatedScalarEvolution &PSE, const SCEVUnknown *PhiScev, + const SCEVAddRecExpr *AR, SmallVectorImpl<Instruction *> &CastInsts) { + + assert(CastInsts.empty() && "CastInsts is expected to be empty."); + auto *PN = cast<PHINode>(PhiScev->getValue()); + assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); + const Loop *L = AR->getLoop(); + + // Find any cast instructions that participate in the def-use chain of + // PhiScev in the loop. + // FORNOW/TODO: We currently expect the def-use chain to include only + // two-operand instructions, where one of the operands is an invariant. + // createAddRecFromPHIWithCasts() currently does not support anything more + // involved than that, so we keep the search simple. This can be + // extended/generalized as needed. + + auto getDef = [&](const Value *Val) -> Value * { + const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val); + if (!BinOp) + return nullptr; + Value *Op0 = BinOp->getOperand(0); + Value *Op1 = BinOp->getOperand(1); + Value *Def = nullptr; + if (L->isLoopInvariant(Op0)) + Def = Op1; + else if (L->isLoopInvariant(Op1)) + Def = Op0; + return Def; + }; + + // Look for the instruction that defines the induction via the + // loop backedge. + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + Value *Val = PN->getIncomingValueForBlock(Latch); + if (!Val) + return false; + + // Follow the def-use chain until the induction phi is reached. + // If on the way we encounter a Value that has the same SCEV Expr as the + // phi node, we can consider the instructions we visit from that point + // as part of the cast-sequence that can be ignored. + bool InCastSequence = false; + auto *Inst = dyn_cast<Instruction>(Val); + while (Val != PN) { + // If we encountered a phi node other than PN, or if we left the loop, + // we bail out. + if (!Inst || !L->contains(Inst)) { + return false; + } + auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val)); + if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR)) + InCastSequence = true; + if (InCastSequence) { + // Only the last instruction in the cast sequence is expected to have + // uses outside the induction def-use chain. + if (!CastInsts.empty()) + if (!Inst->hasOneUse()) + return false; + CastInsts.push_back(Inst); + } + Val = getDef(Val); + if (!Val) + return false; + Inst = dyn_cast<Instruction>(Val); + } + + return InCastSequence; +} + bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, PredicatedScalarEvolution &PSE, InductionDescriptor &D, @@ -869,13 +981,26 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, return false; } + // Record any Cast instructions that participate in the induction update + const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev); + // If we started from an UnknownSCEV, and managed to build an addRecurrence + // only after enabling Assume with PSCEV, this means we may have encountered + // cast instructions that required adding a runtime check in order to + // guarantee the correctness of the AddRecurence respresentation of the + // induction. + if (PhiScev != AR && SymbolicPhi) { + SmallVector<Instruction *, 2> Casts; + if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts)) + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts); + } + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); } -bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, - ScalarEvolution *SE, - InductionDescriptor &D, - const SCEV *Expr) { +bool InductionDescriptor::isInductionPHI( + PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, + InductionDescriptor &D, const SCEV *Expr, + SmallVectorImpl<Instruction *> *CastsToIgnore) { Type *PhiTy = Phi->getType(); // We only handle integer and pointer inductions variables. if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) @@ -894,7 +1019,7 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, // FIXME: We should treat this as a uniform. Unfortunately, we // don't currently know how to handled uniform PHIs. DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); - return false; + return false; } Value *StartValue = @@ -907,7 +1032,8 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, return false; if (PhiTy->isIntegerTy()) { - D = InductionDescriptor(StartValue, IK_IntInduction, Step); + D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/ nullptr, + CastsToIgnore); return true; } @@ -1115,6 +1241,149 @@ Optional<const MDOperand *> llvm::findStringMetadataForLoop(Loop *TheLoop, return None; } +/// Does a BFS from a given node to all of its children inside a given loop. +/// The returned vector of nodes includes the starting point. +SmallVector<DomTreeNode *, 16> +llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) { + SmallVector<DomTreeNode *, 16> Worklist; + auto AddRegionToWorklist = [&](DomTreeNode *DTN) { + // Only include subregions in the top level loop. + BasicBlock *BB = DTN->getBlock(); + if (CurLoop->contains(BB)) + Worklist.push_back(DTN); + }; + + AddRegionToWorklist(N); + + for (size_t I = 0; I < Worklist.size(); I++) + for (DomTreeNode *Child : Worklist[I]->getChildren()) + AddRegionToWorklist(Child); + + return Worklist; +} + +void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, + ScalarEvolution *SE = nullptr, + LoopInfo *LI = nullptr) { + assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!"); + auto *Preheader = L->getLoopPreheader(); + assert(Preheader && "Preheader should exist!"); + + // Now that we know the removal is safe, remove the loop by changing the + // branch from the preheader to go to the single exit block. + // + // Because we're deleting a large chunk of code at once, the sequence in which + // we remove things is very important to avoid invalidation issues. + + // Tell ScalarEvolution that the loop is deleted. Do this before + // deleting the loop so that ScalarEvolution can look at the loop + // to determine what it needs to clean up. + if (SE) + SE->forgetLoop(L); + + auto *ExitBlock = L->getUniqueExitBlock(); + assert(ExitBlock && "Should have a unique exit block!"); + assert(L->hasDedicatedExits() && "Loop should have dedicated exits!"); + + auto *OldBr = dyn_cast<BranchInst>(Preheader->getTerminator()); + assert(OldBr && "Preheader must end with a branch"); + assert(OldBr->isUnconditional() && "Preheader must have a single successor"); + // Connect the preheader to the exit block. Keep the old edge to the header + // around to perform the dominator tree update in two separate steps + // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge + // preheader -> header. + // + // + // 0. Preheader 1. Preheader 2. Preheader + // | | | | + // V | V | + // Header <--\ | Header <--\ | Header <--\ + // | | | | | | | | | | | + // | V | | | V | | | V | + // | Body --/ | | Body --/ | | Body --/ + // V V V V V + // Exit Exit Exit + // + // By doing this is two separate steps we can perform the dominator tree + // update without using the batch update API. + // + // Even when the loop is never executed, we cannot remove the edge from the + // source block to the exit block. Consider the case where the unexecuted loop + // branches back to an outer loop. If we deleted the loop and removed the edge + // coming to this inner loop, this will break the outer loop structure (by + // deleting the backedge of the outer loop). If the outer loop is indeed a + // non-loop, it will be deleted in a future iteration of loop deletion pass. + IRBuilder<> Builder(OldBr); + Builder.CreateCondBr(Builder.getFalse(), L->getHeader(), ExitBlock); + // Remove the old branch. The conditional branch becomes a new terminator. + OldBr->eraseFromParent(); + + // Rewrite phis in the exit block to get their inputs from the Preheader + // instead of the exiting block. + BasicBlock::iterator BI = ExitBlock->begin(); + while (PHINode *P = dyn_cast<PHINode>(BI)) { + // Set the zero'th element of Phi to be from the preheader and remove all + // other incoming values. Given the loop has dedicated exits, all other + // incoming values must be from the exiting blocks. + int PredIndex = 0; + P->setIncomingBlock(PredIndex, Preheader); + // Removes all incoming values from all other exiting blocks (including + // duplicate values from an exiting block). + // Nuke all entries except the zero'th entry which is the preheader entry. + // NOTE! We need to remove Incoming Values in the reverse order as done + // below, to keep the indices valid for deletion (removeIncomingValues + // updates getNumIncomingValues and shifts all values down into the operand + // being deleted). + for (unsigned i = 0, e = P->getNumIncomingValues() - 1; i != e; ++i) + P->removeIncomingValue(e - i, false); + + assert((P->getNumIncomingValues() == 1 && + P->getIncomingBlock(PredIndex) == Preheader) && + "Should have exactly one value and that's from the preheader!"); + ++BI; + } + + // Disconnect the loop body by branching directly to its exit. + Builder.SetInsertPoint(Preheader->getTerminator()); + Builder.CreateBr(ExitBlock); + // Remove the old branch. + Preheader->getTerminator()->eraseFromParent(); + + if (DT) { + // Update the dominator tree by informing it about the new edge from the + // preheader to the exit. + DT->insertEdge(Preheader, ExitBlock); + // Inform the dominator tree about the removed edge. + DT->deleteEdge(Preheader, L->getHeader()); + } + + // Remove the block from the reference counting scheme, so that we can + // delete it freely later. + for (auto *Block : L->blocks()) + Block->dropAllReferences(); + + if (LI) { + // Erase the instructions and the blocks without having to worry + // about ordering because we already dropped the references. + // NOTE: This iteration is safe because erasing the block does not remove + // its entry from the loop's block list. We do that in the next section. + for (Loop::block_iterator LpI = L->block_begin(), LpE = L->block_end(); + LpI != LpE; ++LpI) + (*LpI)->eraseFromParent(); + + // Finally, the blocks from loopinfo. This has to happen late because + // otherwise our loop iterators won't work. + + SmallPtrSet<BasicBlock *, 8> blocks; + blocks.insert(L->block_begin(), L->block_end()); + for (BasicBlock *BB : blocks) + LI->removeBlock(BB); + + // The last step is to update LoopInfo now that we've eliminated this loop. + LI->erase(L); + } +} + /// Returns true if the instruction in a loop is guaranteed to execute at least /// once. bool llvm::isGuaranteedToExecute(const Instruction &Inst, @@ -1194,7 +1463,7 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { static Value *addFastMathFlag(Value *V) { if (isa<FPMathOperator>(V)) { FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); cast<Instruction>(V)->setFastMathFlags(Flags); } return V; @@ -1256,8 +1525,8 @@ Value *llvm::createSimpleTargetReduction( using RD = RecurrenceDescriptor; RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid; // TODO: Support creating ordered reductions. - FastMathFlags FMFUnsafe; - FMFUnsafe.setUnsafeAlgebra(); + FastMathFlags FMFFast; + FMFFast.setFast(); switch (Opcode) { case Instruction::Add: @@ -1278,14 +1547,14 @@ Value *llvm::createSimpleTargetReduction( case Instruction::FAdd: BuildFunc = [&]() { auto Rdx = Builder.CreateFAddReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); return Rdx; }; break; case Instruction::FMul: BuildFunc = [&]() { auto Rdx = Builder.CreateFMulReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); return Rdx; }; break; @@ -1321,55 +1590,39 @@ Value *llvm::createSimpleTargetReduction( } /// Create a vector reduction using a given recurrence descriptor. -Value *llvm::createTargetReduction(IRBuilder<> &Builder, +Value *llvm::createTargetReduction(IRBuilder<> &B, const TargetTransformInfo *TTI, RecurrenceDescriptor &Desc, Value *Src, bool NoNaN) { // TODO: Support in-order reductions based on the recurrence descriptor. - RecurrenceDescriptor::RecurrenceKind RecKind = Desc.getRecurrenceKind(); + using RD = RecurrenceDescriptor; + RD::RecurrenceKind RecKind = Desc.getRecurrenceKind(); TargetTransformInfo::ReductionFlags Flags; Flags.NoNaN = NoNaN; - auto getSimpleRdx = [&](unsigned Opc) { - return createSimpleTargetReduction(Builder, TTI, Opc, Src, Flags); - }; switch (RecKind) { - case RecurrenceDescriptor::RK_FloatAdd: - return getSimpleRdx(Instruction::FAdd); - case RecurrenceDescriptor::RK_FloatMult: - return getSimpleRdx(Instruction::FMul); - case RecurrenceDescriptor::RK_IntegerAdd: - return getSimpleRdx(Instruction::Add); - case RecurrenceDescriptor::RK_IntegerMult: - return getSimpleRdx(Instruction::Mul); - case RecurrenceDescriptor::RK_IntegerAnd: - return getSimpleRdx(Instruction::And); - case RecurrenceDescriptor::RK_IntegerOr: - return getSimpleRdx(Instruction::Or); - case RecurrenceDescriptor::RK_IntegerXor: - return getSimpleRdx(Instruction::Xor); - case RecurrenceDescriptor::RK_IntegerMinMax: { - switch (Desc.getMinMaxRecurrenceKind()) { - case RecurrenceDescriptor::MRK_SIntMax: - Flags.IsSigned = true; - Flags.IsMaxOp = true; - break; - case RecurrenceDescriptor::MRK_UIntMax: - Flags.IsMaxOp = true; - break; - case RecurrenceDescriptor::MRK_SIntMin: - Flags.IsSigned = true; - break; - case RecurrenceDescriptor::MRK_UIntMin: - break; - default: - llvm_unreachable("Unhandled MRK"); - } - return getSimpleRdx(Instruction::ICmp); + case RD::RK_FloatAdd: + return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags); + case RD::RK_FloatMult: + return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags); + case RD::RK_IntegerAdd: + return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags); + case RD::RK_IntegerMult: + return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags); + case RD::RK_IntegerAnd: + return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags); + case RD::RK_IntegerOr: + return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags); + case RD::RK_IntegerXor: + return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags); + case RD::RK_IntegerMinMax: { + RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind(); + Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax); + Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin); + return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags); } - case RecurrenceDescriptor::RK_FloatMinMax: { - Flags.IsMaxOp = - Desc.getMinMaxRecurrenceKind() == RecurrenceDescriptor::MRK_FloatMax; - return getSimpleRdx(Instruction::FCmp); + case RD::RK_FloatMinMax: { + Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax; + return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags); } default: llvm_unreachable("Unhandled RecKind"); |