aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-12-18 20:30:12 +0000
committerDimitry Andric <dim@FreeBSD.org>2024-04-19 21:12:03 +0000
commitc9157d925c489f07ba9c0b2ce47e5149b75969a5 (patch)
tree08bc4a3d9cad3f9ebffa558ddf140b9d9257b219 /contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
parent2a66844f606a35d68ad8a8061f4bea204274b3bc (diff)
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp904
1 files changed, 904 insertions, 0 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
new file mode 100644
index 000000000000..ea6d952cfa7d
--- /dev/null
+++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopConstrainer.cpp
@@ -0,0 +1,904 @@
+#include "llvm/Transforms/Utils/LoopConstrainer.h"
+#include "llvm/Analysis/LoopInfo.h"
+#include "llvm/Analysis/ScalarEvolution.h"
+#include "llvm/Analysis/ScalarEvolutionExpressions.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+#include "llvm/Transforms/Utils/LoopSimplify.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+#include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
+
+using namespace llvm;
+
+static const char *ClonedLoopTag = "loop_constrainer.loop.clone";
+
+#define DEBUG_TYPE "loop-constrainer"
+
+/// Given a loop with an deccreasing induction variable, is it possible to
+/// safely calculate the bounds of a new loop using the given Predicate.
+static bool isSafeDecreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
+ const SCEV *Step, ICmpInst::Predicate Pred,
+ unsigned LatchBrExitIdx, Loop *L,
+ ScalarEvolution &SE) {
+ if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
+ Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
+ return false;
+
+ if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
+ return false;
+
+ assert(SE.isKnownNegative(Step) && "expecting negative step");
+
+ LLVM_DEBUG(dbgs() << "isSafeDecreasingBound with:\n");
+ LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
+ LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
+ LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
+ LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
+ LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
+
+ bool IsSigned = ICmpInst::isSigned(Pred);
+ // The predicate that we need to check that the induction variable lies
+ // within bounds.
+ ICmpInst::Predicate BoundPred =
+ IsSigned ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT;
+
+ if (LatchBrExitIdx == 1)
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
+
+ assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be either 0 or 1");
+
+ const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType()));
+ unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
+ APInt Min = IsSigned ? APInt::getSignedMinValue(BitWidth)
+ : APInt::getMinValue(BitWidth);
+ const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Min), StepPlusOne);
+
+ const SCEV *MinusOne =
+ SE.getMinusSCEV(BoundSCEV, SE.getOne(BoundSCEV->getType()));
+
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, MinusOne) &&
+ SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit);
+}
+
+/// Given a loop with an increasing induction variable, is it possible to
+/// safely calculate the bounds of a new loop using the given Predicate.
+static bool isSafeIncreasingBound(const SCEV *Start, const SCEV *BoundSCEV,
+ const SCEV *Step, ICmpInst::Predicate Pred,
+ unsigned LatchBrExitIdx, Loop *L,
+ ScalarEvolution &SE) {
+ if (Pred != ICmpInst::ICMP_SLT && Pred != ICmpInst::ICMP_SGT &&
+ Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_UGT)
+ return false;
+
+ if (!SE.isAvailableAtLoopEntry(BoundSCEV, L))
+ return false;
+
+ LLVM_DEBUG(dbgs() << "isSafeIncreasingBound with:\n");
+ LLVM_DEBUG(dbgs() << "Start: " << *Start << "\n");
+ LLVM_DEBUG(dbgs() << "Step: " << *Step << "\n");
+ LLVM_DEBUG(dbgs() << "BoundSCEV: " << *BoundSCEV << "\n");
+ LLVM_DEBUG(dbgs() << "Pred: " << Pred << "\n");
+ LLVM_DEBUG(dbgs() << "LatchExitBrIdx: " << LatchBrExitIdx << "\n");
+
+ bool IsSigned = ICmpInst::isSigned(Pred);
+ // The predicate that we need to check that the induction variable lies
+ // within bounds.
+ ICmpInst::Predicate BoundPred =
+ IsSigned ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT;
+
+ if (LatchBrExitIdx == 1)
+ return SE.isLoopEntryGuardedByCond(L, BoundPred, Start, BoundSCEV);
+
+ assert(LatchBrExitIdx == 0 && "LatchBrExitIdx should be 0 or 1");
+
+ const SCEV *StepMinusOne = SE.getMinusSCEV(Step, SE.getOne(Step->getType()));
+ unsigned BitWidth = cast<IntegerType>(BoundSCEV->getType())->getBitWidth();
+ APInt Max = IsSigned ? APInt::getSignedMaxValue(BitWidth)
+ : APInt::getMaxValue(BitWidth);
+ const SCEV *Limit = SE.getMinusSCEV(SE.getConstant(Max), StepMinusOne);
+
+ return (SE.isLoopEntryGuardedByCond(L, BoundPred, Start,
+ SE.getAddExpr(BoundSCEV, Step)) &&
+ SE.isLoopEntryGuardedByCond(L, BoundPred, BoundSCEV, Limit));
+}
+
+/// Returns estimate for max latch taken count of the loop of the narrowest
+/// available type. If the latch block has such estimate, it is returned.
+/// Otherwise, we use max exit count of whole loop (that is potentially of wider
+/// type than latch check itself), which is still better than no estimate.
+static const SCEV *getNarrowestLatchMaxTakenCountEstimate(ScalarEvolution &SE,
+ const Loop &L) {
+ const SCEV *FromBlock =
+ SE.getExitCount(&L, L.getLoopLatch(), ScalarEvolution::SymbolicMaximum);
+ if (isa<SCEVCouldNotCompute>(FromBlock))
+ return SE.getSymbolicMaxBackedgeTakenCount(&L);
+ return FromBlock;
+}
+
+std::optional<LoopStructure>
+LoopStructure::parseLoopStructure(ScalarEvolution &SE, Loop &L,
+ bool AllowUnsignedLatchCond,
+ const char *&FailureReason) {
+ if (!L.isLoopSimplifyForm()) {
+ FailureReason = "loop not in LoopSimplify form";
+ return std::nullopt;
+ }
+
+ BasicBlock *Latch = L.getLoopLatch();
+ assert(Latch && "Simplified loops only have one latch!");
+
+ if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) {
+ FailureReason = "loop has already been cloned";
+ return std::nullopt;
+ }
+
+ if (!L.isLoopExiting(Latch)) {
+ FailureReason = "no loop latch";
+ return std::nullopt;
+ }
+
+ BasicBlock *Header = L.getHeader();
+ BasicBlock *Preheader = L.getLoopPreheader();
+ if (!Preheader) {
+ FailureReason = "no preheader";
+ return std::nullopt;
+ }
+
+ BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator());
+ if (!LatchBr || LatchBr->isUnconditional()) {
+ FailureReason = "latch terminator not conditional branch";
+ return std::nullopt;
+ }
+
+ unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0;
+
+ ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition());
+ if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) {
+ FailureReason = "latch terminator branch not conditional on integral icmp";
+ return std::nullopt;
+ }
+
+ const SCEV *MaxBETakenCount = getNarrowestLatchMaxTakenCountEstimate(SE, L);
+ if (isa<SCEVCouldNotCompute>(MaxBETakenCount)) {
+ FailureReason = "could not compute latch count";
+ return std::nullopt;
+ }
+ assert(SE.getLoopDisposition(MaxBETakenCount, &L) ==
+ ScalarEvolution::LoopInvariant &&
+ "loop variant exit count doesn't make sense!");
+
+ ICmpInst::Predicate Pred = ICI->getPredicate();
+ Value *LeftValue = ICI->getOperand(0);
+ const SCEV *LeftSCEV = SE.getSCEV(LeftValue);
+ IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType());
+
+ Value *RightValue = ICI->getOperand(1);
+ const SCEV *RightSCEV = SE.getSCEV(RightValue);
+
+ // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence.
+ if (!isa<SCEVAddRecExpr>(LeftSCEV)) {
+ if (isa<SCEVAddRecExpr>(RightSCEV)) {
+ std::swap(LeftSCEV, RightSCEV);
+ std::swap(LeftValue, RightValue);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ } else {
+ FailureReason = "no add recurrences in the icmp";
+ return std::nullopt;
+ }
+ }
+
+ auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) {
+ if (AR->getNoWrapFlags(SCEV::FlagNSW))
+ return true;
+
+ IntegerType *Ty = cast<IntegerType>(AR->getType());
+ IntegerType *WideTy =
+ IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2);
+
+ const SCEVAddRecExpr *ExtendAfterOp =
+ dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy));
+ if (ExtendAfterOp) {
+ const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy);
+ const SCEV *ExtendedStep =
+ SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy);
+
+ bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart &&
+ ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep;
+
+ if (NoSignedWrap)
+ return true;
+ }
+
+ // We may have proved this when computing the sign extension above.
+ return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap;
+ };
+
+ // `ICI` is interpreted as taking the backedge if the *next* value of the
+ // induction variable satisfies some constraint.
+
+ const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV);
+ if (IndVarBase->getLoop() != &L) {
+ FailureReason = "LHS in cmp is not an AddRec for this loop";
+ return std::nullopt;
+ }
+ if (!IndVarBase->isAffine()) {
+ FailureReason = "LHS in icmp not induction variable";
+ return std::nullopt;
+ }
+ const SCEV *StepRec = IndVarBase->getStepRecurrence(SE);
+ if (!isa<SCEVConstant>(StepRec)) {
+ FailureReason = "LHS in icmp not induction variable";
+ return std::nullopt;
+ }
+ ConstantInt *StepCI = cast<SCEVConstant>(StepRec)->getValue();
+
+ if (ICI->isEquality() && !HasNoSignedWrap(IndVarBase)) {
+ FailureReason = "LHS in icmp needs nsw for equality predicates";
+ return std::nullopt;
+ }
+
+ assert(!StepCI->isZero() && "Zero step?");
+ bool IsIncreasing = !StepCI->isNegative();
+ bool IsSignedPredicate;
+ const SCEV *StartNext = IndVarBase->getStart();
+ const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE));
+ const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend);
+ const SCEV *Step = SE.getSCEV(StepCI);
+
+ const SCEV *FixedRightSCEV = nullptr;
+
+ // If RightValue resides within loop (but still being loop invariant),
+ // regenerate it as preheader.
+ if (auto *I = dyn_cast<Instruction>(RightValue))
+ if (L.contains(I->getParent()))
+ FixedRightSCEV = RightSCEV;
+
+ if (IsIncreasing) {
+ bool DecreasedRightValueByOne = false;
+ if (StepCI->isOne()) {
+ // Try to turn eq/ne predicates to those we can work with.
+ if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
+ // while (++i != len) { while (++i < len) {
+ // ... ---> ...
+ // } }
+ // If both parts are known non-negative, it is profitable to use
+ // unsigned comparison in increasing loop. This allows us to make the
+ // comparison check against "RightSCEV + 1" more optimistic.
+ if (isKnownNonNegativeInLoop(IndVarStart, &L, SE) &&
+ isKnownNonNegativeInLoop(RightSCEV, &L, SE))
+ Pred = ICmpInst::ICMP_ULT;
+ else
+ Pred = ICmpInst::ICMP_SLT;
+ else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
+ // while (true) { while (true) {
+ // if (++i == len) ---> if (++i > len - 1)
+ // break; break;
+ // ... ...
+ // } }
+ if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
+ cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ false)) {
+ Pred = ICmpInst::ICMP_UGT;
+ RightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ DecreasedRightValueByOne = true;
+ } else if (cannotBeMinInLoop(RightSCEV, &L, SE, /*Signed*/ true)) {
+ Pred = ICmpInst::ICMP_SGT;
+ RightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ DecreasedRightValueByOne = true;
+ }
+ }
+ }
+
+ bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
+ bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
+ bool FoundExpectedPred =
+ (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0);
+
+ if (!FoundExpectedPred) {
+ FailureReason = "expected icmp slt semantically, found something else";
+ return std::nullopt;
+ }
+
+ IsSignedPredicate = ICmpInst::isSigned(Pred);
+ if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
+ FailureReason = "unsigned latch conditions are explicitly prohibited";
+ return std::nullopt;
+ }
+
+ if (!isSafeIncreasingBound(IndVarStart, RightSCEV, Step, Pred,
+ LatchBrExitIdx, &L, SE)) {
+ FailureReason = "Unsafe loop bounds";
+ return std::nullopt;
+ }
+ if (LatchBrExitIdx == 0) {
+ // We need to increase the right value unless we have already decreased
+ // it virtually when we replaced EQ with SGT.
+ if (!DecreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ } else {
+ assert(!DecreasedRightValueByOne &&
+ "Right value can be decreased only for LatchBrExitIdx == 0!");
+ }
+ } else {
+ bool IncreasedRightValueByOne = false;
+ if (StepCI->isMinusOne()) {
+ // Try to turn eq/ne predicates to those we can work with.
+ if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1)
+ // while (--i != len) { while (--i > len) {
+ // ... ---> ...
+ // } }
+ // We intentionally don't turn the predicate into UGT even if we know
+ // that both operands are non-negative, because it will only pessimize
+ // our check against "RightSCEV - 1".
+ Pred = ICmpInst::ICMP_SGT;
+ else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0) {
+ // while (true) { while (true) {
+ // if (--i == len) ---> if (--i < len + 1)
+ // break; break;
+ // ... ...
+ // } }
+ if (IndVarBase->getNoWrapFlags(SCEV::FlagNUW) &&
+ cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ false)) {
+ Pred = ICmpInst::ICMP_ULT;
+ RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ IncreasedRightValueByOne = true;
+ } else if (cannotBeMaxInLoop(RightSCEV, &L, SE, /* Signed */ true)) {
+ Pred = ICmpInst::ICMP_SLT;
+ RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType()));
+ IncreasedRightValueByOne = true;
+ }
+ }
+ }
+
+ bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT);
+ bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT);
+
+ bool FoundExpectedPred =
+ (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0);
+
+ if (!FoundExpectedPred) {
+ FailureReason = "expected icmp sgt semantically, found something else";
+ return std::nullopt;
+ }
+
+ IsSignedPredicate =
+ Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT;
+
+ if (!IsSignedPredicate && !AllowUnsignedLatchCond) {
+ FailureReason = "unsigned latch conditions are explicitly prohibited";
+ return std::nullopt;
+ }
+
+ if (!isSafeDecreasingBound(IndVarStart, RightSCEV, Step, Pred,
+ LatchBrExitIdx, &L, SE)) {
+ FailureReason = "Unsafe bounds";
+ return std::nullopt;
+ }
+
+ if (LatchBrExitIdx == 0) {
+ // We need to decrease the right value unless we have already increased
+ // it virtually when we replaced EQ with SLT.
+ if (!IncreasedRightValueByOne)
+ FixedRightSCEV =
+ SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType()));
+ } else {
+ assert(!IncreasedRightValueByOne &&
+ "Right value can be increased only for LatchBrExitIdx == 0!");
+ }
+ }
+ BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx);
+
+ assert(!L.contains(LatchExit) && "expected an exit block!");
+ const DataLayout &DL = Preheader->getModule()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "loop-constrainer");
+ Instruction *Ins = Preheader->getTerminator();
+
+ if (FixedRightSCEV)
+ RightValue =
+ Expander.expandCodeFor(FixedRightSCEV, FixedRightSCEV->getType(), Ins);
+
+ Value *IndVarStartV = Expander.expandCodeFor(IndVarStart, IndVarTy, Ins);
+ IndVarStartV->setName("indvar.start");
+
+ LoopStructure Result;
+
+ Result.Tag = "main";
+ Result.Header = Header;
+ Result.Latch = Latch;
+ Result.LatchBr = LatchBr;
+ Result.LatchExit = LatchExit;
+ Result.LatchBrExitIdx = LatchBrExitIdx;
+ Result.IndVarStart = IndVarStartV;
+ Result.IndVarStep = StepCI;
+ Result.IndVarBase = LeftValue;
+ Result.IndVarIncreasing = IsIncreasing;
+ Result.LoopExitAt = RightValue;
+ Result.IsSignedPredicate = IsSignedPredicate;
+ Result.ExitCountTy = cast<IntegerType>(MaxBETakenCount->getType());
+
+ FailureReason = nullptr;
+
+ return Result;
+}
+
+// Add metadata to the loop L to disable loop optimizations. Callers need to
+// confirm that optimizing loop L is not beneficial.
+static void DisableAllLoopOptsOnLoop(Loop &L) {
+ // We do not care about any existing loopID related metadata for L, since we
+ // are setting all loop metadata to false.
+ LLVMContext &Context = L.getHeader()->getContext();
+ // Reserve first location for self reference to the LoopID metadata node.
+ MDNode *Dummy = MDNode::get(Context, {});
+ MDNode *DisableUnroll = MDNode::get(
+ Context, {MDString::get(Context, "llvm.loop.unroll.disable")});
+ Metadata *FalseVal =
+ ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0));
+ MDNode *DisableVectorize = MDNode::get(
+ Context,
+ {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal});
+ MDNode *DisableLICMVersioning = MDNode::get(
+ Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")});
+ MDNode *DisableDistribution = MDNode::get(
+ Context,
+ {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal});
+ MDNode *NewLoopID =
+ MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize,
+ DisableLICMVersioning, DisableDistribution});
+ // Set operand 0 to refer to the loop id itself.
+ NewLoopID->replaceOperandWith(0, NewLoopID);
+ L.setLoopID(NewLoopID);
+}
+
+LoopConstrainer::LoopConstrainer(Loop &L, LoopInfo &LI,
+ function_ref<void(Loop *, bool)> LPMAddNewLoop,
+ const LoopStructure &LS, ScalarEvolution &SE,
+ DominatorTree &DT, Type *T, SubRanges SR)
+ : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), SE(SE),
+ DT(DT), LI(LI), LPMAddNewLoop(LPMAddNewLoop), OriginalLoop(L), RangeTy(T),
+ MainLoopStructure(LS), SR(SR) {}
+
+void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result,
+ const char *Tag) const {
+ for (BasicBlock *BB : OriginalLoop.getBlocks()) {
+ BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F);
+ Result.Blocks.push_back(Clone);
+ Result.Map[BB] = Clone;
+ }
+
+ auto GetClonedValue = [&Result](Value *V) {
+ assert(V && "null values not in domain!");
+ auto It = Result.Map.find(V);
+ if (It == Result.Map.end())
+ return V;
+ return static_cast<Value *>(It->second);
+ };
+
+ auto *ClonedLatch =
+ cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch()));
+ ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag,
+ MDNode::get(Ctx, {}));
+
+ Result.Structure = MainLoopStructure.map(GetClonedValue);
+ Result.Structure.Tag = Tag;
+
+ for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) {
+ BasicBlock *ClonedBB = Result.Blocks[i];
+ BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i];
+
+ assert(Result.Map[OriginalBB] == ClonedBB && "invariant!");
+
+ for (Instruction &I : *ClonedBB)
+ RemapInstruction(&I, Result.Map,
+ RF_NoModuleLevelChanges | RF_IgnoreMissingLocals);
+
+ // Exit blocks will now have one more predecessor and their PHI nodes need
+ // to be edited to reflect that. No phi nodes need to be introduced because
+ // the loop is in LCSSA.
+
+ for (auto *SBB : successors(OriginalBB)) {
+ if (OriginalLoop.contains(SBB))
+ continue; // not an exit block
+
+ for (PHINode &PN : SBB->phis()) {
+ Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB);
+ PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB);
+ SE.forgetValue(&PN);
+ }
+ }
+ }
+}
+
+LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd(
+ const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt,
+ BasicBlock *ContinuationBlock) const {
+ // We start with a loop with a single latch:
+ //
+ // +--------------------+
+ // | |
+ // | preheader |
+ // | |
+ // +--------+-----------+
+ // | ----------------\
+ // | / |
+ // +--------v----v------+ |
+ // | | |
+ // | header | |
+ // | | |
+ // +--------------------+ |
+ // |
+ // ..... |
+ // |
+ // +--------------------+ |
+ // | | |
+ // | latch >----------/
+ // | |
+ // +-------v------------+
+ // |
+ // |
+ // | +--------------------+
+ // | | |
+ // +---> original exit |
+ // | |
+ // +--------------------+
+ //
+ // We change the control flow to look like
+ //
+ //
+ // +--------------------+
+ // | |
+ // | preheader >-------------------------+
+ // | | |
+ // +--------v-----------+ |
+ // | /-------------+ |
+ // | / | |
+ // +--------v--v--------+ | |
+ // | | | |
+ // | header | | +--------+ |
+ // | | | | | |
+ // +--------------------+ | | +-----v-----v-----------+
+ // | | | |
+ // | | | .pseudo.exit |
+ // | | | |
+ // | | +-----------v-----------+
+ // | | |
+ // ..... | | |
+ // | | +--------v-------------+
+ // +--------------------+ | | | |
+ // | | | | | ContinuationBlock |
+ // | latch >------+ | | |
+ // | | | +----------------------+
+ // +---------v----------+ |
+ // | |
+ // | |
+ // | +---------------^-----+
+ // | | |
+ // +-----> .exit.selector |
+ // | |
+ // +----------v----------+
+ // |
+ // +--------------------+ |
+ // | | |
+ // | original exit <----+
+ // | |
+ // +--------------------+
+
+ RewrittenRangeInfo RRI;
+
+ BasicBlock *BBInsertLocation = LS.Latch->getNextNode();
+ RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector",
+ &F, BBInsertLocation);
+ RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F,
+ BBInsertLocation);
+
+ BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator());
+ bool Increasing = LS.IndVarIncreasing;
+ bool IsSignedPredicate = LS.IsSignedPredicate;
+
+ IRBuilder<> B(PreheaderJump);
+ auto NoopOrExt = [&](Value *V) {
+ if (V->getType() == RangeTy)
+ return V;
+ return IsSignedPredicate ? B.CreateSExt(V, RangeTy, "wide." + V->getName())
+ : B.CreateZExt(V, RangeTy, "wide." + V->getName());
+ };
+
+ // EnterLoopCond - is it okay to start executing this `LS'?
+ Value *EnterLoopCond = nullptr;
+ auto Pred =
+ Increasing
+ ? (IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT)
+ : (IsSignedPredicate ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT);
+ Value *IndVarStart = NoopOrExt(LS.IndVarStart);
+ EnterLoopCond = B.CreateICmp(Pred, IndVarStart, ExitSubloopAt);
+
+ B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit);
+ PreheaderJump->eraseFromParent();
+
+ LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector);
+ B.SetInsertPoint(LS.LatchBr);
+ Value *IndVarBase = NoopOrExt(LS.IndVarBase);
+ Value *TakeBackedgeLoopCond = B.CreateICmp(Pred, IndVarBase, ExitSubloopAt);
+
+ Value *CondForBranch = LS.LatchBrExitIdx == 1
+ ? TakeBackedgeLoopCond
+ : B.CreateNot(TakeBackedgeLoopCond);
+
+ LS.LatchBr->setCondition(CondForBranch);
+
+ B.SetInsertPoint(RRI.ExitSelector);
+
+ // IterationsLeft - are there any more iterations left, given the original
+ // upper bound on the induction variable? If not, we branch to the "real"
+ // exit.
+ Value *LoopExitAt = NoopOrExt(LS.LoopExitAt);
+ Value *IterationsLeft = B.CreateICmp(Pred, IndVarBase, LoopExitAt);
+ B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit);
+
+ BranchInst *BranchToContinuation =
+ BranchInst::Create(ContinuationBlock, RRI.PseudoExit);
+
+ // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of
+ // each of the PHI nodes in the loop header. This feeds into the initial
+ // value of the same PHI nodes if/when we continue execution.
+ for (PHINode &PN : LS.Header->phis()) {
+ PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy",
+ BranchToContinuation);
+
+ NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader);
+ NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch),
+ RRI.ExitSelector);
+ RRI.PHIValuesAtPseudoExit.push_back(NewPHI);
+ }
+
+ RRI.IndVarEnd = PHINode::Create(IndVarBase->getType(), 2, "indvar.end",
+ BranchToContinuation);
+ RRI.IndVarEnd->addIncoming(IndVarStart, Preheader);
+ RRI.IndVarEnd->addIncoming(IndVarBase, RRI.ExitSelector);
+
+ // The latch exit now has a branch from `RRI.ExitSelector' instead of
+ // `LS.Latch'. The PHI nodes need to be updated to reflect that.
+ LS.LatchExit->replacePhiUsesWith(LS.Latch, RRI.ExitSelector);
+
+ return RRI;
+}
+
+void LoopConstrainer::rewriteIncomingValuesForPHIs(
+ LoopStructure &LS, BasicBlock *ContinuationBlock,
+ const LoopConstrainer::RewrittenRangeInfo &RRI) const {
+ unsigned PHIIndex = 0;
+ for (PHINode &PN : LS.Header->phis())
+ PN.setIncomingValueForBlock(ContinuationBlock,
+ RRI.PHIValuesAtPseudoExit[PHIIndex++]);
+
+ LS.IndVarStart = RRI.IndVarEnd;
+}
+
+BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS,
+ BasicBlock *OldPreheader,
+ const char *Tag) const {
+ BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header);
+ BranchInst::Create(LS.Header, Preheader);
+
+ LS.Header->replacePhiUsesWith(OldPreheader, Preheader);
+
+ return Preheader;
+}
+
+void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) {
+ Loop *ParentLoop = OriginalLoop.getParentLoop();
+ if (!ParentLoop)
+ return;
+
+ for (BasicBlock *BB : BBs)
+ ParentLoop->addBasicBlockToLoop(BB, LI);
+}
+
+Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent,
+ ValueToValueMapTy &VM,
+ bool IsSubloop) {
+ Loop &New = *LI.AllocateLoop();
+ if (Parent)
+ Parent->addChildLoop(&New);
+ else
+ LI.addTopLevelLoop(&New);
+ LPMAddNewLoop(&New, IsSubloop);
+
+ // Add all of the blocks in Original to the new loop.
+ for (auto *BB : Original->blocks())
+ if (LI.getLoopFor(BB) == Original)
+ New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI);
+
+ // Add all of the subloops to the new loop.
+ for (Loop *SubLoop : *Original)
+ createClonedLoopStructure(SubLoop, &New, VM, /* IsSubloop */ true);
+
+ return &New;
+}
+
+bool LoopConstrainer::run() {
+ BasicBlock *Preheader = OriginalLoop.getLoopPreheader();
+ assert(Preheader != nullptr && "precondition!");
+
+ OriginalPreheader = Preheader;
+ MainLoopPreheader = Preheader;
+ bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate;
+ bool Increasing = MainLoopStructure.IndVarIncreasing;
+ IntegerType *IVTy = cast<IntegerType>(RangeTy);
+
+ SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "loop-constrainer");
+ Instruction *InsertPt = OriginalPreheader->getTerminator();
+
+ // It would have been better to make `PreLoop' and `PostLoop'
+ // `std::optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy
+ // constructor.
+ ClonedLoop PreLoop, PostLoop;
+ bool NeedsPreLoop =
+ Increasing ? SR.LowLimit.has_value() : SR.HighLimit.has_value();
+ bool NeedsPostLoop =
+ Increasing ? SR.HighLimit.has_value() : SR.LowLimit.has_value();
+
+ Value *ExitPreLoopAt = nullptr;
+ Value *ExitMainLoopAt = nullptr;
+ const SCEVConstant *MinusOneS =
+ cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */));
+
+ if (NeedsPreLoop) {
+ const SCEV *ExitPreLoopAtSCEV = nullptr;
+
+ if (Increasing)
+ ExitPreLoopAtSCEV = *SR.LowLimit;
+ else if (cannotBeMinInLoop(*SR.HighLimit, &OriginalLoop, SE,
+ IsSignedPredicate))
+ ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS);
+ else {
+ LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
+ << "preloop exit limit. HighLimit = "
+ << *(*SR.HighLimit) << "\n");
+ return false;
+ }
+
+ if (!Expander.isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt)) {
+ LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
+ << " preloop exit limit " << *ExitPreLoopAtSCEV
+ << " at block " << InsertPt->getParent()->getName()
+ << "\n");
+ return false;
+ }
+
+ ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt);
+ ExitPreLoopAt->setName("exit.preloop.at");
+ }
+
+ if (NeedsPostLoop) {
+ const SCEV *ExitMainLoopAtSCEV = nullptr;
+
+ if (Increasing)
+ ExitMainLoopAtSCEV = *SR.HighLimit;
+ else if (cannotBeMinInLoop(*SR.LowLimit, &OriginalLoop, SE,
+ IsSignedPredicate))
+ ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS);
+ else {
+ LLVM_DEBUG(dbgs() << "could not prove no-overflow when computing "
+ << "mainloop exit limit. LowLimit = "
+ << *(*SR.LowLimit) << "\n");
+ return false;
+ }
+
+ if (!Expander.isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt)) {
+ LLVM_DEBUG(dbgs() << "could not prove that it is safe to expand the"
+ << " main loop exit limit " << *ExitMainLoopAtSCEV
+ << " at block " << InsertPt->getParent()->getName()
+ << "\n");
+ return false;
+ }
+
+ ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt);
+ ExitMainLoopAt->setName("exit.mainloop.at");
+ }
+
+ // We clone these ahead of time so that we don't have to deal with changing
+ // and temporarily invalid IR as we transform the loops.
+ if (NeedsPreLoop)
+ cloneLoop(PreLoop, "preloop");
+ if (NeedsPostLoop)
+ cloneLoop(PostLoop, "postloop");
+
+ RewrittenRangeInfo PreLoopRRI;
+
+ if (NeedsPreLoop) {
+ Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header,
+ PreLoop.Structure.Header);
+
+ MainLoopPreheader =
+ createPreheader(MainLoopStructure, Preheader, "mainloop");
+ PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader,
+ ExitPreLoopAt, MainLoopPreheader);
+ rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader,
+ PreLoopRRI);
+ }
+
+ BasicBlock *PostLoopPreheader = nullptr;
+ RewrittenRangeInfo PostLoopRRI;
+
+ if (NeedsPostLoop) {
+ PostLoopPreheader =
+ createPreheader(PostLoop.Structure, Preheader, "postloop");
+ PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader,
+ ExitMainLoopAt, PostLoopPreheader);
+ rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader,
+ PostLoopRRI);
+ }
+
+ BasicBlock *NewMainLoopPreheader =
+ MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr;
+ BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit,
+ PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit,
+ PostLoopRRI.ExitSelector, NewMainLoopPreheader};
+
+ // Some of the above may be nullptr, filter them out before passing to
+ // addToParentLoopIfNeeded.
+ auto NewBlocksEnd =
+ std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr);
+
+ addToParentLoopIfNeeded(ArrayRef(std::begin(NewBlocks), NewBlocksEnd));
+
+ DT.recalculate(F);
+
+ // We need to first add all the pre and post loop blocks into the loop
+ // structures (as part of createClonedLoopStructure), and then update the
+ // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating
+ // LI when LoopSimplifyForm is generated.
+ Loop *PreL = nullptr, *PostL = nullptr;
+ if (!PreLoop.Blocks.empty()) {
+ PreL = createClonedLoopStructure(&OriginalLoop,
+ OriginalLoop.getParentLoop(), PreLoop.Map,
+ /* IsSubLoop */ false);
+ }
+
+ if (!PostLoop.Blocks.empty()) {
+ PostL =
+ createClonedLoopStructure(&OriginalLoop, OriginalLoop.getParentLoop(),
+ PostLoop.Map, /* IsSubLoop */ false);
+ }
+
+ // This function canonicalizes the loop into Loop-Simplify and LCSSA forms.
+ auto CanonicalizeLoop = [&](Loop *L, bool IsOriginalLoop) {
+ formLCSSARecursively(*L, DT, &LI, &SE);
+ simplifyLoop(L, &DT, &LI, &SE, nullptr, nullptr, true);
+ // Pre/post loops are slow paths, we do not need to perform any loop
+ // optimizations on them.
+ if (!IsOriginalLoop)
+ DisableAllLoopOptsOnLoop(*L);
+ };
+ if (PreL)
+ CanonicalizeLoop(PreL, false);
+ if (PostL)
+ CanonicalizeLoop(PostL, false);
+ CanonicalizeLoop(&OriginalLoop, true);
+
+ /// At this point:
+ /// - We've broken a "main loop" out of the loop in a way that the "main loop"
+ /// runs with the induction variable in a subset of [Begin, End).
+ /// - There is no overflow when computing "main loop" exit limit.
+ /// - Max latch taken count of the loop is limited.
+ /// It guarantees that induction variable will not overflow iterating in the
+ /// "main loop".
+ if (isa<OverflowingBinaryOperator>(MainLoopStructure.IndVarBase))
+ if (IsSignedPredicate)
+ cast<BinaryOperator>(MainLoopStructure.IndVarBase)
+ ->setHasNoSignedWrap(true);
+ /// TODO: support unsigned predicate.
+ /// To add NUW flag we need to prove that both operands of BO are
+ /// non-negative. E.g:
+ /// ...
+ /// %iv.next = add nsw i32 %iv, -1
+ /// %cmp = icmp ult i32 %iv.next, %n
+ /// br i1 %cmp, label %loopexit, label %loop
+ ///
+ /// -1 is MAX_UINT in terms of unsigned int. Adding anything but zero will
+ /// overflow, therefore NUW flag is not legal here.
+
+ return true;
+}