summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineCompares.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp737
1 files changed, 351 insertions, 386 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index a8faaecb5c34..3bc7fae77cb1 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -17,9 +17,7 @@
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
-#include "llvm/Analysis/MemoryBuiltins.h"
#include "llvm/Analysis/TargetLibraryInfo.h"
-#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
@@ -37,77 +35,30 @@ using namespace PatternMatch;
STATISTIC(NumSel, "Number of select opts");
-static ConstantInt *extractElement(Constant *V, Constant *Idx) {
- return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx));
-}
-
-static bool hasAddOverflow(ConstantInt *Result,
- ConstantInt *In1, ConstantInt *In2,
- bool IsSigned) {
- if (!IsSigned)
- return Result->getValue().ult(In1->getValue());
-
- if (In2->isNegative())
- return Result->getValue().sgt(In1->getValue());
- return Result->getValue().slt(In1->getValue());
-}
-
/// Compute Result = In1+In2, returning true if the result overflowed for this
/// type.
-static bool addWithOverflow(Constant *&Result, Constant *In1,
- Constant *In2, bool IsSigned = false) {
- Result = ConstantExpr::getAdd(In1, In2);
-
- if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
- for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
- Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
- if (hasAddOverflow(extractElement(Result, Idx),
- extractElement(In1, Idx),
- extractElement(In2, Idx),
- IsSigned))
- return true;
- }
- return false;
- }
-
- return hasAddOverflow(cast<ConstantInt>(Result),
- cast<ConstantInt>(In1), cast<ConstantInt>(In2),
- IsSigned);
-}
-
-static bool hasSubOverflow(ConstantInt *Result,
- ConstantInt *In1, ConstantInt *In2,
- bool IsSigned) {
- if (!IsSigned)
- return Result->getValue().ugt(In1->getValue());
-
- if (In2->isNegative())
- return Result->getValue().slt(In1->getValue());
+static bool addWithOverflow(APInt &Result, const APInt &In1,
+ const APInt &In2, bool IsSigned = false) {
+ bool Overflow;
+ if (IsSigned)
+ Result = In1.sadd_ov(In2, Overflow);
+ else
+ Result = In1.uadd_ov(In2, Overflow);
- return Result->getValue().sgt(In1->getValue());
+ return Overflow;
}
/// Compute Result = In1-In2, returning true if the result overflowed for this
/// type.
-static bool subWithOverflow(Constant *&Result, Constant *In1,
- Constant *In2, bool IsSigned = false) {
- Result = ConstantExpr::getSub(In1, In2);
-
- if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) {
- for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) {
- Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i);
- if (hasSubOverflow(extractElement(Result, Idx),
- extractElement(In1, Idx),
- extractElement(In2, Idx),
- IsSigned))
- return true;
- }
- return false;
- }
+static bool subWithOverflow(APInt &Result, const APInt &In1,
+ const APInt &In2, bool IsSigned = false) {
+ bool Overflow;
+ if (IsSigned)
+ Result = In1.ssub_ov(In2, Overflow);
+ else
+ Result = In1.usub_ov(In2, Overflow);
- return hasSubOverflow(cast<ConstantInt>(Result),
- cast<ConstantInt>(In1), cast<ConstantInt>(In2),
- IsSigned);
+ return Overflow;
}
/// Given an icmp instruction, return true if any use of this comparison is a
@@ -473,8 +424,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
// Look for an appropriate type:
// - The type of Idx if the magic fits
- // - The smallest fitting legal type if we have a DataLayout
- // - Default to i32
+ // - The smallest fitting legal type
if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth())
Ty = Idx->getType();
else
@@ -1108,7 +1058,6 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI,
// because we don't allow ptrtoint. Memcpy and memmove are safe because
// we don't allow stores, so src cannot point to V.
case Intrinsic::lifetime_start: case Intrinsic::lifetime_end:
- case Intrinsic::dbg_declare: case Intrinsic::dbg_value:
case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset:
continue;
default:
@@ -1131,8 +1080,7 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI,
}
/// Fold "icmp pred (X+CI), X".
-Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI,
- Value *X, ConstantInt *CI,
+Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI,
ICmpInst::Predicate Pred) {
// From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0,
// so the values can never be equal. Similarly for all other "or equals"
@@ -1367,6 +1315,24 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
}
+// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0)
+Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {
+ CmpInst::Predicate Pred = Cmp.getPredicate();
+ Value *X = Cmp.getOperand(0);
+
+ if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) {
+ Value *A, *B;
+ SelectPatternResult SPR = matchSelectPattern(X, A, B);
+ if (SPR.Flavor == SPF_SMIN) {
+ if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT))
+ return new ICmpInst(Pred, B, Cmp.getOperand(1));
+ if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT))
+ return new ICmpInst(Pred, A, Cmp.getOperand(1));
+ }
+ }
+ return nullptr;
+}
+
// Fold icmp Pred X, C.
Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
CmpInst::Predicate Pred = Cmp.getPredicate();
@@ -1398,17 +1364,6 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
return Res;
}
- // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0)
- if (C->isNullValue() && Pred == ICmpInst::ICMP_SGT) {
- SelectPatternResult SPR = matchSelectPattern(X, A, B);
- if (SPR.Flavor == SPF_SMIN) {
- if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT))
- return new ICmpInst(Pred, B, Cmp.getOperand(1));
- if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT))
- return new ICmpInst(Pred, A, Cmp.getOperand(1));
- }
- }
-
// FIXME: Use m_APInt to allow folds for splat constants.
ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1));
if (!CI)
@@ -1462,11 +1417,11 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
/// Fold icmp (trunc X, Y), C.
Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
- Instruction *Trunc,
- const APInt *C) {
+ TruncInst *Trunc,
+ const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Trunc->getOperand(0);
- if (C->isOneValue() && C->getBitWidth() > 1) {
+ if (C.isOneValue() && C.getBitWidth() > 1) {
// icmp slt trunc(signum(V)) 1 --> icmp slt V, 1
Value *V = nullptr;
if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V))))
@@ -1484,7 +1439,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
// If all the high bits are known, we can do this xform.
if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) {
// Pull in the high bits from known-ones set.
- APInt NewRHS = C->zext(SrcBits);
+ APInt NewRHS = C.zext(SrcBits);
NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS));
}
@@ -1496,7 +1451,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp,
/// Fold icmp (xor X, Y), C.
Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
BinaryOperator *Xor,
- const APInt *C) {
+ const APInt &C) {
Value *X = Xor->getOperand(0);
Value *Y = Xor->getOperand(1);
const APInt *XorC;
@@ -1506,8 +1461,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
// If this is a comparison that tests the signbit (X < 0) or (x > -1),
// fold the xor.
ICmpInst::Predicate Pred = Cmp.getPredicate();
- if ((Pred == ICmpInst::ICMP_SLT && C->isNullValue()) ||
- (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) {
+ bool TrueIfSigned = false;
+ if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) {
// If the sign bit of the XorCst is not set, there is no change to
// the operation, just stop using the Xor.
@@ -1517,17 +1472,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
return &Cmp;
}
- // Was the old condition true if the operand is positive?
- bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT;
-
- // If so, the new one isn't.
- isTrueIfPositive ^= true;
-
- Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1));
- if (isTrueIfPositive)
- return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant));
+ // Emit the opposite comparison.
+ if (TrueIfSigned)
+ return new ICmpInst(ICmpInst::ICMP_SGT, X,
+ ConstantInt::getAllOnesValue(X->getType()));
else
- return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant));
+ return new ICmpInst(ICmpInst::ICMP_SLT, X,
+ ConstantInt::getNullValue(X->getType()));
}
if (Xor->hasOneUse()) {
@@ -1535,7 +1486,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
if (!Cmp.isEquality() && XorC->isSignMask()) {
Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate()
: Cmp.getSignedPredicate();
- return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC));
+ return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
// (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask))
@@ -1543,18 +1494,18 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate()
: Cmp.getSignedPredicate();
Pred = Cmp.getSwappedPredicate(Pred);
- return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC));
+ return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC));
}
}
// (icmp ugt (xor X, C), ~C) -> (icmp ult X, C)
// iff -C is a power of 2
- if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2())
+ if (Pred == ICmpInst::ICMP_UGT && *XorC == ~C && (C + 1).isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_ULT, X, Y);
// (icmp ult (xor X, C), -C) -> (icmp uge X, C)
// iff -C is a power of 2
- if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2())
+ if (Pred == ICmpInst::ICMP_ULT && *XorC == -C && C.isPowerOf2())
return new ICmpInst(ICmpInst::ICMP_UGE, X, Y);
return nullptr;
@@ -1562,7 +1513,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
/// Fold icmp (and (sh X, Y), C2), C1.
Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
- const APInt *C1, const APInt *C2) {
+ const APInt &C1, const APInt &C2) {
BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0));
if (!Shift || !Shift->isShift())
return nullptr;
@@ -1577,32 +1528,35 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
const APInt *C3;
if (match(Shift->getOperand(1), m_APInt(C3))) {
bool CanFold = false;
- if (ShiftOpcode == Instruction::AShr) {
- // There may be some constraints that make this possible, but nothing
- // simple has been discovered yet.
- CanFold = false;
- } else if (ShiftOpcode == Instruction::Shl) {
+ if (ShiftOpcode == Instruction::Shl) {
// For a left shift, we can fold if the comparison is not signed. We can
// also fold a signed comparison if the mask value and comparison value
// are not negative. These constraints may not be obvious, but we can
// prove that they are correct using an SMT solver.
- if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative()))
+ if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative()))
CanFold = true;
- } else if (ShiftOpcode == Instruction::LShr) {
+ } else {
+ bool IsAshr = ShiftOpcode == Instruction::AShr;
// For a logical right shift, we can fold if the comparison is not signed.
// We can also fold a signed comparison if the shifted mask value and the
// shifted comparison value are not negative. These constraints may not be
// obvious, but we can prove that they are correct using an SMT solver.
- if (!Cmp.isSigned() ||
- (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative()))
- CanFold = true;
+ // For an arithmetic shift right we can do the same, if we ensure
+ // the And doesn't use any bits being shifted in. Normally these would
+ // be turned into lshr by SimplifyDemandedBits, but not if there is an
+ // additional user.
+ if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) {
+ if (!Cmp.isSigned() ||
+ (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative()))
+ CanFold = true;
+ }
}
if (CanFold) {
- APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3);
+ APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3);
APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3);
// Check to see if we are shifting out any of the bits being compared.
- if (SameAsC1 != *C1) {
+ if (SameAsC1 != C1) {
// If we shifted bits out, the fold is not going to work out. As a
// special case, check to see if this means that the result is always
// true or false now.
@@ -1612,7 +1566,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType()));
} else {
Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst));
- APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3);
+ APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3);
And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst));
And->setOperand(0, Shift->getOperand(0));
Worklist.Add(Shift); // Shift is dead.
@@ -1624,7 +1578,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
// Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is
// preferable because it allows the C2 << Y expression to be hoisted out of a
// loop if Y is invariant and X is not.
- if (Shift->hasOneUse() && C1->isNullValue() && Cmp.isEquality() &&
+ if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() &&
!Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) {
// Compute C2 << Y.
Value *NewShift =
@@ -1643,12 +1597,12 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
/// Fold icmp (and X, C2), C1.
Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
BinaryOperator *And,
- const APInt *C1) {
+ const APInt &C1) {
const APInt *C2;
if (!match(And->getOperand(1), m_APInt(C2)))
return nullptr;
- if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse())
+ if (!And->hasOneUse())
return nullptr;
// If the LHS is an 'and' of a truncate and we can widen the and/compare to
@@ -1660,29 +1614,29 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
// set or if it is an equality comparison. Extending a relational comparison
// when we're checking the sign bit would not work.
Value *W;
- if (match(And->getOperand(0), m_Trunc(m_Value(W))) &&
- (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) {
+ if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) &&
+ (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) {
// TODO: Is this a good transform for vectors? Wider types may reduce
// throughput. Should this transform be limited (even for scalars) by using
// shouldChangeType()?
if (!Cmp.getType()->isVectorTy()) {
Type *WideType = W->getType();
unsigned WideScalarBits = WideType->getScalarSizeInBits();
- Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits));
+ Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits));
Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits));
Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName());
return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1);
}
}
- if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2))
+ if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2))
return I;
// (icmp pred (and (or (lshr A, B), A), 1), 0) -->
// (icmp pred (and A, (or (shl 1, B), 1), 0))
//
// iff pred isn't signed
- if (!Cmp.isSigned() && C1->isNullValue() &&
+ if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() &&
match(And->getOperand(1), m_One())) {
Constant *One = cast<Constant>(And->getOperand(1));
Value *Or = And->getOperand(0);
@@ -1716,22 +1670,13 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
}
}
- // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a
- // result greater than C1.
- unsigned NumTZ = C2->countTrailingZeros();
- if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() &&
- APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) {
- Constant *Zero = Constant::getNullValue(And->getType());
- return new ICmpInst(ICmpInst::ICMP_NE, And, Zero);
- }
-
return nullptr;
}
/// Fold icmp (and X, Y), C.
Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
BinaryOperator *And,
- const APInt *C) {
+ const APInt &C) {
if (Instruction *I = foldICmpAndConstConst(Cmp, And, C))
return I;
@@ -1756,7 +1701,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
// X & -C == -C -> X > u ~C
// X & -C != -C -> X <= u ~C
// iff C is a power of 2
- if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) {
+ if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) {
auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT
: CmpInst::ICMP_ULE;
return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1))));
@@ -1766,7 +1711,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
// (X & C2) != 0 -> (trunc X) < 0
// iff C2 is a power of 2 and it masks the sign bit of a legal integer type.
const APInt *C2;
- if (And->hasOneUse() && C->isNullValue() && match(Y, m_APInt(C2))) {
+ if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) {
int32_t ExactLogBase2 = C2->exactLogBase2();
if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1);
@@ -1784,9 +1729,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
/// Fold icmp (or X, Y), C.
Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
- const APInt *C) {
+ const APInt &C) {
ICmpInst::Predicate Pred = Cmp.getPredicate();
- if (C->isOneValue()) {
+ if (C.isOneValue()) {
// icmp slt signum(V) 1 --> icmp slt V, 1
Value *V = nullptr;
if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V))))
@@ -1798,12 +1743,12 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
// X | C != C --> X >u C
// iff C+1 is a power of 2 (C is a bitmask of the low bits)
if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) &&
- (*C + 1).isPowerOf2()) {
+ (C + 1).isPowerOf2()) {
Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1));
}
- if (!Cmp.isEquality() || !C->isNullValue() || !Or->hasOneUse())
+ if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse())
return nullptr;
Value *P, *Q;
@@ -1837,7 +1782,7 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
/// Fold icmp (mul X, Y), C.
Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
BinaryOperator *Mul,
- const APInt *C) {
+ const APInt &C) {
const APInt *MulC;
if (!match(Mul->getOperand(1), m_APInt(MulC)))
return nullptr;
@@ -1845,7 +1790,7 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
// If this is a test of the sign bit and the multiply is sign-preserving with
// a constant operand, use the multiply LHS operand instead.
ICmpInst::Predicate Pred = Cmp.getPredicate();
- if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) {
+ if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) {
if (MulC->isNegative())
Pred = ICmpInst::getSwappedPredicate(Pred);
return new ICmpInst(Pred, Mul->getOperand(0),
@@ -1857,14 +1802,14 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp,
/// Fold icmp (shl 1, Y), C.
static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
- const APInt *C) {
+ const APInt &C) {
Value *Y;
if (!match(Shl, m_Shl(m_One(), m_Value(Y))))
return nullptr;
Type *ShiftType = Shl->getType();
- uint32_t TypeBits = C->getBitWidth();
- bool CIsPowerOf2 = C->isPowerOf2();
+ unsigned TypeBits = C.getBitWidth();
+ bool CIsPowerOf2 = C.isPowerOf2();
ICmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isUnsigned()) {
// (1 << Y) pred C -> Y pred Log2(C)
@@ -1881,7 +1826,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
// (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31
// (1 << Y) < 2147483648 -> Y < 31 -> Y != 31
- unsigned CLog2 = C->logBase2();
+ unsigned CLog2 = C.logBase2();
if (CLog2 == TypeBits - 1) {
if (Pred == ICmpInst::ICMP_UGE)
Pred = ICmpInst::ICMP_EQ;
@@ -1891,7 +1836,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
} else if (Cmp.isSigned()) {
Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
- if (C->isAllOnesValue()) {
+ if (C.isAllOnesValue()) {
// (1 << Y) <= -1 -> Y == 31
if (Pred == ICmpInst::ICMP_SLE)
return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
@@ -1899,7 +1844,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
// (1 << Y) > -1 -> Y != 31
if (Pred == ICmpInst::ICMP_SGT)
return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
- } else if (!(*C)) {
+ } else if (!C) {
// (1 << Y) < 0 -> Y == 31
// (1 << Y) <= 0 -> Y == 31
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
@@ -1911,7 +1856,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
}
} else if (Cmp.isEquality() && CIsPowerOf2) {
- return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2()));
+ return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2()));
}
return nullptr;
@@ -1920,10 +1865,10 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
/// Fold icmp (shl X, Y), C.
Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
BinaryOperator *Shl,
- const APInt *C) {
+ const APInt &C) {
const APInt *ShiftVal;
if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal)))
- return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal);
+ return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal);
const APInt *ShiftAmt;
if (!match(Shl->getOperand(1), m_APInt(ShiftAmt)))
@@ -1931,7 +1876,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
// Check that the shift amount is in range. If not, don't perform undefined
// shifts. When the shift is visited, it will be simplified.
- unsigned TypeBits = C->getBitWidth();
+ unsigned TypeBits = C.getBitWidth();
if (ShiftAmt->uge(TypeBits))
return nullptr;
@@ -1945,15 +1890,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasNoSignedWrap()) {
if (Pred == ICmpInst::ICMP_SGT) {
// icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt)
- APInt ShiftedC = C->ashr(*ShiftAmt);
+ APInt ShiftedC = C.ashr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
// This is the same code as the SGT case, but assert the pre-condition
// that is needed for this to work with equality predicates.
- assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C &&
"Compare known true or false was not folded");
- APInt ShiftedC = C->ashr(*ShiftAmt);
+ APInt ShiftedC = C.ashr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_SLT) {
@@ -1961,14 +1906,14 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
// (X << S) <=s C is equiv to X <=s (C >> S) for all C
// (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX
// (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN
- assert(!C->isMinSignedValue() && "Unexpected icmp slt");
- APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1;
+ assert(!C.isMinSignedValue() && "Unexpected icmp slt");
+ APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1;
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
// If this is a signed comparison to 0 and the shift is sign preserving,
// use the shift LHS operand instead; isSignTest may change 'Pred', so only
// do that if we're sure to not continue on in this function.
- if (isSignTest(Pred, *C))
+ if (isSignTest(Pred, C))
return new ICmpInst(Pred, X, Constant::getNullValue(ShType));
}
@@ -1978,15 +1923,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasNoUnsignedWrap()) {
if (Pred == ICmpInst::ICMP_UGT) {
// icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt)
- APInt ShiftedC = C->lshr(*ShiftAmt);
+ APInt ShiftedC = C.lshr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
// This is the same code as the UGT case, but assert the pre-condition
// that is needed for this to work with equality predicates.
- assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C &&
"Compare known true or false was not folded");
- APInt ShiftedC = C->lshr(*ShiftAmt);
+ APInt ShiftedC = C.lshr(*ShiftAmt);
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
if (Pred == ICmpInst::ICMP_ULT) {
@@ -1994,8 +1939,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
// (X << S) <=u C is equiv to X <=u (C >> S) for all C
// (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u
// (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0
- assert(C->ugt(0) && "ult 0 should have been eliminated");
- APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1;
+ assert(C.ugt(0) && "ult 0 should have been eliminated");
+ APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1;
return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
}
@@ -2006,13 +1951,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
ShType,
APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask");
- Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt));
+ Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt));
return new ICmpInst(Pred, And, LShrC);
}
// Otherwise, if this is a comparison of the sign bit, simplify to and/test.
bool TrueIfSigned = false;
- if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) {
+ if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) {
// (X << 31) <s 0 --> (X & 1) != 0
Constant *Mask = ConstantInt::get(
ShType,
@@ -2029,13 +1974,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
// free on the target. It has the additional benefit of comparing to a
// smaller constant that may be more target-friendly.
unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1);
- if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt &&
+ if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
if (ShType->isVectorTy())
TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());
Constant *NewC =
- ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt));
+ ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
}
@@ -2045,18 +1990,18 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
/// Fold icmp ({al}shr X, Y), C.
Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
BinaryOperator *Shr,
- const APInt *C) {
+ const APInt &C) {
// An exact shr only shifts out zero bits, so:
// icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0
Value *X = Shr->getOperand(0);
CmpInst::Predicate Pred = Cmp.getPredicate();
if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() &&
- C->isNullValue())
+ C.isNullValue())
return new ICmpInst(Pred, X, Cmp.getOperand(1));
const APInt *ShiftVal;
if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal)))
- return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal);
+ return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal);
const APInt *ShiftAmt;
if (!match(Shr->getOperand(1), m_APInt(ShiftAmt)))
@@ -2064,71 +2009,73 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
// Check that the shift amount is in range. If not, don't perform undefined
// shifts. When the shift is visited it will be simplified.
- unsigned TypeBits = C->getBitWidth();
+ unsigned TypeBits = C.getBitWidth();
unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits);
if (ShAmtVal >= TypeBits || ShAmtVal == 0)
return nullptr;
bool IsAShr = Shr->getOpcode() == Instruction::AShr;
- if (!Cmp.isEquality()) {
- // If we have an unsigned comparison and an ashr, we can't simplify this.
- // Similarly for signed comparisons with lshr.
- if (Cmp.isSigned() != IsAShr)
- return nullptr;
-
- // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv
- // by a power of 2. Since we already have logic to simplify these,
- // transform to div and then simplify the resultant comparison.
- if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1))
- return nullptr;
-
- // Revisit the shift (to delete it).
- Worklist.Add(Shr);
-
- Constant *DivCst = ConstantInt::get(
- Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal));
-
- Value *Tmp = IsAShr ? Builder.CreateSDiv(X, DivCst, "", Shr->isExact())
- : Builder.CreateUDiv(X, DivCst, "", Shr->isExact());
-
- Cmp.setOperand(0, Tmp);
-
- // If the builder folded the binop, just return it.
- BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp);
- if (!TheDiv)
- return &Cmp;
-
- // Otherwise, fold this div/compare.
- assert(TheDiv->getOpcode() == Instruction::SDiv ||
- TheDiv->getOpcode() == Instruction::UDiv);
-
- Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C);
- assert(Res && "This div/cst should have folded!");
- return Res;
+ bool IsExact = Shr->isExact();
+ Type *ShrTy = Shr->getType();
+ // TODO: If we could guarantee that InstSimplify would handle all of the
+ // constant-value-based preconditions in the folds below, then we could assert
+ // those conditions rather than checking them. This is difficult because of
+ // undef/poison (PR34838).
+ if (IsAShr) {
+ if (Pred == CmpInst::ICMP_SLT || (Pred == CmpInst::ICMP_SGT && IsExact)) {
+ // icmp slt (ashr X, ShAmtC), C --> icmp slt X, (C << ShAmtC)
+ // icmp sgt (ashr exact X, ShAmtC), C --> icmp sgt X, (C << ShAmtC)
+ APInt ShiftedC = C.shl(ShAmtVal);
+ if (ShiftedC.ashr(ShAmtVal) == C)
+ return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
+ }
+ if (Pred == CmpInst::ICMP_SGT) {
+ // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1
+ APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1;
+ if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() &&
+ (ShiftedC + 1).ashr(ShAmtVal) == (C + 1))
+ return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
+ }
+ } else {
+ if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) {
+ // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC)
+ // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC)
+ APInt ShiftedC = C.shl(ShAmtVal);
+ if (ShiftedC.lshr(ShAmtVal) == C)
+ return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
+ }
+ if (Pred == CmpInst::ICMP_UGT) {
+ // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1
+ APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1;
+ if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1))
+ return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC));
+ }
}
+ if (!Cmp.isEquality())
+ return nullptr;
+
// Handle equality comparisons of shift-by-constant.
// If the comparison constant changes with the shift, the comparison cannot
// succeed (bits of the comparison constant cannot match the shifted value).
// This should be known by InstSimplify and already be folded to true/false.
- assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) ||
- (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) &&
+ assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) ||
+ (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) &&
"Expected icmp+shr simplify did not occur.");
- // Check if the bits shifted out are known to be zero. If so, we can compare
- // against the unshifted value:
+ // If the bits shifted out are known zero, compare the unshifted value:
// (X & 4) >> 1 == 2 --> (X & 4) == 4.
- Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal);
- if (Shr->hasOneUse()) {
- if (Shr->isExact())
- return new ICmpInst(Pred, X, ShiftedCmpRHS);
+ if (Shr->isExact())
+ return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal));
- // Otherwise strength reduce the shift into an 'and'.
+ if (Shr->hasOneUse()) {
+ // Canonicalize the shift into an 'and':
+ // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt)
APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal));
- Constant *Mask = ConstantInt::get(Shr->getType(), Val);
+ Constant *Mask = ConstantInt::get(ShrTy, Val);
Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask");
- return new ICmpInst(Pred, And, ShiftedCmpRHS);
+ return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal));
}
return nullptr;
@@ -2137,7 +2084,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp,
/// Fold icmp (udiv X, Y), C.
Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
BinaryOperator *UDiv,
- const APInt *C) {
+ const APInt &C) {
const APInt *C2;
if (!match(UDiv->getOperand(0), m_APInt(C2)))
return nullptr;
@@ -2147,17 +2094,17 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
// (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1))
Value *Y = UDiv->getOperand(1);
if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) {
- assert(!C->isMaxValue() &&
+ assert(!C.isMaxValue() &&
"icmp ugt X, UINT_MAX should have been simplified already.");
return new ICmpInst(ICmpInst::ICMP_ULE, Y,
- ConstantInt::get(Y->getType(), C2->udiv(*C + 1)));
+ ConstantInt::get(Y->getType(), C2->udiv(C + 1)));
}
// (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C)
if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) {
- assert(*C != 0 && "icmp ult X, 0 should have been simplified already.");
+ assert(C != 0 && "icmp ult X, 0 should have been simplified already.");
return new ICmpInst(ICmpInst::ICMP_UGT, Y,
- ConstantInt::get(Y->getType(), C2->udiv(*C)));
+ ConstantInt::get(Y->getType(), C2->udiv(C)));
}
return nullptr;
@@ -2166,7 +2113,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp,
/// Fold icmp ({su}div X, Y), C.
Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
BinaryOperator *Div,
- const APInt *C) {
+ const APInt &C) {
// Fold: icmp pred ([us]div X, C2), C -> range test
// Fold this div into the comparison, producing a range check.
// Determine, based on the divide type, what the range is being
@@ -2197,28 +2144,22 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
(DivIsSigned && C2->isAllOnesValue()))
return nullptr;
- // TODO: We could do all of the computations below using APInt.
- Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1));
- Constant *DivRHS = cast<Constant>(Div->getOperand(1));
-
- // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of
- // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS).
+ // Compute Prod = C * C2. We are essentially solving an equation of
+ // form X / C2 = C. We solve for X by multiplying C2 and C.
// By solving for X, we can turn this into a range check instead of computing
// a divide.
- Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS);
+ APInt Prod = C * *C2;
// Determine if the product overflows by seeing if the product is not equal to
// the divide. Make sure we do the same kind of divide as in the LHS
// instruction that we're folding.
- bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS)
- : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS;
+ bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C;
ICmpInst::Predicate Pred = Cmp.getPredicate();
// If the division is known to be exact, then there is no remainder from the
// divide, so the covered range size is unit, otherwise it is the divisor.
- Constant *RangeSize =
- Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS;
+ APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2;
// Figure out the interval that is being checked. For example, a comparison
// like "X /u 5 == 0" is really checking that X is in the interval [0, 5).
@@ -2228,7 +2169,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
// overflow variable is set to 0 if it's corresponding bound variable is valid
// -1 if overflowed off the bottom end, or +1 if overflowed off the top end.
int LoOverflow = 0, HiOverflow = 0;
- Constant *LoBound = nullptr, *HiBound = nullptr;
+ APInt LoBound, HiBound;
if (!DivIsSigned) { // udiv
// e.g. X/5 op 3 --> [15, 20)
@@ -2240,38 +2181,38 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false);
}
} else if (C2->isStrictlyPositive()) { // Divisor is > 0.
- if (C->isNullValue()) { // (X / pos) op 0
+ if (C.isNullValue()) { // (X / pos) op 0
// Can't overflow. e.g. X/2 op 0 --> [-1, 2)
- LoBound = ConstantExpr::getNeg(SubOne(RangeSize));
+ LoBound = -(RangeSize - 1);
HiBound = RangeSize;
- } else if (C->isStrictlyPositive()) { // (X / pos) op pos
+ } else if (C.isStrictlyPositive()) { // (X / pos) op pos
LoBound = Prod; // e.g. X/5 op 3 --> [15, 20)
HiOverflow = LoOverflow = ProdOV;
if (!HiOverflow)
HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true);
} else { // (X / pos) op neg
// e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14)
- HiBound = AddOne(Prod);
+ HiBound = Prod + 1;
LoOverflow = HiOverflow = ProdOV ? -1 : 0;
if (!LoOverflow) {
- Constant *DivNeg = ConstantExpr::getNeg(RangeSize);
+ APInt DivNeg = -RangeSize;
LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0;
}
}
} else if (C2->isNegative()) { // Divisor is < 0.
if (Div->isExact())
- RangeSize = ConstantExpr::getNeg(RangeSize);
- if (C->isNullValue()) { // (X / neg) op 0
+ RangeSize.negate();
+ if (C.isNullValue()) { // (X / neg) op 0
// e.g. X/-5 op 0 --> [-4, 5)
- LoBound = AddOne(RangeSize);
- HiBound = ConstantExpr::getNeg(RangeSize);
- if (HiBound == DivRHS) { // -INTMIN = INTMIN
+ LoBound = RangeSize + 1;
+ HiBound = -RangeSize;
+ if (HiBound == *C2) { // -INTMIN = INTMIN
HiOverflow = 1; // [INTMIN+1, overflow)
- HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN
+ HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN
}
- } else if (C->isStrictlyPositive()) { // (X / neg) op pos
+ } else if (C.isStrictlyPositive()) { // (X / neg) op pos
// e.g. X/-5 op 3 --> [-19, -14)
- HiBound = AddOne(Prod);
+ HiBound = Prod + 1;
HiOverflow = LoOverflow = ProdOV ? -1 : 0;
if (!LoOverflow)
LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0;
@@ -2294,25 +2235,27 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
return replaceInstUsesWith(Cmp, Builder.getFalse());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
- ICmpInst::ICMP_UGE, X, LoBound);
+ ICmpInst::ICMP_UGE, X,
+ ConstantInt::get(Div->getType(), LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
- ICmpInst::ICMP_ULT, X, HiBound);
+ ICmpInst::ICMP_ULT, X,
+ ConstantInt::get(Div->getType(), HiBound));
return replaceInstUsesWith(
- Cmp, insertRangeTest(X, LoBound->getUniqueInteger(),
- HiBound->getUniqueInteger(), DivIsSigned, true));
+ Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true));
case ICmpInst::ICMP_NE:
if (LoOverflow && HiOverflow)
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (HiOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT :
- ICmpInst::ICMP_ULT, X, LoBound);
+ ICmpInst::ICMP_ULT, X,
+ ConstantInt::get(Div->getType(), LoBound));
if (LoOverflow)
return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE :
- ICmpInst::ICMP_UGE, X, HiBound);
+ ICmpInst::ICMP_UGE, X,
+ ConstantInt::get(Div->getType(), HiBound));
return replaceInstUsesWith(Cmp,
- insertRangeTest(X, LoBound->getUniqueInteger(),
- HiBound->getUniqueInteger(),
+ insertRangeTest(X, LoBound, HiBound,
DivIsSigned, false));
case ICmpInst::ICMP_ULT:
case ICmpInst::ICMP_SLT:
@@ -2320,7 +2263,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (LoOverflow == -1) // Low bound is less than input range.
return replaceInstUsesWith(Cmp, Builder.getFalse());
- return new ICmpInst(Pred, X, LoBound);
+ return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound));
case ICmpInst::ICMP_UGT:
case ICmpInst::ICMP_SGT:
if (HiOverflow == +1) // High bound greater than input range.
@@ -2328,8 +2271,10 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
if (HiOverflow == -1) // High bound less than input range.
return replaceInstUsesWith(Cmp, Builder.getTrue());
if (Pred == ICmpInst::ICMP_UGT)
- return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound);
- return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound);
+ return new ICmpInst(ICmpInst::ICMP_UGE, X,
+ ConstantInt::get(Div->getType(), HiBound));
+ return new ICmpInst(ICmpInst::ICMP_SGE, X,
+ ConstantInt::get(Div->getType(), HiBound));
}
return nullptr;
@@ -2338,7 +2283,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp,
/// Fold icmp (sub X, Y), C.
Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
BinaryOperator *Sub,
- const APInt *C) {
+ const APInt &C) {
Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1);
ICmpInst::Predicate Pred = Cmp.getPredicate();
@@ -2349,19 +2294,19 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
if (Sub->hasNoSignedWrap()) {
// (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y)
- if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())
+ if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue())
return new ICmpInst(ICmpInst::ICMP_SGE, X, Y);
// (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y)
- if (Pred == ICmpInst::ICMP_SGT && C->isNullValue())
+ if (Pred == ICmpInst::ICMP_SGT && C.isNullValue())
return new ICmpInst(ICmpInst::ICMP_SGT, X, Y);
// (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y)
- if (Pred == ICmpInst::ICMP_SLT && C->isNullValue())
+ if (Pred == ICmpInst::ICMP_SLT && C.isNullValue())
return new ICmpInst(ICmpInst::ICMP_SLT, X, Y);
// (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y)
- if (Pred == ICmpInst::ICMP_SLT && C->isOneValue())
+ if (Pred == ICmpInst::ICMP_SLT && C.isOneValue())
return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);
}
@@ -2371,14 +2316,14 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
// C2 - Y <u C -> (Y | (C - 1)) == C2
// iff (C2 & (C - 1)) == C - 1 and C is a power of 2
- if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() &&
- (*C2 & (*C - 1)) == (*C - 1))
- return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, *C - 1), X);
+ if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() &&
+ (*C2 & (C - 1)) == (C - 1))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X);
// C2 - Y >u C -> (Y | C) != C2
// iff C2 & C == C and C + 1 is a power of 2
- if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C)
- return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, *C), X);
+ if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C)
+ return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X);
return nullptr;
}
@@ -2386,7 +2331,7 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,
/// Fold icmp (add X, Y), C.
Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
BinaryOperator *Add,
- const APInt *C) {
+ const APInt &C) {
Value *Y = Add->getOperand(1);
const APInt *C2;
if (Cmp.isEquality() || !match(Y, m_APInt(C2)))
@@ -2403,7 +2348,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
if (Add->hasNoSignedWrap() &&
(Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) {
bool Overflow;
- APInt NewC = C->ssub_ov(*C2, Overflow);
+ APInt NewC = C.ssub_ov(*C2, Overflow);
// If there is overflow, the result must be true or false.
// TODO: Can we assert there is no overflow because InstSimplify always
// handles those cases?
@@ -2412,7 +2357,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
}
- auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2);
+ auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2);
const APInt &Upper = CR.getUpper();
const APInt &Lower = CR.getLower();
if (Cmp.isSigned()) {
@@ -2433,15 +2378,15 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
// X+C <u C2 -> (X & -C2) == C
// iff C & (C2-1) == 0
// C2 is a power of 2
- if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0)
- return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -(*C)),
+ if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C),
ConstantExpr::getNeg(cast<Constant>(Y)));
// X+C >u C2 -> (X & ~C2) != C
// iff C & C2 == 0
// C2+1 is a power of 2
- if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0)
- return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~(*C)),
+ if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0)
+ return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C),
ConstantExpr::getNeg(cast<Constant>(Y)));
return nullptr;
@@ -2471,7 +2416,7 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS,
}
Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,
- Instruction *Select,
+ SelectInst *Select,
ConstantInt *C) {
assert(C && "Cmp RHS should be a constant int!");
@@ -2483,8 +2428,8 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,
Value *OrigLHS, *OrigRHS;
ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan;
if (Cmp.hasOneUse() &&
- matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS,
- C1LessThan, C2Equal, C3GreaterThan)) {
+ matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal,
+ C3GreaterThan)) {
assert(C1LessThan && C2Equal && C3GreaterThan);
bool TrueWhenLessThan =
@@ -2525,82 +2470,74 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
if (!match(Cmp.getOperand(1), m_APInt(C)))
return nullptr;
- BinaryOperator *BO;
- if (match(Cmp.getOperand(0), m_BinOp(BO))) {
+ if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) {
switch (BO->getOpcode()) {
case Instruction::Xor:
- if (Instruction *I = foldICmpXorConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpXorConstant(Cmp, BO, *C))
return I;
break;
case Instruction::And:
- if (Instruction *I = foldICmpAndConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpAndConstant(Cmp, BO, *C))
return I;
break;
case Instruction::Or:
- if (Instruction *I = foldICmpOrConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpOrConstant(Cmp, BO, *C))
return I;
break;
case Instruction::Mul:
- if (Instruction *I = foldICmpMulConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpMulConstant(Cmp, BO, *C))
return I;
break;
case Instruction::Shl:
- if (Instruction *I = foldICmpShlConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpShlConstant(Cmp, BO, *C))
return I;
break;
case Instruction::LShr:
case Instruction::AShr:
- if (Instruction *I = foldICmpShrConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C))
return I;
break;
case Instruction::UDiv:
- if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C))
return I;
LLVM_FALLTHROUGH;
case Instruction::SDiv:
- if (Instruction *I = foldICmpDivConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpDivConstant(Cmp, BO, *C))
return I;
break;
case Instruction::Sub:
- if (Instruction *I = foldICmpSubConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpSubConstant(Cmp, BO, *C))
return I;
break;
case Instruction::Add:
- if (Instruction *I = foldICmpAddConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpAddConstant(Cmp, BO, *C))
return I;
break;
default:
break;
}
// TODO: These folds could be refactored to be part of the above calls.
- if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C))
+ if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, *C))
return I;
}
// Match against CmpInst LHS being instructions other than binary operators.
- Instruction *LHSI;
- if (match(Cmp.getOperand(0), m_Instruction(LHSI))) {
- switch (LHSI->getOpcode()) {
- case Instruction::Select:
- {
- // For now, we only support constant integers while folding the
- // ICMP(SELECT)) pattern. We can extend this to support vector of integers
- // similar to the cases handled by binary ops above.
- if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)))
- if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS))
- return I;
- break;
- }
- case Instruction::Trunc:
- if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C))
+
+ if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) {
+ // For now, we only support constant integers while folding the
+ // ICMP(SELECT)) pattern. We can extend this to support vector of integers
+ // similar to the cases handled by binary ops above.
+ if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1)))
+ if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS))
return I;
- break;
- default:
- break;
- }
}
- if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C))
+ if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) {
+ if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C))
+ return I;
+ }
+
+ if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C))
return I;
return nullptr;
@@ -2610,7 +2547,7 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {
/// icmp eq/ne BO, C.
Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
BinaryOperator *BO,
- const APInt *C) {
+ const APInt &C) {
// TODO: Some of these folds could work with arbitrary constants, but this
// function is limited to scalar and vector splat constants.
if (!Cmp.isEquality())
@@ -2624,7 +2561,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
switch (BO->getOpcode()) {
case Instruction::SRem:
// If we have a signed (X % (2^c)) == 0, turn it into an unsigned one.
- if (C->isNullValue() && BO->hasOneUse()) {
+ if (C.isNullValue() && BO->hasOneUse()) {
const APInt *BOC;
if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) {
Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName());
@@ -2641,7 +2578,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1));
return new ICmpInst(Pred, BOp0, SubC);
}
- } else if (C->isNullValue()) {
+ } else if (C.isNullValue()) {
// Replace ((add A, B) != 0) with (A != -B) if A or B is
// efficiently invertible, or if the add has just this one use.
if (Value *NegVal = dyn_castNegVal(BOp1))
@@ -2662,7 +2599,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
// For the xor case, we can xor two constants together, eliminating
// the explicit xor.
return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC));
- } else if (C->isNullValue()) {
+ } else if (C.isNullValue()) {
// Replace ((xor A, B) != 0) with (A != B)
return new ICmpInst(Pred, BOp0, BOp1);
}
@@ -2675,7 +2612,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
// Replace ((sub BOC, B) != C) with (B != BOC-C).
Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS);
return new ICmpInst(Pred, BOp1, SubC);
- } else if (C->isNullValue()) {
+ } else if (C.isNullValue()) {
// Replace ((sub A, B) != 0) with (A != B).
return new ICmpInst(Pred, BOp0, BOp1);
}
@@ -2697,7 +2634,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
const APInt *BOC;
if (match(BOp1, m_APInt(BOC))) {
// If we have ((X & C) == C), turn it into ((X & C) != 0).
- if (C == BOC && C->isPowerOf2())
+ if (C == *BOC && C.isPowerOf2())
return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,
BO, Constant::getNullValue(RHS->getType()));
@@ -2713,7 +2650,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
}
// ((X & ~7) == 0) --> X < 8
- if (C->isNullValue() && (~(*BOC) + 1).isPowerOf2()) {
+ if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) {
Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1));
auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT;
return new ICmpInst(NewPred, BOp0, NegBOC);
@@ -2722,7 +2659,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
break;
}
case Instruction::Mul:
- if (C->isNullValue() && BO->hasNoSignedWrap()) {
+ if (C.isNullValue() && BO->hasNoSignedWrap()) {
const APInt *BOC;
if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) {
// The trivial case (mul X, 0) is handled by InstSimplify.
@@ -2733,7 +2670,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
}
break;
case Instruction::UDiv:
- if (C->isNullValue()) {
+ if (C.isNullValue()) {
// (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A)
auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
return new ICmpInst(NewPred, BOp1, BOp0);
@@ -2747,7 +2684,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C.
Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
- const APInt *C) {
+ const APInt &C) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0));
if (!II || !Cmp.isEquality())
return nullptr;
@@ -2758,13 +2695,13 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
case Intrinsic::bswap:
Worklist.Add(II);
Cmp.setOperand(0, II->getArgOperand(0));
- Cmp.setOperand(1, ConstantInt::get(Ty, C->byteSwap()));
+ Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap()));
return &Cmp;
case Intrinsic::ctlz:
case Intrinsic::cttz:
// ctz(A) == bitwidth(A) -> A == 0 and likewise for !=
- if (*C == C->getBitWidth()) {
+ if (C == C.getBitWidth()) {
Worklist.Add(II);
Cmp.setOperand(0, II->getArgOperand(0));
Cmp.setOperand(1, ConstantInt::getNullValue(Ty));
@@ -2775,8 +2712,8 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,
case Intrinsic::ctpop: {
// popcount(A) == 0 -> A == 0 and likewise for !=
// popcount(A) == bitwidth(A) -> A == -1 and likewise for !=
- bool IsZero = C->isNullValue();
- if (IsZero || *C == C->getBitWidth()) {
+ bool IsZero = C.isNullValue();
+ if (IsZero || C == C.getBitWidth()) {
Worklist.Add(II);
Cmp.setOperand(0, II->getArgOperand(0));
auto *NewOp =
@@ -3924,31 +3861,29 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
/// When performing a comparison against a constant, it is possible that not all
/// the bits in the LHS are demanded. This helper method computes the mask that
/// IS demanded.
-static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth,
- bool isSignCheck) {
- if (isSignCheck)
- return APInt::getSignMask(BitWidth);
+static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) {
+ const APInt *RHS;
+ if (!match(I.getOperand(1), m_APInt(RHS)))
+ return APInt::getAllOnesValue(BitWidth);
- ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1));
- if (!CI) return APInt::getAllOnesValue(BitWidth);
- const APInt &RHS = CI->getValue();
+ // If this is a normal comparison, it demands all bits. If it is a sign bit
+ // comparison, it only demands the sign bit.
+ bool UnusedBit;
+ if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit))
+ return APInt::getSignMask(BitWidth);
switch (I.getPredicate()) {
// For a UGT comparison, we don't care about any bits that
// correspond to the trailing ones of the comparand. The value of these
// bits doesn't impact the outcome of the comparison, because any value
// greater than the RHS must differ in a bit higher than these due to carry.
- case ICmpInst::ICMP_UGT: {
- unsigned trailingOnes = RHS.countTrailingOnes();
- return APInt::getBitsSetFrom(BitWidth, trailingOnes);
- }
+ case ICmpInst::ICMP_UGT:
+ return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes());
// Similarly, for a ULT comparison, we don't care about the trailing zeros.
// Any value less than the RHS must differ in a higher bit because of carries.
- case ICmpInst::ICMP_ULT: {
- unsigned trailingZeros = RHS.countTrailingZeros();
- return APInt::getBitsSetFrom(BitWidth, trailingZeros);
- }
+ case ICmpInst::ICMP_ULT:
+ return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros());
default:
return APInt::getAllOnesValue(BitWidth);
@@ -4122,20 +4057,11 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
if (!BitWidth)
return nullptr;
- // If this is a normal comparison, it demands all bits. If it is a sign bit
- // comparison, it only demands the sign bit.
- bool IsSignBit = false;
- const APInt *CmpC;
- if (match(Op1, m_APInt(CmpC))) {
- bool UnusedBit;
- IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit);
- }
-
KnownBits Op0Known(BitWidth);
KnownBits Op1Known(BitWidth);
if (SimplifyDemandedBits(&I, 0,
- getDemandedBitsLHSMask(I, BitWidth, IsSignBit),
+ getDemandedBitsLHSMask(I, BitWidth),
Op0Known, 0))
return &I;
@@ -4233,20 +4159,22 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
const APInt *CmpC;
if (match(Op1, m_APInt(CmpC))) {
// A <u C -> A == C-1 if min(A)+1 == C
- if (Op1Max == Op0Min + 1) {
- Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1);
- return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1);
- }
+ if (*CmpC == Op0Min + 1)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ ConstantInt::get(Op1->getType(), *CmpC - 1));
+ // X <u C --> X == 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2())
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
+ Constant::getNullValue(Op1->getType()));
}
break;
}
case ICmpInst::ICMP_UGT: {
if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
-
if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
@@ -4256,42 +4184,52 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
if (*CmpC == Op0Max - 1)
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
ConstantInt::get(Op1->getType(), *CmpC + 1));
+ // X >u C --> X != 0, if the number of zero bits in the bottom of X
+ // exceeds the log2 of C.
+ if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits())
+ return new ICmpInst(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Op1->getType()));
}
break;
}
- case ICmpInst::ICMP_SLT:
+ case ICmpInst::ICMP_SLT: {
if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
- if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Builder.getInt(CI->getValue() - 1));
+ ConstantInt::get(Op1->getType(), *CmpC - 1));
}
break;
- case ICmpInst::ICMP_SGT:
+ }
+ case ICmpInst::ICMP_SGT: {
if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
-
if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B)
return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1);
- if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) {
- if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
+ const APInt *CmpC;
+ if (match(Op1, m_APInt(CmpC))) {
+ if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C
return new ICmpInst(ICmpInst::ICMP_EQ, Op0,
- Builder.getInt(CI->getValue() + 1));
+ ConstantInt::get(Op1->getType(), *CmpC + 1));
}
break;
+ }
case ICmpInst::ICMP_SGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!");
if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B)
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_SLE:
assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!");
@@ -4299,6 +4237,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_UGE:
assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!");
@@ -4306,6 +4246,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
case ICmpInst::ICMP_ULE:
assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!");
@@ -4313,6 +4255,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType()));
if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B)
return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType()));
+ if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B)
+ return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1);
break;
}
@@ -4478,7 +4422,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- // comparing -val or val with non-zero is the same as just comparing val
+ // Comparing -val or val with non-zero is the same as just comparing val
// ie, abs(val) != 0 -> val != 0
if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) {
Value *Cond, *SelectTrue, *SelectFalse;
@@ -4515,11 +4459,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
// and CodeGen. And in this case, at least one of the comparison
// operands has at least one user besides the compare (the select),
// which would often largely negate the benefit of folding anyway.
+ //
+ // Do the same for the other patterns recognized by matchSelectPattern.
if (I.hasOneUse())
- if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin()))
- if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) ||
- (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1))
+ if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) {
+ Value *A, *B;
+ SelectPatternResult SPR = matchSelectPattern(SI, A, B);
+ if (SPR.Flavor != SPF_UNKNOWN)
return nullptr;
+ }
+
+ // Do this after checking for min/max to prevent infinite looping.
+ if (Instruction *Res = foldICmpWithZero(I))
+ return Res;
// FIXME: We only do this after checking for min/max to prevent infinite
// looping caused by a reverse canonicalization of these patterns for min/max.
@@ -4684,11 +4636,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
Value *X; ConstantInt *Cst;
// icmp X+Cst, X
if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X)
- return foldICmpAddOpConst(I, X, Cst, I.getPredicate());
+ return foldICmpAddOpConst(X, Cst, I.getPredicate());
// icmp X, X+Cst
if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X)
- return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate());
+ return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate());
}
return Changed ? &I : nullptr;
}
@@ -4943,17 +4895,16 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
Changed = true;
}
+ const CmpInst::Predicate Pred = I.getPredicate();
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (Value *V =
- SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(),
- SQ.getWithInstruction(&I)))
+ if (Value *V = SimplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(),
+ SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
// Simplify 'fcmp pred X, X'
if (Op0 == Op1) {
- switch (I.getPredicate()) {
- default: llvm_unreachable("Unknown predicate!");
+ switch (Pred) {
+ default: break;
case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y)
case FCmpInst::FCMP_ULT: // True if unordered or less than
case FCmpInst::FCMP_UGT: // True if unordered or greater than
@@ -4974,6 +4925,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
}
}
+ // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand,
+ // then canonicalize the operand to 0.0.
+ if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) {
+ if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) {
+ I.setOperand(0, ConstantFP::getNullValue(Op0->getType()));
+ return &I;
+ }
+ if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) {
+ I.setOperand(1, ConstantFP::getNullValue(Op0->getType()));
+ return &I;
+ }
+ }
+
// Test if the FCmpInst instruction is used exclusively by a select as
// part of a minimum or maximum operation. If so, refrain from doing
// any other folding. This helps out other analyses which understand
@@ -4982,10 +4946,12 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
// operands has at least one user besides the compare (the select),
// which would often largely negate the benefit of folding anyway.
if (I.hasOneUse())
- if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin()))
- if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) ||
- (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1))
+ if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) {
+ Value *A, *B;
+ SelectPatternResult SPR = matchSelectPattern(SI, A, B);
+ if (SPR.Flavor != SPF_UNKNOWN)
return nullptr;
+ }
// Handle fcmp with constant RHS
if (Constant *RHSC = dyn_cast<Constant>(Op1)) {
@@ -5027,7 +4993,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
((Fabs.compare(APFloat::getSmallestNormalized(*Sem)) !=
APFloat::cmpLessThan) || Fabs.isZero()))
- return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0),
+ return new FCmpInst(Pred, LHSExt->getOperand(0),
ConstantFP::get(RHSC->getContext(), F));
break;
}
@@ -5072,7 +5038,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
break;
// Various optimization for fabs compared with zero.
- switch (I.getPredicate()) {
+ switch (Pred) {
default:
break;
// fabs(x) < 0 --> false
@@ -5093,7 +5059,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
case FCmpInst::FCMP_UEQ:
case FCmpInst::FCMP_ONE:
case FCmpInst::FCMP_UNE:
- return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC);
+ return new FCmpInst(Pred, CI->getArgOperand(0), RHSC);
}
}
}
@@ -5108,8 +5074,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
if (FPExtInst *LHSExt = dyn_cast<FPExtInst>(Op0))
if (FPExtInst *RHSExt = dyn_cast<FPExtInst>(Op1))
if (LHSExt->getSrcTy() == RHSExt->getSrcTy())
- return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0),
- RHSExt->getOperand(0));
+ return new FCmpInst(Pred, LHSExt->getOperand(0), RHSExt->getOperand(0));
return Changed ? &I : nullptr;
}