diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 480 |
1 files changed, 320 insertions, 160 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 0e715b8814ff..e5134f2eeda9 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -112,6 +112,7 @@ #include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" @@ -162,6 +163,11 @@ static cl::opt<bool> cl::desc("Verify no dangling value in ScalarEvolution's " "ExprValueMap (slow)")); +static cl::opt<bool> VerifyIR( + "scev-verify-ir", cl::Hidden, + cl::desc("Verify IR correctness when making sensitive SCEV queries (slow)"), + cl::init(false)); + static cl::opt<unsigned> MulOpsInlineThreshold( "scev-mulops-inline-threshold", cl::Hidden, cl::desc("Threshold for inlining multiplication operands into a SCEV"), @@ -204,7 +210,7 @@ static cl::opt<unsigned> static cl::opt<unsigned> MaxAddRecSize("scalar-evolution-max-add-rec-size", cl::Hidden, cl::desc("Max coefficients in AddRec during evolving"), - cl::init(16)); + cl::init(8)); //===----------------------------------------------------------------------===// // SCEV class definitions @@ -692,10 +698,6 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; - // Compare NoWrap flags. - if (LA->getNoWrapFlags() != RA->getNoWrapFlags()) - return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags(); - // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, @@ -720,10 +722,6 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; - // Compare NoWrap flags. - if (LC->getNoWrapFlags() != RC->getNoWrapFlags()) - return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags(); - for (unsigned i = 0; i != LNumOps; ++i) { int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getOperand(i), RC->getOperand(i), DT, @@ -2767,6 +2765,29 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, } const SCEV * +ScalarEvolution::getOrCreateAddRecExpr(SmallVectorImpl<const SCEV *> &Ops, + const Loop *L, SCEV::NoWrapFlags Flags) { + FoldingSetNodeID ID; + ID.AddInteger(scAddRecExpr); + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + ID.AddPointer(Ops[i]); + ID.AddPointer(L); + void *IP = nullptr; + SCEVAddRecExpr *S = + static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); + if (!S) { + const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Ops.size()); + std::uninitialized_copy(Ops.begin(), Ops.end(), O); + S = new (SCEVAllocator) + SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Ops.size(), L); + UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); + } + S->setNoWrapFlags(Flags); + return S; +} + +const SCEV * ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; @@ -3045,7 +3066,7 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, SmallVector<const SCEV*, 7> AddRecOps; for (int x = 0, xe = AddRec->getNumOperands() + OtherAddRec->getNumOperands() - 1; x != xe && !Overflow; ++x) { - const SCEV *Term = getZero(Ty); + SmallVector <const SCEV *, 7> SumOps; for (int y = x, ye = 2*x+1; y != ye && !Overflow; ++y) { uint64_t Coeff1 = Choose(x, 2*x - y, Overflow); for (int z = std::max(y-x, y-(int)AddRec->getNumOperands()+1), @@ -3060,12 +3081,13 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEV *CoeffTerm = getConstant(Ty, Coeff); const SCEV *Term1 = AddRec->getOperand(y-z); const SCEV *Term2 = OtherAddRec->getOperand(z); - Term = getAddExpr(Term, getMulExpr(CoeffTerm, Term1, Term2, - SCEV::FlagAnyWrap, Depth + 1), - SCEV::FlagAnyWrap, Depth + 1); + SumOps.push_back(getMulExpr(CoeffTerm, Term1, Term2, + SCEV::FlagAnyWrap, Depth + 1)); } } - AddRecOps.push_back(Term); + if (SumOps.empty()) + SumOps.push_back(getZero(Ty)); + AddRecOps.push_back(getAddExpr(SumOps, SCEV::FlagAnyWrap, Depth + 1)); } if (!Overflow) { const SCEV *NewAddRec = getAddRecExpr(AddRecOps, AddRec->getLoop(), @@ -3416,24 +3438,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, // Okay, it looks like we really DO need an addrec expr. Check to see if we // already have one, otherwise create a new one. - FoldingSetNodeID ID; - ID.AddInteger(scAddRecExpr); - for (unsigned i = 0, e = Operands.size(); i != e; ++i) - ID.AddPointer(Operands[i]); - ID.AddPointer(L); - void *IP = nullptr; - SCEVAddRecExpr *S = - static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); - if (!S) { - const SCEV **O = SCEVAllocator.Allocate<const SCEV *>(Operands.size()); - std::uninitialized_copy(Operands.begin(), Operands.end(), O); - S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), - O, Operands.size(), L); - UniqueSCEVs.InsertNode(S, IP); - addToLoopUseLists(S); - } - S->setNoWrapFlags(Flags); - return S; + return getOrCreateAddRecExpr(Operands, L, Flags); } const SCEV * @@ -7080,7 +7085,7 @@ ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, return getCouldNotCompute(); bool IsOnlyExit = (L->getExitingBlock() != nullptr); - TerminatorInst *Term = ExitingBlock->getTerminator(); + Instruction *Term = ExitingBlock->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(Term)) { assert(BI->isConditional() && "If unconditional, it can't be in loop!"); bool ExitIfTrue = !L->contains(BI->getSuccessor(0)); @@ -8344,69 +8349,273 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, return SE.getUDivExactExpr(SE.getMulExpr(B, SE.getConstant(I)), D); } -/// Find the roots of the quadratic equation for the given quadratic chrec -/// {L,+,M,+,N}. This returns either the two roots (which might be the same) or -/// two SCEVCouldNotCompute objects. -static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> -SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { +/// For a given quadratic addrec, generate coefficients of the corresponding +/// quadratic equation, multiplied by a common value to ensure that they are +/// integers. +/// The returned value is a tuple { A, B, C, M, BitWidth }, where +/// Ax^2 + Bx + C is the quadratic function, M is the value that A, B and C +/// were multiplied by, and BitWidth is the bit width of the original addrec +/// coefficients. +/// This function returns None if the addrec coefficients are not compile- +/// time constants. +static Optional<std::tuple<APInt, APInt, APInt, APInt, unsigned>> +GetQuadraticEquation(const SCEVAddRecExpr *AddRec) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); const SCEVConstant *LC = dyn_cast<SCEVConstant>(AddRec->getOperand(0)); const SCEVConstant *MC = dyn_cast<SCEVConstant>(AddRec->getOperand(1)); const SCEVConstant *NC = dyn_cast<SCEVConstant>(AddRec->getOperand(2)); + LLVM_DEBUG(dbgs() << __func__ << ": analyzing quadratic addrec: " + << *AddRec << '\n'); // We currently can only solve this if the coefficients are constants. - if (!LC || !MC || !NC) + if (!LC || !MC || !NC) { + LLVM_DEBUG(dbgs() << __func__ << ": coefficients are not constant\n"); return None; + } - uint32_t BitWidth = LC->getAPInt().getBitWidth(); - const APInt &L = LC->getAPInt(); - const APInt &M = MC->getAPInt(); - const APInt &N = NC->getAPInt(); - APInt Two(BitWidth, 2); - - // Convert from chrec coefficients to polynomial coefficients AX^2+BX+C + APInt L = LC->getAPInt(); + APInt M = MC->getAPInt(); + APInt N = NC->getAPInt(); + assert(!N.isNullValue() && "This is not a quadratic addrec"); + + unsigned BitWidth = LC->getAPInt().getBitWidth(); + unsigned NewWidth = BitWidth + 1; + LLVM_DEBUG(dbgs() << __func__ << ": addrec coeff bw: " + << BitWidth << '\n'); + // The sign-extension (as opposed to a zero-extension) here matches the + // extension used in SolveQuadraticEquationWrap (with the same motivation). + N = N.sext(NewWidth); + M = M.sext(NewWidth); + L = L.sext(NewWidth); + + // The increments are M, M+N, M+2N, ..., so the accumulated values are + // L+M, (L+M)+(M+N), (L+M)+(M+N)+(M+2N), ..., that is, + // L+M, L+2M+N, L+3M+3N, ... + // After n iterations the accumulated value Acc is L + nM + n(n-1)/2 N. + // + // The equation Acc = 0 is then + // L + nM + n(n-1)/2 N = 0, or 2L + 2M n + n(n-1) N = 0. + // In a quadratic form it becomes: + // N n^2 + (2M-N) n + 2L = 0. + + APInt A = N; + APInt B = 2 * M - A; + APInt C = 2 * L; + APInt T = APInt(NewWidth, 2); + LLVM_DEBUG(dbgs() << __func__ << ": equation " << A << "x^2 + " << B + << "x + " << C << ", coeff bw: " << NewWidth + << ", multiplied by " << T << '\n'); + return std::make_tuple(A, B, C, T, BitWidth); +} + +/// Helper function to compare optional APInts: +/// (a) if X and Y both exist, return min(X, Y), +/// (b) if neither X nor Y exist, return None, +/// (c) if exactly one of X and Y exists, return that value. +static Optional<APInt> MinOptional(Optional<APInt> X, Optional<APInt> Y) { + if (X.hasValue() && Y.hasValue()) { + unsigned W = std::max(X->getBitWidth(), Y->getBitWidth()); + APInt XW = X->sextOrSelf(W); + APInt YW = Y->sextOrSelf(W); + return XW.slt(YW) ? *X : *Y; + } + if (!X.hasValue() && !Y.hasValue()) + return None; + return X.hasValue() ? *X : *Y; +} - // The A coefficient is N/2 - APInt A = N.sdiv(Two); +/// Helper function to truncate an optional APInt to a given BitWidth. +/// When solving addrec-related equations, it is preferable to return a value +/// that has the same bit width as the original addrec's coefficients. If the +/// solution fits in the original bit width, truncate it (except for i1). +/// Returning a value of a different bit width may inhibit some optimizations. +/// +/// In general, a solution to a quadratic equation generated from an addrec +/// may require BW+1 bits, where BW is the bit width of the addrec's +/// coefficients. The reason is that the coefficients of the quadratic +/// equation are BW+1 bits wide (to avoid truncation when converting from +/// the addrec to the equation). +static Optional<APInt> TruncIfPossible(Optional<APInt> X, unsigned BitWidth) { + if (!X.hasValue()) + return None; + unsigned W = X->getBitWidth(); + if (BitWidth > 1 && BitWidth < W && X->isIntN(BitWidth)) + return X->trunc(BitWidth); + return X; +} + +/// Let c(n) be the value of the quadratic chrec {L,+,M,+,N} after n +/// iterations. The values L, M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n >= 0 such that c(n) = 0 in the arithmetic modulo 2^BW, +/// where BW is the bit width of the addrec's coefficients. +/// If the calculated value is a BW-bit integer (for BW > 1), it will be +/// returned as such, otherwise the bit width of the returned value may +/// be greater than BW. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution. For cases +/// like x^2 = 5, no integer solutions exist, in other cases an integer +/// solution may exist, but SolveQuadraticEquationWrap may fail to find it. +static Optional<APInt> +SolveQuadraticAddRecExact(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) + return None; - // The B coefficient is M-N/2 - APInt B = M; - B -= A; // A is the same as N/2. + std::tie(A, B, C, M, BitWidth) = *T; + LLVM_DEBUG(dbgs() << __func__ << ": solving for unsigned overflow\n"); + Optional<APInt> X = APIntOps::SolveQuadraticEquationWrap(A, B, C, BitWidth+1); + if (!X.hasValue()) + return None; - // The C coefficient is L. - const APInt& C = L; + ConstantInt *CX = ConstantInt::get(SE.getContext(), *X); + ConstantInt *V = EvaluateConstantChrecAtConstant(AddRec, CX, SE); + if (!V->isZero()) + return None; - // Compute the B^2-4ac term. - APInt SqrtTerm = B; - SqrtTerm *= B; - SqrtTerm -= 4 * (A * C); + return TruncIfPossible(X, BitWidth); +} - if (SqrtTerm.isNegative()) { - // The loop is provably infinite. +/// Let c(n) be the value of the quadratic chrec {0,+,M,+,N} after n +/// iterations. The values M, N are assumed to be signed, and they +/// should all have the same bit widths. +/// Find the least n such that c(n) does not belong to the given range, +/// while c(n-1) does. +/// +/// This function returns None if +/// (a) the addrec coefficients are not constant, or +/// (b) SolveQuadraticEquationWrap was unable to find a solution for the +/// bounds of the range. +static Optional<APInt> +SolveQuadraticAddRecRange(const SCEVAddRecExpr *AddRec, + const ConstantRange &Range, ScalarEvolution &SE) { + assert(AddRec->getOperand(0)->isZero() && + "Starting value of addrec should be 0"); + LLVM_DEBUG(dbgs() << __func__ << ": solving boundary crossing for range " + << Range << ", addrec " << *AddRec << '\n'); + // This case is handled in getNumIterationsInRange. Here we can assume that + // we start in the range. + assert(Range.contains(APInt(SE.getTypeSizeInBits(AddRec->getType()), 0)) && + "Addrec's initial value should be in range"); + + APInt A, B, C, M; + unsigned BitWidth; + auto T = GetQuadraticEquation(AddRec); + if (!T.hasValue()) return None; - } - // Compute sqrt(B^2-4ac). This is guaranteed to be the nearest - // integer value or else APInt::sqrt() will assert. - APInt SqrtVal = SqrtTerm.sqrt(); + // Be careful about the return value: there can be two reasons for not + // returning an actual number. First, if no solutions to the equations + // were found, and second, if the solutions don't leave the given range. + // The first case means that the actual solution is "unknown", the second + // means that it's known, but not valid. If the solution is unknown, we + // cannot make any conclusions. + // Return a pair: the optional solution and a flag indicating if the + // solution was found. + auto SolveForBoundary = [&](APInt Bound) -> std::pair<Optional<APInt>,bool> { + // Solve for signed overflow and unsigned overflow, pick the lower + // solution. + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: checking boundary " + << Bound << " (before multiplying by " << M << ")\n"); + Bound *= M; // The quadratic equation multiplier. + + Optional<APInt> SO = None; + if (BitWidth > 1) { + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "signed overflow\n"); + SO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, BitWidth); + } + LLVM_DEBUG(dbgs() << "SolveQuadraticAddRecRange: solving for " + "unsigned overflow\n"); + Optional<APInt> UO = APIntOps::SolveQuadraticEquationWrap(A, B, -Bound, + BitWidth+1); + + auto LeavesRange = [&] (const APInt &X) { + ConstantInt *C0 = ConstantInt::get(SE.getContext(), X); + ConstantInt *V0 = EvaluateConstantChrecAtConstant(AddRec, C0, SE); + if (Range.contains(V0->getValue())) + return false; + // X should be at least 1, so X-1 is non-negative. + ConstantInt *C1 = ConstantInt::get(SE.getContext(), X-1); + ConstantInt *V1 = EvaluateConstantChrecAtConstant(AddRec, C1, SE); + if (Range.contains(V1->getValue())) + return true; + return false; + }; - // Compute the two solutions for the quadratic formula. - // The divisions must be performed as signed divisions. - APInt NegB = -std::move(B); - APInt TwoA = std::move(A); - TwoA <<= 1; - if (TwoA.isNullValue()) - return None; + // If SolveQuadraticEquationWrap returns None, it means that there can + // be a solution, but the function failed to find it. We cannot treat it + // as "no solution". + if (!SO.hasValue() || !UO.hasValue()) + return { None, false }; + + // Check the smaller value first to see if it leaves the range. + // At this point, both SO and UO must have values. + Optional<APInt> Min = MinOptional(SO, UO); + if (LeavesRange(*Min)) + return { Min, true }; + Optional<APInt> Max = Min == SO ? UO : SO; + if (LeavesRange(*Max)) + return { Max, true }; + + // Solutions were found, but were eliminated, hence the "true". + return { None, true }; + }; - LLVMContext &Context = SE.getContext(); + std::tie(A, B, C, M, BitWidth) = *T; + // Lower bound is inclusive, subtract 1 to represent the exiting value. + APInt Lower = Range.getLower().sextOrSelf(A.getBitWidth()) - 1; + APInt Upper = Range.getUpper().sextOrSelf(A.getBitWidth()); + auto SL = SolveForBoundary(Lower); + auto SU = SolveForBoundary(Upper); + // If any of the solutions was unknown, no meaninigful conclusions can + // be made. + if (!SL.second || !SU.second) + return None; - ConstantInt *Solution1 = - ConstantInt::get(Context, (NegB + SqrtVal).sdiv(TwoA)); - ConstantInt *Solution2 = - ConstantInt::get(Context, (NegB - SqrtVal).sdiv(TwoA)); + // Claim: The correct solution is not some value between Min and Max. + // + // Justification: Assuming that Min and Max are different values, one of + // them is when the first signed overflow happens, the other is when the + // first unsigned overflow happens. Crossing the range boundary is only + // possible via an overflow (treating 0 as a special case of it, modeling + // an overflow as crossing k*2^W for some k). + // + // The interesting case here is when Min was eliminated as an invalid + // solution, but Max was not. The argument is that if there was another + // overflow between Min and Max, it would also have been eliminated if + // it was considered. + // + // For a given boundary, it is possible to have two overflows of the same + // type (signed/unsigned) without having the other type in between: this + // can happen when the vertex of the parabola is between the iterations + // corresponding to the overflows. This is only possible when the two + // overflows cross k*2^W for the same k. In such case, if the second one + // left the range (and was the first one to do so), the first overflow + // would have to enter the range, which would mean that either we had left + // the range before or that we started outside of it. Both of these cases + // are contradictions. + // + // Claim: In the case where SolveForBoundary returns None, the correct + // solution is not some value between the Max for this boundary and the + // Min of the other boundary. + // + // Justification: Assume that we had such Max_A and Min_B corresponding + // to range boundaries A and B and such that Max_A < Min_B. If there was + // a solution between Max_A and Min_B, it would have to be caused by an + // overflow corresponding to either A or B. It cannot correspond to B, + // since Min_B is the first occurrence of such an overflow. If it + // corresponded to A, it would have to be either a signed or an unsigned + // overflow that is larger than both eliminated overflows for A. But + // between the eliminated overflows and this overflow, the values would + // cover the entire value space, thus crossing the other boundary, which + // is a contradiction. - return std::make_pair(cast<SCEVConstant>(SE.getConstant(Solution1)), - cast<SCEVConstant>(SE.getConstant(Solution2))); + return TruncIfPossible(MinOptional(SL.first, SU.first), BitWidth); } ScalarEvolution::ExitLimit @@ -8441,23 +8650,12 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of // the quadratic equation to solve it. if (AddRec->isQuadratic() && AddRec->getType()->isIntegerTy()) { - if (auto Roots = SolveQuadraticEquation(AddRec, *this)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( - CmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - - // We can only use this value if the chrec ends up with an exact zero - // value at this index. When solving for "X*X != 5", for example, we - // should not accept a root of 2. - const SCEV *Val = AddRec->evaluateAtIteration(R1, *this); - if (Val->isZero()) - // We found a quadratic root! - return ExitLimit(R1, R1, false, Predicates); - } + // We can only use this value if the chrec ends up with an exact zero + // value at this index. When solving for "X*X != 5", for example, we + // should not accept a root of 2. + if (auto S = SolveQuadraticAddRecExact(AddRec, *this)) { + const auto *R = cast<SCEVConstant>(getConstant(S.getValue())); + return ExitLimit(R, R, false, Predicates); } return getCouldNotCompute(); } @@ -8617,7 +8815,13 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, const SCEV *&LHS, const SCEV *&RHS, unsigned Depth) { bool Changed = false; - + // Simplifies ICMP to trivial true or false by turning it into '0 == 0' or + // '0 != 0'. + auto TrivialCase = [&](bool TriviallyTrue) { + LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); + Pred = TriviallyTrue ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; + return true; + }; // If we hit the max recursion limit bail out. if (Depth >= 3) return false; @@ -8629,9 +8833,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, if (ConstantExpr::getICmp(Pred, LHSC->getValue(), RHSC->getValue())->isNullValue()) - goto trivially_false; + return TrivialCase(false); else - goto trivially_true; + return TrivialCase(true); } // Otherwise swap the operands to put the constant on the right. std::swap(LHS, RHS); @@ -8661,9 +8865,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, if (!ICmpInst::isEquality(Pred)) { ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); if (ExactCR.isFullSet()) - goto trivially_true; + return TrivialCase(true); else if (ExactCR.isEmptySet()) - goto trivially_false; + return TrivialCase(false); APInt NewRHS; CmpInst::Predicate NewPred; @@ -8699,7 +8903,7 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // The "Should have been caught earlier!" messages refer to the fact // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above // should have fired on the corresponding cases, and canonicalized the - // check to trivially_true or trivially_false. + // check to trivial case. case ICmpInst::ICMP_UGE: assert(!RA.isMinValue() && "Should have been caught earlier!"); @@ -8732,9 +8936,9 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, // Check for obvious equality. if (HasSameValue(LHS, RHS)) { if (ICmpInst::isTrueWhenEqual(Pred)) - goto trivially_true; + return TrivialCase(true); if (ICmpInst::isFalseWhenEqual(Pred)) - goto trivially_false; + return TrivialCase(false); } // If possible, canonicalize GE/LE comparisons to GT/LT comparisons, by @@ -8802,18 +9006,6 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred, return SimplifyICmpOperands(Pred, LHS, RHS, Depth+1); return Changed; - -trivially_true: - // Return 0 == 0. - LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); - Pred = ICmpInst::ICMP_EQ; - return true; - -trivially_false: - // Return 0 != 0. - LHS = RHS = getConstant(ConstantInt::getFalse(getContext())); - Pred = ICmpInst::ICMP_NE; - return true; } bool ScalarEvolution::isKnownNegative(const SCEV *S) { @@ -9184,6 +9376,11 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return true; + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + + if (isKnownViaNonRecursiveReasoning(Pred, LHS, RHS)) return true; @@ -9289,6 +9486,10 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L, // (interprocedural conditions notwithstanding). if (!L) return false; + if (VerifyIR) + assert(!verifyFunction(*L->getHeader()->getParent(), &dbgs()) && + "This cannot be done on broken IR!"); + // Both LHS and RHS must be available at loop entry. assert(isAvailableAtLoopEntry(LHS, L) && "LHS is not available at Loop Entry"); @@ -10565,52 +10766,11 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, ConstantInt::get(SE.getContext(), ExitVal - 1), SE)->getValue()) && "Linear scev computation is off in a bad way!"); return SE.getConstant(ExitValue); - } else if (isQuadratic()) { - // If this is a quadratic (3-term) AddRec {L,+,M,+,N}, find the roots of the - // quadratic equation to solve it. To do this, we must frame our problem in - // terms of figuring out when zero is crossed, instead of when - // Range.getUpper() is crossed. - SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end()); - NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); - const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap); - - // Next, solve the constructed addrec - if (auto Roots = - SolveQuadraticEquation(cast<SCEVAddRecExpr>(NewAddRec), SE)) { - const SCEVConstant *R1 = Roots->first; - const SCEVConstant *R2 = Roots->second; - // Pick the smallest positive root value. - if (ConstantInt *CB = dyn_cast<ConstantInt>(ConstantExpr::getICmp( - ICmpInst::ICMP_ULT, R1->getValue(), R2->getValue()))) { - if (!CB->getZExtValue()) - std::swap(R1, R2); // R1 is the minimum root now. - - // Make sure the root is not off by one. The returned iteration should - // not be in the range, but the previous one should be. When solving - // for "X*X < 5", for example, we should not return a root of 2. - ConstantInt *R1Val = - EvaluateConstantChrecAtConstant(this, R1->getValue(), SE); - if (Range.contains(R1Val->getValue())) { - // The next iteration must be out of the range... - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() + 1); - - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (!Range.contains(R1Val->getValue())) - return SE.getConstant(NextVal); - return SE.getCouldNotCompute(); // Something strange happened - } + } - // If R1 was not in the range, then it is a good return value. Make - // sure that R1-1 WAS in the range though, just in case. - ConstantInt *NextVal = - ConstantInt::get(SE.getContext(), R1->getAPInt() - 1); - R1Val = EvaluateConstantChrecAtConstant(this, NextVal, SE); - if (Range.contains(R1Val->getValue())) - return R1; - return SE.getCouldNotCompute(); // Something strange happened - } - } + if (isQuadratic()) { + if (auto S = SolveQuadraticAddRecRange(this, Range, SE)) + return SE.getConstant(S.getValue()); } return SE.getCouldNotCompute(); @@ -10920,7 +11080,7 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, Terms.erase(std::unique(Terms.begin(), Terms.end()), Terms.end()); // Put larger terms first. - llvm::sort(Terms.begin(), Terms.end(), [](const SCEV *LHS, const SCEV *RHS) { + llvm::sort(Terms, [](const SCEV *LHS, const SCEV *RHS) { return numberOfTerms(LHS) > numberOfTerms(RHS); }); |