diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 870 |
1 files changed, 620 insertions, 250 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3a4283ae5406..a9f64feb600c 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -69,34 +69,6 @@ static bool hasBranchUse(ICmpInst &I) { return false; } -/// Given an exploded icmp instruction, return true if the comparison only -/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the -/// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, - bool &TrueIfSigned) { - switch (Pred) { - case ICmpInst::ICMP_SLT: // True if LHS s< 0 - TrueIfSigned = true; - return RHS.isNullValue(); - case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 - TrueIfSigned = true; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_SGT: // True if LHS s> -1 - TrueIfSigned = false; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_UGT: - // True if LHS u> RHS and RHS == high-bit-mask - 1 - TrueIfSigned = true; - return RHS.isMaxSignedValue(); - case ICmpInst::ICMP_UGE: - // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) - TrueIfSigned = true; - return RHS.isSignMask(); - default: - return false; - } -} - /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. @@ -832,6 +804,10 @@ getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, const DataLayout &DL) { + // FIXME: Support vector of pointers. + if (GEPLHS->getType()->isVectorTy()) + return nullptr; + if (!GEPLHS->hasAllConstantIndices()) return nullptr; @@ -882,7 +858,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (PtrBase == RHS && GEPLHS->isInBounds()) { + // FIXME: Support vector pointer GEPs. + if (PtrBase == RHS && GEPLHS->isInBounds() && + !GEPLHS->getType()->isVectorTy()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can @@ -894,6 +872,37 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); + } + + if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) && + isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() && + !NullPointerIsDefined(I.getFunction(), + RHS->getType()->getPointerAddressSpace())) { + // For most address spaces, an allocation can't be placed at null, but null + // itself is treated as a 0 size allocation in the in bounds rules. Thus, + // the only valid inbounds address derived from null, is null itself. + // Thus, we have four cases to consider: + // 1) Base == nullptr, Offset == 0 -> inbounds, null + // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds + // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations) + // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison) + // + // (Note if we're indexing a type of size 0, that simply collapses into one + // of the buckets above.) + // + // In general, we're allowed to make values less poison (i.e. remove + // sources of full UB), so in this case, we just select between the two + // non-poison cases (1 and 4 above). + // + // For vectors, we apply the same reasoning on a per-lane basis. + auto *Base = GEPLHS->getPointerOperand(); + if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { + int NumElts = GEPLHS->getType()->getVectorNumElements(); + Base = Builder.CreateVectorSplat(NumElts, Base); + } + return new ICmpInst(Cond, Base, + ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(RHS), Base->getType())); } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { // If the base pointers are different, but the indices are the same, just // compare the base pointer. @@ -916,11 +925,13 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. + // FIXME: Support vector of pointers. if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts() && + !GEPLHS->getType()->isVectorTy()) { Value *LOffset = EmitGEPOffset(GEPLHS); Value *ROffset = EmitGEPOffset(GEPRHS); @@ -949,12 +960,14 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } // If one of the GEPs has all zero indices, recurse. - if (GEPLHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - if (GEPRHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -964,15 +977,20 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, unsigned DiffOperand = 0; // The operand that differs. for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { - if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() != - GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) { + Type *LHSType = GEPLHS->getOperand(i)->getType(); + Type *RHSType = GEPRHS->getOperand(i)->getType(); + // FIXME: Better support for vector of pointers. + if (LHSType->getPrimitiveSizeInBits() != + RHSType->getPrimitiveSizeInBits() || + (GEPLHS->getType()->isVectorTy() && + (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) { // Irreconcilable differences. NumDifferences = 2; break; - } else { - if (NumDifferences++) break; - DiffOperand = i; } + + if (NumDifferences++) break; + DiffOperand = i; } if (NumDifferences == 0) // SAME GEP? @@ -1317,6 +1335,59 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +/// If we have: +/// icmp eq/ne (urem/srem %x, %y), 0 +/// iff %y is a power-of-two, we can replace this with a bit test: +/// icmp eq/ne (and %x, (add %y, -1)), 0 +Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { + // This fold is only valid for equality predicates. + if (!I.isEquality()) + return nullptr; + ICmpInst::Predicate Pred; + Value *X, *Y, *Zero; + if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), + m_CombineAnd(m_Zero(), m_Value(Zero))))) + return nullptr; + if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) + return nullptr; + // This may increase instruction count, we don't enforce that Y is a constant. + Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); + Value *Masked = Builder.CreateAnd(X, Mask); + return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); +} + +/// Fold equality-comparison between zero and any (maybe truncated) right-shift +/// by one-less-than-bitwidth into a sign test on the original value. +Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { + Instruction *Val; + ICmpInst::Predicate Pred; + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) + return nullptr; + + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa<BinaryOperator>(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast<BinaryOperator>(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else + return nullptr; + + return ICmpInst::Create(Instruction::ICmp, + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + X, ConstantInt::getNullValue(XTy)); +} + // Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1335,6 +1406,9 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { } } + if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) + return New; + // Given: // icmp eq/ne (urem %x, %y), 0 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': @@ -2179,6 +2253,44 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { + // Match an 'is positive' or 'is negative' comparison of remainder by a + // constant power-of-2 value: + // (X % pow2C) sgt/slt 0 + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT) + return nullptr; + + // TODO: The one-use check is standard because we do not typically want to + // create longer instruction sequences, but this might be a special-case + // because srem is not good for analysis or codegen. + if (!SRem->hasOneUse()) + return nullptr; + + const APInt *DivisorC; + if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // Mask off the sign bit and the modulo bits (low-bits). + Type *Ty = SRem->getType(); + APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); + Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); + Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + + // For 'is positive?' check that the sign-bit is clear and at least 1 masked + // bit is set. Example: + // (i8 X % 32) s> 0 --> (X & 159) s> 0 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); + + // For 'is negative?' check that the sign-bit is set and at least 1 masked + // bit is set. Example: + // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 + return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); +} + /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, @@ -2387,6 +2499,11 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, const APInt *C2; APInt SubResult; + // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 + if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) + return new ICmpInst(Cmp.getPredicate(), Y, + ConstantInt::get(Y->getType(), 0)); + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) if (match(X, m_APInt(C2)) && ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || @@ -2509,20 +2626,49 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, // TODO: Generalize this to work with other comparison idioms or ensure // they get canonicalized into this form. - // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 - // Greater), where Equal, Less and Greater are placeholders for any three - // constants. - ICmpInst::Predicate PredA, PredB; - if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && - match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && - PredA == ICmpInst::ICMP_EQ && - match(SI->getFalseValue(), - m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), - m_ConstantInt(Less), m_ConstantInt(Greater))) && - PredB == ICmpInst::ICMP_SLT) { - return true; + // select i1 (a == b), + // i32 Equal, + // i32 (select i1 (a < b), i32 Less, i32 Greater) + // where Equal, Less and Greater are placeholders for any three constants. + ICmpInst::Predicate PredA; + if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || + !ICmpInst::isEquality(PredA)) + return false; + Value *EqualVal = SI->getTrueValue(); + Value *UnequalVal = SI->getFalseValue(); + // We still can get non-canonical predicate here, so canonicalize. + if (PredA == ICmpInst::ICMP_NE) + std::swap(EqualVal, UnequalVal); + if (!match(EqualVal, m_ConstantInt(Equal))) + return false; + ICmpInst::Predicate PredB; + Value *LHS2, *RHS2; + if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) + return false; + // We can get predicate mismatch here, so canonicalize if possible: + // First, ensure that 'LHS' match. + if (LHS2 != LHS) { + // x sgt y <--> y slt x + std::swap(LHS2, RHS2); + PredB = ICmpInst::getSwappedPredicate(PredB); + } + if (LHS2 != LHS) + return false; + // We also need to canonicalize 'RHS'. + if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { + // x sgt C-1 <--> x sge C <--> not(x slt C) + auto FlippedStrictness = + getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2)); + if (!FlippedStrictness) + return false; + assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check"); + RHS2 = FlippedStrictness->second; + // And kind-of perform the result swap. + std::swap(Less, Greater); + PredB = ICmpInst::ICMP_SLT; } - return false; + return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, @@ -2702,6 +2848,10 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C)) + return I; + break; case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; @@ -2926,6 +3076,28 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, } break; } + + case Intrinsic::uadd_sat: { + // uadd.sat(a, b) == 0 -> (a | b) == 0 + if (C.isNullValue()) { + Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); + return replaceInstUsesWith(Cmp, Builder.CreateICmp( + Cmp.getPredicate(), Or, Constant::getNullValue(Ty))); + + } + break; + } + + case Intrinsic::usub_sat: { + // usub.sat(a, b) == 0 -> a <= b + if (C.isNullValue()) { + ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return ICmpInst::Create(Instruction::ICmp, NewPred, + II->getArgOperand(0), II->getArgOperand(1)); + } + break; + } default: break; } @@ -3275,6 +3447,7 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I, // we should move shifts to the same hand of 'and', i.e. rewrite as // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) // We are only interested in opposite logical shifts here. +// One of the shifts can be truncated. // If we can, we want to end up creating 'lshr' shift. static Value * foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, @@ -3284,55 +3457,215 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, return nullptr; auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); - auto m_AnyLShr = m_LShr(m_Value(), m_Value()); - - // Look for an 'and' of two (opposite) logical shifts. - // Pick the single-use shift as XShift. - Value *XShift, *YShift; - if (!match(I.getOperand(0), - m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), - m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) + + // Look for an 'and' of two logical shifts, one of which may be truncated. + // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. + Instruction *XShift, *MaybeTruncation, *YShift; + if (!match( + I.getOperand(0), + m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), + m_CombineAnd(m_TruncOrSelf(m_CombineAnd( + m_AnyLogicalShift, m_Instruction(YShift))), + m_Instruction(MaybeTruncation))))) return nullptr; - // If YShift is a single-use 'lshr', swap the shifts around. - if (match(YShift, m_OneUse(m_AnyLShr))) + // We potentially looked past 'trunc', but only when matching YShift, + // therefore YShift must have the widest type. + Instruction *WidestShift = YShift; + // Therefore XShift must have the shallowest type. + // Or they both have identical types if there was no truncation. + Instruction *NarrowestShift = XShift; + + Type *WidestTy = WidestShift->getType(); + assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + "We did not look past any shifts while matching XShift though."); + bool HadTrunc = WidestTy != I.getOperand(0)->getType(); + + // If YShift is a 'lshr', swap the shifts around. + if (match(YShift, m_LShr(m_Value(), m_Value()))) std::swap(XShift, YShift); // The shifts must be in opposite directions. - Instruction::BinaryOps XShiftOpcode = - cast<BinaryOperator>(XShift)->getOpcode(); - if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) + auto XShiftOpcode = XShift->getOpcode(); + if (XShiftOpcode == YShift->getOpcode()) return nullptr; // Do not care about same-direction shifts here. Value *X, *XShAmt, *Y, *YShAmt; - match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); - match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); + match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); + + // If one of the values being shifted is a constant, then we will end with + // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, + // however, we will need to ensure that we won't increase instruction count. + if (!isa<Constant>(X) && !isa<Constant>(Y)) { + // At least one of the hands of the 'and' should be one-use shift. + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) + return nullptr; + if (HadTrunc) { + // Due to the 'trunc', we will need to widen X. For that either the old + // 'trunc' or the shift amt in the non-truncated shift should be one-use. + if (!MaybeTruncation->hasOneUse() && + !NarrowestShift->getOperand(1)->hasOneUse()) + return nullptr; + } + } + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (XShAmt->getType() != YShAmt->getType()) + return nullptr; // Can we fold (XShAmt+YShAmt) ? - Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, - SQ.getWithInstruction(&I)); + auto *NewShAmt = dyn_cast_or_null<Constant>( + SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width? // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); - if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(WidestBitWidth, WidestBitWidth)))) return nullptr; - // All good, we can do this fold. The shift is the same that was for X. + + // An extra legality check is needed if we had trunc-of-lshr. + if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) { + auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ, + WidestShift]() { + // It isn't obvious whether it's worth it to analyze non-constants here. + // Also, let's basically give up on non-splat cases, pessimizing vectors. + // If *any* of these preconditions matches we can perform the fold. + Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy() + ? NewShAmt->getSplatValue() + : NewShAmt; + // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold. + if (NewShAmtSplat && + (NewShAmtSplat->isNullValue() || + NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1)) + return true; + // We consider *min* leading zeros so a single outlier + // blocks the transform as opposed to allowing it. + if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: NewShAmt u<= countLeadingZeros(C) + if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero)) + return true; + } + if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C) + if (NewShAmtSplat) { + APInt AdjNewShAmt = + (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger(); + if (AdjNewShAmt.ule(MinLeadZero)) + return true; + } + } + return false; // Can't tell if it's ok. + }; + if (!CanFold()) + return nullptr; + } + + // All good, we can do this fold. + X = Builder.CreateZExt(X, WidestTy); + Y = Builder.CreateZExt(Y, WidestTy); + // The shift is the same that was for X. Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr ? Builder.CreateLShr(X, NewShAmt) : Builder.CreateShl(X, NewShAmt); Value *T1 = Builder.CreateAnd(T0, Y); return Builder.CreateICmp(I.getPredicate(), T1, - Constant::getNullValue(X->getType())); + Constant::getNullValue(WidestTy)); +} + +/// Fold +/// (-1 u/ x) u< y +/// ((x * y) u/ x) != y +/// to +/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// Note that the comparison is commutative, while inverted (u>=, ==) predicate +/// will mean that we are looking for the opposite answer. +Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { + ICmpInst::Predicate Pred; + Value *X, *Y; + Instruction *Mul; + bool NeedNegation; + // Look for: (-1 u/ x) u</u>= y + if (!I.isEquality() && + match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Value(Y)))) { + Mul = nullptr; + // Canonicalize as-if y was on RHS. + if (I.getOperand(1) != Y) + Pred = I.getSwappedPredicate(); + + // Are we checking that overflow does not happen, or does happen? + switch (Pred) { + case ICmpInst::Predicate::ICMP_ULT: + NeedNegation = false; + break; // OK + case ICmpInst::Predicate::ICMP_UGE: + NeedNegation = true; + break; // OK + default: + return nullptr; // Wrong predicate. + } + } else // Look for: ((x * y) u/ x) !=/== y + if (I.isEquality() && + match(&I, m_c_ICmp(Pred, m_Value(Y), + m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + m_Value(X)), + m_Instruction(Mul)), + m_Deferred(X)))))) { + NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; + } else + return nullptr; + + BuilderTy::InsertPointGuard Guard(Builder); + // If the pattern included (x * y), we'll want to insert new instructions + // right before that original multiplication so that we can replace it. + bool MulHadOtherUses = Mul && !Mul->hasOneUse(); + if (MulHadOtherUses) + Builder.SetInsertPoint(Mul); + + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + + // If the multiplication was used elsewhere, to ensure that we don't leave + // "duplicate" instructions, replace uses of that original multiplication + // with the multiplication result from the with.overflow intrinsic. + if (MulHadOtherUses) + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + + Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + if (NeedNegation) // This technically increases instruction count. + Res = Builder.CreateNot(Res, "umul.not.ov"); + + return Res; } /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code /// duplication. -Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // Special logic for binary operators. @@ -3345,13 +3678,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { Value *X; // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. - // (Op1 + X) <u Op1 --> ~Op1 <u X - // Op0 >u (Op0 + X) --> X >u ~Op0 + // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && - Pred == ICmpInst::ICMP_ULT) + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) return new ICmpInst(Pred, Builder.CreateNot(Op1), X); + // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0 if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && - Pred == ICmpInst::ICMP_UGT) + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; @@ -3378,21 +3711,21 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow. + // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, Constant::getNullValue(Op1->getType())); - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow. + // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow. if ((C == Op0 || D == Op0) && NoOp1WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), C == Op0 ? D : C); - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow. if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && - NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { + NoOp1WrapProblem) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -3416,39 +3749,39 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { return new ICmpInst(Pred, Y, Z); } - // icmp slt (X + -1), Y -> icmp sle X, Y + // icmp slt (A + -1), Op1 -> icmp sle A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - // icmp sge (X + -1), Y -> icmp sgt X, Y + // icmp sge (A + -1), Op1 -> icmp sgt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - // icmp sle (X + 1), Y -> icmp slt X, Y + // icmp sle (A + 1), Op1 -> icmp slt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - // icmp sgt (X + 1), Y -> icmp sge X, Y + // icmp sgt (A + 1), Op1 -> icmp sge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - // icmp sgt X, (Y + -1) -> icmp sge X, Y + // icmp sgt Op0, (C + -1) -> icmp sge Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - // icmp sle X, (Y + -1) -> icmp slt X, Y + // icmp sle Op0, (C + -1) -> icmp slt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - // icmp sge X, (Y + 1) -> icmp sgt X, Y + // icmp sge Op0, (C + 1) -> icmp sgt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - // icmp slt X, (Y + 1) -> icmp sle X, Y + // icmp slt Op0, (C + 1) -> icmp sle Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); @@ -3456,33 +3789,33 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { // canonicalization from (X -nuw 1) to (X + -1) means that the combinations // wouldn't happen even if they were implemented. // - // icmp ult (X - 1), Y -> icmp ule X, Y - // icmp uge (X - 1), Y -> icmp ugt X, Y - // icmp ugt X, (Y - 1) -> icmp uge X, Y - // icmp ule X, (Y - 1) -> icmp ult X, Y + // icmp ult (A - 1), Op1 -> icmp ule A, Op1 + // icmp uge (A - 1), Op1 -> icmp ugt A, Op1 + // icmp ugt Op0, (C - 1) -> icmp uge Op0, C + // icmp ule Op0, (C - 1) -> icmp ult Op0, C - // icmp ule (X + 1), Y -> icmp ult X, Y + // icmp ule (A + 1), Op0 -> icmp ult A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); - // icmp ugt (X + 1), Y -> icmp uge X, Y + // icmp ugt (A + 1), Op0 -> icmp uge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); - // icmp uge X, (Y + 1) -> icmp ugt X, Y + // icmp uge Op0, (C + 1) -> icmp ugt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); - // icmp ult X, (Y + 1) -> icmp ule X, Y + // icmp ult Op0, (C + 1) -> icmp ule Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // icmp (A + C1), (C + C2) -> icmp (A + C3), C // s.t. C3 = C1 - C2 // // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // icmp (A + C1), (C + C2) -> icmp A, (C + C3) // s.t. C3 = C2 - C1 if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) @@ -3520,29 +3853,35 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow. if (C == Op0 && NoOp1WrapProblem) return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - // (A - B) >u A --> A <u B - if (A == Op1 && Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_ULT, A, B); - // C <u (C - D) --> C <u D - if (C == Op0 && Pred == ICmpInst::ICMP_ULT) - return new ICmpInst(ICmpInst::ICMP_ULT, C, D); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + // Convert sub-with-unsigned-overflow comparisons into a comparison of args. + // (A - B) u>/u<= A --> B u>/u<= A + if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) + return new ICmpInst(Pred, B, A); + // C u</u>= (C - D) --> C u</u>= D + if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) + return new ICmpInst(Pred, C, D); + // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 + if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); + // C u<=/u> (C - D) --> C u</u>= D iff B != 0 + if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); + + // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, A, C); - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + + // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, D, B); // icmp (0-X) < cst --> x > -cst @@ -3677,6 +4016,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { } } + if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + return replaceInstUsesWith(I, V); + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) return replaceInstUsesWith(I, V); @@ -3953,125 +4295,140 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; } -/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so -/// far. -Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { - const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); - Value *LHSCIOp = LHSCI->getOperand(0); - Type *SrcTy = LHSCIOp->getType(); - Type *DestTy = LHSCI->getType(); - - // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the - // integer type is the same size as the pointer type. - const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool { - if (isa<VectorType>(SrcTy)) { - SrcTy = cast<VectorType>(SrcTy)->getElementType(); - DestTy = cast<VectorType>(DestTy)->getElementType(); - } - return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); - }; - if (LHSCI->getOpcode() == Instruction::PtrToInt && - CompatibleSizes(SrcTy, DestTy)) { - Value *RHSOp = nullptr; - if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { - Value *RHSCIOp = RHSC->getOperand(0); - if (RHSCIOp->getType()->getPointerAddressSpace() == - LHSCIOp->getType()->getPointerAddressSpace()) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder.CreateBitCast(RHSOp, LHSCIOp->getType()); - } - } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { - RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } - - if (RHSOp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp); - } - - // The code below only handles extension cast instructions, so far. - // Enforce this. - if (LHSCI->getOpcode() != Instruction::ZExt && - LHSCI->getOpcode() != Instruction::SExt) +static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); + auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); + Value *X; + if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) return nullptr; - bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICmp.isSigned(); - - if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { - // Not an extension from the same type? - Value *RHSCIOp = CI->getOperand(0); - if (RHSCIOp->getType() != LHSCIOp->getType()) - return nullptr; - + bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; + bool IsSignedCmp = ICmp.isSigned(); + if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) { // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. - if (CI->getOpcode() != LHSCI->getOpcode()) + // TODO: This is too strict. We can handle some predicates (equality?). + if (CastOp0->getOpcode() != CastOp1->getOpcode()) return nullptr; - // Deal with equality cases early. + // Not an extension from the same type? + Value *Y = CastOp1->getOperand(0); + Type *XTy = X->getType(), *YTy = Y->getType(); + if (XTy != YTy) { + // One of the casts must have one use because we are creating a new cast. + if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + return nullptr; + // Extend the narrower operand to the type of the wider operand. + if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) + X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) + Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + else + return nullptr; + } + + // (zext X) == (zext Y) --> X == Y + // (sext X) == (sext Y) --> X == Y if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getPredicate(), X, Y); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedCmp && isSignedExt) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + if (IsSignedCmp && IsSignedExt) + return new ICmpInst(ICmp.getPredicate(), X, Y); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); } - // If we aren't dealing with a constant on the RHS, exit early. + // Below here, we are only folding a compare with constant. auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // re-extended to DestTy. + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); - Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); + Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... if (Res2 == C) { - // Deal with equality cases early. if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedExt && isSignedCmp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + if (IsSignedExt && IsSignedCmp) + return new ICmpInst(ICmp.getPredicate(), X, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); } // The re-extended constant changed, partly changed (in the case of a vector), // or could not be determined to be equal (in the case of a constant // expression), so the constant cannot be represented in the shorter type. - // Consequently, we cannot emit a simple comparison. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. + if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) + return nullptr; + + // Is source op positive? + // icmp ult (sext X), C --> icmp sgt X, -1 + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); + + // Is source op negative? + // icmp ugt (sext X), C --> icmp slt X, 0 + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); +} - if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C)) +/// Handle icmp (cast x), (cast or constant). +Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { + auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); + if (!CastOp0) + return nullptr; + if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1))) return nullptr; - // Evaluate the comparison for LT (we invert for GT below). LE and GE cases - // should have been folded away previously and not enter in here. + Value *Op0Src = CastOp0->getOperand(0); + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); - // We're performing an unsigned comp with a sign extended value. - // This is true if the input is >= 0. [aka >s -1] - Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder.CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; + if (CastOp0->getOpcode() == Instruction::PtrToInt && + CompatibleSizes(SrcTy, DestTy)) { + Value *NewOp1 = nullptr; + if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { + Value *PtrSrc = PtrToIntOp1->getOperand(0); + if (PtrSrc->getType()->getPointerAddressSpace() == + Op0Src->getType()->getPointerAddressSpace()) { + NewOp1 = PtrToIntOp1->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (Op0Src->getType() != NewOp1->getType()) + NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); + } + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { + NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } - // Finally, return the value computed. - if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) - return replaceInstUsesWith(ICmp, Result); + if (NewOp1) + return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); + } - assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); - return BinaryOperator::CreateNot(Result); + return foldICmpWithZextOrSext(ICmp, Builder); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { @@ -4791,41 +5148,35 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return nullptr; } -/// If we have an icmp le or icmp ge instruction with a constant operand, turn -/// it into the appropriate icmp lt or icmp gt instruction. This transform -/// allows them to be folded in visitICmpInst. -static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { - ICmpInst::Predicate Pred = I.getPredicate(); - if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && - Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) - return nullptr; +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, + Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); - Value *Op0 = I.getOperand(0); - Value *Op1 = I.getOperand(1); - auto *Op1C = dyn_cast<Constant>(Op1); - if (!Op1C) - return nullptr; + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; - // Check if the constant operand can be safely incremented/decremented without - // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled - // the edge cases for us, so we just assert on them. For vectors, we must - // handle the edge cases. - Type *Op1Type = Op1->getType(); - bool IsSigned = I.isSigned(); - bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); - auto *CI = dyn_cast<ConstantInt>(Op1C); - if (CI) { - // A <= MAX -> TRUE ; A >= MIN -> TRUE - assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); - } else if (Op1Type->isVectorTy()) { - // TODO? If the edge cases for vectors were guaranteed to be handled as they - // are for scalar, we could remove the min/max checks. However, to do that, - // we would have to use insertelement/shufflevector to replace edge values. - unsigned NumElts = Op1Type->getVectorNumElements(); + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return llvm::None; + } else if (Type->isVectorTy()) { + unsigned NumElts = Type->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = Op1C->getAggregateElement(i); + Constant *Elt = C->getAggregateElement(i); if (!Elt) - return nullptr; + return llvm::None; if (isa<UndefValue>(Elt)) continue; @@ -4833,20 +5184,43 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); - if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) - return nullptr; + if (!CI || !ConstantIsOk(CI)) + return llvm::None; } } else { // ConstantExpr? - return nullptr; + return llvm::None; } - // Increment or decrement the constant and set the new comparison predicate: - // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT - Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true); - CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; - NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; - return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || + isCanonicalPredicate(Pred)) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + if (!FlippedStrictness) + return nullptr; + + return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); } /// Integer compare with boolean values can always be turned into bitwise ops. @@ -5002,6 +5376,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); unsigned Op0Cplxity = getComplexity(Op0); unsigned Op1Cplxity = getComplexity(Op1); @@ -5016,8 +5391,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) return replaceInstUsesWith(I, V); // Comparing -val or val with non-zero is the same as just comparing val @@ -5050,6 +5424,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5098,6 +5475,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; + // Try to match comparison as a sign bit test. Intentionally do this after + // foldICmpInstWithConstant() to potentially let other folds to happen first. + if (Instruction *New = foldSignBitTest(I)) + return New; + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; @@ -5124,20 +5506,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpBitCast(I, Builder)) return Res; - if (isa<CastInst>(Op0)) { - // Handle the special case of: icmp (cast bool to X), <cst> - // This comes up when you have code like - // int X = A < B; - // if (X) ... - // For generality, we handle any zero-extension of any operand comparison - // with a constant or another cast from the same type. - if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = foldICmpWithCastAndCast(I)) - return R; - } - - if (Instruction *Res = foldICmpBinOp(I)) - return Res; + if (Instruction *R = foldICmpWithCastOp(I)) + return R; if (Instruction *Res = foldICmpWithMinMax(I)) return Res; |