summaryrefslogtreecommitdiff
path: root/lib/Analysis/ScalarEvolution.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r--lib/Analysis/ScalarEvolution.cpp45
1 files changed, 27 insertions, 18 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 44f1a6dde0d21..b3905cc01e84b 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -7032,20 +7032,21 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const APInt &B,
// 3. Compute I: the multiplicative inverse of (A / D) in arithmetic
// modulo (N / D).
//
- // (N / D) may need BW+1 bits in its representation. Hence, we'll use this
- // bit width during computations.
+ // If D == 1, (N / D) == N == 2^BW, so we need one extra bit to represent
+ // (N / D) in general. The inverse itself always fits into BW bits, though,
+ // so we immediately truncate it.
APInt AD = A.lshr(Mult2).zext(BW + 1); // AD = A / D
APInt Mod(BW + 1, 0);
Mod.setBit(BW - Mult2); // Mod = N / D
- APInt I = AD.multiplicativeInverse(Mod);
+ APInt I = AD.multiplicativeInverse(Mod).trunc(BW);
// 4. Compute the minimum unsigned root of the equation:
// I * (B / D) mod (N / D)
- APInt Result = (I * B.lshr(Mult2).zext(BW + 1)).urem(Mod);
+ // To simplify the computation, we factor out the divide by D:
+ // (I * B mod N) / D
+ APInt Result = (I * B).lshr(Mult2);
- // The result is guaranteed to be less than 2^BW so we may truncate it to BW
- // bits.
- return SE.getConstant(Result.trunc(BW));
+ return SE.getConstant(Result);
}
/// Find the roots of the quadratic equation for the given quadratic chrec
@@ -7206,17 +7207,25 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,
// 1*N = -Start; -1*N = Start (mod 2^BW), so:
// N = Distance (as unsigned)
if (StepC->getValue()->equalsInt(1) || StepC->getValue()->isAllOnesValue()) {
- ConstantRange CR = getUnsignedRange(Start);
- const SCEV *MaxBECount;
- if (!CountDown && CR.getUnsignedMin().isMinValue())
- // When counting up, the worst starting value is 1, not 0.
- MaxBECount = CR.getUnsignedMax().isMinValue()
- ? getConstant(APInt::getMinValue(CR.getBitWidth()))
- : getConstant(APInt::getMaxValue(CR.getBitWidth()));
- else
- MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()
- : -CR.getUnsignedMin());
- return ExitLimit(Distance, MaxBECount, false, Predicates);
+ APInt MaxBECount = getUnsignedRange(Distance).getUnsignedMax();
+
+ // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
+ // we end up with a loop whose backedge-taken count is n - 1. Detect this
+ // case, and see if we can improve the bound.
+ //
+ // Explicitly handling this here is necessary because getUnsignedRange
+ // isn't context-sensitive; it doesn't know that we only care about the
+ // range inside the loop.
+ const SCEV *Zero = getZero(Distance->getType());
+ const SCEV *One = getOne(Distance->getType());
+ const SCEV *DistancePlusOne = getAddExpr(Distance, One);
+ if (isLoopEntryGuardedByCond(L, ICmpInst::ICMP_NE, DistancePlusOne, Zero)) {
+ // If Distance + 1 doesn't overflow, we can compute the maximum distance
+ // as "unsigned_max(Distance + 1) - 1".
+ ConstantRange CR = getUnsignedRange(DistancePlusOne);
+ MaxBECount = APIntOps::umin(MaxBECount, CR.getUnsignedMax() - 1);
+ }
+ return ExitLimit(Distance, getConstant(MaxBECount), false, Predicates);
}
// As a special case, handle the instance where Step is a positive power of