aboutsummaryrefslogtreecommitdiff
path: root/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp')
-rw-r--r--contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp210
1 files changed, 147 insertions, 63 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 7d6662c44f07..59485126b280 100644
--- a/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/contrib/llvm-project/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -296,7 +296,7 @@ std::optional<MDNode *> llvm::makeFollowupLoopID(
StringRef AttrName = cast<MDString>(NameMD)->getString();
// Do not inherit excluded attributes.
- return !AttrName.startswith(InheritOptionsExceptPrefix);
+ return !AttrName.starts_with(InheritOptionsExceptPrefix);
};
if (InheritThisAttribute(Op))
@@ -556,12 +556,8 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// Removes all incoming values from all other exiting blocks (including
// duplicate values from an exiting block).
// Nuke all entries except the zero'th entry which is the preheader entry.
- // NOTE! We need to remove Incoming Values in the reverse order as done
- // below, to keep the indices valid for deletion (removeIncomingValues
- // updates getNumIncomingValues and shifts all values down into the
- // operand being deleted).
- for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i)
- P.removeIncomingValue(e - i, false);
+ P.removeIncomingValueIf([](unsigned Idx) { return Idx != 0; },
+ /* DeletePHIIfEmpty */ false);
assert((P.getNumIncomingValues() == 1 &&
P.getIncomingBlock(PredIndex) == Preheader) &&
@@ -608,6 +604,7 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// Use a map to unique and a vector to guarantee deterministic ordering.
llvm::SmallDenseSet<DebugVariable, 4> DeadDebugSet;
llvm::SmallVector<DbgVariableIntrinsic *, 4> DeadDebugInst;
+ llvm::SmallVector<DPValue *, 4> DeadDPValues;
if (ExitBlock) {
// Given LCSSA form is satisfied, we should not have users of instructions
@@ -632,6 +629,24 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
"Unexpected user in reachable block");
U.set(Poison);
}
+
+ // RemoveDIs: do the same as below for DPValues.
+ if (Block->IsNewDbgInfoFormat) {
+ for (DPValue &DPV :
+ llvm::make_early_inc_range(I.getDbgValueRange())) {
+ DebugVariable Key(DPV.getVariable(), DPV.getExpression(),
+ DPV.getDebugLoc().get());
+ if (!DeadDebugSet.insert(Key).second)
+ continue;
+ // Unlinks the DPV from it's container, for later insertion.
+ DPV.removeFromParent();
+ DeadDPValues.push_back(&DPV);
+ }
+ }
+
+ // For one of each variable encountered, preserve a debug intrinsic (set
+ // to Poison) and transfer it to the loop exit. This terminates any
+ // variable locations that were set during the loop.
auto *DVI = dyn_cast<DbgVariableIntrinsic>(&I);
if (!DVI)
continue;
@@ -646,12 +661,22 @@ void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT, ScalarEvolution *SE,
// be be replaced with undef. Loop invariant values will still be available.
// Move dbg.values out the loop so that earlier location ranges are still
// terminated and loop invariant assignments are preserved.
- Instruction *InsertDbgValueBefore = ExitBlock->getFirstNonPHI();
- assert(InsertDbgValueBefore &&
+ DIBuilder DIB(*ExitBlock->getModule());
+ BasicBlock::iterator InsertDbgValueBefore =
+ ExitBlock->getFirstInsertionPt();
+ assert(InsertDbgValueBefore != ExitBlock->end() &&
"There should be a non-PHI instruction in exit block, else these "
"instructions will have no parent.");
+
for (auto *DVI : DeadDebugInst)
- DVI->moveBefore(InsertDbgValueBefore);
+ DVI->moveBefore(*ExitBlock, InsertDbgValueBefore);
+
+ // Due to the "head" bit in BasicBlock::iterator, we're going to insert
+ // each DPValue right at the start of the block, wheras dbg.values would be
+ // repeatedly inserted before the first instruction. To replicate this
+ // behaviour, do it backwards.
+ for (DPValue *DPV : llvm::reverse(DeadDPValues))
+ ExitBlock->insertDPValueBefore(DPV, InsertDbgValueBefore);
}
// Remove the block from the reference counting scheme, so that we can
@@ -937,8 +962,8 @@ CmpInst::Predicate llvm::getMinMaxReductionPredicate(RecurKind RK) {
}
}
-Value *llvm::createSelectCmpOp(IRBuilderBase &Builder, Value *StartVal,
- RecurKind RK, Value *Left, Value *Right) {
+Value *llvm::createAnyOfOp(IRBuilderBase &Builder, Value *StartVal,
+ RecurKind RK, Value *Left, Value *Right) {
if (auto VTy = dyn_cast<VectorType>(Left->getType()))
StartVal = Builder.CreateVectorSplat(VTy->getElementCount(), StartVal);
Value *Cmp =
@@ -1028,14 +1053,12 @@ Value *llvm::getShuffleReduction(IRBuilderBase &Builder, Value *Src,
return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
}
-Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
- const TargetTransformInfo *TTI,
- Value *Src,
- const RecurrenceDescriptor &Desc,
- PHINode *OrigPhi) {
- assert(RecurrenceDescriptor::isSelectCmpRecurrenceKind(
- Desc.getRecurrenceKind()) &&
- "Unexpected reduction kind");
+Value *llvm::createAnyOfTargetReduction(IRBuilderBase &Builder, Value *Src,
+ const RecurrenceDescriptor &Desc,
+ PHINode *OrigPhi) {
+ assert(
+ RecurrenceDescriptor::isAnyOfRecurrenceKind(Desc.getRecurrenceKind()) &&
+ "Unexpected reduction kind");
Value *InitVal = Desc.getRecurrenceStartValue();
Value *NewVal = nullptr;
@@ -1068,9 +1091,8 @@ Value *llvm::createSelectCmpTargetReduction(IRBuilderBase &Builder,
return Builder.CreateSelect(Cmp, NewVal, InitVal, "rdx.select");
}
-Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
- const TargetTransformInfo *TTI,
- Value *Src, RecurKind RdxKind) {
+Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder, Value *Src,
+ RecurKind RdxKind) {
auto *SrcVecEltTy = cast<VectorType>(Src->getType())->getElementType();
switch (RdxKind) {
case RecurKind::Add:
@@ -1111,7 +1133,6 @@ Value *llvm::createSimpleTargetReduction(IRBuilderBase &Builder,
}
Value *llvm::createTargetReduction(IRBuilderBase &B,
- const TargetTransformInfo *TTI,
const RecurrenceDescriptor &Desc, Value *Src,
PHINode *OrigPhi) {
// TODO: Support in-order reductions based on the recurrence descriptor.
@@ -1121,10 +1142,10 @@ Value *llvm::createTargetReduction(IRBuilderBase &B,
B.setFastMathFlags(Desc.getFastMathFlags());
RecurKind RK = Desc.getRecurrenceKind();
- if (RecurrenceDescriptor::isSelectCmpRecurrenceKind(RK))
- return createSelectCmpTargetReduction(B, TTI, Src, Desc, OrigPhi);
+ if (RecurrenceDescriptor::isAnyOfRecurrenceKind(RK))
+ return createAnyOfTargetReduction(B, Src, Desc, OrigPhi);
- return createSimpleTargetReduction(B, TTI, Src, RK);
+ return createSimpleTargetReduction(B, Src, RK);
}
Value *llvm::createOrderedReduction(IRBuilderBase &B,
@@ -1453,7 +1474,7 @@ int llvm::rewriteLoopExitValues(Loop *L, LoopInfo *LI, TargetLibraryInfo *TLI,
// Note that we must not perform expansions until after
// we query *all* the costs, because if we perform temporary expansion
// inbetween, one that we might not intend to keep, said expansion
- // *may* affect cost calculation of the the next SCEV's we'll query,
+ // *may* affect cost calculation of the next SCEV's we'll query,
// and next SCEV may errneously get smaller cost.
// Collect all the candidate PHINodes to be rewritten.
@@ -1632,42 +1653,92 @@ Loop *llvm::cloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM,
struct PointerBounds {
TrackingVH<Value> Start;
TrackingVH<Value> End;
+ Value *StrideToCheck;
};
/// Expand code for the lower and upper bound of the pointer group \p CG
/// in \p TheLoop. \return the values for the bounds.
static PointerBounds expandBounds(const RuntimeCheckingPtrGroup *CG,
Loop *TheLoop, Instruction *Loc,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
LLVMContext &Ctx = Loc->getContext();
- Type *PtrArithTy = Type::getInt8PtrTy(Ctx, CG->AddressSpace);
+ Type *PtrArithTy = PointerType::get(Ctx, CG->AddressSpace);
Value *Start = nullptr, *End = nullptr;
LLVM_DEBUG(dbgs() << "LAA: Adding RT check for range:\n");
- Start = Exp.expandCodeFor(CG->Low, PtrArithTy, Loc);
- End = Exp.expandCodeFor(CG->High, PtrArithTy, Loc);
+ const SCEV *Low = CG->Low, *High = CG->High, *Stride = nullptr;
+
+ // If the Low and High values are themselves loop-variant, then we may want
+ // to expand the range to include those covered by the outer loop as well.
+ // There is a trade-off here with the advantage being that creating checks
+ // using the expanded range permits the runtime memory checks to be hoisted
+ // out of the outer loop. This reduces the cost of entering the inner loop,
+ // which can be significant for low trip counts. The disadvantage is that
+ // there is a chance we may now never enter the vectorized inner loop,
+ // whereas using a restricted range check could have allowed us to enter at
+ // least once. This is why the behaviour is not currently the default and is
+ // controlled by the parameter 'HoistRuntimeChecks'.
+ if (HoistRuntimeChecks && TheLoop->getParentLoop() &&
+ isa<SCEVAddRecExpr>(High) && isa<SCEVAddRecExpr>(Low)) {
+ auto *HighAR = cast<SCEVAddRecExpr>(High);
+ auto *LowAR = cast<SCEVAddRecExpr>(Low);
+ const Loop *OuterLoop = TheLoop->getParentLoop();
+ const SCEV *Recur = LowAR->getStepRecurrence(*Exp.getSE());
+ if (Recur == HighAR->getStepRecurrence(*Exp.getSE()) &&
+ HighAR->getLoop() == OuterLoop && LowAR->getLoop() == OuterLoop) {
+ BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch();
+ const SCEV *OuterExitCount =
+ Exp.getSE()->getExitCount(OuterLoop, OuterLoopLatch);
+ if (!isa<SCEVCouldNotCompute>(OuterExitCount) &&
+ OuterExitCount->getType()->isIntegerTy()) {
+ const SCEV *NewHigh = cast<SCEVAddRecExpr>(High)->evaluateAtIteration(
+ OuterExitCount, *Exp.getSE());
+ if (!isa<SCEVCouldNotCompute>(NewHigh)) {
+ LLVM_DEBUG(dbgs() << "LAA: Expanded RT check for range to include "
+ "outer loop in order to permit hoisting\n");
+ High = NewHigh;
+ Low = cast<SCEVAddRecExpr>(Low)->getStart();
+ // If there is a possibility that the stride is negative then we have
+ // to generate extra checks to ensure the stride is positive.
+ if (!Exp.getSE()->isKnownNonNegative(Recur)) {
+ Stride = Recur;
+ LLVM_DEBUG(dbgs() << "LAA: ... but need to check stride is "
+ "positive: "
+ << *Stride << '\n');
+ }
+ }
+ }
+ }
+ }
+
+ Start = Exp.expandCodeFor(Low, PtrArithTy, Loc);
+ End = Exp.expandCodeFor(High, PtrArithTy, Loc);
if (CG->NeedsFreeze) {
IRBuilder<> Builder(Loc);
Start = Builder.CreateFreeze(Start, Start->getName() + ".fr");
End = Builder.CreateFreeze(End, End->getName() + ".fr");
}
- LLVM_DEBUG(dbgs() << "Start: " << *CG->Low << " End: " << *CG->High << "\n");
- return {Start, End};
+ Value *StrideVal =
+ Stride ? Exp.expandCodeFor(Stride, Stride->getType(), Loc) : nullptr;
+ LLVM_DEBUG(dbgs() << "Start: " << *Low << " End: " << *High << "\n");
+ return {Start, End, StrideVal};
}
/// Turns a collection of checks into a collection of expanded upper and
/// lower bounds for both pointers in the check.
static SmallVector<std::pair<PointerBounds, PointerBounds>, 4>
expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
- Instruction *Loc, SCEVExpander &Exp) {
+ Instruction *Loc, SCEVExpander &Exp, bool HoistRuntimeChecks) {
SmallVector<std::pair<PointerBounds, PointerBounds>, 4> ChecksWithBounds;
// Here we're relying on the SCEV Expander's cache to only emit code for the
// same bounds once.
transform(PointerChecks, std::back_inserter(ChecksWithBounds),
[&](const RuntimePointerCheck &Check) {
- PointerBounds First = expandBounds(Check.first, L, Loc, Exp),
- Second = expandBounds(Check.second, L, Loc, Exp);
+ PointerBounds First = expandBounds(Check.first, L, Loc, Exp,
+ HoistRuntimeChecks),
+ Second = expandBounds(Check.second, L, Loc, Exp,
+ HoistRuntimeChecks);
return std::make_pair(First, Second);
});
@@ -1677,10 +1748,11 @@ expandBounds(const SmallVectorImpl<RuntimePointerCheck> &PointerChecks, Loop *L,
Value *llvm::addRuntimeChecks(
Instruction *Loc, Loop *TheLoop,
const SmallVectorImpl<RuntimePointerCheck> &PointerChecks,
- SCEVExpander &Exp) {
+ SCEVExpander &Exp, bool HoistRuntimeChecks) {
// TODO: Move noalias annotation code from LoopVersioning here and share with LV if possible.
// TODO: Pass RtPtrChecking instead of PointerChecks and SE separately, if possible
- auto ExpandedChecks = expandBounds(PointerChecks, TheLoop, Loc, Exp);
+ auto ExpandedChecks =
+ expandBounds(PointerChecks, TheLoop, Loc, Exp, HoistRuntimeChecks);
LLVMContext &Ctx = Loc->getContext();
IRBuilder<InstSimplifyFolder> ChkBuilder(Ctx,
@@ -1693,21 +1765,13 @@ Value *llvm::addRuntimeChecks(
const PointerBounds &A = Check.first, &B = Check.second;
// Check if two pointers (A and B) conflict where conflict is computed as:
// start(A) <= end(B) && start(B) <= end(A)
- unsigned AS0 = A.Start->getType()->getPointerAddressSpace();
- unsigned AS1 = B.Start->getType()->getPointerAddressSpace();
- assert((AS0 == B.End->getType()->getPointerAddressSpace()) &&
- (AS1 == A.End->getType()->getPointerAddressSpace()) &&
+ assert((A.Start->getType()->getPointerAddressSpace() ==
+ B.End->getType()->getPointerAddressSpace()) &&
+ (B.Start->getType()->getPointerAddressSpace() ==
+ A.End->getType()->getPointerAddressSpace()) &&
"Trying to bounds check pointers with different address spaces");
- Type *PtrArithTy0 = Type::getInt8PtrTy(Ctx, AS0);
- Type *PtrArithTy1 = Type::getInt8PtrTy(Ctx, AS1);
-
- Value *Start0 = ChkBuilder.CreateBitCast(A.Start, PtrArithTy0, "bc");
- Value *Start1 = ChkBuilder.CreateBitCast(B.Start, PtrArithTy1, "bc");
- Value *End0 = ChkBuilder.CreateBitCast(A.End, PtrArithTy1, "bc");
- Value *End1 = ChkBuilder.CreateBitCast(B.End, PtrArithTy0, "bc");
-
// [A|B].Start points to the first accessed byte under base [A|B].
// [A|B].End points to the last accessed byte, plus one.
// There is no conflict when the intervals are disjoint:
@@ -1716,9 +1780,21 @@ Value *llvm::addRuntimeChecks(
// bound0 = (B.Start < A.End)
// bound1 = (A.Start < B.End)
// IsConflict = bound0 & bound1
- Value *Cmp0 = ChkBuilder.CreateICmpULT(Start0, End1, "bound0");
- Value *Cmp1 = ChkBuilder.CreateICmpULT(Start1, End0, "bound1");
+ Value *Cmp0 = ChkBuilder.CreateICmpULT(A.Start, B.End, "bound0");
+ Value *Cmp1 = ChkBuilder.CreateICmpULT(B.Start, A.End, "bound1");
Value *IsConflict = ChkBuilder.CreateAnd(Cmp0, Cmp1, "found.conflict");
+ if (A.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ A.StrideToCheck, ConstantInt::get(A.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
+ if (B.StrideToCheck) {
+ Value *IsNegativeStride = ChkBuilder.CreateICmpSLT(
+ B.StrideToCheck, ConstantInt::get(B.StrideToCheck->getType(), 0),
+ "stride.check");
+ IsConflict = ChkBuilder.CreateOr(IsConflict, IsNegativeStride);
+ }
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");
@@ -1740,23 +1816,31 @@ Value *llvm::addDiffRuntimeChecks(
// Our instructions might fold to a constant.
Value *MemoryRuntimeCheck = nullptr;
+ auto &SE = *Expander.getSE();
+ // Map to keep track of created compares, The key is the pair of operands for
+ // the compare, to allow detecting and re-using redundant compares.
+ DenseMap<std::pair<Value *, Value *>, Value *> SeenCompares;
for (const auto &C : Checks) {
Type *Ty = C.SinkStart->getType();
// Compute VF * IC * AccessSize.
auto *VFTimesUFTimesSize =
ChkBuilder.CreateMul(GetVF(ChkBuilder, Ty->getScalarSizeInBits()),
ConstantInt::get(Ty, IC * C.AccessSize));
- Value *Sink = Expander.expandCodeFor(C.SinkStart, Ty, Loc);
- Value *Src = Expander.expandCodeFor(C.SrcStart, Ty, Loc);
- if (C.NeedsFreeze) {
- IRBuilder<> Builder(Loc);
- Sink = Builder.CreateFreeze(Sink, Sink->getName() + ".fr");
- Src = Builder.CreateFreeze(Src, Src->getName() + ".fr");
- }
- Value *Diff = ChkBuilder.CreateSub(Sink, Src);
- Value *IsConflict =
- ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check");
+ Value *Diff = Expander.expandCodeFor(
+ SE.getMinusSCEV(C.SinkStart, C.SrcStart), Ty, Loc);
+
+ // Check if the same compare has already been created earlier. In that case,
+ // there is no need to check it again.
+ Value *IsConflict = SeenCompares.lookup({Diff, VFTimesUFTimesSize});
+ if (IsConflict)
+ continue;
+ IsConflict =
+ ChkBuilder.CreateICmpULT(Diff, VFTimesUFTimesSize, "diff.check");
+ SeenCompares.insert({{Diff, VFTimesUFTimesSize}, IsConflict});
+ if (C.NeedsFreeze)
+ IsConflict =
+ ChkBuilder.CreateFreeze(IsConflict, IsConflict->getName() + ".fr");
if (MemoryRuntimeCheck) {
IsConflict =
ChkBuilder.CreateOr(MemoryRuntimeCheck, IsConflict, "conflict.rdx");