diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 493 |
1 files changed, 353 insertions, 140 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index c5e14ebf3ae3..7a9e177f19da 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -78,15 +78,15 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (C.isNullValue()) + if (C.isZero()) return ICmpInst::isRelational(Pred); - if (C.isOneValue()) { + if (C.isOne()) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (C.isAllOnesValue()) { + } else if (C.isAllOnes()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -541,7 +541,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, if (!CI->isNoopCast(DL)) return false; - if (Explored.count(CI->getOperand(0)) == 0) + if (!Explored.contains(CI->getOperand(0))) WorkList.push_back(CI->getOperand(0)); } @@ -553,7 +553,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, GEP->getType() != Start->getType()) return false; - if (Explored.count(GEP->getOperand(0)) == 0) + if (!Explored.contains(GEP->getOperand(0))) WorkList.push_back(GEP->getOperand(0)); } @@ -575,7 +575,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, // Explore the PHI nodes further. for (auto *PN : PHIs) for (Value *Op : PN->incoming_values()) - if (Explored.count(Op) == 0) + if (!Explored.contains(Op)) WorkList.push_back(Op); } @@ -589,7 +589,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, auto *Inst = dyn_cast<Instruction>(Val); if (Inst == Base || Inst == PHI || !Inst || !PHI || - Explored.count(PHI) == 0) + !Explored.contains(PHI)) continue; if (PHI->getParent() == Inst->getParent()) @@ -1147,12 +1147,12 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; bool IsAShr = isa<AShrOperator>(I.getOperand(0)); if (IsAShr) { - if (AP2.isAllOnesValue()) + if (AP2.isAllOnes()) return nullptr; if (AP2.isNegative() != AP1.isNegative()) return nullptr; @@ -1178,7 +1178,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, if (IsAShr && AP1 == AP2.ashr(Shift)) { // There are multiple solutions if we are comparing against -1 and the LHS // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + if (AP1.isAllOnes() && !AP2.isPowerOf2()) return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); } else if (AP1 == AP2.lshr(Shift)) { @@ -1206,7 +1206,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; unsigned AP2TrailingZeros = AP2.countTrailingZeros(); @@ -1270,9 +1270,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + if (IC.ComputeMinSignedBits(A, 0, &I) > NewWidth || + IC.ComputeMinSignedBits(B, 0, &I) > NewWidth) return nullptr; // In order to replace the original add with a narrower @@ -1544,7 +1543,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C.isOneValue() && C.getBitWidth() > 1) { + if (C.isOne() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1725,7 +1724,7 @@ Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1749,7 +1748,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 // TODO: We canonicalize to the longer form for scalars because we have // better analysis/folds for icmp, and codegen may be better with icmp. - if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() && match(And->getOperand(1), m_One())) return new TruncInst(And->getOperand(0), Cmp.getType()); @@ -1762,7 +1761,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, if (!And->hasOneUse()) return nullptr; - if (Cmp.isEquality() && C1.isNullValue()) { + if (Cmp.isEquality() && C1.isZero()) { // Restrict this fold to single-use 'and' (PR10267). // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 if (C2->isSignMask()) { @@ -1812,7 +1811,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && + if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1889,7 +1888,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) { auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1899,7 +1898,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isZero() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1920,7 +1919,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C.isOneValue()) { + if (C.isOne()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1950,7 +1949,18 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, } } - if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) + // (X | (X-1)) s< 0 --> X s< 1 + // (X | (X-1)) s> -1 --> X s> 0 + Value *X; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) { + auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT; + Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0); + return new ICmpInst(NewPred, X, NewC); + } + + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -2001,14 +2011,14 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, // If the multiply does not wrap, try to divide the compare constant by the // multiplication factor. - if (Cmp.isEquality() && !MulC->isNullValue()) { + if (Cmp.isEquality() && !MulC->isZero()) { // (mul nsw X, MulC) == C --> X == C /s MulC - if (Mul->hasNoSignedWrap() && C.srem(*MulC).isNullValue()) { + if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isNullValue()) { + if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } @@ -2053,7 +2063,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnesValue()) { + if (C.isAllOnes()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -2227,8 +2237,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C.isNullValue()) + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && C.isZero()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; @@ -2316,7 +2325,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Shr->isExact()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - if (C.isNullValue()) { + if (C.isZero()) { // == 0 is u< 1. if (Pred == CmpInst::ICMP_EQ) return new ICmpInst(CmpInst::ICMP_ULT, X, @@ -2355,7 +2364,7 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, return nullptr; const APInt *DivisorC; - if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + if (!C.isZero() || !match(SRem->getOperand(1), m_Power2(DivisorC))) return nullptr; // Mask off the sign bit and the modulo bits (low-bits). @@ -2435,8 +2444,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // INT_MIN will also fail if the divisor is 1. Although folds of all these // division-by-constant cases should be present, we can not assert that they // have happened before we reach this icmp instruction. - if (C2->isNullValue() || C2->isOneValue() || - (DivIsSigned && C2->isAllOnesValue())) + if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes())) return nullptr; // Compute Prod = C * C2. We are essentially solving an equation of @@ -2476,16 +2484,16 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C.isNullValue()) { // (X / pos) op 0 + if (C.isZero()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C.isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); - } else { // (X / pos) op neg + } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; @@ -2497,7 +2505,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) RangeSize.negate(); - if (C.isNullValue()) { // (X / neg) op 0 + if (C.isZero()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = RangeSize + 1; HiBound = -RangeSize; @@ -2505,13 +2513,13 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = 1; // [INTMIN+1, overflow) HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C.isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; - } else { // (X / neg) op neg + } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) @@ -2581,42 +2589,54 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); - const APInt *C2; - APInt SubResult; + Type *Ty = Sub->getType(); - // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 - if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) - return new ICmpInst(Cmp.getPredicate(), Y, - ConstantInt::get(Y->getType(), 0)); + // (SubC - Y) == C) --> Y == (SubC - C) + // (SubC - Y) != C) --> Y != (SubC - C) + Constant *SubC; + if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) { + return new ICmpInst(Pred, Y, + ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C))); + } // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + const APInt *C2; + APInt SubResult; + ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate(); + bool HasNSW = Sub->hasNoSignedWrap(); + bool HasNUW = Sub->hasNoUnsignedWrap(); if (match(X, m_APInt(C2)) && - ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || - (Cmp.isSigned() && Sub->hasNoSignedWrap())) && + ((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) && !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) - return new ICmpInst(Cmp.getSwappedPredicate(), Y, - ConstantInt::get(Y->getType(), SubResult)); + return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); // The following transforms are only worth it if the only user of the subtract // is the icmp. + // TODO: This is an artificial restriction for all of the transforms below + // that only need a single replacement icmp. if (!Sub->hasOneUse()) return nullptr; + // X - Y == 0 --> X == Y. + // X - Y != 0 --> X != Y. + if (Cmp.isEquality() && C.isZero()) + return new ICmpInst(Pred, X, Y); + if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2634,7 +2654,12 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); - return nullptr; + // We have handled special cases that reduce. + // Canonicalize any remaining sub to add as: + // (C2 - Y) > C --> (Y + ~C2) < ~C + Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub", + HasNUW, HasNSW); + return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); } /// Fold icmp (add X, Y), C. @@ -2723,6 +2748,14 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); + // The range test idiom can use either ult or ugt. Arbitrarily canonicalize + // to the ult form. + // X+C2 >u C -> X+(C2-C-1) <u ~C + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, + Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)), + ConstantInt::get(Ty, ~C)); + return nullptr; } @@ -2830,8 +2863,7 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, return nullptr; } -static Instruction *foldICmpBitCast(ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); if (!Bitcast) return nullptr; @@ -2917,6 +2949,39 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, return new ICmpInst(Pred, BCSrcOp, Op1); } + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || + !Bitcast->getType()->isIntegerTy() || + !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + return nullptr; + + // If this is checking if all elements of a vector compare are set or not, + // invert the casted vector equality compare and test if all compare + // elements are clear or not. Compare against zero is generally easier for + // analysis and codegen. + // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 + // Example: are all elements equal? --> are zero elements not equal? + // TODO: Try harder to reduce compare of 2 freely invertible operands? + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && + isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { + Type *ScalarTy = Bitcast->getType(); + Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), ScalarTy); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(ScalarTy)); + } + + // If this is checking if all elements of an extended vector are clear or not, + // compare in a narrow type to eliminate the extend: + // icmp eq/ne (bitcast (ext X) to iN), 0 --> icmp eq/ne (bitcast X to iM), 0 + Value *X; + if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() && + match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) { + if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) { + Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits()); + Value *NewCast = Builder.CreateBitCast(X, NewType); + return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType)); + } + } + // Folding: icmp <pred> iN X, C // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN // and C is a splat of a K-bit pattern @@ -2924,12 +2989,6 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, // Into: // %E = extractelement <M x iK> %vec, i32 C' // icmp <pred> iK %E, trunc(C) - const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C)) || - !Bitcast->getType()->isIntegerTy() || - !Bitcast->getSrcTy()->isIntOrIntVectorTy()) - return nullptr; - Value *Vec; ArrayRef<int> Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { @@ -3055,7 +3114,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C.isNullValue() && BO->hasOneUse()) { + if (C.isZero() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -3069,7 +3128,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( if (Constant *BOC = dyn_cast<Constant>(BOp1)) { if (BO->hasOneUse()) return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -3090,25 +3149,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } } break; - case Instruction::Sub: - if (BO->hasOneUse()) { - // Only check for constant LHS here, as constant RHS will be canonicalized - // to add and use the fold above. - if (Constant *BOC = dyn_cast<Constant>(BOp0)) { - // Replace ((sub BOC, B) != C) with (B != BOC-C). - return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS)); - } else if (C.isNullValue()) { - // Replace ((sub A, B) != 0) with (A != B). - return new ICmpInst(Pred, BOp0, BOp1); - } - } - break; case Instruction::Or: { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { @@ -3132,7 +3178,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( break; } case Instruction::UDiv: - if (C.isNullValue()) { + if (C.isZero()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -3149,25 +3195,26 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + switch (II->getIntrinsicID()) { case Intrinsic::abs: // abs(A) == 0 -> A == 0 // abs(A) == INT_MIN -> A == INT_MIN - if (C.isNullValue() || C.isMinSignedValue()) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - ConstantInt::get(Ty, C)); + if (C.isZero() || C.isMinSignedValue()) + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C)); break; case Intrinsic::bswap: // bswap(A) == C -> A == bswap(C) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C.byteSwap())); case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != if (C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::getNullValue(Ty)); // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set @@ -3181,9 +3228,8 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( APInt Mask2 = IsTrailing ? APInt::getOneBitSet(BitWidth, Num) : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); - return new ICmpInst(Cmp.getPredicate(), - Builder.CreateAnd(II->getArgOperand(0), Mask1), - ConstantInt::get(Ty, Mask2)); + return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); } break; } @@ -3191,28 +3237,49 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C.isNullValue(); + bool IsZero = C.isZero(); if (IsZero || C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty)); + return new ICmpInst(Pred, II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) + : Constant::getAllOnesValue(Ty)); break; } + case Intrinsic::fshl: + case Intrinsic::fshr: + if (II->getArgOperand(0) == II->getArgOperand(1)) { + // (rot X, ?) == 0/-1 --> X == 0/-1 + // TODO: This transform is safe to re-use undef elts in a vector, but + // the constant value passed in by the caller doesn't allow that. + if (C.isZero() || C.isAllOnes()) + return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); + + const APInt *RotAmtC; + // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) + // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) + if (match(II->getArgOperand(2), m_APInt(RotAmtC))) + return new ICmpInst(Pred, II->getArgOperand(0), + II->getIntrinsicID() == Intrinsic::fshl + ? ConstantInt::get(Ty, C.rotr(*RotAmtC)) + : ConstantInt::get(Ty, C.rotl(*RotAmtC))); + } + break; + case Intrinsic::uadd_sat: { // uadd.sat(a, b) == 0 -> (a | b) == 0 - if (C.isNullValue()) { + if (C.isZero()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); - return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty)); + return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); } break; } case Intrinsic::usub_sat: { // usub.sat(a, b) == 0 -> a <= b - if (C.isNullValue()) { - ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + if (C.isZero()) { + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } break; @@ -3224,6 +3291,42 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( return nullptr; } +/// Fold an icmp with LLVM intrinsics +static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { + assert(Cmp.isEquality()); + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *Op1 = Cmp.getOperand(1); + const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0); + const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1); + if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID()) + return nullptr; + + switch (IIOp0->getIntrinsicID()) { + case Intrinsic::bswap: + case Intrinsic::bitreverse: + // If both operands are byte-swapped or bit-reversed, just compare the + // original values. + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + case Intrinsic::fshl: + case Intrinsic::fshr: + // If both operands are rotated by same amount, just compare the + // original values. + if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) + break; + if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) + break; + if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) + break; + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + default: + break; + } + + return nullptr; +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, @@ -3663,7 +3766,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, (WidestTy->getScalarSizeInBits() - 1) + (NarrowestTy->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits()); if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) return nullptr; @@ -3746,19 +3849,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /// Fold /// (-1 u/ x) u< y -/// ((x * y) u/ x) != y +/// ((x * y) ?/ x) != y /// to -/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit /// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { +Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; Instruction *Mul; + Instruction *Div; bool NeedNegation; // Look for: (-1 u/ x) u</u>= y if (!I.isEquality() && - match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + match(&I, m_c_ICmp(Pred, + m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Instruction(Div)), m_Value(Y)))) { Mul = nullptr; @@ -3773,13 +3879,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { default: return nullptr; // Wrong predicate. } - } else // Look for: ((x * y) u/ x) !=/== y + } else // Look for: ((x * y) / x) !=/== y if (I.isEquality() && - match(&I, m_c_ICmp(Pred, m_Value(Y), - m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + match(&I, + m_c_ICmp(Pred, m_Value(Y), + m_CombineAnd( + m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), m_Value(X)), m_Instruction(Mul)), - m_Deferred(X)))))) { + m_Deferred(X))), + m_Instruction(Div))))) { NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; } else return nullptr; @@ -3791,19 +3900,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (MulHadOtherUses) Builder.SetInsertPoint(Mul); - Function *F = Intrinsic::getDeclaration( - I.getModule(), Intrinsic::umul_with_overflow, X->getType()); - CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + Function *F = Intrinsic::getDeclaration(I.getModule(), + Div->getOpcode() == Instruction::UDiv + ? Intrinsic::umul_with_overflow + : Intrinsic::smul_with_overflow, + X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); // If the multiplication was used elsewhere, to ensure that we don't leave // "duplicate" instructions, replace uses of that original multiplication // with the multiplication result from the with.overflow intrinsic. if (MulHadOtherUses) - replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); - Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); if (NeedNegation) // This technically increases instruction count. - Res = Builder.CreateNot(Res, "umul.not.ov"); + Res = Builder.CreateNot(Res, "mul.not.ov"); // If we replaced the mul, erase it. Do this after all uses of Builder, // as the mul is used as insertion point. @@ -4079,8 +4191,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) if (!C->countTrailingZeros() || - (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) + (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || + (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) return new ICmpInst(Pred, X, Y); } @@ -4146,8 +4258,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, break; const APInt *C; - if (match(BO0->getOperand(1), m_APInt(C)) && !C->isNullValue() && - !C->isOneValue()) { + if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() && + !C->isOne()) { // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) // Mask = -1 >> count-trailing-zeros(C). if (unsigned TZs = C->countTrailingZeros()) { @@ -4200,7 +4312,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } - if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) @@ -4373,6 +4485,19 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } + { + // Similar to above, but specialized for constant because invert is needed: + // (X | C) == (Y | C) --> (X ^ Y) & ~C == 0 + Value *X, *Y; + Constant *C; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) && + match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) { + Value *Xor = Builder.CreateXor(X, Y); + Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C)); + return new ICmpInst(Pred, And, Constant::getNullValue(And->getType())); + } + } + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) ConstantInt *Cst1; @@ -4441,14 +4566,8 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // If both operands are byte-swapped or bit-reversed, just compare the - // original values. - // TODO: Move this to a function similar to foldICmpIntrinsicWithConstant() - // and handle more intrinsics. - if ((match(Op0, m_BSwap(m_Value(A))) && match(Op1, m_BSwap(m_Value(B)))) || - (match(Op0, m_BitReverse(m_Value(A))) && - match(Op1, m_BitReverse(m_Value(B))))) - return new ICmpInst(Pred, A, B); + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + return ICmp; // Canonicalize checking for a power-of-2-or-zero value: // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) @@ -4474,6 +4593,74 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); } + // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the + // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", + // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps + // of instcombine. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) && + match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) && + A->getType()->getScalarSizeInBits() == BitWidth * 2 && + (I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) { + APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1); + Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C)); + return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT + : ICmpInst::ICMP_UGE, + Add, ConstantInt::get(A->getType(), C.shl(1))); + } + + return nullptr; +} + +static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + const ICmpInst::Predicate Pred = ICmp.getPredicate(); + Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); + + // Try to canonicalize trunc + compare-to-constant into a mask + cmp. + // The trunc masks high bits while the compare may effectively mask low bits. + Value *X; + const APInt *C; + if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C))) + return nullptr; + + unsigned SrcBits = X->getType()->getScalarSizeInBits(); + if (Pred == ICmpInst::ICMP_ULT) { + if (C->isPowerOf2()) { + // If C is a power-of-2 (one set bit): + // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) + Constant *MaskC = ConstantInt::get(X->getType(), (-*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_EQ, And, Zero); + } + // If C is a negative power-of-2 (high-bit mask): + // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) + if (C->isNegatedPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); + } + } + + if (Pred == ICmpInst::ICMP_UGT) { + // If C is a low-bit-mask (C+1 is a power-of-2): + // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) + if (C->isMask()) { + Constant *MaskC = ConstantInt::get(X->getType(), (~*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + // If C is not-of-power-of-2 (one clear bit): + // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) + if ((~*C).isPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); + } + } + return nullptr; } @@ -4620,6 +4807,9 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } + if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + return R; + return foldICmpWithZextOrSext(ICmp, Builder); } @@ -4943,7 +5133,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { const APInt *RHS; if (!match(I.getOperand(1), m_APInt(RHS))) - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); // If this is a normal comparison, it demands all bits. If it is a sign bit // comparison, it only demands the sign bit. @@ -4965,7 +5155,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); } } @@ -5129,8 +5319,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { Op0Known, 0)) return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), - Op1Known, 0)) + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) return &I; // Given the known and unknown bits, compute a range that the LHS could be @@ -5280,7 +5469,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. const APInt *CI; - if (Op0KnownZeroInverted.isOneValue() && + if (Op0KnownZeroInverted.isOne() && match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { // ((8 >>u X) & 1) == 0 -> X != 3 // ((8 >>u X) & 1) != 0 -> X == 3 @@ -5618,7 +5807,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); + return new ShuffleVectorInst(NewCmp, M); } // Try to canonicalize compare with splatted operand and splat constant. @@ -5639,8 +5828,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, ScalarC); SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); Value *NewCmp = Builder.CreateCmp(Pred, V1, C); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), - NewM); + return new ShuffleVectorInst(NewCmp, NewM); } return nullptr; @@ -5676,6 +5864,23 @@ static Instruction *foldICmpOfUAddOv(ICmpInst &I) { return ExtractValueInst::Create(UAddOv, 1); } +static Instruction *foldICmpInvariantGroup(ICmpInst &I) { + if (!I.getOperand(0)->getType()->isPointerTy() || + NullPointerIsDefined( + I.getParent()->getParent(), + I.getOperand(0)->getType()->getPointerAddressSpace())) { + return nullptr; + } + Instruction *Op; + if (match(I.getOperand(0), m_Instruction(Op)) && + match(I.getOperand(1), m_Zero()) && + Op->isLaunderOrStripInvariantGroup()) { + return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), + Op->getOperand(0), I.getOperand(1)); + } + return nullptr; +} + Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -5729,9 +5934,6 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpBinOp(I, Q)) - return Res; - if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5777,6 +5979,15 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + // The folds in here may rely on wrapping flags and special constants, so + // they can break up min/max idioms in some cases but not seemingly similar + // patterns. + // FIXME: It may be possible to enhance select folding to make this + // unnecessary. It may also be moot if we canonicalize to min/max + // intrinsics. + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; @@ -5788,13 +5999,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; - // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) + // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. + if (auto *GEP = dyn_cast<GEPOperator>(Op0)) if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = foldGEPICmp(GEP, Op0, - ICmpInst::getSwappedPredicate(I.getPredicate()), I)) + if (auto *GEP = dyn_cast<GEPOperator>(Op1)) + if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I)) return NI; // Try to optimize equality comparisons against alloca-based pointers. @@ -5808,7 +6018,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { return New; } - if (Instruction *Res = foldICmpBitCast(I, Builder)) + if (Instruction *Res = foldICmpBitCast(I)) return Res; // TODO: Hoist this above the min/max bailout. @@ -5910,6 +6120,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; + if (Instruction *Res = foldICmpInvariantGroup(I)) + return Res; + return Changed ? &I : nullptr; } |
