diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 45 |
1 files changed, 27 insertions, 18 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 44f1a6dde0d21..b3905cc01e84b 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -7032,20 +7032,21 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B, // 3. Compute I: the multiplicative inverse of (A / D) in arithmetic // modulo (N / D). // - // (N / D) may need BW+1 bits in its representation. Hence, we'll use this - // bit width during computations. + // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent + // (N / D) in general. The inverse itself always fits into BW bits, though, + // so we immediately truncate it. APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D APInt Mod(BW + 1, 0); Mod.setBit(BW - Mult2); // Mod = N / D - APInt I = AD.multiplicativeInverse(Mod); + APInt I = AD.multiplicativeInverse(Mod).trunc(BW); // 4. Compute the minimum unsigned root of the equation: // I * (B / D) mod (N / D) - APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod); + // To simplify the computation, we factor out the divide by D: + // (I * B mod N) / D + APInt Result = (I * B).lshr(Mult2); - // The result is guaranteed to be less than 2^BW so we may truncate it to BW - // bits. - return SE.getConstant(Result.trunc(BW)); + return SE.getConstant(Result); } /// Find the roots of the quadratic equation for the given quadratic chrec @@ -7206,17 +7207,25 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit, // 1*N = -Start; -1*N = Start (mod 2^BW), so: // N = Distance (as unsigned) if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) { - ConstantRange CR = getUnsignedRange(Start); - const SCEV *MaxBECount; - if (!CountDown && CR.getUnsignedMin().isMinValue()) - // When counting up, the worst starting value is 1, not 0. - MaxBECount = CR.getUnsignedMax().isMinValue() - ? getConstant(APInt::getMinValue(CR.getBitWidth())) - : getConstant(APInt::getMaxValue(CR.getBitWidth())); - else - MaxBECount = getConstant(CountDown ? CR.getUnsignedMax() - : -CR.getUnsignedMin()); - return ExitLimit(Distance, MaxBECount, false, Predicates); + APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax(); + + // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated, + // we end up with a loop whose backedge-taken count is n - 1. Detect this + // case, and see if we can improve the bound. + // + // Explicitly handling this here is necessary because getUnsignedRange + // isn't context-sensitive; it doesn't know that we only care about the + // range inside the loop. + const SCEV *Zero = getZero(Distance->getType()); + const SCEV *One = getOne(Distance->getType()); + const SCEV *DistancePlusOne = getAddExpr(Distance, One); + if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) { + // If Distance + 1 doesn't overflow, we can compute the maximum distance + // as "unsigned_max(Distance + 1) - 1". + ConstantRange CR = getUnsignedRange(DistancePlusOne); + MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1); + } + return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates); } // As a special case, handle the instance where Step is a positive power of |