diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp | 904 |
1 files changed, 904 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp new file mode 100644 index 000000000000..ea6d952cfa7d --- /dev/null +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp @@ -0,0 +1,904 @@ +#include "llvm/Transforms/Utils/LoopConstrainer.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h" + +using namespace llvm; + +static const char *ClonedLoopTag = "loop_constrainer.loop.clone"; + +#define DEBUG_TYPE "loop-constrainer" + +/// Given a loop with an deccreasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV, + const SCEV *Step, ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, Loop *L, + ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + assert(SE.isKnownNegative(Step) && "expecting negative step"); + + LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); + LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1"); + + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth) + : APInt::getMinValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne); + + const SCEV *MinusOne = + SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType())); + + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit); +} + +/// Given a loop with an increasing induction variable, is it possible to +/// safely calculate the bounds of a new loop using the given Predicate. +static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV, + const SCEV *Step, ICmpInst::Predicate Pred, + unsigned LatchBrExitIdx, Loop *L, + ScalarEvolution &SE) { + if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT && + Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT) + return false; + + if (!SE.isAvailableAtLoopEntry(BoundSCEV, L)) + return false; + + LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n"); + LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n"); + LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n"); + LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n"); + LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n"); + LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n"); + + bool IsSigned = ICmpInst::isSigned(Pred); + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + + if (LatchBrExitIdx == 1) + return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV); + + assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1"); + + const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType())); + unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth(); + APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth) + : APInt::getMaxValue(BitWidth); + const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne); + + return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start, + SE.getAddExpr(BoundSCEV, Step)) && + SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit)); +} + +/// Returns estimate for max latch taken count of the loop of the narrowest +/// available type. If the latch block has such estimate, it is returned. +/// Otherwise, we use max exit count of whole loop (that is potentially of wider +/// type than latch check itself), which is still better than no estimate. +static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE, + const Loop &L) { + const SCEV *FromBlock = + SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum); + if (isa<SCEVCouldNotCompute>(FromBlock)) + return SE.getSymbolicMaxBackedgeTakenCount(&L); + return FromBlock; +} + +std::optional<LoopStructure> +LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L, + bool AllowUnsignedLatchCond, + const char *&FailureReason) { + if (!L.isLoopSimplifyForm()) { + FailureReason = "loop not in LoopSimplify form"; + return std::nullopt; + } + + BasicBlock *Latch = L.getLoopLatch(); + assert(Latch && "Simplified loops only have one latch!"); + + if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { + FailureReason = "loop has already been cloned"; + return std::nullopt; + } + + if (!L.isLoopExiting(Latch)) { + FailureReason = "no loop latch"; + return std::nullopt; + } + + BasicBlock *Header = L.getHeader(); + BasicBlock *Preheader = L.getLoopPreheader(); + if (!Preheader) { + FailureReason = "no preheader"; + return std::nullopt; + } + + BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!LatchBr || LatchBr->isUnconditional()) { + FailureReason = "latch terminator not conditional branch"; + return std::nullopt; + } + + unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; + + ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); + if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { + FailureReason = "latch terminator branch not conditional on integral icmp"; + return std::nullopt; + } + + const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L); + if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) { + FailureReason = "could not compute latch count"; + return std::nullopt; + } + assert(SE.getLoopDisposition(MaxBETakenCount, &L) == + ScalarEvolution::LoopInvariant && + "loop variant exit count doesn't make sense!"); + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LeftValue = ICI->getOperand(0); + const SCEV *LeftSCEV = SE.getSCEV(LeftValue); + IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); + + Value *RightValue = ICI->getOperand(1); + const SCEV *RightSCEV = SE.getSCEV(RightValue); + + // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. + if (!isa<SCEVAddRecExpr>(LeftSCEV)) { + if (isa<SCEVAddRecExpr>(RightSCEV)) { + std::swap(LeftSCEV, RightSCEV); + std::swap(LeftValue, RightValue); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + FailureReason = "no add recurrences in the icmp"; + return std::nullopt; + } + } + + auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { + if (AR->getNoWrapFlags(SCEV::FlagNSW)) + return true; + + IntegerType *Ty = cast<IntegerType>(AR->getType()); + IntegerType *WideTy = + IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + + const SCEVAddRecExpr *ExtendAfterOp = + dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); + if (ExtendAfterOp) { + const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); + const SCEV *ExtendedStep = + SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); + + bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && + ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; + + if (NoSignedWrap) + return true; + } + + // We may have proved this when computing the sign extension above. + return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; + }; + + // `ICI` is interpreted as taking the backedge if the *next* value of the + // induction variable satisfies some constraint. + + const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); + if (IndVarBase->getLoop() != &L) { + FailureReason = "LHS in cmp is not an AddRec for this loop"; + return std::nullopt; + } + if (!IndVarBase->isAffine()) { + FailureReason = "LHS in icmp not induction variable"; + return std::nullopt; + } + const SCEV *StepRec = IndVarBase->getStepRecurrence(SE); + if (!isa<SCEVConstant>(StepRec)) { + FailureReason = "LHS in icmp not induction variable"; + return std::nullopt; + } + ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue(); + + if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) { + FailureReason = "LHS in icmp needs nsw for equality predicates"; + return std::nullopt; + } + + assert(!StepCI->isZero() && "Zero step?"); + bool IsIncreasing = !StepCI->isNegative(); + bool IsSignedPredicate; + const SCEV *StartNext = IndVarBase->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); + const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + const SCEV *Step = SE.getSCEV(StepCI); + + const SCEV *FixedRightSCEV = nullptr; + + // If RightValue resides within loop (but still being loop invariant), + // regenerate it as preheader. + if (auto *I = dyn_cast<Instruction>(RightValue)) + if (L.contains(I->getParent())) + FixedRightSCEV = RightSCEV; + + if (IsIncreasing) { + bool DecreasedRightValueByOne = false; + if (StepCI->isOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + // If both parts are known non-negative, it is profitable to use + // unsigned comparison in increasing loop. This allows us to make the + // comparison check against "RightSCEV + 1" more optimistic. + if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) && + isKnownNonNegativeInLoop(RightSCEV, &L, SE)) + Pred = ICmpInst::ICMP_ULT; + else + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) { + Pred = ICmpInst::ICMP_UGT; + RightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) { + Pred = ICmpInst::ICMP_SGT; + RightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + bool FoundExpectedPred = + (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp slt semantically, found something else"; + return std::nullopt; + } + + IsSignedPredicate = ICmpInst::isSigned(Pred); + if (!IsSignedPredicate && !AllowUnsignedLatchCond) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return std::nullopt; + } + + if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe loop bounds"; + return std::nullopt; + } + if (LatchBrExitIdx == 0) { + // We need to increase the right value unless we have already decreased + // it virtually when we replaced EQ with SGT. + if (!DecreasedRightValueByOne) + FixedRightSCEV = + SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + } else { + assert(!DecreasedRightValueByOne && + "Right value can be decreased only for LatchBrExitIdx == 0!"); + } + } else { + bool IncreasedRightValueByOne = false; + if (StepCI->isMinusOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + // We intentionally don't turn the predicate into UGT even if we know + // that both operands are non-negative, because it will only pessimize + // our check against "RightSCEV - 1". + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) && + cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) { + Pred = ICmpInst::ICMP_ULT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) { + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + + bool FoundExpectedPred = + (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp sgt semantically, found something else"; + return std::nullopt; + } + + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCond) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return std::nullopt; + } + + if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred, + LatchBrExitIdx, &L, SE)) { + FailureReason = "Unsafe bounds"; + return std::nullopt; + } + + if (LatchBrExitIdx == 0) { + // We need to decrease the right value unless we have already increased + // it virtually when we replaced EQ with SLT. + if (!IncreasedRightValueByOne) + FixedRightSCEV = + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + } else { + assert(!IncreasedRightValueByOne && + "Right value can be increased only for LatchBrExitIdx == 0!"); + } + } + BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); + + assert(!L.contains(LatchExit) && "expected an exit block!"); + const DataLayout &DL = Preheader->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "loop-constrainer"); + Instruction *Ins = Preheader->getTerminator(); + + if (FixedRightSCEV) + RightValue = + Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins); + + Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins); + IndVarStartV->setName("indvar.start"); + + LoopStructure Result; + + Result.Tag = "main"; + Result.Header = Header; + Result.Latch = Latch; + Result.LatchBr = LatchBr; + Result.LatchExit = LatchExit; + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarStart = IndVarStartV; + Result.IndVarStep = StepCI; + Result.IndVarBase = LeftValue; + Result.IndVarIncreasing = IsIncreasing; + Result.LoopExitAt = RightValue; + Result.IsSignedPredicate = IsSignedPredicate; + Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType()); + + FailureReason = nullptr; + + return Result; +} + +// Add metadata to the loop L to disable loop optimizations. Callers need to +// confirm that optimizing loop L is not beneficial. +static void DisableAllLoopOptsOnLoop(Loop &L) { + // We do not care about any existing loopID related metadata for L, since we + // are setting all loop metadata to false. + LLVMContext &Context = L.getHeader()->getContext(); + // Reserve first location for self reference to the LoopID metadata node. + MDNode *Dummy = MDNode::get(Context, {}); + MDNode *DisableUnroll = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); + Metadata *FalseVal = + ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); + MDNode *DisableVectorize = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); + MDNode *DisableLICMVersioning = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); + MDNode *DisableDistribution = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); + MDNode *NewLoopID = + MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, + DisableLICMVersioning, DisableDistribution}); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L.setLoopID(NewLoopID); +} + +LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI, + function_ref<void(Loop *, bool)> LPMAddNewLoop, + const LoopStructure &LS, ScalarEvolution &SE, + DominatorTree &DT, Type *T, SubRanges SR) + : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE), + DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T), + MainLoopStructure(LS), SR(SR) {} + +void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, + const char *Tag) const { + for (BasicBlock *BB : OriginalLoop.getBlocks()) { + BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); + Result.Blocks.push_back(Clone); + Result.Map[BB] = Clone; + } + + auto GetClonedValue = [&Result](Value *V) { + assert(V && "null values not in domain!"); + auto It = Result.Map.find(V); + if (It == Result.Map.end()) + return V; + return static_cast<Value *>(It->second); + }; + + auto *ClonedLatch = + cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); + ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, + MDNode::get(Ctx, {})); + + Result.Structure = MainLoopStructure.map(GetClonedValue); + Result.Structure.Tag = Tag; + + for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { + BasicBlock *ClonedBB = Result.Blocks[i]; + BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; + + assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); + + for (Instruction &I : *ClonedBB) + RemapInstruction(&I, Result.Map, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Exit blocks will now have one more predecessor and their PHI nodes need + // to be edited to reflect that. No phi nodes need to be introduced because + // the loop is in LCSSA. + + for (auto *SBB : successors(OriginalBB)) { + if (OriginalLoop.contains(SBB)) + continue; // not an exit block + + for (PHINode &PN : SBB->phis()) { + Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); + PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); + SE.forgetValue(&PN); + } + } + } +} + +LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( + const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, + BasicBlock *ContinuationBlock) const { + // We start with a loop with a single latch: + // + // +--------------------+ + // | | + // | preheader | + // | | + // +--------+-----------+ + // | ----------------\ + // | / | + // +--------v----v------+ | + // | | | + // | header | | + // | | | + // +--------------------+ | + // | + // ..... | + // | + // +--------------------+ | + // | | | + // | latch >----------/ + // | | + // +-------v------------+ + // | + // | + // | +--------------------+ + // | | | + // +---> original exit | + // | | + // +--------------------+ + // + // We change the control flow to look like + // + // + // +--------------------+ + // | | + // | preheader >-------------------------+ + // | | | + // +--------v-----------+ | + // | /-------------+ | + // | / | | + // +--------v--v--------+ | | + // | | | | + // | header | | +--------+ | + // | | | | | | + // +--------------------+ | | +-----v-----v-----------+ + // | | | | + // | | | .pseudo.exit | + // | | | | + // | | +-----------v-----------+ + // | | | + // ..... | | | + // | | +--------v-------------+ + // +--------------------+ | | | | + // | | | | | ContinuationBlock | + // | latch >------+ | | | + // | | | +----------------------+ + // +---------v----------+ | + // | | + // | | + // | +---------------^-----+ + // | | | + // +-----> .exit.selector | + // | | + // +----------v----------+ + // | + // +--------------------+ | + // | | | + // | original exit <----+ + // | | + // +--------------------+ + + RewrittenRangeInfo RRI; + + BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); + RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", + &F, BBInsertLocation); + RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, + BBInsertLocation); + + BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); + bool Increasing = LS.IndVarIncreasing; + bool IsSignedPredicate = LS.IsSignedPredicate; + + IRBuilder<> B(PreheaderJump); + auto NoopOrExt = [&](Value *V) { + if (V->getType() == RangeTy) + return V; + return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName()) + : B.CreateZExt(V, RangeTy, "wide." + V->getName()); + }; + + // EnterLoopCond - is it okay to start executing this `LS'? + Value *EnterLoopCond = nullptr; + auto Pred = + Increasing + ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT) + : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT); + Value *IndVarStart = NoopOrExt(LS.IndVarStart); + EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt); + + B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); + PreheaderJump->eraseFromParent(); + + LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); + B.SetInsertPoint(LS.LatchBr); + Value *IndVarBase = NoopOrExt(LS.IndVarBase); + Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt); + + Value *CondForBranch = LS.LatchBrExitIdx == 1 + ? TakeBackedgeLoopCond + : B.CreateNot(TakeBackedgeLoopCond); + + LS.LatchBr->setCondition(CondForBranch); + + B.SetInsertPoint(RRI.ExitSelector); + + // IterationsLeft - are there any more iterations left, given the original + // upper bound on the induction variable? If not, we branch to the "real" + // exit. + Value *LoopExitAt = NoopOrExt(LS.LoopExitAt); + Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt); + B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); + + BranchInst *BranchToContinuation = + BranchInst::Create(ContinuationBlock, RRI.PseudoExit); + + // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of + // each of the PHI nodes in the loop header. This feeds into the initial + // value of the same PHI nodes if/when we continue execution. + for (PHINode &PN : LS.Header->phis()) { + PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", + BranchToContinuation); + + NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); + NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), + RRI.ExitSelector); + RRI.PHIValuesAtPseudoExit.push_back(NewPHI); + } + + RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end", + BranchToContinuation); + RRI.IndVarEnd->addIncoming(IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector); + + // The latch exit now has a branch from `RRI.ExitSelector' instead of + // `LS.Latch'. The PHI nodes need to be updated to reflect that. + LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector); + + return RRI; +} + +void LoopConstrainer::rewriteIncomingValuesForPHIs( + LoopStructure &LS, BasicBlock *ContinuationBlock, + const LoopConstrainer::RewrittenRangeInfo &RRI) const { + unsigned PHIIndex = 0; + for (PHINode &PN : LS.Header->phis()) + PN.setIncomingValueForBlock(ContinuationBlock, + RRI.PHIValuesAtPseudoExit[PHIIndex++]); + + LS.IndVarStart = RRI.IndVarEnd; +} + +BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, + BasicBlock *OldPreheader, + const char *Tag) const { + BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); + BranchInst::Create(LS.Header, Preheader); + + LS.Header->replacePhiUsesWith(OldPreheader, Preheader); + + return Preheader; +} + +void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { + Loop *ParentLoop = OriginalLoop.getParentLoop(); + if (!ParentLoop) + return; + + for (BasicBlock *BB : BBs) + ParentLoop->addBasicBlockToLoop(BB, LI); +} + +Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM, + bool IsSubloop) { + Loop &New = *LI.AllocateLoop(); + if (Parent) + Parent->addChildLoop(&New); + else + LI.addTopLevelLoop(&New); + LPMAddNewLoop(&New, IsSubloop); + + // Add all of the blocks in Original to the new loop. + for (auto *BB : Original->blocks()) + if (LI.getLoopFor(BB) == Original) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); + + // Add all of the subloops to the new loop. + for (Loop *SubLoop : *Original) + createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true); + + return &New; +} + +bool LoopConstrainer::run() { + BasicBlock *Preheader = OriginalLoop.getLoopPreheader(); + assert(Preheader != nullptr && "precondition!"); + + OriginalPreheader = Preheader; + MainLoopPreheader = Preheader; + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; + bool Increasing = MainLoopStructure.IndVarIncreasing; + IntegerType *IVTy = cast<IntegerType>(RangeTy); + + SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer"); + Instruction *InsertPt = OriginalPreheader->getTerminator(); + + // It would have been better to make `PreLoop' and `PostLoop' + // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy + // constructor. + ClonedLoop PreLoop, PostLoop; + bool NeedsPreLoop = + Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value(); + bool NeedsPostLoop = + Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value(); + + Value *ExitPreLoopAt = nullptr; + Value *ExitMainLoopAt = nullptr; + const SCEVConstant *MinusOneS = + cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); + + if (NeedsPreLoop) { + const SCEV *ExitPreLoopAtSCEV = nullptr; + + if (Increasing) + ExitPreLoopAtSCEV = *SR.LowLimit; + else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " + << *(*SR.HighLimit) << "\n"); + return false; + } + + if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) { + LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); + return false; + } + + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); + ExitPreLoopAt->setName("exit.preloop.at"); + } + + if (NeedsPostLoop) { + const SCEV *ExitMainLoopAtSCEV = nullptr; + + if (Increasing) + ExitMainLoopAtSCEV = *SR.HighLimit; + else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE, + IsSignedPredicate)) + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + else { + LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " + << *(*SR.LowLimit) << "\n"); + return false; + } + + if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) { + LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() + << "\n"); + return false; + } + + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); + ExitMainLoopAt->setName("exit.mainloop.at"); + } + + // We clone these ahead of time so that we don't have to deal with changing + // and temporarily invalid IR as we transform the loops. + if (NeedsPreLoop) + cloneLoop(PreLoop, "preloop"); + if (NeedsPostLoop) + cloneLoop(PostLoop, "postloop"); + + RewrittenRangeInfo PreLoopRRI; + + if (NeedsPreLoop) { + Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, + PreLoop.Structure.Header); + + MainLoopPreheader = + createPreheader(MainLoopStructure, Preheader, "mainloop"); + PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, + ExitPreLoopAt, MainLoopPreheader); + rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, + PreLoopRRI); + } + + BasicBlock *PostLoopPreheader = nullptr; + RewrittenRangeInfo PostLoopRRI; + + if (NeedsPostLoop) { + PostLoopPreheader = + createPreheader(PostLoop.Structure, Preheader, "postloop"); + PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, + ExitMainLoopAt, PostLoopPreheader); + rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, + PostLoopRRI); + } + + BasicBlock *NewMainLoopPreheader = + MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; + BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, + PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, + PostLoopRRI.ExitSelector, NewMainLoopPreheader}; + + // Some of the above may be nullptr, filter them out before passing to + // addToParentLoopIfNeeded. + auto NewBlocksEnd = + std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); + + addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd)); + + DT.recalculate(F); + + // We need to first add all the pre and post loop blocks into the loop + // structures (as part of createClonedLoopStructure), and then update the + // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating + // LI when LoopSimplifyForm is generated. + Loop *PreL = nullptr, *PostL = nullptr; + if (!PreLoop.Blocks.empty()) { + PreL = createClonedLoopStructure(&OriginalLoop, + OriginalLoop.getParentLoop(), PreLoop.Map, + /* IsSubLoop */ false); + } + + if (!PostLoop.Blocks.empty()) { + PostL = + createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(), + PostLoop.Map, /* IsSubLoop */ false); + } + + // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. + auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) { + formLCSSARecursively(*L, DT, &LI, &SE); + simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true); + // Pre/post loops are slow paths, we do not need to perform any loop + // optimizations on them. + if (!IsOriginalLoop) + DisableAllLoopOptsOnLoop(*L); + }; + if (PreL) + CanonicalizeLoop(PreL, false); + if (PostL) + CanonicalizeLoop(PostL, false); + CanonicalizeLoop(&OriginalLoop, true); + + /// At this point: + /// - We've broken a "main loop" out of the loop in a way that the "main loop" + /// runs with the induction variable in a subset of [Begin, End). + /// - There is no overflow when computing "main loop" exit limit. + /// - Max latch taken count of the loop is limited. + /// It guarantees that induction variable will not overflow iterating in the + /// "main loop". + if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase)) + if (IsSignedPredicate) + cast<BinaryOperator>(MainLoopStructure.IndVarBase) + ->setHasNoSignedWrap(true); + /// TODO: support unsigned predicate. + /// To add NUW flag we need to prove that both operands of BO are + /// non-negative. E.g: + /// ... + /// %iv.next = add nsw i32 %iv, -1 + /// %cmp = icmp ult i32 %iv.next, %n + /// br i1 %cmp, label %loopexit, label %loop + /// + /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will + /// overflow, therefore NUW flag is not legal here. + + return true; +} |
