summaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--llvm/lib/Analysis/ScalarEvolution.cpp1995
1 files changed, 1429 insertions, 566 deletions
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 48c686b73260..fe9d8297d679 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -135,6 +135,7 @@
#include <vector>
using namespace llvm;
+using namespace PatternMatch;
#define DEBUG_TYPE "scalar-evolution"
@@ -226,6 +227,11 @@ ClassifyExpressions("scalar-evolution-classify-expressions",
cl::Hidden, cl::init(true),
cl::desc("When printing analysis, include information on every instruction"));
+static cl::opt<bool> UseExpensiveRangeSharpening(
+ "scalar-evolution-use-expensive-range-sharpening", cl::Hidden,
+ cl::init(false),
+ cl::desc("Use more powerful methods of sharpening expression ranges. May "
+ "be costly in terms of compile time"));
//===----------------------------------------------------------------------===//
// SCEV class definitions
@@ -243,10 +249,17 @@ LLVM_DUMP_METHOD void SCEV::dump() const {
#endif
void SCEV::print(raw_ostream &OS) const {
- switch (static_cast<SCEVTypes>(getSCEVType())) {
+ switch (getSCEVType()) {
case scConstant:
cast<SCEVConstant>(this)->getValue()->printAsOperand(OS, false);
return;
+ case scPtrToInt: {
+ const SCEVPtrToIntExpr *PtrToInt = cast<SCEVPtrToIntExpr>(this);
+ const SCEV *Op = PtrToInt->getOperand();
+ OS << "(ptrtoint " << *Op->getType() << " " << *Op << " to "
+ << *PtrToInt->getType() << ")";
+ return;
+ }
case scTruncate: {
const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(this);
const SCEV *Op = Trunc->getOperand();
@@ -304,6 +317,8 @@ void SCEV::print(raw_ostream &OS) const {
case scSMinExpr:
OpStr = " smin ";
break;
+ default:
+ llvm_unreachable("There are no other nary expression types.");
}
OS << "(";
for (SCEVNAryExpr::op_iterator I = NAry->op_begin(), E = NAry->op_end();
@@ -320,6 +335,10 @@ void SCEV::print(raw_ostream &OS) const {
OS << "<nuw>";
if (NAry->hasNoSignedWrap())
OS << "<nsw>";
+ break;
+ default:
+ // Nothing to print for other nary expressions.
+ break;
}
return;
}
@@ -361,9 +380,10 @@ void SCEV::print(raw_ostream &OS) const {
}
Type *SCEV::getType() const {
- switch (static_cast<SCEVTypes>(getSCEVType())) {
+ switch (getSCEVType()) {
case scConstant:
return cast<SCEVConstant>(this)->getType();
+ case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -445,28 +465,42 @@ ScalarEvolution::getConstant(Type *Ty, uint64_t V, bool isSigned) {
return getConstant(ConstantInt::get(ITy, V, isSigned));
}
-SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID,
- unsigned SCEVTy, const SCEV *op, Type *ty)
- : SCEV(ID, SCEVTy, computeExpressionSize(op)), Op(op), Ty(ty) {}
+SCEVCastExpr::SCEVCastExpr(const FoldingSetNodeIDRef ID, SCEVTypes SCEVTy,
+ const SCEV *op, Type *ty)
+ : SCEV(ID, SCEVTy, computeExpressionSize(op)), Ty(ty) {
+ Operands[0] = op;
+}
-SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID,
- const SCEV *op, Type *ty)
- : SCEVCastExpr(ID, scTruncate, op, ty) {
- assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+SCEVPtrToIntExpr::SCEVPtrToIntExpr(const FoldingSetNodeIDRef ID, const SCEV *Op,
+ Type *ITy)
+ : SCEVCastExpr(ID, scPtrToInt, Op, ITy) {
+ assert(getOperand()->getType()->isPointerTy() && Ty->isIntegerTy() &&
+ "Must be a non-bit-width-changing pointer-to-integer cast!");
+}
+
+SCEVIntegralCastExpr::SCEVIntegralCastExpr(const FoldingSetNodeIDRef ID,
+ SCEVTypes SCEVTy, const SCEV *op,
+ Type *ty)
+ : SCEVCastExpr(ID, SCEVTy, op, ty) {}
+
+SCEVTruncateExpr::SCEVTruncateExpr(const FoldingSetNodeIDRef ID, const SCEV *op,
+ Type *ty)
+ : SCEVIntegralCastExpr(ID, scTruncate, op, ty) {
+ assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot truncate non-integer value!");
}
SCEVZeroExtendExpr::SCEVZeroExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
- : SCEVCastExpr(ID, scZeroExtend, op, ty) {
- assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ : SCEVIntegralCastExpr(ID, scZeroExtend, op, ty) {
+ assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot zero extend non-integer value!");
}
SCEVSignExtendExpr::SCEVSignExtendExpr(const FoldingSetNodeIDRef ID,
const SCEV *op, Type *ty)
- : SCEVCastExpr(ID, scSignExtend, op, ty) {
- assert(Op->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
+ : SCEVIntegralCastExpr(ID, scSignExtend, op, ty) {
+ assert(getOperand()->getType()->isIntOrPtrTy() && Ty->isIntOrPtrTy() &&
"Cannot sign extend non-integer value!");
}
@@ -665,7 +699,7 @@ static int CompareSCEVComplexity(
return 0;
// Primarily, sort the SCEVs by their getSCEVType().
- unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
+ SCEVTypes LType = LHS->getSCEVType(), RType = RHS->getSCEVType();
if (LType != RType)
return (int)LType - (int)RType;
@@ -674,7 +708,7 @@ static int CompareSCEVComplexity(
// Aside from the getSCEVType() ordering, the particular ordering
// isn't very important except that it's beneficial to be consistent,
// so that (a + b) and (b + a) don't end up as different expressions.
- switch (static_cast<SCEVTypes>(LType)) {
+ switch (LType) {
case scUnknown: {
const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);
const SCEVUnknown *RU = cast<SCEVUnknown>(RHS);
@@ -776,6 +810,7 @@ static int CompareSCEVComplexity(
return X;
}
+ case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend: {
@@ -999,6 +1034,115 @@ const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,
// SCEV Expression folder implementations
//===----------------------------------------------------------------------===//
+const SCEV *ScalarEvolution::getPtrToIntExpr(const SCEV *Op, Type *Ty,
+ unsigned Depth) {
+ assert(Ty->isIntegerTy() && "Target type must be an integer type!");
+ assert(Depth <= 1 && "getPtrToIntExpr() should self-recurse at most once.");
+
+ // We could be called with an integer-typed operands during SCEV rewrites.
+ // Since the operand is an integer already, just perform zext/trunc/self cast.
+ if (!Op->getType()->isPointerTy())
+ return getTruncateOrZeroExtend(Op, Ty);
+
+ // What would be an ID for such a SCEV cast expression?
+ FoldingSetNodeID ID;
+ ID.AddInteger(scPtrToInt);
+ ID.AddPointer(Op);
+
+ void *IP = nullptr;
+
+ // Is there already an expression for such a cast?
+ if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP))
+ return getTruncateOrZeroExtend(S, Ty);
+
+ // If not, is this expression something we can't reduce any further?
+ if (isa<SCEVUnknown>(Op)) {
+ // Create an explicit cast node.
+ // We can reuse the existing insert position since if we get here,
+ // we won't have made any changes which would invalidate it.
+ Type *IntPtrTy = getDataLayout().getIntPtrType(Op->getType());
+ assert(getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(
+ Op->getType())) == getDataLayout().getTypeSizeInBits(IntPtrTy) &&
+ "We can only model ptrtoint if SCEV's effective (integer) type is "
+ "sufficiently wide to represent all possible pointer values.");
+ SCEV *S = new (SCEVAllocator)
+ SCEVPtrToIntExpr(ID.Intern(SCEVAllocator), Op, IntPtrTy);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ return getTruncateOrZeroExtend(S, Ty);
+ }
+
+ assert(Depth == 0 &&
+ "getPtrToIntExpr() should not self-recurse for non-SCEVUnknown's.");
+
+ // Otherwise, we've got some expression that is more complex than just a
+ // single SCEVUnknown. But we don't want to have a SCEVPtrToIntExpr of an
+ // arbitrary expression, we want to have SCEVPtrToIntExpr of an SCEVUnknown
+ // only, and the expressions must otherwise be integer-typed.
+ // So sink the cast down to the SCEVUnknown's.
+
+ /// The SCEVPtrToIntSinkingRewriter takes a scalar evolution expression,
+ /// which computes a pointer-typed value, and rewrites the whole expression
+ /// tree so that *all* the computations are done on integers, and the only
+ /// pointer-typed operands in the expression are SCEVUnknown.
+ class SCEVPtrToIntSinkingRewriter
+ : public SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter> {
+ using Base = SCEVRewriteVisitor<SCEVPtrToIntSinkingRewriter>;
+
+ public:
+ SCEVPtrToIntSinkingRewriter(ScalarEvolution &SE) : SCEVRewriteVisitor(SE) {}
+
+ static const SCEV *rewrite(const SCEV *Scev, ScalarEvolution &SE) {
+ SCEVPtrToIntSinkingRewriter Rewriter(SE);
+ return Rewriter.visit(Scev);
+ }
+
+ const SCEV *visit(const SCEV *S) {
+ Type *STy = S->getType();
+ // If the expression is not pointer-typed, just keep it as-is.
+ if (!STy->isPointerTy())
+ return S;
+ // Else, recursively sink the cast down into it.
+ return Base::visit(S);
+ }
+
+ const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+ SmallVector<const SCEV *, 2> Operands;
+ bool Changed = false;
+ for (auto *Op : Expr->operands()) {
+ Operands.push_back(visit(Op));
+ Changed |= Op != Operands.back();
+ }
+ return !Changed ? Expr : SE.getAddExpr(Operands, Expr->getNoWrapFlags());
+ }
+
+ const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+ SmallVector<const SCEV *, 2> Operands;
+ bool Changed = false;
+ for (auto *Op : Expr->operands()) {
+ Operands.push_back(visit(Op));
+ Changed |= Op != Operands.back();
+ }
+ return !Changed ? Expr : SE.getMulExpr(Operands, Expr->getNoWrapFlags());
+ }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ Type *ExprPtrTy = Expr->getType();
+ assert(ExprPtrTy->isPointerTy() &&
+ "Should only reach pointer-typed SCEVUnknown's.");
+ Type *ExprIntPtrTy = SE.getDataLayout().getIntPtrType(ExprPtrTy);
+ return SE.getPtrToIntExpr(Expr, ExprIntPtrTy, /*Depth=*/1);
+ }
+ };
+
+ // And actually perform the cast sinking.
+ const SCEV *IntOp = SCEVPtrToIntSinkingRewriter::rewrite(Op, *this);
+ assert(IntOp->getType()->isIntegerTy() &&
+ "We must have succeeded in sinking the cast, "
+ "and ending up with an integer-typed expression!");
+ return getTruncateOrZeroExtend(IntOp, Ty);
+}
+
const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
unsigned Depth) {
assert(getTypeSizeInBits(Op->getType()) > getTypeSizeInBits(Ty) &&
@@ -1050,7 +1194,8 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
for (unsigned i = 0, e = CommOp->getNumOperands(); i != e && numTruncs < 2;
++i) {
const SCEV *S = getTruncateExpr(CommOp->getOperand(i), Ty, Depth + 1);
- if (!isa<SCEVCastExpr>(CommOp->getOperand(i)) && isa<SCEVTruncateExpr>(S))
+ if (!isa<SCEVIntegralCastExpr>(CommOp->getOperand(i)) &&
+ isa<SCEVTruncateExpr>(S))
numTruncs++;
Operands.push_back(S);
}
@@ -1077,6 +1222,11 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, Type *Ty,
return getAddRecExpr(Operands, AddRec->getLoop(), SCEV::FlagAnyWrap);
}
+ // Return zero if truncating to known zeros.
+ uint32_t MinTrailingZeros = GetMinTrailingZeros(Op);
+ if (MinTrailingZeros >= getTypeSizeInBits(Ty))
+ return getZero(Ty);
+
// The cast wasn't folded; create an explicit cast node. We can reuse
// the existing insert position since if we get here, we won't have
// made any changes which would invalidate it.
@@ -1237,7 +1387,7 @@ static const SCEV *getPreStartForExtend(const SCEVAddRecExpr *AR, Type *Ty,
// If we know `AR` == {`PreStart`+`Step`,+,`Step`} is `WrapType` (FlagNSW
// or FlagNUW) and that `PreStart` + `Step` is `WrapType` too, then
// `PreAR` == {`PreStart`,+,`Step`} is also `WrapType`. Cache this fact.
- const_cast<SCEVAddRecExpr *>(PreAR)->setNoWrapFlags(WrapType);
+ SE->setNoWrapFlags(const_cast<SCEVAddRecExpr *>(PreAR), WrapType);
}
return PreStart;
}
@@ -1441,7 +1591,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
if (!AR->hasNoUnsignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
// If we have special knowledge that this addrec won't overflow,
@@ -1461,8 +1611,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
// that value once it has finished.
const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
if (!isa<SCEVCouldNotCompute>(MaxBECount)) {
- // Manually compute the final value for AR, checking for
- // overflow.
+ // Manually compute the final value for AR, checking for overflow.
// Check whether the backedge-taken count can be losslessly casted to
// the addrec's type. The count is always unsigned.
@@ -1490,7 +1639,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV::FlagAnyWrap, Depth + 1);
if (ZAdd == OperandExtendedAdd) {
// Cache knowledge of AR NUW, which is propagated to this AddRec.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
@@ -1509,7 +1658,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
if (ZAdd == OperandExtendedAdd) {
// Cache knowledge of AR NW, which is propagated to this AddRec.
// Negative step causes unsigned wrap, but it still can't self-wrap.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
@@ -1529,27 +1678,24 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
// doing extra work that may not pay off.
if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
!AC.assumptions().empty()) {
- // If the backedge is guarded by a comparison with the pre-inc
- // value the addrec is safe. Also, if the entry is guarded by
- // a comparison with the start value and the backedge is
- // guarded by a comparison with the post-inc value, the addrec
- // is safe.
- if (isKnownPositive(Step)) {
- const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
- getUnsignedRangeMax(Step));
- if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
- isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
- // Cache knowledge of AR NUW, which is propagated to this
- // AddRec.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
- // Return the expression with the addrec on the outside.
- return getAddRecExpr(
+
+ auto NewFlags = proveNoUnsignedWrapViaInduction(AR);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
+ if (AR->hasNoUnsignedWrap()) {
+ // Same as nuw case above - duplicated here to avoid a compile time
+ // issue. It's not clear that the order of checks does matter, but
+ // it's one of two issue possible causes for a change which was
+ // reverted. Be conservative for the moment.
+ return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L,
AR->getNoWrapFlags());
- }
- } else if (isKnownNegative(Step)) {
+ }
+
+ // For a negative step, we can extend the operands iff doing so only
+ // traverses values in the range zext([0,UINT_MAX]).
+ if (isKnownNegative(Step)) {
const SCEV *N = getConstant(APInt::getMaxValue(BitWidth) -
getSignedRangeMin(Step));
if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_UGT, AR, N) ||
@@ -1557,7 +1703,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
// Cache knowledge of AR NW, which is propagated to this
// AddRec. Negative step causes unsigned wrap, but it
// still can't self-wrap.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this,
@@ -1586,7 +1732,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
if (proveNoWrapByVaryingStart<SCEVZeroExtendExpr>(Start, Step, L)) {
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNUW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNUW);
return getAddRecExpr(
getExtendAddRecStart<SCEVZeroExtendExpr>(AR, Ty, this, Depth + 1),
getZeroExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
@@ -1785,7 +1931,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
if (!AR->hasNoSignedWrap()) {
auto NewFlags = proveNoWrapViaConstantRanges(AR);
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(NewFlags);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
}
// If we have special knowledge that this addrec won't overflow,
@@ -1834,7 +1980,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
SCEV::FlagAnyWrap, Depth + 1);
if (SAdd == OperandExtendedAdd) {
// Cache knowledge of AR NSW, which is propagated to this AddRec.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this,
@@ -1859,7 +2005,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
// Thus (AR is not NW => SAdd != OperandExtendedAdd) <=>
// (SAdd == OperandExtendedAdd => AR is NW)
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNW);
// Return the expression with the addrec on the outside.
return getAddRecExpr(
@@ -1871,33 +2017,16 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
}
- // Normally, in the cases we can prove no-overflow via a
- // backedge guarding condition, we can also compute a backedge
- // taken count for the loop. The exceptions are assumptions and
- // guards present in the loop -- SCEV is not great at exploiting
- // these to compute max backedge taken counts, but can still use
- // these to prove lack of overflow. Use this fact to avoid
- // doing extra work that may not pay off.
-
- if (!isa<SCEVCouldNotCompute>(MaxBECount) || HasGuards ||
- !AC.assumptions().empty()) {
- // If the backedge is guarded by a comparison with the pre-inc
- // value the addrec is safe. Also, if the entry is guarded by
- // a comparison with the start value and the backedge is
- // guarded by a comparison with the post-inc value, the addrec
- // is safe.
- ICmpInst::Predicate Pred;
- const SCEV *OverflowLimit =
- getSignedOverflowLimitForStep(Step, &Pred, this);
- if (OverflowLimit &&
- (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
- isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
- // Cache knowledge of AR NSW, then propagate NSW to the wide AddRec.
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
- return getAddRecExpr(
- getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
- getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
- }
+ auto NewFlags = proveNoSignedWrapViaInduction(AR);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), NewFlags);
+ if (AR->hasNoSignedWrap()) {
+ // Same as nsw case above - duplicated here to avoid a compile time
+ // issue. It's not clear that the order of checks does matter, but
+ // it's one of two issue possible causes for a change which was
+ // reverted. Be conservative for the moment.
+ return getAddRecExpr(
+ getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
+ getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
}
// sext({C,+,Step}) --> (sext(D) + sext({C-D,+,Step}))<nuw><nsw>
@@ -1918,7 +2047,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {
}
if (proveNoWrapByVaryingStart<SCEVSignExtendExpr>(Start, Step, L)) {
- const_cast<SCEVAddRecExpr *>(AR)->setNoWrapFlags(SCEV::FlagNSW);
+ setNoWrapFlags(const_cast<SCEVAddRecExpr *>(AR), SCEV::FlagNSW);
return getAddRecExpr(
getExtendAddRecStart<SCEVSignExtendExpr>(AR, Ty, this, Depth + 1),
getSignExtendExpr(Step, Ty, Depth + 1), L, AR->getNoWrapFlags());
@@ -2048,7 +2177,7 @@ CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,
} else {
// A multiplication of a constant with some other value. Update
// the map.
- SmallVector<const SCEV *, 4> MulOps(Mul->op_begin()+1, Mul->op_end());
+ SmallVector<const SCEV *, 4> MulOps(drop_begin(Mul->operands()));
const SCEV *Key = SE.getMulExpr(MulOps);
auto Pair = M.insert({Key, NewScale});
if (Pair.second) {
@@ -2152,9 +2281,9 @@ bool ScalarEvolution::isAvailableAtLoopEntry(const SCEV *S, const Loop *L) {
/// Get a canonical add expression, or something simpler if possible.
const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
- SCEV::NoWrapFlags Flags,
+ SCEV::NoWrapFlags OrigFlags,
unsigned Depth) {
- assert(!(Flags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
+ assert(!(OrigFlags & ~(SCEV::FlagNUW | SCEV::FlagNSW)) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty add!");
if (Ops.size() == 1) return Ops[0];
@@ -2168,8 +2297,6 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, &LI, DT);
- Flags = StrengthenNoWrapFlags(this, scAddExpr, Ops, Flags);
-
// If there are any constants, fold them together.
unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
@@ -2192,12 +2319,20 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
if (Ops.size() == 1) return Ops[0];
}
+ // Delay expensive flag strengthening until necessary.
+ auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
+ return StrengthenNoWrapFlags(this, scAddExpr, Ops, OrigFlags);
+ };
+
// Limit recursion calls depth.
if (Depth > MaxArithDepth || hasHugeExpression(Ops))
- return getOrCreateAddExpr(Ops, Flags);
+ return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
if (SCEV *S = std::get<0>(findExistingSCEVInCache(scAddExpr, Ops))) {
- static_cast<SCEVAddExpr *>(S)->setNoWrapFlags(Flags);
+ // Don't strengthen flags if we have no new information.
+ SCEVAddExpr *Add = static_cast<SCEVAddExpr *>(S);
+ if (Add->getNoWrapFlags(OrigFlags) != OrigFlags)
+ Add->setNoWrapFlags(ComputeFlags(Ops));
return S;
}
@@ -2223,7 +2358,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
FoundMatch = true;
}
if (FoundMatch)
- return getAddExpr(Ops, Flags, Depth + 1);
+ return getAddExpr(Ops, OrigFlags, Depth + 1);
// Check for truncates. If all the operands are truncated from the same
// type, see if factoring out the truncate would permit the result to be
@@ -2458,11 +2593,16 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
// If we found some loop invariants, fold them into the recurrence.
if (!LIOps.empty()) {
+ // Compute nowrap flags for the addition of the loop-invariant ops and
+ // the addrec. Temporarily push it as an operand for that purpose.
+ LIOps.push_back(AddRec);
+ SCEV::NoWrapFlags Flags = ComputeFlags(LIOps);
+ LIOps.pop_back();
+
// NLI + LI + {Start,+,Step} --> NLI + {LI+Start,+,Step}
LIOps.push_back(AddRec->getStart());
- SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
- AddRec->op_end());
+ SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
// This follows from the fact that the no-wrap flags on the outer add
// expression are applicable on the 0th iteration, when the add recurrence
// will be equal to its start value.
@@ -2500,8 +2640,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
"AddRecExprs are not sorted in reverse dominance order?");
if (AddRecLoop == cast<SCEVAddRecExpr>(Ops[OtherIdx])->getLoop()) {
// Other + {A,+,B}<L> + {C,+,D}<L> --> Other + {A+C,+,B+D}<L>
- SmallVector<const SCEV *, 4> AddRecOps(AddRec->op_begin(),
- AddRec->op_end());
+ SmallVector<const SCEV *, 4> AddRecOps(AddRec->operands());
for (; OtherIdx != Ops.size() && isa<SCEVAddRecExpr>(Ops[OtherIdx]);
++OtherIdx) {
const auto *OtherAddRec = cast<SCEVAddRecExpr>(Ops[OtherIdx]);
@@ -2532,7 +2671,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,
// Okay, it looks like we really DO need an add expr. Check to see if we
// already have one, otherwise create a new one.
- return getOrCreateAddExpr(Ops, Flags);
+ return getOrCreateAddExpr(Ops, ComputeFlags(Ops));
}
const SCEV *
@@ -2576,7 +2715,7 @@ ScalarEvolution::getOrCreateAddRecExpr(ArrayRef<const SCEV *> Ops,
UniqueSCEVs.InsertNode(S, IP);
addToLoopUseLists(S);
}
- S->setNoWrapFlags(Flags);
+ setNoWrapFlags(S, Flags);
return S;
}
@@ -2658,9 +2797,9 @@ static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
/// Get a canonical multiply expression, or something simpler if possible.
const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
- SCEV::NoWrapFlags Flags,
+ SCEV::NoWrapFlags OrigFlags,
unsigned Depth) {
- assert(Flags == maskFlags(Flags, SCEV::FlagNUW | SCEV::FlagNSW) &&
+ assert(OrigFlags == maskFlags(OrigFlags, SCEV::FlagNUW | SCEV::FlagNSW) &&
"only nuw or nsw allowed");
assert(!Ops.empty() && "Cannot get empty mul!");
if (Ops.size() == 1) return Ops[0];
@@ -2674,24 +2813,52 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, &LI, DT);
- Flags = StrengthenNoWrapFlags(this, scMulExpr, Ops, Flags);
+ // If there are any constants, fold them together.
+ unsigned Idx = 0;
+ if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
+ ++Idx;
+ assert(Idx < Ops.size());
+ while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
+ // We found two constants, fold them together!
+ Ops[0] = getConstant(LHSC->getAPInt() * RHSC->getAPInt());
+ if (Ops.size() == 2) return Ops[0];
+ Ops.erase(Ops.begin()+1); // Erase the folded element
+ LHSC = cast<SCEVConstant>(Ops[0]);
+ }
+
+ // If we have a multiply of zero, it will always be zero.
+ if (LHSC->getValue()->isZero())
+ return LHSC;
- // Limit recursion calls depth, but fold all-constant expressions.
- // `Ops` is sorted, so it's enough to check just last one.
- if ((Depth > MaxArithDepth || hasHugeExpression(Ops)) &&
- !isa<SCEVConstant>(Ops.back()))
- return getOrCreateMulExpr(Ops, Flags);
+ // If we are left with a constant one being multiplied, strip it off.
+ if (LHSC->getValue()->isOne()) {
+ Ops.erase(Ops.begin());
+ --Idx;
+ }
+
+ if (Ops.size() == 1)
+ return Ops[0];
+ }
+
+ // Delay expensive flag strengthening until necessary.
+ auto ComputeFlags = [this, OrigFlags](const ArrayRef<const SCEV *> Ops) {
+ return StrengthenNoWrapFlags(this, scMulExpr, Ops, OrigFlags);
+ };
+
+ // Limit recursion calls depth.
+ if (Depth > MaxArithDepth || hasHugeExpression(Ops))
+ return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
if (SCEV *S = std::get<0>(findExistingSCEVInCache(scMulExpr, Ops))) {
- static_cast<SCEVMulExpr *>(S)->setNoWrapFlags(Flags);
+ // Don't strengthen flags if we have no new information.
+ SCEVMulExpr *Mul = static_cast<SCEVMulExpr *>(S);
+ if (Mul->getNoWrapFlags(OrigFlags) != OrigFlags)
+ Mul->setNoWrapFlags(ComputeFlags(Ops));
return S;
}
- // If there are any constants, fold them together.
- unsigned Idx = 0;
if (const SCEVConstant *LHSC = dyn_cast<SCEVConstant>(Ops[0])) {
-
- if (Ops.size() == 2)
+ if (Ops.size() == 2) {
// C1*(C2+V) -> C1*C2 + C1*V
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1]))
// If any of Add's ops are Adds or Muls with a constant, apply this
@@ -2707,28 +2874,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::FlagAnyWrap, Depth + 1),
SCEV::FlagAnyWrap, Depth + 1);
- ++Idx;
- while (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(Ops[Idx])) {
- // We found two constants, fold them together!
- ConstantInt *Fold =
- ConstantInt::get(getContext(), LHSC->getAPInt() * RHSC->getAPInt());
- Ops[0] = getConstant(Fold);
- Ops.erase(Ops.begin()+1); // Erase the folded element
- if (Ops.size() == 1) return Ops[0];
- LHSC = cast<SCEVConstant>(Ops[0]);
- }
-
- // If we are left with a constant one being multiplied, strip it off.
- if (cast<SCEVConstant>(Ops[0])->getValue()->isOne()) {
- Ops.erase(Ops.begin());
- --Idx;
- } else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
- // If we have a multiply of zero, it will always be zero.
- return Ops[0];
- } else if (Ops[0]->isAllOnesValue()) {
- // If we have a mul by -1 of an add, try distributing the -1 among the
- // add operands.
- if (Ops.size() == 2) {
+ if (Ops[0]->isAllOnesValue()) {
+ // If we have a mul by -1 of an add, try distributing the -1 among the
+ // add operands.
if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
SmallVector<const SCEV *, 4> NewOps;
bool AnyFolded = false;
@@ -2752,9 +2900,6 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
}
}
}
-
- if (Ops.size() == 1)
- return Ops[0];
}
// Skip over the add expression until we get to a multiply.
@@ -2816,8 +2961,9 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
//
// No self-wrap cannot be guaranteed after changing the step size, but
// will be inferred if either NUW or NSW is true.
- Flags = AddRec->getNoWrapFlags(clearFlags(Flags, SCEV::FlagNW));
- const SCEV *NewRec = getAddRecExpr(NewOps, AddRecLoop, Flags);
+ SCEV::NoWrapFlags Flags = ComputeFlags({Scale, AddRec});
+ const SCEV *NewRec = getAddRecExpr(
+ NewOps, AddRecLoop, AddRec->getNoWrapFlags(Flags));
// If all of the other operands were loop invariant, we are done.
if (Ops.size() == 1) return NewRec;
@@ -2910,7 +3056,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
// Okay, it looks like we really DO need an mul expr. Check to see if we
// already have one, otherwise create a new one.
- return getOrCreateMulExpr(Ops, Flags);
+ return getOrCreateMulExpr(Ops, ComputeFlags(Ops));
}
/// Represents an unsigned remainder expression based on unsigned division.
@@ -3034,8 +3180,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,
const SCEV *Op = M->getOperand(i);
const SCEV *Div = getUDivExpr(Op, RHSC);
if (!isa<SCEVUDivExpr>(Div) && getMulExpr(Div, RHSC) == Op) {
- Operands = SmallVector<const SCEV *, 4>(M->op_begin(),
- M->op_end());
+ Operands = SmallVector<const SCEV *, 4>(M->operands());
Operands[i] = Div;
return getMulExpr(Operands);
}
@@ -3129,8 +3274,7 @@ const SCEV *ScalarEvolution::getUDivExactExpr(const SCEV *LHS,
// first element of the mulexpr.
if (const auto *LHSCst = dyn_cast<SCEVConstant>(Mul->getOperand(0))) {
if (LHSCst == RHSCst) {
- SmallVector<const SCEV *, 2> Operands;
- Operands.append(Mul->op_begin() + 1, Mul->op_end());
+ SmallVector<const SCEV *, 2> Operands(drop_begin(Mul->operands()));
return getMulExpr(Operands);
}
@@ -3220,8 +3364,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
? (L->getLoopDepth() < NestedLoop->getLoopDepth())
: (!NestedLoop->contains(L) &&
DT.dominates(L->getHeader(), NestedLoop->getHeader()))) {
- SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
- NestedAR->op_end());
+ SmallVector<const SCEV *, 4> NestedOperands(NestedAR->operands());
Operands[0] = NestedAR->getStart();
// AddRecs require their operands be loop-invariant with respect to their
// loops. Don't perform this transformation if it would break this
@@ -3274,12 +3417,12 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP,
// flow and the no-overflow bits may not be valid for the expression in any
// context. This can be fixed similarly to how these flags are handled for
// adds.
- SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW
- : SCEV::FlagAnyWrap;
+ SCEV::NoWrapFlags OffsetWrap =
+ GEP->isInBounds() ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
- const SCEV *TotalOffset = getZero(IntIdxTy);
Type *CurTy = GEP->getType();
bool FirstIter = true;
+ SmallVector<const SCEV *, 4> Offsets;
for (const SCEV *IndexExpr : IndexExprs) {
// Compute the (potentially symbolic) offset in bytes for this index.
if (StructType *STy = dyn_cast<StructType>(CurTy)) {
@@ -3287,9 +3430,7 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP,
ConstantInt *Index = cast<SCEVConstant>(IndexExpr)->getValue();
unsigned FieldNo = Index->getZExtValue();
const SCEV *FieldOffset = getOffsetOfExpr(IntIdxTy, STy, FieldNo);
-
- // Add the field offset to the running total offset.
- TotalOffset = getAddExpr(TotalOffset, FieldOffset);
+ Offsets.push_back(FieldOffset);
// Update CurTy to the type of the field at Index.
CurTy = STy->getTypeAtIndex(Index);
@@ -3309,22 +3450,27 @@ ScalarEvolution::getGEPExpr(GEPOperator *GEP,
IndexExpr = getTruncateOrSignExtend(IndexExpr, IntIdxTy);
// Multiply the index by the element size to compute the element offset.
- const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, Wrap);
-
- // Add the element offset to the running total offset.
- TotalOffset = getAddExpr(TotalOffset, LocalOffset);
+ const SCEV *LocalOffset = getMulExpr(IndexExpr, ElementSize, OffsetWrap);
+ Offsets.push_back(LocalOffset);
}
}
- // Add the total offset from all the GEP indices to the base.
- auto *GEPExpr = getAddExpr(BaseExpr, TotalOffset, Wrap);
- assert(BaseExpr->getType() == GEPExpr->getType() &&
- "GEP should not change type mid-flight.");
- return GEPExpr;
+ // Handle degenerate case of GEP without offsets.
+ if (Offsets.empty())
+ return BaseExpr;
+
+ // Add the offsets together, assuming nsw if inbounds.
+ const SCEV *Offset = getAddExpr(Offsets, OffsetWrap);
+ // Add the base address and the offset. We cannot use the nsw flag, as the
+ // base address is unsigned. However, if we know that the offset is
+ // non-negative, we can use nuw.
+ SCEV::NoWrapFlags BaseWrap = GEP->isInBounds() && isKnownNonNegative(Offset)
+ ? SCEV::FlagNUW : SCEV::FlagAnyWrap;
+ return getAddExpr(BaseExpr, Offset, BaseWrap);
}
std::tuple<SCEV *, FoldingSetNodeID, void *>
-ScalarEvolution::findExistingSCEVInCache(int SCEVType,
+ScalarEvolution::findExistingSCEVInCache(SCEVTypes SCEVType,
ArrayRef<const SCEV *> Ops) {
FoldingSetNodeID ID;
void *IP = nullptr;
@@ -3335,7 +3481,17 @@ ScalarEvolution::findExistingSCEVInCache(int SCEVType,
UniqueSCEVs.FindNodeOrInsertPos(ID, IP), std::move(ID), IP);
}
-const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind,
+const SCEV *ScalarEvolution::getAbsExpr(const SCEV *Op, bool IsNSW) {
+ SCEV::NoWrapFlags Flags = IsNSW ? SCEV::FlagNSW : SCEV::FlagAnyWrap;
+ return getSMaxExpr(Op, getNegativeSCEV(Op, Flags));
+}
+
+const SCEV *ScalarEvolution::getSignumExpr(const SCEV *Op) {
+ Type *Ty = Op->getType();
+ return getSMinExpr(getSMaxExpr(Op, getMinusOne(Ty)), getOne(Ty));
+}
+
+const SCEV *ScalarEvolution::getMinMaxExpr(SCEVTypes Kind,
SmallVectorImpl<const SCEV *> &Ops) {
assert(!Ops.empty() && "Cannot get empty (u|s)(min|max)!");
if (Ops.size() == 1) return Ops[0];
@@ -3459,8 +3615,8 @@ const SCEV *ScalarEvolution::getMinMaxExpr(unsigned Kind,
return ExistingSCEV;
const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
std::uninitialized_copy(Ops.begin(), Ops.end(), O);
- SCEV *S = new (SCEVAllocator) SCEVMinMaxExpr(
- ID.Intern(SCEVAllocator), static_cast<SCEVTypes>(Kind), O, Ops.size());
+ SCEV *S = new (SCEVAllocator)
+ SCEVMinMaxExpr(ID.Intern(SCEVAllocator), Kind, O, Ops.size());
UniqueSCEVs.InsertNode(S, IP);
addToLoopUseLists(S);
@@ -3505,25 +3661,42 @@ const SCEV *ScalarEvolution::getUMinExpr(SmallVectorImpl<const SCEV *> &Ops) {
return getMinMaxExpr(scUMinExpr, Ops);
}
+const SCEV *
+ScalarEvolution::getSizeOfScalableVectorExpr(Type *IntTy,
+ ScalableVectorType *ScalableTy) {
+ Constant *NullPtr = Constant::getNullValue(ScalableTy->getPointerTo());
+ Constant *One = ConstantInt::get(IntTy, 1);
+ Constant *GEP = ConstantExpr::getGetElementPtr(ScalableTy, NullPtr, One);
+ // Note that the expression we created is the final expression, we don't
+ // want to simplify it any further Also, if we call a normal getSCEV(),
+ // we'll end up in an endless recursion. So just create an SCEVUnknown.
+ return getUnknown(ConstantExpr::getPtrToInt(GEP, IntTy));
+}
+
const SCEV *ScalarEvolution::getSizeOfExpr(Type *IntTy, Type *AllocTy) {
- // We can bypass creating a target-independent
- // constant expression and then folding it back into a ConstantInt.
- // This is just a compile-time optimization.
- if (isa<ScalableVectorType>(AllocTy)) {
- Constant *NullPtr = Constant::getNullValue(AllocTy->getPointerTo());
- Constant *One = ConstantInt::get(IntTy, 1);
- Constant *GEP = ConstantExpr::getGetElementPtr(AllocTy, NullPtr, One);
- return getSCEV(ConstantExpr::getPtrToInt(GEP, IntTy));
- }
+ if (auto *ScalableAllocTy = dyn_cast<ScalableVectorType>(AllocTy))
+ return getSizeOfScalableVectorExpr(IntTy, ScalableAllocTy);
+ // We can bypass creating a target-independent constant expression and then
+ // folding it back into a ConstantInt. This is just a compile-time
+ // optimization.
return getConstant(IntTy, getDataLayout().getTypeAllocSize(AllocTy));
}
+const SCEV *ScalarEvolution::getStoreSizeOfExpr(Type *IntTy, Type *StoreTy) {
+ if (auto *ScalableStoreTy = dyn_cast<ScalableVectorType>(StoreTy))
+ return getSizeOfScalableVectorExpr(IntTy, ScalableStoreTy);
+ // We can bypass creating a target-independent constant expression and then
+ // folding it back into a ConstantInt. This is just a compile-time
+ // optimization.
+ return getConstant(IntTy, getDataLayout().getTypeStoreSize(StoreTy));
+}
+
const SCEV *ScalarEvolution::getOffsetOfExpr(Type *IntTy,
StructType *STy,
unsigned FieldNo) {
- // We can bypass creating a target-independent
- // constant expression and then folding it back into a ConstantInt.
- // This is just a compile-time optimization.
+ // We can bypass creating a target-independent constant expression and then
+ // folding it back into a ConstantInt. This is just a compile-time
+ // optimization.
return getConstant(
IntTy, getDataLayout().getStructLayout(STy)->getElementOffset(FieldNo));
}
@@ -3747,8 +3920,7 @@ const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,
Type *Ty = V->getType();
Ty = getEffectiveSCEVType(Ty);
- return getMulExpr(
- V, getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty))), Flags);
+ return getMulExpr(V, getMinusOne(Ty), Flags);
}
/// If Expr computes ~A, return A else return nullptr
@@ -3782,9 +3954,8 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
return (const SCEV *)nullptr;
MatchedOperands.push_back(Matched);
}
- return getMinMaxExpr(
- SCEVMinMaxExpr::negate(static_cast<SCEVTypes>(MME->getSCEVType())),
- MatchedOperands);
+ return getMinMaxExpr(SCEVMinMaxExpr::negate(MME->getSCEVType()),
+ MatchedOperands);
};
if (const SCEV *Replaced = MatchMinMaxNegation(MME))
return Replaced;
@@ -3792,9 +3963,7 @@ const SCEV *ScalarEvolution::getNotSCEV(const SCEV *V) {
Type *Ty = V->getType();
Ty = getEffectiveSCEVType(Ty);
- const SCEV *AllOnes =
- getConstant(cast<ConstantInt>(Constant::getAllOnesValue(Ty)));
- return getMinusSCEV(AllOnes, V);
+ return getMinusSCEV(getMinusOne(Ty), V);
}
const SCEV *ScalarEvolution::getMinusSCEV(const SCEV *LHS, const SCEV *RHS,
@@ -3941,6 +4110,7 @@ const SCEV *ScalarEvolution::getUMinFromMismatchedTypes(
MaxType = getWiderType(MaxType, S->getType());
else
MaxType = S->getType();
+ assert(MaxType && "Failed to find maximum type!");
// Extend all ops to max type.
SmallVector<const SCEV *, 2> PromotedOps;
@@ -3957,7 +4127,7 @@ const SCEV *ScalarEvolution::getPointerBase(const SCEV *V) {
return V;
while (true) {
- if (const SCEVCastExpr *Cast = dyn_cast<SCEVCastExpr>(V)) {
+ if (const SCEVIntegralCastExpr *Cast = dyn_cast<SCEVIntegralCastExpr>(V)) {
V = Cast->getOperand();
} else if (const SCEVNAryExpr *NAry = dyn_cast<SCEVNAryExpr>(V)) {
const SCEV *PtrOp = nullptr;
@@ -4260,6 +4430,107 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {
return Result;
}
+SCEV::NoWrapFlags
+ScalarEvolution::proveNoSignedWrapViaInduction(const SCEVAddRecExpr *AR) {
+ SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
+
+ if (AR->hasNoSignedWrap())
+ return Result;
+
+ if (!AR->isAffine())
+ return Result;
+
+ const SCEV *Step = AR->getStepRecurrence(*this);
+ const Loop *L = AR->getLoop();
+
+ // Check whether the backedge-taken count is SCEVCouldNotCompute.
+ // Note that this serves two purposes: It filters out loops that are
+ // simply not analyzable, and it covers the case where this code is
+ // being called from within backedge-taken count analysis, such that
+ // attempting to ask for the backedge-taken count would likely result
+ // in infinite recursion. In the later case, the analysis code will
+ // cope with a conservative value, and it will take care to purge
+ // that value once it has finished.
+ const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
+
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+
+ if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
+ AC.assumptions().empty())
+ return Result;
+
+ // If the backedge is guarded by a comparison with the pre-inc value the
+ // addrec is safe. Also, if the entry is guarded by a comparison with the
+ // start value and the backedge is guarded by a comparison with the post-inc
+ // value, the addrec is safe.
+ ICmpInst::Predicate Pred;
+ const SCEV *OverflowLimit =
+ getSignedOverflowLimitForStep(Step, &Pred, this);
+ if (OverflowLimit &&
+ (isLoopBackedgeGuardedByCond(L, Pred, AR, OverflowLimit) ||
+ isKnownOnEveryIteration(Pred, AR, OverflowLimit))) {
+ Result = setFlags(Result, SCEV::FlagNSW);
+ }
+ return Result;
+}
+SCEV::NoWrapFlags
+ScalarEvolution::proveNoUnsignedWrapViaInduction(const SCEVAddRecExpr *AR) {
+ SCEV::NoWrapFlags Result = AR->getNoWrapFlags();
+
+ if (AR->hasNoUnsignedWrap())
+ return Result;
+
+ if (!AR->isAffine())
+ return Result;
+
+ const SCEV *Step = AR->getStepRecurrence(*this);
+ unsigned BitWidth = getTypeSizeInBits(AR->getType());
+ const Loop *L = AR->getLoop();
+
+ // Check whether the backedge-taken count is SCEVCouldNotCompute.
+ // Note that this serves two purposes: It filters out loops that are
+ // simply not analyzable, and it covers the case where this code is
+ // being called from within backedge-taken count analysis, such that
+ // attempting to ask for the backedge-taken count would likely result
+ // in infinite recursion. In the later case, the analysis code will
+ // cope with a conservative value, and it will take care to purge
+ // that value once it has finished.
+ const SCEV *MaxBECount = getConstantMaxBackedgeTakenCount(L);
+
+ // Normally, in the cases we can prove no-overflow via a
+ // backedge guarding condition, we can also compute a backedge
+ // taken count for the loop. The exceptions are assumptions and
+ // guards present in the loop -- SCEV is not great at exploiting
+ // these to compute max backedge taken counts, but can still use
+ // these to prove lack of overflow. Use this fact to avoid
+ // doing extra work that may not pay off.
+
+ if (isa<SCEVCouldNotCompute>(MaxBECount) && !HasGuards &&
+ AC.assumptions().empty())
+ return Result;
+
+ // If the backedge is guarded by a comparison with the pre-inc value the
+ // addrec is safe. Also, if the entry is guarded by a comparison with the
+ // start value and the backedge is guarded by a comparison with the post-inc
+ // value, the addrec is safe.
+ if (isKnownPositive(Step)) {
+ const SCEV *N = getConstant(APInt::getMinValue(BitWidth) -
+ getUnsignedRangeMax(Step));
+ if (isLoopBackedgeGuardedByCond(L, ICmpInst::ICMP_ULT, AR, N) ||
+ isKnownOnEveryIteration(ICmpInst::ICMP_ULT, AR, N)) {
+ Result = setFlags(Result, SCEV::FlagNUW);
+ }
+ }
+
+ return Result;
+}
+
namespace {
/// Represents an abstract binary operation. This may exist as a
@@ -4271,6 +4542,7 @@ struct BinaryOp {
Value *RHS;
bool IsNSW = false;
bool IsNUW = false;
+ bool IsExact = false;
/// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or
/// constant expression.
@@ -4283,11 +4555,14 @@ struct BinaryOp {
IsNSW = OBO->hasNoSignedWrap();
IsNUW = OBO->hasNoUnsignedWrap();
}
+ if (auto *PEO = dyn_cast<PossiblyExactOperator>(Op))
+ IsExact = PEO->isExact();
}
explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,
- bool IsNUW = false)
- : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}
+ bool IsNUW = false, bool IsExact = false)
+ : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW),
+ IsExact(IsExact) {}
};
} // end anonymous namespace
@@ -4984,8 +5259,15 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
bool follow(const SCEV *S) {
switch (S->getSCEVType()) {
- case scConstant: case scTruncate: case scZeroExtend: case scSignExtend:
- case scAddExpr: case scMulExpr: case scUMaxExpr: case scSMaxExpr:
+ case scConstant:
+ case scPtrToInt:
+ case scTruncate:
+ case scZeroExtend:
+ case scSignExtend:
+ case scAddExpr:
+ case scMulExpr:
+ case scUMaxExpr:
+ case scSMaxExpr:
case scUMinExpr:
case scSMinExpr:
// These expressions are available if their operand(s) is/are.
@@ -5023,7 +5305,7 @@ static bool IsAvailableOnEntry(const Loop *L, DominatorTree &DT, const SCEV *S,
// We do not try to smart about these at all.
return setUnavailable();
}
- llvm_unreachable("switch should be fully covered!");
+ llvm_unreachable("Unknown SCEV kind!");
}
bool isDone() { return TraversalDone; }
@@ -5243,6 +5525,9 @@ uint32_t ScalarEvolution::GetMinTrailingZerosImpl(const SCEV *S) {
if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S))
return C->getAPInt().countTrailingZeros();
+ if (const SCEVPtrToIntExpr *I = dyn_cast<SCEVPtrToIntExpr>(S))
+ return GetMinTrailingZeros(I->getOperand());
+
if (const SCEVTruncateExpr *T = dyn_cast<SCEVTruncateExpr>(S))
return std::min(GetMinTrailingZeros(T->getOperand()),
(uint32_t)getTypeSizeInBits(T->getType()));
@@ -5334,6 +5619,15 @@ static Optional<ConstantRange> GetRangeFromMetadata(Value *V) {
return None;
}
+void ScalarEvolution::setNoWrapFlags(SCEVAddRecExpr *AddRec,
+ SCEV::NoWrapFlags Flags) {
+ if (AddRec->getNoWrapFlags(Flags) != Flags) {
+ AddRec->setNoWrapFlags(Flags);
+ UnsignedRanges.erase(AddRec);
+ SignedRanges.erase(AddRec);
+ }
+}
+
/// Determine the range for a particular SCEV. If SignHint is
/// HINT_RANGE_UNSIGNED (resp. HINT_RANGE_SIGNED) then getRange prefers ranges
/// with a "cleaner" unsigned (resp. signed) representation.
@@ -5448,6 +5742,11 @@ ScalarEvolution::getRangeRef(const SCEV *S,
RangeType));
}
+ if (const SCEVPtrToIntExpr *PtrToInt = dyn_cast<SCEVPtrToIntExpr>(S)) {
+ ConstantRange X = getRangeRef(PtrToInt->getOperand(), SignHint);
+ return setRange(PtrToInt, SignHint, X);
+ }
+
if (const SCEVTruncateExpr *Trunc = dyn_cast<SCEVTruncateExpr>(S)) {
ConstantRange X = getRangeRef(Trunc->getOperand(), SignHint);
return setRange(Trunc, SignHint,
@@ -5500,16 +5799,28 @@ ScalarEvolution::getRangeRef(const SCEV *S,
auto RangeFromAffine = getRangeForAffineAR(
AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
BitWidth);
- if (!RangeFromAffine.isFullSet())
- ConservativeResult =
- ConservativeResult.intersectWith(RangeFromAffine, RangeType);
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromAffine, RangeType);
auto RangeFromFactoring = getRangeViaFactoring(
AddRec->getStart(), AddRec->getStepRecurrence(*this), MaxBECount,
BitWidth);
- if (!RangeFromFactoring.isFullSet())
+ ConservativeResult =
+ ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
+ }
+
+ // Now try symbolic BE count and more powerful methods.
+ if (UseExpensiveRangeSharpening) {
+ const SCEV *SymbolicMaxBECount =
+ getSymbolicMaxBackedgeTakenCount(AddRec->getLoop());
+ if (!isa<SCEVCouldNotCompute>(SymbolicMaxBECount) &&
+ getTypeSizeInBits(MaxBECount->getType()) <= BitWidth &&
+ AddRec->hasNoSelfWrap()) {
+ auto RangeFromAffineNew = getRangeForAffineNoSelfWrappingAR(
+ AddRec, SymbolicMaxBECount, BitWidth, SignHint);
ConservativeResult =
- ConservativeResult.intersectWith(RangeFromFactoring, RangeType);
+ ConservativeResult.intersectWith(RangeFromAffineNew, RangeType);
+ }
}
}
@@ -5680,6 +5991,74 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,
return SR.intersectWith(UR, ConstantRange::Smallest);
}
+ConstantRange ScalarEvolution::getRangeForAffineNoSelfWrappingAR(
+ const SCEVAddRecExpr *AddRec, const SCEV *MaxBECount, unsigned BitWidth,
+ ScalarEvolution::RangeSignHint SignHint) {
+ assert(AddRec->isAffine() && "Non-affine AddRecs are not suppored!\n");
+ assert(AddRec->hasNoSelfWrap() &&
+ "This only works for non-self-wrapping AddRecs!");
+ const bool IsSigned = SignHint == HINT_RANGE_SIGNED;
+ const SCEV *Step = AddRec->getStepRecurrence(*this);
+ // Only deal with constant step to save compile time.
+ if (!isa<SCEVConstant>(Step))
+ return ConstantRange::getFull(BitWidth);
+ // Let's make sure that we can prove that we do not self-wrap during
+ // MaxBECount iterations. We need this because MaxBECount is a maximum
+ // iteration count estimate, and we might infer nw from some exit for which we
+ // do not know max exit count (or any other side reasoning).
+ // TODO: Turn into assert at some point.
+ if (getTypeSizeInBits(MaxBECount->getType()) >
+ getTypeSizeInBits(AddRec->getType()))
+ return ConstantRange::getFull(BitWidth);
+ MaxBECount = getNoopOrZeroExtend(MaxBECount, AddRec->getType());
+ const SCEV *RangeWidth = getMinusOne(AddRec->getType());
+ const SCEV *StepAbs = getUMinExpr(Step, getNegativeSCEV(Step));
+ const SCEV *MaxItersWithoutWrap = getUDivExpr(RangeWidth, StepAbs);
+ if (!isKnownPredicateViaConstantRanges(ICmpInst::ICMP_ULE, MaxBECount,
+ MaxItersWithoutWrap))
+ return ConstantRange::getFull(BitWidth);
+
+ ICmpInst::Predicate LEPred =
+ IsSigned ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
+ ICmpInst::Predicate GEPred =
+ IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE;
+ const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
+
+ // We know that there is no self-wrap. Let's take Start and End values and
+ // look at all intermediate values V1, V2, ..., Vn that IndVar takes during
+ // the iteration. They either lie inside the range [Min(Start, End),
+ // Max(Start, End)] or outside it:
+ //
+ // Case 1: RangeMin ... Start V1 ... VN End ... RangeMax;
+ // Case 2: RangeMin Vk ... V1 Start ... End Vn ... Vk + 1 RangeMax;
+ //
+ // No self wrap flag guarantees that the intermediate values cannot be BOTH
+ // outside and inside the range [Min(Start, End), Max(Start, End)]. Using that
+ // knowledge, let's try to prove that we are dealing with Case 1. It is so if
+ // Start <= End and step is positive, or Start >= End and step is negative.
+ const SCEV *Start = AddRec->getStart();
+ ConstantRange StartRange = getRangeRef(Start, SignHint);
+ ConstantRange EndRange = getRangeRef(End, SignHint);
+ ConstantRange RangeBetween = StartRange.unionWith(EndRange);
+ // If they already cover full iteration space, we will know nothing useful
+ // even if we prove what we want to prove.
+ if (RangeBetween.isFullSet())
+ return RangeBetween;
+ // Only deal with ranges that do not wrap (i.e. RangeMin < RangeMax).
+ bool IsWrappedSet = IsSigned ? RangeBetween.isSignWrappedSet()
+ : RangeBetween.isWrappedSet();
+ if (IsWrappedSet)
+ return ConstantRange::getFull(BitWidth);
+
+ if (isKnownPositive(Step) &&
+ isKnownPredicateViaConstantRanges(LEPred, Start, End))
+ return RangeBetween;
+ else if (isKnownNegative(Step) &&
+ isKnownPredicateViaConstantRanges(GEPred, Start, End))
+ return RangeBetween;
+ return ConstantRange::getFull(BitWidth);
+}
+
ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
const SCEV *Step,
const SCEV *MaxBECount,
@@ -5712,7 +6091,7 @@ ConstantRange ScalarEvolution::getRangeViaFactoring(const SCEV *Start,
}
// Peel off a cast operation
- if (auto *SCast = dyn_cast<SCEVCastExpr>(S)) {
+ if (auto *SCast = dyn_cast<SCEVIntegralCastExpr>(S)) {
CastOp = SCast->getSCEVType();
S = SCast->getOperand();
}
@@ -5913,7 +6292,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {
const Instruction *Poison = PoisonStack.pop_back_val();
for (auto *PoisonUser : Poison->users()) {
- if (propagatesPoison(cast<Instruction>(PoisonUser))) {
+ if (propagatesPoison(cast<Operator>(PoisonUser))) {
if (Pushed.insert(cast<Instruction>(PoisonUser)).second)
PoisonStack.push_back(cast<Instruction>(PoisonUser));
} else if (auto *BI = dyn_cast<BranchInst>(PoisonUser)) {
@@ -5977,6 +6356,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
} else if (ConstantInt *CI = dyn_cast<ConstantInt>(V))
return getConstant(CI);
else if (isa<ConstantPointerNull>(V))
+ // FIXME: we shouldn't special-case null pointer constant.
return getZero(V->getType());
else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V))
return GA->isInterposable() ? getUnknown(V) : getSCEV(GA->getAliasee());
@@ -6267,6 +6647,15 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
}
}
}
+ if (BO->IsExact) {
+ // Given exact arithmetic in-bounds right-shift by a constant,
+ // we can lower it into: (abs(x) EXACT/u (1<<C)) * signum(x)
+ const SCEV *X = getSCEV(BO->LHS);
+ const SCEV *AbsX = getAbsExpr(X, /*IsNSW=*/false);
+ APInt Mult = APInt::getOneBitSet(BitWidth, AShrAmt);
+ const SCEV *Div = getUDivExactExpr(AbsX, getConstant(Mult));
+ return getMulExpr(Div, getSignumExpr(X), SCEV::FlagNSW);
+ }
break;
}
}
@@ -6303,6 +6692,29 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
return getSCEV(U->getOperand(0));
break;
+ case Instruction::PtrToInt: {
+ // Pointer to integer cast is straight-forward, so do model it.
+ Value *Ptr = U->getOperand(0);
+ const SCEV *Op = getSCEV(Ptr);
+ Type *DstIntTy = U->getType();
+ // SCEV doesn't have constant pointer expression type, but it supports
+ // nullptr constant (and only that one), which is modelled in SCEV as a
+ // zero integer constant. So just skip the ptrtoint cast for constants.
+ if (isa<SCEVConstant>(Op))
+ return getTruncateOrZeroExtend(Op, DstIntTy);
+ Type *PtrTy = Ptr->getType();
+ Type *IntPtrTy = getDataLayout().getIntPtrType(PtrTy);
+ // But only if effective SCEV (integer) type is wide enough to represent
+ // all possible pointer values.
+ if (getDataLayout().getTypeSizeInBits(getEffectiveSCEVType(PtrTy)) !=
+ getDataLayout().getTypeSizeInBits(IntPtrTy))
+ return getUnknown(V);
+ return getPtrToIntExpr(Op, DstIntTy);
+ }
+ case Instruction::IntToPtr:
+ // Just don't deal with inttoptr casts.
+ return getUnknown(V);
+
case Instruction::SDiv:
// If both operands are non-negative, this is just an udiv.
if (isKnownNonNegative(getSCEV(U->getOperand(0))) &&
@@ -6317,11 +6729,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
return getURemExpr(getSCEV(U->getOperand(0)), getSCEV(U->getOperand(1)));
break;
- // It's tempting to handle inttoptr and ptrtoint as no-ops, however this can
- // lead to pointer expressions which cannot safely be expanded to GEPs,
- // because ScalarEvolution doesn't respect the GEP aliasing rules when
- // simplifying integer expressions.
-
case Instruction::GetElementPtr:
return createNodeForGEP(cast<GEPOperator>(U));
@@ -6342,6 +6749,45 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {
case Instruction::Invoke:
if (Value *RV = cast<CallBase>(U)->getReturnedArgOperand())
return getSCEV(RV);
+
+ if (auto *II = dyn_cast<IntrinsicInst>(U)) {
+ switch (II->getIntrinsicID()) {
+ case Intrinsic::abs:
+ return getAbsExpr(
+ getSCEV(II->getArgOperand(0)),
+ /*IsNSW=*/cast<ConstantInt>(II->getArgOperand(1))->isOne());
+ case Intrinsic::umax:
+ return getUMaxExpr(getSCEV(II->getArgOperand(0)),
+ getSCEV(II->getArgOperand(1)));
+ case Intrinsic::umin:
+ return getUMinExpr(getSCEV(II->getArgOperand(0)),
+ getSCEV(II->getArgOperand(1)));
+ case Intrinsic::smax:
+ return getSMaxExpr(getSCEV(II->getArgOperand(0)),
+ getSCEV(II->getArgOperand(1)));
+ case Intrinsic::smin:
+ return getSMinExpr(getSCEV(II->getArgOperand(0)),
+ getSCEV(II->getArgOperand(1)));
+ case Intrinsic::usub_sat: {
+ const SCEV *X = getSCEV(II->getArgOperand(0));
+ const SCEV *Y = getSCEV(II->getArgOperand(1));
+ const SCEV *ClampedY = getUMinExpr(X, Y);
+ return getMinusSCEV(X, ClampedY, SCEV::FlagNUW);
+ }
+ case Intrinsic::uadd_sat: {
+ const SCEV *X = getSCEV(II->getArgOperand(0));
+ const SCEV *Y = getSCEV(II->getArgOperand(1));
+ const SCEV *ClampedX = getUMinExpr(X, getNotSCEV(Y));
+ return getAddExpr(ClampedX, Y, SCEV::FlagNUW);
+ }
+ case Intrinsic::start_loop_iterations:
+ // A start_loop_iterations is just equivalent to the first operand for
+ // SCEV purposes.
+ return getSCEV(II->getArgOperand(0));
+ default:
+ break;
+ }
+ }
break;
}
@@ -6374,8 +6820,9 @@ unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L) {
return 0;
}
-unsigned ScalarEvolution::getSmallConstantTripCount(const Loop *L,
- BasicBlock *ExitingBlock) {
+unsigned
+ScalarEvolution::getSmallConstantTripCount(const Loop *L,
+ const BasicBlock *ExitingBlock) {
assert(ExitingBlock && "Must pass a non-null exiting block!");
assert(L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!");
@@ -6412,7 +6859,7 @@ unsigned ScalarEvolution::getSmallConstantTripMultiple(const Loop *L) {
/// that control exits the loop via ExitingBlock.
unsigned
ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
- BasicBlock *ExitingBlock) {
+ const BasicBlock *ExitingBlock) {
assert(ExitingBlock && "Must pass a non-null exiting block!");
assert(L->isLoopExiting(ExitingBlock) &&
"Exiting block must actually branch out of the loop!");
@@ -6443,13 +6890,14 @@ ScalarEvolution::getSmallConstantTripMultiple(const Loop *L,
}
const SCEV *ScalarEvolution::getExitCount(const Loop *L,
- BasicBlock *ExitingBlock,
+ const BasicBlock *ExitingBlock,
ExitCountKind Kind) {
switch (Kind) {
case Exact:
+ case SymbolicMaximum:
return getBackedgeTakenInfo(L).getExact(ExitingBlock, this);
case ConstantMaximum:
- return getBackedgeTakenInfo(L).getMax(ExitingBlock, this);
+ return getBackedgeTakenInfo(L).getConstantMax(ExitingBlock, this);
};
llvm_unreachable("Invalid ExitCountKind!");
}
@@ -6466,13 +6914,15 @@ const SCEV *ScalarEvolution::getBackedgeTakenCount(const Loop *L,
case Exact:
return getBackedgeTakenInfo(L).getExact(L, this);
case ConstantMaximum:
- return getBackedgeTakenInfo(L).getMax(this);
+ return getBackedgeTakenInfo(L).getConstantMax(this);
+ case SymbolicMaximum:
+ return getBackedgeTakenInfo(L).getSymbolicMax(L, this);
};
llvm_unreachable("Invalid ExitCountKind!");
}
bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) {
- return getBackedgeTakenInfo(L).isMaxOrZero(this);
+ return getBackedgeTakenInfo(L).isConstantMaxOrZero(this);
}
/// Push PHI nodes in the header of the given loop onto the given Worklist.
@@ -6502,7 +6952,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {
return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
}
-const ScalarEvolution::BackedgeTakenInfo &
+ScalarEvolution::BackedgeTakenInfo &
ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
// Initially insert an invalid entry for this loop. If the insertion
// succeeds, proceed to actually compute a backedge-taken count and
@@ -6526,12 +6976,11 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {
const SCEV *BEExact = Result.getExact(L, this);
if (BEExact != getCouldNotCompute()) {
assert(isLoopInvariant(BEExact, L) &&
- isLoopInvariant(Result.getMax(this), L) &&
+ isLoopInvariant(Result.getConstantMax(this), L) &&
"Computed backedge-taken count isn't loop invariant for loop!");
++NumTripCountsComputed;
- }
- else if (Result.getMax(this) == getCouldNotCompute() &&
- isa<PHINode>(L->getHeader()->begin())) {
+ } else if (Result.getConstantMax(this) == getCouldNotCompute() &&
+ isa<PHINode>(L->getHeader()->begin())) {
// Only count loops that have phi nodes as not being computable.
++NumTripCountsNotComputed;
}
@@ -6772,7 +7221,7 @@ ScalarEvolution::BackedgeTakenInfo::getExact(const Loop *L, ScalarEvolution *SE,
/// Get the exact not taken count for this loop exit.
const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
+ScalarEvolution::BackedgeTakenInfo::getExact(const BasicBlock *ExitingBlock,
ScalarEvolution *SE) const {
for (auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
@@ -6781,9 +7230,8 @@ ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,
return SE->getCouldNotCompute();
}
-const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getMax(BasicBlock *ExitingBlock,
- ScalarEvolution *SE) const {
+const SCEV *ScalarEvolution::BackedgeTakenInfo::getConstantMax(
+ const BasicBlock *ExitingBlock, ScalarEvolution *SE) const {
for (auto &ENT : ExitNotTaken)
if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())
return ENT.MaxNotTaken;
@@ -6791,22 +7239,32 @@ ScalarEvolution::BackedgeTakenInfo::getMax(BasicBlock *ExitingBlock,
return SE->getCouldNotCompute();
}
-/// getMax - Get the max backedge taken count for the loop.
+/// getConstantMax - Get the constant max backedge taken count for the loop.
const SCEV *
-ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const {
+ScalarEvolution::BackedgeTakenInfo::getConstantMax(ScalarEvolution *SE) const {
auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
return !ENT.hasAlwaysTruePredicate();
};
- if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax())
+ if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getConstantMax())
return SE->getCouldNotCompute();
- assert((isa<SCEVCouldNotCompute>(getMax()) || isa<SCEVConstant>(getMax())) &&
+ assert((isa<SCEVCouldNotCompute>(getConstantMax()) ||
+ isa<SCEVConstant>(getConstantMax())) &&
"No point in having a non-constant max backedge taken count!");
- return getMax();
+ return getConstantMax();
+}
+
+const SCEV *
+ScalarEvolution::BackedgeTakenInfo::getSymbolicMax(const Loop *L,
+ ScalarEvolution *SE) {
+ if (!SymbolicMax)
+ SymbolicMax = SE->computeSymbolicMaxBackedgeTakenCount(L);
+ return SymbolicMax;
}
-bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const {
+bool ScalarEvolution::BackedgeTakenInfo::isConstantMaxOrZero(
+ ScalarEvolution *SE) const {
auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) {
return !ENT.hasAlwaysTruePredicate();
};
@@ -6815,8 +7273,8 @@ bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const
bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,
ScalarEvolution *SE) const {
- if (getMax() && getMax() != SE->getCouldNotCompute() &&
- SE->hasOperand(getMax(), S))
+ if (getConstantMax() && getConstantMax() != SE->getCouldNotCompute() &&
+ SE->hasOperand(getConstantMax(), S))
return true;
for (auto &ENT : ExitNotTaken)
@@ -6869,10 +7327,9 @@ ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E, const SCEV *M,
/// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each
/// computable exit into a persistent ExitNotTakenInfo array.
ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
- ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo>
- ExitCounts,
- bool Complete, const SCEV *MaxCount, bool MaxOrZero)
- : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) {
+ ArrayRef<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> ExitCounts,
+ bool IsComplete, const SCEV *ConstantMax, bool MaxOrZero)
+ : ConstantMax(ConstantMax), IsComplete(IsComplete), MaxOrZero(MaxOrZero) {
using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;
ExitNotTaken.reserve(ExitCounts.size());
@@ -6892,7 +7349,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(
return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, EL.MaxNotTaken,
std::move(Predicate));
});
- assert((isa<SCEVCouldNotCompute>(MaxCount) || isa<SCEVConstant>(MaxCount)) &&
+ assert((isa<SCEVCouldNotCompute>(ConstantMax) ||
+ isa<SCEVConstant>(ConstantMax)) &&
"No point in having a non-constant max backedge taken count!");
}
@@ -7081,114 +7539,10 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondCached(
ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
bool ControlsExit, bool AllowPredicates) {
- // Check if the controlling expression for this loop is an And or Or.
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(ExitCond)) {
- if (BO->getOpcode() == Instruction::And) {
- // Recurse on the operands of the and.
- bool EitherMayExit = !ExitIfTrue;
- ExitLimit EL0 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(0), ExitIfTrue,
- ControlsExit && !EitherMayExit, AllowPredicates);
- ExitLimit EL1 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(1), ExitIfTrue,
- ControlsExit && !EitherMayExit, AllowPredicates);
- // Be robust against unsimplified IR for the form "and i1 X, true"
- if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1)))
- return CI->isOne() ? EL0 : EL1;
- if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(0)))
- return CI->isOne() ? EL1 : EL0;
- const SCEV *BECount = getCouldNotCompute();
- const SCEV *MaxBECount = getCouldNotCompute();
- if (EitherMayExit) {
- // Both conditions must be true for the loop to continue executing.
- // Choose the less conservative count.
- if (EL0.ExactNotTaken == getCouldNotCompute() ||
- EL1.ExactNotTaken == getCouldNotCompute())
- BECount = getCouldNotCompute();
- else
- BECount =
- getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
- if (EL0.MaxNotTaken == getCouldNotCompute())
- MaxBECount = EL1.MaxNotTaken;
- else if (EL1.MaxNotTaken == getCouldNotCompute())
- MaxBECount = EL0.MaxNotTaken;
- else
- MaxBECount =
- getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
- } else {
- // Both conditions must be true at the same time for the loop to exit.
- // For now, be conservative.
- if (EL0.MaxNotTaken == EL1.MaxNotTaken)
- MaxBECount = EL0.MaxNotTaken;
- if (EL0.ExactNotTaken == EL1.ExactNotTaken)
- BECount = EL0.ExactNotTaken;
- }
-
- // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
- // to be more aggressive when computing BECount than when computing
- // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
- // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
- // to not.
- if (isa<SCEVCouldNotCompute>(MaxBECount) &&
- !isa<SCEVCouldNotCompute>(BECount))
- MaxBECount = getConstant(getUnsignedRangeMax(BECount));
-
- return ExitLimit(BECount, MaxBECount, false,
- {&EL0.Predicates, &EL1.Predicates});
- }
- if (BO->getOpcode() == Instruction::Or) {
- // Recurse on the operands of the or.
- bool EitherMayExit = ExitIfTrue;
- ExitLimit EL0 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(0), ExitIfTrue,
- ControlsExit && !EitherMayExit, AllowPredicates);
- ExitLimit EL1 = computeExitLimitFromCondCached(
- Cache, L, BO->getOperand(1), ExitIfTrue,
- ControlsExit && !EitherMayExit, AllowPredicates);
- // Be robust against unsimplified IR for the form "or i1 X, true"
- if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1)))
- return CI->isZero() ? EL0 : EL1;
- if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(0)))
- return CI->isZero() ? EL1 : EL0;
- const SCEV *BECount = getCouldNotCompute();
- const SCEV *MaxBECount = getCouldNotCompute();
- if (EitherMayExit) {
- // Both conditions must be false for the loop to continue executing.
- // Choose the less conservative count.
- if (EL0.ExactNotTaken == getCouldNotCompute() ||
- EL1.ExactNotTaken == getCouldNotCompute())
- BECount = getCouldNotCompute();
- else
- BECount =
- getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
- if (EL0.MaxNotTaken == getCouldNotCompute())
- MaxBECount = EL1.MaxNotTaken;
- else if (EL1.MaxNotTaken == getCouldNotCompute())
- MaxBECount = EL0.MaxNotTaken;
- else
- MaxBECount =
- getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
- } else {
- // Both conditions must be false at the same time for the loop to exit.
- // For now, be conservative.
- if (EL0.MaxNotTaken == EL1.MaxNotTaken)
- MaxBECount = EL0.MaxNotTaken;
- if (EL0.ExactNotTaken == EL1.ExactNotTaken)
- BECount = EL0.ExactNotTaken;
- }
- // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
- // to be more aggressive when computing BECount than when computing
- // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
- // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
- // to not.
- if (isa<SCEVCouldNotCompute>(MaxBECount) &&
- !isa<SCEVCouldNotCompute>(BECount))
- MaxBECount = getConstant(getUnsignedRangeMax(BECount));
-
- return ExitLimit(BECount, MaxBECount, false,
- {&EL0.Predicates, &EL1.Predicates});
- }
- }
+ // Handle BinOp conditions (And, Or).
+ if (auto LimitFromBinOp = computeExitLimitFromCondFromBinOp(
+ Cache, L, ExitCond, ExitIfTrue, ControlsExit, AllowPredicates))
+ return *LimitFromBinOp;
// With an icmp, it may be feasible to compute an exact backedge-taken count.
// Proceed to the next level to examine the icmp.
@@ -7220,6 +7574,95 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimitFromCondImpl(
return computeExitCountExhaustively(L, ExitCond, ExitIfTrue);
}
+Optional<ScalarEvolution::ExitLimit>
+ScalarEvolution::computeExitLimitFromCondFromBinOp(
+ ExitLimitCacheTy &Cache, const Loop *L, Value *ExitCond, bool ExitIfTrue,
+ bool ControlsExit, bool AllowPredicates) {
+ // Check if the controlling expression for this loop is an And or Or.
+ Value *Op0, *Op1;
+ bool IsAnd = false;
+ if (match(ExitCond, m_LogicalAnd(m_Value(Op0), m_Value(Op1))))
+ IsAnd = true;
+ else if (match(ExitCond, m_LogicalOr(m_Value(Op0), m_Value(Op1))))
+ IsAnd = false;
+ else
+ return None;
+
+ // EitherMayExit is true in these two cases:
+ // br (and Op0 Op1), loop, exit
+ // br (or Op0 Op1), exit, loop
+ bool EitherMayExit = IsAnd ^ ExitIfTrue;
+ ExitLimit EL0 = computeExitLimitFromCondCached(Cache, L, Op0, ExitIfTrue,
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
+ ExitLimit EL1 = computeExitLimitFromCondCached(Cache, L, Op1, ExitIfTrue,
+ ControlsExit && !EitherMayExit,
+ AllowPredicates);
+
+ // Be robust against unsimplified IR for the form "op i1 X, NeutralElement"
+ const Constant *NeutralElement = ConstantInt::get(ExitCond->getType(), IsAnd);
+ if (isa<ConstantInt>(Op1))
+ return Op1 == NeutralElement ? EL0 : EL1;
+ if (isa<ConstantInt>(Op0))
+ return Op0 == NeutralElement ? EL1 : EL0;
+
+ const SCEV *BECount = getCouldNotCompute();
+ const SCEV *MaxBECount = getCouldNotCompute();
+ if (EitherMayExit) {
+ // Both conditions must be same for the loop to continue executing.
+ // Choose the less conservative count.
+ // If ExitCond is a short-circuit form (select), using
+ // umin(EL0.ExactNotTaken, EL1.ExactNotTaken) is unsafe in general.
+ // To see the detailed examples, please see
+ // test/Analysis/ScalarEvolution/exit-count-select.ll
+ bool PoisonSafe = isa<BinaryOperator>(ExitCond);
+ if (!PoisonSafe)
+ // Even if ExitCond is select, we can safely derive BECount using both
+ // EL0 and EL1 in these cases:
+ // (1) EL0.ExactNotTaken is non-zero
+ // (2) EL1.ExactNotTaken is non-poison
+ // (3) EL0.ExactNotTaken is zero (BECount should be simply zero and
+ // it cannot be umin(0, ..))
+ // The PoisonSafe assignment below is simplified and the assertion after
+ // BECount calculation fully guarantees the condition (3).
+ PoisonSafe = isa<SCEVConstant>(EL0.ExactNotTaken) ||
+ isa<SCEVConstant>(EL1.ExactNotTaken);
+ if (EL0.ExactNotTaken != getCouldNotCompute() &&
+ EL1.ExactNotTaken != getCouldNotCompute() && PoisonSafe) {
+ BECount =
+ getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken);
+
+ // If EL0.ExactNotTaken was zero and ExitCond was a short-circuit form,
+ // it should have been simplified to zero (see the condition (3) above)
+ assert(!isa<BinaryOperator>(ExitCond) || !EL0.ExactNotTaken->isZero() ||
+ BECount->isZero());
+ }
+ if (EL0.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL1.MaxNotTaken;
+ else if (EL1.MaxNotTaken == getCouldNotCompute())
+ MaxBECount = EL0.MaxNotTaken;
+ else
+ MaxBECount = getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);
+ } else {
+ // Both conditions must be same at the same time for the loop to exit.
+ // For now, be conservative.
+ if (EL0.ExactNotTaken == EL1.ExactNotTaken)
+ BECount = EL0.ExactNotTaken;
+ }
+
+ // There are cases (e.g. PR26207) where computeExitLimitFromCond is able
+ // to be more aggressive when computing BECount than when computing
+ // MaxBECount. In these cases it is possible for EL0.ExactNotTaken and
+ // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken
+ // to not.
+ if (isa<SCEVCouldNotCompute>(MaxBECount) &&
+ !isa<SCEVCouldNotCompute>(BECount))
+ MaxBECount = getConstant(getUnsignedRangeMax(BECount));
+
+ return ExitLimit(BECount, MaxBECount, false,
+ { &EL0.Predicates, &EL1.Predicates });
+}
+
ScalarEvolution::ExitLimit
ScalarEvolution::computeExitLimitFromICmp(const Loop *L,
ICmpInst *ExitCond,
@@ -7914,100 +8357,110 @@ const SCEV *ScalarEvolution::getSCEVAtScope(const SCEV *V, const Loop *L) {
/// SCEVConstant, because SCEVConstant is restricted to ConstantInt.
/// Returns NULL if the SCEV isn't representable as a Constant.
static Constant *BuildConstantFromSCEV(const SCEV *V) {
- switch (static_cast<SCEVTypes>(V->getSCEVType())) {
- case scCouldNotCompute:
- case scAddRecExpr:
- break;
- case scConstant:
- return cast<SCEVConstant>(V)->getValue();
- case scUnknown:
- return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
- case scSignExtend: {
- const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
- return ConstantExpr::getSExt(CastOp, SS->getType());
- break;
- }
- case scZeroExtend: {
- const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
- return ConstantExpr::getZExt(CastOp, SZ->getType());
- break;
- }
- case scTruncate: {
- const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
- if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
- return ConstantExpr::getTrunc(CastOp, ST->getType());
- break;
- }
- case scAddExpr: {
- const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
- if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
- if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
- unsigned AS = PTy->getAddressSpace();
+ switch (V->getSCEVType()) {
+ case scCouldNotCompute:
+ case scAddRecExpr:
+ return nullptr;
+ case scConstant:
+ return cast<SCEVConstant>(V)->getValue();
+ case scUnknown:
+ return dyn_cast<Constant>(cast<SCEVUnknown>(V)->getValue());
+ case scSignExtend: {
+ const SCEVSignExtendExpr *SS = cast<SCEVSignExtendExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(SS->getOperand()))
+ return ConstantExpr::getSExt(CastOp, SS->getType());
+ return nullptr;
+ }
+ case scZeroExtend: {
+ const SCEVZeroExtendExpr *SZ = cast<SCEVZeroExtendExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(SZ->getOperand()))
+ return ConstantExpr::getZExt(CastOp, SZ->getType());
+ return nullptr;
+ }
+ case scPtrToInt: {
+ const SCEVPtrToIntExpr *P2I = cast<SCEVPtrToIntExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(P2I->getOperand()))
+ return ConstantExpr::getPtrToInt(CastOp, P2I->getType());
+
+ return nullptr;
+ }
+ case scTruncate: {
+ const SCEVTruncateExpr *ST = cast<SCEVTruncateExpr>(V);
+ if (Constant *CastOp = BuildConstantFromSCEV(ST->getOperand()))
+ return ConstantExpr::getTrunc(CastOp, ST->getType());
+ return nullptr;
+ }
+ case scAddExpr: {
+ const SCEVAddExpr *SA = cast<SCEVAddExpr>(V);
+ if (Constant *C = BuildConstantFromSCEV(SA->getOperand(0))) {
+ if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
+ unsigned AS = PTy->getAddressSpace();
+ Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
+ C = ConstantExpr::getBitCast(C, DestPtrTy);
+ }
+ for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
+ Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
+ if (!C2)
+ return nullptr;
+
+ // First pointer!
+ if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
+ unsigned AS = C2->getType()->getPointerAddressSpace();
+ std::swap(C, C2);
Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
+ // The offsets have been converted to bytes. We can add bytes to an
+ // i8* by GEP with the byte count in the first index.
C = ConstantExpr::getBitCast(C, DestPtrTy);
}
- for (unsigned i = 1, e = SA->getNumOperands(); i != e; ++i) {
- Constant *C2 = BuildConstantFromSCEV(SA->getOperand(i));
- if (!C2) return nullptr;
-
- // First pointer!
- if (!C->getType()->isPointerTy() && C2->getType()->isPointerTy()) {
- unsigned AS = C2->getType()->getPointerAddressSpace();
- std::swap(C, C2);
- Type *DestPtrTy = Type::getInt8PtrTy(C->getContext(), AS);
- // The offsets have been converted to bytes. We can add bytes to an
- // i8* by GEP with the byte count in the first index.
- C = ConstantExpr::getBitCast(C, DestPtrTy);
- }
- // Don't bother trying to sum two pointers. We probably can't
- // statically compute a load that results from it anyway.
- if (C2->getType()->isPointerTy())
- return nullptr;
+ // Don't bother trying to sum two pointers. We probably can't
+ // statically compute a load that results from it anyway.
+ if (C2->getType()->isPointerTy())
+ return nullptr;
- if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
- if (PTy->getElementType()->isStructTy())
- C2 = ConstantExpr::getIntegerCast(
- C2, Type::getInt32Ty(C->getContext()), true);
- C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
- } else
- C = ConstantExpr::getAdd(C, C2);
- }
- return C;
+ if (PointerType *PTy = dyn_cast<PointerType>(C->getType())) {
+ if (PTy->getElementType()->isStructTy())
+ C2 = ConstantExpr::getIntegerCast(
+ C2, Type::getInt32Ty(C->getContext()), true);
+ C = ConstantExpr::getGetElementPtr(PTy->getElementType(), C, C2);
+ } else
+ C = ConstantExpr::getAdd(C, C2);
}
- break;
+ return C;
}
- case scMulExpr: {
- const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
- if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
- // Don't bother with pointers at all.
- if (C->getType()->isPointerTy()) return nullptr;
- for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
- Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
- if (!C2 || C2->getType()->isPointerTy()) return nullptr;
- C = ConstantExpr::getMul(C, C2);
- }
- return C;
+ return nullptr;
+ }
+ case scMulExpr: {
+ const SCEVMulExpr *SM = cast<SCEVMulExpr>(V);
+ if (Constant *C = BuildConstantFromSCEV(SM->getOperand(0))) {
+ // Don't bother with pointers at all.
+ if (C->getType()->isPointerTy())
+ return nullptr;
+ for (unsigned i = 1, e = SM->getNumOperands(); i != e; ++i) {
+ Constant *C2 = BuildConstantFromSCEV(SM->getOperand(i));
+ if (!C2 || C2->getType()->isPointerTy())
+ return nullptr;
+ C = ConstantExpr::getMul(C, C2);
}
- break;
+ return C;
}
- case scUDivExpr: {
- const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
- if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
- if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
- if (LHS->getType() == RHS->getType())
- return ConstantExpr::getUDiv(LHS, RHS);
- break;
- }
- case scSMaxExpr:
- case scUMaxExpr:
- case scSMinExpr:
- case scUMinExpr:
- break; // TODO: smax, umax, smin, umax.
+ return nullptr;
}
- return nullptr;
+ case scUDivExpr: {
+ const SCEVUDivExpr *SU = cast<SCEVUDivExpr>(V);
+ if (Constant *LHS = BuildConstantFromSCEV(SU->getLHS()))
+ if (Constant *RHS = BuildConstantFromSCEV(SU->getRHS()))
+ if (LHS->getType() == RHS->getType())
+ return ConstantExpr::getUDiv(LHS, RHS);
+ return nullptr;
+ }
+ case scSMaxExpr:
+ case scUMaxExpr:
+ case scSMinExpr:
+ case scUMinExpr:
+ return nullptr; // TODO: smax, umax, smin, umax.
+ }
+ llvm_unreachable("Unknown SCEV kind!");
}
const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
@@ -8018,22 +8471,22 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(V)) {
if (Instruction *I = dyn_cast<Instruction>(SU->getValue())) {
if (PHINode *PN = dyn_cast<PHINode>(I)) {
- const Loop *LI = this->LI[I->getParent()];
+ const Loop *CurrLoop = this->LI[I->getParent()];
// Looking for loop exit value.
- if (LI && LI->getParentLoop() == L &&
- PN->getParent() == LI->getHeader()) {
+ if (CurrLoop && CurrLoop->getParentLoop() == L &&
+ PN->getParent() == CurrLoop->getHeader()) {
// Okay, there is no closed form solution for the PHI node. Check
// to see if the loop that contains it has a known backedge-taken
// count. If so, we may be able to force computation of the exit
// value.
- const SCEV *BackedgeTakenCount = getBackedgeTakenCount(LI);
+ const SCEV *BackedgeTakenCount = getBackedgeTakenCount(CurrLoop);
// This trivial case can show up in some degenerate cases where
// the incoming IR has not yet been fully simplified.
if (BackedgeTakenCount->isZero()) {
Value *InitValue = nullptr;
bool MultipleInitValues = false;
for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) {
- if (!LI->contains(PN->getIncomingBlock(i))) {
+ if (!CurrLoop->contains(PN->getIncomingBlock(i))) {
if (!InitValue)
InitValue = PN->getIncomingValue(i);
else if (InitValue != PN->getIncomingValue(i)) {
@@ -8051,17 +8504,18 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
isKnownPositive(BackedgeTakenCount) &&
PN->getNumIncomingValues() == 2) {
- unsigned InLoopPred = LI->contains(PN->getIncomingBlock(0)) ? 0 : 1;
+ unsigned InLoopPred =
+ CurrLoop->contains(PN->getIncomingBlock(0)) ? 0 : 1;
Value *BackedgeVal = PN->getIncomingValue(InLoopPred);
- if (LI->isLoopInvariant(BackedgeVal))
+ if (CurrLoop->isLoopInvariant(BackedgeVal))
return getSCEV(BackedgeVal);
}
if (auto *BTCC = dyn_cast<SCEVConstant>(BackedgeTakenCount)) {
// Okay, we know how many times the containing loop executes. If
// this is a constant evolving PHI node, get the final value at
// the specified iteration number.
- Constant *RV =
- getConstantEvolutionLoopExitValue(PN, BTCC->getAPInt(), LI);
+ Constant *RV = getConstantEvolutionLoopExitValue(
+ PN, BTCC->getAPInt(), CurrLoop);
if (RV) return getSCEV(RV);
}
}
@@ -8117,9 +8571,10 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
if (const CmpInst *CI = dyn_cast<CmpInst>(I))
C = ConstantFoldCompareInstOperands(CI->getPredicate(), Operands[0],
Operands[1], DL, &TLI);
- else if (const LoadInst *LI = dyn_cast<LoadInst>(I)) {
- if (!LI->isVolatile())
- C = ConstantFoldLoadFromConstPtr(Operands[0], LI->getType(), DL);
+ else if (const LoadInst *Load = dyn_cast<LoadInst>(I)) {
+ if (!Load->isVolatile())
+ C = ConstantFoldLoadFromConstPtr(Operands[0], Load->getType(),
+ DL);
} else
C = ConstantFoldInstOperands(I, Operands, DL, &TLI);
if (!C) return V;
@@ -8236,6 +8691,13 @@ const SCEV *ScalarEvolution::computeSCEVAtScope(const SCEV *V, const Loop *L) {
return getTruncateExpr(Op, Cast->getType());
}
+ if (const SCEVPtrToIntExpr *Cast = dyn_cast<SCEVPtrToIntExpr>(V)) {
+ const SCEV *Op = getSCEVAtScope(Cast->getOperand(), L);
+ if (Op == Cast->getOperand())
+ return Cast; // must be loop invariant
+ return getPtrToIntExpr(Op, Cast->getType());
+ }
+
llvm_unreachable("Unknown SCEV type!");
}
@@ -8650,7 +9112,10 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
if (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne()) {
- APInt MaxBECount = getUnsignedRangeMax(Distance);
+ APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+ APInt MaxBECountBase = getUnsignedRangeMax(Distance);
+ if (MaxBECountBase.ult(MaxBECount))
+ MaxBECount = MaxBECountBase;
// When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
// we end up with a loop whose backedge-taken count is n - 1. Detect this
@@ -8715,18 +9180,19 @@ ScalarEvolution::howFarToNonZero(const SCEV *V, const Loop *L) {
return getCouldNotCompute();
}
-std::pair<BasicBlock *, BasicBlock *>
-ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {
+std::pair<const BasicBlock *, const BasicBlock *>
+ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(const BasicBlock *BB)
+ const {
// If the block has a unique predecessor, then there is no path from the
// predecessor to the block that does not go through the direct edge
// from the predecessor to the block.
- if (BasicBlock *Pred = BB->getSinglePredecessor())
+ if (const BasicBlock *Pred = BB->getSinglePredecessor())
return {Pred, BB};
// A loop's header is defined to be a block that dominates the loop.
// If the header has a unique predecessor outside the loop, it must be
// a block that has exactly one successor that can reach the loop.
- if (Loop *L = LI.getLoopFor(BB))
+ if (const Loop *L = LI.getLoopFor(BB))
return {L->getLoopPredecessor(), L->getHeader()};
return {nullptr, nullptr};
@@ -9055,6 +9521,14 @@ bool ScalarEvolution::isKnownPredicate(ICmpInst::Predicate Pred,
return isKnownViaNonRecursiveReasoning(Pred, LHS, RHS);
}
+bool ScalarEvolution::isKnownPredicateAt(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const Instruction *Context) {
+ // TODO: Analyze guards and assumes from Context's block.
+ return isKnownPredicate(Pred, LHS, RHS) ||
+ isBasicBlockEntryGuardedByCond(Context->getParent(), Pred, LHS, RHS);
+}
+
bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
const SCEVAddRecExpr *LHS,
const SCEV *RHS) {
@@ -9063,31 +9537,30 @@ bool ScalarEvolution::isKnownOnEveryIteration(ICmpInst::Predicate Pred,
isLoopBackedgeGuardedByCond(L, Pred, LHS->getPostIncExpr(*this), RHS);
}
-bool ScalarEvolution::isMonotonicPredicate(const SCEVAddRecExpr *LHS,
- ICmpInst::Predicate Pred,
- bool &Increasing) {
- bool Result = isMonotonicPredicateImpl(LHS, Pred, Increasing);
+Optional<ScalarEvolution::MonotonicPredicateType>
+ScalarEvolution::getMonotonicPredicateType(const SCEVAddRecExpr *LHS,
+ ICmpInst::Predicate Pred) {
+ auto Result = getMonotonicPredicateTypeImpl(LHS, Pred);
#ifndef NDEBUG
// Verify an invariant: inverting the predicate should turn a monotonically
// increasing change to a monotonically decreasing one, and vice versa.
- bool IncreasingSwapped;
- bool ResultSwapped = isMonotonicPredicateImpl(
- LHS, ICmpInst::getSwappedPredicate(Pred), IncreasingSwapped);
+ if (Result) {
+ auto ResultSwapped =
+ getMonotonicPredicateTypeImpl(LHS, ICmpInst::getSwappedPredicate(Pred));
- assert(Result == ResultSwapped && "should be able to analyze both!");
- if (ResultSwapped)
- assert(Increasing == !IncreasingSwapped &&
+ assert(ResultSwapped.hasValue() && "should be able to analyze both!");
+ assert(ResultSwapped.getValue() != Result.getValue() &&
"monotonicity should flip as we flip the predicate");
+ }
#endif
return Result;
}
-bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
- ICmpInst::Predicate Pred,
- bool &Increasing) {
-
+Optional<ScalarEvolution::MonotonicPredicateType>
+ScalarEvolution::getMonotonicPredicateTypeImpl(const SCEVAddRecExpr *LHS,
+ ICmpInst::Predicate Pred) {
// A zero step value for LHS means the induction variable is essentially a
// loop invariant value. We don't really depend on the predicate actually
// flipping from false to true (for increasing predicates, and the other way
@@ -9098,56 +9571,46 @@ bool ScalarEvolution::isMonotonicPredicateImpl(const SCEVAddRecExpr *LHS,
// where SCEV can prove X >= 0 but not prove X > 0, so it is helpful to be
// as general as possible.
- switch (Pred) {
- default:
- return false; // Conservative answer
-
- case ICmpInst::ICMP_UGT:
- case ICmpInst::ICMP_UGE:
- case ICmpInst::ICMP_ULT:
- case ICmpInst::ICMP_ULE:
- if (!LHS->hasNoUnsignedWrap())
- return false;
+ // Only handle LE/LT/GE/GT predicates.
+ if (!ICmpInst::isRelational(Pred))
+ return None;
- Increasing = Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE;
- return true;
+ bool IsGreater = ICmpInst::isGE(Pred) || ICmpInst::isGT(Pred);
+ assert((IsGreater || ICmpInst::isLE(Pred) || ICmpInst::isLT(Pred)) &&
+ "Should be greater or less!");
- case ICmpInst::ICMP_SGT:
- case ICmpInst::ICMP_SGE:
- case ICmpInst::ICMP_SLT:
- case ICmpInst::ICMP_SLE: {
+ // Check that AR does not wrap.
+ if (ICmpInst::isUnsigned(Pred)) {
+ if (!LHS->hasNoUnsignedWrap())
+ return None;
+ return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
+ } else {
+ assert(ICmpInst::isSigned(Pred) &&
+ "Relational predicate is either signed or unsigned!");
if (!LHS->hasNoSignedWrap())
- return false;
+ return None;
const SCEV *Step = LHS->getStepRecurrence(*this);
- if (isKnownNonNegative(Step)) {
- Increasing = Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE;
- return true;
- }
+ if (isKnownNonNegative(Step))
+ return IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
- if (isKnownNonPositive(Step)) {
- Increasing = Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE;
- return true;
- }
-
- return false;
- }
+ if (isKnownNonPositive(Step))
+ return !IsGreater ? MonotonicallyIncreasing : MonotonicallyDecreasing;
+ return None;
}
-
- llvm_unreachable("switch has default clause!");
}
-bool ScalarEvolution::isLoopInvariantPredicate(
- ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
- ICmpInst::Predicate &InvariantPred, const SCEV *&InvariantLHS,
- const SCEV *&InvariantRHS) {
+Optional<ScalarEvolution::LoopInvariantPredicate>
+ScalarEvolution::getLoopInvariantPredicate(ICmpInst::Predicate Pred,
+ const SCEV *LHS, const SCEV *RHS,
+ const Loop *L) {
// If there is a loop-invariant, force it into the RHS, otherwise bail out.
if (!isLoopInvariant(RHS, L)) {
if (!isLoopInvariant(LHS, L))
- return false;
+ return None;
std::swap(LHS, RHS);
Pred = ICmpInst::getSwappedPredicate(Pred);
@@ -9155,12 +9618,11 @@ bool ScalarEvolution::isLoopInvariantPredicate(
const SCEVAddRecExpr *ArLHS = dyn_cast<SCEVAddRecExpr>(LHS);
if (!ArLHS || ArLHS->getLoop() != L)
- return false;
-
- bool Increasing;
- if (!isMonotonicPredicate(ArLHS, Pred, Increasing))
- return false;
+ return None;
+ auto MonotonicType = getMonotonicPredicateType(ArLHS, Pred);
+ if (!MonotonicType)
+ return None;
// If the predicate "ArLHS `Pred` RHS" monotonically increases from false to
// true as the loop iterates, and the backedge is control dependent on
// "ArLHS `Pred` RHS" == true then we can reason as follows:
@@ -9178,16 +9640,77 @@ bool ScalarEvolution::isLoopInvariantPredicate(
//
// A similar reasoning applies for a monotonically decreasing predicate, by
// replacing true with false and false with true in the above two bullets.
-
+ bool Increasing = *MonotonicType == ScalarEvolution::MonotonicallyIncreasing;
auto P = Increasing ? Pred : ICmpInst::getInversePredicate(Pred);
if (!isLoopBackedgeGuardedByCond(L, P, LHS, RHS))
- return false;
+ return None;
- InvariantPred = Pred;
- InvariantLHS = ArLHS->getStart();
- InvariantRHS = RHS;
- return true;
+ return ScalarEvolution::LoopInvariantPredicate(Pred, ArLHS->getStart(), RHS);
+}
+
+Optional<ScalarEvolution::LoopInvariantPredicate>
+ScalarEvolution::getLoopInvariantExitCondDuringFirstIterations(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, const Loop *L,
+ const Instruction *Context, const SCEV *MaxIter) {
+ // Try to prove the following set of facts:
+ // - The predicate is monotonic in the iteration space.
+ // - If the check does not fail on the 1st iteration:
+ // - No overflow will happen during first MaxIter iterations;
+ // - It will not fail on the MaxIter'th iteration.
+ // If the check does fail on the 1st iteration, we leave the loop and no
+ // other checks matter.
+
+ // If there is a loop-invariant, force it into the RHS, otherwise bail out.
+ if (!isLoopInvariant(RHS, L)) {
+ if (!isLoopInvariant(LHS, L))
+ return None;
+
+ std::swap(LHS, RHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ auto *AR = dyn_cast<SCEVAddRecExpr>(LHS);
+ if (!AR || AR->getLoop() != L)
+ return None;
+
+ // The predicate must be relational (i.e. <, <=, >=, >).
+ if (!ICmpInst::isRelational(Pred))
+ return None;
+
+ // TODO: Support steps other than +/- 1.
+ const SCEV *Step = AR->getStepRecurrence(*this);
+ auto *One = getOne(Step->getType());
+ auto *MinusOne = getNegativeSCEV(One);
+ if (Step != One && Step != MinusOne)
+ return None;
+
+ // Type mismatch here means that MaxIter is potentially larger than max
+ // unsigned value in start type, which mean we cannot prove no wrap for the
+ // indvar.
+ if (AR->getType() != MaxIter->getType())
+ return None;
+
+ // Value of IV on suggested last iteration.
+ const SCEV *Last = AR->evaluateAtIteration(MaxIter, *this);
+ // Does it still meet the requirement?
+ if (!isLoopBackedgeGuardedByCond(L, Pred, Last, RHS))
+ return None;
+ // Because step is +/- 1 and MaxIter has same type as Start (i.e. it does
+ // not exceed max unsigned value of this type), this effectively proves
+ // that there is no wrap during the iteration. To prove that there is no
+ // signed/unsigned wrap, we need to check that
+ // Start <= Last for step = 1 or Start >= Last for step = -1.
+ ICmpInst::Predicate NoOverflowPred =
+ CmpInst::isSigned(Pred) ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE;
+ if (Step == MinusOne)
+ NoOverflowPred = CmpInst::getSwappedPredicate(NoOverflowPred);
+ const SCEV *Start = AR->getStart();
+ if (!isKnownPredicateAt(NoOverflowPred, Start, Last, Context))
+ return None;
+
+ // Everything is fine.
+ return ScalarEvolution::LoopInvariantPredicate(Pred, Start, RHS);
}
bool ScalarEvolution::isKnownPredicateViaConstantRanges(
@@ -9272,6 +9795,24 @@ bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,
if (MatchBinaryAddToConst(LHS, RHS, C, SCEV::FlagNSW) && C.isNegative())
return true;
break;
+
+ case ICmpInst::ICMP_UGE:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_ULE:
+ // X u<= (X + C)<nuw> for any C
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW))
+ return true;
+ break;
+
+ case ICmpInst::ICMP_UGT:
+ std::swap(LHS, RHS);
+ LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_ULT:
+ // X u< (X + C)<nuw> if C != 0
+ if (MatchBinaryAddToConst(RHS, LHS, C, SCEV::FlagNUW) && !C.isNullValue())
+ return true;
+ break;
}
return false;
@@ -9299,14 +9840,14 @@ bool ScalarEvolution::isKnownPredicateViaSplitting(ICmpInst::Predicate Pred,
isKnownPredicate(CmpInst::ICMP_SLT, LHS, RHS);
}
-bool ScalarEvolution::isImpliedViaGuard(BasicBlock *BB,
+bool ScalarEvolution::isImpliedViaGuard(const BasicBlock *BB,
ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS) {
// No need to even try if we know the module has no guards.
if (!HasGuards)
return false;
- return any_of(*BB, [&](Instruction &I) {
+ return any_of(*BB, [&](const Instruction &I) {
using namespace llvm::PatternMatch;
Value *Condition;
@@ -9429,24 +9970,14 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
return false;
}
-bool
-ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
- ICmpInst::Predicate Pred,
- const SCEV *LHS, const SCEV *RHS) {
- // Interpret a null as meaning no loop, where there is obviously no guard
- // (interprocedural conditions notwithstanding).
- if (!L) return false;
-
+bool ScalarEvolution::isBasicBlockEntryGuardedByCond(const BasicBlock *BB,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
if (VerifyIR)
- assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
+ assert(!verifyFunction(*BB->getParent(), &dbgs()) &&
"This cannot be done on broken IR!");
- // Both LHS and RHS must be available at loop entry.
- assert(isAvailableAtLoopEntry(LHS, L) &&
- "LHS is not available at Loop Entry");
- assert(isAvailableAtLoopEntry(RHS, L) &&
- "RHS is not available at Loop Entry");
-
if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
@@ -9470,7 +10001,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
}
// Try to prove (Pred, LHS, RHS) using isImpliedViaGuard.
- auto ProveViaGuard = [&](BasicBlock *Block) {
+ auto ProveViaGuard = [&](const BasicBlock *Block) {
if (isImpliedViaGuard(Block, Pred, LHS, RHS))
return true;
if (ProvingStrictComparison) {
@@ -9487,35 +10018,39 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
};
// Try to prove (Pred, LHS, RHS) using isImpliedCond.
- auto ProveViaCond = [&](Value *Condition, bool Inverse) {
- if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse))
+ auto ProveViaCond = [&](const Value *Condition, bool Inverse) {
+ const Instruction *Context = &BB->front();
+ if (isImpliedCond(Pred, LHS, RHS, Condition, Inverse, Context))
return true;
if (ProvingStrictComparison) {
if (!ProvedNonStrictComparison)
- ProvedNonStrictComparison =
- isImpliedCond(NonStrictPredicate, LHS, RHS, Condition, Inverse);
+ ProvedNonStrictComparison = isImpliedCond(NonStrictPredicate, LHS, RHS,
+ Condition, Inverse, Context);
if (!ProvedNonEquality)
- ProvedNonEquality =
- isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS, Condition, Inverse);
+ ProvedNonEquality = isImpliedCond(ICmpInst::ICMP_NE, LHS, RHS,
+ Condition, Inverse, Context);
if (ProvedNonStrictComparison && ProvedNonEquality)
return true;
}
return false;
};
- // Starting at the loop predecessor, climb up the predecessor chain, as long
+ // Starting at the block's predecessor, climb up the predecessor chain, as long
// as there are predecessors that can be found that have unique successors
- // leading to the original header.
- for (std::pair<BasicBlock *, BasicBlock *>
- Pair(L->getLoopPredecessor(), L->getHeader());
- Pair.first;
- Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
-
+ // leading to the original block.
+ const Loop *ContainingLoop = LI.getLoopFor(BB);
+ const BasicBlock *PredBB;
+ if (ContainingLoop && ContainingLoop->getHeader() == BB)
+ PredBB = ContainingLoop->getLoopPredecessor();
+ else
+ PredBB = BB->getSinglePredecessor();
+ for (std::pair<const BasicBlock *, const BasicBlock *> Pair(PredBB, BB);
+ Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
if (ProveViaGuard(Pair.first))
return true;
- BranchInst *LoopEntryPredicate =
- dyn_cast<BranchInst>(Pair.first->getTerminator());
+ const BranchInst *LoopEntryPredicate =
+ dyn_cast<BranchInst>(Pair.first->getTerminator());
if (!LoopEntryPredicate ||
LoopEntryPredicate->isUnconditional())
continue;
@@ -9530,7 +10065,7 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
if (!AssumeVH)
continue;
auto *CI = cast<CallInst>(AssumeVH);
- if (!DT.dominates(CI, L->getHeader()))
+ if (!DT.dominates(CI, BB))
continue;
if (ProveViaCond(CI->getArgOperand(0), false))
@@ -9540,10 +10075,27 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
return false;
}
-bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
- const SCEV *LHS, const SCEV *RHS,
- Value *FoundCondValue,
- bool Inverse) {
+bool ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
+ ICmpInst::Predicate Pred,
+ const SCEV *LHS,
+ const SCEV *RHS) {
+ // Interpret a null as meaning no loop, where there is obviously no guard
+ // (interprocedural conditions notwithstanding).
+ if (!L)
+ return false;
+
+ // Both LHS and RHS must be available at loop entry.
+ assert(isAvailableAtLoopEntry(LHS, L) &&
+ "LHS is not available at Loop Entry");
+ assert(isAvailableAtLoopEntry(RHS, L) &&
+ "RHS is not available at Loop Entry");
+ return isBasicBlockEntryGuardedByCond(L->getHeader(), Pred, LHS, RHS);
+}
+
+bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
+ const SCEV *RHS,
+ const Value *FoundCondValue, bool Inverse,
+ const Instruction *Context) {
if (!PendingLoopPredicates.insert(FoundCondValue).second)
return false;
@@ -9551,19 +10103,23 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); });
// Recursively handle And and Or conditions.
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
+ if (const BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {
if (BO->getOpcode() == Instruction::And) {
if (!Inverse)
- return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
- isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+ return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
+ Context) ||
+ isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
+ Context);
} else if (BO->getOpcode() == Instruction::Or) {
if (Inverse)
- return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse) ||
- isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse);
+ return isImpliedCond(Pred, LHS, RHS, BO->getOperand(0), Inverse,
+ Context) ||
+ isImpliedCond(Pred, LHS, RHS, BO->getOperand(1), Inverse,
+ Context);
}
}
- ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
+ const ICmpInst *ICI = dyn_cast<ICmpInst>(FoundCondValue);
if (!ICI) return false;
// Now that we found a conditional branch that dominates the loop or controls
@@ -9577,17 +10133,36 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,
const SCEV *FoundLHS = getSCEV(ICI->getOperand(0));
const SCEV *FoundRHS = getSCEV(ICI->getOperand(1));
- return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS);
+ return isImpliedCond(Pred, LHS, RHS, FoundPred, FoundLHS, FoundRHS, Context);
}
bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
const SCEV *RHS,
ICmpInst::Predicate FoundPred,
- const SCEV *FoundLHS,
- const SCEV *FoundRHS) {
+ const SCEV *FoundLHS, const SCEV *FoundRHS,
+ const Instruction *Context) {
// Balance the types.
if (getTypeSizeInBits(LHS->getType()) <
getTypeSizeInBits(FoundLHS->getType())) {
+ // For unsigned and equality predicates, try to prove that both found
+ // operands fit into narrow unsigned range. If so, try to prove facts in
+ // narrow types.
+ if (!CmpInst::isSigned(FoundPred)) {
+ auto *NarrowType = LHS->getType();
+ auto *WideType = FoundLHS->getType();
+ auto BitWidth = getTypeSizeInBits(NarrowType);
+ const SCEV *MaxValue = getZeroExtendExpr(
+ getConstant(APInt::getMaxValue(BitWidth)), WideType);
+ if (isKnownPredicate(ICmpInst::ICMP_ULE, FoundLHS, MaxValue) &&
+ isKnownPredicate(ICmpInst::ICMP_ULE, FoundRHS, MaxValue)) {
+ const SCEV *TruncFoundLHS = getTruncateExpr(FoundLHS, NarrowType);
+ const SCEV *TruncFoundRHS = getTruncateExpr(FoundRHS, NarrowType);
+ if (isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, TruncFoundLHS,
+ TruncFoundRHS, Context))
+ return true;
+ }
+ }
+
if (CmpInst::isSigned(Pred)) {
LHS = getSignExtendExpr(LHS, FoundLHS->getType());
RHS = getSignExtendExpr(RHS, FoundLHS->getType());
@@ -9605,7 +10180,17 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
FoundRHS = getZeroExtendExpr(FoundRHS, LHS->getType());
}
}
+ return isImpliedCondBalancedTypes(Pred, LHS, RHS, FoundPred, FoundLHS,
+ FoundRHS, Context);
+}
+bool ScalarEvolution::isImpliedCondBalancedTypes(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+ ICmpInst::Predicate FoundPred, const SCEV *FoundLHS, const SCEV *FoundRHS,
+ const Instruction *Context) {
+ assert(getTypeSizeInBits(LHS->getType()) ==
+ getTypeSizeInBits(FoundLHS->getType()) &&
+ "Types should be balanced!");
// Canonicalize the query to match the way instcombine will have
// canonicalized the comparison.
if (SimplifyICmpOperands(Pred, LHS, RHS))
@@ -9628,16 +10213,16 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
// Check whether the found predicate is the same as the desired predicate.
if (FoundPred == Pred)
- return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
// Check whether swapping the found predicate makes it the same as the
// desired predicate.
if (ICmpInst::getSwappedPredicate(FoundPred) == Pred) {
if (isa<SCEVConstant>(RHS))
- return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS);
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundRHS, FoundLHS, Context);
else
- return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred),
- RHS, LHS, FoundLHS, FoundRHS);
+ return isImpliedCondOperands(ICmpInst::getSwappedPredicate(Pred), RHS,
+ LHS, FoundLHS, FoundRHS, Context);
}
// Unsigned comparison is the same as signed comparison when both the operands
@@ -9645,7 +10230,7 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
if (CmpInst::isUnsigned(FoundPred) &&
CmpInst::getSignedPredicate(FoundPred) == Pred &&
isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS))
- return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS);
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context);
// Check if we can make progress by sharpening ranges.
if (FoundPred == ICmpInst::ICMP_NE &&
@@ -9682,8 +10267,8 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
case ICmpInst::ICMP_UGE:
// We know V `Pred` SharperMin. If this implies LHS `Pred`
// RHS, we're done.
- if (isImpliedCondOperands(Pred, LHS, RHS, V,
- getConstant(SharperMin)))
+ if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(SharperMin),
+ Context))
return true;
LLVM_FALLTHROUGH;
@@ -9698,10 +10283,26 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
//
// If V `Pred` Min implies LHS `Pred` RHS, we're done.
- if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min)))
+ if (isImpliedCondOperands(Pred, LHS, RHS, V, getConstant(Min),
+ Context))
+ return true;
+ break;
+
+ // `LHS < RHS` and `LHS <= RHS` are handled in the same way as `RHS > LHS` and `RHS >= LHS` respectively.
+ case ICmpInst::ICMP_SLE:
+ case ICmpInst::ICMP_ULE:
+ if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
+ LHS, V, getConstant(SharperMin), Context))
return true;
LLVM_FALLTHROUGH;
+ case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_ULT:
+ if (isImpliedCondOperands(CmpInst::getSwappedPredicate(Pred), RHS,
+ LHS, V, getConstant(Min), Context))
+ return true;
+ break;
+
default:
// No change
break;
@@ -9712,11 +10313,12 @@ bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred, const SCEV *LHS,
// Check whether the actual condition is beyond sufficient.
if (FoundPred == ICmpInst::ICMP_EQ)
if (ICmpInst::isTrueWhenEqual(Pred))
- if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS))
+ if (isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, FoundRHS, Context))
return true;
if (Pred == ICmpInst::ICMP_NE)
if (!ICmpInst::isTrueWhenEqual(FoundPred))
- if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS))
+ if (isImpliedCondOperands(FoundPred, LHS, RHS, FoundLHS, FoundRHS,
+ Context))
return true;
// Otherwise assume the worst.
@@ -9795,6 +10397,51 @@ Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More,
return None;
}
+bool ScalarEvolution::isImpliedCondOperandsViaAddRecStart(
+ ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
+ const SCEV *FoundLHS, const SCEV *FoundRHS, const Instruction *Context) {
+ // Try to recognize the following pattern:
+ //
+ // FoundRHS = ...
+ // ...
+ // loop:
+ // FoundLHS = {Start,+,W}
+ // context_bb: // Basic block from the same loop
+ // known(Pred, FoundLHS, FoundRHS)
+ //
+ // If some predicate is known in the context of a loop, it is also known on
+ // each iteration of this loop, including the first iteration. Therefore, in
+ // this case, `FoundLHS Pred FoundRHS` implies `Start Pred FoundRHS`. Try to
+ // prove the original pred using this fact.
+ if (!Context)
+ return false;
+ const BasicBlock *ContextBB = Context->getParent();
+ // Make sure AR varies in the context block.
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundLHS)) {
+ const Loop *L = AR->getLoop();
+ // Make sure that context belongs to the loop and executes on 1st iteration
+ // (if it ever executes at all).
+ if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
+ return false;
+ if (!isAvailableAtLoopEntry(FoundRHS, AR->getLoop()))
+ return false;
+ return isImpliedCondOperands(Pred, LHS, RHS, AR->getStart(), FoundRHS);
+ }
+
+ if (auto *AR = dyn_cast<SCEVAddRecExpr>(FoundRHS)) {
+ const Loop *L = AR->getLoop();
+ // Make sure that context belongs to the loop and executes on 1st iteration
+ // (if it ever executes at all).
+ if (!L->contains(ContextBB) || !DT.dominates(ContextBB, L->getLoopLatch()))
+ return false;
+ if (!isAvailableAtLoopEntry(FoundLHS, AR->getLoop()))
+ return false;
+ return isImpliedCondOperands(Pred, LHS, RHS, FoundLHS, AR->getStart());
+ }
+
+ return false;
+}
+
bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(
ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS, const SCEV *FoundRHS) {
@@ -9985,13 +10632,18 @@ bool ScalarEvolution::isImpliedViaMerge(ICmpInst::Predicate Pred,
bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,
const SCEV *LHS, const SCEV *RHS,
const SCEV *FoundLHS,
- const SCEV *FoundRHS) {
+ const SCEV *FoundRHS,
+ const Instruction *Context) {
if (isImpliedCondOperandsViaRanges(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
if (isImpliedCondOperandsViaNoOverflow(Pred, LHS, RHS, FoundLHS, FoundRHS))
return true;
+ if (isImpliedCondOperandsViaAddRecStart(Pred, LHS, RHS, FoundLHS, FoundRHS,
+ Context))
+ return true;
+
return isImpliedCondOperandsHelper(Pred, LHS, RHS,
FoundLHS, FoundRHS) ||
// ~x < ~y --> x > y
@@ -10008,7 +10660,7 @@ static bool IsMinMaxConsistingOf(const SCEV *MaybeMinMaxExpr,
if (!MinMaxExpr)
return false;
- return find(MinMaxExpr->operands(), Candidate) != MinMaxExpr->op_end();
+ return is_contained(MinMaxExpr->operands(), Candidate);
}
static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
@@ -10090,13 +10742,31 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
// We want to avoid hurting the compile time with analysis of too big trees.
if (Depth > MaxSCEVOperationsImplicationDepth)
return false;
- // We only want to work with ICMP_SGT comparison so far.
- // TODO: Extend to ICMP_UGT?
- if (Pred == ICmpInst::ICMP_SLT) {
- Pred = ICmpInst::ICMP_SGT;
+
+ // We only want to work with GT comparison so far.
+ if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_SLT) {
+ Pred = CmpInst::getSwappedPredicate(Pred);
std::swap(LHS, RHS);
std::swap(FoundLHS, FoundRHS);
}
+
+ // For unsigned, try to reduce it to corresponding signed comparison.
+ if (Pred == ICmpInst::ICMP_UGT)
+ // We can replace unsigned predicate with its signed counterpart if all
+ // involved values are non-negative.
+ // TODO: We could have better support for unsigned.
+ if (isKnownNonNegative(FoundLHS) && isKnownNonNegative(FoundRHS)) {
+ // Knowing that both FoundLHS and FoundRHS are non-negative, and knowing
+ // FoundLHS >u FoundRHS, we also know that FoundLHS >s FoundRHS. Let us
+ // use this fact to prove that LHS and RHS are non-negative.
+ const SCEV *MinusOne = getMinusOne(LHS->getType());
+ if (isImpliedCondOperands(ICmpInst::ICMP_SGT, LHS, MinusOne, FoundLHS,
+ FoundRHS) &&
+ isImpliedCondOperands(ICmpInst::ICMP_SGT, RHS, MinusOne, FoundLHS,
+ FoundRHS))
+ Pred = ICmpInst::ICMP_SGT;
+ }
+
if (Pred != ICmpInst::ICMP_SGT)
return false;
@@ -10136,7 +10806,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
auto *LL = LHSAddExpr->getOperand(0);
auto *LR = LHSAddExpr->getOperand(1);
- auto *MinusOne = getNegativeSCEV(getOne(RHS->getType()));
+ auto *MinusOne = getMinusOne(RHS->getType());
// Checks that S1 >= 0 && S2 > RHS, trivially or using the found context.
auto IsSumGreaterThanRHS = [&](const SCEV *S1, const SCEV *S2) {
@@ -10209,7 +10879,7 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,
// 1. If FoundLHS is negative, then the result is 0.
// 2. If FoundLHS is non-negative, then the result is non-negative.
// Anyways, the result is non-negative.
- auto *MinusOne = getNegativeSCEV(getOne(WTy));
+ auto *MinusOne = getMinusOne(WTy);
auto *NegDenomMinusOne = getMinusSCEV(MinusOne, DenominatorExt);
if (isKnownNegative(RHS) &&
IsSGTViaContext(FoundRHSExt, NegDenomMinusOne))
@@ -10564,7 +11234,13 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,
if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS))
BECount = BECountIfBackedgeTaken;
else {
- End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
+ // If we know that RHS >= Start in the context of loop, then we know that
+ // max(RHS, Start) = RHS at this point.
+ if (isLoopEntryGuardedByCond(
+ L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, RHS, Start))
+ End = RHS;
+ else
+ End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start);
BECount = computeBECount(getMinusSCEV(End, Start), Stride, false);
}
@@ -10631,8 +11307,15 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,
const SCEV *Start = IV->getStart();
const SCEV *End = RHS;
- if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS))
- End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
+ if (!isLoopEntryGuardedByCond(L, Cond, getAddExpr(Start, Stride), RHS)) {
+ // If we know that Start >= RHS in the context of loop, then we know that
+ // min(RHS, Start) = RHS at this point.
+ if (isLoopEntryGuardedByCond(
+ L, IsSigned ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE, Start, RHS))
+ End = RHS;
+ else
+ End = IsSigned ? getSMinExpr(RHS, Start) : getUMinExpr(RHS, Start);
+ }
const SCEV *BECount = computeBECount(getMinusSCEV(Start, End), Stride, false);
@@ -10672,7 +11355,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
// If the start is a non-zero constant, shift the range to simplify things.
if (const SCEVConstant *SC = dyn_cast<SCEVConstant>(getStart()))
if (!SC->getValue()->isZero()) {
- SmallVector<const SCEV *, 4> Operands(op_begin(), op_end());
+ SmallVector<const SCEV *, 4> Operands(operands());
Operands[0] = SE.getZero(SC->getType());
const SCEV *Shifted = SE.getAddRecExpr(Operands, getLoop(),
getNoWrapFlags(FlagNW));
@@ -10955,9 +11638,7 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,
}
// Remove all SCEVConstants.
- Terms.erase(
- remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }),
- Terms.end());
+ erase_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); });
if (Terms.size() > 0)
if (!findArrayDimensionsRec(SE, Terms, Sizes))
@@ -11285,7 +11966,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
// so that future queries will recompute the expressions using the new
// value.
Value *Old = getValPtr();
- SmallVector<User *, 16> Worklist(Old->user_begin(), Old->user_end());
+ SmallVector<User *, 16> Worklist(Old->users());
SmallPtrSet<User *, 8> Visited;
while (!Worklist.empty()) {
User *U = Worklist.pop_back_val();
@@ -11298,7 +11979,7 @@ void ScalarEvolution::SCEVCallbackVH::allUsesReplacedWith(Value *V) {
if (PHINode *PN = dyn_cast<PHINode>(U))
SE->ConstantEvolutionLoopExitValue.erase(PN);
SE->eraseValueFromMap(U);
- Worklist.insert(Worklist.end(), U->user_begin(), U->user_end());
+ llvm::append_range(Worklist, U->users());
}
// Delete the Old value.
if (PHINode *PN = dyn_cast<PHINode>(Old))
@@ -11580,9 +12261,10 @@ ScalarEvolution::getLoopDisposition(const SCEV *S, const Loop *L) {
ScalarEvolution::LoopDisposition
ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {
- switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+ switch (S->getSCEVType()) {
case scConstant:
return LoopInvariant;
+ case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -11687,9 +12369,10 @@ ScalarEvolution::getBlockDisposition(const SCEV *S, const BasicBlock *BB) {
ScalarEvolution::BlockDisposition
ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {
- switch (static_cast<SCEVTypes>(S->getSCEVType())) {
+ switch (S->getSCEVType()) {
case scConstant:
return ProperlyDominatesBlock;
+ case scPtrToInt:
case scTruncate:
case scZeroExtend:
case scSignExtend:
@@ -11861,7 +12544,7 @@ void ScalarEvolution::verify() const {
while (!LoopStack.empty()) {
auto *L = LoopStack.pop_back_val();
- LoopStack.insert(LoopStack.end(), L->begin(), L->end());
+ llvm::append_range(LoopStack, *L);
auto *CurBECount = SCM.visit(
const_cast<ScalarEvolution *>(this)->getBackedgeTakenCount(L));
@@ -11905,6 +12588,25 @@ void ScalarEvolution::verify() const {
std::abort();
}
}
+
+ // Collect all valid loops currently in LoopInfo.
+ SmallPtrSet<Loop *, 32> ValidLoops;
+ SmallVector<Loop *, 32> Worklist(LI.begin(), LI.end());
+ while (!Worklist.empty()) {
+ Loop *L = Worklist.pop_back_val();
+ if (ValidLoops.contains(L))
+ continue;
+ ValidLoops.insert(L);
+ Worklist.append(L->begin(), L->end());
+ }
+ // Check for SCEV expressions referencing invalid/deleted loops.
+ for (auto &KV : ValueExprMap) {
+ auto *AR = dyn_cast<SCEVAddRecExpr>(KV.second);
+ if (!AR)
+ continue;
+ assert(ValidLoops.contains(AR->getLoop()) &&
+ "AddRec references invalid loop");
+ }
}
bool ScalarEvolution::invalidate(
@@ -11937,6 +12639,11 @@ ScalarEvolutionVerifierPass::run(Function &F, FunctionAnalysisManager &AM) {
PreservedAnalyses
ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {
+ // For compatibility with opt's -analyze feature under legacy pass manager
+ // which was not ported to NPM. This keeps tests using
+ // update_analyze_test_checks.py working.
+ OS << "Printing analysis 'Scalar Evolution Analysis' for function '"
+ << F.getName() << "':\n";
AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);
return PreservedAnalyses::all();
}
@@ -12432,11 +13139,24 @@ void PredicatedScalarEvolution::print(raw_ostream &OS, unsigned Depth) const {
}
// Match the mathematical pattern A - (A / B) * B, where A and B can be
-// arbitrary expressions.
+// arbitrary expressions. Also match zext (trunc A to iB) to iY, which is used
+// for URem with constant power-of-2 second operands.
// It's not always easy, as A and B can be folded (imagine A is X / 2, and B is
// 4, A / B becomes X / 8).
bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
const SCEV *&RHS) {
+ // Try to match 'zext (trunc A to iB) to iY', which is used
+ // for URem with constant power-of-2 second operands. Make sure the size of
+ // the operand A matches the size of the whole expressions.
+ if (const auto *ZExt = dyn_cast<SCEVZeroExtendExpr>(Expr))
+ if (const auto *Trunc = dyn_cast<SCEVTruncateExpr>(ZExt->getOperand(0))) {
+ LHS = Trunc->getOperand();
+ if (LHS->getType() != Expr->getType())
+ LHS = getZeroExtendExpr(LHS, Expr->getType());
+ RHS = getConstant(APInt(getTypeSizeInBits(Expr->getType()), 1)
+ << getTypeSizeInBits(Trunc->getType()));
+ return true;
+ }
const auto *Add = dyn_cast<SCEVAddExpr>(Expr);
if (Add == nullptr || Add->getNumOperands() != 2)
return false;
@@ -12470,3 +13190,146 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
MatchURemWithDivisor(getNegativeSCEV(Mul->getOperand(0)));
return false;
}
+
+const SCEV *
+ScalarEvolution::computeSymbolicMaxBackedgeTakenCount(const Loop *L) {
+ SmallVector<BasicBlock*, 16> ExitingBlocks;
+ L->getExitingBlocks(ExitingBlocks);
+
+ // Form an expression for the maximum exit count possible for this loop. We
+ // merge the max and exact information to approximate a version of
+ // getConstantMaxBackedgeTakenCount which isn't restricted to just constants.
+ SmallVector<const SCEV*, 4> ExitCounts;
+ for (BasicBlock *ExitingBB : ExitingBlocks) {
+ const SCEV *ExitCount = getExitCount(L, ExitingBB);
+ if (isa<SCEVCouldNotCompute>(ExitCount))
+ ExitCount = getExitCount(L, ExitingBB,
+ ScalarEvolution::ConstantMaximum);
+ if (!isa<SCEVCouldNotCompute>(ExitCount)) {
+ assert(DT.dominates(ExitingBB, L->getLoopLatch()) &&
+ "We should only have known counts for exiting blocks that "
+ "dominate latch!");
+ ExitCounts.push_back(ExitCount);
+ }
+ }
+ if (ExitCounts.empty())
+ return getCouldNotCompute();
+ return getUMinFromMismatchedTypes(ExitCounts);
+}
+
+/// This rewriter is similar to SCEVParameterRewriter (it replaces SCEVUnknown
+/// components following the Map (Value -> SCEV)), but skips AddRecExpr because
+/// we cannot guarantee that the replacement is loop invariant in the loop of
+/// the AddRec.
+class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
+ ValueToSCEVMapTy &Map;
+
+public:
+ SCEVLoopGuardRewriter(ScalarEvolution &SE, ValueToSCEVMapTy &M)
+ : SCEVRewriteVisitor(SE), Map(M) {}
+
+ const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
+
+ const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+ auto I = Map.find(Expr->getValue());
+ if (I == Map.end())
+ return Expr;
+ return I->second;
+ }
+};
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+ auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
+ const SCEV *RHS, ValueToSCEVMapTy &RewriteMap) {
+ if (!isa<SCEVUnknown>(LHS)) {
+ std::swap(LHS, RHS);
+ Predicate = CmpInst::getSwappedPredicate(Predicate);
+ }
+
+ // For now, limit to conditions that provide information about unknown
+ // expressions.
+ auto *LHSUnknown = dyn_cast<SCEVUnknown>(LHS);
+ if (!LHSUnknown)
+ return;
+
+ // TODO: use information from more predicates.
+ switch (Predicate) {
+ case CmpInst::ICMP_ULT: {
+ if (!containsAddRecurrence(RHS)) {
+ const SCEV *Base = LHS;
+ auto I = RewriteMap.find(LHSUnknown->getValue());
+ if (I != RewriteMap.end())
+ Base = I->second;
+
+ RewriteMap[LHSUnknown->getValue()] =
+ getUMinExpr(Base, getMinusSCEV(RHS, getOne(RHS->getType())));
+ }
+ break;
+ }
+ case CmpInst::ICMP_ULE: {
+ if (!containsAddRecurrence(RHS)) {
+ const SCEV *Base = LHS;
+ auto I = RewriteMap.find(LHSUnknown->getValue());
+ if (I != RewriteMap.end())
+ Base = I->second;
+ RewriteMap[LHSUnknown->getValue()] = getUMinExpr(Base, RHS);
+ }
+ break;
+ }
+ case CmpInst::ICMP_EQ:
+ if (isa<SCEVConstant>(RHS))
+ RewriteMap[LHSUnknown->getValue()] = RHS;
+ break;
+ case CmpInst::ICMP_NE:
+ if (isa<SCEVConstant>(RHS) &&
+ cast<SCEVConstant>(RHS)->getValue()->isNullValue())
+ RewriteMap[LHSUnknown->getValue()] =
+ getUMaxExpr(LHS, getOne(RHS->getType()));
+ break;
+ default:
+ break;
+ }
+ };
+ // Starting at the loop predecessor, climb up the predecessor chain, as long
+ // as there are predecessors that can be found that have unique successors
+ // leading to the original header.
+ // TODO: share this logic with isLoopEntryGuardedByCond.
+ ValueToSCEVMapTy RewriteMap;
+ for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
+ L->getLoopPredecessor(), L->getHeader());
+ Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
+
+ const BranchInst *LoopEntryPredicate =
+ dyn_cast<BranchInst>(Pair.first->getTerminator());
+ if (!LoopEntryPredicate || LoopEntryPredicate->isUnconditional())
+ continue;
+
+ // TODO: use information from more complex conditions, e.g. AND expressions.
+ auto *Cmp = dyn_cast<ICmpInst>(LoopEntryPredicate->getCondition());
+ if (!Cmp)
+ continue;
+
+ auto Predicate = Cmp->getPredicate();
+ if (LoopEntryPredicate->getSuccessor(1) == Pair.second)
+ Predicate = CmpInst::getInversePredicate(Predicate);
+ CollectCondition(Predicate, getSCEV(Cmp->getOperand(0)),
+ getSCEV(Cmp->getOperand(1)), RewriteMap);
+ }
+
+ // Also collect information from assumptions dominating the loop.
+ for (auto &AssumeVH : AC.assumptions()) {
+ if (!AssumeVH)
+ continue;
+ auto *AssumeI = cast<CallInst>(AssumeVH);
+ auto *Cmp = dyn_cast<ICmpInst>(AssumeI->getOperand(0));
+ if (!Cmp || !DT.dominates(AssumeI, L->getHeader()))
+ continue;
+ CollectCondition(Cmp->getPredicate(), getSCEV(Cmp->getOperand(0)),
+ getSCEV(Cmp->getOperand(1)), RewriteMap);
+ }
+
+ if (RewriteMap.empty())
+ return Expr;
+ SCEVLoopGuardRewriter Rewriter(*this, RewriteMap);
+ return Rewriter.visit(Expr);
+}