aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--llvm/lib/Transforms/Utils/LoopUtils.cpp313
1 files changed, 145 insertions, 168 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index e4d78f9ada08..f0f079335683 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -612,10 +612,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
for (auto *Block : L->blocks())
for (Instruction &I : *Block) {
auto *Undef = UndefValue::get(I.getType());
- for (Value::use_iterator UI = I.use_begin(), E = I.use_end();
- UI != E;) {
- Use &U = *UI;
- ++UI;
+ for (Use &U : llvm::make_early_inc_range(I.uses())) {
if (auto *Usr = dyn_cast<Instruction>(U.getUser()))
if (L->contains(Usr->getParent()))
continue;
@@ -710,21 +707,58 @@ void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE,
SE.forgetLoop(L);
- // Note: By splitting the backedge, and then explicitly making it unreachable
- // we gracefully handle corner cases such as non-bottom tested loops and the
- // like. We also have the benefit of being able to reuse existing well tested
- // code. It might be worth special casing the common bottom tested case at
- // some point to avoid code churn.
-
std::unique_ptr<MemorySSAUpdater> MSSAU;
if (MSSA)
MSSAU = std::make_unique<MemorySSAUpdater>(MSSA);
- auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get());
+ // Update the CFG and domtree. We chose to special case a couple of
+ // of common cases for code quality and test readability reasons.
+ [&]() -> void {
+ if (auto *BI = dyn_cast<BranchInst>(Latch->getTerminator())) {
+ if (!BI->isConditional()) {
+ DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
+ (void)changeToUnreachable(BI, /*PreserveLCSSA*/ true, &DTU,
+ MSSAU.get());
+ return;
+ }
+
+ // Conditional latch/exit - note that latch can be shared by inner
+ // and outer loop so the other target doesn't need to an exit
+ if (L->isLoopExiting(Latch)) {
+ // TODO: Generalize ConstantFoldTerminator so that it can be used
+ // here without invalidating LCSSA or MemorySSA. (Tricky case for
+ // LCSSA: header is an exit block of a preceeding sibling loop w/o
+ // dedicated exits.)
+ const unsigned ExitIdx = L->contains(BI->getSuccessor(0)) ? 1 : 0;
+ BasicBlock *ExitBB = BI->getSuccessor(ExitIdx);
+
+ DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
+ Header->removePredecessor(Latch, true);
+
+ IRBuilder<> Builder(BI);
+ auto *NewBI = Builder.CreateBr(ExitBB);
+ // Transfer the metadata to the new branch instruction (minus the
+ // loop info since this is no longer a loop)
+ NewBI->copyMetadata(*BI, {LLVMContext::MD_dbg,
+ LLVMContext::MD_annotation});
+
+ BI->eraseFromParent();
+ DTU.applyUpdates({{DominatorTree::Delete, Latch, Header}});
+ if (MSSA)
+ MSSAU->applyUpdates({{DominatorTree::Delete, Latch, Header}}, DT);
+ return;
+ }
+ }
- DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
- (void)changeToUnreachable(BackedgeBB->getTerminator(),
- /*PreserveLCSSA*/ true, &DTU, MSSAU.get());
+ // General case. By splitting the backedge, and then explicitly making it
+ // unreachable we gracefully handle corner cases such as switch and invoke
+ // termiantors.
+ auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get());
+
+ DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager);
+ (void)changeToUnreachable(BackedgeBB->getTerminator(),
+ /*PreserveLCSSA*/ true, &DTU, MSSAU.get());
+ }();
// Erase (and destroy) this loop instance. Handles relinking sub-loops
// and blocks within the loop as needed.
@@ -852,32 +886,37 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
return true;
}
-Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
- Value *Right) {
- CmpInst::Predicate Pred;
+CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
switch (RK) {
default:
llvm_unreachable("Unknown min/max recurrence kind");
case RecurKind::UMin:
- Pred = CmpInst::ICMP_ULT;
- break;
+ return CmpInst::ICMP_ULT;
case RecurKind::UMax:
- Pred = CmpInst::ICMP_UGT;
- break;
+ return CmpInst::ICMP_UGT;
case RecurKind::SMin:
- Pred = CmpInst::ICMP_SLT;
- break;
+ return CmpInst::ICMP_SLT;
case RecurKind::SMax:
- Pred = CmpInst::ICMP_SGT;
- break;
+ return CmpInst::ICMP_SGT;
case RecurKind::FMin:
- Pred = CmpInst::FCMP_OLT;
- break;
+ return CmpInst::FCMP_OLT;
case RecurKind::FMax:
- Pred = CmpInst::FCMP_OGT;
- break;
+ return CmpInst::FCMP_OGT;
}
+}
+Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
+ RecurKind RK, Value *Left, Value *Right) {
+ if (auto VTy = dyn_cast<VectorType>(Left->getType()))
+ StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp");
+ return Builder.CreateSelect(Cmp, Left, Right, "rdx.select");
+}
+
+Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left,
+ Value *Right) {
+ CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK);
Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp");
Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select");
return Select;
@@ -955,15 +994,50 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
}
+Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
+ const TargetTransformInfo *TTI,
+ Value *Src,
+ const RecurrenceDescriptor &Desc,
+ PHINode *OrigPhi) {
+ assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind(
+ Desc.getRecurrenceKind()) &&
+ "Unexpected reduction kind");
+ Value *InitVal = Desc.getRecurrenceStartValue();
+ Value *NewVal = nullptr;
+
+ // First use the original phi to determine the new value we're trying to
+ // select from in the loop.
+ SelectInst *SI = nullptr;
+ for (auto *U : OrigPhi->users()) {
+ if ((SI = dyn_cast<SelectInst>(U)))
+ break;
+ }
+ assert(SI && "One user of the original phi should be a select");
+
+ if (SI->getTrueValue() == OrigPhi)
+ NewVal = SI->getFalseValue();
+ else {
+ assert(SI->getFalseValue() == OrigPhi &&
+ "At least one input to the select should be the original Phi");
+ NewVal = SI->getTrueValue();
+ }
+
+ // Create a splat vector with the new value and compare this to the vector
+ // we want to reduce.
+ ElementCount EC = cast<VectorType>(Src->getType())->getElementCount();
+ Value *Right = Builder.CreateVectorSplat(EC, InitVal);
+ Value *Cmp =
+ Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp");
+
+ // If any predicate is true it means that we want to select the new value.
+ Cmp = Builder.CreateOrReduce(Cmp);
+ return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
+}
+
Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
const TargetTransformInfo *TTI,
Value *Src, RecurKind RdxKind,
ArrayRef<Value *> RedOps) {
- TargetTransformInfo::ReductionFlags RdxFlags;
- RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax ||
- RdxKind == RecurKind::FMax;
- RdxFlags.IsSigned = RdxKind == RecurKind::SMax || RdxKind == RecurKind::SMin;
-
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
switch (RdxKind) {
case RecurKind::Add:
@@ -1000,14 +1074,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
Value *llvm::createTargetReduction(IRBuilderBase &B,
const TargetTransformInfo *TTI,
- const RecurrenceDescriptor &Desc,
- Value *Src) {
+ const RecurrenceDescriptor &Desc, Value *Src,
+ PHINode *OrigPhi) {
// TODO: Support in-order reductions based on the recurrence descriptor.
// All ops in the reduction inherit fast-math-flags from the recurrence
// descriptor.
IRBuilderBase::FastMathFlagGuard FMFGuard(B);
B.setFastMathFlags(Desc.getFastMathFlags());
- return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind());
+
+ RecurKind RK = Desc.getRecurrenceKind();
+ if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
+ return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi);
+
+ return createSimpleTargetReduction(B, TTI, Src, RK);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,
@@ -1081,58 +1160,6 @@ bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE,
// As a side effect, reduces the amount of IV processing within the loop.
//===----------------------------------------------------------------------===//
-// Return true if the SCEV expansion generated by the rewriter can replace the
-// original value. SCEV guarantees that it produces the same value, but the way
-// it is produced may be illegal IR. Ideally, this function will only be
-// called for verification.
-static bool isValidRewrite(ScalarEvolution *SE, Value *FromVal, Value *ToVal) {
- // If an SCEV expression subsumed multiple pointers, its expansion could
- // reassociate the GEP changing the base pointer. This is illegal because the
- // final address produced by a GEP chain must be inbounds relative to its
- // underlying object. Otherwise basic alias analysis, among other things,
- // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid
- // producing an expression involving multiple pointers. Until then, we must
- // bail out here.
- //
- // Retrieve the pointer operand of the GEP. Don't use getUnderlyingObject
- // because it understands lcssa phis while SCEV does not.
- Value *FromPtr = FromVal;
- Value *ToPtr = ToVal;
- if (auto *GEP = dyn_cast<GEPOperator>(FromVal))
- FromPtr = GEP->getPointerOperand();
-
- if (auto *GEP = dyn_cast<GEPOperator>(ToVal))
- ToPtr = GEP->getPointerOperand();
-
- if (FromPtr != FromVal || ToPtr != ToVal) {
- // Quickly check the common case
- if (FromPtr == ToPtr)
- return true;
-
- // SCEV may have rewritten an expression that produces the GEP's pointer
- // operand. That's ok as long as the pointer operand has the same base
- // pointer. Unlike getUnderlyingObject(), getPointerBase() will find the
- // base of a recurrence. This handles the case in which SCEV expansion
- // converts a pointer type recurrence into a nonrecurrent pointer base
- // indexed by an integer recurrence.
-
- // If the GEP base pointer is a vector of pointers, abort.
- if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy())
- return false;
-
- const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr));
- const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr));
- if (FromBase == ToBase)
- return true;
-
- LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: GEP rewrite bail out "
- << *FromBase << " != " << *ToBase << "\n");
-
- return false;
- }
- return true;
-}
-
static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) {
SmallPtrSet<const Instruction *, 8> Visited;
SmallVector<const Instruction *, 8> WorkList;
@@ -1165,9 +1192,6 @@ struct RewritePhi {
Instruction *ExpansionPoint; // Where we'd like to expand that SCEV?
bool HighCost; // Is this expansion a high-cost?
- Value *Expansion = nullptr;
- bool ValidRewrite = false;
-
RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt,
bool H)
: PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt),
@@ -1204,8 +1228,6 @@ static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet)
// phase later. Skip it in the loop invariant check below.
bool found = false;
for (const RewritePhi &Phi : RewritePhiSet) {
- if (!Phi.ValidRewrite)
- continue;
unsigned i = Phi.Ith;
if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) {
found = true;
@@ -1264,13 +1286,6 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
if (!SE->isSCEVable(PN->getType()))
continue;
- // It's necessary to tell ScalarEvolution about this explicitly so that
- // it can walk the def-use list and forget all SCEVs, as it may not be
- // watching the PHI itself. Once the new exit value is in place, there
- // may not be a def-use connection between the loop and every instruction
- // which got a SCEVAddRecExpr for that loop.
- SE->forgetValue(PN);
-
// Iterate over all of the values in all the PHI nodes.
for (unsigned i = 0; i != NumPreds; ++i) {
// If the value being merged in is not integer or is not defined
@@ -1339,61 +1354,49 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
}
}
- // Now that we've done preliminary filtering and billed all the SCEV's,
- // we can perform the last sanity check - the expansion must be valid.
- for (RewritePhi &Phi : RewritePhiSet) {
- Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(),
- Phi.ExpansionPoint);
+ // TODO: evaluate whether it is beneficial to change how we calculate
+ // high-cost: if we have SCEV 'A' which we know we will expand, should we
+ // calculate the cost of other SCEV's after expanding SCEV 'A', thus
+ // potentially giving cost bonus to those other SCEV's?
- LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = "
- << *(Phi.Expansion) << '\n'
- << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
+ bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
+ int NumReplaced = 0;
+
+ // Transformation.
+ for (const RewritePhi &Phi : RewritePhiSet) {
+ PHINode *PN = Phi.PN;
- // FIXME: isValidRewrite() is a hack. it should be an assert, eventually.
- Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion);
- if (!Phi.ValidRewrite) {
- DeadInsts.push_back(Phi.Expansion);
+ // Only do the rewrite when the ExitValue can be expanded cheaply.
+ // If LoopCanBeDel is true, rewrite exit value aggressively.
+ if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost)
continue;
- }
+
+ Value *ExitVal = Rewriter.expandCodeFor(
+ Phi.ExpansionSCEV, Phi.PN->getType(), Phi.ExpansionPoint);
+
+ LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal
+ << '\n'
+ << " LoopVal = " << *(Phi.ExpansionPoint) << "\n");
#ifndef NDEBUG
// If we reuse an instruction from a loop which is neither L nor one of
// its containing loops, we end up breaking LCSSA form for this loop by
// creating a new use of its instruction.
- if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion))
+ if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal))
if (auto *EVL = LI->getLoopFor(ExitInsn->getParent()))
if (EVL != L)
assert(EVL->contains(L) && "LCSSA breach detected!");
#endif
- }
-
- // TODO: after isValidRewrite() is an assertion, evaluate whether
- // it is beneficial to change how we calculate high-cost:
- // if we have SCEV 'A' which we know we will expand, should we calculate
- // the cost of other SCEV's after expanding SCEV 'A',
- // thus potentially giving cost bonus to those other SCEV's?
-
- bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet);
- int NumReplaced = 0;
-
- // Transformation.
- for (const RewritePhi &Phi : RewritePhiSet) {
- if (!Phi.ValidRewrite)
- continue;
-
- PHINode *PN = Phi.PN;
- Value *ExitVal = Phi.Expansion;
-
- // Only do the rewrite when the ExitValue can be expanded cheaply.
- // If LoopCanBeDel is true, rewrite exit value aggressively.
- if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) {
- DeadInsts.push_back(ExitVal);
- continue;
- }
NumReplaced++;
Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith));
PN->setIncomingValue(Phi.Ith, ExitVal);
+ // It's necessary to tell ScalarEvolution about this explicitly so that
+ // it can walk the def-use list and forget all SCEVs, as it may not be
+ // watching the PHI itself. Once the new exit value is in place, there
+ // may not be a def-use connection between the loop and every instruction
+ // which got a SCEVAddRecExpr for that loop.
+ SE->forgetValue(PN);
// If this instruction is dead now, delete it. Don't do it now to avoid
// invalidating iterators.
@@ -1554,7 +1557,7 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
return ChecksWithBounds;
}
-std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks(
+Value *llvm::addRuntimeChecks(
Instruction *Loc, Loop *TheLoop,
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
SCEVExpander &Exp) {
@@ -1563,22 +1566,10 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks(
auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp);
LLVMContext &Ctx = Loc->getContext();
- Instruction *FirstInst = nullptr;
IRBuilder<> ChkBuilder(Loc);
// Our instructions might fold to a constant.
Value *MemoryRuntimeCheck = nullptr;
- // FIXME: this helper is currently a duplicate of the one in
- // LoopVectorize.cpp.
- auto GetFirstInst = [](Instruction *FirstInst, Value *V,
- Instruction *Loc) -> Instruction * {
- if (FirstInst)
- return FirstInst;
- if (Instruction *I = dyn_cast<Instruction>(V))
- return I->getParent() == Loc->getParent() ? I : nullptr;
- return nullptr;
- };
-
for (const auto &Check : ExpandedChecks) {
const PointerBounds &A = Check.first, &B = Check.second;
// Check if two pointers (A and B) conflict where conflict is computed as:
@@ -1607,30 +1598,16 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks(
// bound1 = (A.Start < B.End)
// IsConflict = bound0 & bound1
Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
- FirstInst = GetFirstInst(FirstInst, Cmp0, Loc);
Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
- FirstInst = GetFirstInst(FirstInst, Cmp1, Loc);
Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
- FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
- FirstInst = GetFirstInst(FirstInst, IsConflict, Loc);
}
MemoryRuntimeCheck = IsConflict;
}
- if (!MemoryRuntimeCheck)
- return std::make_pair(nullptr, nullptr);
-
- // We have to do this trickery because the IRBuilder might fold the check to a
- // constant expression in which case there is no Instruction anchored in a
- // the block.
- Instruction *Check =
- BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx));
- ChkBuilder.Insert(Check, "memcheck.conflict");
- FirstInst = GetFirstInst(FirstInst, Check, Loc);
- return std::make_pair(FirstInst, Check);
+ return MemoryRuntimeCheck;
}
Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L,