diff options
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 773 |
1 files changed, 720 insertions, 53 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index c4c40189fda46..43363736684ee 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -11,12 +11,19 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/PriorityWorklist.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/DomTreeUpdater.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemorySSA.h" @@ -31,7 +38,9 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/ValueHandle.h" #include "llvm/InitializePasses.h" @@ -39,10 +48,17 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" using namespace llvm; using namespace llvm::PatternMatch; +static cl::opt<bool> ForceReductionIntrinsic( + "force-reduction-intrinsics", cl::Hidden, + cl::desc("Force creating reduction intrinsics for testing."), + cl::init(false)); + #define DEBUG_TYPE "loop-utils" static const char *LLVMLoopDisableNonforced = "llvm.loop.disable_nonforced"; @@ -496,20 +512,24 @@ llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) { AddRegionToWorklist(N); - for (size_t I = 0; I < Worklist.size(); I++) - for (DomTreeNode *Child : Worklist[I]->getChildren()) + for (size_t I = 0; I < Worklist.size(); I++) { + for (DomTreeNode *Child : Worklist[I]->children()) AddRegionToWorklist(Child); + } return Worklist; } -void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, - ScalarEvolution *SE = nullptr, - LoopInfo *LI = nullptr) { +void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, + LoopInfo *LI, MemorySSA *MSSA) { assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!"); auto *Preheader = L->getLoopPreheader(); assert(Preheader && "Preheader should exist!"); + std::unique_ptr<MemorySSAUpdater> MSSAU; + if (MSSA) + MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); + // Now that we know the removal is safe, remove the loop by changing the // branch from the preheader to go to the single exit block. // @@ -582,18 +602,33 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, "Should have exactly one value and that's from the preheader!"); } + DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); + if (DT) { + DTU.applyUpdates({{DominatorTree::Insert, Preheader, ExitBlock}}); + if (MSSA) { + MSSAU->applyUpdates({{DominatorTree::Insert, Preheader, ExitBlock}}, *DT); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); + } + } + // Disconnect the loop body by branching directly to its exit. Builder.SetInsertPoint(Preheader->getTerminator()); Builder.CreateBr(ExitBlock); // Remove the old branch. Preheader->getTerminator()->eraseFromParent(); - DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager); if (DT) { - // Update the dominator tree by informing it about the new edge from the - // preheader to the exit and the removed edge. - DTU.applyUpdates({{DominatorTree::Insert, Preheader, ExitBlock}, - {DominatorTree::Delete, Preheader, L->getHeader()}}); + DTU.applyUpdates({{DominatorTree::Delete, Preheader, L->getHeader()}}); + if (MSSA) { + MSSAU->applyUpdates({{DominatorTree::Delete, Preheader, L->getHeader()}}, + *DT); + SmallSetVector<BasicBlock *, 8> DeadBlockSet(L->block_begin(), + L->block_end()); + MSSAU->removeBlocks(DeadBlockSet); + if (VerifyMemorySSA) + MSSA->verifyMemorySSA(); + } } // Use a map to unique and a vector to guarantee deterministic ordering. @@ -654,6 +689,9 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, for (auto *Block : L->blocks()) Block->dropAllReferences(); + if (MSSA && VerifyMemorySSA) + MSSA->verifyMemorySSA(); + if (LI) { // Erase the instructions and the blocks without having to worry // about ordering because we already dropped the references. @@ -676,11 +714,11 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, // its parent. While removeLoop/removeChildLoop remove the given loop but // not relink its subloops, which is what we want. if (Loop *ParentLoop = L->getParentLoop()) { - Loop::iterator I = find(ParentLoop->begin(), ParentLoop->end(), L); + Loop::iterator I = find(*ParentLoop, L); assert(I != ParentLoop->end() && "Couldn't find loop"); ParentLoop->removeChildLoop(I); } else { - Loop::iterator I = find(LI->begin(), LI->end(), L); + Loop::iterator I = find(*LI, L); assert(I != LI->end() && "Couldn't find loop"); LI->removeLoop(I); } @@ -688,17 +726,17 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, } } -Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { - // Support loops with an exiting latch and other existing exists only - // deoptimize. - - // Get the branch weights for the loop's backedge. +/// Checks if \p L has single exit through latch block except possibly +/// "deoptimizing" exits. Returns branch instruction terminating the loop +/// latch if above check is successful, nullptr otherwise. +static BranchInst *getExpectedExitLoopLatchBranch(Loop *L) { BasicBlock *Latch = L->getLoopLatch(); if (!Latch) - return None; + return nullptr; + BranchInst *LatchBR = dyn_cast<BranchInst>(Latch->getTerminator()); if (!LatchBR || LatchBR->getNumSuccessors() != 2 || !L->isLoopExiting(Latch)) - return None; + return nullptr; assert((LatchBR->getSuccessor(0) == L->getHeader() || LatchBR->getSuccessor(1) == L->getHeader()) && @@ -709,24 +747,73 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { if (any_of(ExitBlocks, [](const BasicBlock *EB) { return !EB->getTerminatingDeoptimizeCall(); })) + return nullptr; + + return LatchBR; +} + +Optional<unsigned> +llvm::getLoopEstimatedTripCount(Loop *L, + unsigned *EstimatedLoopInvocationWeight) { + // Support loops with an exiting latch and other existing exists only + // deoptimize. + BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); + if (!LatchBranch) return None; // To estimate the number of times the loop body was executed, we want to // know the number of times the backedge was taken, vs. the number of times // we exited the loop. uint64_t BackedgeTakenWeight, LatchExitWeight; - if (!LatchBR->extractProfMetadata(BackedgeTakenWeight, LatchExitWeight)) + if (!LatchBranch->extractProfMetadata(BackedgeTakenWeight, LatchExitWeight)) return None; - if (LatchBR->getSuccessor(0) != L->getHeader()) + if (LatchBranch->getSuccessor(0) != L->getHeader()) + std::swap(BackedgeTakenWeight, LatchExitWeight); + + if (!LatchExitWeight) + return None; + + if (EstimatedLoopInvocationWeight) + *EstimatedLoopInvocationWeight = LatchExitWeight; + + // Estimated backedge taken count is a ratio of the backedge taken weight by + // the weight of the edge exiting the loop, rounded to nearest. + uint64_t BackedgeTakenCount = + llvm::divideNearest(BackedgeTakenWeight, LatchExitWeight); + // Estimated trip count is one plus estimated backedge taken count. + return BackedgeTakenCount + 1; +} + +bool llvm::setLoopEstimatedTripCount(Loop *L, unsigned EstimatedTripCount, + unsigned EstimatedloopInvocationWeight) { + // Support loops with an exiting latch and other existing exists only + // deoptimize. + BranchInst *LatchBranch = getExpectedExitLoopLatchBranch(L); + if (!LatchBranch) + return false; + + // Calculate taken and exit weights. + unsigned LatchExitWeight = 0; + unsigned BackedgeTakenWeight = 0; + + if (EstimatedTripCount > 0) { + LatchExitWeight = EstimatedloopInvocationWeight; + BackedgeTakenWeight = (EstimatedTripCount - 1) * LatchExitWeight; + } + + // Make a swap if back edge is taken when condition is "false". + if (LatchBranch->getSuccessor(0) != L->getHeader()) std::swap(BackedgeTakenWeight, LatchExitWeight); - if (!BackedgeTakenWeight || !LatchExitWeight) - return 0; + MDBuilder MDB(LatchBranch->getContext()); - // Divide the count of the backedge by the count of the edge exiting the loop, - // rounding to nearest. - return llvm::divideNearest(BackedgeTakenWeight, LatchExitWeight); + // Set/Update profile metadata. + LatchBranch->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights(BackedgeTakenWeight, LatchExitWeight)); + + return true; } bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, @@ -751,7 +838,7 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } -Value *llvm::createMinMaxOp(IRBuilder<> &Builder, +Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurrenceDescriptor::MinMaxRecurrenceKind RK, Value *Left, Value *Right) { CmpInst::Predicate P = CmpInst::ICMP_NE; @@ -780,29 +867,22 @@ Value *llvm::createMinMaxOp(IRBuilder<> &Builder, // We only match FP sequences that are 'fast', so we can unconditionally // set it on any generated instructions. - IRBuilder<>::FastMathFlagGuard FMFG(Builder); + IRBuilderBase::FastMathFlagGuard FMFG(Builder); FastMathFlags FMF; FMF.setFast(); Builder.setFastMathFlags(FMF); - - Value *Cmp; - if (RK == RecurrenceDescriptor::MRK_FloatMin || - RK == RecurrenceDescriptor::MRK_FloatMax) - Cmp = Builder.CreateFCmp(P, Left, Right, "rdx.minmax.cmp"); - else - Cmp = Builder.CreateICmp(P, Left, Right, "rdx.minmax.cmp"); - + Value *Cmp = Builder.CreateCmp(P, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); return Select; } // Helper to generate an ordered reduction. Value * -llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, +llvm::getOrderedReduction(IRBuilderBase &Builder, Value *Acc, Value *Src, unsigned Op, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind, ArrayRef<Value *> RedOps) { - unsigned VF = Src->getType()->getVectorNumElements(); + unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements(); // Extract and apply reduction ops in ascending order: // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1] @@ -829,29 +909,27 @@ llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src, // Helper to generate a log2 shuffle reduction. Value * -llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, +llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, unsigned Op, RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind, ArrayRef<Value *> RedOps) { - unsigned VF = Src->getType()->getVectorNumElements(); + unsigned VF = cast<FixedVectorType>(Src->getType())->getNumElements(); // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles // and vector ops, reducing the set of values being computed by half each // round. assert(isPowerOf2_32(VF) && "Reduction emission only supported for pow2 vectors!"); Value *TmpVec = Src; - SmallVector<Constant *, 32> ShuffleMask(VF, nullptr); + SmallVector<int, 32> ShuffleMask(VF); for (unsigned i = VF; i != 1; i >>= 1) { // Move the upper half of the vector to the lower half. for (unsigned j = 0; j != i / 2; ++j) - ShuffleMask[j] = Builder.getInt32(i / 2 + j); + ShuffleMask[j] = i / 2 + j; // Fill the rest of the mask with undef. - std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), - UndefValue::get(Builder.getInt32Ty())); + std::fill(&ShuffleMask[i / 2], ShuffleMask.end(), -1); Value *Shuf = Builder.CreateShuffleVector( - TmpVec, UndefValue::get(TmpVec->getType()), - ConstantVector::get(ShuffleMask), "rdx.shuf"); + TmpVec, UndefValue::get(TmpVec->getType()), ShuffleMask, "rdx.shuf"); if (Op != Instruction::ICmp && Op != Instruction::FCmp) { // The builder propagates its fast-math-flags setting. @@ -864,6 +942,11 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, } if (!RedOps.empty()) propagateIRFlags(TmpVec, RedOps); + + // We may compute the reassociated scalar ops in a way that does not + // preserve nsw/nuw etc. Conservatively, drop those flags. + if (auto *ReductionInst = dyn_cast<Instruction>(TmpVec)) + ReductionInst->dropPoisonGeneratingFlags(); } // The result is in the first element of the vector. return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); @@ -872,10 +955,10 @@ llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op, /// Create a simple vector reduction specified by an opcode and some /// flags (if generating min/max reductions). Value *llvm::createSimpleTargetReduction( - IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode, + IRBuilderBase &Builder, const TargetTransformInfo *TTI, unsigned Opcode, Value *Src, TargetTransformInfo::ReductionFlags Flags, ArrayRef<Value *> RedOps) { - assert(isa<VectorType>(Src->getType()) && "Type must be a vector"); + auto *SrcVTy = cast<VectorType>(Src->getType()); std::function<Value *()> BuildFunc; using RD = RecurrenceDescriptor; @@ -900,13 +983,13 @@ Value *llvm::createSimpleTargetReduction( case Instruction::FAdd: BuildFunc = [&]() { auto Rdx = Builder.CreateFAddReduce( - Constant::getNullValue(Src->getType()->getVectorElementType()), Src); + Constant::getNullValue(SrcVTy->getElementType()), Src); return Rdx; }; break; case Instruction::FMul: BuildFunc = [&]() { - Type *Ty = Src->getType()->getVectorElementType(); + Type *Ty = SrcVTy->getElementType(); auto Rdx = Builder.CreateFMulReduce(ConstantFP::get(Ty, 1.0), Src); return Rdx; }; @@ -937,13 +1020,14 @@ Value *llvm::createSimpleTargetReduction( llvm_unreachable("Unhandled opcode"); break; } - if (TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags)) + if (ForceReductionIntrinsic || + TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags)) return BuildFunc(); return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps); } /// Create a vector reduction using a given recurrence descriptor. -Value *llvm::createTargetReduction(IRBuilder<> &B, +Value *llvm::createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI, RecurrenceDescriptor &Desc, Value *Src, bool NoNaN) { @@ -955,7 +1039,7 @@ Value *llvm::createTargetReduction(IRBuilder<> &B, // All ops in the reduction inherit fast-math-flags from the recurrence // descriptor. - IRBuilder<>::FastMathFlagGuard FMFGuard(B); + IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); switch (RecKind) { @@ -1042,3 +1126,586 @@ bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, SE.isLoopEntryGuardedByCond(L, Predicate, S, SE.getConstant(Max)); } + +//===----------------------------------------------------------------------===// +// rewriteLoopExitValues - Optimize IV users outside the loop. +// As a side effect, reduces the amount of IV processing within the loop. +//===----------------------------------------------------------------------===// + +// Return true if the SCEV expansion generated by the rewriter can replace the +// original value. SCEV guarantees that it produces the same value, but the way +// it is produced may be illegal IR. Ideally, this function will only be +// called for verification. +static bool isValidRewrite(ScalarEvolution *SE, Value *FromVal, Value *ToVal) { + // If an SCEV expression subsumed multiple pointers, its expansion could + // reassociate the GEP changing the base pointer. This is illegal because the + // final address produced by a GEP chain must be inbounds relative to its + // underlying object. Otherwise basic alias analysis, among other things, + // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid + // producing an expression involving multiple pointers. Until then, we must + // bail out here. + // + // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject + // because it understands lcssa phis while SCEV does not. + Value *FromPtr = FromVal; + Value *ToPtr = ToVal; + if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) + FromPtr = GEP->getPointerOperand(); + + if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) + ToPtr = GEP->getPointerOperand(); + + if (FromPtr != FromVal || ToPtr != ToVal) { + // Quickly check the common case + if (FromPtr == ToPtr) + return true; + + // SCEV may have rewritten an expression that produces the GEP's pointer + // operand. That's ok as long as the pointer operand has the same base + // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the + // base of a recurrence. This handles the case in which SCEV expansion + // converts a pointer type recurrence into a nonrecurrent pointer base + // indexed by an integer recurrence. + + // If the GEP base pointer is a vector of pointers, abort. + if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) + return false; + + const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); + const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); + if (FromBase == ToBase) + return true; + + LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: GEP rewrite bail out " + << *FromBase << " != " << *ToBase << "\n"); + + return false; + } + return true; +} + +static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) { + SmallPtrSet<const Instruction *, 8> Visited; + SmallVector<const Instruction *, 8> WorkList; + Visited.insert(I); + WorkList.push_back(I); + while (!WorkList.empty()) { + const Instruction *Curr = WorkList.pop_back_val(); + // This use is outside the loop, nothing to do. + if (!L->contains(Curr)) + continue; + // Do we assume it is a "hard" use which will not be eliminated easily? + if (Curr->mayHaveSideEffects()) + return true; + // Otherwise, add all its users to worklist. + for (auto U : Curr->users()) { + auto *UI = cast<Instruction>(U); + if (Visited.insert(UI).second) + WorkList.push_back(UI); + } + } + return false; +} + +// Collect information about PHI nodes which can be transformed in +// rewriteLoopExitValues. +struct RewritePhi { + PHINode *PN; // For which PHI node is this replacement? + unsigned Ith; // For which incoming value? + const SCEV *ExpansionSCEV; // The SCEV of the incoming value we are rewriting. + Instruction *ExpansionPoint; // Where we'd like to expand that SCEV? + bool HighCost; // Is this expansion a high-cost? + + Value *Expansion = nullptr; + bool ValidRewrite = false; + + RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt, + bool H) + : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt), + HighCost(H) {} +}; + +// Check whether it is possible to delete the loop after rewriting exit +// value. If it is possible, ignore ReplaceExitValue and do rewriting +// aggressively. +static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) { + BasicBlock *Preheader = L->getLoopPreheader(); + // If there is no preheader, the loop will not be deleted. + if (!Preheader) + return false; + + // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1. + // We obviate multiple ExitingBlocks case for simplicity. + // TODO: If we see testcase with multiple ExitingBlocks can be deleted + // after exit value rewriting, we can enhance the logic here. + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + if (ExitBlocks.size() != 1 || ExitingBlocks.size() != 1) + return false; + + BasicBlock *ExitBlock = ExitBlocks[0]; + BasicBlock::iterator BI = ExitBlock->begin(); + while (PHINode *P = dyn_cast<PHINode>(BI)) { + Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); + + // If the Incoming value of P is found in RewritePhiSet, we know it + // could be rewritten to use a loop invariant value in transformation + // phase later. Skip it in the loop invariant check below. + bool found = false; + for (const RewritePhi &Phi : RewritePhiSet) { + if (!Phi.ValidRewrite) + continue; + unsigned i = Phi.Ith; + if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { + found = true; + break; + } + } + + Instruction *I; + if (!found && (I = dyn_cast<Instruction>(Incoming))) + if (!L->hasLoopInvariantOperands(I)) + return false; + + ++BI; + } + + for (auto *BB : L->blocks()) + if (llvm::any_of(*BB, [](Instruction &I) { + return I.mayHaveSideEffects(); + })) + return false; + + return true; +} + +int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, + ScalarEvolution *SE, + const TargetTransformInfo *TTI, + SCEVExpander &Rewriter, DominatorTree *DT, + ReplaceExitVal ReplaceExitValue, + SmallVector<WeakTrackingVH, 16> &DeadInsts) { + // Check a pre-condition. + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Indvars did not preserve LCSSA!"); + + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + SmallVector<RewritePhi, 8> RewritePhiSet; + // Find all values that are computed inside the loop, but used outside of it. + // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan + // the exit blocks of the loop to find them. + for (BasicBlock *ExitBB : ExitBlocks) { + // If there are no PHI nodes in this exit block, then no values defined + // inside the loop are used on this path, skip it. + PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); + if (!PN) continue; + + unsigned NumPreds = PN->getNumIncomingValues(); + + // Iterate over all of the PHI nodes. + BasicBlock::iterator BBI = ExitBB->begin(); + while ((PN = dyn_cast<PHINode>(BBI++))) { + if (PN->use_empty()) + continue; // dead use, don't replace it + + if (!SE->isSCEVable(PN->getType())) + continue; + + // It's necessary to tell ScalarEvolution about this explicitly so that + // it can walk the def-use list and forget all SCEVs, as it may not be + // watching the PHI itself. Once the new exit value is in place, there + // may not be a def-use connection between the loop and every instruction + // which got a SCEVAddRecExpr for that loop. + SE->forgetValue(PN); + + // Iterate over all of the values in all the PHI nodes. + for (unsigned i = 0; i != NumPreds; ++i) { + // If the value being merged in is not integer or is not defined + // in the loop, skip it. + Value *InVal = PN->getIncomingValue(i); + if (!isa<Instruction>(InVal)) + continue; + + // If this pred is for a subloop, not L itself, skip it. + if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) + continue; // The Block is in a subloop, skip it. + + // Check that InVal is defined in the loop. + Instruction *Inst = cast<Instruction>(InVal); + if (!L->contains(Inst)) + continue; + + // Okay, this instruction has a user outside of the current loop + // and varies predictably *inside* the loop. Evaluate the value it + // contains when the loop exits, if possible. We prefer to start with + // expressions which are true for all exits (so as to maximize + // expression reuse by the SCEVExpander), but resort to per-exit + // evaluation if that fails. + const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); + if (isa<SCEVCouldNotCompute>(ExitValue) || + !SE->isLoopInvariant(ExitValue, L) || + !isSafeToExpand(ExitValue, *SE)) { + // TODO: This should probably be sunk into SCEV in some way; maybe a + // getSCEVForExit(SCEV*, L, ExitingBB)? It can be generalized for + // most SCEV expressions and other recurrence types (e.g. shift + // recurrences). Is there existing code we can reuse? + const SCEV *ExitCount = SE->getExitCount(L, PN->getIncomingBlock(i)); + if (isa<SCEVCouldNotCompute>(ExitCount)) + continue; + if (auto *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Inst))) + if (AddRec->getLoop() == L) + ExitValue = AddRec->evaluateAtIteration(ExitCount, *SE); + if (isa<SCEVCouldNotCompute>(ExitValue) || + !SE->isLoopInvariant(ExitValue, L) || + !isSafeToExpand(ExitValue, *SE)) + continue; + } + + // Computing the value outside of the loop brings no benefit if it is + // definitely used inside the loop in a way which can not be optimized + // away. Avoid doing so unless we know we have a value which computes + // the ExitValue already. TODO: This should be merged into SCEV + // expander to leverage its knowledge of existing expressions. + if (ReplaceExitValue != AlwaysRepl && !isa<SCEVConstant>(ExitValue) && + !isa<SCEVUnknown>(ExitValue) && hasHardUserWithinLoop(L, Inst)) + continue; + + // Check if expansions of this SCEV would count as being high cost. + bool HighCost = Rewriter.isHighCostExpansion( + ExitValue, L, SCEVCheapExpansionBudget, TTI, Inst); + + // Note that we must not perform expansions until after + // we query *all* the costs, because if we perform temporary expansion + // inbetween, one that we might not intend to keep, said expansion + // *may* affect cost calculation of the the next SCEV's we'll query, + // and next SCEV may errneously get smaller cost. + + // Collect all the candidate PHINodes to be rewritten. + RewritePhiSet.emplace_back(PN, i, ExitValue, Inst, HighCost); + } + } + } + + // Now that we've done preliminary filtering and billed all the SCEV's, + // we can perform the last sanity check - the expansion must be valid. + for (RewritePhi &Phi : RewritePhiSet) { + Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(), + Phi.ExpansionPoint); + + LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " + << *(Phi.Expansion) << '\n' + << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); + + // FIXME: isValidRewrite() is a hack. it should be an assert, eventually. + Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion); + if (!Phi.ValidRewrite) { + DeadInsts.push_back(Phi.Expansion); + continue; + } + +#ifndef NDEBUG + // If we reuse an instruction from a loop which is neither L nor one of + // its containing loops, we end up breaking LCSSA form for this loop by + // creating a new use of its instruction. + if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion)) + if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) + if (EVL != L) + assert(EVL->contains(L) && "LCSSA breach detected!"); +#endif + } + + // TODO: after isValidRewrite() is an assertion, evaluate whether + // it is beneficial to change how we calculate high-cost: + // if we have SCEV 'A' which we know we will expand, should we calculate + // the cost of other SCEV's after expanding SCEV 'A', + // thus potentially giving cost bonus to those other SCEV's? + + bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); + int NumReplaced = 0; + + // Transformation. + for (const RewritePhi &Phi : RewritePhiSet) { + if (!Phi.ValidRewrite) + continue; + + PHINode *PN = Phi.PN; + Value *ExitVal = Phi.Expansion; + + // Only do the rewrite when the ExitValue can be expanded cheaply. + // If LoopCanBeDel is true, rewrite exit value aggressively. + if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { + DeadInsts.push_back(ExitVal); + continue; + } + + NumReplaced++; + Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); + PN->setIncomingValue(Phi.Ith, ExitVal); + + // If this instruction is dead now, delete it. Don't do it now to avoid + // invalidating iterators. + if (isInstructionTriviallyDead(Inst, TLI)) + DeadInsts.push_back(Inst); + + // Replace PN with ExitVal if that is legal and does not break LCSSA. + if (PN->getNumIncomingValues() == 1 && + LI->replacementPreservesLCSSAForm(PN, ExitVal)) { + PN->replaceAllUsesWith(ExitVal); + PN->eraseFromParent(); + } + } + + // The insertion point instruction may have been deleted; clear it out + // so that the rewriter doesn't trip over it later. + Rewriter.clearInsertPoint(); + return NumReplaced; +} + +/// Set weights for \p UnrolledLoop and \p RemainderLoop based on weights for +/// \p OrigLoop. +void llvm::setProfileInfoAfterUnrolling(Loop *OrigLoop, Loop *UnrolledLoop, + Loop *RemainderLoop, uint64_t UF) { + assert(UF > 0 && "Zero unrolled factor is not supported"); + assert(UnrolledLoop != RemainderLoop && + "Unrolled and Remainder loops are expected to distinct"); + + // Get number of iterations in the original scalar loop. + unsigned OrigLoopInvocationWeight = 0; + Optional<unsigned> OrigAverageTripCount = + getLoopEstimatedTripCount(OrigLoop, &OrigLoopInvocationWeight); + if (!OrigAverageTripCount) + return; + + // Calculate number of iterations in unrolled loop. + unsigned UnrolledAverageTripCount = *OrigAverageTripCount / UF; + // Calculate number of iterations for remainder loop. + unsigned RemainderAverageTripCount = *OrigAverageTripCount % UF; + + setLoopEstimatedTripCount(UnrolledLoop, UnrolledAverageTripCount, + OrigLoopInvocationWeight); + setLoopEstimatedTripCount(RemainderLoop, RemainderAverageTripCount, + OrigLoopInvocationWeight); +} + +/// Utility that implements appending of loops onto a worklist. +/// Loops are added in preorder (analogous for reverse postorder for trees), +/// and the worklist is processed LIFO. +template <typename RangeT> +void llvm::appendReversedLoopsToWorklist( + RangeT &&Loops, SmallPriorityWorklist<Loop *, 4> &Worklist) { + // We use an internal worklist to build up the preorder traversal without + // recursion. + SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; + + // We walk the initial sequence of loops in reverse because we generally want + // to visit defs before uses and the worklist is LIFO. + for (Loop *RootL : Loops) { + assert(PreOrderLoops.empty() && "Must start with an empty preorder walk."); + assert(PreOrderWorklist.empty() && + "Must start with an empty preorder walk worklist."); + PreOrderWorklist.push_back(RootL); + do { + Loop *L = PreOrderWorklist.pop_back_val(); + PreOrderWorklist.append(L->begin(), L->end()); + PreOrderLoops.push_back(L); + } while (!PreOrderWorklist.empty()); + + Worklist.insert(std::move(PreOrderLoops)); + PreOrderLoops.clear(); + } +} + +template <typename RangeT> +void llvm::appendLoopsToWorklist(RangeT &&Loops, + SmallPriorityWorklist<Loop *, 4> &Worklist) { + appendReversedLoopsToWorklist(reverse(Loops), Worklist); +} + +template void llvm::appendLoopsToWorklist<ArrayRef<Loop *> &>( + ArrayRef<Loop *> &Loops, SmallPriorityWorklist<Loop *, 4> &Worklist); + +template void +llvm::appendLoopsToWorklist<Loop &>(Loop &L, + SmallPriorityWorklist<Loop *, 4> &Worklist); + +void llvm::appendLoopsToWorklist(LoopInfo &LI, + SmallPriorityWorklist<Loop *, 4> &Worklist) { + appendReversedLoopsToWorklist(LI, Worklist); +} + +Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, + LoopInfo *LI, LPPassManager *LPM) { + Loop &New = *LI->AllocateLoop(); + if (PL) + PL->addChildLoop(&New); + else + LI->addTopLevelLoop(&New); + + if (LPM) + LPM->addLoop(New); + + // Add all of the blocks in L to the new loop. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + if (LI->getLoopFor(*I) == L) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI); + + // Add all of the subloops to the new loop. + for (Loop *I : *L) + cloneLoop(I, &New, VM, LI, LPM); + + return &New; +} + +/// IR Values for the lower and upper bounds of a pointer evolution. We +/// need to use value-handles because SCEV expansion can invalidate previously +/// expanded values. Thus expansion of a pointer can invalidate the bounds for +/// a previous one. +struct PointerBounds { + TrackingVH<Value> Start; + TrackingVH<Value> End; +}; + +/// Expand code for the lower and upper bound of the pointer group \p CG +/// in \p TheLoop. \return the values for the bounds. +static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, + Loop *TheLoop, Instruction *Loc, + SCEVExpander &Exp, ScalarEvolution *SE) { + // TODO: Add helper to retrieve pointers to CG. + Value *Ptr = CG->RtCheck.Pointers[CG->Members[0]].PointerValue; + const SCEV *Sc = SE->getSCEV(Ptr); + + unsigned AS = Ptr->getType()->getPointerAddressSpace(); + LLVMContext &Ctx = Loc->getContext(); + + // Use this type for pointer arithmetic. + Type *PtrArithTy = Type::getInt8PtrTy(Ctx, AS); + + if (SE->isLoopInvariant(Sc, TheLoop)) { + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for a loop invariant ptr:" + << *Ptr << "\n"); + // Ptr could be in the loop body. If so, expand a new one at the correct + // location. + Instruction *Inst = dyn_cast<Instruction>(Ptr); + Value *NewPtr = (Inst && TheLoop->contains(Inst)) + ? Exp.expandCodeFor(Sc, PtrArithTy, Loc) + : Ptr; + // We must return a half-open range, which means incrementing Sc. + const SCEV *ScPlusOne = SE->getAddExpr(Sc, SE->getOne(PtrArithTy)); + Value *NewPtrPlusOne = Exp.expandCodeFor(ScPlusOne, PtrArithTy, Loc); + return {NewPtr, NewPtrPlusOne}; + } else { + Value *Start = nullptr, *End = nullptr; + LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); + Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High + << "\n"); + return {Start, End}; + } +} + +/// Turns a collection of checks into a collection of expanded upper and +/// lower bounds for both pointers in the check. +static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> +expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, + Instruction *Loc, ScalarEvolution *SE, SCEVExpander &Exp) { + SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; + + // Here we're relying on the SCEV Expander's cache to only emit code for the + // same bounds once. + transform(PointerChecks, std::back_inserter(ChecksWithBounds), + [&](const RuntimePointerCheck &Check) { + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, SE), + Second = + expandBounds(Check.second, L, Loc, Exp, SE); + return std::make_pair(First, Second); + }); + + return ChecksWithBounds; +} + +std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( + Instruction *Loc, Loop *TheLoop, + const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, + ScalarEvolution *SE) { + // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. + // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible + const DataLayout &DL = TheLoop->getHeader()->getModule()->getDataLayout(); + SCEVExpander Exp(*SE, DL, "induction"); + auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, SE, Exp); + + LLVMContext &Ctx = Loc->getContext(); + Instruction *FirstInst = nullptr; + IRBuilder<> ChkBuilder(Loc); + // Our instructions might fold to a constant. + Value *MemoryRuntimeCheck = nullptr; + + // FIXME: this helper is currently a duplicate of the one in + // LoopVectorize.cpp. + auto GetFirstInst = [](Instruction *FirstInst, Value *V, + Instruction *Loc) -> Instruction * { + if (FirstInst) + return FirstInst; + if (Instruction *I = dyn_cast<Instruction>(V)) + return I->getParent() == Loc->getParent() ? I : nullptr; + return nullptr; + }; + + for (const auto &Check : ExpandedChecks) { + const PointerBounds &A = Check.first, &B = Check.second; + // Check if two pointers (A and B) conflict where conflict is computed as: + // start(A) <= end(B) && start(B) <= end(A) + unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); + unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); + + assert((AS0 == B.End->getType()->getPointerAddressSpace()) && + (AS1 == A.End->getType()->getPointerAddressSpace()) && + "Trying to bounds check pointers with different address spaces"); + + Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); + Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); + + Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); + Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); + Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); + Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); + + // [A|B].Start points to the first accessed byte under base [A|B]. + // [A|B].End points to the last accessed byte, plus one. + // There is no conflict when the intervals are disjoint: + // NoConflict = (B.Start >= A.End) || (A.Start >= B.End) + // + // bound0 = (B.Start < A.End) + // bound1 = (A.Start < B.End) + // IsConflict = bound0 & bound1 + Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); + FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); + Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); + Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + if (MemoryRuntimeCheck) { + IsConflict = + ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); + FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); + } + MemoryRuntimeCheck = IsConflict; + } + + if (!MemoryRuntimeCheck) + return std::make_pair(nullptr, nullptr); + + // We have to do this trickery because the IRBuilder might fold the check to a + // constant expression in which case there is no Instruction anchored in a + // the block. + Instruction *Check = + BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); + ChkBuilder.Insert(Check, "memcheck.conflict"); + FirstInst = GetFirstInst(FirstInst, Check, Loc); + return std::make_pair(FirstInst, Check); +} |