diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp | 210 |
1 files changed, 147 insertions, 63 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp index 7d6662c44f07..59485126b280 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp @@ -296,7 +296,7 @@ std::optional<MDNode *> llvm::makeFollowupLoopID( StringRef AttrName = cast<MDString>(NameMD)->getString(); // Do not inherit excluded attributes. - return !AttrName.startswith(InheritOptionsExceptPrefix); + return !AttrName.starts_with(InheritOptionsExceptPrefix); }; if (InheritThisAttribute(Op)) @@ -556,12 +556,8 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Removes all incoming values from all other exiting blocks (including // duplicate values from an exiting block). // Nuke all entries except the zero'th entry which is the preheader entry. - // NOTE! We need to remove Incoming Values in the reverse order as done - // below, to keep the indices valid for deletion (removeIncomingValues - // updates getNumIncomingValues and shifts all values down into the - // operand being deleted). - for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i) - P.removeIncomingValue(e - i, false); + P.removeIncomingValueIf([](unsigned Idx) { return Idx != 0; }, + /* DeletePHIIfEmpty */ false); assert((P.getNumIncomingValues() == 1 && P.getIncomingBlock(PredIndex) == Preheader) && @@ -608,6 +604,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // Use a map to unique and a vector to guarantee deterministic ordering. llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet; llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst; + llvm::SmallVector<DPValue *, 4> DeadDPValues; if (ExitBlock) { // Given LCSSA form is satisfied, we should not have users of instructions @@ -632,6 +629,24 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, "Unexpected user in reachable block"); U.set(Poison); } + + // RemoveDIs: do the same as below for DPValues. + if (Block->IsNewDbgInfoFormat) { + for (DPValue &DPV : + llvm::make_early_inc_range(I.getDbgValueRange())) { + DebugVariable Key(DPV.getVariable(), DPV.getExpression(), + DPV.getDebugLoc().get()); + if (!DeadDebugSet.insert(Key).second) + continue; + // Unlinks the DPV from it's container, for later insertion. + DPV.removeFromParent(); + DeadDPValues.push_back(&DPV); + } + } + + // For one of each variable encountered, preserve a debug intrinsic (set + // to Poison) and transfer it to the loop exit. This terminates any + // variable locations that were set during the loop. auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I); if (!DVI) continue; @@ -646,12 +661,22 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE, // be be replaced with undef. Loop invariant values will still be available. // Move dbg.values out the loop so that earlier location ranges are still // terminated and loop invariant assignments are preserved. - Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI(); - assert(InsertDbgValueBefore && + DIBuilder DIB(*ExitBlock->getModule()); + BasicBlock::iterator InsertDbgValueBefore = + ExitBlock->getFirstInsertionPt(); + assert(InsertDbgValueBefore != ExitBlock->end() && "There should be a non-PHI instruction in exit block, else these " "instructions will have no parent."); + for (auto *DVI : DeadDebugInst) - DVI->moveBefore(InsertDbgValueBefore); + DVI->moveBefore(*ExitBlock, InsertDbgValueBefore); + + // Due to the "head" bit in BasicBlock::iterator, we're going to insert + // each DPValue right at the start of the block, wheras dbg.values would be + // repeatedly inserted before the first instruction. To replicate this + // behaviour, do it backwards. + for (DPValue *DPV : llvm::reverse(DeadDPValues)) + ExitBlock->insertDPValueBefore(DPV, InsertDbgValueBefore); } // Remove the block from the reference counting scheme, so that we can @@ -937,8 +962,8 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) { } } -Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal, - RecurKind RK, Value *Left, Value *Right) { +Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal, + RecurKind RK, Value *Left, Value *Right) { if (auto VTy = dyn_cast<VectorType>(Left->getType())) StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal); Value *Cmp = @@ -1028,14 +1053,12 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src, return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } -Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, - const TargetTransformInfo *TTI, - Value *Src, - const RecurrenceDescriptor &Desc, - PHINode *OrigPhi) { - assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind( - Desc.getRecurrenceKind()) && - "Unexpected reduction kind"); +Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src, + const RecurrenceDescriptor &Desc, + PHINode *OrigPhi) { + assert( + RecurrenceDescriptor::isAnyOfRecurrenceKind(Desc.getRecurrenceKind()) && + "Unexpected reduction kind"); Value *InitVal = Desc.getRecurrenceStartValue(); Value *NewVal = nullptr; @@ -1068,9 +1091,8 @@ Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder, return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select"); } -Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, - const TargetTransformInfo *TTI, - Value *Src, RecurKind RdxKind) { +Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src, + RecurKind RdxKind) { auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType(); switch (RdxKind) { case RecurKind::Add: @@ -1111,7 +1133,6 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, } Value *llvm::createTargetReduction(IRBuilderBase &B, - const TargetTransformInfo *TTI, const RecurrenceDescriptor &Desc, Value *Src, PHINode *OrigPhi) { // TODO: Support in-order reductions based on the recurrence descriptor. @@ -1121,10 +1142,10 @@ Value *llvm::createTargetReduction(IRBuilderBase &B, B.setFastMathFlags(Desc.getFastMathFlags()); RecurKind RK = Desc.getRecurrenceKind(); - if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK)) - return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi); + if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK)) + return createAnyOfTargetReduction(B, Src, Desc, OrigPhi); - return createSimpleTargetReduction(B, TTI, Src, RK); + return createSimpleTargetReduction(B, Src, RK); } Value *llvm::createOrderedReduction(IRBuilderBase &B, @@ -1453,7 +1474,7 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI, // Note that we must not perform expansions until after // we query *all* the costs, because if we perform temporary expansion // inbetween, one that we might not intend to keep, said expansion - // *may* affect cost calculation of the the next SCEV's we'll query, + // *may* affect cost calculation of the next SCEV's we'll query, // and next SCEV may errneously get smaller cost. // Collect all the candidate PHINodes to be rewritten. @@ -1632,42 +1653,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, struct PointerBounds { TrackingVH<Value> Start; TrackingVH<Value> End; + Value *StrideToCheck; }; /// Expand code for the lower and upper bound of the pointer group \p CG /// in \p TheLoop. \return the values for the bounds. static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG, Loop *TheLoop, Instruction *Loc, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { LLVMContext &Ctx = Loc->getContext(); - Type *PtrArithTy = Type::getInt8PtrTy(Ctx, CG->AddressSpace); + Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace); Value *Start = nullptr, *End = nullptr; LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n"); - Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc); - End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc); + const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr; + + // If the Low and High values are themselves loop-variant, then we may want + // to expand the range to include those covered by the outer loop as well. + // There is a trade-off here with the advantage being that creating checks + // using the expanded range permits the runtime memory checks to be hoisted + // out of the outer loop. This reduces the cost of entering the inner loop, + // which can be significant for low trip counts. The disadvantage is that + // there is a chance we may now never enter the vectorized inner loop, + // whereas using a restricted range check could have allowed us to enter at + // least once. This is why the behaviour is not currently the default and is + // controlled by the parameter 'HoistRuntimeChecks'. + if (HoistRuntimeChecks && TheLoop->getParentLoop() && + isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) { + auto *HighAR = cast<SCEVAddRecExpr>(High); + auto *LowAR = cast<SCEVAddRecExpr>(Low); + const Loop *OuterLoop = TheLoop->getParentLoop(); + const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE()); + if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) && + HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) { + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + const SCEV *OuterExitCount = + Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch); + if (!isa<SCEVCouldNotCompute>(OuterExitCount) && + OuterExitCount->getType()->isIntegerTy()) { + const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration( + OuterExitCount, *Exp.getSE()); + if (!isa<SCEVCouldNotCompute>(NewHigh)) { + LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include " + "outer loop in order to permit hoisting\n"); + High = NewHigh; + Low = cast<SCEVAddRecExpr>(Low)->getStart(); + // If there is a possibility that the stride is negative then we have + // to generate extra checks to ensure the stride is positive. + if (!Exp.getSE()->isKnownNonNegative(Recur)) { + Stride = Recur; + LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is " + "positive: " + << *Stride << '\n'); + } + } + } + } + } + + Start = Exp.expandCodeFor(Low, PtrArithTy, Loc); + End = Exp.expandCodeFor(High, PtrArithTy, Loc); if (CG->NeedsFreeze) { IRBuilder<> Builder(Loc); Start = Builder.CreateFreeze(Start, Start->getName() + ".fr"); End = Builder.CreateFreeze(End, End->getName() + ".fr"); } - LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n"); - return {Start, End}; + Value *StrideVal = + Stride ? Exp.expandCodeFor(Stride, Stride->getType(), Loc) : nullptr; + LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n"); + return {Start, End, StrideVal}; } /// Turns a collection of checks into a collection of expanded upper and /// lower bounds for both pointers in the check. static SmallVector<std::pair<PointerBounds, PointerBounds>, 4> expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, - Instruction *Loc, SCEVExpander &Exp) { + Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) { SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds; // Here we're relying on the SCEV Expander's cache to only emit code for the // same bounds once. transform(PointerChecks, std::back_inserter(ChecksWithBounds), [&](const RuntimePointerCheck &Check) { - PointerBounds First = expandBounds(Check.first, L, Loc, Exp), - Second = expandBounds(Check.second, L, Loc, Exp); + PointerBounds First = expandBounds(Check.first, L, Loc, Exp, + HoistRuntimeChecks), + Second = expandBounds(Check.second, L, Loc, Exp, + HoistRuntimeChecks); return std::make_pair(First, Second); }); @@ -1677,10 +1748,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L, Value *llvm::addRuntimeChecks( Instruction *Loc, Loop *TheLoop, const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, - SCEVExpander &Exp) { + SCEVExpander &Exp, bool HoistRuntimeChecks) { // TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible. // TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible - auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp); + auto ExpandedChecks = + expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks); LLVMContext &Ctx = Loc->getContext(); IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx, @@ -1693,21 +1765,13 @@ Value *llvm::addRuntimeChecks( const PointerBounds &A = Check.first, &B = Check.second; // Check if two pointers (A and B) conflict where conflict is computed as: // start(A) <= end(B) && start(B) <= end(A) - unsigned AS0 = A.Start->getType()->getPointerAddressSpace(); - unsigned AS1 = B.Start->getType()->getPointerAddressSpace(); - assert((AS0 == B.End->getType()->getPointerAddressSpace()) && - (AS1 == A.End->getType()->getPointerAddressSpace()) && + assert((A.Start->getType()->getPointerAddressSpace() == + B.End->getType()->getPointerAddressSpace()) && + (B.Start->getType()->getPointerAddressSpace() == + A.End->getType()->getPointerAddressSpace()) && "Trying to bounds check pointers with different address spaces"); - Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0); - Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1); - - Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc"); - Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc"); - Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc"); - Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc"); - // [A|B].Start points to the first accessed byte under base [A|B]. // [A|B].End points to the last accessed byte, plus one. // There is no conflict when the intervals are disjoint: @@ -1716,9 +1780,21 @@ Value *llvm::addRuntimeChecks( // bound0 = (B.Start < A.End) // bound1 = (A.Start < B.End) // IsConflict = bound0 & bound1 - Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0"); - Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1"); + Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0"); + Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1"); Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict"); + if (A.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } + if (B.StrideToCheck) { + Value *IsNegativeStride = ChkBuilder.CreateICmpSLT( + B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0), + "stride.check"); + IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride); + } if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); @@ -1740,23 +1816,31 @@ Value *llvm::addDiffRuntimeChecks( // Our instructions might fold to a constant. Value *MemoryRuntimeCheck = nullptr; + auto &SE = *Expander.getSE(); + // Map to keep track of created compares, The key is the pair of operands for + // the compare, to allow detecting and re-using redundant compares. + DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares; for (const auto &C : Checks) { Type *Ty = C.SinkStart->getType(); // Compute VF * IC * AccessSize. auto *VFTimesUFTimesSize = ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()), ConstantInt::get(Ty, IC * C.AccessSize)); - Value *Sink = Expander.expandCodeFor(C.SinkStart, Ty, Loc); - Value *Src = Expander.expandCodeFor(C.SrcStart, Ty, Loc); - if (C.NeedsFreeze) { - IRBuilder<> Builder(Loc); - Sink = Builder.CreateFreeze(Sink, Sink->getName() + ".fr"); - Src = Builder.CreateFreeze(Src, Src->getName() + ".fr"); - } - Value *Diff = ChkBuilder.CreateSub(Sink, Src); - Value *IsConflict = - ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); + Value *Diff = Expander.expandCodeFor( + SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc); + + // Check if the same compare has already been created earlier. In that case, + // there is no need to check it again. + Value *IsConflict = SeenCompares.lookup({Diff, VFTimesUFTimesSize}); + if (IsConflict) + continue; + IsConflict = + ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check"); + SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict}); + if (C.NeedsFreeze) + IsConflict = + ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr"); if (MemoryRuntimeCheck) { IsConflict = ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx"); |
