diff options
Diffstat (limited to 'llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp')
-rw-r--r-- | llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp | 390 |
1 files changed, 263 insertions, 127 deletions
diff --git a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 4c89f947d7fc..a4369b83e732 100644 --- a/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -799,7 +799,7 @@ static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, /// value, and mutate S to point to a new SCEV with that value excluded. static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - if (C->getAPInt().getMinSignedBits() <= 64) { + if (C->getAPInt().getSignificantBits() <= 64) { S = SE.getConstant(C->getType(), 0); return C->getValue()->getSExtValue(); } @@ -896,9 +896,14 @@ static bool isAddressUse(const TargetTransformInfo &TTI, /// Return the type of the memory being accessed. static MemAccessTy getAccessType(const TargetTransformInfo &TTI, Instruction *Inst, Value *OperandVal) { - MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); + MemAccessTy AccessTy = MemAccessTy::getUnknown(Inst->getContext()); + + // First get the type of memory being accessed. + if (Type *Ty = Inst->getAccessType()) + AccessTy.MemTy = Ty; + + // Then get the pointer address space. if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { - AccessTy.MemTy = SI->getOperand(0)->getType(); AccessTy.AddrSpace = SI->getPointerAddressSpace(); } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { AccessTy.AddrSpace = LI->getPointerAddressSpace(); @@ -923,7 +928,6 @@ static MemAccessTy getAccessType(const TargetTransformInfo &TTI, II->getArgOperand(0)->getType()->getPointerAddressSpace(); break; case Intrinsic::masked_store: - AccessTy.MemTy = II->getOperand(0)->getType(); AccessTy.AddrSpace = II->getArgOperand(1)->getType()->getPointerAddressSpace(); break; @@ -976,6 +980,7 @@ static bool isHighCostExpansion(const SCEV *S, switch (S->getSCEVType()) { case scUnknown: case scConstant: + case scVScale: return false; case scTruncate: return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(), @@ -1414,7 +1419,7 @@ void Cost::RateFormula(const Formula &F, C.ImmCost += 64; // Handle symbolic values conservatively. // TODO: This should probably be the pointer size. else if (Offset != 0) - C.ImmCost += APInt(64, Offset, true).getMinSignedBits(); + C.ImmCost += APInt(64, Offset, true).getSignificantBits(); // Check with target if this offset with this instruction is // specifically not supported. @@ -2498,7 +2503,7 @@ LSRInstance::OptimizeLoopTermCond() { if (C->isOne() || C->isMinusOne()) goto decline_post_inc; // Avoid weird situations. - if (C->getValue().getMinSignedBits() >= 64 || + if (C->getValue().getSignificantBits() >= 64 || C->getValue().isMinSignedValue()) goto decline_post_inc; // Check for possible scaled-address reuse. @@ -2508,13 +2513,13 @@ LSRInstance::OptimizeLoopTermCond() { int64_t Scale = C->getSExtValue(); if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, + /*HasBaseReg=*/true, Scale, AccessTy.AddrSpace)) goto decline_post_inc; Scale = -Scale; if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, /*BaseOffset=*/0, - /*HasBaseReg=*/false, Scale, + /*HasBaseReg=*/true, Scale, AccessTy.AddrSpace)) goto decline_post_inc; } @@ -2660,8 +2665,7 @@ LSRUse * LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF, const LSRUse &OrigLU) { // Search all uses for the formula. This could be more clever. - for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { - LSRUse &LU = Uses[LUIdx]; + for (LSRUse &LU : Uses) { // Check whether this use is close enough to OrigLU, to see whether it's // worthwhile looking through its formulae. // Ignore ICmpZero uses because they may contain formulae generated by @@ -2703,6 +2707,8 @@ void LSRInstance::CollectInterestingTypesAndFactors() { SmallVector<const SCEV *, 4> Worklist; for (const IVStrideUse &U : IU) { const SCEV *Expr = IU.getExpr(U); + if (!Expr) + continue; // Collect interesting types. Types.insert(SE.getEffectiveSCEVType(Expr->getType())); @@ -2740,13 +2746,13 @@ void LSRInstance::CollectInterestingTypesAndFactors() { if (const SCEVConstant *Factor = dyn_cast_or_null<SCEVConstant>(getExactSDiv(NewStride, OldStride, SE, true))) { - if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero()) + if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero()) Factors.insert(Factor->getAPInt().getSExtValue()); } else if (const SCEVConstant *Factor = dyn_cast_or_null<SCEVConstant>(getExactSDiv(OldStride, NewStride, SE, true))) { - if (Factor->getAPInt().getMinSignedBits() <= 64 && !Factor->isZero()) + if (Factor->getAPInt().getSignificantBits() <= 64 && !Factor->isZero()) Factors.insert(Factor->getAPInt().getSExtValue()); } } @@ -2812,9 +2818,10 @@ static bool isCompatibleIVType(Value *LVal, Value *RVal) { /// SCEVUnknown, we simply return the rightmost SCEV operand. static const SCEV *getExprBase(const SCEV *S) { switch (S->getSCEVType()) { - default: // uncluding scUnknown. + default: // including scUnknown. return S; case scConstant: + case scVScale: return nullptr; case scTruncate: return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand()); @@ -3175,7 +3182,7 @@ static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) return false; - if (IncConst->getAPInt().getMinSignedBits() > 64) + if (IncConst->getAPInt().getSignificantBits() > 64) return false; MemAccessTy AccessTy = getAccessType(TTI, UserInst, Operand); @@ -3320,6 +3327,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { } const SCEV *S = IU.getExpr(U); + if (!S) + continue; PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops(); // Equality (== and !=) ICmps are special. We can rewrite (i == N) as @@ -3352,6 +3361,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // S is normalized, so normalize N before folding it into S // to keep the result normalized. N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); + if (!N) + continue; Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); } else if (L->isLoopInvariant(NV) && @@ -3366,6 +3377,8 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { // SCEV can't compute the difference of two unknown pointers. N = SE.getUnknown(NV); N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); + if (!N) + continue; Kind = LSRUse::ICmpZero; S = SE.getMinusSCEV(N, S); assert(!isa<SCEVCouldNotCompute>(S)); @@ -3494,8 +3507,8 @@ LSRInstance::CollectLoopInvariantFixupsAndFormulae() { if (const Instruction *Inst = dyn_cast<Instruction>(V)) { // Look for instructions defined outside the loop. if (L->contains(Inst)) continue; - } else if (isa<UndefValue>(V)) - // Undef doesn't have a live range, so it doesn't matter. + } else if (isa<Constant>(V)) + // Constants can be re-materialized. continue; for (const Use &U : V->uses()) { const Instruction *UserInst = dyn_cast<Instruction>(U.getUser()); @@ -4137,6 +4150,29 @@ void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { } } +/// Extend/Truncate \p Expr to \p ToTy considering post-inc uses in \p Loops. +/// For all PostIncLoopSets in \p Loops, first de-normalize \p Expr, then +/// perform the extension/truncate and normalize again, as the normalized form +/// can result in folds that are not valid in the post-inc use contexts. The +/// expressions for all PostIncLoopSets must match, otherwise return nullptr. +static const SCEV * +getAnyExtendConsideringPostIncUses(ArrayRef<PostIncLoopSet> Loops, + const SCEV *Expr, Type *ToTy, + ScalarEvolution &SE) { + const SCEV *Result = nullptr; + for (auto &L : Loops) { + auto *DenormExpr = denormalizeForPostIncUse(Expr, L, SE); + const SCEV *NewDenormExpr = SE.getAnyExtendExpr(DenormExpr, ToTy); + const SCEV *New = normalizeForPostIncUse(NewDenormExpr, L, SE); + if (!New || (Result && New != Result)) + return nullptr; + Result = New; + } + + assert(Result && "failed to create expression"); + return Result; +} + /// Generate reuse formulae from different IV types. void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { // Don't bother truncating symbolic values. @@ -4156,6 +4192,10 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { [](const SCEV *S) { return S->getType()->isPointerTy(); })) return; + SmallVector<PostIncLoopSet> Loops; + for (auto &LF : LU.Fixups) + Loops.push_back(LF.PostIncLoops); + for (Type *SrcTy : Types) { if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) { Formula F = Base; @@ -4165,15 +4205,17 @@ void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { // initial node (maybe due to depth limitations), but it can do them while // taking ext. if (F.ScaledReg) { - const SCEV *NewScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy); - if (NewScaledReg->isZero()) - continue; + const SCEV *NewScaledReg = + getAnyExtendConsideringPostIncUses(Loops, F.ScaledReg, SrcTy, SE); + if (!NewScaledReg || NewScaledReg->isZero()) + continue; F.ScaledReg = NewScaledReg; } bool HasZeroBaseReg = false; for (const SCEV *&BaseReg : F.BaseRegs) { - const SCEV *NewBaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy); - if (NewBaseReg->isZero()) { + const SCEV *NewBaseReg = + getAnyExtendConsideringPostIncUses(Loops, BaseReg, SrcTy, SE); + if (!NewBaseReg || NewBaseReg->isZero()) { HasZeroBaseReg = true; break; } @@ -4379,8 +4421,8 @@ void LSRInstance::GenerateCrossUseConstantOffsets() { if ((C->getAPInt() + NewF.BaseOffset) .abs() .slt(std::abs(NewF.BaseOffset)) && - (C->getAPInt() + NewF.BaseOffset).countTrailingZeros() >= - countTrailingZeros<uint64_t>(NewF.BaseOffset)) + (C->getAPInt() + NewF.BaseOffset).countr_zero() >= + (unsigned)llvm::countr_zero<uint64_t>(NewF.BaseOffset)) goto skip_formula; // Ok, looks good. @@ -4982,6 +5024,32 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { LLVM_DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); } +// Check if Best and Reg are SCEVs separated by a constant amount C, and if so +// would the addressing offset +C would be legal where the negative offset -C is +// not. +static bool IsSimplerBaseSCEVForTarget(const TargetTransformInfo &TTI, + ScalarEvolution &SE, const SCEV *Best, + const SCEV *Reg, + MemAccessTy AccessType) { + if (Best->getType() != Reg->getType() || + (isa<SCEVAddRecExpr>(Best) && isa<SCEVAddRecExpr>(Reg) && + cast<SCEVAddRecExpr>(Best)->getLoop() != + cast<SCEVAddRecExpr>(Reg)->getLoop())) + return false; + const auto *Diff = dyn_cast<SCEVConstant>(SE.getMinusSCEV(Best, Reg)); + if (!Diff) + return false; + + return TTI.isLegalAddressingMode( + AccessType.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/Diff->getAPInt().getSExtValue(), + /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace) && + !TTI.isLegalAddressingMode( + AccessType.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/-Diff->getAPInt().getSExtValue(), + /*HasBaseReg=*/true, /*Scale=*/0, AccessType.AddrSpace); +} + /// Pick a register which seems likely to be profitable, and then in any use /// which has any reference to that register, delete all formulae which do not /// reference that register. @@ -5010,6 +5078,19 @@ void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { Best = Reg; BestNum = Count; } + + // If the scores are the same, but the Reg is simpler for the target + // (for example {x,+,1} as opposed to {x+C,+,1}, where the target can + // handle +C but not -C), opt for the simpler formula. + if (Count == BestNum) { + int LUIdx = RegUses.getUsedByIndices(Reg).find_first(); + if (LUIdx >= 0 && Uses[LUIdx].Kind == LSRUse::Address && + IsSimplerBaseSCEVForTarget(TTI, SE, Best, Reg, + Uses[LUIdx].AccessTy)) { + Best = Reg; + BestNum = Count; + } + } } } assert(Best && "Failed to find best LSRUse candidate"); @@ -5497,6 +5578,13 @@ void LSRInstance::RewriteForPHI( PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F, SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { DenseMap<BasicBlock *, Value *> Inserted; + + // Inserting instructions in the loop and using them as PHI's input could + // break LCSSA in case if PHI's parent block is not a loop exit (i.e. the + // corresponding incoming block is not loop exiting). So collect all such + // instructions to form LCSSA for them later. + SmallVector<Instruction *, 4> InsertedNonLCSSAInsts; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) if (PN->getIncomingValue(i) == LF.OperandValToReplace) { bool needUpdateFixups = false; @@ -5562,6 +5650,13 @@ void LSRInstance::RewriteForPHI( FullV, LF.OperandValToReplace->getType(), "tmp", BB->getTerminator()); + // If the incoming block for this value is not in the loop, it means the + // current PHI is not in a loop exit, so we must create a LCSSA PHI for + // the inserted value. + if (auto *I = dyn_cast<Instruction>(FullV)) + if (L->contains(I) && !L->contains(BB)) + InsertedNonLCSSAInsts.push_back(I); + PN->setIncomingValue(i, FullV); Pair.first->second = FullV; } @@ -5604,6 +5699,8 @@ void LSRInstance::RewriteForPHI( } } } + + formLCSSAForInstructions(InsertedNonLCSSAInsts, DT, LI, &SE); } /// Emit instructions for the leading candidate expression for this LSRUse (this @@ -5643,6 +5740,36 @@ void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, DeadInsts.emplace_back(OperandIsInstr); } +// Trying to hoist the IVInc to loop header if all IVInc users are in +// the loop header. It will help backend to generate post index load/store +// when the latch block is different from loop header block. +static bool canHoistIVInc(const TargetTransformInfo &TTI, const LSRFixup &Fixup, + const LSRUse &LU, Instruction *IVIncInsertPos, + Loop *L) { + if (LU.Kind != LSRUse::Address) + return false; + + // For now this code do the conservative optimization, only work for + // the header block. Later we can hoist the IVInc to the block post + // dominate all users. + BasicBlock *LHeader = L->getHeader(); + if (IVIncInsertPos->getParent() == LHeader) + return false; + + if (!Fixup.OperandValToReplace || + any_of(Fixup.OperandValToReplace->users(), [&LHeader](User *U) { + Instruction *UI = cast<Instruction>(U); + return UI->getParent() != LHeader; + })) + return false; + + Instruction *I = Fixup.UserInst; + Type *Ty = I->getType(); + return Ty->isIntegerTy() && + ((isa<LoadInst>(I) && TTI.isIndexedLoadLegal(TTI.MIM_PostInc, Ty)) || + (isa<StoreInst>(I) && TTI.isIndexedStoreLegal(TTI.MIM_PostInc, Ty))); +} + /// Rewrite all the fixup locations with new values, following the chosen /// solution. void LSRInstance::ImplementSolution( @@ -5651,8 +5778,6 @@ void LSRInstance::ImplementSolution( // we can remove them after we are done working. SmallVector<WeakTrackingVH, 16> DeadInsts; - Rewriter.setIVIncInsertPos(L, IVIncInsertPos); - // Mark phi nodes that terminate chains so the expander tries to reuse them. for (const IVChain &Chain : IVChainVec) { if (PHINode *PN = dyn_cast<PHINode>(Chain.tailUserInst())) @@ -5662,6 +5787,11 @@ void LSRInstance::ImplementSolution( // Expand the new value definitions and update the users. for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { + Instruction *InsertPos = + canHoistIVInc(TTI, Fixup, Uses[LUIdx], IVIncInsertPos, L) + ? L->getHeader()->getTerminator() + : IVIncInsertPos; + Rewriter.setIVIncInsertPos(L, InsertPos); Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], DeadInsts); Changed = true; } @@ -5994,7 +6124,7 @@ struct SCEVDbgValueBuilder { } bool pushConst(const SCEVConstant *C) { - if (C->getAPInt().getMinSignedBits() > 64) + if (C->getAPInt().getSignificantBits() > 64) return false; Expr.push_back(llvm::dwarf::DW_OP_consts); Expr.push_back(C->getAPInt().getSExtValue()); @@ -6083,7 +6213,7 @@ struct SCEVDbgValueBuilder { /// SCEV constant value is an identity function. bool isIdentityFunction(uint64_t Op, const SCEV *S) { if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { - if (C->getAPInt().getMinSignedBits() > 64) + if (C->getAPInt().getSignificantBits() > 64) return false; int64_t I = C->getAPInt().getSExtValue(); switch (Op) { @@ -6338,13 +6468,13 @@ static void UpdateDbgValueInst(DVIRecoveryRec &DVIRec, } } -/// Cached location ops may be erased during LSR, in which case an undef is +/// Cached location ops may be erased during LSR, in which case a poison is /// required when restoring from the cache. The type of that location is no -/// longer available, so just use int8. The undef will be replaced by one or +/// longer available, so just use int8. The poison will be replaced by one or /// more locations later when a SCEVDbgValueBuilder selects alternative /// locations to use for the salvage. -static Value *getValueOrUndef(WeakVH &VH, LLVMContext &C) { - return (VH) ? VH : UndefValue::get(llvm::Type::getInt8Ty(C)); +static Value *getValueOrPoison(WeakVH &VH, LLVMContext &C) { + return (VH) ? VH : PoisonValue::get(llvm::Type::getInt8Ty(C)); } /// Restore the DVI's pre-LSR arguments. Substitute undef for any erased values. @@ -6363,12 +6493,12 @@ static void restorePreTransformState(DVIRecoveryRec &DVIRec) { // this case was not present before, so force the location back to a single // uncontained Value. Value *CachedValue = - getValueOrUndef(DVIRec.LocationOps[0], DVIRec.DVI->getContext()); + getValueOrPoison(DVIRec.LocationOps[0], DVIRec.DVI->getContext()); DVIRec.DVI->setRawLocation(ValueAsMetadata::get(CachedValue)); } else { SmallVector<ValueAsMetadata *, 3> MetadataLocs; for (WeakVH VH : DVIRec.LocationOps) { - Value *CachedValue = getValueOrUndef(VH, DVIRec.DVI->getContext()); + Value *CachedValue = getValueOrPoison(VH, DVIRec.DVI->getContext()); MetadataLocs.push_back(ValueAsMetadata::get(CachedValue)); } auto ValArrayRef = llvm::ArrayRef<llvm::ValueAsMetadata *>(MetadataLocs); @@ -6431,7 +6561,7 @@ static bool SalvageDVI(llvm::Loop *L, ScalarEvolution &SE, // less DWARF ops than an iteration count-based expression. if (std::optional<APInt> Offset = SE.computeConstantDifference(DVIRec.SCEVs[i], SCEVInductionVar)) { - if (Offset->getMinSignedBits() <= 64) + if (Offset->getSignificantBits() <= 64) SalvageExpr->createOffsetExpr(Offset->getSExtValue(), LSRInductionVar); } else if (!SalvageExpr->createIterCountExpr(DVIRec.SCEVs[i], IterCountExpr, SE)) @@ -6607,7 +6737,7 @@ static llvm::PHINode *GetInductionVariable(const Loop &L, ScalarEvolution &SE, return nullptr; } -static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *>> +static std::optional<std::tuple<PHINode *, PHINode *, const SCEV *, bool>> canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, const LoopInfo &LI) { if (!L->isInnermost()) { @@ -6626,16 +6756,13 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, } BasicBlock *LoopLatch = L->getLoopLatch(); - - // TODO: Can we do something for greater than and less than? - // Terminating condition is foldable when it is an eq/ne icmp - BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); - if (BI->isUnconditional()) + BranchInst *BI = dyn_cast<BranchInst>(LoopLatch->getTerminator()); + if (!BI || BI->isUnconditional()) return std::nullopt; - Value *TermCond = BI->getCondition(); - if (!isa<ICmpInst>(TermCond) || !cast<ICmpInst>(TermCond)->isEquality()) { - LLVM_DEBUG(dbgs() << "Cannot fold on branching condition that is not an " - "ICmpInst::eq / ICmpInst::ne\n"); + auto *TermCond = dyn_cast<ICmpInst>(BI->getCondition()); + if (!TermCond) { + LLVM_DEBUG( + dbgs() << "Cannot fold on branching condition that is not an ICmpInst"); return std::nullopt; } if (!TermCond->hasOneUse()) { @@ -6645,89 +6772,42 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, return std::nullopt; } - // For `IsToFold`, a primary IV can be replaced by other affine AddRec when it - // is only used by the terminating condition. To check for this, we may need - // to traverse through a chain of use-def until we can examine the final - // usage. - // *----------------------* - // *---->| LoopHeader: | - // | | PrimaryIV = phi ... | - // | *----------------------* - // | | - // | | - // | chain of - // | single use - // used by | - // phi | - // | Value - // | / \ - // | chain of chain of - // | single use single use - // | / \ - // | / \ - // *- Value Value --> used by terminating condition - auto IsToFold = [&](PHINode &PN) -> bool { - Value *V = &PN; - - while (V->getNumUses() == 1) - V = *V->user_begin(); - - if (V->getNumUses() != 2) - return false; + BinaryOperator *LHS = dyn_cast<BinaryOperator>(TermCond->getOperand(0)); + Value *RHS = TermCond->getOperand(1); + if (!LHS || !L->isLoopInvariant(RHS)) + // We could pattern match the inverse form of the icmp, but that is + // non-canonical, and this pass is running *very* late in the pipeline. + return std::nullopt; - Value *VToPN = nullptr; - Value *VToTermCond = nullptr; - for (User *U : V->users()) { - while (U->getNumUses() == 1) { - if (isa<PHINode>(U)) - VToPN = U; - if (U == TermCond) - VToTermCond = U; - U = *U->user_begin(); - } - } - return VToPN && VToTermCond; - }; + // Find the IV used by the current exit condition. + PHINode *ToFold; + Value *ToFoldStart, *ToFoldStep; + if (!matchSimpleRecurrence(LHS, ToFold, ToFoldStart, ToFoldStep)) + return std::nullopt; - // If this is an IV which we could replace the terminating condition, return - // the final value of the alternative IV on the last iteration. - auto getAlternateIVEnd = [&](PHINode &PN) -> const SCEV * { - // FIXME: This does not properly account for overflow. - const SCEVAddRecExpr *AddRec = cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); - const SCEV *BECount = SE.getBackedgeTakenCount(L); - const SCEV *TermValueS = SE.getAddExpr( - AddRec->getOperand(0), - SE.getTruncateOrZeroExtend( - SE.getMulExpr( - AddRec->getOperand(1), - SE.getTruncateOrZeroExtend( - SE.getAddExpr(BECount, SE.getOne(BECount->getType())), - AddRec->getOperand(1)->getType())), - AddRec->getOperand(0)->getType())); - const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); - SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - if (!Expander.isSafeToExpand(TermValueS)) { - LLVM_DEBUG( - dbgs() << "Is not safe to expand terminating value for phi node" << PN - << "\n"); - return nullptr; - } - return TermValueS; - }; + // If that IV isn't dead after we rewrite the exit condition in terms of + // another IV, there's no point in doing the transform. + if (!isAlmostDeadIV(ToFold, LoopLatch, TermCond)) + return std::nullopt; + + const SCEV *BECount = SE.getBackedgeTakenCount(L); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); - PHINode *ToFold = nullptr; PHINode *ToHelpFold = nullptr; const SCEV *TermValueS = nullptr; - + bool MustDropPoison = false; for (PHINode &PN : L->getHeader()->phis()) { + if (ToFold == &PN) + continue; + if (!SE.isSCEVable(PN.getType())) { LLVM_DEBUG(dbgs() << "IV of phi '" << PN << "' is not SCEV-able, not qualified for the " "terminating condition folding.\n"); continue; } - const SCEV *S = SE.getSCEV(&PN); - const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S); + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(&PN)); // Only speculate on affine AddRec if (!AddRec || !AddRec->isAffine()) { LLVM_DEBUG(dbgs() << "SCEV of phi '" << PN @@ -6736,12 +6816,63 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, continue; } - if (IsToFold(PN)) - ToFold = &PN; - else if (auto P = getAlternateIVEnd(PN)) { - ToHelpFold = &PN; - TermValueS = P; + // Check that we can compute the value of AddRec on the exiting iteration + // without soundness problems. evaluateAtIteration internally needs + // to multiply the stride of the iteration number - which may wrap around. + // The issue here is subtle because computing the result accounting for + // wrap is insufficient. In order to use the result in an exit test, we + // must also know that AddRec doesn't take the same value on any previous + // iteration. The simplest case to consider is a candidate IV which is + // narrower than the trip count (and thus original IV), but this can + // also happen due to non-unit strides on the candidate IVs. + if (!AddRec->hasNoSelfWrap()) + continue; + + const SCEVAddRecExpr *PostInc = AddRec->getPostIncExpr(SE); + const SCEV *TermValueSLocal = PostInc->evaluateAtIteration(BECount, SE); + if (!Expander.isSafeToExpand(TermValueSLocal)) { + LLVM_DEBUG( + dbgs() << "Is not safe to expand terminating value for phi node" << PN + << "\n"); + continue; } + + // The candidate IV may have been otherwise dead and poison from the + // very first iteration. If we can't disprove that, we can't use the IV. + if (!mustExecuteUBIfPoisonOnPathTo(&PN, LoopLatch->getTerminator(), &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety for IV " + << PN << "\n"); + continue; + } + + // The candidate IV may become poison on the last iteration. If this + // value is not branched on, this is a well defined program. We're + // about to add a new use to this IV, and we have to ensure we don't + // insert UB which didn't previously exist. + bool MustDropPoisonLocal = false; + Instruction *PostIncV = + cast<Instruction>(PN.getIncomingValueForBlock(LoopLatch)); + if (!mustExecuteUBIfPoisonOnPathTo(PostIncV, LoopLatch->getTerminator(), + &DT)) { + LLVM_DEBUG(dbgs() << "Can not prove poison safety to insert use" + << PN << "\n"); + + // If this is a complex recurrance with multiple instructions computing + // the backedge value, we might need to strip poison flags from all of + // them. + if (PostIncV->getOperand(0) != &PN) + continue; + + // In order to perform the transform, we need to drop the poison generating + // flags on this instruction (if any). + MustDropPoisonLocal = PostIncV->hasPoisonGeneratingFlags(); + } + + // We pick the last legal alternate IV. We could expore choosing an optimal + // alternate IV if we had a decent heuristic to do so. + ToHelpFold = &PN; + TermValueS = TermValueSLocal; + MustDropPoison = MustDropPoisonLocal; } LLVM_DEBUG(if (ToFold && !ToHelpFold) dbgs() @@ -6757,7 +6888,7 @@ canFoldTermCondOfLoop(Loop *L, ScalarEvolution &SE, DominatorTree &DT, if (!ToFold || !ToHelpFold) return std::nullopt; - return std::make_tuple(ToFold, ToHelpFold, TermValueS); + return std::make_tuple(ToFold, ToHelpFold, TermValueS, MustDropPoison); } static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, @@ -6820,7 +6951,7 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, if (AllowTerminatingConditionFoldingAfterLSR) { if (auto Opt = canFoldTermCondOfLoop(L, SE, DT, LI)) { - auto [ToFold, ToHelpFold, TermValueS] = *Opt; + auto [ToFold, ToHelpFold, TermValueS, MustDrop] = *Opt; Changed = true; NumTermFold++; @@ -6838,6 +6969,10 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, (void)StartValue; Value *LoopValue = ToHelpFold->getIncomingValueForBlock(LoopLatch); + // See comment in canFoldTermCondOfLoop on why this is sufficient. + if (MustDrop) + cast<Instruction>(LoopValue)->dropPoisonGeneratingFlags(); + // SCEVExpander for both use in preheader and latch const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); SCEVExpander Expander(SE, DL, "lsr_fold_term_cond"); @@ -6859,11 +6994,12 @@ static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, BranchInst *BI = cast<BranchInst>(LoopLatch->getTerminator()); ICmpInst *OldTermCond = cast<ICmpInst>(BI->getCondition()); IRBuilder<> LatchBuilder(LoopLatch->getTerminator()); - // FIXME: We are adding a use of an IV here without account for poison safety. - // This is incorrect. - Value *NewTermCond = LatchBuilder.CreateICmp( - OldTermCond->getPredicate(), LoopValue, TermValue, - "lsr_fold_term_cond.replaced_term_cond"); + Value *NewTermCond = + LatchBuilder.CreateICmp(CmpInst::ICMP_EQ, LoopValue, TermValue, + "lsr_fold_term_cond.replaced_term_cond"); + // Swap successors to exit loop body if IV equals to new TermValue + if (BI->getSuccessor(0) == L->getHeader()) + BI->swapSuccessors(); LLVM_DEBUG(dbgs() << "Old term-cond:\n" << *OldTermCond << "\n" |