aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp')
-rw-r--r--llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp390
1 files changed, 263 insertions, 127 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
index 4c89f947d7fc..a4369b83e732 100644
--- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp
@@ -799,7 +799,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS,
/// value, and mutate S to point to a new SCEV with that value excluded.
static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (C->getAPInt().getMinSignedBits() <= 64) {
+ if (C->getAPInt().getSignificantBits() <= 64) {
S = SE.getConstant(C->getType(), 0);
return C->getValue()->getSExtValue();
}
@@ -896,9 +896,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI,
/// Return the type of the memory being accessed.
static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
Instruction *Inst, Value *OperandVal) {
- MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace);
+ MemAccessTy AccessTy = MemAccessTy::getUnknown(Inst->getContext());
+
+ // First get the type of memory being accessed.
+ if (Type *Ty = Inst->getAccessType())
+ AccessTy.MemTy = Ty;
+
+ // Then get the pointer address space.
if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) {
- AccessTy.MemTy = SI->getOperand(0)->getType();
AccessTy.AddrSpace = SI->getPointerAddressSpace();
} else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) {
AccessTy.AddrSpace = LI->getPointerAddressSpace();
@@ -923,7 +928,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI,
II->getArgOperand(0)->getType()->getPointerAddressSpace();
break;
case Intrinsic::masked_store:
- AccessTy.MemTy = II->getOperand(0)->getType();
AccessTy.AddrSpace =
II->getArgOperand(1)->getType()->getPointerAddressSpace();
break;
@@ -976,6 +980,7 @@ static bool isHighCostExpansion(const SCEV *S,
switch (S->getSCEVType()) {
case scUnknown:
case scConstant:
+ case scVScale:
return false;
case scTruncate:
return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(),
@@ -1414,7 +1419,7 @@ void Cost::RateFormula(const Formula &F,
C.ImmCost += 64; // Handle symbolic values conservatively.
// TODO: This should probably be the pointer size.
else if (Offset != 0)
- C.ImmCost += APInt(64, Offset, true).getMinSignedBits();
+ C.ImmCost += APInt(64, Offset, true).getSignificantBits();
// Check with target if this offset with this instruction is
// specifically not supported.
@@ -2498,7 +2503,7 @@ LSRInstance::OptimizeLoopTermCond() {
if (C->isOne() || C->isMinusOne())
goto decline_post_inc;
// Avoid weird situations.
- if (C->getValue().getMinSignedBits() >= 64 ||
+ if (C->getValue().getSignificantBits() >= 64 ||
C->getValue().isMinSignedValue())
goto decline_post_inc;
// Check for possible scaled-address reuse.
@@ -2508,13 +2513,13 @@ LSRInstance::OptimizeLoopTermCond() {
int64_t Scale = C->getSExtValue();
if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr,
/*BaseOffset=*/0,
- /*HasBaseReg=*/false, Scale,
+ /*HasBaseReg=*/true, Scale,
AccessTy.AddrSpace))
goto decline_post_inc;
Scale = -Scale;
if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr,
/*BaseOffset=*/0,
- /*HasBaseReg=*/false, Scale,
+ /*HasBaseReg=*/true, Scale,
AccessTy.AddrSpace))
goto decline_post_inc;
}
@@ -2660,8 +2665,7 @@ LSRUse *
LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF,
const LSRUse &OrigLU) {
// Search all uses for the formula. This could be more clever.
- for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) {
- LSRUse &LU = Uses[LUIdx];
+ for (LSRUse &LU : Uses) {
// Check whether this use is close enough to OrigLU, to see whether it's
// worthwhile looking through its formulae.
// Ignore ICmpZero uses because they may contain formulae generated by
@@ -2703,6 +2707,8 @@ void LSRInstance::CollectInterestingTypesAndFactors() {
SmallVector<const SCEV *, 4> Worklist;
for (const IVStrideUse &U : IU) {
const SCEV *Expr = IU.getExpr(U);
+ if (!Expr)
+ continue;
// Collect interesting types.
Types.insert(SE.getEffectiveSCEVType(Expr->getType()));
@@ -2740,13 +2746,13 @@ void LSRInstance::CollectInterestingTypesAndFactors() {
if (const SCEVConstant *Factor =
dyn_cast_or_null<SCEVConstant>(getExactSDiv(NewStride, OldStride,
SE, true))) {
- if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero())
+ if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero())
Factors.insert(Factor->getAPInt().getSExtValue());
} else if (const SCEVConstant *Factor =
dyn_cast_or_null<SCEVConstant>(getExactSDiv(OldStride,
NewStride,
SE, true))) {
- if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero())
+ if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero())
Factors.insert(Factor->getAPInt().getSExtValue());
}
}
@@ -2812,9 +2818,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) {
/// SCEVUnknown, we simply return the rightmost SCEV operand.
static const SCEV *getExprBase(const SCEV *S) {
switch (S->getSCEVType()) {
- default: // uncluding scUnknown.
+ default: // including scUnknown.
return S;
case scConstant:
+ case scVScale:
return nullptr;
case scTruncate:
return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand());
@@ -3175,7 +3182,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst,
if (!IncConst || !isAddressUse(TTI, UserInst, Operand))
return false;
- if (IncConst->getAPInt().getMinSignedBits() > 64)
+ if (IncConst->getAPInt().getSignificantBits() > 64)
return false;
MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand);
@@ -3320,6 +3327,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
}
const SCEV *S = IU.getExpr(U);
+ if (!S)
+ continue;
PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops();
// Equality (== and !=) ICmps are special. We can rewrite (i == N) as
@@ -3352,6 +3361,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
// S is normalized, so normalize N before folding it into S
// to keep the result normalized.
N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);
+ if (!N)
+ continue;
Kind = LSRUse::ICmpZero;
S = SE.getMinusSCEV(N, S);
} else if (L->isLoopInvariant(NV) &&
@@ -3366,6 +3377,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() {
// SCEV can't compute the difference of two unknown pointers.
N = SE.getUnknown(NV);
N = normalizeForPostIncUse(N, TmpPostIncLoops, SE);
+ if (!N)
+ continue;
Kind = LSRUse::ICmpZero;
S = SE.getMinusSCEV(N, S);
assert(!isa<SCEVCouldNotCompute>(S));
@@ -3494,8 +3507,8 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() {
if (const Instruction *Inst = dyn_cast<Instruction>(V)) {
// Look for instructions defined outside the loop.
if (L->contains(Inst)) continue;
- } else if (isa<UndefValue>(V))
- // Undef doesn't have a live range, so it doesn't matter.
+ } else if (isa<Constant>(V))
+ // Constants can be re-materialized.
continue;
for (const Use &U : V->uses()) {
const Instruction *UserInst = dyn_cast<Instruction>(U.getUser());
@@ -4137,6 +4150,29 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) {
}
}
+/// Extend/Truncate \p Expr to \p ToTy considering post-inc uses in \p Loops.
+/// For all PostIncLoopSets in \p Loops, first de-normalize \p Expr, then
+/// perform the extension/truncate and normalize again, as the normalized form
+/// can result in folds that are not valid in the post-inc use contexts. The
+/// expressions for all PostIncLoopSets must match, otherwise return nullptr.
+static const SCEV *
+getAnyExtendConsideringPostIncUses(ArrayRef<PostIncLoopSet> Loops,
+ const SCEV *Expr, Type *ToTy,
+ ScalarEvolution &SE) {
+ const SCEV *Result = nullptr;
+ for (auto &L : Loops) {
+ auto *DenormExpr = denormalizeForPostIncUse(Expr, L, SE);
+ const SCEV *NewDenormExpr = SE.getAnyExtendExpr(DenormExpr, ToTy);
+ const SCEV *New = normalizeForPostIncUse(NewDenormExpr, L, SE);
+ if (!New || (Result && New != Result))
+ return nullptr;
+ Result = New;
+ }
+
+ assert(Result && "failed to create expression");
+ return Result;
+}
+
/// Generate reuse formulae from different IV types.
void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
// Don't bother truncating symbolic values.
@@ -4156,6 +4192,10 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
[](const SCEV *S) { return S->getType()->isPointerTy(); }))
return;
+ SmallVector<PostIncLoopSet> Loops;
+ for (auto &LF : LU.Fixups)
+ Loops.push_back(LF.PostIncLoops);
+
for (Type *SrcTy : Types) {
if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) {
Formula F = Base;
@@ -4165,15 +4205,17 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) {
// initial node (maybe due to depth limitations), but it can do them while
// taking ext.
if (F.ScaledReg) {
- const SCEV *NewScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy);
- if (NewScaledReg->isZero())
- continue;
+ const SCEV *NewScaledReg =
+ getAnyExtendConsideringPostIncUses(Loops, F.ScaledReg, SrcTy, SE);
+ if (!NewScaledReg || NewScaledReg->isZero())
+ continue;
F.ScaledReg = NewScaledReg;
}
bool HasZeroBaseReg = false;
for (const SCEV *&BaseReg : F.BaseRegs) {
- const SCEV *NewBaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy);
- if (NewBaseReg->isZero()) {
+ const SCEV *NewBaseReg =
+ getAnyExtendConsideringPostIncUses(Loops, BaseReg, SrcTy, SE);
+ if (!NewBaseReg || NewBaseReg->isZero()) {
HasZeroBaseReg = true;
break;
}
@@ -4379,8 +4421,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() {
if ((C->getAPInt() + NewF.BaseOffset)
.abs()
.slt(std::abs(NewF.BaseOffset)) &&
- (C->getAPInt() + NewF.BaseOffset).countTrailingZeros() >=
- countTrailingZeros<uint64_t>(NewF.BaseOffset))
+ (C->getAPInt() + NewF.BaseOffset).countr_zero() >=
+ (unsigned)llvm::countr_zero<uint64_t>(NewF.BaseOffset))
goto skip_formula;
// Ok, looks good.
@@ -4982,6 +5024,32 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() {
LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs()));
}
+// Check if Best and Reg are SCEVs separated by a constant amount C, and if so
+// would the addressing offset +C would be legal where the negative offset -C is
+// not.
+static bool IsSimplerBaseSCEVForTarget(const TargetTransformInfo &TTI,
+ ScalarEvolution &SE, const SCEV *Best,
+ const SCEV *Reg,
+ MemAccessTy AccessType) {
+ if (Best->getType() != Reg->getType() ||
+ (isa<SCEVAddRecExpr>(Best) && isa<SCEVAddRecExpr>(Reg) &&
+ cast<SCEVAddRecExpr>(Best)->getLoop() !=
+ cast<SCEVAddRecExpr>(Reg)->getLoop()))
+ return false;
+ const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Best, Reg));
+ if (!Diff)
+ return false;
+
+ return TTI.isLegalAddressingMode(
+ AccessType.MemTy, /*BaseGV=*/nullptr,
+ /*BaseOffset=*/Diff->getAPInt().getSExtValue(),
+ /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace) &&
+ !TTI.isLegalAddressingMode(
+ AccessType.MemTy, /*BaseGV=*/nullptr,
+ /*BaseOffset=*/-Diff->getAPInt().getSExtValue(),
+ /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace);
+}
+
/// Pick a register which seems likely to be profitable, and then in any use
/// which has any reference to that register, delete all formulae which do not
/// reference that register.
@@ -5010,6 +5078,19 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() {
Best = Reg;
BestNum = Count;
}
+
+ // If the scores are the same, but the Reg is simpler for the target
+ // (for example {x,+,1} as opposed to {x+C,+,1}, where the target can
+ // handle +C but not -C), opt for the simpler formula.
+ if (Count == BestNum) {
+ int LUIdx = RegUses.getUsedByIndices(Reg).find_first();
+ if (LUIdx >= 0 && Uses[LUIdx].Kind == LSRUse::Address &&
+ IsSimplerBaseSCEVForTarget(TTI, SE, Best, Reg,
+ Uses[LUIdx].AccessTy)) {
+ Best = Reg;
+ BestNum = Count;
+ }
+ }
}
}
assert(Best && "Failed to find best LSRUse candidate");
@@ -5497,6 +5578,13 @@ void LSRInstance::RewriteForPHI(
PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F,
SmallVectorImpl<WeakTrackingVH> &DeadInsts) const {
DenseMap<BasicBlock *, Value *> Inserted;
+
+ // Inserting instructions in the loop and using them as PHI's input could
+ // break LCSSA in case if PHI's parent block is not a loop exit (i.e. the
+ // corresponding incoming block is not loop exiting). So collect all such
+ // instructions to form LCSSA for them later.
+ SmallVector<Instruction *, 4> InsertedNonLCSSAInsts;
+
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
if (PN->getIncomingValue(i) == LF.OperandValToReplace) {
bool needUpdateFixups = false;
@@ -5562,6 +5650,13 @@ void LSRInstance::RewriteForPHI(
FullV, LF.OperandValToReplace->getType(),
"tmp", BB->getTerminator());
+ // If the incoming block for this value is not in the loop, it means the
+ // current PHI is not in a loop exit, so we must create a LCSSA PHI for
+ // the inserted value.
+ if (auto *I = dyn_cast<Instruction>(FullV))
+ if (L->contains(I) && !L->contains(BB))
+ InsertedNonLCSSAInsts.push_back(I);
+
PN->setIncomingValue(i, FullV);
Pair.first->second = FullV;
}
@@ -5604,6 +5699,8 @@ void LSRInstance::RewriteForPHI(
}
}
}
+
+ formLCSSAForInstructions(InsertedNonLCSSAInsts, DT, LI, &SE);
}
/// Emit instructions for the leading candidate expression for this LSRUse (this
@@ -5643,6 +5740,36 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF,
DeadInsts.emplace_back(OperandIsInstr);
}
+// Trying to hoist the IVInc to loop header if all IVInc users are in
+// the loop header. It will help backend to generate post index load/store
+// when the latch block is different from loop header block.
+static bool canHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup,
+ const LSRUse &LU, Instruction *IVIncInsertPos,
+ Loop *L) {
+ if (LU.Kind != LSRUse::Address)
+ return false;
+
+ // For now this code do the conservative optimization, only work for
+ // the header block. Later we can hoist the IVInc to the block post
+ // dominate all users.
+ BasicBlock *LHeader = L->getHeader();
+ if (IVIncInsertPos->getParent() == LHeader)
+ return false;
+
+ if (!Fixup.OperandValToReplace ||
+ any_of(Fixup.OperandValToReplace->users(), [&LHeader](User *U) {
+ Instruction *UI = cast<Instruction>(U);
+ return UI->getParent() != LHeader;
+ }))
+ return false;
+
+ Instruction *I = Fixup.UserInst;
+ Type *Ty = I->getType();
+ return Ty->isIntegerTy() &&
+ ((isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) ||
+ (isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty)));
+}
+
/// Rewrite all the fixup locations with new values, following the chosen
/// solution.
void LSRInstance::ImplementSolution(
@@ -5651,8 +5778,6 @@ void LSRInstance::ImplementSolution(
// we can remove them after we are done working.
SmallVector<WeakTrackingVH, 16> DeadInsts;
- Rewriter.setIVIncInsertPos(L, IVIncInsertPos);
-
// Mark phi nodes that terminate chains so the expander tries to reuse them.
for (const IVChain &Chain : IVChainVec) {
if (PHINode *PN = dyn_cast<PHINode>(Chain.tailUserInst()))
@@ -5662,6 +5787,11 @@ void LSRInstance::ImplementSolution(
// Expand the new value definitions and update the users.
for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx)
for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) {
+ Instruction *InsertPos =
+ canHoistIVInc(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, L)
+ ? L->getHeader()->getTerminator()
+ : IVIncInsertPos;
+ Rewriter.setIVIncInsertPos(L, InsertPos);
Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts);
Changed = true;
}
@@ -5994,7 +6124,7 @@ struct SCEVDbgValueBuilder {
}
bool pushConst(const SCEVConstant *C) {
- if (C->getAPInt().getMinSignedBits() > 64)
+ if (C->getAPInt().getSignificantBits() > 64)
return false;
Expr.push_back(llvm::dwarf::DW_OP_consts);
Expr.push_back(C->getAPInt().getSExtValue());
@@ -6083,7 +6213,7 @@ struct SCEVDbgValueBuilder {
/// SCEV constant value is an identity function.
bool isIdentityFunction(uint64_t Op, const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) {
- if (C->getAPInt().getMinSignedBits() > 64)
+ if (C->getAPInt().getSignificantBits() > 64)
return false;
int64_t I = C->getAPInt().getSExtValue();
switch (Op) {
@@ -6338,13 +6468,13 @@ static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec,
}
}
-/// Cached location ops may be erased during LSR, in which case an undef is
+/// Cached location ops may be erased during LSR, in which case a poison is
/// required when restoring from the cache. The type of that location is no
-/// longer available, so just use int8. The undef will be replaced by one or
+/// longer available, so just use int8. The poison will be replaced by one or
/// more locations later when a SCEVDbgValueBuilder selects alternative
/// locations to use for the salvage.
-static Value *getValueOrUndef(WeakVH &VH, LLVMContext &C) {
- return (VH) ? VH : UndefValue::get(llvm::Type::getInt8Ty(C));
+static Value *getValueOrPoison(WeakVH &VH, LLVMContext &C) {
+ return (VH) ? VH : PoisonValue::get(llvm::Type::getInt8Ty(C));
}
/// Restore the DVI's pre-LSR arguments. Substitute undef for any erased values.
@@ -6363,12 +6493,12 @@ static void restorePreTransformState(DVIRecoveryRec &DVIRec) {
// this case was not present before, so force the location back to a single
// uncontained Value.
Value *CachedValue =
- getValueOrUndef(DVIRec.LocationOps[0], DVIRec.DVI->getContext());
+ getValueOrPoison(DVIRec.LocationOps[0], DVIRec.DVI->getContext());
DVIRec.DVI->setRawLocation(ValueAsMetadata::get(CachedValue));
} else {
SmallVector<ValueAsMetadata *, 3> MetadataLocs;
for (WeakVH VH : DVIRec.LocationOps) {
- Value *CachedValue = getValueOrUndef(VH, DVIRec.DVI->getContext());
+ Value *CachedValue = getValueOrPoison(VH, DVIRec.DVI->getContext());
MetadataLocs.push_back(ValueAsMetadata::get(CachedValue));
}
auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs);
@@ -6431,7 +6561,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE,
// less DWARF ops than an iteration count-based expression.
if (std::optional<APInt> Offset =
SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) {
- if (Offset->getMinSignedBits() <= 64)
+ if (Offset->getSignificantBits() <= 64)
SalvageExpr->createOffsetExpr(Offset->getSExtValue(), LSRInductionVar);
} else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr,
SE))
@@ -6607,7 +6737,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE,
return nullptr;
}
-static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *>>
+static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>>
canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
const LoopInfo &LI) {
if (!L->isInnermost()) {
@@ -6626,16 +6756,13 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
}
BasicBlock *LoopLatch = L->getLoopLatch();
-
- // TODO: Can we do something for greater than and less than?
- // Terminating condition is foldable when it is an eq/ne icmp
- BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
- if (BI->isUnconditional())
+ BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator());
+ if (!BI || BI->isUnconditional())
return std::nullopt;
- Value *TermCond = BI->getCondition();
- if (!isa<ICmpInst>(TermCond) || !cast<ICmpInst>(TermCond)->isEquality()) {
- LLVM_DEBUG(dbgs() << "Cannot fold on branching condition that is not an "
- "ICmpInst::eq / ICmpInst::ne\n");
+ auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition());
+ if (!TermCond) {
+ LLVM_DEBUG(
+ dbgs() << "Cannot fold on branching condition that is not an ICmpInst");
return std::nullopt;
}
if (!TermCond->hasOneUse()) {
@@ -6645,89 +6772,42 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
return std::nullopt;
}
- // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it
- // is only used by the terminating condition. To check for this, we may need
- // to traverse through a chain of use-def until we can examine the final
- // usage.
- // *----------------------*
- // *---->| LoopHeader: |
- // | | PrimaryIV = phi ... |
- // | *----------------------*
- // | |
- // | |
- // | chain of
- // | single use
- // used by |
- // phi |
- // | Value
- // | / \
- // | chain of chain of
- // | single use single use
- // | / \
- // | / \
- // *- Value Value --> used by terminating condition
- auto IsToFold = [&](PHINode &PN) -> bool {
- Value *V = &PN;
-
- while (V->getNumUses() == 1)
- V = *V->user_begin();
-
- if (V->getNumUses() != 2)
- return false;
+ BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0));
+ Value *RHS = TermCond->getOperand(1);
+ if (!LHS || !L->isLoopInvariant(RHS))
+ // We could pattern match the inverse form of the icmp, but that is
+ // non-canonical, and this pass is running *very* late in the pipeline.
+ return std::nullopt;
- Value *VToPN = nullptr;
- Value *VToTermCond = nullptr;
- for (User *U : V->users()) {
- while (U->getNumUses() == 1) {
- if (isa<PHINode>(U))
- VToPN = U;
- if (U == TermCond)
- VToTermCond = U;
- U = *U->user_begin();
- }
- }
- return VToPN && VToTermCond;
- };
+ // Find the IV used by the current exit condition.
+ PHINode *ToFold;
+ Value *ToFoldStart, *ToFoldStep;
+ if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep))
+ return std::nullopt;
- // If this is an IV which we could replace the terminating condition, return
- // the final value of the alternative IV on the last iteration.
- auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * {
- // FIXME: This does not properly account for overflow.
- const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
- const SCEV *BECount = SE.getBackedgeTakenCount(L);
- const SCEV *TermValueS = SE.getAddExpr(
- AddRec->getOperand(0),
- SE.getTruncateOrZeroExtend(
- SE.getMulExpr(
- AddRec->getOperand(1),
- SE.getTruncateOrZeroExtend(
- SE.getAddExpr(BECount, SE.getOne(BECount->getType())),
- AddRec->getOperand(1)->getType())),
- AddRec->getOperand(0)->getType()));
- const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
- SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
- if (!Expander.isSafeToExpand(TermValueS)) {
- LLVM_DEBUG(
- dbgs() << "Is not safe to expand terminating value for phi node" << PN
- << "\n");
- return nullptr;
- }
- return TermValueS;
- };
+ // If that IV isn't dead after we rewrite the exit condition in terms of
+ // another IV, there's no point in doing the transform.
+ if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond))
+ return std::nullopt;
+
+ const SCEV *BECount = SE.getBackedgeTakenCount(L);
+ const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
+ SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
- PHINode *ToFold = nullptr;
PHINode *ToHelpFold = nullptr;
const SCEV *TermValueS = nullptr;
-
+ bool MustDropPoison = false;
for (PHINode &PN : L->getHeader()->phis()) {
+ if (ToFold == &PN)
+ continue;
+
if (!SE.isSCEVable(PN.getType())) {
LLVM_DEBUG(dbgs() << "IV of phi '" << PN
<< "' is not SCEV-able, not qualified for the "
"terminating condition folding.\n");
continue;
}
- const SCEV *S = SE.getSCEV(&PN);
- const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S);
+ const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN));
// Only speculate on affine AddRec
if (!AddRec || !AddRec->isAffine()) {
LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN
@@ -6736,12 +6816,63 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
continue;
}
- if (IsToFold(PN))
- ToFold = &PN;
- else if (auto P = getAlternateIVEnd(PN)) {
- ToHelpFold = &PN;
- TermValueS = P;
+ // Check that we can compute the value of AddRec on the exiting iteration
+ // without soundness problems. evaluateAtIteration internally needs
+ // to multiply the stride of the iteration number - which may wrap around.
+ // The issue here is subtle because computing the result accounting for
+ // wrap is insufficient. In order to use the result in an exit test, we
+ // must also know that AddRec doesn't take the same value on any previous
+ // iteration. The simplest case to consider is a candidate IV which is
+ // narrower than the trip count (and thus original IV), but this can
+ // also happen due to non-unit strides on the candidate IVs.
+ if (!AddRec->hasNoSelfWrap())
+ continue;
+
+ const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE);
+ const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE);
+ if (!Expander.isSafeToExpand(TermValueSLocal)) {
+ LLVM_DEBUG(
+ dbgs() << "Is not safe to expand terminating value for phi node" << PN
+ << "\n");
+ continue;
}
+
+ // The candidate IV may have been otherwise dead and poison from the
+ // very first iteration. If we can't disprove that, we can't use the IV.
+ if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV "
+ << PN << "\n");
+ continue;
+ }
+
+ // The candidate IV may become poison on the last iteration. If this
+ // value is not branched on, this is a well defined program. We're
+ // about to add a new use to this IV, and we have to ensure we don't
+ // insert UB which didn't previously exist.
+ bool MustDropPoisonLocal = false;
+ Instruction *PostIncV =
+ cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch));
+ if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(),
+ &DT)) {
+ LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use"
+ << PN << "\n");
+
+ // If this is a complex recurrance with multiple instructions computing
+ // the backedge value, we might need to strip poison flags from all of
+ // them.
+ if (PostIncV->getOperand(0) != &PN)
+ continue;
+
+ // In order to perform the transform, we need to drop the poison generating
+ // flags on this instruction (if any).
+ MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags();
+ }
+
+ // We pick the last legal alternate IV. We could expore choosing an optimal
+ // alternate IV if we had a decent heuristic to do so.
+ ToHelpFold = &PN;
+ TermValueS = TermValueSLocal;
+ MustDropPoison = MustDropPoisonLocal;
}
LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs()
@@ -6757,7 +6888,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT,
if (!ToFold || !ToHelpFold)
return std::nullopt;
- return std::make_tuple(ToFold, ToHelpFold, TermValueS);
+ return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison);
}
static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
@@ -6820,7 +6951,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
if (AllowTerminatingConditionFoldingAfterLSR) {
if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) {
- auto [ToFold, ToHelpFold, TermValueS] = *Opt;
+ auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt;
Changed = true;
NumTermFold++;
@@ -6838,6 +6969,10 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
(void)StartValue;
Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch);
+ // See comment in canFoldTermCondOfLoop on why this is sufficient.
+ if (MustDrop)
+ cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags();
+
// SCEVExpander for both use in preheader and latch
const DataLayout &DL = L->getHeader()->getModule()->getDataLayout();
SCEVExpander Expander(SE, DL, "lsr_fold_term_cond");
@@ -6859,11 +6994,12 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE,
BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator());
ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition());
IRBuilder<> LatchBuilder(LoopLatch->getTerminator());
- // FIXME: We are adding a use of an IV here without account for poison safety.
- // This is incorrect.
- Value *NewTermCond = LatchBuilder.CreateICmp(
- OldTermCond->getPredicate(), LoopValue, TermValue,
- "lsr_fold_term_cond.replaced_term_cond");
+ Value *NewTermCond =
+ LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue,
+ "lsr_fold_term_cond.replaced_term_cond");
+ // Swap successors to exit loop body if IV equals to new TermValue
+ if (BI->getSuccessor(0) == L->getHeader())
+ BI->swapSuccessors();
LLVM_DEBUG(dbgs() << "Old term-cond:\n"
<< *OldTermCond << "\n"