diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 737 |
1 files changed, 351 insertions, 386 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index a8faaecb5c34..3bc7fae77cb1 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,9 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -37,77 +35,30 @@ using namespace PatternMatch; STATISTIC(NumSel, "Number of select opts"); -static ConstantInt *extractElement(Constant *V, Constant *Idx) { - return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); -} - -static bool hasAddOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ult(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().sgt(In1->getValue()); - return Result->getValue().slt(In1->getValue()); -} - /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool addWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getAdd(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasAddOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } - - return hasAddOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); -} - -static bool hasSubOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ugt(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().slt(In1->getValue()); +static bool addWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.sadd_ov(In2, Overflow); + else + Result = In1.uadd_ov(In2, Overflow); - return Result->getValue().sgt(In1->getValue()); + return Overflow; } /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool subWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getSub(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasSubOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } +static bool subWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.ssub_ov(In2, Overflow); + else + Result = In1.usub_ov(In2, Overflow); - return hasSubOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); + return Overflow; } /// Given an icmp instruction, return true if any use of this comparison is a @@ -473,8 +424,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // Look for an appropriate type: // - The type of Idx if the magic fits - // - The smallest fitting legal type if we have a DataLayout - // - Default to i32 + // - The smallest fitting legal type if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); else @@ -1108,7 +1058,6 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, // because we don't allow ptrtoint. Memcpy and memmove are safe because // we don't allow stores, so src cannot point to V. case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: - case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: continue; default: @@ -1131,8 +1080,7 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, - Value *X, ConstantInt *CI, +Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" @@ -1367,6 +1315,24 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + return nullptr; +} + // Fold icmp Pred X, C. Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1398,17 +1364,6 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { return Res; } - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (C->isNullValue() && Pred == ICmpInst::ICMP_SGT) { - SelectPatternResult SPR = matchSelectPattern(X, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, A, Cmp.getOperand(1)); - } - } - // FIXME: Use m_APInt to allow folds for splat constants. ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); if (!CI) @@ -1462,11 +1417,11 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { /// Fold icmp (trunc X, Y), C. Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, - Instruction *Trunc, - const APInt *C) { + TruncInst *Trunc, + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C->isOneValue() && C->getBitWidth() > 1) { + if (C.isOneValue() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1484,7 +1439,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, // If all the high bits are known, we can do this xform. if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. - APInt NewRHS = C->zext(SrcBits); + APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); } @@ -1496,7 +1451,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, /// Fold icmp (xor X, Y), C. Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, - const APInt *C) { + const APInt &C) { Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1506,8 +1461,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if ((Pred == ICmpInst::ICMP_SLT && C->isNullValue()) || - (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + bool TrueIfSigned = false; + if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) { // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. @@ -1517,17 +1472,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, return &Cmp; } - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; - - // If so, the new one isn't. - isTrueIfPositive ^= true; - - Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + // Emit the opposite comparison. + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(X->getType())); else - return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(X->getType())); } if (Xor->hasOneUse()) { @@ -1535,7 +1486,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, if (!Cmp.isEquality() && XorC->isSignMask()) { Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) @@ -1543,18 +1494,18 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); Pred = Cmp.getSwappedPredicate(Pred); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } } // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~C && (C + 1).isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); // (icmp ult (xor X, C), -C) -> (icmp uge X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + if (Pred == ICmpInst::ICMP_ULT && *XorC == -C && C.isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); return nullptr; @@ -1562,7 +1513,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, /// Fold icmp (and (sh X, Y), C2), C1. Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1, const APInt *C2) { + const APInt &C1, const APInt &C2) { BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); if (!Shift || !Shift->isShift()) return nullptr; @@ -1577,32 +1528,35 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, const APInt *C3; if (match(Shift->getOperand(1), m_APInt(C3))) { bool CanFold = false; - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, but nothing - // simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { + if (ShiftOpcode == Instruction::Shl) { // For a left shift, we can fold if the comparison is not signed. We can // also fold a signed comparison if the mask value and comparison value // are not negative. These constraints may not be obvious, but we can // prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative())) CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { + } else { + bool IsAshr = ShiftOpcode == Instruction::AShr; // For a logical right shift, we can fold if the comparison is not signed. // We can also fold a signed comparison if the shifted mask value and the // shifted comparison value are not negative. These constraints may not be // obvious, but we can prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || - (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) - CanFold = true; + // For an arithmetic shift right we can do the same, if we ensure + // the And doesn't use any bits being shifted in. Normally these would + // be turned into lshr by SimplifyDemandedBits, but not if there is an + // additional user. + if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) { + if (!Cmp.isSigned() || + (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative())) + CanFold = true; + } } if (CanFold) { - APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3); APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); // Check to see if we are shifting out any of the bits being compared. - if (SameAsC1 != *C1) { + if (SameAsC1 != C1) { // If we shifted bits out, the fold is not going to work out. As a // special case, check to see if this means that the result is always // true or false now. @@ -1612,7 +1566,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); } else { Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); - APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3); And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); And->setOperand(0, Shift->getOperand(0)); Worklist.Add(Shift); // Shift is dead. @@ -1624,7 +1578,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1->isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1643,12 +1597,12 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, /// Fold icmp (and X, C2), C1. Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1) { + const APInt &C1) { const APInt *C2; if (!match(And->getOperand(1), m_APInt(C2))) return nullptr; - if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + if (!And->hasOneUse()) return nullptr; // If the LHS is an 'and' of a truncate and we can widen the and/compare to @@ -1660,29 +1614,29 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, // set or if it is an equality comparison. Extending a relational comparison // when we're checking the sign bit would not work. Value *W; - if (match(And->getOperand(0), m_Trunc(m_Value(W))) && - (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) && + (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); - Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits)); Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); } } - if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2)) return I; // (icmp pred (and (or (lshr A, B), A), 1), 0) --> // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1->isNullValue() && + if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1716,22 +1670,13 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } } - // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a - // result greater than C1. - unsigned NumTZ = C2->countTrailingZeros(); - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && - APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { - Constant *Zero = Constant::getNullValue(And->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); - } - return nullptr; } /// Fold icmp (and X, Y), C. Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C) { + const APInt &C) { if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) return I; @@ -1756,7 +1701,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1766,7 +1711,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C->isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1784,9 +1729,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, /// Fold icmp (or X, Y), C. Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt *C) { + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C->isOneValue()) { + if (C.isOneValue()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1798,12 +1743,12 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // X | C != C --> X >u C // iff C+1 is a power of 2 (C is a bitmask of the low bits) if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && - (*C + 1).isPowerOf2()) { + (C + 1).isPowerOf2()) { Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); } - if (!Cmp.isEquality() || !C->isNullValue() || !Or->hasOneUse()) + if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -1837,7 +1782,7 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, /// Fold icmp (mul X, Y), C. Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, - const APInt *C) { + const APInt &C) { const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; @@ -1845,7 +1790,7 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, // If this is a test of the sign bit and the multiply is sign-preserving with // a constant operand, use the multiply LHS operand instead. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); return new ICmpInst(Pred, Mul->getOperand(0), @@ -1857,14 +1802,14 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, /// Fold icmp (shl 1, Y), C. static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, - const APInt *C) { + const APInt &C) { Value *Y; if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) return nullptr; Type *ShiftType = Shl->getType(); - uint32_t TypeBits = C->getBitWidth(); - bool CIsPowerOf2 = C->isPowerOf2(); + unsigned TypeBits = C.getBitWidth(); + bool CIsPowerOf2 = C.isPowerOf2(); ICmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isUnsigned()) { // (1 << Y) pred C -> Y pred Log2(C) @@ -1881,7 +1826,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 - unsigned CLog2 = C->logBase2(); + unsigned CLog2 = C.logBase2(); if (CLog2 == TypeBits - 1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; @@ -1891,7 +1836,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C->isAllOnesValue()) { + if (C.isAllOnesValue()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -1899,7 +1844,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) > -1 -> Y != 31 if (Pred == ICmpInst::ICMP_SGT) return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!(*C)) { + } else if (!C) { // (1 << Y) < 0 -> Y == 31 // (1 << Y) <= 0 -> Y == 31 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) @@ -1911,7 +1856,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); } } else if (Cmp.isEquality() && CIsPowerOf2) { - return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); } return nullptr; @@ -1920,10 +1865,10 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, /// Fold icmp (shl X, Y), C. Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, - const APInt *C) { + const APInt &C) { const APInt *ShiftVal; if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) @@ -1931,7 +1876,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited, it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); if (ShiftAmt->uge(TypeBits)) return nullptr; @@ -1945,15 +1890,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoSignedWrap()) { if (Pred == ICmpInst::ICMP_SGT) { // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the SGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_SLT) { @@ -1961,14 +1906,14 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=s C is equiv to X <=s (C >> S) for all C // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN - assert(!C->isMinSignedValue() && "Unexpected icmp slt"); - APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + assert(!C.isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } // If this is a signed comparison to 0 and the shift is sign preserving, // use the shift LHS operand instead; isSignTest may change 'Pred', so only // do that if we're sure to not continue on in this function. - if (isSignTest(Pred, *C)) + if (isSignTest(Pred, C)) return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); } @@ -1978,15 +1923,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoUnsignedWrap()) { if (Pred == ICmpInst::ICMP_UGT) { // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the UGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_ULT) { @@ -1994,8 +1939,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=u C is equiv to X <=u (C >> S) for all C // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert(C->ugt(0) && "ult 0 should have been eliminated"); - APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + assert(C.ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } @@ -2006,13 +1951,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ShType, APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); - Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt)); return new ICmpInst(Pred, And, LShrC); } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; - if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( ShType, @@ -2029,13 +1974,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (ShType->isVectorTy()) TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = - ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); } @@ -2045,18 +1990,18 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, /// Fold icmp ({al}shr X, Y), C. Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, - const APInt *C) { + const APInt &C) { // An exact shr only shifts out zero bits, so: // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C->isNullValue()) + C.isNullValue()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) @@ -2064,71 +2009,73 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) return nullptr; bool IsAShr = Shr->getOpcode() == Instruction::AShr; - if (!Cmp.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (Cmp.isSigned() != IsAShr) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; - - // Revisit the shift (to delete it). - Worklist.Add(Shr); - - Constant *DivCst = ConstantInt::get( - Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); - - Value *Tmp = IsAShr ? Builder.CreateSDiv(X, DivCst, "", Shr->isExact()) - : Builder.CreateUDiv(X, DivCst, "", Shr->isExact()); - - Cmp.setOperand(0, Tmp); - - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &Cmp; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); - - Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); - assert(Res && "This div/cst should have folded!"); - return Res; + bool IsExact = Shr->isExact(); + Type *ShrTy = Shr->getType(); + // TODO: If we could guarantee that InstSimplify would handle all of the + // constant-value-based preconditions in the folds below, then we could assert + // those conditions rather than checking them. This is difficult because of + // undef/poison (PR34838). + if (IsAShr) { + if (Pred == CmpInst::ICMP_SLT || (Pred == CmpInst::ICMP_SGT && IsExact)) { + // icmp slt (ashr X, ShAmtC), C --> icmp slt X, (C << ShAmtC) + // icmp sgt (ashr exact X, ShAmtC), C --> icmp sgt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.ashr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_SGT) { + // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() && + (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + } else { + if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { + // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) + // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.lshr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_UGT) { + // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } } + if (!Cmp.isEquality()) + return nullptr; + // Handle equality comparisons of shift-by-constant. // If the comparison constant changes with the shift, the comparison cannot // succeed (bits of the comparison constant cannot match the shifted value). // This should be known by InstSimplify and already be folded to true/false. - assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || - (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) || + (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) && "Expected icmp+shr simplify did not occur."); - // Check if the bits shifted out are known to be zero. If so, we can compare - // against the unshifted value: + // If the bits shifted out are known zero, compare the unshifted value: // (X & 4) >> 1 == 2 --> (X & 4) == 4. - Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); - if (Shr->hasOneUse()) { - if (Shr->isExact()) - return new ICmpInst(Pred, X, ShiftedCmpRHS); + if (Shr->isExact()) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - // Otherwise strength reduce the shift into an 'and'. + if (Shr->hasOneUse()) { + // Canonicalize the shift into an 'and': + // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Constant *Mask = ConstantInt::get(ShrTy, Val); Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); - return new ICmpInst(Pred, And, ShiftedCmpRHS); + return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal)); } return nullptr; @@ -2137,7 +2084,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, - const APInt *C) { + const APInt &C) { const APInt *C2; if (!match(UDiv->getOperand(0), m_APInt(C2))) return nullptr; @@ -2147,17 +2094,17 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) Value *Y = UDiv->getOperand(1); if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C->isMaxValue() && + assert(!C.isMaxValue() && "icmp ugt X, UINT_MAX should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_ULE, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + ConstantInt::get(Y->getType(), C2->udiv(C + 1))); } // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { - assert(*C != 0 && "icmp ult X, 0 should have been simplified already."); + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_UGT, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C))); + ConstantInt::get(Y->getType(), C2->udiv(C))); } return nullptr; @@ -2166,7 +2113,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, /// Fold icmp ({su}div X, Y), C. Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, - const APInt *C) { + const APInt &C) { // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2197,28 +2144,22 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, (DivIsSigned && C2->isAllOnesValue())) return nullptr; - // TODO: We could do all of the computations below using APInt. - Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); - Constant *DivRHS = cast<Constant>(Div->getOperand(1)); - - // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of - // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // Compute Prod = C * C2. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 and C. // By solving for X, we can turn this into a range check instead of computing // a divide. - Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); + APInt Prod = C * *C2; // Determine if the product overflows by seeing if the product is not equal to // the divide. Make sure we do the same kind of divide as in the LHS // instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) - : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - Constant *RangeSize = - Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; + APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -2228,7 +2169,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = nullptr, *HiBound = nullptr; + APInt LoBound, HiBound; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -2240,38 +2181,38 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C->isNullValue()) { // (X / pos) op 0 + if (C.isNullValue()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C->isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + APInt DivNeg = -RangeSize; LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) - RangeSize = ConstantExpr::getNeg(RangeSize); - if (C->isNullValue()) { // (X / neg) op 0 + RangeSize.negate(); + if (C.isNullValue()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(RangeSize); - HiBound = ConstantExpr::getNeg(RangeSize); - if (HiBound == DivRHS) { // -INTMIN = INTMIN + LoBound = RangeSize + 1; + HiBound = -RangeSize; + if (HiBound == *C2) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C->isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; @@ -2294,25 +2235,27 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), DivIsSigned, true)); + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return replaceInstUsesWith(Cmp, Builder.getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), + insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: @@ -2320,7 +2263,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getTrue()); if (LoOverflow == -1) // Low bound is less than input range. return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, LoBound); + return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. @@ -2328,8 +2271,10 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, if (HiOverflow == -1) // High bound less than input range. return replaceInstUsesWith(Cmp, Builder.getTrue()); if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, + ConstantInt::get(Div->getType(), HiBound)); } return nullptr; @@ -2338,7 +2283,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, /// Fold icmp (sub X, Y), C. Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, - const APInt *C) { + const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); @@ -2349,19 +2294,19 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2371,14 +2316,14 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, // C2 - Y <u C -> (Y | (C - 1)) == C2 // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == (*C - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, *C - 1), X); + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && + (*C2 & (C - 1)) == (C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X); // C2 - Y >u C -> (Y | C) != C2 // iff C2 & C == C and C + 1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, *C), X); + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); return nullptr; } @@ -2386,7 +2331,7 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, /// Fold icmp (add X, Y), C. Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt *C) { + const APInt &C) { Value *Y = Add->getOperand(1); const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) @@ -2403,7 +2348,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, if (Add->hasNoSignedWrap() && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { bool Overflow; - APInt NewC = C->ssub_ov(*C2, Overflow); + APInt NewC = C.ssub_ov(*C2, Overflow); // If there is overflow, the result must be true or false. // TODO: Can we assert there is no overflow because InstSimplify always // handles those cases? @@ -2412,7 +2357,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); } - auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); + auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { @@ -2433,15 +2378,15 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -(*C)), + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~(*C)), + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); return nullptr; @@ -2471,7 +2416,7 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, - Instruction *Select, + SelectInst *Select, ConstantInt *C) { assert(C && "Cmp RHS should be a constant int!"); @@ -2483,8 +2428,8 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, Value *OrigLHS, *OrigRHS; ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; if (Cmp.hasOneUse() && - matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS, - C1LessThan, C2Equal, C3GreaterThan)) { + matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal, + C3GreaterThan)) { assert(C1LessThan && C2Equal && C3GreaterThan); bool TrueWhenLessThan = @@ -2525,82 +2470,74 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (!match(Cmp.getOperand(1), m_APInt(C))) return nullptr; - BinaryOperator *BO; - if (match(Cmp.getOperand(0), m_BinOp(BO))) { + if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) { switch (BO->getOpcode()) { case Instruction::Xor: - if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpXorConstant(Cmp, BO, *C)) return I; break; case Instruction::And: - if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAndConstant(Cmp, BO, *C)) return I; break; case Instruction::Or: - if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpOrConstant(Cmp, BO, *C)) return I; break; case Instruction::Mul: - if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpMulConstant(Cmp, BO, *C)) return I; break; case Instruction::Shl: - if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShlConstant(Cmp, BO, *C)) return I; break; case Instruction::LShr: case Instruction::AShr: - if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; case Instruction::UDiv: - if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; LLVM_FALLTHROUGH; case Instruction::SDiv: - if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpDivConstant(Cmp, BO, *C)) return I; break; case Instruction::Sub: - if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpSubConstant(Cmp, BO, *C)) return I; break; case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAddConstant(Cmp, BO, *C)) return I; break; default: break; } // TODO: These folds could be refactored to be part of the above calls. - if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, *C)) return I; } // Match against CmpInst LHS being instructions other than binary operators. - Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { - switch (LHSI->getOpcode()) { - case Instruction::Select: - { - // For now, we only support constant integers while folding the - // ICMP(SELECT)) pattern. We can extend this to support vector of integers - // similar to the cases handled by binary ops above. - if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) - if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) - return I; - break; - } - case Instruction::Trunc: - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + + if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) return I; - break; - default: - break; - } } - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) { + if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) + return I; + } + + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) return I; return nullptr; @@ -2610,7 +2547,7 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { /// icmp eq/ne BO, C. Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, - const APInt *C) { + const APInt &C) { // TODO: Some of these folds could work with arbitrary constants, but this // function is limited to scalar and vector splat constants. if (!Cmp.isEquality()) @@ -2624,7 +2561,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C->isNullValue() && BO->hasOneUse()) { + if (C.isNullValue() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -2641,7 +2578,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); return new ICmpInst(Pred, BOp0, SubC); } - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -2662,7 +2599,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } @@ -2675,7 +2612,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // Replace ((sub BOC, B) != C) with (B != BOC-C). Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); return new ICmpInst(Pred, BOp1, SubC); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); } @@ -2697,7 +2634,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, const APInt *BOC; if (match(BOp1, m_APInt(BOC))) { // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (C == BOC && C->isPowerOf2()) + if (C == *BOC && C.isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); @@ -2713,7 +2650,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } // ((X & ~7) == 0) --> X < 8 - if (C->isNullValue() && (~(*BOC) + 1).isPowerOf2()) { + if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; return new ICmpInst(NewPred, BOp0, NegBOC); @@ -2722,7 +2659,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; } case Instruction::Mul: - if (C->isNullValue() && BO->hasNoSignedWrap()) { + if (C.isNullValue() && BO->hasNoSignedWrap()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) { // The trivial case (mul X, 0) is handled by InstSimplify. @@ -2733,7 +2670,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } break; case Instruction::UDiv: - if (C->isNullValue()) { + if (C.isNullValue()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -2747,7 +2684,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - const APInt *C) { + const APInt &C) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); if (!II || !Cmp.isEquality()) return nullptr; @@ -2758,13 +2695,13 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::bswap: Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::get(Ty, C->byteSwap())); + Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap())); return &Cmp; case Intrinsic::ctlz: case Intrinsic::cttz: // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (*C == C->getBitWidth()) { + if (C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); @@ -2775,8 +2712,8 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C->isNullValue(); - if (IsZero || *C == C->getBitWidth()) { + bool IsZero = C.isNullValue(); + if (IsZero || C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); auto *NewOp = @@ -3924,31 +3861,29 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, - bool isSignCheck) { - if (isSignCheck) - return APInt::getSignMask(BitWidth); +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { + const APInt *RHS; + if (!match(I.getOperand(1), m_APInt(RHS))) + return APInt::getAllOnesValue(BitWidth); - ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1)); - if (!CI) return APInt::getAllOnesValue(BitWidth); - const APInt &RHS = CI->getValue(); + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool UnusedBit; + if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + return APInt::getSignMask(BitWidth); switch (I.getPredicate()) { // For a UGT comparison, we don't care about any bits that // correspond to the trailing ones of the comparand. The value of these // bits doesn't impact the outcome of the comparison, because any value // greater than the RHS must differ in a bit higher than these due to carry. - case ICmpInst::ICMP_UGT: { - unsigned trailingOnes = RHS.countTrailingOnes(); - return APInt::getBitsSetFrom(BitWidth, trailingOnes); - } + case ICmpInst::ICMP_UGT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. - case ICmpInst::ICMP_ULT: { - unsigned trailingZeros = RHS.countTrailingZeros(); - return APInt::getBitsSetFrom(BitWidth, trailingZeros); - } + case ICmpInst::ICMP_ULT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: return APInt::getAllOnesValue(BitWidth); @@ -4122,20 +4057,11 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (!BitWidth) return nullptr; - // If this is a normal comparison, it demands all bits. If it is a sign bit - // comparison, it only demands the sign bit. - bool IsSignBit = false; - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - bool UnusedBit; - IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); - } - KnownBits Op0Known(BitWidth); KnownBits Op1Known(BitWidth); if (SimplifyDemandedBits(&I, 0, - getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + getDemandedBitsLHSMask(I, BitWidth), Op0Known, 0)) return &I; @@ -4233,20 +4159,22 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { const APInt *CmpC; if (match(Op1, m_APInt(CmpC))) { // A <u C -> A == C-1 if min(A)+1 == C - if (Op1Max == Op0Min + 1) { - Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); - } + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X <u C --> X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); } break; } case ICmpInst::ICMP_UGT: { if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -4256,42 +4184,52 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (*CmpC == Op0Max - 1) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); } break; } - case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLT: { if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() - 1)); + ConstantInt::get(Op1->getType(), *CmpC - 1)); } break; - case ICmpInst::ICMP_SGT: + } + case ICmpInst::ICMP_SGT: { if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() + 1)); + ConstantInt::get(Op1->getType(), *CmpC + 1)); } break; + } case ICmpInst::ICMP_SGE: assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_SLE: assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); @@ -4299,6 +4237,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_UGE: assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); @@ -4306,6 +4246,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_ULE: assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); @@ -4313,6 +4255,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; } @@ -4478,7 +4422,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - // comparing -val or val with non-zero is the same as just comparing val + // Comparing -val or val with non-zero is the same as just comparing val // ie, abs(val) != 0 -> val != 0 if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { Value *Cond, *SelectTrue, *SelectFalse; @@ -4515,11 +4459,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and CodeGen. And in this case, at least one of the comparison // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. + // + // Do the same for the other patterns recognized by matchSelectPattern. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } + + // Do this after checking for min/max to prevent infinite looping. + if (Instruction *Res = foldICmpWithZero(I)) + return Res; // FIXME: We only do this after checking for min/max to prevent infinite // looping caused by a reverse canonicalization of these patterns for min/max. @@ -4684,11 +4636,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } @@ -4943,17 +4895,16 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Changed = true; } + const CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = - SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' if (Op0 == Op1) { - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown predicate!"); + switch (Pred) { + default: break; case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) case FCmpInst::FCMP_ULT: // True if unordered or less than case FCmpInst::FCMP_UGT: // True if unordered or greater than @@ -4974,6 +4925,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } + // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, + // then canonicalize the operand to 0.0. + if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { + if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) { + I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) { + I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + } + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -4982,10 +4946,12 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } // Handle fcmp with constant RHS if (Constant *RHSC = dyn_cast<Constant>(Op1)) { @@ -5027,7 +4993,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { ((Fabs.compare(APFloat::getSmallestNormalized(*Sem)) != APFloat::cmpLessThan) || Fabs.isZero())) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), + return new FCmpInst(Pred, LHSExt->getOperand(0), ConstantFP::get(RHSC->getContext(), F)); break; } @@ -5072,7 +5038,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; // Various optimization for fabs compared with zero. - switch (I.getPredicate()) { + switch (Pred) { default: break; // fabs(x) < 0 --> false @@ -5093,7 +5059,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_UEQ: case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); + return new FCmpInst(Pred, CI->getArgOperand(0), RHSC); } } } @@ -5108,8 +5074,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (FPExtInst *LHSExt = dyn_cast<FPExtInst>(Op0)) if (FPExtInst *RHSExt = dyn_cast<FPExtInst>(Op1)) if (LHSExt->getSrcTy() == RHSExt->getSrcTy()) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), - RHSExt->getOperand(0)); + return new FCmpInst(Pred, LHSExt->getOperand(0), RHSExt->getOperand(0)); return Changed ? &I : nullptr; } |