diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 1428 |
1 files changed, 966 insertions, 462 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index f34549ae52b4..aa95ace93014 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -83,6 +83,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Config/llvm-config.h" #include "llvm/IR/Argument.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" @@ -420,24 +421,21 @@ SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scTruncate, op, ty) { - assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate non-integer value!"); } SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scZeroExtend, op, ty) { - assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot zero extend non-integer value!"); } SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID, const SCEV *op, Type *ty) : SCEVCastExpr(ID, scSignExtend, op, ty) { - assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot sign extend non-integer value!"); } @@ -1255,42 +1253,32 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op)) return getTruncateOrZeroExtend(SZ->getOperand(), Ty); - // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can - // eliminate all the truncates, or we replace other casts with truncates. - if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) { + // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and + // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN), + // if after transforming we have at most one truncate, not counting truncates + // that replace other casts. + if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) { + auto *CommOp = cast<SCEVCommutativeExpr>(Op); SmallVector<const SCEV *, 4> Operands; - bool hasTrunc = false; - for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) { - const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty); - if (!isa<SCEVCastExpr>(SA->getOperand(i))) - hasTrunc = isa<SCEVTruncateExpr>(S); + unsigned numTruncs = 0; + for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2; + ++i) { + const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty); + if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S)) + numTruncs++; Operands.push_back(S); } - if (!hasTrunc) - return getAddExpr(Operands); - // In spite we checked in the beginning that ID is not in the cache, - // it is possible that during recursion and different modification - // ID came to cache, so if we found it, just return it. - if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) - return S; - } - - // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can - // eliminate all the truncates, or we replace other casts with truncates. - if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) { - SmallVector<const SCEV *, 4> Operands; - bool hasTrunc = false; - for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) { - const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty); - if (!isa<SCEVCastExpr>(SM->getOperand(i))) - hasTrunc = isa<SCEVTruncateExpr>(S); - Operands.push_back(S); + if (numTruncs < 2) { + if (isa<SCEVAddExpr>(Op)) + return getAddExpr(Operands); + else if (isa<SCEVMulExpr>(Op)) + return getMulExpr(Operands); + else + llvm_unreachable("Unexpected SCEV type for Op."); } - if (!hasTrunc) - return getMulExpr(Operands); - // In spite we checked in the beginning that ID is not in the cache, - // it is possible that during recursion and different modification - // ID came to cache, so if we found it, just return it. + // Although we checked in the beginning that ID is not in the cache, it is + // possible that during recursion and different modification ID was inserted + // into the cache. So if we find it, just return it. if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S; } @@ -1571,6 +1559,43 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, return false; } +// Finds an integer D for an expression (C + x + y + ...) such that the top +// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or +// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is +// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and +// the (C + x + y + ...) expression is \p WholeAddExpr. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, + const SCEVConstant *ConstantTerm, + const SCEVAddExpr *WholeAddExpr) { + const APInt C = ConstantTerm->getAPInt(); + const unsigned BitWidth = C.getBitWidth(); + // Find number of trailing zeros of (x + y + ...) w/o the C first: + uint32_t TZ = BitWidth; + for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I) + TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I))); + if (TZ) { + // Set D to be as many least significant bits of C as possible while still + // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap: + return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C; + } + return APInt(BitWidth, 0); +} + +// Finds an integer D for an affine AddRec expression {C,+,x} such that the top +// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the +// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p +// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count. +static APInt extractConstantWithoutWrapping(ScalarEvolution &SE, + const APInt &ConstantStart, + const SCEV *Step) { + const unsigned BitWidth = ConstantStart.getBitWidth(); + const uint32_t TZ = SE.GetMinTrailingZeros(Step); + if (TZ) + return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth) + : ConstantStart; + return APInt(BitWidth, 0); +} + const SCEV * ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -1727,9 +1752,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { const SCEV *N = getConstant(APInt::getMinValue(BitWidth) - getUnsignedRangeMax(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) || - (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) && - isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, - AR->getPostIncExpr(*this), N))) { + isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) { // Cache knowledge of AR NUW, which is propagated to this // AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); @@ -1744,9 +1767,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) - getSignedRangeMin(Step)); if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) || - (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) && - isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, - AR->getPostIncExpr(*this), N))) { + isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) { // Cache knowledge of AR NW, which is propagated to this // AddRec. Negative step causes unsigned wrap, but it // still can't self-wrap. @@ -1761,6 +1782,23 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } } + // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw> + // if D + (C - D + Step * n) could be proven to not unsigned wrap + // where D maximizes the number of trailing zeros of (C - D + Step * n) + if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { + const APInt &C = SC->getAPInt(); + const APInt &D = extractConstantWithoutWrapping(*this, C, Step); + if (D != 0) { + const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); + const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SZExtD, SZExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) { const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW); return getAddRecExpr( @@ -1769,6 +1807,20 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } } + // zext(A % B) --> zext(A) % zext(B) + { + const SCEV *LHS; + const SCEV *RHS; + if (matchURem(Op, LHS, RHS)) + return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1), + getZeroExtendExpr(RHS, Ty, Depth + 1)); + } + + // zext(A / B) --> zext(A) / zext(B). + if (auto *Div = dyn_cast<SCEVUDivExpr>(Op)) + return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1), + getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1)); + if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { // zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw> if (SA->hasNoUnsignedWrap()) { @@ -1779,6 +1831,65 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1); } + + // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...)) + // if D + (C - D + x + y + ...) could be proven to not unsigned wrap + // where D maximizes the number of trailing zeros of (C - D + x + y + ...) + // + // Often address arithmetics contain expressions like + // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))). + // This transformation is useful while proving that such expressions are + // equal or differ by a small constant amount, see LoadStoreVectorizer pass. + if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { + const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); + if (D != 0) { + const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); + const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SZExtD, SZExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } + } + + if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) { + // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw> + if (SM->hasNoUnsignedWrap()) { + // If the multiply does not unsign overflow then we can, by definition, + // commute the zero extension with the multiply operation. + SmallVector<const SCEV *, 4> Ops; + for (const auto *Op : SM->operands()) + Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1)); + return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1); + } + + // zext(2^K * (trunc X to iN)) to iM -> + // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw> + // + // Proof: + // + // zext(2^K * (trunc X to iN)) to iM + // = zext((trunc X to iN) << K) to iM + // = zext((trunc X to i{N-K}) << K)<nuw> to iM + // (because shl removes the top K bits) + // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM + // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>. + // + if (SM->getNumOperands() == 2) + if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0))) + if (MulLHS->getAPInt().isPowerOf2()) + if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) { + int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) - + MulLHS->getAPInt().logBase2(); + Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits); + return getMulExpr( + getZeroExtendExpr(MulLHS, Ty), + getZeroExtendExpr( + getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty), + SCEV::FlagNUW, Depth + 1); + } } // The cast wasn't folded; create an explicit cast node. @@ -1842,24 +1953,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { return getTruncateOrSignExtend(X, Ty); } - // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2 if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) { - if (SA->getNumOperands() == 2) { - auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0)); - auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1)); - if (SMul && SC1) { - if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) { - const APInt &C1 = SC1->getAPInt(); - const APInt &C2 = SC2->getAPInt(); - if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && - C2.ugt(C1) && C2.isPowerOf2()) - return getAddExpr(getSignExtendExpr(SC1, Ty, Depth + 1), - getSignExtendExpr(SMul, Ty, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); - } - } - } - // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> if (SA->hasNoSignedWrap()) { // If the addition does not sign overflow then we can, by definition, @@ -1869,6 +1963,28 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1)); return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1); } + + // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...)) + // if D + (C - D + x + y + ...) could be proven to not signed wrap + // where D maximizes the number of trailing zeros of (C - D + x + y + ...) + // + // For instance, this will bring two seemingly different expressions: + // 1 + sext(5 + 20 * %x + 24 * %y) and + // sext(6 + 20 * %x + 24 * %y) + // to the same form: + // 2 + sext(4 + 20 * %x + 24 * %y) + if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) { + const APInt &D = extractConstantWithoutWrapping(*this, SC, SA); + if (D != 0) { + const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth); + const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SSExtD, SSExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); + } + } } // If the input value is a chrec scev, and we can prove that the value // did not overflow the old, smaller, value, we can sign extend all of the @@ -1989,9 +2105,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { getSignedOverflowLimitForStep(Step, &Pred, this); if (OverflowLimit && (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) || - (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) && - isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this), - OverflowLimit)))) { + isKnownOnEveryIteration(Pred, AR, OverflowLimit))) { // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec. const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW); return getAddRecExpr( @@ -2000,21 +2114,20 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { } } - // If Start and Step are constants, check if we can apply this - // transformation: - // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2 - auto *SC1 = dyn_cast<SCEVConstant>(Start); - auto *SC2 = dyn_cast<SCEVConstant>(Step); - if (SC1 && SC2) { - const APInt &C1 = SC1->getAPInt(); - const APInt &C2 = SC2->getAPInt(); - if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) && - C2.isPowerOf2()) { - Start = getSignExtendExpr(Start, Ty, Depth + 1); - const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L, - AR->getNoWrapFlags()); - return getAddExpr(Start, getSignExtendExpr(NewAR, Ty, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw> + // if D + (C - D + Step * n) could be proven to not signed wrap + // where D maximizes the number of trailing zeros of (C - D + Step * n) + if (const auto *SC = dyn_cast<SCEVConstant>(Start)) { + const APInt &C = SC->getAPInt(); + const APInt &D = extractConstantWithoutWrapping(*this, C, Step); + if (D != 0) { + const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth); + const SCEV *SResidual = + getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags()); + const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1); + return getAddExpr(SSExtD, SSExtR, + (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW), + Depth + 1); } } @@ -2210,22 +2323,35 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask); - if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr && - Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) { + if (SignOrUnsignWrap != SignOrUnsignMask && + (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 && + isa<SCEVConstant>(Ops[0])) { - // (A + C) --> (A + C)<nsw> if the addition does not sign overflow - // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow + auto Opcode = [&] { + switch (Type) { + case scAddExpr: + return Instruction::Add; + case scMulExpr: + return Instruction::Mul; + default: + llvm_unreachable("Unexpected SCEV op."); + } + }(); const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt(); + + // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow. if (!(SignOrUnsignWrap & SCEV::FlagNSW)) { auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion( - Instruction::Add, C, OBO::NoSignedWrap); + Opcode, C, OBO::NoSignedWrap); if (NSWRegion.contains(SE->getSignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW); } + + // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow. if (!(SignOrUnsignWrap & SCEV::FlagNUW)) { auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion( - Instruction::Add, C, OBO::NoUnsignedWrap); + Opcode, C, OBO::NoUnsignedWrap); if (NUWRegion.contains(SE->getUnsignedRange(Ops[1]))) Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW); } @@ -2235,59 +2361,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, } bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) { - if (!isLoopInvariant(S, L)) - return false; - // If a value depends on a SCEVUnknown which is defined after the loop, we - // conservatively assume that we cannot calculate it at the loop's entry. - struct FindDominatedSCEVUnknown { - bool Found = false; - const Loop *L; - DominatorTree &DT; - LoopInfo &LI; - - FindDominatedSCEVUnknown(const Loop *L, DominatorTree &DT, LoopInfo &LI) - : L(L), DT(DT), LI(LI) {} - - bool checkSCEVUnknown(const SCEVUnknown *SU) { - if (auto *I = dyn_cast<Instruction>(SU->getValue())) { - if (DT.dominates(L->getHeader(), I->getParent())) - Found = true; - else - assert(DT.dominates(I->getParent(), L->getHeader()) && - "No dominance relationship between SCEV and loop?"); - } - return false; - } - - bool follow(const SCEV *S) { - switch (static_cast<SCEVTypes>(S->getSCEVType())) { - case scConstant: - return false; - case scAddRecExpr: - case scTruncate: - case scZeroExtend: - case scSignExtend: - case scAddExpr: - case scMulExpr: - case scUMaxExpr: - case scSMaxExpr: - case scUDivExpr: - return true; - case scUnknown: - return checkSCEVUnknown(cast<SCEVUnknown>(S)); - case scCouldNotCompute: - llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); - } - return false; - } - - bool isDone() { return Found; } - }; - - FindDominatedSCEVUnknown FSU(L, DT, LI); - SCEVTraversal<FindDominatedSCEVUnknown> ST(FSU); - ST.visitAll(S); - return !FSU.Found; + return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader()); } /// Get a canonical add expression, or something simpler if possible. @@ -2358,7 +2432,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, FoundMatch = true; } if (FoundMatch) - return getAddExpr(Ops, Flags); + return getAddExpr(Ops, Flags, Depth + 1); // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be @@ -2418,7 +2492,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, } if (Ok) { // Evaluate the expression in the larger type. - const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); + const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) return getTruncateExpr(Fold, Ty); @@ -2796,22 +2870,21 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, unsigned Idx = 0; if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) { - // C1*(C2+V) -> C1*C2 + C1*V if (Ops.size() == 2) - if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) - // If any of Add's ops are Adds or Muls with a constant, - // apply this transformation as well. - if (Add->getNumOperands() == 2) - // TODO: There are some cases where this transformation is not - // profitable, for example: - // Add = (C0 + X) * Y + Z. - // Maybe the scope of this transformation should be narrowed down. - if (containsConstantInAddMulChain(Add)) - return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), - SCEV::FlagAnyWrap, Depth + 1), - getMulExpr(LHSC, Add->getOperand(1), - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + // C1*(C2+V) -> C1*C2 + C1*V + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) + // If any of Add's ops are Adds or Muls with a constant, apply this + // transformation as well. + // + // TODO: There are some cases where this transformation is not + // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of + // this transformation should be narrowed down. + if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add)) + return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), + SCEV::FlagAnyWrap, Depth + 1), + getMulExpr(LHSC, Add->getOperand(1), + SCEV::FlagAnyWrap, Depth + 1), + SCEV::FlagAnyWrap, Depth + 1); ++Idx; while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) { @@ -3123,6 +3196,21 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, } } } + + // (A/B)/C --> A/(B*C) if safe and B*C can be folded. + if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) { + if (auto *DivisorConstant = + dyn_cast<SCEVConstant>(OtherDiv->getRHS())) { + bool Overflow = false; + APInt NewRHS = + DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow); + if (Overflow) { + return getConstant(RHSC->getType(), 0, false); + } + return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS)); + } + } + // (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded. if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) { SmallVector<const SCEV *, 4> Operands; @@ -3574,12 +3662,13 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { for (unsigned i = 0, e = Ops.size()-1; i != e; ++i) // X umax Y umax Y --> X umax Y // X umax Y --> X, if X is always greater than Y - if (Ops[i] == Ops[i+1] || - isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) { - Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2); + if (Ops[i] == Ops[i + 1] || isKnownViaNonRecursiveReasoning( + ICmpInst::ICMP_UGE, Ops[i], Ops[i + 1])) { + Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2); --i; --e; - } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) { - Ops.erase(Ops.begin()+i, Ops.begin()+i+1); + } else if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, Ops[i], + Ops[i + 1])) { + Ops.erase(Ops.begin() + i, Ops.begin() + i + 1); --i; --e; } @@ -3606,14 +3695,35 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS, const SCEV *RHS) { - // ~smax(~x, ~y) == smin(x, y). - return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getSMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) { + // ~smax(~x, ~y, ~z) == smin(x, y, z). + SmallVector<const SCEV *, 2> NotOps; + for (auto *S : Ops) + NotOps.push_back(getNotSCEV(S)); + return getNotSCEV(getSMaxExpr(NotOps)); } const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS, const SCEV *RHS) { - // ~umax(~x, ~y) == umin(x, y) - return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS))); + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getUMinExpr(Ops); +} + +const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) { + assert(!Ops.empty() && "At least one operand must be!"); + // Trivial case. + if (Ops.size() == 1) + return Ops[0]; + + // ~umax(~x, ~y, ~z) == umin(x, y, z). + SmallVector<const SCEV *, 2> NotOps; + for (auto *S : Ops) + NotOps.push_back(getNotSCEV(S)); + return getNotSCEV(getUMaxExpr(NotOps)); } const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) { @@ -3665,13 +3775,15 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) { /// target-specific information. bool ScalarEvolution::isSCEVable(Type *Ty) const { // Integers and pointers are always SCEVable. - return Ty->isIntegerTy() || Ty->isPointerTy(); + return Ty->isIntOrPtrTy(); } /// Return the size in bits of the specified type, for which isSCEVable must /// return true. uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const { assert(isSCEVable(Ty) && "Type is not SCEVable!"); + if (Ty->isPointerTy()) + return getDataLayout().getIndexTypeSizeInBits(Ty); return getDataLayout().getTypeSizeInBits(Ty); } @@ -3774,6 +3886,24 @@ void ScalarEvolution::eraseValueFromMap(Value *V) { } } +/// Check whether value has nuw/nsw/exact set but SCEV does not. +/// TODO: In reality it is better to check the poison recursevely +/// but this is better than nothing. +static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) { + if (auto *I = dyn_cast<Instruction>(V)) { + if (isa<OverflowingBinaryOperator>(I)) { + if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) { + if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap()) + return true; + if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap()) + return true; + } + } else if (isa<PossiblyExactOperator>(I) && I->isExact()) + return true; + } + return false; +} + /// Return an existing SCEV if it exists, otherwise analyze the expression and /// create a new one. const SCEV *ScalarEvolution::getSCEV(Value *V) { @@ -3787,7 +3917,7 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) { // ValueExprMap before insert S->{V, 0} into ExprValueMap. std::pair<ValueExprMapType::iterator, bool> Pair = ValueExprMap.insert({SCEVCallbackVH(V, this), S}); - if (Pair.second) { + if (Pair.second && !SCEVLostPoisonFlags(S, V)) { ExprValueMap[S].insert({V, nullptr}); // If S == Stripped + Offset, add Stripped -> {V, Offset} into @@ -3890,8 +4020,7 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS, const SCEV * ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -3904,8 +4033,7 @@ const SCEV * ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or zero extend with non-integer arguments!"); if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty)) return V; // No conversion @@ -3917,8 +4045,7 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V, const SCEV * ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or zero extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrZeroExtend cannot truncate!"); @@ -3930,8 +4057,7 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) { const SCEV * ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or sign extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrSignExtend cannot truncate!"); @@ -3943,8 +4069,7 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) { const SCEV * ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot noop or any extend with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) && "getNoopOrAnyExtend cannot truncate!"); @@ -3956,8 +4081,7 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) { const SCEV * ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) { Type *SrcTy = V->getType(); - assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && - (Ty->isIntegerTy() || Ty->isPointerTy()) && + assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() && "Cannot truncate or noop with non-integer arguments!"); assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) && "getTruncateOrNoop cannot extend!"); @@ -3981,15 +4105,32 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS, const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS, const SCEV *RHS) { - const SCEV *PromotedLHS = LHS; - const SCEV *PromotedRHS = RHS; + SmallVector<const SCEV *, 2> Ops = { LHS, RHS }; + return getUMinFromMismatchedTypes(Ops); +} + +const SCEV *ScalarEvolution::getUMinFromMismatchedTypes( + SmallVectorImpl<const SCEV *> &Ops) { + assert(!Ops.empty() && "At least one operand must be!"); + // Trivial case. + if (Ops.size() == 1) + return Ops[0]; + + // Find the max type first. + Type *MaxType = nullptr; + for (auto *S : Ops) + if (MaxType) + MaxType = getWiderType(MaxType, S->getType()); + else + MaxType = S->getType(); - if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType())) - PromotedRHS = getZeroExtendExpr(RHS, LHS->getType()); - else - PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType()); + // Extend all ops to max type. + SmallVector<const SCEV *, 2> PromotedOps; + for (auto *S : Ops) + PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType)); - return getUMinExpr(PromotedLHS, PromotedRHS); + // Generate umin. + return getUMinExpr(PromotedOps); } const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) { @@ -4066,37 +4207,90 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { namespace { +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start +/// expression in case its Loop is L. If it is not L then +/// if IgnoreOtherLoops is true then use AddRec itself +/// otherwise rewrite cannot be done. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { public: - static const SCEV *rewrite(const SCEV *S, const Loop *L, - ScalarEvolution &SE) { + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, + bool IgnoreOtherLoops = true) { SCEVInitRewriter Rewriter(L, SE); const SCEV *Result = Rewriter.visit(S); - return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); + if (Rewriter.hasSeenLoopVariantSCEVUnknown()) + return SE.getCouldNotCompute(); + return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops + ? SE.getCouldNotCompute() + : Result; } const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) - Valid = false; + SeenLoopVariantSCEVUnknown = true; return Expr; } const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { - // Only allow AddRecExprs for this loop. + // Only re-write AddRecExprs for this loop. if (Expr->getLoop() == L) return Expr->getStart(); - Valid = false; + SeenOtherLoops = true; return Expr; } - bool isValid() { return Valid; } + bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + + bool hasSeenOtherLoops() { return SeenOtherLoops; } private: explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) : SCEVRewriteVisitor(SE), L(L) {} const Loop *L; - bool Valid = true; + bool SeenLoopVariantSCEVUnknown = false; + bool SeenOtherLoops = false; +}; + +/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post +/// increment expression in case its Loop is L. If it is not L then +/// use AddRec itself. +/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done. +class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) { + SCEVPostIncRewriter Rewriter(L, SE); + const SCEV *Result = Rewriter.visit(S); + return Rewriter.hasSeenLoopVariantSCEVUnknown() + ? SE.getCouldNotCompute() + : Result; + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + if (!SE.isLoopInvariant(Expr, L)) + SeenLoopVariantSCEVUnknown = true; + return Expr; + } + + const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { + // Only re-write AddRecExprs for this loop. + if (Expr->getLoop() == L) + return Expr->getPostIncExpr(SE); + SeenOtherLoops = true; + return Expr; + } + + bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; } + + bool hasSeenOtherLoops() { return SeenOtherLoops; } + +private: + explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + + const Loop *L; + bool SeenLoopVariantSCEVUnknown = false; + bool SeenOtherLoops = false; }; /// This class evaluates the compare condition by matching it against the @@ -4668,7 +4862,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); if (PredIsKnownFalse(StartVal, StartExtended)) { - DEBUG(dbgs() << "P2 is compile-time false\n";); + LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";); return None; } @@ -4676,7 +4870,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI // NSSW or NUSW) const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); if (PredIsKnownFalse(Accum, AccumExtended)) { - DEBUG(dbgs() << "P3 is compile-time false\n";); + LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";); return None; } @@ -4685,7 +4879,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI if (Expr != ExtendedExpr && !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); - DEBUG (dbgs() << "Added Predicate: " << *Pred); + LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred); Predicates.push_back(Pred); } }; @@ -4948,7 +5142,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { // by one iteration: // PHI(f(0), f({1,+,1})) --> f({0,+,1}) const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this); - const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this); + const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false); if (Shifted != getCouldNotCompute() && Start != getCouldNotCompute()) { const SCEV *StartVal = getSCEV(StartValueV); @@ -5510,6 +5704,25 @@ ScalarEvolution::getRangeRef(const SCEV *S, APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1)); } + // A range of Phi is a subset of union of all ranges of its input. + if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) { + // Make sure that we do not run over cycled Phis. + if (PendingPhiRanges.insert(Phi).second) { + ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false); + for (auto &Op : Phi->operands()) { + auto OpRange = getRangeRef(getSCEV(Op), SignHint); + RangeFromOps = RangeFromOps.unionWith(OpRange); + // No point to continue if we already have a full set. + if (RangeFromOps.isFullSet()) + break; + } + ConservativeResult = ConservativeResult.intersectWith(RangeFromOps); + bool Erased = PendingPhiRanges.erase(Phi); + assert(Erased && "Failed to erase Phi properly?"); + (void) Erased; + } + } + return setRange(U, SignHint, std::move(ConservativeResult)); } @@ -6129,33 +6342,33 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } break; - case Instruction::Shl: - // Turn shift left of a constant amount into a multiply. - if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { - uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); - - // If the shift count is not less than the bitwidth, the result of - // the shift is undefined. Don't try to analyze it, because the - // resolution chosen here may differ from the resolution chosen in - // other parts of the compiler. - if (SA->getValue().uge(BitWidth)) - break; + case Instruction::Shl: + // Turn shift left of a constant amount into a multiply. + if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) { + uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth(); - // It is currently not resolved how to interpret NSW for left - // shift by BitWidth - 1, so we avoid applying flags in that - // case. Remove this check (or this comment) once the situation - // is resolved. See - // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html - // and http://reviews.llvm.org/D8890 . - auto Flags = SCEV::FlagAnyWrap; - if (BO->Op && SA->getValue().ult(BitWidth - 1)) - Flags = getNoWrapFlagsFromUB(BO->Op); + // If the shift count is not less than the bitwidth, the result of + // the shift is undefined. Don't try to analyze it, because the + // resolution chosen here may differ from the resolution chosen in + // other parts of the compiler. + if (SA->getValue().uge(BitWidth)) + break; - Constant *X = ConstantInt::get(getContext(), - APInt::getOneBitSet(BitWidth, SA->getZExtValue())); - return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); - } - break; + // It is currently not resolved how to interpret NSW for left + // shift by BitWidth - 1, so we avoid applying flags in that + // case. Remove this check (or this comment) once the situation + // is resolved. See + // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html + // and http://reviews.llvm.org/D8890 . + auto Flags = SCEV::FlagAnyWrap; + if (BO->Op && SA->getValue().ult(BitWidth - 1)) + Flags = getNoWrapFlagsFromUB(BO->Op); + + Constant *X = ConstantInt::get( + getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue())); + return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags); + } + break; case Instruction::AShr: { // AShr X, C, where C is a constant. @@ -6379,11 +6592,11 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L, const SCEV * ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L, SCEVUnionPredicate &Preds) { - return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds); + return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds); } const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) { - return getBackedgeTakenInfo(L).getExact(this); + return getBackedgeTakenInfo(L).getExact(L, this); } /// Similar to getBackedgeTakenCount, except return the least SCEV value that is @@ -6402,9 +6615,8 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { BasicBlock *Header = L->getHeader(); // Push all Loop-header PHIs onto the Worklist stack. - for (BasicBlock::iterator I = Header->begin(); - PHINode *PN = dyn_cast<PHINode>(I); ++I) - Worklist.push_back(PN); + for (PHINode &PN : Header->phis()) + Worklist.push_back(&PN); } const ScalarEvolution::BackedgeTakenInfo & @@ -6441,8 +6653,13 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // must be cleared in this scope. BackedgeTakenInfo Result = computeBackedgeTakenCount(L); - if (Result.getExact(this) != getCouldNotCompute()) { - assert(isLoopInvariant(Result.getExact(this), L) && + // In product build, there are no usage of statistic. + (void)NumTripCountsComputed; + (void)NumTripCountsNotComputed; +#if LLVM_ENABLE_STATS || !defined(NDEBUG) + const SCEV *BEExact = Result.getExact(L, this); + if (BEExact != getCouldNotCompute()) { + assert(isLoopInvariant(BEExact, L) && isLoopInvariant(Result.getMax(this), L) && "Computed backedge-taken count isn't loop invariant for loop!"); ++NumTripCountsComputed; @@ -6452,6 +6669,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { // Only count loops that have phi nodes as not being computable. ++NumTripCountsNotComputed; } +#endif // LLVM_ENABLE_STATS || !defined(NDEBUG) // Now that we know more about the trip count for this loop, forget any // existing SCEV values for PHI nodes in this loop since they are only @@ -6587,6 +6805,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) { } } +void ScalarEvolution::forgetTopmostLoop(const Loop *L) { + while (Loop *Parent = L->getParentLoop()) + L = Parent; + forgetLoop(L); +} + void ScalarEvolution::forgetValue(Value *V) { Instruction *I = dyn_cast<Instruction>(V); if (!I) return; @@ -6615,28 +6839,35 @@ void ScalarEvolution::forgetValue(Value *V) { } /// Get the exact loop backedge taken count considering all loop exits. A -/// computable result can only be returned for loops with a single exit. -/// Returning the minimum taken count among all exits is incorrect because one -/// of the loop's exit limit's may have been skipped. howFarToZero assumes that -/// the limit of each loop test is never skipped. This is a valid assumption as -/// long as the loop exits via that test. For precise results, it is the -/// caller's responsibility to specify the relevant loop exit using -/// getExact(ExitingBlock, SE). +/// computable result can only be returned for loops with all exiting blocks +/// dominating the latch. howFarToZero assumes that the limit of each loop test +/// is never skipped. This is a valid assumption as long as the loop exits via +/// that test. For precise results, it is the caller's responsibility to specify +/// the relevant loop exiting block using getExact(ExitingBlock, SE). const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE, +ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { // If any exits were not computable, the loop is not computable. if (!isComplete() || ExitNotTaken.empty()) return SE->getCouldNotCompute(); - const SCEV *BECount = nullptr; + const BasicBlock *Latch = L->getLoopLatch(); + // All exiting blocks we have collected must dominate the only backedge. + if (!Latch) + return SE->getCouldNotCompute(); + + // All exiting blocks we have gathered dominate loop's latch, so exact trip + // count is simply a minimum out of all these calculated exit counts. + SmallVector<const SCEV *, 2> Ops; for (auto &ENT : ExitNotTaken) { - assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV"); + const SCEV *BECount = ENT.ExactNotTaken; + assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!"); + assert(SE->DT.dominates(ENT.ExitingBlock, Latch) && + "We should only have known counts for exiting blocks that dominate " + "latch!"); + + Ops.push_back(BECount); - if (!BECount) - BECount = ENT.ExactNotTaken; - else if (BECount != ENT.ExactNotTaken) - return SE->getCouldNotCompute(); if (Preds && !ENT.hasAlwaysTruePredicate()) Preds->add(ENT.Predicate.get()); @@ -6644,8 +6875,7 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE, "Predicate should be always true!"); } - assert(BECount && "Invalid not taken count for loop exit"); - return BECount; + return SE->getUMinFromMismatchedTypes(Ops); } /// Get the exact not taken count for this loop exit. @@ -6842,99 +7072,60 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, bool AllowPredicates) { - // Okay, we've chosen an exiting block. See what condition causes us to exit - // at this block and remember the exit block and whether all other targets - // lead to the loop header. - bool MustExecuteLoopHeader = true; - BasicBlock *Exit = nullptr; - for (auto *SBB : successors(ExitingBlock)) - if (!L->contains(SBB)) { - if (Exit) // Multiple exit successors. - return getCouldNotCompute(); - Exit = SBB; - } else if (SBB != L->getHeader()) { - MustExecuteLoopHeader = false; - } - - // At this point, we know we have a conditional branch that determines whether - // the loop is exited. However, we don't know if the branch is executed each - // time through the loop. If not, then the execution count of the branch will - // not be equal to the trip count of the loop. - // - // Currently we check for this by checking to see if the Exit branch goes to - // the loop header. If so, we know it will always execute the same number of - // times as the loop. We also handle the case where the exit block *is* the - // loop header. This is common for un-rotated loops. - // - // If both of those tests fail, walk up the unique predecessor chain to the - // header, stopping if there is an edge that doesn't exit the loop. If the - // header is reached, the execution count of the branch will be equal to the - // trip count of the loop. - // - // More extensive analysis could be done to handle more cases here. - // - if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) { - // The simple checks failed, try climbing the unique predecessor chain - // up to the header. - bool Ok = false; - for (BasicBlock *BB = ExitingBlock; BB; ) { - BasicBlock *Pred = BB->getUniquePredecessor(); - if (!Pred) - return getCouldNotCompute(); - TerminatorInst *PredTerm = Pred->getTerminator(); - for (const BasicBlock *PredSucc : PredTerm->successors()) { - if (PredSucc == BB) - continue; - // If the predecessor has a successor that isn't BB and isn't - // outside the loop, assume the worst. - if (L->contains(PredSucc)) - return getCouldNotCompute(); - } - if (Pred == L->getHeader()) { - Ok = true; - break; - } - BB = Pred; - } - if (!Ok) - return getCouldNotCompute(); - } + assert(L->contains(ExitingBlock) && "Exit count for non-loop block?"); + // If our exiting block does not dominate the latch, then its connection with + // loop's exit limit may be far from trivial. + const BasicBlock *Latch = L->getLoopLatch(); + if (!Latch || !DT.dominates(ExitingBlock, Latch)) + return getCouldNotCompute(); bool IsOnlyExit = (L->getExitingBlock() != nullptr); TerminatorInst *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); + bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); + assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) && + "It should have one successor in loop and one exit block!"); // Proceed to the next level to examine the exit condition expression. return computeExitLimitFromCond( - L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1), + L, BI->getCondition(), ExitIfTrue, /*ControlsExit=*/IsOnlyExit, AllowPredicates); } - if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) + if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) { + // For switch, make sure that there is a single exit from the loop. + BasicBlock *Exit = nullptr; + for (auto *SBB : successors(ExitingBlock)) + if (!L->contains(SBB)) { + if (Exit) // Multiple exit successors. + return getCouldNotCompute(); + Exit = SBB; + } + assert(Exit && "Exiting block must have at least one exit"); return computeExitLimitFromSingleExitSwitch(L, SI, Exit, /*ControlsExit=*/IsOnlyExit); + } return getCouldNotCompute(); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond( - const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB, + const Loop *L, Value *ExitCond, bool ExitIfTrue, bool ControlsExit, bool AllowPredicates) { - ScalarEvolution::ExitLimitCacheTy Cache(L, TBB, FBB, AllowPredicates); - return computeExitLimitFromCondCached(Cache, L, ExitCond, TBB, FBB, + ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates); + return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates); } Optional<ScalarEvolution::ExitLimit> ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, - BasicBlock *TBB, BasicBlock *FBB, - bool ControlsExit, bool AllowPredicates) { + bool ExitIfTrue, bool ControlsExit, + bool AllowPredicates) { (void)this->L; - (void)this->TBB; - (void)this->FBB; + (void)this->ExitIfTrue; (void)this->AllowPredicates; - assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + assert(this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && "Variance in assumed invariant key components!"); auto Itr = TripCountMap.find({ExitCond, ControlsExit}); @@ -6944,47 +7135,48 @@ ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond, } void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond, - BasicBlock *TBB, BasicBlock *FBB, + bool ExitIfTrue, bool ControlsExit, bool AllowPredicates, const ExitLimit &EL) { - assert(this->L == L && this->TBB == TBB && this->FBB == FBB && + assert(this->L == L && this->ExitIfTrue == ExitIfTrue && this->AllowPredicates == AllowPredicates && "Variance in assumed invariant key components!"); auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL}); assert(InsertResult.second && "Expected successful insertion!"); (void)InsertResult; + (void)ExitIfTrue; } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, - BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { if (auto MaybeEL = - Cache.find(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates)) + Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates)) return *MaybeEL; - ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, TBB, FBB, + ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates); - Cache.insert(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates, EL); + Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL); return EL; } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( - ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB, - BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { + ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue, + bool ControlsExit, bool AllowPredicates) { // Check if the controlling expression for this loop is an And or Or. if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) { if (BO->getOpcode() == Instruction::And) { // Recurse on the operands of the and. - bool EitherMayExit = L->contains(TBB); + bool EitherMayExit = !ExitIfTrue; ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, - AllowPredicates); + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, - AllowPredicates); + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -7006,7 +7198,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( } else { // Both conditions must be true at the same time for the loop to exit. // For now, be conservative. - assert(L->contains(FBB) && "Loop block has no successor in loop!"); if (EL0.MaxNotTaken == EL1.MaxNotTaken) MaxBECount = EL0.MaxNotTaken; if (EL0.ExactNotTaken == EL1.ExactNotTaken) @@ -7027,13 +7218,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( } if (BO->getOpcode() == Instruction::Or) { // Recurse on the operands of the or. - bool EitherMayExit = L->contains(FBB); + bool EitherMayExit = ExitIfTrue; ExitLimit EL0 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit, - AllowPredicates); + Cache, L, BO->getOperand(0), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); ExitLimit EL1 = computeExitLimitFromCondCached( - Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit, - AllowPredicates); + Cache, L, BO->getOperand(1), ExitIfTrue, + ControlsExit && !EitherMayExit, AllowPredicates); const SCEV *BECount = getCouldNotCompute(); const SCEV *MaxBECount = getCouldNotCompute(); if (EitherMayExit) { @@ -7055,7 +7246,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( } else { // Both conditions must be false at the same time for the loop to exit. // For now, be conservative. - assert(L->contains(TBB) && "Loop block has no successor in loop!"); if (EL0.MaxNotTaken == EL1.MaxNotTaken) MaxBECount = EL0.MaxNotTaken; if (EL0.ExactNotTaken == EL1.ExactNotTaken) @@ -7071,12 +7261,12 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( // Proceed to the next level to examine the icmp. if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) { ExitLimit EL = - computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit); + computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit); if (EL.hasFullInfo() || !AllowPredicates) return EL; // Try again, but use SCEV predicates this time. - return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit, + return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit, /*AllowPredicates=*/true); } @@ -7085,7 +7275,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( // preserve the CFG and is temporarily leaving constant conditions // in place. if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) { - if (L->contains(FBB) == !CI->getZExtValue()) + if (ExitIfTrue == !CI->getZExtValue()) // The backedge is always taken. return getCouldNotCompute(); else @@ -7094,19 +7284,18 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl( } // If it's not an integer or pointer comparison then compute it the hard way. - return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + return computeExitCountExhaustively(L, ExitCond, ExitIfTrue); } ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromICmp(const Loop *L, ICmpInst *ExitCond, - BasicBlock *TBB, - BasicBlock *FBB, + bool ExitIfTrue, bool ControlsExit, bool AllowPredicates) { // If the condition was exit on true, convert the condition to exit on false ICmpInst::Predicate Pred; - if (!L->contains(FBB)) + if (!ExitIfTrue) Pred = ExitCond->getPredicate(); else Pred = ExitCond->getInversePredicate(); @@ -7188,7 +7377,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, } auto *ExhaustiveCount = - computeExitCountExhaustively(L, ExitCond, !L->contains(TBB)); + computeExitCountExhaustively(L, ExitCond, ExitIfTrue); if (!isa<SCEVCouldNotCompute>(ExhaustiveCount)) return ExhaustiveCount; @@ -7638,12 +7827,9 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, if (!Latch) return nullptr; - for (auto &I : *Header) { - PHINode *PHI = dyn_cast<PHINode>(&I); - if (!PHI) break; - auto *StartCST = getOtherIncomingValue(PHI, Latch); - if (!StartCST) continue; - CurrentIterVals[PHI] = StartCST; + for (PHINode &PHI : Header->phis()) { + if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) + CurrentIterVals[&PHI] = StartCST; } if (!CurrentIterVals.count(PN)) return RetVal = nullptr; @@ -7720,13 +7906,9 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L, BasicBlock *Latch = L->getLoopLatch(); assert(Latch && "Should follow from NumIncomingValues == 2!"); - for (auto &I : *Header) { - PHINode *PHI = dyn_cast<PHINode>(&I); - if (!PHI) - break; - auto *StartCST = getOtherIncomingValue(PHI, Latch); - if (!StartCST) continue; - CurrentIterVals[PHI] = StartCST; + for (PHINode &PHI : Header->phis()) { + if (auto *StartCST = getOtherIncomingValue(&PHI, Latch)) + CurrentIterVals[&PHI] = StartCST; } if (!CurrentIterVals.count(PN)) return getCouldNotCompute(); @@ -8107,6 +8289,14 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) { return getSCEVAtScope(getSCEV(V), L); } +const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const { + if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S)) + return stripInjectiveFunctions(ZExt->getOperand()); + if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S)) + return stripInjectiveFunctions(SExt->getOperand()); + return S; +} + /// Finds the minimum unsigned root of the following equation: /// /// A * X = B (mod N) @@ -8236,7 +8426,9 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, return getCouldNotCompute(); // Otherwise it will loop infinitely. } - const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V); + const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V)); + if (!AddRec && AllowPredicates) // Try to make this an AddRec using runtime tests, in the first X // iterations of this loop, where X is the SCEV expression found by the @@ -8644,43 +8836,88 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) { return isKnownNegative(S) || isKnownPositive(S); } +std::pair<const SCEV *, const SCEV *> +ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) { + // Compute SCEV on entry of loop L. + const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this); + if (Start == getCouldNotCompute()) + return { Start, Start }; + // Compute post increment SCEV for loop L. + const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this); + assert(PostInc != getCouldNotCompute() && "Unexpected could not compute"); + return { Start, PostInc }; +} + +bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS) { + // First collect all loops. + SmallPtrSet<const Loop *, 8> LoopsUsed; + getUsedLoops(LHS, LoopsUsed); + getUsedLoops(RHS, LoopsUsed); + + if (LoopsUsed.empty()) + return false; + + // Domination relationship must be a linear order on collected loops. +#ifndef NDEBUG + for (auto *L1 : LoopsUsed) + for (auto *L2 : LoopsUsed) + assert((DT.dominates(L1->getHeader(), L2->getHeader()) || + DT.dominates(L2->getHeader(), L1->getHeader())) && + "Domination relationship is not a linear order"); +#endif + + const Loop *MDL = + *std::max_element(LoopsUsed.begin(), LoopsUsed.end(), + [&](const Loop *L1, const Loop *L2) { + return DT.properlyDominates(L1->getHeader(), L2->getHeader()); + }); + + // Get init and post increment value for LHS. + auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS); + // if LHS contains unknown non-invariant SCEV then bail out. + if (SplitLHS.first == getCouldNotCompute()) + return false; + assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC"); + // Get init and post increment value for RHS. + auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS); + // if RHS contains unknown non-invariant SCEV then bail out. + if (SplitRHS.first == getCouldNotCompute()) + return false; + assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC"); + // It is possible that init SCEV contains an invariant load but it does + // not dominate MDL and is not available at MDL loop entry, so we should + // check it here. + if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) || + !isAvailableAtLoopEntry(SplitRHS.first, MDL)) + return false; + + return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) && + isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second, + SplitRHS.second); +} + bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { // Canonicalize the inputs first. (void)SimplifyICmpOperands(Pred, LHS, RHS); - // If LHS or RHS is an addrec, check to see if the condition is true in - // every iteration of the loop. - // If LHS and RHS are both addrec, both conditions must be true in - // every iteration of the loop. - const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS); - const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); - bool LeftGuarded = false; - bool RightGuarded = false; - if (LAR) { - const Loop *L = LAR->getLoop(); - if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) && - isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) { - if (!RAR) return true; - LeftGuarded = true; - } - } - if (RAR) { - const Loop *L = RAR->getLoop(); - if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) && - isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) { - if (!LAR) return true; - RightGuarded = true; - } - } - if (LeftGuarded && RightGuarded) + if (isKnownViaInduction(Pred, LHS, RHS)) return true; if (isKnownPredicateViaSplitting(Pred, LHS, RHS)) return true; - // Otherwise see what can be done with known constant ranges. - return isKnownPredicateViaConstantRanges(Pred, LHS, RHS); + // Otherwise see what can be done with some simple reasoning. + return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS); +} + +bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred, + const SCEVAddRecExpr *LHS, + const SCEV *RHS) { + const Loop *L = LHS->getLoop(); + return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) && + isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS); } bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS, @@ -8947,7 +9184,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; - if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) return true; BasicBlock *Latch = L->getLoopLatch(); @@ -9052,9 +9289,68 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; - if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS)) + // Both LHS and RHS must be available at loop entry. + assert(isAvailableAtLoopEntry(LHS, L) && + "LHS is not available at Loop Entry"); + assert(isAvailableAtLoopEntry(RHS, L) && + "RHS is not available at Loop Entry"); + + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) return true; + // If we cannot prove strict comparison (e.g. a > b), maybe we can prove + // the facts (a >= b && a != b) separately. A typical situation is when the + // non-strict comparison is known from ranges and non-equality is known from + // dominating predicates. If we are proving strict comparison, we always try + // to prove non-equality and non-strict comparison separately. + auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred); + const bool ProvingStrictComparison = (Pred != NonStrictPredicate); + bool ProvedNonStrictComparison = false; + bool ProvedNonEquality = false; + + if (ProvingStrictComparison) { + ProvedNonStrictComparison = + isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS); + ProvedNonEquality = + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + + // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard. + auto ProveViaGuard = [&](BasicBlock *Block) { + if (isImpliedViaGuard(Block, Pred, LHS, RHS)) + return true; + if (ProvingStrictComparison) { + if (!ProvedNonStrictComparison) + ProvedNonStrictComparison = + isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS); + if (!ProvedNonEquality) + ProvedNonEquality = + isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + return false; + }; + + // Try to prove (Pred, LHS, RHS) using isImpliedCond. + auto ProveViaCond = [&](Value *Condition, bool Inverse) { + if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse)) + return true; + if (ProvingStrictComparison) { + if (!ProvedNonStrictComparison) + ProvedNonStrictComparison = + isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse); + if (!ProvedNonEquality) + ProvedNonEquality = + isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse); + if (ProvedNonStrictComparison && ProvedNonEquality) + return true; + } + return false; + }; + // Starting at the loop predecessor, climb up the predecessor chain, as long // as there are predecessors that can be found that have unique successors // leading to the original header. @@ -9063,7 +9359,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) { - if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS)) + if (ProveViaGuard(Pair.first)) return true; BranchInst *LoopEntryPredicate = @@ -9072,9 +9368,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, LoopEntryPredicate->isUnconditional()) continue; - if (isImpliedCond(Pred, LHS, RHS, - LoopEntryPredicate->getCondition(), - LoopEntryPredicate->getSuccessor(0) != Pair.second)) + if (ProveViaCond(LoopEntryPredicate->getCondition(), + LoopEntryPredicate->getSuccessor(0) != Pair.second)) return true; } @@ -9086,7 +9381,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, if (!DT.dominates(CI, L->getHeader())) continue; - if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false)) + if (ProveViaCond(CI->getArgOperand(0), false)) return true; } @@ -9321,17 +9616,25 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, return M - L; } - const SCEV *L, *R; SCEV::NoWrapFlags Flags; - if (splitBinaryAdd(Less, L, R, Flags)) - if (const auto *LC = dyn_cast<SCEVConstant>(L)) - if (R == More) - return -(LC->getAPInt()); - - if (splitBinaryAdd(More, L, R, Flags)) - if (const auto *LC = dyn_cast<SCEVConstant>(L)) - if (R == Less) - return LC->getAPInt(); + const SCEV *LLess = nullptr, *RLess = nullptr; + const SCEV *LMore = nullptr, *RMore = nullptr; + const SCEVConstant *C1 = nullptr, *C2 = nullptr; + // Compare (X + C1) vs X. + if (splitBinaryAdd(Less, LLess, RLess, Flags)) + if ((C1 = dyn_cast<SCEVConstant>(LLess))) + if (RLess == More) + return -(C1->getAPInt()); + + // Compare X vs (X + C2). + if (splitBinaryAdd(More, LMore, RMore, Flags)) + if ((C2 = dyn_cast<SCEVConstant>(LMore))) + if (RMore == Less) + return C2->getAPInt(); + + // Compare (X + C1) vs (X + C2). + if (C1 && C2 && RLess == RMore) + return C2->getAPInt() - C1->getAPInt(); return None; } @@ -9408,10 +9711,121 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( } // Try to prove (1) or (2), as needed. - return isLoopEntryGuardedByCond(L, Pred, FoundRHS, + return isAvailableAtLoopEntry(FoundRHS, L) && + isLoopEntryGuardedByCond(L, Pred, FoundRHS, getConstant(FoundRHSLimit)); } +bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred, + const SCEV *LHS, const SCEV *RHS, + const SCEV *FoundLHS, + const SCEV *FoundRHS, unsigned Depth) { + const PHINode *LPhi = nullptr, *RPhi = nullptr; + + auto ClearOnExit = make_scope_exit([&]() { + if (LPhi) { + bool Erased = PendingMerges.erase(LPhi); + assert(Erased && "Failed to erase LPhi!"); + (void)Erased; + } + if (RPhi) { + bool Erased = PendingMerges.erase(RPhi); + assert(Erased && "Failed to erase RPhi!"); + (void)Erased; + } + }); + + // Find respective Phis and check that they are not being pending. + if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS)) + if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) { + if (!PendingMerges.insert(Phi).second) + return false; + LPhi = Phi; + } + if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS)) + if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) { + // If we detect a loop of Phi nodes being processed by this method, for + // example: + // + // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ] + // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ] + // + // we don't want to deal with a case that complex, so return conservative + // answer false. + if (!PendingMerges.insert(Phi).second) + return false; + RPhi = Phi; + } + + // If none of LHS, RHS is a Phi, nothing to do here. + if (!LPhi && !RPhi) + return false; + + // If there is a SCEVUnknown Phi we are interested in, make it left. + if (!LPhi) { + std::swap(LHS, RHS); + std::swap(FoundLHS, FoundRHS); + std::swap(LPhi, RPhi); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!"); + const BasicBlock *LBB = LPhi->getParent(); + const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS); + + auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) { + return isKnownViaNonRecursiveReasoning(Pred, S1, S2) || + isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) || + isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth); + }; + + if (RPhi && RPhi->getParent() == LBB) { + // Case one: RHS is also a SCEVUnknown Phi from the same basic block. + // If we compare two Phis from the same block, and for each entry block + // the predicate is true for incoming values from this block, then the + // predicate is also true for the Phis. + for (const BasicBlock *IncBB : predecessors(LBB)) { + const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB)); + if (!ProvedEasily(L, R)) + return false; + } + } else if (RAR && RAR->getLoop()->getHeader() == LBB) { + // Case two: RHS is also a Phi from the same basic block, and it is an + // AddRec. It means that there is a loop which has both AddRec and Unknown + // PHIs, for it we can compare incoming values of AddRec from above the loop + // and latch with their respective incoming values of LPhi. + // TODO: Generalize to handle loops with many inputs in a header. + if (LPhi->getNumIncomingValues() != 2) return false; + + auto *RLoop = RAR->getLoop(); + auto *Predecessor = RLoop->getLoopPredecessor(); + assert(Predecessor && "Loop with AddRec with no predecessor?"); + const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor)); + if (!ProvedEasily(L1, RAR->getStart())) + return false; + auto *Latch = RLoop->getLoopLatch(); + assert(Latch && "Loop with AddRec with no latch?"); + const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch)); + if (!ProvedEasily(L2, RAR->getPostIncExpr(*this))) + return false; + } else { + // In all other cases go over inputs of LHS and compare each of them to RHS, + // the predicate is true for (LHS, RHS) if it is true for all such pairs. + // At this point RHS is either a non-Phi, or it is a Phi from some block + // different from LBB. + for (const BasicBlock *IncBB : predecessors(LBB)) { + // Check that RHS is available in this block. + if (!dominates(RHS, IncBB)) + return false; + const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB)); + if (!ProvedEasily(L, RHS)) + return false; + } + } + return true; +} + bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const SCEV *FoundLHS, @@ -9565,13 +9979,14 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, }; // Acquire values from extensions. + auto *OrigLHS = LHS; auto *OrigFoundLHS = FoundLHS; LHS = GetOpFromSExt(LHS); FoundLHS = GetOpFromSExt(FoundLHS); // Is the SGT predicate can be proved trivially or using the found context. auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) { - return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) || + return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) || isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS, FoundRHS, Depth + 1); }; @@ -9672,11 +10087,17 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, } } + // If our expression contained SCEVUnknown Phis, and we split it down and now + // need to prove something for them, try to prove the predicate for every + // possible incoming values of those Phis. + if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1)) + return true; + return false; } bool -ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred, +ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) || IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) || @@ -9698,26 +10119,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred, break; case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && - isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && - isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: - if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && - isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS)) return true; break; case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: - if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && - isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) + if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) && + isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS)) return true; break; } @@ -10195,6 +10616,31 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, return SE.getCouldNotCompute(); } +const SCEVAddRecExpr * +SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const { + assert(getNumOperands() > 1 && "AddRec with zero step?"); + // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)), + // but in this case we cannot guarantee that the value returned will be an + // AddRec because SCEV does not have a fixed point where it stops + // simplification: it is legal to return ({rec1} + {rec2}). For example, it + // may happen if we reach arithmetic depth limit while simplifying. So we + // construct the returned value explicitly. + SmallVector<const SCEV *, 3> Ops; + // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and + // (this + Step) is {A+B,+,B+C,+...,+,N}. + for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i) + Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1))); + // We know that the last operand is not a constant zero (otherwise it would + // have been popped out earlier). This guarantees us that if the result has + // the same last operand, then it will also not be popped out, meaning that + // the returned value will be an AddRec. + const SCEV *Last = getOperand(getNumOperands() - 1); + assert(!Last->isZero() && "Recurrency with zero step?"); + Ops.push_back(Last); + return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(), + SCEV::FlagAnyWrap)); +} + // Return true when S contains at least an undef value. static inline bool containsUndefs(const SCEV *S) { return SCEVExprContains(S, [](const SCEV *S) { @@ -10337,22 +10783,22 @@ void ScalarEvolution::collectParametricTerms(const SCEV *Expr, SCEVCollectStrides StrideCollector(*this, Strides); visitAll(Expr, StrideCollector); - DEBUG({ - dbgs() << "Strides:\n"; - for (const SCEV *S : Strides) - dbgs() << *S << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Strides:\n"; + for (const SCEV *S : Strides) + dbgs() << *S << "\n"; + }); for (const SCEV *S : Strides) { SCEVCollectTerms TermCollector(Terms); visitAll(S, TermCollector); } - DEBUG({ - dbgs() << "Terms:\n"; - for (const SCEV *T : Terms) - dbgs() << *T << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); SCEVCollectAddRecMultiplies MulCollector(Terms, *this); visitAll(Expr, MulCollector); @@ -10463,18 +10909,18 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, if (!containsParameters(Terms)) return; - DEBUG({ - dbgs() << "Terms:\n"; - for (const SCEV *T : Terms) - dbgs() << *T << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Terms:\n"; + for (const SCEV *T : Terms) + dbgs() << *T << "\n"; + }); // Remove duplicates. array_pod_sort(Terms.begin(), Terms.end()); Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. - std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { + llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { return numberOfTerms(LHS) > numberOfTerms(RHS); }); @@ -10494,11 +10940,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, if (const SCEV *NewT = removeConstantFactors(*this, T)) NewTerms.push_back(NewT); - DEBUG({ - dbgs() << "Terms after sorting:\n"; - for (const SCEV *T : NewTerms) - dbgs() << *T << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Terms after sorting:\n"; + for (const SCEV *T : NewTerms) + dbgs() << *T << "\n"; + }); if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) { Sizes.clear(); @@ -10508,11 +10954,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, // The last element to be pushed into Sizes is the size of an element. Sizes.push_back(ElementSize); - DEBUG({ - dbgs() << "Sizes:\n"; - for (const SCEV *S : Sizes) - dbgs() << *S << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Sizes:\n"; + for (const SCEV *S : Sizes) + dbgs() << *S << "\n"; + }); } void ScalarEvolution::computeAccessFunctions( @@ -10532,13 +10978,13 @@ void ScalarEvolution::computeAccessFunctions( const SCEV *Q, *R; SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R); - DEBUG({ - dbgs() << "Res: " << *Res << "\n"; - dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; - dbgs() << "Res divided by Sizes[i]:\n"; - dbgs() << "Quotient: " << *Q << "\n"; - dbgs() << "Remainder: " << *R << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Res: " << *Res << "\n"; + dbgs() << "Sizes[i]: " << *Sizes[i] << "\n"; + dbgs() << "Res divided by Sizes[i]:\n"; + dbgs() << "Quotient: " << *Q << "\n"; + dbgs() << "Remainder: " << *R << "\n"; + }); Res = Q; @@ -10566,11 +11012,11 @@ void ScalarEvolution::computeAccessFunctions( std::reverse(Subscripts.begin(), Subscripts.end()); - DEBUG({ - dbgs() << "Subscripts:\n"; - for (const SCEV *S : Subscripts) - dbgs() << *S << "\n"; - }); + LLVM_DEBUG({ + dbgs() << "Subscripts:\n"; + for (const SCEV *S : Subscripts) + dbgs() << *S << "\n"; + }); } /// Splits the SCEV into two vectors of SCEVs representing the subscripts and @@ -10644,17 +11090,17 @@ void ScalarEvolution::delinearize(const SCEV *Expr, if (Subscripts.empty()) return; - DEBUG({ - dbgs() << "succeeded to delinearize " << *Expr << "\n"; - dbgs() << "ArrayDecl[UnknownSize]"; - for (const SCEV *S : Sizes) - dbgs() << "[" << *S << "]"; + LLVM_DEBUG({ + dbgs() << "succeeded to delinearize " << *Expr << "\n"; + dbgs() << "ArrayDecl[UnknownSize]"; + for (const SCEV *S : Sizes) + dbgs() << "[" << *S << "]"; - dbgs() << "\nArrayRef"; - for (const SCEV *S : Subscripts) - dbgs() << "[" << *S << "]"; - dbgs() << "\n"; - }); + dbgs() << "\nArrayRef"; + for (const SCEV *S : Subscripts) + dbgs() << "[" << *S << "]"; + dbgs() << "\n"; + }); } //===----------------------------------------------------------------------===// @@ -10731,6 +11177,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), + PendingPhiRanges(std::move(Arg.PendingPhiRanges)), + PendingMerges(std::move(Arg.PendingMerges)), MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( @@ -10774,6 +11222,8 @@ ScalarEvolution::~ScalarEvolution() { BTCI.second.clear(); assert(PendingLoopPredicates.empty() && "isImpliedCond garbage"); + assert(PendingPhiRanges.empty() && "getRangeRef garbage"); + assert(PendingMerges.empty() && "isImpliedViaMerge garbage"); assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!"); assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!"); } @@ -11184,9 +11634,13 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S) { RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } -void ScalarEvolution::addToLoopUseLists(const SCEV *S) { +void +ScalarEvolution::getUsedLoops(const SCEV *S, + SmallPtrSetImpl<const Loop *> &LoopsUsed) { struct FindUsedLoops { - SmallPtrSet<const Loop *, 8> LoopsUsed; + FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed) + : LoopsUsed(LoopsUsed) {} + SmallPtrSetImpl<const Loop *> &LoopsUsed; bool follow(const SCEV *S) { if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) LoopsUsed.insert(AR->getLoop()); @@ -11196,10 +11650,14 @@ void ScalarEvolution::addToLoopUseLists(const SCEV *S) { bool isDone() const { return false; } }; - FindUsedLoops F; + FindUsedLoops F(LoopsUsed); SCEVTraversal<FindUsedLoops>(F).visitAll(S); +} - for (auto *L : F.LoopsUsed) +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { + SmallPtrSet<const Loop *, 8> LoopsUsed; + getUsedLoops(S, LoopsUsed); + for (auto *L : LoopsUsed) LoopUsers[L].push_back(S); } @@ -11482,6 +11940,12 @@ private: if (!PredicatedRewrite) return Expr; for (auto *P : PredicatedRewrite->second){ + // Wrap predicates from outer loops are not supported. + if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) { + auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr()); + if (L != AR->getLoop()) + return Expr; + } if (!addOverflowAssumption(P)) return Expr; } @@ -11787,3 +12251,43 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const { OS.indent(Depth + 2) << "--> " << *II->second.second << "\n"; } } + +// Match the mathematical pattern A - (A / B) * B, where A and B can be +// arbitrary expressions. +// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is +// 4, A / B becomes X / 8). +bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS, + const SCEV *&RHS) { + const auto *Add = dyn_cast<SCEVAddExpr>(Expr); + if (Add == nullptr || Add->getNumOperands() != 2) + return false; + + const SCEV *A = Add->getOperand(1); + const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0)); + + if (Mul == nullptr) + return false; + + const auto MatchURemWithDivisor = [&](const SCEV *B) { + // (SomeExpr + (-(SomeExpr / B) * B)). + if (Expr == getURemExpr(A, B)) { + LHS = A; + RHS = B; + return true; + } + return false; + }; + + // (SomeExpr + (-1 * (SomeExpr / B) * B)). + if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0))) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(2)); + + // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)). + if (Mul->getNumOperands() == 2) + return MatchURemWithDivisor(Mul->getOperand(1)) || + MatchURemWithDivisor(Mul->getOperand(0)) || + MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) || + MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0))); + return false; +} |