summaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp1428
1 files changed, 966 insertions, 462 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index f34549ae52b4..aa95ace93014 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -83,6 +83,7 @@
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Config/llvm-config.h"
#include "llvm/IR/Argument.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/CFG.h"
@@ -420,24 +421,21 @@ SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVCastExpr(ID, scTruncate, op, ty) {
- assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate non-integer value!");
}
SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVCastExpr(ID, scZeroExtend, op, ty) {
- assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot zero extend non-integer value!");
}
SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
: SCEVCastExpr(ID, scSignExtend, op, ty) {
- assert((Op->getType()->isIntegerTy() || Op->getType()->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot sign extend non-integer value!");
}
@@ -1255,42 +1253,32 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,
if (const SCEVZeroExtendExpr *SZ = dyn_cast<SCEVZeroExtendExpr>(Op))
return getTruncateOrZeroExtend(SZ->getOperand(), Ty);
- // trunc(x1+x2+...+xN) --> trunc(x1)+trunc(x2)+...+trunc(xN) if we can
- // eliminate all the truncates, or we replace other casts with truncates.
- if (const SCEVAddExpr *SA = dyn_cast<SCEVAddExpr>(Op)) {
+ // trunc(x1 + ... + xN) --> trunc(x1) + ... + trunc(xN) and
+ // trunc(x1 * ... * xN) --> trunc(x1) * ... * trunc(xN),
+ // if after transforming we have at most one truncate, not counting truncates
+ // that replace other casts.
+ if (isa<SCEVAddExpr>(Op) || isa<SCEVMulExpr>(Op)) {
+ auto *CommOp = cast<SCEVCommutativeExpr>(Op);
SmallVector<const SCEV *, 4> Operands;
- bool hasTrunc = false;
- for (unsigned i = 0, e = SA->getNumOperands(); i != e && !hasTrunc; ++i) {
- const SCEV *S = getTruncateExpr(SA->getOperand(i), Ty);
- if (!isa<SCEVCastExpr>(SA->getOperand(i)))
- hasTrunc = isa<SCEVTruncateExpr>(S);
+ unsigned numTruncs = 0;
+ for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
+ ++i) {
+ const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty);
+ if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S))
+ numTruncs++;
Operands.push_back(S);
}
- if (!hasTrunc)
- return getAddExpr(Operands);
- // In spite we checked in the beginning that ID is not in the cache,
- // it is possible that during recursion and different modification
- // ID came to cache, so if we found it, just return it.
- if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
- return S;
- }
-
- // trunc(x1*x2*...*xN) --> trunc(x1)*trunc(x2)*...*trunc(xN) if we can
- // eliminate all the truncates, or we replace other casts with truncates.
- if (const SCEVMulExpr *SM = dyn_cast<SCEVMulExpr>(Op)) {
- SmallVector<const SCEV *, 4> Operands;
- bool hasTrunc = false;
- for (unsigned i = 0, e = SM->getNumOperands(); i != e && !hasTrunc; ++i) {
- const SCEV *S = getTruncateExpr(SM->getOperand(i), Ty);
- if (!isa<SCEVCastExpr>(SM->getOperand(i)))
- hasTrunc = isa<SCEVTruncateExpr>(S);
- Operands.push_back(S);
+ if (numTruncs < 2) {
+ if (isa<SCEVAddExpr>(Op))
+ return getAddExpr(Operands);
+ else if (isa<SCEVMulExpr>(Op))
+ return getMulExpr(Operands);
+ else
+ llvm_unreachable("Unexpected SCEV type for Op.");
}
- if (!hasTrunc)
- return getMulExpr(Operands);
- // In spite we checked in the beginning that ID is not in the cache,
- // it is possible that during recursion and different modification
- // ID came to cache, so if we found it, just return it.
+ // Although we checked in the beginning that ID is not in the cache, it is
+ // possible that during recursion and different modification ID was inserted
+ // into the cache. So if we find it, just return it.
if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
return S;
}
@@ -1571,6 +1559,43 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,
return false;
}
+// Finds an integer D for an expression (C + x + y + ...) such that the top
+// level addition in (D + (C - D + x + y + ...)) would not wrap (signed or
+// unsigned) and the number of trailing zeros of (C - D + x + y + ...) is
+// maximized, where C is the \p ConstantTerm, x, y, ... are arbitrary SCEVs, and
+// the (C + x + y + ...) expression is \p WholeAddExpr.
+static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
+ const SCEVConstant *ConstantTerm,
+ const SCEVAddExpr *WholeAddExpr) {
+ const APInt C = ConstantTerm->getAPInt();
+ const unsigned BitWidth = C.getBitWidth();
+ // Find number of trailing zeros of (x + y + ...) w/o the C first:
+ uint32_t TZ = BitWidth;
+ for (unsigned I = 1, E = WholeAddExpr->getNumOperands(); I < E && TZ; ++I)
+ TZ = std::min(TZ, SE.GetMinTrailingZeros(WholeAddExpr->getOperand(I)));
+ if (TZ) {
+ // Set D to be as many least significant bits of C as possible while still
+ // guaranteeing that adding D to (C - D + x + y + ...) won't cause a wrap:
+ return TZ < BitWidth ? C.trunc(TZ).zext(BitWidth) : C;
+ }
+ return APInt(BitWidth, 0);
+}
+
+// Finds an integer D for an affine AddRec expression {C,+,x} such that the top
+// level addition in (D + {C-D,+,x}) would not wrap (signed or unsigned) and the
+// number of trailing zeros of (C - D + x * n) is maximized, where C is the \p
+// ConstantStart, x is an arbitrary \p Step, and n is the loop trip count.
+static APInt extractConstantWithoutWrapping(ScalarEvolution &SE,
+ const APInt &ConstantStart,
+ const SCEV *Step) {
+ const unsigned BitWidth = ConstantStart.getBitWidth();
+ const uint32_t TZ = SE.GetMinTrailingZeros(Step);
+ if (TZ)
+ return TZ < BitWidth ? ConstantStart.trunc(TZ).zext(BitWidth)
+ : ConstantStart;
+ return APInt(BitWidth, 0);
+}
+
const SCEV *
ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) &&
@@ -1727,9 +1752,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
getUnsignedRangeMax(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
- (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_ULT, Start, N) &&
- isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT,
- AR->getPostIncExpr(*this), N))) {
+ isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
// Cache knowledge of AR NUW, which is propagated to this
// AddRec.
const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
@@ -1744,9 +1767,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
getSignedRangeMin(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
- (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_UGT, Start, N) &&
- isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT,
- AR->getPostIncExpr(*this), N))) {
+ isKnownOnEveryIteration(ICmpInst::ICMP_UGT, AR, N)) {
// Cache knowledge of AR NW, which is propagated to this
// AddRec. Negative step causes unsigned wrap, but it
// still can't self-wrap.
@@ -1761,6 +1782,23 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
}
+ // zext({C,+,Step}) --> (zext(D) + zext({C-D,+,Step}))<nuw><nsw>
+ // if D + (C - D + Step * n) could be proven to not unsigned wrap
+ // where D maximizes the number of trailing zeros of (C - D + Step * n)
+ if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
+ const APInt &C = SC->getAPInt();
+ const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
+ if (D != 0) {
+ const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
+ const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SZExtD, SZExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+
if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
return getAddRecExpr(
@@ -1769,6 +1807,20 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
}
+ // zext(A % B) --> zext(A) % zext(B)
+ {
+ const SCEV *LHS;
+ const SCEV *RHS;
+ if (matchURem(Op, LHS, RHS))
+ return getURemExpr(getZeroExtendExpr(LHS, Ty, Depth + 1),
+ getZeroExtendExpr(RHS, Ty, Depth + 1));
+ }
+
+ // zext(A / B) --> zext(A) / zext(B).
+ if (auto *Div = dyn_cast<SCEVUDivExpr>(Op))
+ return getUDivExpr(getZeroExtendExpr(Div->getLHS(), Ty, Depth + 1),
+ getZeroExtendExpr(Div->getRHS(), Ty, Depth + 1));
+
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
// zext((A + B + ...)<nuw>) --> (zext(A) + zext(B) + ...)<nuw>
if (SA->hasNoUnsignedWrap()) {
@@ -1779,6 +1831,65 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNUW, Depth + 1);
}
+
+ // zext(C + x + y + ...) --> (zext(D) + zext((C - D) + x + y + ...))
+ // if D + (C - D + x + y + ...) could be proven to not unsigned wrap
+ // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
+ //
+ // Often address arithmetics contain expressions like
+ // (zext (add (shl X, C1), C2)), for instance, (zext (5 + (4 * X))).
+ // This transformation is useful while proving that such expressions are
+ // equal or differ by a small constant amount, see LoadStoreVectorizer pass.
+ if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
+ const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
+ if (D != 0) {
+ const SCEV *SZExtD = getZeroExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
+ const SCEV *SZExtR = getZeroExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SZExtD, SZExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
+ }
+
+ if (auto *SM = dyn_cast<SCEVMulExpr>(Op)) {
+ // zext((A * B * ...)<nuw>) --> (zext(A) * zext(B) * ...)<nuw>
+ if (SM->hasNoUnsignedWrap()) {
+ // If the multiply does not unsign overflow then we can, by definition,
+ // commute the zero extension with the multiply operation.
+ SmallVector<const SCEV *, 4> Ops;
+ for (const auto *Op : SM->operands())
+ Ops.push_back(getZeroExtendExpr(Op, Ty, Depth + 1));
+ return getMulExpr(Ops, SCEV::FlagNUW, Depth + 1);
+ }
+
+ // zext(2^K * (trunc X to iN)) to iM ->
+ // 2^K * (zext(trunc X to i{N-K}) to iM)<nuw>
+ //
+ // Proof:
+ //
+ // zext(2^K * (trunc X to iN)) to iM
+ // = zext((trunc X to iN) << K) to iM
+ // = zext((trunc X to i{N-K}) << K)<nuw> to iM
+ // (because shl removes the top K bits)
+ // = zext((2^K * (trunc X to i{N-K}))<nuw>) to iM
+ // = (2^K * (zext(trunc X to i{N-K}) to iM))<nuw>.
+ //
+ if (SM->getNumOperands() == 2)
+ if (auto *MulLHS = dyn_cast<SCEVConstant>(SM->getOperand(0)))
+ if (MulLHS->getAPInt().isPowerOf2())
+ if (auto *TruncRHS = dyn_cast<SCEVTruncateExpr>(SM->getOperand(1))) {
+ int NewTruncBits = getTypeSizeInBits(TruncRHS->getType()) -
+ MulLHS->getAPInt().logBase2();
+ Type *NewTruncTy = IntegerType::get(getContext(), NewTruncBits);
+ return getMulExpr(
+ getZeroExtendExpr(MulLHS, Ty),
+ getZeroExtendExpr(
+ getTruncateExpr(TruncRHS->getOperand(), NewTruncTy), Ty),
+ SCEV::FlagNUW, Depth + 1);
+ }
}
// The cast wasn't folded; create an explicit cast node.
@@ -1842,24 +1953,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
return getTruncateOrSignExtend(X, Ty);
}
- // sext(C1 + (C2 * x)) --> C1 + sext(C2 * x) if C1 < C2
if (auto *SA = dyn_cast<SCEVAddExpr>(Op)) {
- if (SA->getNumOperands() == 2) {
- auto *SC1 = dyn_cast<SCEVConstant>(SA->getOperand(0));
- auto *SMul = dyn_cast<SCEVMulExpr>(SA->getOperand(1));
- if (SMul && SC1) {
- if (auto *SC2 = dyn_cast<SCEVConstant>(SMul->getOperand(0))) {
- const APInt &C1 = SC1->getAPInt();
- const APInt &C2 = SC2->getAPInt();
- if (C1.isStrictlyPositive() && C2.isStrictlyPositive() &&
- C2.ugt(C1) && C2.isPowerOf2())
- return getAddExpr(getSignExtendExpr(SC1, Ty, Depth + 1),
- getSignExtendExpr(SMul, Ty, Depth + 1),
- SCEV::FlagAnyWrap, Depth + 1);
- }
- }
- }
-
// sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw>
if (SA->hasNoSignedWrap()) {
// If the addition does not sign overflow then we can, by definition,
@@ -1869,6 +1963,28 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
Ops.push_back(getSignExtendExpr(Op, Ty, Depth + 1));
return getAddExpr(Ops, SCEV::FlagNSW, Depth + 1);
}
+
+ // sext(C + x + y + ...) --> (sext(D) + sext((C - D) + x + y + ...))
+ // if D + (C - D + x + y + ...) could be proven to not signed wrap
+ // where D maximizes the number of trailing zeros of (C - D + x + y + ...)
+ //
+ // For instance, this will bring two seemingly different expressions:
+ // 1 + sext(5 + 20 * %x + 24 * %y) and
+ // sext(6 + 20 * %x + 24 * %y)
+ // to the same form:
+ // 2 + sext(4 + 20 * %x + 24 * %y)
+ if (const auto *SC = dyn_cast<SCEVConstant>(SA->getOperand(0))) {
+ const APInt &D = extractConstantWithoutWrapping(*this, SC, SA);
+ if (D != 0) {
+ const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddExpr(getConstant(-D), SA, SCEV::FlagAnyWrap, Depth);
+ const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SSExtD, SSExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
+ }
+ }
}
// If the input value is a chrec scev, and we can prove that the value
// did not overflow the old, smaller, value, we can sign extend all of the
@@ -1989,9 +2105,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
getSignedOverflowLimitForStep(Step, &Pred, this);
if (OverflowLimit &&
(isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
- (isLoopEntryGuardedByCond(L, Pred, Start, OverflowLimit) &&
- isLoopBackedgeGuardedByCond(L, Pred, AR->getPostIncExpr(*this),
- OverflowLimit)))) {
+ isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
// Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
return getAddRecExpr(
@@ -2000,21 +2114,20 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
}
- // If Start and Step are constants, check if we can apply this
- // transformation:
- // sext{C1,+,C2} --> C1 + sext{0,+,C2} if C1 < C2
- auto *SC1 = dyn_cast<SCEVConstant>(Start);
- auto *SC2 = dyn_cast<SCEVConstant>(Step);
- if (SC1 && SC2) {
- const APInt &C1 = SC1->getAPInt();
- const APInt &C2 = SC2->getAPInt();
- if (C1.isStrictlyPositive() && C2.isStrictlyPositive() && C2.ugt(C1) &&
- C2.isPowerOf2()) {
- Start = getSignExtendExpr(Start, Ty, Depth + 1);
- const SCEV *NewAR = getAddRecExpr(getZero(AR->getType()), Step, L,
- AR->getNoWrapFlags());
- return getAddExpr(Start, getSignExtendExpr(NewAR, Ty, Depth + 1),
- SCEV::FlagAnyWrap, Depth + 1);
+ // sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
+ // if D + (C - D + Step * n) could be proven to not signed wrap
+ // where D maximizes the number of trailing zeros of (C - D + Step * n)
+ if (const auto *SC = dyn_cast<SCEVConstant>(Start)) {
+ const APInt &C = SC->getAPInt();
+ const APInt &D = extractConstantWithoutWrapping(*this, C, Step);
+ if (D != 0) {
+ const SCEV *SSExtD = getSignExtendExpr(getConstant(D), Ty, Depth);
+ const SCEV *SResidual =
+ getAddRecExpr(getConstant(C - D), Step, L, AR->getNoWrapFlags());
+ const SCEV *SSExtR = getSignExtendExpr(SResidual, Ty, Depth + 1);
+ return getAddExpr(SSExtD, SSExtR,
+ (SCEV::NoWrapFlags)(SCEV::FlagNSW | SCEV::FlagNUW),
+ Depth + 1);
}
}
@@ -2210,22 +2323,35 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
SignOrUnsignWrap = ScalarEvolution::maskFlags(Flags, SignOrUnsignMask);
- if (SignOrUnsignWrap != SignOrUnsignMask && Type == scAddExpr &&
- Ops.size() == 2 && isa<SCEVConstant>(Ops[0])) {
+ if (SignOrUnsignWrap != SignOrUnsignMask &&
+ (Type == scAddExpr || Type == scMulExpr) && Ops.size() == 2 &&
+ isa<SCEVConstant>(Ops[0])) {
- // (A + C) --> (A + C)<nsw> if the addition does not sign overflow
- // (A + C) --> (A + C)<nuw> if the addition does not unsign overflow
+ auto Opcode = [&] {
+ switch (Type) {
+ case scAddExpr:
+ return Instruction::Add;
+ case scMulExpr:
+ return Instruction::Mul;
+ default:
+ llvm_unreachable("Unexpected SCEV op.");
+ }
+ }();
const APInt &C = cast<SCEVConstant>(Ops[0])->getAPInt();
+
+ // (A <opcode> C) --> (A <opcode> C)<nsw> if the op doesn't sign overflow.
if (!(SignOrUnsignWrap & SCEV::FlagNSW)) {
auto NSWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
- Instruction::Add, C, OBO::NoSignedWrap);
+ Opcode, C, OBO::NoSignedWrap);
if (NSWRegion.contains(SE->getSignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNSW);
}
+
+ // (A <opcode> C) --> (A <opcode> C)<nuw> if the op doesn't unsign overflow.
if (!(SignOrUnsignWrap & SCEV::FlagNUW)) {
auto NUWRegion = ConstantRange::makeGuaranteedNoWrapRegion(
- Instruction::Add, C, OBO::NoUnsignedWrap);
+ Opcode, C, OBO::NoUnsignedWrap);
if (NUWRegion.contains(SE->getUnsignedRange(Ops[1])))
Flags = ScalarEvolution::setFlags(Flags, SCEV::FlagNUW);
}
@@ -2235,59 +2361,7 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,
}
bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
- if (!isLoopInvariant(S, L))
- return false;
- // If a value depends on a SCEVUnknown which is defined after the loop, we
- // conservatively assume that we cannot calculate it at the loop's entry.
- struct FindDominatedSCEVUnknown {
- bool Found = false;
- const Loop *L;
- DominatorTree &DT;
- LoopInfo &LI;
-
- FindDominatedSCEVUnknown(const Loop *L, DominatorTree &DT, LoopInfo &LI)
- : L(L), DT(DT), LI(LI) {}
-
- bool checkSCEVUnknown(const SCEVUnknown *SU) {
- if (auto *I = dyn_cast<Instruction>(SU->getValue())) {
- if (DT.dominates(L->getHeader(), I->getParent()))
- Found = true;
- else
- assert(DT.dominates(I->getParent(), L->getHeader()) &&
- "No dominance relationship between SCEV and loop?");
- }
- return false;
- }
-
- bool follow(const SCEV *S) {
- switch (static_cast<SCEVTypes>(S->getSCEVType())) {
- case scConstant:
- return false;
- case scAddRecExpr:
- case scTruncate:
- case scZeroExtend:
- case scSignExtend:
- case scAddExpr:
- case scMulExpr:
- case scUMaxExpr:
- case scSMaxExpr:
- case scUDivExpr:
- return true;
- case scUnknown:
- return checkSCEVUnknown(cast<SCEVUnknown>(S));
- case scCouldNotCompute:
- llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!");
- }
- return false;
- }
-
- bool isDone() { return Found; }
- };
-
- FindDominatedSCEVUnknown FSU(L, DT, LI);
- SCEVTraversal<FindDominatedSCEVUnknown> ST(FSU);
- ST.visitAll(S);
- return !FSU.Found;
+ return isLoopInvariant(S, L) && properlyDominates(S, L->getHeader());
}
/// Get a canonical add expression, or something simpler if possible.
@@ -2358,7 +2432,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
FoundMatch = true;
}
if (FoundMatch)
- return getAddExpr(Ops, Flags);
+ return getAddExpr(Ops, Flags, Depth + 1);
// Check for truncates. If all the operands are truncated from the same
// type, see if factoring out the truncate would permit the result to be
@@ -2418,7 +2492,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
}
if (Ok) {
// Evaluate the expression in the larger type.
- const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
+ const SCEV *Fold = getAddExpr(LargeOps, SCEV::FlagAnyWrap, Depth + 1);
// If it folds to something simple, use it. Otherwise, don't.
if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
return getTruncateExpr(Fold, Ty);
@@ -2796,22 +2870,21 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
- // C1*(C2+V) -> C1*C2 + C1*V
if (Ops.size() == 2)
- if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
- // If any of Add's ops are Adds or Muls with a constant,
- // apply this transformation as well.
- if (Add->getNumOperands() == 2)
- // TODO: There are some cases where this transformation is not
- // profitable, for example:
- // Add = (C0 + X) * Y + Z.
- // Maybe the scope of this transformation should be narrowed down.
- if (containsConstantInAddMulChain(Add))
- return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
- SCEV::FlagAnyWrap, Depth + 1),
- getMulExpr(LHSC, Add->getOperand(1),
- SCEV::FlagAnyWrap, Depth + 1),
- SCEV::FlagAnyWrap, Depth + 1);
+ // C1*(C2+V) -> C1*C2 + C1*V
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
+ // If any of Add's ops are Adds or Muls with a constant, apply this
+ // transformation as well.
+ //
+ // TODO: There are some cases where this transformation is not
+ // profitable; for example, Add = (C0 + X) * Y + Z. Maybe the scope of
+ // this transformation should be narrowed down.
+ if (Add->getNumOperands() == 2 && containsConstantInAddMulChain(Add))
+ return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
+ SCEV::FlagAnyWrap, Depth + 1),
+ getMulExpr(LHSC, Add->getOperand(1),
+ SCEV::FlagAnyWrap, Depth + 1),
+ SCEV::FlagAnyWrap, Depth + 1);
++Idx;
while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
@@ -3123,6 +3196,21 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
}
}
}
+
+ // (A/B)/C --> A/(B*C) if safe and B*C can be folded.
+ if (const SCEVUDivExpr *OtherDiv = dyn_cast<SCEVUDivExpr>(LHS)) {
+ if (auto *DivisorConstant =
+ dyn_cast<SCEVConstant>(OtherDiv->getRHS())) {
+ bool Overflow = false;
+ APInt NewRHS =
+ DivisorConstant->getAPInt().umul_ov(RHSC->getAPInt(), Overflow);
+ if (Overflow) {
+ return getConstant(RHSC->getType(), 0, false);
+ }
+ return getUDivExpr(OtherDiv->getLHS(), getConstant(NewRHS));
+ }
+ }
+
// (A+B)/C --> (A/C + B/C) if safe and A/C and B/C can be folded.
if (const SCEVAddExpr *A = dyn_cast<SCEVAddExpr>(LHS)) {
SmallVector<const SCEV *, 4> Operands;
@@ -3574,12 +3662,13 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
for (unsigned i = 0, e = Ops.size()-1; i != e; ++i)
// X umax Y umax Y --> X umax Y
// X umax Y --> X, if X is always greater than Y
- if (Ops[i] == Ops[i+1] ||
- isKnownPredicate(ICmpInst::ICMP_UGE, Ops[i], Ops[i+1])) {
- Ops.erase(Ops.begin()+i+1, Ops.begin()+i+2);
+ if (Ops[i] == Ops[i + 1] || isKnownViaNonRecursiveReasoning(
+ ICmpInst::ICMP_UGE, Ops[i], Ops[i + 1])) {
+ Ops.erase(Ops.begin() + i + 1, Ops.begin() + i + 2);
--i; --e;
- } else if (isKnownPredicate(ICmpInst::ICMP_ULE, Ops[i], Ops[i+1])) {
- Ops.erase(Ops.begin()+i, Ops.begin()+i+1);
+ } else if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, Ops[i],
+ Ops[i + 1])) {
+ Ops.erase(Ops.begin() + i, Ops.begin() + i + 1);
--i; --e;
}
@@ -3606,14 +3695,35 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {
const SCEV *ScalarEvolution::getSMinExpr(const SCEV *LHS,
const SCEV *RHS) {
- // ~smax(~x, ~y) == smin(x, y).
- return getNotSCEV(getSMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getSMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getSMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ // ~smax(~x, ~y, ~z) == smin(x, y, z).
+ SmallVector<const SCEV *, 2> NotOps;
+ for (auto *S : Ops)
+ NotOps.push_back(getNotSCEV(S));
+ return getNotSCEV(getSMaxExpr(NotOps));
}
const SCEV *ScalarEvolution::getUMinExpr(const SCEV *LHS,
const SCEV *RHS) {
- // ~umax(~x, ~y) == umin(x, y)
- return getNotSCEV(getUMaxExpr(getNotSCEV(LHS), getNotSCEV(RHS)));
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getUMinExpr(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
+ assert(!Ops.empty() && "At least one operand must be!");
+ // Trivial case.
+ if (Ops.size() == 1)
+ return Ops[0];
+
+ // ~umax(~x, ~y, ~z) == umin(x, y, z).
+ SmallVector<const SCEV *, 2> NotOps;
+ for (auto *S : Ops)
+ NotOps.push_back(getNotSCEV(S));
+ return getNotSCEV(getUMaxExpr(NotOps));
}
const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
@@ -3665,13 +3775,15 @@ const SCEV *ScalarEvolution::getUnknown(Value *V) {
/// target-specific information.
bool ScalarEvolution::isSCEVable(Type *Ty) const {
// Integers and pointers are always SCEVable.
- return Ty->isIntegerTy() || Ty->isPointerTy();
+ return Ty->isIntOrPtrTy();
}
/// Return the size in bits of the specified type, for which isSCEVable must
/// return true.
uint64_t ScalarEvolution::getTypeSizeInBits(Type *Ty) const {
assert(isSCEVable(Ty) && "Type is not SCEVable!");
+ if (Ty->isPointerTy())
+ return getDataLayout().getIndexTypeSizeInBits(Ty);
return getDataLayout().getTypeSizeInBits(Ty);
}
@@ -3774,6 +3886,24 @@ void ScalarEvolution::eraseValueFromMap(Value *V) {
}
}
+/// Check whether value has nuw/nsw/exact set but SCEV does not.
+/// TODO: In reality it is better to check the poison recursevely
+/// but this is better than nothing.
+static bool SCEVLostPoisonFlags(const SCEV *S, const Value *V) {
+ if (auto *I = dyn_cast<Instruction>(V)) {
+ if (isa<OverflowingBinaryOperator>(I)) {
+ if (auto *NS = dyn_cast<SCEVNAryExpr>(S)) {
+ if (I->hasNoSignedWrap() && !NS->hasNoSignedWrap())
+ return true;
+ if (I->hasNoUnsignedWrap() && !NS->hasNoUnsignedWrap())
+ return true;
+ }
+ } else if (isa<PossiblyExactOperator>(I) && I->isExact())
+ return true;
+ }
+ return false;
+}
+
/// Return an existing SCEV if it exists, otherwise analyze the expression and
/// create a new one.
const SCEV *ScalarEvolution::getSCEV(Value *V) {
@@ -3787,7 +3917,7 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) {
// ValueExprMap before insert S->{V, 0} into ExprValueMap.
std::pair<ValueExprMapType::iterator, bool> Pair =
ValueExprMap.insert({SCEVCallbackVH(V, this), S});
- if (Pair.second) {
+ if (Pair.second && !SCEVLostPoisonFlags(S, V)) {
ExprValueMap[S].insert({V, nullptr});
// If S == Stripped + Offset, add Stripped -> {V, Offset} into
@@ -3890,8 +4020,7 @@ const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
const SCEV *
ScalarEvolution::getTruncateOrZeroExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or zero extend with non-integer arguments!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; // No conversion
@@ -3904,8 +4033,7 @@ const SCEV *
ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or zero extend with non-integer arguments!");
if (getTypeSizeInBits(SrcTy) == getTypeSizeInBits(Ty))
return V; // No conversion
@@ -3917,8 +4045,7 @@ ScalarEvolution::getTruncateOrSignExtend(const SCEV *V,
const SCEV *
ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or zero extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrZeroExtend cannot truncate!");
@@ -3930,8 +4057,7 @@ ScalarEvolution::getNoopOrZeroExtend(const SCEV *V, Type *Ty) {
const SCEV *
ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or sign extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrSignExtend cannot truncate!");
@@ -3943,8 +4069,7 @@ ScalarEvolution::getNoopOrSignExtend(const SCEV *V, Type *Ty) {
const SCEV *
ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot noop or any extend with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) <= getTypeSizeInBits(Ty) &&
"getNoopOrAnyExtend cannot truncate!");
@@ -3956,8 +4081,7 @@ ScalarEvolution::getNoopOrAnyExtend(const SCEV *V, Type *Ty) {
const SCEV *
ScalarEvolution::getTruncateOrNoop(const SCEV *V, Type *Ty) {
Type *SrcTy = V->getType();
- assert((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) &&
- (Ty->isIntegerTy() || Ty->isPointerTy()) &&
+ assert(SrcTy->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate or noop with non-integer arguments!");
assert(getTypeSizeInBits(SrcTy) >= getTypeSizeInBits(Ty) &&
"getTruncateOrNoop cannot extend!");
@@ -3981,15 +4105,32 @@ const SCEV *ScalarEvolution::getUMaxFromMismatchedTypes(const SCEV *LHS,
const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(const SCEV *LHS,
const SCEV *RHS) {
- const SCEV *PromotedLHS = LHS;
- const SCEV *PromotedRHS = RHS;
+ SmallVector<const SCEV *, 2> Ops = { LHS, RHS };
+ return getUMinFromMismatchedTypes(Ops);
+}
+
+const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(
+ SmallVectorImpl<const SCEV *> &Ops) {
+ assert(!Ops.empty() && "At least one operand must be!");
+ // Trivial case.
+ if (Ops.size() == 1)
+ return Ops[0];
+
+ // Find the max type first.
+ Type *MaxType = nullptr;
+ for (auto *S : Ops)
+ if (MaxType)
+ MaxType = getWiderType(MaxType, S->getType());
+ else
+ MaxType = S->getType();
- if (getTypeSizeInBits(LHS->getType()) > getTypeSizeInBits(RHS->getType()))
- PromotedRHS = getZeroExtendExpr(RHS, LHS->getType());
- else
- PromotedLHS = getNoopOrZeroExtend(LHS, RHS->getType());
+ // Extend all ops to max type.
+ SmallVector<const SCEV *, 2> PromotedOps;
+ for (auto *S : Ops)
+ PromotedOps.push_back(getNoopOrZeroExtend(S, MaxType));
- return getUMinExpr(PromotedLHS, PromotedRHS);
+ // Generate umin.
+ return getUMinExpr(PromotedOps);
}
const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
@@ -4066,37 +4207,90 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {
namespace {
+/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its start
+/// expression in case its Loop is L. If it is not L then
+/// if IgnoreOtherLoops is true then use AddRec itself
+/// otherwise rewrite cannot be done.
+/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {
public:
- static const SCEV *rewrite(const SCEV *S, const Loop *L,
- ScalarEvolution &SE) {
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE,
+ bool IgnoreOtherLoops = true) {
SCEVInitRewriter Rewriter(L, SE);
const SCEV *Result = Rewriter.visit(S);
- return Rewriter.isValid() ? Result : SE.getCouldNotCompute();
+ if (Rewriter.hasSeenLoopVariantSCEVUnknown())
+ return SE.getCouldNotCompute();
+ return Rewriter.hasSeenOtherLoops() && !IgnoreOtherLoops
+ ? SE.getCouldNotCompute()
+ : Result;
}
const SCEV *visitUnknown(const SCEVUnknown *Expr) {
if (!SE.isLoopInvariant(Expr, L))
- Valid = false;
+ SeenLoopVariantSCEVUnknown = true;
return Expr;
}
const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
- // Only allow AddRecExprs for this loop.
+ // Only re-write AddRecExprs for this loop.
if (Expr->getLoop() == L)
return Expr->getStart();
- Valid = false;
+ SeenOtherLoops = true;
return Expr;
}
- bool isValid() { return Valid; }
+ bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
+
+ bool hasSeenOtherLoops() { return SeenOtherLoops; }
private:
explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE)
: SCEVRewriteVisitor(SE), L(L) {}
const Loop *L;
- bool Valid = true;
+ bool SeenLoopVariantSCEVUnknown = false;
+ bool SeenOtherLoops = false;
+};
+
+/// Takes SCEV S and Loop L. For each AddRec sub-expression, use its post
+/// increment expression in case its Loop is L. If it is not L then
+/// use AddRec itself.
+/// If SCEV contains non-invariant unknown SCEV rewrite cannot be done.
+class SCEVPostIncRewriter : public SCEVRewriteVisitor<SCEVPostIncRewriter> {
+public:
+ static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE) {
+ SCEVPostIncRewriter Rewriter(L, SE);
+ const SCEV *Result = Rewriter.visit(S);
+ return Rewriter.hasSeenLoopVariantSCEVUnknown()
+ ? SE.getCouldNotCompute()
+ : Result;
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ if (!SE.isLoopInvariant(Expr, L))
+ SeenLoopVariantSCEVUnknown = true;
+ return Expr;
+ }
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) {
+ // Only re-write AddRecExprs for this loop.
+ if (Expr->getLoop() == L)
+ return Expr->getPostIncExpr(SE);
+ SeenOtherLoops = true;
+ return Expr;
+ }
+
+ bool hasSeenLoopVariantSCEVUnknown() { return SeenLoopVariantSCEVUnknown; }
+
+ bool hasSeenOtherLoops() { return SeenOtherLoops; }
+
+private:
+ explicit SCEVPostIncRewriter(const Loop *L, ScalarEvolution &SE)
+ : SCEVRewriteVisitor(SE), L(L) {}
+
+ const Loop *L;
+ bool SeenLoopVariantSCEVUnknown = false;
+ bool SeenOtherLoops = false;
};
/// This class evaluates the compare condition by matching it against the
@@ -4668,7 +4862,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
const SCEV *StartExtended = getExtendedExpr(StartVal, Signed);
if (PredIsKnownFalse(StartVal, StartExtended)) {
- DEBUG(dbgs() << "P2 is compile-time false\n";);
+ LLVM_DEBUG(dbgs() << "P2 is compile-time false\n";);
return None;
}
@@ -4676,7 +4870,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
// NSSW or NUSW)
const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true);
if (PredIsKnownFalse(Accum, AccumExtended)) {
- DEBUG(dbgs() << "P3 is compile-time false\n";);
+ LLVM_DEBUG(dbgs() << "P3 is compile-time false\n";);
return None;
}
@@ -4685,7 +4879,7 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI
if (Expr != ExtendedExpr &&
!isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {
const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr);
- DEBUG (dbgs() << "Added Predicate: " << *Pred);
+ LLVM_DEBUG(dbgs() << "Added Predicate: " << *Pred);
Predicates.push_back(Pred);
}
};
@@ -4948,7 +5142,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {
// by one iteration:
// PHI(f(0), f({1,+,1})) --> f({0,+,1})
const SCEV *Shifted = SCEVShiftRewriter::rewrite(BEValue, L, *this);
- const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this);
+ const SCEV *Start = SCEVInitRewriter::rewrite(Shifted, L, *this, false);
if (Shifted != getCouldNotCompute() &&
Start != getCouldNotCompute()) {
const SCEV *StartVal = getSCEV(StartValueV);
@@ -5510,6 +5704,25 @@ ScalarEvolution::getRangeRef(const SCEV *S,
APInt::getSignedMaxValue(BitWidth).ashr(NS - 1) + 1));
}
+ // A range of Phi is a subset of union of all ranges of its input.
+ if (const PHINode *Phi = dyn_cast<PHINode>(U->getValue())) {
+ // Make sure that we do not run over cycled Phis.
+ if (PendingPhiRanges.insert(Phi).second) {
+ ConstantRange RangeFromOps(BitWidth, /*isFullSet=*/false);
+ for (auto &Op : Phi->operands()) {
+ auto OpRange = getRangeRef(getSCEV(Op), SignHint);
+ RangeFromOps = RangeFromOps.unionWith(OpRange);
+ // No point to continue if we already have a full set.
+ if (RangeFromOps.isFullSet())
+ break;
+ }
+ ConservativeResult = ConservativeResult.intersectWith(RangeFromOps);
+ bool Erased = PendingPhiRanges.erase(Phi);
+ assert(Erased && "Failed to erase Phi properly?");
+ (void) Erased;
+ }
+ }
+
return setRange(U, SignHint, std::move(ConservativeResult));
}
@@ -6129,33 +6342,33 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
}
break;
- case Instruction::Shl:
- // Turn shift left of a constant amount into a multiply.
- if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
- uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
-
- // If the shift count is not less than the bitwidth, the result of
- // the shift is undefined. Don't try to analyze it, because the
- // resolution chosen here may differ from the resolution chosen in
- // other parts of the compiler.
- if (SA->getValue().uge(BitWidth))
- break;
+ case Instruction::Shl:
+ // Turn shift left of a constant amount into a multiply.
+ if (ConstantInt *SA = dyn_cast<ConstantInt>(BO->RHS)) {
+ uint32_t BitWidth = cast<IntegerType>(SA->getType())->getBitWidth();
- // It is currently not resolved how to interpret NSW for left
- // shift by BitWidth - 1, so we avoid applying flags in that
- // case. Remove this check (or this comment) once the situation
- // is resolved. See
- // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
- // and http://reviews.llvm.org/D8890 .
- auto Flags = SCEV::FlagAnyWrap;
- if (BO->Op && SA->getValue().ult(BitWidth - 1))
- Flags = getNoWrapFlagsFromUB(BO->Op);
+ // If the shift count is not less than the bitwidth, the result of
+ // the shift is undefined. Don't try to analyze it, because the
+ // resolution chosen here may differ from the resolution chosen in
+ // other parts of the compiler.
+ if (SA->getValue().uge(BitWidth))
+ break;
- Constant *X = ConstantInt::get(getContext(),
- APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
- return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
- }
- break;
+ // It is currently not resolved how to interpret NSW for left
+ // shift by BitWidth - 1, so we avoid applying flags in that
+ // case. Remove this check (or this comment) once the situation
+ // is resolved. See
+ // http://lists.llvm.org/pipermail/llvm-dev/2015-April/084195.html
+ // and http://reviews.llvm.org/D8890 .
+ auto Flags = SCEV::FlagAnyWrap;
+ if (BO->Op && SA->getValue().ult(BitWidth - 1))
+ Flags = getNoWrapFlagsFromUB(BO->Op);
+
+ Constant *X = ConstantInt::get(
+ getContext(), APInt::getOneBitSet(BitWidth, SA->getZExtValue()));
+ return getMulExpr(getSCEV(BO->LHS), getSCEV(X), Flags);
+ }
+ break;
case Instruction::AShr: {
// AShr X, C, where C is a constant.
@@ -6379,11 +6592,11 @@ const SCEV *ScalarEvolution::getExitCount(const Loop *L,
const SCEV *
ScalarEvolution::getPredicatedBackedgeTakenCount(const Loop *L,
SCEVUnionPredicate &Preds) {
- return getPredicatedBackedgeTakenInfo(L).getExact(this, &Preds);
+ return getPredicatedBackedgeTakenInfo(L).getExact(L, this, &Preds);
}
const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L) {
- return getBackedgeTakenInfo(L).getExact(this);
+ return getBackedgeTakenInfo(L).getExact(L, this);
}
/// Similar to getBackedgeTakenCount, except return the least SCEV value that is
@@ -6402,9 +6615,8 @@ PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) {
BasicBlock *Header = L->getHeader();
// Push all Loop-header PHIs onto the Worklist stack.
- for (BasicBlock::iterator I = Header->begin();
- PHINode *PN = dyn_cast<PHINode>(I); ++I)
- Worklist.push_back(PN);
+ for (PHINode &PN : Header->phis())
+ Worklist.push_back(&PN);
}
const ScalarEvolution::BackedgeTakenInfo &
@@ -6441,8 +6653,13 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// must be cleared in this scope.
BackedgeTakenInfo Result = computeBackedgeTakenCount(L);
- if (Result.getExact(this) != getCouldNotCompute()) {
- assert(isLoopInvariant(Result.getExact(this), L) &&
+ // In product build, there are no usage of statistic.
+ (void)NumTripCountsComputed;
+ (void)NumTripCountsNotComputed;
+#if LLVM_ENABLE_STATS || !defined(NDEBUG)
+ const SCEV *BEExact = Result.getExact(L, this);
+ if (BEExact != getCouldNotCompute()) {
+ assert(isLoopInvariant(BEExact, L) &&
isLoopInvariant(Result.getMax(this), L) &&
"Computed backedge-taken count isn't loop invariant for loop!");
++NumTripCountsComputed;
@@ -6452,6 +6669,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// Only count loops that have phi nodes as not being computable.
++NumTripCountsNotComputed;
}
+#endif // LLVM_ENABLE_STATS || !defined(NDEBUG)
// Now that we know more about the trip count for this loop, forget any
// existing SCEV values for PHI nodes in this loop since they are only
@@ -6587,6 +6805,12 @@ void ScalarEvolution::forgetLoop(const Loop *L) {
}
}
+void ScalarEvolution::forgetTopmostLoop(const Loop *L) {
+ while (Loop *Parent = L->getParentLoop())
+ L = Parent;
+ forgetLoop(L);
+}
+
void ScalarEvolution::forgetValue(Value *V) {
Instruction *I = dyn_cast<Instruction>(V);
if (!I) return;
@@ -6615,28 +6839,35 @@ void ScalarEvolution::forgetValue(Value *V) {
}
/// Get the exact loop backedge taken count considering all loop exits. A
-/// computable result can only be returned for loops with a single exit.
-/// Returning the minimum taken count among all exits is incorrect because one
-/// of the loop's exit limit's may have been skipped. howFarToZero assumes that
-/// the limit of each loop test is never skipped. This is a valid assumption as
-/// long as the loop exits via that test. For precise results, it is the
-/// caller's responsibility to specify the relevant loop exit using
-/// getExact(ExitingBlock, SE).
+/// computable result can only be returned for loops with all exiting blocks
+/// dominating the latch. howFarToZero assumes that the limit of each loop test
+/// is never skipped. This is a valid assumption as long as the loop exits via
+/// that test. For precise results, it is the caller's responsibility to specify
+/// the relevant loop exiting block using getExact(ExitingBlock, SE).
const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE,
+ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
SCEVUnionPredicate *Preds) const {
// If any exits were not computable, the loop is not computable.
if (!isComplete() || ExitNotTaken.empty())
return SE->getCouldNotCompute();
- const SCEV *BECount = nullptr;
+ const BasicBlock *Latch = L->getLoopLatch();
+ // All exiting blocks we have collected must dominate the only backedge.
+ if (!Latch)
+ return SE->getCouldNotCompute();
+
+ // All exiting blocks we have gathered dominate loop's latch, so exact trip
+ // count is simply a minimum out of all these calculated exit counts.
+ SmallVector<const SCEV *, 2> Ops;
for (auto &ENT : ExitNotTaken) {
- assert(ENT.ExactNotTaken != SE->getCouldNotCompute() && "bad exit SCEV");
+ const SCEV *BECount = ENT.ExactNotTaken;
+ assert(BECount != SE->getCouldNotCompute() && "Bad exit SCEV!");
+ assert(SE->DT.dominates(ENT.ExitingBlock, Latch) &&
+ "We should only have known counts for exiting blocks that dominate "
+ "latch!");
+
+ Ops.push_back(BECount);
- if (!BECount)
- BECount = ENT.ExactNotTaken;
- else if (BECount != ENT.ExactNotTaken)
- return SE->getCouldNotCompute();
if (Preds && !ENT.hasAlwaysTruePredicate())
Preds->add(ENT.Predicate.get());
@@ -6644,8 +6875,7 @@ ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE,
"Predicate should be always true!");
}
- assert(BECount && "Invalid not taken count for loop exit");
- return BECount;
+ return SE->getUMinFromMismatchedTypes(Ops);
}
/// Get the exact not taken count for this loop exit.
@@ -6842,99 +7072,60 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
bool AllowPredicates) {
- // Okay, we've chosen an exiting block. See what condition causes us to exit
- // at this block and remember the exit block and whether all other targets
- // lead to the loop header.
- bool MustExecuteLoopHeader = true;
- BasicBlock *Exit = nullptr;
- for (auto *SBB : successors(ExitingBlock))
- if (!L->contains(SBB)) {
- if (Exit) // Multiple exit successors.
- return getCouldNotCompute();
- Exit = SBB;
- } else if (SBB != L->getHeader()) {
- MustExecuteLoopHeader = false;
- }
-
- // At this point, we know we have a conditional branch that determines whether
- // the loop is exited. However, we don't know if the branch is executed each
- // time through the loop. If not, then the execution count of the branch will
- // not be equal to the trip count of the loop.
- //
- // Currently we check for this by checking to see if the Exit branch goes to
- // the loop header. If so, we know it will always execute the same number of
- // times as the loop. We also handle the case where the exit block *is* the
- // loop header. This is common for un-rotated loops.
- //
- // If both of those tests fail, walk up the unique predecessor chain to the
- // header, stopping if there is an edge that doesn't exit the loop. If the
- // header is reached, the execution count of the branch will be equal to the
- // trip count of the loop.
- //
- // More extensive analysis could be done to handle more cases here.
- //
- if (!MustExecuteLoopHeader && ExitingBlock != L->getHeader()) {
- // The simple checks failed, try climbing the unique predecessor chain
- // up to the header.
- bool Ok = false;
- for (BasicBlock *BB = ExitingBlock; BB; ) {
- BasicBlock *Pred = BB->getUniquePredecessor();
- if (!Pred)
- return getCouldNotCompute();
- TerminatorInst *PredTerm = Pred->getTerminator();
- for (const BasicBlock *PredSucc : PredTerm->successors()) {
- if (PredSucc == BB)
- continue;
- // If the predecessor has a successor that isn't BB and isn't
- // outside the loop, assume the worst.
- if (L->contains(PredSucc))
- return getCouldNotCompute();
- }
- if (Pred == L->getHeader()) {
- Ok = true;
- break;
- }
- BB = Pred;
- }
- if (!Ok)
- return getCouldNotCompute();
- }
+ assert(L->contains(ExitingBlock) && "Exit count for non-loop block?");
+ // If our exiting block does not dominate the latch, then its connection with
+ // loop's exit limit may be far from trivial.
+ const BasicBlock *Latch = L->getLoopLatch();
+ if (!Latch || !DT.dominates(ExitingBlock, Latch))
+ return getCouldNotCompute();
bool IsOnlyExit = (L->getExitingBlock() != nullptr);
TerminatorInst *Term = ExitingBlock->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
+ bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
+ assert(ExitIfTrue == L->contains(BI->getSuccessor(1)) &&
+ "It should have one successor in loop and one exit block!");
// Proceed to the next level to examine the exit condition expression.
return computeExitLimitFromCond(
- L, BI->getCondition(), BI->getSuccessor(0), BI->getSuccessor(1),
+ L, BI->getCondition(), ExitIfTrue,
/*ControlsExit=*/IsOnlyExit, AllowPredicates);
}
- if (SwitchInst *SI = dyn_cast<SwitchInst>(Term))
+ if (SwitchInst *SI = dyn_cast<SwitchInst>(Term)) {
+ // For switch, make sure that there is a single exit from the loop.
+ BasicBlock *Exit = nullptr;
+ for (auto *SBB : successors(ExitingBlock))
+ if (!L->contains(SBB)) {
+ if (Exit) // Multiple exit successors.
+ return getCouldNotCompute();
+ Exit = SBB;
+ }
+ assert(Exit && "Exiting block must have at least one exit");
return computeExitLimitFromSingleExitSwitch(L, SI, Exit,
/*ControlsExit=*/IsOnlyExit);
+ }
return getCouldNotCompute();
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCond(
- const Loop *L, Value *ExitCond, BasicBlock *TBB, BasicBlock *FBB,
+ const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
- ScalarEvolution::ExitLimitCacheTy Cache(L, TBB, FBB, AllowPredicates);
- return computeExitLimitFromCondCached(Cache, L, ExitCond, TBB, FBB,
+ ScalarEvolution::ExitLimitCacheTy Cache(L, ExitIfTrue, AllowPredicates);
+ return computeExitLimitFromCondCached(Cache, L, ExitCond, ExitIfTrue,
ControlsExit, AllowPredicates);
}
Optional<ScalarEvolution::ExitLimit>
ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
- BasicBlock *TBB, BasicBlock *FBB,
- bool ControlsExit, bool AllowPredicates) {
+ bool ExitIfTrue, bool ControlsExit,
+ bool AllowPredicates) {
(void)this->L;
- (void)this->TBB;
- (void)this->FBB;
+ (void)this->ExitIfTrue;
(void)this->AllowPredicates;
- assert(this->L == L && this->TBB == TBB && this->FBB == FBB &&
+ assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
this->AllowPredicates == AllowPredicates &&
"Variance in assumed invariant key components!");
auto Itr = TripCountMap.find({ExitCond, ControlsExit});
@@ -6944,47 +7135,48 @@ ScalarEvolution::ExitLimitCache::find(const Loop *L, Value *ExitCond,
}
void ScalarEvolution::ExitLimitCache::insert(const Loop *L, Value *ExitCond,
- BasicBlock *TBB, BasicBlock *FBB,
+ bool ExitIfTrue,
bool ControlsExit,
bool AllowPredicates,
const ExitLimit &EL) {
- assert(this->L == L && this->TBB == TBB && this->FBB == FBB &&
+ assert(this->L == L && this->ExitIfTrue == ExitIfTrue &&
this->AllowPredicates == AllowPredicates &&
"Variance in assumed invariant key components!");
auto InsertResult = TripCountMap.insert({{ExitCond, ControlsExit}, EL});
assert(InsertResult.second && "Expected successful insertion!");
(void)InsertResult;
+ (void)ExitIfTrue;
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB,
- BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) {
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
if (auto MaybeEL =
- Cache.find(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates))
+ Cache.find(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
return *MaybeEL;
- ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, TBB, FBB,
+ ExitLimit EL = computeExitLimitFromCondImpl(Cache, L, ExitCond, ExitIfTrue,
ControlsExit, AllowPredicates);
- Cache.insert(L, ExitCond, TBB, FBB, ControlsExit, AllowPredicates, EL);
+ Cache.insert(L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates, EL);
return EL;
}
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
- ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, BasicBlock *TBB,
- BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) {
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
// Check if the controlling expression for this loop is an And or Or.
if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
if (BO->getOpcode() == Instruction::And) {
// Recurse on the operands of the and.
- bool EitherMayExit = L->contains(TBB);
+ bool EitherMayExit = !ExitIfTrue;
ExitLimit EL0 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, BO->getOperand(0), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, BO->getOperand(1), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (EitherMayExit) {
@@ -7006,7 +7198,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
} else {
// Both conditions must be true at the same time for the loop to exit.
// For now, be conservative.
- assert(L->contains(FBB) && "Loop block has no successor in loop!");
if (EL0.MaxNotTaken == EL1.MaxNotTaken)
MaxBECount = EL0.MaxNotTaken;
if (EL0.ExactNotTaken == EL1.ExactNotTaken)
@@ -7027,13 +7218,13 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
}
if (BO->getOpcode() == Instruction::Or) {
// Recurse on the operands of the or.
- bool EitherMayExit = L->contains(FBB);
+ bool EitherMayExit = ExitIfTrue;
ExitLimit EL0 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(0), TBB, FBB, ControlsExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, BO->getOperand(0), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
ExitLimit EL1 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(1), TBB, FBB, ControlsExit && !EitherMayExit,
- AllowPredicates);
+ Cache, L, BO->getOperand(1), ExitIfTrue,
+ ControlsExit && !EitherMayExit, AllowPredicates);
const SCEV *BECount = getCouldNotCompute();
const SCEV *MaxBECount = getCouldNotCompute();
if (EitherMayExit) {
@@ -7055,7 +7246,6 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
} else {
// Both conditions must be false at the same time for the loop to exit.
// For now, be conservative.
- assert(L->contains(TBB) && "Loop block has no successor in loop!");
if (EL0.MaxNotTaken == EL1.MaxNotTaken)
MaxBECount = EL0.MaxNotTaken;
if (EL0.ExactNotTaken == EL1.ExactNotTaken)
@@ -7071,12 +7261,12 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
// Proceed to the next level to examine the icmp.
if (ICmpInst *ExitCondICmp = dyn_cast<ICmpInst>(ExitCond)) {
ExitLimit EL =
- computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit);
+ computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit);
if (EL.hasFullInfo() || !AllowPredicates)
return EL;
// Try again, but use SCEV predicates this time.
- return computeExitLimitFromICmp(L, ExitCondICmp, TBB, FBB, ControlsExit,
+ return computeExitLimitFromICmp(L, ExitCondICmp, ExitIfTrue, ControlsExit,
/*AllowPredicates=*/true);
}
@@ -7085,7 +7275,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
// preserve the CFG and is temporarily leaving constant conditions
// in place.
if (ConstantInt *CI = dyn_cast<ConstantInt>(ExitCond)) {
- if (L->contains(FBB) == !CI->getZExtValue())
+ if (ExitIfTrue == !CI->getZExtValue())
// The backedge is always taken.
return getCouldNotCompute();
else
@@ -7094,19 +7284,18 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
}
// If it's not an integer or pointer comparison then compute it the hard way.
- return computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+ return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
}
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
- BasicBlock *TBB,
- BasicBlock *FBB,
+ bool ExitIfTrue,
bool ControlsExit,
bool AllowPredicates) {
// If the condition was exit on true, convert the condition to exit on false
ICmpInst::Predicate Pred;
- if (!L->contains(FBB))
+ if (!ExitIfTrue)
Pred = ExitCond->getPredicate();
else
Pred = ExitCond->getInversePredicate();
@@ -7188,7 +7377,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
}
auto *ExhaustiveCount =
- computeExitCountExhaustively(L, ExitCond, !L->contains(TBB));
+ computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
if (!isa<SCEVCouldNotCompute>(ExhaustiveCount))
return ExhaustiveCount;
@@ -7638,12 +7827,9 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,
if (!Latch)
return nullptr;
- for (auto &I : *Header) {
- PHINode *PHI = dyn_cast<PHINode>(&I);
- if (!PHI) break;
- auto *StartCST = getOtherIncomingValue(PHI, Latch);
- if (!StartCST) continue;
- CurrentIterVals[PHI] = StartCST;
+ for (PHINode &PHI : Header->phis()) {
+ if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
+ CurrentIterVals[&PHI] = StartCST;
}
if (!CurrentIterVals.count(PN))
return RetVal = nullptr;
@@ -7720,13 +7906,9 @@ const SCEV *ScalarEvolution::computeExitCountExhaustively(const Loop *L,
BasicBlock *Latch = L->getLoopLatch();
assert(Latch && "Should follow from NumIncomingValues == 2!");
- for (auto &I : *Header) {
- PHINode *PHI = dyn_cast<PHINode>(&I);
- if (!PHI)
- break;
- auto *StartCST = getOtherIncomingValue(PHI, Latch);
- if (!StartCST) continue;
- CurrentIterVals[PHI] = StartCST;
+ for (PHINode &PHI : Header->phis()) {
+ if (auto *StartCST = getOtherIncomingValue(&PHI, Latch))
+ CurrentIterVals[&PHI] = StartCST;
}
if (!CurrentIterVals.count(PN))
return getCouldNotCompute();
@@ -8107,6 +8289,14 @@ const SCEV *ScalarEvolution::getSCEVAtScope(Value *V, const Loop *L) {
return getSCEVAtScope(getSCEV(V), L);
}
+const SCEV *ScalarEvolution::stripInjectiveFunctions(const SCEV *S) const {
+ if (const SCEVZeroExtendExpr *ZExt = dyn_cast<SCEVZeroExtendExpr>(S))
+ return stripInjectiveFunctions(ZExt->getOperand());
+ if (const SCEVSignExtendExpr *SExt = dyn_cast<SCEVSignExtendExpr>(S))
+ return stripInjectiveFunctions(SExt->getOperand());
+ return S;
+}
+
/// Finds the minimum unsigned root of the following equation:
///
/// A * X = B (mod N)
@@ -8236,7 +8426,9 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
return getCouldNotCompute(); // Otherwise it will loop infinitely.
}
- const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(V);
+ const SCEVAddRecExpr *AddRec =
+ dyn_cast<SCEVAddRecExpr>(stripInjectiveFunctions(V));
+
if (!AddRec && AllowPredicates)
// Try to make this an AddRec using runtime tests, in the first X
// iterations of this loop, where X is the SCEV expression found by the
@@ -8644,43 +8836,88 @@ bool ScalarEvolution::isKnownNonZero(const SCEV *S) {
return isKnownNegative(S) || isKnownPositive(S);
}
+std::pair<const SCEV *, const SCEV *>
+ScalarEvolution::SplitIntoInitAndPostInc(const Loop *L, const SCEV *S) {
+ // Compute SCEV on entry of loop L.
+ const SCEV *Start = SCEVInitRewriter::rewrite(S, L, *this);
+ if (Start == getCouldNotCompute())
+ return { Start, Start };
+ // Compute post increment SCEV for loop L.
+ const SCEV *PostInc = SCEVPostIncRewriter::rewrite(S, L, *this);
+ assert(PostInc != getCouldNotCompute() && "Unexpected could not compute");
+ return { Start, PostInc };
+}
+
+bool ScalarEvolution::isKnownViaInduction(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS) {
+ // First collect all loops.
+ SmallPtrSet<const Loop *, 8> LoopsUsed;
+ getUsedLoops(LHS, LoopsUsed);
+ getUsedLoops(RHS, LoopsUsed);
+
+ if (LoopsUsed.empty())
+ return false;
+
+ // Domination relationship must be a linear order on collected loops.
+#ifndef NDEBUG
+ for (auto *L1 : LoopsUsed)
+ for (auto *L2 : LoopsUsed)
+ assert((DT.dominates(L1->getHeader(), L2->getHeader()) ||
+ DT.dominates(L2->getHeader(), L1->getHeader())) &&
+ "Domination relationship is not a linear order");
+#endif
+
+ const Loop *MDL =
+ *std::max_element(LoopsUsed.begin(), LoopsUsed.end(),
+ [&](const Loop *L1, const Loop *L2) {
+ return DT.properlyDominates(L1->getHeader(), L2->getHeader());
+ });
+
+ // Get init and post increment value for LHS.
+ auto SplitLHS = SplitIntoInitAndPostInc(MDL, LHS);
+ // if LHS contains unknown non-invariant SCEV then bail out.
+ if (SplitLHS.first == getCouldNotCompute())
+ return false;
+ assert (SplitLHS.second != getCouldNotCompute() && "Unexpected CNC");
+ // Get init and post increment value for RHS.
+ auto SplitRHS = SplitIntoInitAndPostInc(MDL, RHS);
+ // if RHS contains unknown non-invariant SCEV then bail out.
+ if (SplitRHS.first == getCouldNotCompute())
+ return false;
+ assert (SplitRHS.second != getCouldNotCompute() && "Unexpected CNC");
+ // It is possible that init SCEV contains an invariant load but it does
+ // not dominate MDL and is not available at MDL loop entry, so we should
+ // check it here.
+ if (!isAvailableAtLoopEntry(SplitLHS.first, MDL) ||
+ !isAvailableAtLoopEntry(SplitRHS.first, MDL))
+ return false;
+
+ return isLoopEntryGuardedByCond(MDL, Pred, SplitLHS.first, SplitRHS.first) &&
+ isLoopBackedgeGuardedByCond(MDL, Pred, SplitLHS.second,
+ SplitRHS.second);
+}
+
bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
// Canonicalize the inputs first.
(void)SimplifyICmpOperands(Pred, LHS, RHS);
- // If LHS or RHS is an addrec, check to see if the condition is true in
- // every iteration of the loop.
- // If LHS and RHS are both addrec, both conditions must be true in
- // every iteration of the loop.
- const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
- const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
- bool LeftGuarded = false;
- bool RightGuarded = false;
- if (LAR) {
- const Loop *L = LAR->getLoop();
- if (isLoopEntryGuardedByCond(L, Pred, LAR->getStart(), RHS) &&
- isLoopBackedgeGuardedByCond(L, Pred, LAR->getPostIncExpr(*this), RHS)) {
- if (!RAR) return true;
- LeftGuarded = true;
- }
- }
- if (RAR) {
- const Loop *L = RAR->getLoop();
- if (isLoopEntryGuardedByCond(L, Pred, LHS, RAR->getStart()) &&
- isLoopBackedgeGuardedByCond(L, Pred, LHS, RAR->getPostIncExpr(*this))) {
- if (!LAR) return true;
- RightGuarded = true;
- }
- }
- if (LeftGuarded && RightGuarded)
+ if (isKnownViaInduction(Pred, LHS, RHS))
return true;
if (isKnownPredicateViaSplitting(Pred, LHS, RHS))
return true;
- // Otherwise see what can be done with known constant ranges.
- return isKnownPredicateViaConstantRanges(Pred, LHS, RHS);
+ // Otherwise see what can be done with some simple reasoning.
+ return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
+}
+
+bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
+ const SCEVAddRecExpr *LHS,
+ const SCEV *RHS) {
+ const Loop *L = LHS->getLoop();
+ return isLoopEntryGuardedByCond(L, Pred, LHS->getStart(), RHS) &&
+ isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
}
bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
@@ -8947,7 +9184,7 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return true;
- if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
+ if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
BasicBlock *Latch = L->getLoopLatch();
@@ -9052,9 +9289,68 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return false;
- if (isKnownPredicateViaConstantRanges(Pred, LHS, RHS))
+ // Both LHS and RHS must be available at loop entry.
+ assert(isAvailableAtLoopEntry(LHS, L) &&
+ "LHS is not available at Loop Entry");
+ assert(isAvailableAtLoopEntry(RHS, L) &&
+ "RHS is not available at Loop Entry");
+
+ if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
+ // If we cannot prove strict comparison (e.g. a > b), maybe we can prove
+ // the facts (a >= b && a != b) separately. A typical situation is when the
+ // non-strict comparison is known from ranges and non-equality is known from
+ // dominating predicates. If we are proving strict comparison, we always try
+ // to prove non-equality and non-strict comparison separately.
+ auto NonStrictPredicate = ICmpInst::getNonStrictPredicate(Pred);
+ const bool ProvingStrictComparison = (Pred != NonStrictPredicate);
+ bool ProvedNonStrictComparison = false;
+ bool ProvedNonEquality = false;
+
+ if (ProvingStrictComparison) {
+ ProvedNonStrictComparison =
+ isKnownViaNonRecursiveReasoning(NonStrictPredicate, LHS, RHS);
+ ProvedNonEquality =
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_NE, LHS, RHS);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+
+ // Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
+ auto ProveViaGuard = [&](BasicBlock *Block) {
+ if (isImpliedViaGuard(Block, Pred, LHS, RHS))
+ return true;
+ if (ProvingStrictComparison) {
+ if (!ProvedNonStrictComparison)
+ ProvedNonStrictComparison =
+ isImpliedViaGuard(Block, NonStrictPredicate, LHS, RHS);
+ if (!ProvedNonEquality)
+ ProvedNonEquality =
+ isImpliedViaGuard(Block, ICmpInst::ICMP_NE, LHS, RHS);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+ return false;
+ };
+
+ // Try to prove (Pred, LHS, RHS) using isImpliedCond.
+ auto ProveViaCond = [&](Value *Condition, bool Inverse) {
+ if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
+ return true;
+ if (ProvingStrictComparison) {
+ if (!ProvedNonStrictComparison)
+ ProvedNonStrictComparison =
+ isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
+ if (!ProvedNonEquality)
+ ProvedNonEquality =
+ isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
+ if (ProvedNonStrictComparison && ProvedNonEquality)
+ return true;
+ }
+ return false;
+ };
+
// Starting at the loop predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
// leading to the original header.
@@ -9063,7 +9359,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
Pair.first;
Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
- if (isImpliedViaGuard(Pair.first, Pred, LHS, RHS))
+ if (ProveViaGuard(Pair.first))
return true;
BranchInst *LoopEntryPredicate =
@@ -9072,9 +9368,8 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
LoopEntryPredicate->isUnconditional())
continue;
- if (isImpliedCond(Pred, LHS, RHS,
- LoopEntryPredicate->getCondition(),
- LoopEntryPredicate->getSuccessor(0) != Pair.second))
+ if (ProveViaCond(LoopEntryPredicate->getCondition(),
+ LoopEntryPredicate->getSuccessor(0) != Pair.second))
return true;
}
@@ -9086,7 +9381,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
if (!DT.dominates(CI, L->getHeader()))
continue;
- if (isImpliedCond(Pred, LHS, RHS, CI->getArgOperand(0), false))
+ if (ProveViaCond(CI->getArgOperand(0), false))
return true;
}
@@ -9321,17 +9616,25 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
return M - L;
}
- const SCEV *L, *R;
SCEV::NoWrapFlags Flags;
- if (splitBinaryAdd(Less, L, R, Flags))
- if (const auto *LC = dyn_cast<SCEVConstant>(L))
- if (R == More)
- return -(LC->getAPInt());
-
- if (splitBinaryAdd(More, L, R, Flags))
- if (const auto *LC = dyn_cast<SCEVConstant>(L))
- if (R == Less)
- return LC->getAPInt();
+ const SCEV *LLess = nullptr, *RLess = nullptr;
+ const SCEV *LMore = nullptr, *RMore = nullptr;
+ const SCEVConstant *C1 = nullptr, *C2 = nullptr;
+ // Compare (X + C1) vs X.
+ if (splitBinaryAdd(Less, LLess, RLess, Flags))
+ if ((C1 = dyn_cast<SCEVConstant>(LLess)))
+ if (RLess == More)
+ return -(C1->getAPInt());
+
+ // Compare X vs (X + C2).
+ if (splitBinaryAdd(More, LMore, RMore, Flags))
+ if ((C2 = dyn_cast<SCEVConstant>(LMore)))
+ if (RMore == Less)
+ return C2->getAPInt();
+
+ // Compare (X + C1) vs (X + C2).
+ if (C1 && C2 && RLess == RMore)
+ return C2->getAPInt() - C1->getAPInt();
return None;
}
@@ -9408,10 +9711,121 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
}
// Try to prove (1) or (2), as needed.
- return isLoopEntryGuardedByCond(L, Pred, FoundRHS,
+ return isAvailableAtLoopEntry(FoundRHS, L) &&
+ isLoopEntryGuardedByCond(L, Pred, FoundRHS,
getConstant(FoundRHSLimit));
}
+bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS,
+ const SCEV *FoundRHS, unsigned Depth) {
+ const PHINode *LPhi = nullptr, *RPhi = nullptr;
+
+ auto ClearOnExit = make_scope_exit([&]() {
+ if (LPhi) {
+ bool Erased = PendingMerges.erase(LPhi);
+ assert(Erased && "Failed to erase LPhi!");
+ (void)Erased;
+ }
+ if (RPhi) {
+ bool Erased = PendingMerges.erase(RPhi);
+ assert(Erased && "Failed to erase RPhi!");
+ (void)Erased;
+ }
+ });
+
+ // Find respective Phis and check that they are not being pending.
+ if (const SCEVUnknown *LU = dyn_cast<SCEVUnknown>(LHS))
+ if (auto *Phi = dyn_cast<PHINode>(LU->getValue())) {
+ if (!PendingMerges.insert(Phi).second)
+ return false;
+ LPhi = Phi;
+ }
+ if (const SCEVUnknown *RU = dyn_cast<SCEVUnknown>(RHS))
+ if (auto *Phi = dyn_cast<PHINode>(RU->getValue())) {
+ // If we detect a loop of Phi nodes being processed by this method, for
+ // example:
+ //
+ // %a = phi i32 [ %some1, %preheader ], [ %b, %latch ]
+ // %b = phi i32 [ %some2, %preheader ], [ %a, %latch ]
+ //
+ // we don't want to deal with a case that complex, so return conservative
+ // answer false.
+ if (!PendingMerges.insert(Phi).second)
+ return false;
+ RPhi = Phi;
+ }
+
+ // If none of LHS, RHS is a Phi, nothing to do here.
+ if (!LPhi && !RPhi)
+ return false;
+
+ // If there is a SCEVUnknown Phi we are interested in, make it left.
+ if (!LPhi) {
+ std::swap(LHS, RHS);
+ std::swap(FoundLHS, FoundRHS);
+ std::swap(LPhi, RPhi);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ assert(LPhi && "LPhi should definitely be a SCEVUnknown Phi!");
+ const BasicBlock *LBB = LPhi->getParent();
+ const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
+
+ auto ProvedEasily = [&](const SCEV *S1, const SCEV *S2) {
+ return isKnownViaNonRecursiveReasoning(Pred, S1, S2) ||
+ isImpliedCondOperandsViaRanges(Pred, S1, S2, FoundLHS, FoundRHS) ||
+ isImpliedViaOperations(Pred, S1, S2, FoundLHS, FoundRHS, Depth);
+ };
+
+ if (RPhi && RPhi->getParent() == LBB) {
+ // Case one: RHS is also a SCEVUnknown Phi from the same basic block.
+ // If we compare two Phis from the same block, and for each entry block
+ // the predicate is true for incoming values from this block, then the
+ // predicate is also true for the Phis.
+ for (const BasicBlock *IncBB : predecessors(LBB)) {
+ const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
+ const SCEV *R = getSCEV(RPhi->getIncomingValueForBlock(IncBB));
+ if (!ProvedEasily(L, R))
+ return false;
+ }
+ } else if (RAR && RAR->getLoop()->getHeader() == LBB) {
+ // Case two: RHS is also a Phi from the same basic block, and it is an
+ // AddRec. It means that there is a loop which has both AddRec and Unknown
+ // PHIs, for it we can compare incoming values of AddRec from above the loop
+ // and latch with their respective incoming values of LPhi.
+ // TODO: Generalize to handle loops with many inputs in a header.
+ if (LPhi->getNumIncomingValues() != 2) return false;
+
+ auto *RLoop = RAR->getLoop();
+ auto *Predecessor = RLoop->getLoopPredecessor();
+ assert(Predecessor && "Loop with AddRec with no predecessor?");
+ const SCEV *L1 = getSCEV(LPhi->getIncomingValueForBlock(Predecessor));
+ if (!ProvedEasily(L1, RAR->getStart()))
+ return false;
+ auto *Latch = RLoop->getLoopLatch();
+ assert(Latch && "Loop with AddRec with no latch?");
+ const SCEV *L2 = getSCEV(LPhi->getIncomingValueForBlock(Latch));
+ if (!ProvedEasily(L2, RAR->getPostIncExpr(*this)))
+ return false;
+ } else {
+ // In all other cases go over inputs of LHS and compare each of them to RHS,
+ // the predicate is true for (LHS, RHS) if it is true for all such pairs.
+ // At this point RHS is either a non-Phi, or it is a Phi from some block
+ // different from LBB.
+ for (const BasicBlock *IncBB : predecessors(LBB)) {
+ // Check that RHS is available in this block.
+ if (!dominates(RHS, IncBB))
+ return false;
+ const SCEV *L = getSCEV(LPhi->getIncomingValueForBlock(IncBB));
+ if (!ProvedEasily(L, RHS))
+ return false;
+ }
+ }
+ return true;
+}
+
bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
@@ -9565,13 +9979,14 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
};
// Acquire values from extensions.
+ auto *OrigLHS = LHS;
auto *OrigFoundLHS = FoundLHS;
LHS = GetOpFromSExt(LHS);
FoundLHS = GetOpFromSExt(FoundLHS);
// Is the SGT predicate can be proved trivially or using the found context.
auto IsSGTViaContext = [&](const SCEV *S1, const SCEV *S2) {
- return isKnownViaSimpleReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
+ return isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGT, S1, S2) ||
isImpliedViaOperations(ICmpInst::ICMP_SGT, S1, S2, OrigFoundLHS,
FoundRHS, Depth + 1);
};
@@ -9672,11 +10087,17 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
}
}
+ // If our expression contained SCEVUnknown Phis, and we split it down and now
+ // need to prove something for them, try to prove the predicate for every
+ // possible incoming values of those Phis.
+ if (isImpliedViaMerge(Pred, OrigLHS, RHS, OrigFoundLHS, FoundRHS, Depth + 1))
+ return true;
+
return false;
}
bool
-ScalarEvolution::isKnownViaSimpleReasoning(ICmpInst::Predicate Pred,
+ScalarEvolution::isKnownViaNonRecursiveReasoning(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
return isKnownPredicateViaConstantRanges(Pred, LHS, RHS) ||
IsKnownPredicateViaMinOrMax(*this, Pred, LHS, RHS) ||
@@ -9698,26 +10119,26 @@ ScalarEvolution::isImpliedCondOperandsHelper(ICmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
- if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
- isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
- if (isKnownViaSimpleReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
- isKnownViaSimpleReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SGE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_SLE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_ULE:
- if (isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
- isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, RHS, FoundRHS))
return true;
break;
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_UGE:
- if (isKnownViaSimpleReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
- isKnownViaSimpleReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
+ if (isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_UGE, LHS, FoundLHS) &&
+ isKnownViaNonRecursiveReasoning(ICmpInst::ICMP_ULE, RHS, FoundRHS))
return true;
break;
}
@@ -10195,6 +10616,31 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
return SE.getCouldNotCompute();
}
+const SCEVAddRecExpr *
+SCEVAddRecExpr::getPostIncExpr(ScalarEvolution &SE) const {
+ assert(getNumOperands() > 1 && "AddRec with zero step?");
+ // There is a temptation to just call getAddExpr(this, getStepRecurrence(SE)),
+ // but in this case we cannot guarantee that the value returned will be an
+ // AddRec because SCEV does not have a fixed point where it stops
+ // simplification: it is legal to return ({rec1} + {rec2}). For example, it
+ // may happen if we reach arithmetic depth limit while simplifying. So we
+ // construct the returned value explicitly.
+ SmallVector<const SCEV *, 3> Ops;
+ // If this is {A,+,B,+,C,...,+,N}, then its step is {B,+,C,+,...,+,N}, and
+ // (this + Step) is {A+B,+,B+C,+...,+,N}.
+ for (unsigned i = 0, e = getNumOperands() - 1; i < e; ++i)
+ Ops.push_back(SE.getAddExpr(getOperand(i), getOperand(i + 1)));
+ // We know that the last operand is not a constant zero (otherwise it would
+ // have been popped out earlier). This guarantees us that if the result has
+ // the same last operand, then it will also not be popped out, meaning that
+ // the returned value will be an AddRec.
+ const SCEV *Last = getOperand(getNumOperands() - 1);
+ assert(!Last->isZero() && "Recurrency with zero step?");
+ Ops.push_back(Last);
+ return cast<SCEVAddRecExpr>(SE.getAddRecExpr(Ops, getLoop(),
+ SCEV::FlagAnyWrap));
+}
+
// Return true when S contains at least an undef value.
static inline bool containsUndefs(const SCEV *S) {
return SCEVExprContains(S, [](const SCEV *S) {
@@ -10337,22 +10783,22 @@ void ScalarEvolution::collectParametricTerms(const SCEV *Expr,
SCEVCollectStrides StrideCollector(*this, Strides);
visitAll(Expr, StrideCollector);
- DEBUG({
- dbgs() << "Strides:\n";
- for (const SCEV *S : Strides)
- dbgs() << *S << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Strides:\n";
+ for (const SCEV *S : Strides)
+ dbgs() << *S << "\n";
+ });
for (const SCEV *S : Strides) {
SCEVCollectTerms TermCollector(Terms);
visitAll(S, TermCollector);
}
- DEBUG({
- dbgs() << "Terms:\n";
- for (const SCEV *T : Terms)
- dbgs() << *T << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Terms:\n";
+ for (const SCEV *T : Terms)
+ dbgs() << *T << "\n";
+ });
SCEVCollectAddRecMultiplies MulCollector(Terms, *this);
visitAll(Expr, MulCollector);
@@ -10463,18 +10909,18 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
if (!containsParameters(Terms))
return;
- DEBUG({
- dbgs() << "Terms:\n";
- for (const SCEV *T : Terms)
- dbgs() << *T << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Terms:\n";
+ for (const SCEV *T : Terms)
+ dbgs() << *T << "\n";
+ });
// Remove duplicates.
array_pod_sort(Terms.begin(), Terms.end());
Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
// Put larger terms first.
- std::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) {
+ llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) {
return numberOfTerms(LHS) > numberOfTerms(RHS);
});
@@ -10494,11 +10940,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
if (const SCEV *NewT = removeConstantFactors(*this, T))
NewTerms.push_back(NewT);
- DEBUG({
- dbgs() << "Terms after sorting:\n";
- for (const SCEV *T : NewTerms)
- dbgs() << *T << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Terms after sorting:\n";
+ for (const SCEV *T : NewTerms)
+ dbgs() << *T << "\n";
+ });
if (NewTerms.empty() || !findArrayDimensionsRec(*this, NewTerms, Sizes)) {
Sizes.clear();
@@ -10508,11 +10954,11 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
// The last element to be pushed into Sizes is the size of an element.
Sizes.push_back(ElementSize);
- DEBUG({
- dbgs() << "Sizes:\n";
- for (const SCEV *S : Sizes)
- dbgs() << *S << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Sizes:\n";
+ for (const SCEV *S : Sizes)
+ dbgs() << *S << "\n";
+ });
}
void ScalarEvolution::computeAccessFunctions(
@@ -10532,13 +10978,13 @@ void ScalarEvolution::computeAccessFunctions(
const SCEV *Q, *R;
SCEVDivision::divide(*this, Res, Sizes[i], &Q, &R);
- DEBUG({
- dbgs() << "Res: " << *Res << "\n";
- dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
- dbgs() << "Res divided by Sizes[i]:\n";
- dbgs() << "Quotient: " << *Q << "\n";
- dbgs() << "Remainder: " << *R << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Res: " << *Res << "\n";
+ dbgs() << "Sizes[i]: " << *Sizes[i] << "\n";
+ dbgs() << "Res divided by Sizes[i]:\n";
+ dbgs() << "Quotient: " << *Q << "\n";
+ dbgs() << "Remainder: " << *R << "\n";
+ });
Res = Q;
@@ -10566,11 +11012,11 @@ void ScalarEvolution::computeAccessFunctions(
std::reverse(Subscripts.begin(), Subscripts.end());
- DEBUG({
- dbgs() << "Subscripts:\n";
- for (const SCEV *S : Subscripts)
- dbgs() << *S << "\n";
- });
+ LLVM_DEBUG({
+ dbgs() << "Subscripts:\n";
+ for (const SCEV *S : Subscripts)
+ dbgs() << *S << "\n";
+ });
}
/// Splits the SCEV into two vectors of SCEVs representing the subscripts and
@@ -10644,17 +11090,17 @@ void ScalarEvolution::delinearize(const SCEV *Expr,
if (Subscripts.empty())
return;
- DEBUG({
- dbgs() << "succeeded to delinearize " << *Expr << "\n";
- dbgs() << "ArrayDecl[UnknownSize]";
- for (const SCEV *S : Sizes)
- dbgs() << "[" << *S << "]";
+ LLVM_DEBUG({
+ dbgs() << "succeeded to delinearize " << *Expr << "\n";
+ dbgs() << "ArrayDecl[UnknownSize]";
+ for (const SCEV *S : Sizes)
+ dbgs() << "[" << *S << "]";
- dbgs() << "\nArrayRef";
- for (const SCEV *S : Subscripts)
- dbgs() << "[" << *S << "]";
- dbgs() << "\n";
- });
+ dbgs() << "\nArrayRef";
+ for (const SCEV *S : Subscripts)
+ dbgs() << "[" << *S << "]";
+ dbgs() << "\n";
+ });
}
//===----------------------------------------------------------------------===//
@@ -10731,6 +11177,8 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)
LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),
ValueExprMap(std::move(Arg.ValueExprMap)),
PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),
+ PendingPhiRanges(std::move(Arg.PendingPhiRanges)),
+ PendingMerges(std::move(Arg.PendingMerges)),
MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),
BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),
PredicatedBackedgeTakenCounts(
@@ -10774,6 +11222,8 @@ ScalarEvolution::~ScalarEvolution() {
BTCI.second.clear();
assert(PendingLoopPredicates.empty() && "isImpliedCond garbage");
+ assert(PendingPhiRanges.empty() && "getRangeRef garbage");
+ assert(PendingMerges.empty() && "isImpliedViaMerge garbage");
assert(!WalkingBEDominatingConds && "isLoopBackedgeGuardedByCond garbage!");
assert(!ProvingSplitPredicate && "ProvingSplitPredicate garbage!");
}
@@ -11184,9 +11634,13 @@ ScalarEvolution::forgetMemoizedResults(const SCEV *S) {
RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
}
-void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
+void
+ScalarEvolution::getUsedLoops(const SCEV *S,
+ SmallPtrSetImpl<const Loop *> &LoopsUsed) {
struct FindUsedLoops {
- SmallPtrSet<const Loop *, 8> LoopsUsed;
+ FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
+ : LoopsUsed(LoopsUsed) {}
+ SmallPtrSetImpl<const Loop *> &LoopsUsed;
bool follow(const SCEV *S) {
if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
LoopsUsed.insert(AR->getLoop());
@@ -11196,10 +11650,14 @@ void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
bool isDone() const { return false; }
};
- FindUsedLoops F;
+ FindUsedLoops F(LoopsUsed);
SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+}
- for (auto *L : F.LoopsUsed)
+void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
+ SmallPtrSet<const Loop *, 8> LoopsUsed;
+ getUsedLoops(S, LoopsUsed);
+ for (auto *L : LoopsUsed)
LoopUsers[L].push_back(S);
}
@@ -11482,6 +11940,12 @@ private:
if (!PredicatedRewrite)
return Expr;
for (auto *P : PredicatedRewrite->second){
+ // Wrap predicates from outer loops are not supported.
+ if (auto *WP = dyn_cast<const SCEVWrapPredicate>(P)) {
+ auto *AR = cast<const SCEVAddRecExpr>(WP->getExpr());
+ if (L != AR->getLoop())
+ return Expr;
+ }
if (!addOverflowAssumption(P))
return Expr;
}
@@ -11787,3 +12251,43 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
OS.indent(Depth + 2) << "--> " << *II->second.second << "\n";
}
}
+
+// Match the mathematical pattern A - (A / B) * B, where A and B can be
+// arbitrary expressions.
+// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
+// 4, A / B becomes X / 8).
+bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
+ const SCEV *&RHS) {
+ const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
+ if (Add == nullptr || Add->getNumOperands() != 2)
+ return false;
+
+ const SCEV *A = Add->getOperand(1);
+ const auto *Mul = dyn_cast<SCEVMulExpr>(Add->getOperand(0));
+
+ if (Mul == nullptr)
+ return false;
+
+ const auto MatchURemWithDivisor = [&](const SCEV *B) {
+ // (SomeExpr + (-(SomeExpr / B) * B)).
+ if (Expr == getURemExpr(A, B)) {
+ LHS = A;
+ RHS = B;
+ return true;
+ }
+ return false;
+ };
+
+ // (SomeExpr + (-1 * (SomeExpr / B) * B)).
+ if (Mul->getNumOperands() == 3 && isa<SCEVConstant>(Mul->getOperand(0)))
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(2));
+
+ // (SomeExpr + ((-SomeExpr / B) * B)) or (SomeExpr + ((SomeExpr / B) * -B)).
+ if (Mul->getNumOperands() == 2)
+ return MatchURemWithDivisor(Mul->getOperand(1)) ||
+ MatchURemWithDivisor(Mul->getOperand(0)) ||
+ MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(1))) ||
+ MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
+ return false;
+}