summaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2019-01-19 10:01:25 +0000
committerDimitry Andric <dim@FreeBSD.org>2019-01-19 10:01:25 +0000
commitd8e91e46262bc44006913e6796843909f1ac7bcd (patch)
tree7d0c143d9b38190e0fa0180805389da22cd834c5 /lib/Analysis/ScalarEvolution.cpp
parentb7eb8e35e481a74962664b63dfb09483b200209a (diff)
Notes
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp480
1 files changed, 320 insertions, 160 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 0e715b8814ff..e5134f2eeda9 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -112,6 +112,7 @@
#include "llvm/IR/Use.h"
#include "llvm/IR/User.h"
#include "llvm/IR/Value.h"
+#include "llvm/IR/Verifier.h"
#include "llvm/Pass.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
@@ -162,6 +163,11 @@ static cl::opt<bool>
cl::desc("Verify no dangling value in ScalarEvolution's "
"ExprValueMap (slow)"));
+static cl::opt<bool> VerifyIR(
+ "scev-verify-ir", cl::Hidden,
+ cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"),
+ cl::init(false));
+
static cl::opt<unsigned> MulOpsInlineThreshold(
"scev-mulops-inline-threshold", cl::Hidden,
cl::desc("Threshold for inlining multiplication operands into a SCEV"),
@@ -204,7 +210,7 @@ static cl::opt<unsigned>
static cl::opt<unsigned>
MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden,
cl::desc("Max coefficients in AddRec during evolving"),
- cl::init(16));
+ cl::init(8));
//===----------------------------------------------------------------------===//
// SCEV class definitions
@@ -692,10 +698,6 @@ static int CompareSCEVComplexity(
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
- // Compare NoWrap flags.
- if (LA->getNoWrapFlags() != RA->getNoWrapFlags())
- return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags();
-
// Lexicographically compare.
for (unsigned i = 0; i != LNumOps; ++i) {
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
@@ -720,10 +722,6 @@ static int CompareSCEVComplexity(
if (LNumOps != RNumOps)
return (int)LNumOps - (int)RNumOps;
- // Compare NoWrap flags.
- if (LC->getNoWrapFlags() != RC->getNoWrapFlags())
- return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags();
-
for (unsigned i = 0; i != LNumOps; ++i) {
int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI,
LC->getOperand(i), RC->getOperand(i), DT,
@@ -2767,6 +2765,29 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,
}
const SCEV *
+ScalarEvolution::getOrCreateAddRecExpr(SmallVectorImpl<const SCEV *> &Ops,
+ const Loop *L, SCEV::NoWrapFlags Flags) {
+ FoldingSetNodeID ID;
+ ID.AddInteger(scAddRecExpr);
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ ID.AddPointer(Ops[i]);
+ ID.AddPointer(L);
+ void *IP = nullptr;
+ SCEVAddRecExpr *S =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size());
+ std::uninitialized_copy(Ops.begin(), Ops.end(), O);
+ S = new (SCEVAllocator)
+ SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L);
+ UniqueSCEVs.InsertNode(S, IP);
+ addToLoopUseLists(S);
+ }
+ S->setNoWrapFlags(Flags);
+ return S;
+}
+
+const SCEV *
ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SCEV::NoWrapFlags Flags) {
FoldingSetNodeID ID;
@@ -3045,7 +3066,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
SmallVector<const SCEV*, 7> AddRecOps;
for (int x = 0, xe = AddRec->getNumOperands() +
OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) {
- const SCEV *Term = getZero(Ty);
+ SmallVector <const SCEV *, 7> SumOps;
for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) {
uint64_t Coeff1 = Choose(x, 2*x - y, Overflow);
for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1),
@@ -3060,12 +3081,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
const SCEV *CoeffTerm = getConstant(Ty, Coeff);
const SCEV *Term1 = AddRec->getOperand(y-z);
const SCEV *Term2 = OtherAddRec->getOperand(z);
- Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2,
- SCEV::FlagAnyWrap, Depth + 1),
- SCEV::FlagAnyWrap, Depth + 1);
+ SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2,
+ SCEV::FlagAnyWrap, Depth + 1));
}
}
- AddRecOps.push_back(Term);
+ if (SumOps.empty())
+ SumOps.push_back(getZero(Ty));
+ AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1));
}
if (!Overflow) {
const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(),
@@ -3416,24 +3438,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,
// Okay, it looks like we really DO need an addrec expr. Check to see if we
// already have one, otherwise create a new one.
- FoldingSetNodeID ID;
- ID.AddInteger(scAddRecExpr);
- for (unsigned i = 0, e = Operands.size(); i != e; ++i)
- ID.AddPointer(Operands[i]);
- ID.AddPointer(L);
- void *IP = nullptr;
- SCEVAddRecExpr *S =
- static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
- if (!S) {
- const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size());
- std::uninitialized_copy(Operands.begin(), Operands.end(), O);
- S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),
- O, Operands.size(), L);
- UniqueSCEVs.InsertNode(S, IP);
- addToLoopUseLists(S);
- }
- S->setNoWrapFlags(Flags);
- return S;
+ return getOrCreateAddRecExpr(Operands, L, Flags);
}
const SCEV *
@@ -7080,7 +7085,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock,
return getCouldNotCompute();
bool IsOnlyExit = (L->getExitingBlock() != nullptr);
- TerminatorInst *Term = ExitingBlock->getTerminator();
+ Instruction *Term = ExitingBlock->getTerminator();
if (BranchInst *BI = dyn_cast<BranchInst>(Term)) {
assert(BI->isConditional() && "If unconditional, it can't be in loop!");
bool ExitIfTrue = !L->contains(BI->getSuccessor(0));
@@ -8344,69 +8349,273 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,
return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D);
}
-/// Find the roots of the quadratic equation for the given quadratic chrec
-/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or
-/// two SCEVCouldNotCompute objects.
-static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>>
-SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
+/// For a given quadratic addrec, generate coefficients of the corresponding
+/// quadratic equation, multiplied by a common value to ensure that they are
+/// integers.
+/// The returned value is a tuple { A, B, C, M, BitWidth }, where
+/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C
+/// were multiplied by, and BitWidth is the bit width of the original addrec
+/// coefficients.
+/// This function returns None if the addrec coefficients are not compile-
+/// time constants.
+static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>>
+GetQuadraticEquation(const SCEVAddRecExpr *AddRec) {
assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!");
const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0));
const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1));
const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2));
+ LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: "
+ << *AddRec << '\n');
// We currently can only solve this if the coefficients are constants.
- if (!LC || !MC || !NC)
+ if (!LC || !MC || !NC) {
+ LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n");
return None;
+ }
- uint32_t BitWidth = LC->getAPInt().getBitWidth();
- const APInt &L = LC->getAPInt();
- const APInt &M = MC->getAPInt();
- const APInt &N = NC->getAPInt();
- APInt Two(BitWidth, 2);
-
- // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C
+ APInt L = LC->getAPInt();
+ APInt M = MC->getAPInt();
+ APInt N = NC->getAPInt();
+ assert(!N.isNullValue() && "This is not a quadratic addrec");
+
+ unsigned BitWidth = LC->getAPInt().getBitWidth();
+ unsigned NewWidth = BitWidth + 1;
+ LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: "
+ << BitWidth << '\n');
+ // The sign-extension (as opposed to a zero-extension) here matches the
+ // extension used in SolveQuadraticEquationWrap (with the same motivation).
+ N = N.sext(NewWidth);
+ M = M.sext(NewWidth);
+ L = L.sext(NewWidth);
+
+ // The increments are M, M+N, M+2N, ..., so the accumulated values are
+ // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is,
+ // L+M, L+2M+N, L+3M+3N, ...
+ // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N.
+ //
+ // The equation Acc = 0 is then
+ // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0.
+ // In a quadratic form it becomes:
+ // N n^2 + (2M-N) n + 2L = 0.
+
+ APInt A = N;
+ APInt B = 2 * M - A;
+ APInt C = 2 * L;
+ APInt T = APInt(NewWidth, 2);
+ LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B
+ << "x + " << C << ", coeff bw: " << NewWidth
+ << ", multiplied by " << T << '\n');
+ return std::make_tuple(A, B, C, T, BitWidth);
+}
+
+/// Helper function to compare optional APInts:
+/// (a) if X and Y both exist, return min(X, Y),
+/// (b) if neither X nor Y exist, return None,
+/// (c) if exactly one of X and Y exists, return that value.
+static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) {
+ if (X.hasValue() && Y.hasValue()) {
+ unsigned W = std::max(X->getBitWidth(), Y->getBitWidth());
+ APInt XW = X->sextOrSelf(W);
+ APInt YW = Y->sextOrSelf(W);
+ return XW.slt(YW) ? *X : *Y;
+ }
+ if (!X.hasValue() && !Y.hasValue())
+ return None;
+ return X.hasValue() ? *X : *Y;
+}
- // The A coefficient is N/2
- APInt A = N.sdiv(Two);
+/// Helper function to truncate an optional APInt to a given BitWidth.
+/// When solving addrec-related equations, it is preferable to return a value
+/// that has the same bit width as the original addrec's coefficients. If the
+/// solution fits in the original bit width, truncate it (except for i1).
+/// Returning a value of a different bit width may inhibit some optimizations.
+///
+/// In general, a solution to a quadratic equation generated from an addrec
+/// may require BW+1 bits, where BW is the bit width of the addrec's
+/// coefficients. The reason is that the coefficients of the quadratic
+/// equation are BW+1 bits wide (to avoid truncation when converting from
+/// the addrec to the equation).
+static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) {
+ if (!X.hasValue())
+ return None;
+ unsigned W = X->getBitWidth();
+ if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth))
+ return X->trunc(BitWidth);
+ return X;
+}
+
+/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n
+/// iterations. The values L, M, N are assumed to be signed, and they
+/// should all have the same bit widths.
+/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW,
+/// where BW is the bit width of the addrec's coefficients.
+/// If the calculated value is a BW-bit integer (for BW > 1), it will be
+/// returned as such, otherwise the bit width of the returned value may
+/// be greater than BW.
+///
+/// This function returns None if
+/// (a) the addrec coefficients are not constant, or
+/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases
+/// like x^2 = 5, no integer solutions exist, in other cases an integer
+/// solution may exist, but SolveQuadraticEquationWrap may fail to find it.
+static Optional<APInt>
+SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {
+ APInt A, B, C, M;
+ unsigned BitWidth;
+ auto T = GetQuadraticEquation(AddRec);
+ if (!T.hasValue())
+ return None;
- // The B coefficient is M-N/2
- APInt B = M;
- B -= A; // A is the same as N/2.
+ std::tie(A, B, C, M, BitWidth) = *T;
+ LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n");
+ Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1);
+ if (!X.hasValue())
+ return None;
- // The C coefficient is L.
- const APInt& C = L;
+ ConstantInt *CX = ConstantInt::get(SE.getContext(), *X);
+ ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE);
+ if (!V->isZero())
+ return None;
- // Compute the B^2-4ac term.
- APInt SqrtTerm = B;
- SqrtTerm *= B;
- SqrtTerm -= 4 * (A * C);
+ return TruncIfPossible(X, BitWidth);
+}
- if (SqrtTerm.isNegative()) {
- // The loop is provably infinite.
+/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n
+/// iterations. The values M, N are assumed to be signed, and they
+/// should all have the same bit widths.
+/// Find the least n such that c(n) does not belong to the given range,
+/// while c(n-1) does.
+///
+/// This function returns None if
+/// (a) the addrec coefficients are not constant, or
+/// (b) SolveQuadraticEquationWrap was unable to find a solution for the
+/// bounds of the range.
+static Optional<APInt>
+SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec,
+ const ConstantRange &Range, ScalarEvolution &SE) {
+ assert(AddRec->getOperand(0)->isZero() &&
+ "Starting value of addrec should be 0");
+ LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range "
+ << Range << ", addrec " << *AddRec << '\n');
+ // This case is handled in getNumIterationsInRange. Here we can assume that
+ // we start in the range.
+ assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) &&
+ "Addrec's initial value should be in range");
+
+ APInt A, B, C, M;
+ unsigned BitWidth;
+ auto T = GetQuadraticEquation(AddRec);
+ if (!T.hasValue())
return None;
- }
- // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest
- // integer value or else APInt::sqrt() will assert.
- APInt SqrtVal = SqrtTerm.sqrt();
+ // Be careful about the return value: there can be two reasons for not
+ // returning an actual number. First, if no solutions to the equations
+ // were found, and second, if the solutions don't leave the given range.
+ // The first case means that the actual solution is "unknown", the second
+ // means that it's known, but not valid. If the solution is unknown, we
+ // cannot make any conclusions.
+ // Return a pair: the optional solution and a flag indicating if the
+ // solution was found.
+ auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> {
+ // Solve for signed overflow and unsigned overflow, pick the lower
+ // solution.
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary "
+ << Bound << " (before multiplying by " << M << ")\n");
+ Bound *= M; // The quadratic equation multiplier.
+
+ Optional<APInt> SO = None;
+ if (BitWidth > 1) {
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
+ "signed overflow\n");
+ SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth);
+ }
+ LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for "
+ "unsigned overflow\n");
+ Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound,
+ BitWidth+1);
+
+ auto LeavesRange = [&] (const APInt &X) {
+ ConstantInt *C0 = ConstantInt::get(SE.getContext(), X);
+ ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE);
+ if (Range.contains(V0->getValue()))
+ return false;
+ // X should be at least 1, so X-1 is non-negative.
+ ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1);
+ ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE);
+ if (Range.contains(V1->getValue()))
+ return true;
+ return false;
+ };
- // Compute the two solutions for the quadratic formula.
- // The divisions must be performed as signed divisions.
- APInt NegB = -std::move(B);
- APInt TwoA = std::move(A);
- TwoA <<= 1;
- if (TwoA.isNullValue())
- return None;
+ // If SolveQuadraticEquationWrap returns None, it means that there can
+ // be a solution, but the function failed to find it. We cannot treat it
+ // as "no solution".
+ if (!SO.hasValue() || !UO.hasValue())
+ return { None, false };
+
+ // Check the smaller value first to see if it leaves the range.
+ // At this point, both SO and UO must have values.
+ Optional<APInt> Min = MinOptional(SO, UO);
+ if (LeavesRange(*Min))
+ return { Min, true };
+ Optional<APInt> Max = Min == SO ? UO : SO;
+ if (LeavesRange(*Max))
+ return { Max, true };
+
+ // Solutions were found, but were eliminated, hence the "true".
+ return { None, true };
+ };
- LLVMContext &Context = SE.getContext();
+ std::tie(A, B, C, M, BitWidth) = *T;
+ // Lower bound is inclusive, subtract 1 to represent the exiting value.
+ APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1;
+ APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth());
+ auto SL = SolveForBoundary(Lower);
+ auto SU = SolveForBoundary(Upper);
+ // If any of the solutions was unknown, no meaninigful conclusions can
+ // be made.
+ if (!SL.second || !SU.second)
+ return None;
- ConstantInt *Solution1 =
- ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA));
- ConstantInt *Solution2 =
- ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA));
+ // Claim: The correct solution is not some value between Min and Max.
+ //
+ // Justification: Assuming that Min and Max are different values, one of
+ // them is when the first signed overflow happens, the other is when the
+ // first unsigned overflow happens. Crossing the range boundary is only
+ // possible via an overflow (treating 0 as a special case of it, modeling
+ // an overflow as crossing k*2^W for some k).
+ //
+ // The interesting case here is when Min was eliminated as an invalid
+ // solution, but Max was not. The argument is that if there was another
+ // overflow between Min and Max, it would also have been eliminated if
+ // it was considered.
+ //
+ // For a given boundary, it is possible to have two overflows of the same
+ // type (signed/unsigned) without having the other type in between: this
+ // can happen when the vertex of the parabola is between the iterations
+ // corresponding to the overflows. This is only possible when the two
+ // overflows cross k*2^W for the same k. In such case, if the second one
+ // left the range (and was the first one to do so), the first overflow
+ // would have to enter the range, which would mean that either we had left
+ // the range before or that we started outside of it. Both of these cases
+ // are contradictions.
+ //
+ // Claim: In the case where SolveForBoundary returns None, the correct
+ // solution is not some value between the Max for this boundary and the
+ // Min of the other boundary.
+ //
+ // Justification: Assume that we had such Max_A and Min_B corresponding
+ // to range boundaries A and B and such that Max_A < Min_B. If there was
+ // a solution between Max_A and Min_B, it would have to be caused by an
+ // overflow corresponding to either A or B. It cannot correspond to B,
+ // since Min_B is the first occurrence of such an overflow. If it
+ // corresponded to A, it would have to be either a signed or an unsigned
+ // overflow that is larger than both eliminated overflows for A. But
+ // between the eliminated overflows and this overflow, the values would
+ // cover the entire value space, thus crossing the other boundary, which
+ // is a contradiction.
- return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)),
- cast<SCEVConstant>(SE.getConstant(Solution2)));
+ return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth);
}
ScalarEvolution::ExitLimit
@@ -8441,23 +8650,12 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of
// the quadratic equation to solve it.
if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) {
- if (auto Roots = SolveQuadraticEquation(AddRec, *this)) {
- const SCEVConstant *R1 = Roots->first;
- const SCEVConstant *R2 = Roots->second;
- // Pick the smallest positive root value.
- if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
- CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
- if (!CB->getZExtValue())
- std::swap(R1, R2); // R1 is the minimum root now.
-
- // We can only use this value if the chrec ends up with an exact zero
- // value at this index. When solving for "X*X != 5", for example, we
- // should not accept a root of 2.
- const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);
- if (Val->isZero())
- // We found a quadratic root!
- return ExitLimit(R1, R1, false, Predicates);
- }
+ // We can only use this value if the chrec ends up with an exact zero
+ // value at this index. When solving for "X*X != 5", for example, we
+ // should not accept a root of 2.
+ if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) {
+ const auto *R = cast<SCEVConstant>(getConstant(S.getValue()));
+ return ExitLimit(R, R, false, Predicates);
}
return getCouldNotCompute();
}
@@ -8617,7 +8815,13 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
const SCEV *&LHS, const SCEV *&RHS,
unsigned Depth) {
bool Changed = false;
-
+ // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or
+ // '0 != 0'.
+ auto TrivialCase = [&](bool TriviallyTrue) {
+ LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
+ Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
+ return true;
+ };
// If we hit the max recursion limit bail out.
if (Depth >= 3)
return false;
@@ -8629,9 +8833,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
if (ConstantExpr::getICmp(Pred,
LHSC->getValue(),
RHSC->getValue())->isNullValue())
- goto trivially_false;
+ return TrivialCase(false);
else
- goto trivially_true;
+ return TrivialCase(true);
}
// Otherwise swap the operands to put the constant on the right.
std::swap(LHS, RHS);
@@ -8661,9 +8865,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
if (!ICmpInst::isEquality(Pred)) {
ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA);
if (ExactCR.isFullSet())
- goto trivially_true;
+ return TrivialCase(true);
else if (ExactCR.isEmptySet())
- goto trivially_false;
+ return TrivialCase(false);
APInt NewRHS;
CmpInst::Predicate NewPred;
@@ -8699,7 +8903,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
// The "Should have been caught earlier!" messages refer to the fact
// that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above
// should have fired on the corresponding cases, and canonicalized the
- // check to trivially_true or trivially_false.
+ // check to trivial case.
case ICmpInst::ICMP_UGE:
assert(!RA.isMinValue() && "Should have been caught earlier!");
@@ -8732,9 +8936,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
// Check for obvious equality.
if (HasSameValue(LHS, RHS)) {
if (ICmpInst::isTrueWhenEqual(Pred))
- goto trivially_true;
+ return TrivialCase(true);
if (ICmpInst::isFalseWhenEqual(Pred))
- goto trivially_false;
+ return TrivialCase(false);
}
// If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by
@@ -8802,18 +9006,6 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,
return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1);
return Changed;
-
-trivially_true:
- // Return 0 == 0.
- LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
- Pred = ICmpInst::ICMP_EQ;
- return true;
-
-trivially_false:
- // Return 0 != 0.
- LHS = RHS = getConstant(ConstantInt::getFalse(getContext()));
- Pred = ICmpInst::ICMP_NE;
- return true;
}
bool ScalarEvolution::isKnownNegative(const SCEV *S) {
@@ -9184,6 +9376,11 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return true;
+ if (VerifyIR)
+ assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
+ "This cannot be done on broken IR!");
+
+
if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS))
return true;
@@ -9289,6 +9486,10 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,
// (interprocedural conditions notwithstanding).
if (!L) return false;
+ if (VerifyIR)
+ assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) &&
+ "This cannot be done on broken IR!");
+
// Both LHS and RHS must be available at loop entry.
assert(isAvailableAtLoopEntry(LHS, L) &&
"LHS is not available at Loop Entry");
@@ -10565,52 +10766,11 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,
ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) &&
"Linear scev computation is off in a bad way!");
return SE.getConstant(ExitValue);
- } else if (isQuadratic()) {
- // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the
- // quadratic equation to solve it. To do this, we must frame our problem in
- // terms of figuring out when zero is crossed, instead of when
- // Range.getUpper() is crossed.
- SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());
- NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper()));
- const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap);
-
- // Next, solve the constructed addrec
- if (auto Roots =
- SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) {
- const SCEVConstant *R1 = Roots->first;
- const SCEVConstant *R2 = Roots->second;
- // Pick the smallest positive root value.
- if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp(
- ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) {
- if (!CB->getZExtValue())
- std::swap(R1, R2); // R1 is the minimum root now.
-
- // Make sure the root is not off by one. The returned iteration should
- // not be in the range, but the previous one should be. When solving
- // for "X*X < 5", for example, we should not return a root of 2.
- ConstantInt *R1Val =
- EvaluateConstantChrecAtConstant(this, R1->getValue(), SE);
- if (Range.contains(R1Val->getValue())) {
- // The next iteration must be out of the range...
- ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getAPInt() + 1);
-
- R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
- if (!Range.contains(R1Val->getValue()))
- return SE.getConstant(NextVal);
- return SE.getCouldNotCompute(); // Something strange happened
- }
+ }
- // If R1 was not in the range, then it is a good return value. Make
- // sure that R1-1 WAS in the range though, just in case.
- ConstantInt *NextVal =
- ConstantInt::get(SE.getContext(), R1->getAPInt() - 1);
- R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE);
- if (Range.contains(R1Val->getValue()))
- return R1;
- return SE.getCouldNotCompute(); // Something strange happened
- }
- }
+ if (isQuadratic()) {
+ if (auto S = SolveQuadraticAddRecRange(this, Range, SE))
+ return SE.getConstant(S.getValue());
}
return SE.getCouldNotCompute();
@@ -10920,7 +11080,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,
Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end());
// Put larger terms first.
- llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) {
+ llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) {
return numberOfTerms(LHS) > numberOfTerms(RHS);
});