summaryrefslogtreecommitdiff
path: root/lib/Support/APInt.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2019-01-19 10:01:25 +0000
committerDimitry Andric <dim@FreeBSD.org>2019-01-19 10:01:25 +0000
commitd8e91e46262bc44006913e6796843909f1ac7bcd (patch)
tree7d0c143d9b38190e0fa0180805389da22cd834c5 /lib/Support/APInt.cpp
parentb7eb8e35e481a74962664b63dfb09483b200209a (diff)
Notes
Diffstat (limited to 'lib/Support/APInt.cpp')
-rw-r--r--lib/Support/APInt.cpp336
1 files changed, 276 insertions, 60 deletions
diff --git a/lib/Support/APInt.cpp b/lib/Support/APInt.cpp
index 1fae0e9b8d6d..a5f4f98c489a 100644
--- a/lib/Support/APInt.cpp
+++ b/lib/Support/APInt.cpp
@@ -16,8 +16,10 @@
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/FoldingSet.h"
#include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringRef.h"
+#include "llvm/ADT/bit.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
@@ -78,7 +80,7 @@ void APInt::initSlowCase(uint64_t val, bool isSigned) {
U.pVal[0] = val;
if (isSigned && int64_t(val) < 0)
for (unsigned i = 1; i < getNumWords(); ++i)
- U.pVal[i] = WORD_MAX;
+ U.pVal[i] = WORDTYPE_MAX;
clearUnusedBits();
}
@@ -304,13 +306,13 @@ void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) {
unsigned hiWord = whichWord(hiBit);
// Create an initial mask for the low word with zeros below loBit.
- uint64_t loMask = WORD_MAX << whichBit(loBit);
+ uint64_t loMask = WORDTYPE_MAX << whichBit(loBit);
// If hiBit is not aligned, we need a high mask.
unsigned hiShiftAmt = whichBit(hiBit);
if (hiShiftAmt != 0) {
// Create a high mask with zeros above hiBit.
- uint64_t hiMask = WORD_MAX >> (APINT_BITS_PER_WORD - hiShiftAmt);
+ uint64_t hiMask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - hiShiftAmt);
// If loWord and hiWord are equal, then we combine the masks. Otherwise,
// set the bits in hiWord.
if (hiWord == loWord)
@@ -323,7 +325,7 @@ void APInt::setBitsSlowCase(unsigned loBit, unsigned hiBit) {
// Fill any words between loWord and hiWord with all ones.
for (unsigned word = loWord + 1; word < hiWord; ++word)
- U.pVal[word] = WORD_MAX;
+ U.pVal[word] = WORDTYPE_MAX;
}
/// Toggle every bit to its opposite value.
@@ -354,7 +356,7 @@ void APInt::insertBits(const APInt &subBits, unsigned bitPosition) {
// Single word result can be done as a direct bitmask.
if (isSingleWord()) {
- uint64_t mask = WORD_MAX >> (APINT_BITS_PER_WORD - subBitWidth);
+ uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - subBitWidth);
U.VAL &= ~(mask << bitPosition);
U.VAL |= (subBits.U.VAL << bitPosition);
return;
@@ -366,7 +368,7 @@ void APInt::insertBits(const APInt &subBits, unsigned bitPosition) {
// Insertion within a single word can be done as a direct bitmask.
if (loWord == hi1Word) {
- uint64_t mask = WORD_MAX >> (APINT_BITS_PER_WORD - subBitWidth);
+ uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - subBitWidth);
U.pVal[loWord] &= ~(mask << loBit);
U.pVal[loWord] |= (subBits.U.VAL << loBit);
return;
@@ -382,7 +384,7 @@ void APInt::insertBits(const APInt &subBits, unsigned bitPosition) {
// Mask+insert remaining bits.
unsigned remainingBits = subBitWidth % APINT_BITS_PER_WORD;
if (remainingBits != 0) {
- uint64_t mask = WORD_MAX >> (APINT_BITS_PER_WORD - remainingBits);
+ uint64_t mask = WORDTYPE_MAX >> (APINT_BITS_PER_WORD - remainingBits);
U.pVal[hi1Word] &= ~mask;
U.pVal[hi1Word] |= subBits.getWord(subBitWidth - 1);
}
@@ -558,7 +560,7 @@ unsigned APInt::countLeadingOnesSlowCase() const {
unsigned Count = llvm::countLeadingOnes(U.pVal[i] << shift);
if (Count == highWordBits) {
for (i--; i >= 0; --i) {
- if (U.pVal[i] == WORD_MAX)
+ if (U.pVal[i] == WORDTYPE_MAX)
Count += APINT_BITS_PER_WORD;
else {
Count += llvm::countLeadingOnes(U.pVal[i]);
@@ -582,7 +584,7 @@ unsigned APInt::countTrailingZerosSlowCase() const {
unsigned APInt::countTrailingOnesSlowCase() const {
unsigned Count = 0;
unsigned i = 0;
- for (; i < getNumWords() && U.pVal[i] == WORD_MAX; ++i)
+ for (; i < getNumWords() && U.pVal[i] == WORDTYPE_MAX; ++i)
Count += APINT_BITS_PER_WORD;
if (i < getNumWords())
Count += llvm::countTrailingOnes(U.pVal[i]);
@@ -711,24 +713,20 @@ APInt llvm::APIntOps::GreatestCommonDivisor(APInt A, APInt B) {
}
APInt llvm::APIntOps::RoundDoubleToAPInt(double Double, unsigned width) {
- union {
- double D;
- uint64_t I;
- } T;
- T.D = Double;
+ uint64_t I = bit_cast<uint64_t>(Double);
// Get the sign bit from the highest order bit
- bool isNeg = T.I >> 63;
+ bool isNeg = I >> 63;
// Get the 11-bit exponent and adjust for the 1023 bit bias
- int64_t exp = ((T.I >> 52) & 0x7ff) - 1023;
+ int64_t exp = ((I >> 52) & 0x7ff) - 1023;
// If the exponent is negative, the value is < 0 so just return 0.
if (exp < 0)
return APInt(width, 0u);
// Extract the mantissa by clearing the top 12 bits (sign + exponent).
- uint64_t mantissa = (T.I & (~0ULL >> 12)) | 1ULL << 52;
+ uint64_t mantissa = (I & (~0ULL >> 12)) | 1ULL << 52;
// If the exponent doesn't shift all bits out of the mantissa
if (exp < 52)
@@ -805,12 +803,8 @@ double APInt::roundToDouble(bool isSigned) const {
// The leading bit of mantissa is implicit, so get rid of it.
uint64_t sign = isNeg ? (1ULL << (APINT_BITS_PER_WORD - 1)) : 0;
- union {
- double D;
- uint64_t I;
- } T;
- T.I = sign | (exp << 52) | mantissa;
- return T.D;
+ uint64_t I = sign | (exp << 52) | mantissa;
+ return bit_cast<double>(I);
}
// Truncate to new width.
@@ -1253,20 +1247,18 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
// The DEBUG macros here tend to be spam in the debug output if you're not
// debugging this code. Disable them unless KNUTH_DEBUG is defined.
-#pragma push_macro("LLVM_DEBUG")
-#ifndef KNUTH_DEBUG
-#undef LLVM_DEBUG
-#define LLVM_DEBUG(X) \
- do { \
- } while (false)
+#ifdef KNUTH_DEBUG
+#define DEBUG_KNUTH(X) LLVM_DEBUG(X)
+#else
+#define DEBUG_KNUTH(X) do {} while(false)
#endif
- LLVM_DEBUG(dbgs() << "KnuthDiv: m=" << m << " n=" << n << '\n');
- LLVM_DEBUG(dbgs() << "KnuthDiv: original:");
- LLVM_DEBUG(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
- LLVM_DEBUG(dbgs() << " by");
- LLVM_DEBUG(for (int i = n; i > 0; i--) dbgs() << " " << v[i - 1]);
- LLVM_DEBUG(dbgs() << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: m=" << m << " n=" << n << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: original:");
+ DEBUG_KNUTH(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
+ DEBUG_KNUTH(dbgs() << " by");
+ DEBUG_KNUTH(for (int i = n; i > 0; i--) dbgs() << " " << v[i - 1]);
+ DEBUG_KNUTH(dbgs() << '\n');
// D1. [Normalize.] Set d = b / (v[n-1] + 1) and multiply all the digits of
// u and v by d. Note that we have taken Knuth's advice here to use a power
// of 2 value for d such that d * v[n-1] >= b/2 (b is the base). A power of
@@ -1292,16 +1284,16 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
}
u[m+n] = u_carry;
- LLVM_DEBUG(dbgs() << "KnuthDiv: normal:");
- LLVM_DEBUG(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
- LLVM_DEBUG(dbgs() << " by");
- LLVM_DEBUG(for (int i = n; i > 0; i--) dbgs() << " " << v[i - 1]);
- LLVM_DEBUG(dbgs() << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: normal:");
+ DEBUG_KNUTH(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
+ DEBUG_KNUTH(dbgs() << " by");
+ DEBUG_KNUTH(for (int i = n; i > 0; i--) dbgs() << " " << v[i - 1]);
+ DEBUG_KNUTH(dbgs() << '\n');
// D2. [Initialize j.] Set j to m. This is the loop counter over the places.
int j = m;
do {
- LLVM_DEBUG(dbgs() << "KnuthDiv: quotient digit #" << j << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: quotient digit #" << j << '\n');
// D3. [Calculate q'.].
// Set qp = (u[j+n]*b + u[j+n-1]) / v[n-1]. (qp=qprime=q')
// Set rp = (u[j+n]*b + u[j+n-1]) % v[n-1]. (rp=rprime=r')
@@ -1311,7 +1303,7 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
// value qp is one too large, and it eliminates all cases where qp is two
// too large.
uint64_t dividend = Make_64(u[j+n], u[j+n-1]);
- LLVM_DEBUG(dbgs() << "KnuthDiv: dividend == " << dividend << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: dividend == " << dividend << '\n');
uint64_t qp = dividend / v[n-1];
uint64_t rp = dividend % v[n-1];
if (qp == b || qp*v[n-2] > b*rp + u[j+n-2]) {
@@ -1320,7 +1312,7 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
if (rp < b && (qp == b || qp*v[n-2] > b*rp + u[j+n-2]))
qp--;
}
- LLVM_DEBUG(dbgs() << "KnuthDiv: qp == " << qp << ", rp == " << rp << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: qp == " << qp << ", rp == " << rp << '\n');
// D4. [Multiply and subtract.] Replace (u[j+n]u[j+n-1]...u[j]) with
// (u[j+n]u[j+n-1]..u[j]) - qp * (v[n-1]...v[1]v[0]). This computation
@@ -1336,15 +1328,15 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
int64_t subres = int64_t(u[j+i]) - borrow - Lo_32(p);
u[j+i] = Lo_32(subres);
borrow = Hi_32(p) - Hi_32(subres);
- LLVM_DEBUG(dbgs() << "KnuthDiv: u[j+i] = " << u[j + i]
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: u[j+i] = " << u[j + i]
<< ", borrow = " << borrow << '\n');
}
bool isNeg = u[j+n] < borrow;
u[j+n] -= Lo_32(borrow);
- LLVM_DEBUG(dbgs() << "KnuthDiv: after subtraction:");
- LLVM_DEBUG(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
- LLVM_DEBUG(dbgs() << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: after subtraction:");
+ DEBUG_KNUTH(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
+ DEBUG_KNUTH(dbgs() << '\n');
// D5. [Test remainder.] Set q[j] = qp. If the result of step D4 was
// negative, go to step D6; otherwise go on to step D7.
@@ -1365,16 +1357,16 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
}
u[j+n] += carry;
}
- LLVM_DEBUG(dbgs() << "KnuthDiv: after correction:");
- LLVM_DEBUG(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
- LLVM_DEBUG(dbgs() << "\nKnuthDiv: digit result = " << q[j] << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: after correction:");
+ DEBUG_KNUTH(for (int i = m + n; i >= 0; i--) dbgs() << " " << u[i]);
+ DEBUG_KNUTH(dbgs() << "\nKnuthDiv: digit result = " << q[j] << '\n');
// D7. [Loop on j.] Decrease j by one. Now if j >= 0, go back to D3.
} while (--j >= 0);
- LLVM_DEBUG(dbgs() << "KnuthDiv: quotient:");
- LLVM_DEBUG(for (int i = m; i >= 0; i--) dbgs() << " " << q[i]);
- LLVM_DEBUG(dbgs() << '\n');
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: quotient:");
+ DEBUG_KNUTH(for (int i = m; i >= 0; i--) dbgs() << " " << q[i]);
+ DEBUG_KNUTH(dbgs() << '\n');
// D8. [Unnormalize]. Now q[...] is the desired quotient, and the desired
// remainder may be obtained by dividing u[...] by d. If r is non-null we
@@ -1385,23 +1377,21 @@ static void KnuthDiv(uint32_t *u, uint32_t *v, uint32_t *q, uint32_t* r,
// shift right here.
if (shift) {
uint32_t carry = 0;
- LLVM_DEBUG(dbgs() << "KnuthDiv: remainder:");
+ DEBUG_KNUTH(dbgs() << "KnuthDiv: remainder:");
for (int i = n-1; i >= 0; i--) {
r[i] = (u[i] >> shift) | carry;
carry = u[i] << (32 - shift);
- LLVM_DEBUG(dbgs() << " " << r[i]);
+ DEBUG_KNUTH(dbgs() << " " << r[i]);
}
} else {
for (int i = n-1; i >= 0; i--) {
r[i] = u[i];
- LLVM_DEBUG(dbgs() << " " << r[i]);
+ DEBUG_KNUTH(dbgs() << " " << r[i]);
}
}
- LLVM_DEBUG(dbgs() << '\n');
+ DEBUG_KNUTH(dbgs() << '\n');
}
- LLVM_DEBUG(dbgs() << '\n');
-
-#pragma pop_macro("LLVM_DEBUG")
+ DEBUG_KNUTH(dbgs() << '\n');
}
void APInt::divide(const WordType *LHS, unsigned lhsWords, const WordType *RHS,
@@ -1957,7 +1947,43 @@ APInt APInt::ushl_ov(const APInt &ShAmt, bool &Overflow) const {
return *this << ShAmt;
}
+APInt APInt::sadd_sat(const APInt &RHS) const {
+ bool Overflow;
+ APInt Res = sadd_ov(RHS, Overflow);
+ if (!Overflow)
+ return Res;
+
+ return isNegative() ? APInt::getSignedMinValue(BitWidth)
+ : APInt::getSignedMaxValue(BitWidth);
+}
+
+APInt APInt::uadd_sat(const APInt &RHS) const {
+ bool Overflow;
+ APInt Res = uadd_ov(RHS, Overflow);
+ if (!Overflow)
+ return Res;
+
+ return APInt::getMaxValue(BitWidth);
+}
+
+APInt APInt::ssub_sat(const APInt &RHS) const {
+ bool Overflow;
+ APInt Res = ssub_ov(RHS, Overflow);
+ if (!Overflow)
+ return Res;
+ return isNegative() ? APInt::getSignedMinValue(BitWidth)
+ : APInt::getSignedMaxValue(BitWidth);
+}
+
+APInt APInt::usub_sat(const APInt &RHS) const {
+ bool Overflow;
+ APInt Res = usub_ov(RHS, Overflow);
+ if (!Overflow)
+ return Res;
+
+ return APInt(BitWidth, 0);
+}
void APInt::fromString(unsigned numbits, StringRef str, uint8_t radix) {
@@ -2707,3 +2733,193 @@ APInt llvm::APIntOps::RoundingSDiv(const APInt &A, const APInt &B,
}
llvm_unreachable("Unknown APInt::Rounding enum");
}
+
+Optional<APInt>
+llvm::APIntOps::SolveQuadraticEquationWrap(APInt A, APInt B, APInt C,
+ unsigned RangeWidth) {
+ unsigned CoeffWidth = A.getBitWidth();
+ assert(CoeffWidth == B.getBitWidth() && CoeffWidth == C.getBitWidth());
+ assert(RangeWidth <= CoeffWidth &&
+ "Value range width should be less than coefficient width");
+ assert(RangeWidth > 1 && "Value range bit width should be > 1");
+
+ LLVM_DEBUG(dbgs() << __func__ << ": solving " << A << "x^2 + " << B
+ << "x + " << C << ", rw:" << RangeWidth << '\n');
+
+ // Identify 0 as a (non)solution immediately.
+ if (C.sextOrTrunc(RangeWidth).isNullValue() ) {
+ LLVM_DEBUG(dbgs() << __func__ << ": zero solution\n");
+ return APInt(CoeffWidth, 0);
+ }
+
+ // The result of APInt arithmetic has the same bit width as the operands,
+ // so it can actually lose high bits. A product of two n-bit integers needs
+ // 2n-1 bits to represent the full value.
+ // The operation done below (on quadratic coefficients) that can produce
+ // the largest value is the evaluation of the equation during bisection,
+ // which needs 3 times the bitwidth of the coefficient, so the total number
+ // of required bits is 3n.
+ //
+ // The purpose of this extension is to simulate the set Z of all integers,
+ // where n+1 > n for all n in Z. In Z it makes sense to talk about positive
+ // and negative numbers (not so much in a modulo arithmetic). The method
+ // used to solve the equation is based on the standard formula for real
+ // numbers, and uses the concepts of "positive" and "negative" with their
+ // usual meanings.
+ CoeffWidth *= 3;
+ A = A.sext(CoeffWidth);
+ B = B.sext(CoeffWidth);
+ C = C.sext(CoeffWidth);
+
+ // Make A > 0 for simplicity. Negate cannot overflow at this point because
+ // the bit width has increased.
+ if (A.isNegative()) {
+ A.negate();
+ B.negate();
+ C.negate();
+ }
+
+ // Solving an equation q(x) = 0 with coefficients in modular arithmetic
+ // is really solving a set of equations q(x) = kR for k = 0, 1, 2, ...,
+ // and R = 2^BitWidth.
+ // Since we're trying not only to find exact solutions, but also values
+ // that "wrap around", such a set will always have a solution, i.e. an x
+ // that satisfies at least one of the equations, or such that |q(x)|
+ // exceeds kR, while |q(x-1)| for the same k does not.
+ //
+ // We need to find a value k, such that Ax^2 + Bx + C = kR will have a
+ // positive solution n (in the above sense), and also such that the n
+ // will be the least among all solutions corresponding to k = 0, 1, ...
+ // (more precisely, the least element in the set
+ // { n(k) | k is such that a solution n(k) exists }).
+ //
+ // Consider the parabola (over real numbers) that corresponds to the
+ // quadratic equation. Since A > 0, the arms of the parabola will point
+ // up. Picking different values of k will shift it up and down by R.
+ //
+ // We want to shift the parabola in such a way as to reduce the problem
+ // of solving q(x) = kR to solving shifted_q(x) = 0.
+ // (The interesting solutions are the ceilings of the real number
+ // solutions.)
+ APInt R = APInt::getOneBitSet(CoeffWidth, RangeWidth);
+ APInt TwoA = 2 * A;
+ APInt SqrB = B * B;
+ bool PickLow;
+
+ auto RoundUp = [] (const APInt &V, const APInt &A) -> APInt {
+ assert(A.isStrictlyPositive());
+ APInt T = V.abs().urem(A);
+ if (T.isNullValue())
+ return V;
+ return V.isNegative() ? V+T : V+(A-T);
+ };
+
+ // The vertex of the parabola is at -B/2A, but since A > 0, it's negative
+ // iff B is positive.
+ if (B.isNonNegative()) {
+ // If B >= 0, the vertex it at a negative location (or at 0), so in
+ // order to have a non-negative solution we need to pick k that makes
+ // C-kR negative. To satisfy all the requirements for the solution
+ // that we are looking for, it needs to be closest to 0 of all k.
+ C = C.srem(R);
+ if (C.isStrictlyPositive())
+ C -= R;
+ // Pick the greater solution.
+ PickLow = false;
+ } else {
+ // If B < 0, the vertex is at a positive location. For any solution
+ // to exist, the discriminant must be non-negative. This means that
+ // C-kR <= B^2/4A is a necessary condition for k, i.e. there is a
+ // lower bound on values of k: kR >= C - B^2/4A.
+ APInt LowkR = C - SqrB.udiv(2*TwoA); // udiv because all values > 0.
+ // Round LowkR up (towards +inf) to the nearest kR.
+ LowkR = RoundUp(LowkR, R);
+
+ // If there exists k meeting the condition above, and such that
+ // C-kR > 0, there will be two positive real number solutions of
+ // q(x) = kR. Out of all such values of k, pick the one that makes
+ // C-kR closest to 0, (i.e. pick maximum k such that C-kR > 0).
+ // In other words, find maximum k such that LowkR <= kR < C.
+ if (C.sgt(LowkR)) {
+ // If LowkR < C, then such a k is guaranteed to exist because
+ // LowkR itself is a multiple of R.
+ C -= -RoundUp(-C, R); // C = C - RoundDown(C, R)
+ // Pick the smaller solution.
+ PickLow = true;
+ } else {
+ // If C-kR < 0 for all potential k's, it means that one solution
+ // will be negative, while the other will be positive. The positive
+ // solution will shift towards 0 if the parabola is moved up.
+ // Pick the kR closest to the lower bound (i.e. make C-kR closest
+ // to 0, or in other words, out of all parabolas that have solutions,
+ // pick the one that is the farthest "up").
+ // Since LowkR is itself a multiple of R, simply take C-LowkR.
+ C -= LowkR;
+ // Pick the greater solution.
+ PickLow = false;
+ }
+ }
+
+ LLVM_DEBUG(dbgs() << __func__ << ": updated coefficients " << A << "x^2 + "
+ << B << "x + " << C << ", rw:" << RangeWidth << '\n');
+
+ APInt D = SqrB - 4*A*C;
+ assert(D.isNonNegative() && "Negative discriminant");
+ APInt SQ = D.sqrt();
+
+ APInt Q = SQ * SQ;
+ bool InexactSQ = Q != D;
+ // The calculated SQ may actually be greater than the exact (non-integer)
+ // value. If that's the case, decremement SQ to get a value that is lower.
+ if (Q.sgt(D))
+ SQ -= 1;
+
+ APInt X;
+ APInt Rem;
+
+ // SQ is rounded down (i.e SQ * SQ <= D), so the roots may be inexact.
+ // When using the quadratic formula directly, the calculated low root
+ // may be greater than the exact one, since we would be subtracting SQ.
+ // To make sure that the calculated root is not greater than the exact
+ // one, subtract SQ+1 when calculating the low root (for inexact value
+ // of SQ).
+ if (PickLow)
+ APInt::sdivrem(-B - (SQ+InexactSQ), TwoA, X, Rem);
+ else
+ APInt::sdivrem(-B + SQ, TwoA, X, Rem);
+
+ // The updated coefficients should be such that the (exact) solution is
+ // positive. Since APInt division rounds towards 0, the calculated one
+ // can be 0, but cannot be negative.
+ assert(X.isNonNegative() && "Solution should be non-negative");
+
+ if (!InexactSQ && Rem.isNullValue()) {
+ LLVM_DEBUG(dbgs() << __func__ << ": solution (root): " << X << '\n');
+ return X;
+ }
+
+ assert((SQ*SQ).sle(D) && "SQ = |_sqrt(D)_|, so SQ*SQ <= D");
+ // The exact value of the square root of D should be between SQ and SQ+1.
+ // This implies that the solution should be between that corresponding to
+ // SQ (i.e. X) and that corresponding to SQ+1.
+ //
+ // The calculated X cannot be greater than the exact (real) solution.
+ // Actually it must be strictly less than the exact solution, while
+ // X+1 will be greater than or equal to it.
+
+ APInt VX = (A*X + B)*X + C;
+ APInt VY = VX + TwoA*X + A + B;
+ bool SignChange = VX.isNegative() != VY.isNegative() ||
+ VX.isNullValue() != VY.isNullValue();
+ // If the sign did not change between X and X+1, X is not a valid solution.
+ // This could happen when the actual (exact) roots don't have an integer
+ // between them, so they would both be contained between X and X+1.
+ if (!SignChange) {
+ LLVM_DEBUG(dbgs() << __func__ << ": no valid solution\n");
+ return None;
+ }
+
+ X += 1;
+ LLVM_DEBUG(dbgs() << __func__ << ": solution (wrap): " << X << '\n');
+ return X;
+}