diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
commit | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch) | |
tree | f42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/Utils/LoopUtils.cpp | |
parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) |
Diffstat (limited to 'llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r-- | llvm/lib/Transforms/Utils/LoopUtils.cpp | 313 |
1 files changed, 145 insertions, 168 deletions
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp index e4d78f9ada08..f0f079335683 100644 --- a/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -612,10 +612,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, for (auto *Block : L->blocks()) for (Instruction &I : *Block) { auto *Undef = UndefValue::get(I.getType()); - for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); - UI != E;) { - Use &U = *UI; - ++UI; + for (Use &U : llvm::make_early_inc_range(I.uses())) { if (auto *Usr = dyn_cast<Instruction>(U.getUser())) if (L->contains(Usr->getParent())) continue; @@ -710,21 +707,58 @@ void llvm::breakLoopBackedge(Loop *L, DominatorTree &DT, ScalarEvolution &SE, SE.forgetLoop(L); - // Note: By splitting the backedge, and then explicitly making it unreachable - // we gracefully handle corner cases such as non-bottom tested loops and the - // like. We also have the benefit of being able to reuse existing well tested - // code. It might be worth special casing the common bottom tested case at - // some point to avoid code churn. - std::unique_ptr<MemorySSAUpdater> MSSAU; if (MSSA) MSSAU = std::make_unique<MemorySSAUpdater>(MSSA); - auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get()); + // Update the CFG and domtree. We chose to special case a couple of + // of common cases for code quality and test readability reasons. + [&]() -> void { + if (auto *BI = dyn_cast<BranchInst>(Latch->getTerminator())) { + if (!BI->isConditional()) { + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + (void)changeToUnreachable(BI, /*PreserveLCSSA*/ true, &DTU, + MSSAU.get()); + return; + } + + // Conditional latch/exit - note that latch can be shared by inner + // and outer loop so the other target doesn't need to an exit + if (L->isLoopExiting(Latch)) { + // TODO: Generalize ConstantFoldTerminator so that it can be used + // here without invalidating LCSSA or MemorySSA. (Tricky case for + // LCSSA: header is an exit block of a preceeding sibling loop w/o + // dedicated exits.) + const unsigned ExitIdx = L->contains(BI->getSuccessor(0)) ? 1 : 0; + BasicBlock *ExitBB = BI->getSuccessor(ExitIdx); + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + Header->removePredecessor(Latch, true); + + IRBuilder<> Builder(BI); + auto *NewBI = Builder.CreateBr(ExitBB); + // Transfer the metadata to the new branch instruction (minus the + // loop info since this is no longer a loop) + NewBI->copyMetadata(*BI, {LLVMContext::MD_dbg, + LLVMContext::MD_annotation}); + + BI->eraseFromParent(); + DTU.applyUpdates({{DominatorTree::Delete, Latch, Header}}); + if (MSSA) + MSSAU->applyUpdates({{DominatorTree::Delete, Latch, Header}}, DT); + return; + } + } - DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); - (void)changeToUnreachable(BackedgeBB->getTerminator(), - /*PreserveLCSSA*/ true, &DTU, MSSAU.get()); + // General case. By splitting the backedge, and then explicitly making it + // unreachable we gracefully handle corner cases such as switch and invoke + // termiantors. + auto *BackedgeBB = SplitEdge(Latch, Header, &DT, &LI, MSSAU.get()); + + DomTreeUpdater DTU(&DT, DomTreeUpdater::UpdateStrategy::Eager); + (void)changeToUnreachable(BackedgeBB->getTerminator(), + /*PreserveLCSSA*/ true, &DTU, MSSAU.get()); + }(); // Erase (and destroy) this loop instance. Handles relinking sub-loops // and blocks within the loop as needed. @@ -852,32 +886,37 @@ bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop, return true; } -Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, - Value *Right) { - CmpInst::Predicate Pred; +CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { switch (RK) { default: llvm_unreachable("Unknown min/max recurrence kind"); case RecurKind::UMin: - Pred = CmpInst::ICMP_ULT; - break; + return CmpInst::ICMP_ULT; case RecurKind::UMax: - Pred = CmpInst::ICMP_UGT; - break; + return CmpInst::ICMP_UGT; case RecurKind::SMin: - Pred = CmpInst::ICMP_SLT; - break; + return CmpInst::ICMP_SLT; case RecurKind::SMax: - Pred = CmpInst::ICMP_SGT; - break; + return CmpInst::ICMP_SGT; case RecurKind::FMin: - Pred = CmpInst::FCMP_OLT; - break; + return CmpInst::FCMP_OLT; case RecurKind::FMax: - Pred = CmpInst::FCMP_OGT; - break; + return CmpInst::FCMP_OGT; } +} +Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, + RecurKind RK, Value *Left, Value *Right) { + if (auto VTy = dyn_cast<VectorType>(Left->getType())) + StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Left, StartVal, "rdx.select.cmp"); + return Builder.CreateSelect(Cmp, Left, Right, "rdx.select"); +} + +Value *llvm::createMinMaxOp(IRBuilderBase &Builder, RecurKind RK, Value *Left, + Value *Right) { + CmpInst::Predicate Pred = getMinMaxReductionPredicate(RK); Value *Cmp = Builder.CreateCmp(Pred, Left, Right, "rdx.minmax.cmp"); Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select"); return Select; @@ -955,15 +994,50 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } +Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, + const TargetTransformInfo *TTI, + Value *Src, + const RecurrenceDescriptor &Desc, + PHINode *OrigPhi) { + assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind( + Desc.getRecurrenceKind()) && + "Unexpected reduction kind"); + Value *InitVal = Desc.getRecurrenceStartValue(); + Value *NewVal = nullptr; + + // First use the original phi to determine the new value we're trying to + // select from in the loop. + SelectInst *SI = nullptr; + for (auto *U : OrigPhi->users()) { + if ((SI = dyn_cast<SelectInst>(U))) + break; + } + assert(SI && "One user of the original phi should be a select"); + + if (SI->getTrueValue() == OrigPhi) + NewVal = SI->getFalseValue(); + else { + assert(SI->getFalseValue() == OrigPhi && + "At least one input to the select should be the original Phi"); + NewVal = SI->getTrueValue(); + } + + // Create a splat vector with the new value and compare this to the vector + // we want to reduce. + ElementCount EC = cast<VectorType>(Src->getType())->getElementCount(); + Value *Right = Builder.CreateVectorSplat(EC, InitVal); + Value *Cmp = + Builder.CreateCmp(CmpInst::ICMP_NE, Src, Right, "rdx.select.cmp"); + + // If any predicate is true it means that we want to select the new value. + Cmp = Builder.CreateOrReduce(Cmp); + return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); +} + Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, const TargetTransformInfo *TTI, Value *Src, RecurKind RdxKind, ArrayRef<Value *> RedOps) { - TargetTransformInfo::ReductionFlags RdxFlags; - RdxFlags.IsMaxOp = RdxKind == RecurKind::SMax || RdxKind == RecurKind::UMax || - RdxKind == RecurKind::FMax; - RdxFlags.IsSigned = RdxKind == RecurKind::SMax || RdxKind == RecurKind::SMin; - auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); switch (RdxKind) { case RecurKind::Add: @@ -1000,14 +1074,19 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *llvm::createTargetReduction(IRBuilderBase &B, const TargetTransformInfo *TTI, - const RecurrenceDescriptor &Desc, - Value *Src) { + const RecurrenceDescriptor &Desc, Value *Src, + PHINode *OrigPhi) { // TODO: Support in-order reductions based on the recurrence descriptor. // All ops in the reduction inherit fast-math-flags from the recurrence // descriptor. IRBuilderBase::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(Desc.getFastMathFlags()); - return createSimpleTargetReduction(B, TTI, Src, Desc.getRecurrenceKind()); + + RecurKind RK = Desc.getRecurrenceKind(); + if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) + return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi); + + return createSimpleTargetReduction(B, TTI, Src, RK); } Value *llvm::createOrderedReduction(IRBuilderBase &B, @@ -1081,58 +1160,6 @@ bool llvm::cannotBeMaxInLoop(const SCEV *S, const Loop *L, ScalarEvolution &SE, // As a side effect, reduces the amount of IV processing within the loop. //===----------------------------------------------------------------------===// -// Return true if the SCEV expansion generated by the rewriter can replace the -// original value. SCEV guarantees that it produces the same value, but the way -// it is produced may be illegal IR. Ideally, this function will only be -// called for verification. -static bool isValidRewrite(ScalarEvolution *SE, Value *FromVal, Value *ToVal) { - // If an SCEV expression subsumed multiple pointers, its expansion could - // reassociate the GEP changing the base pointer. This is illegal because the - // final address produced by a GEP chain must be inbounds relative to its - // underlying object. Otherwise basic alias analysis, among other things, - // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid - // producing an expression involving multiple pointers. Until then, we must - // bail out here. - // - // Retrieve the pointer operand of the GEP. Don't use getUnderlyingObject - // because it understands lcssa phis while SCEV does not. - Value *FromPtr = FromVal; - Value *ToPtr = ToVal; - if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) - FromPtr = GEP->getPointerOperand(); - - if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) - ToPtr = GEP->getPointerOperand(); - - if (FromPtr != FromVal || ToPtr != ToVal) { - // Quickly check the common case - if (FromPtr == ToPtr) - return true; - - // SCEV may have rewritten an expression that produces the GEP's pointer - // operand. That's ok as long as the pointer operand has the same base - // pointer. Unlike getUnderlyingObject(), getPointerBase() will find the - // base of a recurrence. This handles the case in which SCEV expansion - // converts a pointer type recurrence into a nonrecurrent pointer base - // indexed by an integer recurrence. - - // If the GEP base pointer is a vector of pointers, abort. - if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) - return false; - - const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); - const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); - if (FromBase == ToBase) - return true; - - LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: GEP rewrite bail out " - << *FromBase << " != " << *ToBase << "\n"); - - return false; - } - return true; -} - static bool hasHardUserWithinLoop(const Loop *L, const Instruction *I) { SmallPtrSet<const Instruction *, 8> Visited; SmallVector<const Instruction *, 8> WorkList; @@ -1165,9 +1192,6 @@ struct RewritePhi { Instruction *ExpansionPoint; // Where we'd like to expand that SCEV? bool HighCost; // Is this expansion a high-cost? - Value *Expansion = nullptr; - bool ValidRewrite = false; - RewritePhi(PHINode *P, unsigned I, const SCEV *Val, Instruction *ExpansionPt, bool H) : PN(P), Ith(I), ExpansionSCEV(Val), ExpansionPoint(ExpansionPt), @@ -1204,8 +1228,6 @@ static bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) // phase later. Skip it in the loop invariant check below. bool found = false; for (const RewritePhi &Phi : RewritePhiSet) { - if (!Phi.ValidRewrite) - continue; unsigned i = Phi.Ith; if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { found = true; @@ -1264,13 +1286,6 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, if (!SE->isSCEVable(PN->getType())) continue; - // It's necessary to tell ScalarEvolution about this explicitly so that - // it can walk the def-use list and forget all SCEVs, as it may not be - // watching the PHI itself. Once the new exit value is in place, there - // may not be a def-use connection between the loop and every instruction - // which got a SCEVAddRecExpr for that loop. - SE->forgetValue(PN); - // Iterate over all of the values in all the PHI nodes. for (unsigned i = 0; i != NumPreds; ++i) { // If the value being merged in is not integer or is not defined @@ -1339,61 +1354,49 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, } } - // Now that we've done preliminary filtering and billed all the SCEV's, - // we can perform the last sanity check - the expansion must be valid. - for (RewritePhi &Phi : RewritePhiSet) { - Phi.Expansion = Rewriter.expandCodeFor(Phi.ExpansionSCEV, Phi.PN->getType(), - Phi.ExpansionPoint); + // TODO: evaluate whether it is beneficial to change how we calculate + // high-cost: if we have SCEV 'A' which we know we will expand, should we + // calculate the cost of other SCEV's after expanding SCEV 'A', thus + // potentially giving cost bonus to those other SCEV's? - LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " - << *(Phi.Expansion) << '\n' - << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); + bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); + int NumReplaced = 0; + + // Transformation. + for (const RewritePhi &Phi : RewritePhiSet) { + PHINode *PN = Phi.PN; - // FIXME: isValidRewrite() is a hack. it should be an assert, eventually. - Phi.ValidRewrite = isValidRewrite(SE, Phi.ExpansionPoint, Phi.Expansion); - if (!Phi.ValidRewrite) { - DeadInsts.push_back(Phi.Expansion); + // Only do the rewrite when the ExitValue can be expanded cheaply. + // If LoopCanBeDel is true, rewrite exit value aggressively. + if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) continue; - } + + Value *ExitVal = Rewriter.expandCodeFor( + Phi.ExpansionSCEV, Phi.PN->getType(), Phi.ExpansionPoint); + + LLVM_DEBUG(dbgs() << "rewriteLoopExitValues: AfterLoopVal = " << *ExitVal + << '\n' + << " LoopVal = " << *(Phi.ExpansionPoint) << "\n"); #ifndef NDEBUG // If we reuse an instruction from a loop which is neither L nor one of // its containing loops, we end up breaking LCSSA form for this loop by // creating a new use of its instruction. - if (auto *ExitInsn = dyn_cast<Instruction>(Phi.Expansion)) + if (auto *ExitInsn = dyn_cast<Instruction>(ExitVal)) if (auto *EVL = LI->getLoopFor(ExitInsn->getParent())) if (EVL != L) assert(EVL->contains(L) && "LCSSA breach detected!"); #endif - } - - // TODO: after isValidRewrite() is an assertion, evaluate whether - // it is beneficial to change how we calculate high-cost: - // if we have SCEV 'A' which we know we will expand, should we calculate - // the cost of other SCEV's after expanding SCEV 'A', - // thus potentially giving cost bonus to those other SCEV's? - - bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); - int NumReplaced = 0; - - // Transformation. - for (const RewritePhi &Phi : RewritePhiSet) { - if (!Phi.ValidRewrite) - continue; - - PHINode *PN = Phi.PN; - Value *ExitVal = Phi.Expansion; - - // Only do the rewrite when the ExitValue can be expanded cheaply. - // If LoopCanBeDel is true, rewrite exit value aggressively. - if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { - DeadInsts.push_back(ExitVal); - continue; - } NumReplaced++; Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); PN->setIncomingValue(Phi.Ith, ExitVal); + // It's necessary to tell ScalarEvolution about this explicitly so that + // it can walk the def-use list and forget all SCEVs, as it may not be + // watching the PHI itself. Once the new exit value is in place, there + // may not be a def-use connection between the loop and every instruction + // which got a SCEVAddRecExpr for that loop. + SE->forgetValue(PN); // If this instruction is dead now, delete it. Don't do it now to avoid // invalidating iterators. @@ -1554,7 +1557,7 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, return ChecksWithBounds; } -std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( +Value *llvm::addRuntimeChecks( Instruction *Loc, Loop *TheLoop, const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, SCEVExpander &Exp) { @@ -1563,22 +1566,10 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp); LLVMContext &Ctx = Loc->getContext(); - Instruction *FirstInst = nullptr; IRBuilder<> ChkBuilder(Loc); // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; - // FIXME: this helper is currently a duplicate of the one in - // LoopVectorize.cpp. - auto GetFirstInst = [](Instruction *FirstInst, Value *V, - Instruction *Loc) -> Instruction * { - if (FirstInst) - return FirstInst; - if (Instruction *I = dyn_cast<Instruction>(V)) - return I->getParent() == Loc->getParent() ? I : nullptr; - return nullptr; - }; - for (const auto &Check : ExpandedChecks) { const PointerBounds &A = Check.first, &B = Check.second; // Check if two pointers (A and B) conflict where conflict is computed as: @@ -1607,30 +1598,16 @@ std::pair<Instruction *, Instruction *> llvm::addRuntimeChecks( // bound1 = (A.Start < B.End) // IsConflict = bound0 & bound1 Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - FirstInst = GetFirstInst(FirstInst, Cmp0, Loc); Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); - FirstInst = GetFirstInst(FirstInst, Cmp1, Loc); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); - FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); - FirstInst = GetFirstInst(FirstInst, IsConflict, Loc); } MemoryRuntimeCheck = IsConflict; } - if (!MemoryRuntimeCheck) - return std::make_pair(nullptr, nullptr); - - // We have to do this trickery because the IRBuilder might fold the check to a - // constant expression in which case there is no Instruction anchored in a - // the block. - Instruction *Check = - BinaryOperator::CreateAnd(MemoryRuntimeCheck, ConstantInt::getTrue(Ctx)); - ChkBuilder.Insert(Check, "memcheck.conflict"); - FirstInst = GetFirstInst(FirstInst, Check, Loc); - return std::make_pair(FirstInst, Check); + return MemoryRuntimeCheck; } Optional<IVConditionInfo> llvm::hasPartialIVCondition(Loop &L, |