summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp498
1 files changed, 288 insertions, 210 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index f38dc436722dc..f1233b62445d0 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -897,7 +897,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
// For vectors, we apply the same reasoning on a per-lane basis.
auto *Base = GEPLHS->getPointerOperand();
if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) {
- int NumElts = GEPLHS->getType()->getVectorNumElements();
+ int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements();
Base = Builder.CreateVectorSplat(NumElts, Base);
}
return new ICmpInst(Cond, Base,
@@ -1330,6 +1330,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,
// The inner add was the result of the narrow add, zero extended to the
// wider type. Replace it with the result computed by the intrinsic.
IC.replaceInstUsesWith(*OrigAdd, ZExt);
+ IC.eraseInstFromFunction(*OrigAdd);
// The original icmp gets replaced with the overflow value.
return ExtractValueInst::Create(Call, 1, "sadd.overflow");
@@ -1451,6 +1452,27 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this))
return Res;
+ // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...).
+ Constant *C = dyn_cast<Constant>(Op1);
+ if (!C)
+ return nullptr;
+
+ if (auto *Phi = dyn_cast<PHINode>(Op0))
+ if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) {
+ Type *Ty = Cmp.getType();
+ Builder.SetInsertPoint(Phi);
+ PHINode *NewPhi =
+ Builder.CreatePHI(Ty, Phi->getNumOperands());
+ for (BasicBlock *Predecessor : predecessors(Phi->getParent())) {
+ auto *Input =
+ cast<Constant>(Phi->getIncomingValueForBlock(Predecessor));
+ auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C);
+ NewPhi->addIncoming(BoolInput, Predecessor);
+ }
+ NewPhi->takeName(&Cmp);
+ return replaceInstUsesWith(Cmp, NewPhi);
+ }
+
return nullptr;
}
@@ -1575,11 +1597,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp,
// If the sign bit of the XorCst is not set, there is no change to
// the operation, just stop using the Xor.
- if (!XorC->isNegative()) {
- Cmp.setOperand(0, X);
- Worklist.Add(Xor);
- return &Cmp;
- }
+ if (!XorC->isNegative())
+ return replaceOperand(Cmp, 0, X);
// Emit the opposite comparison.
if (TrueIfSigned)
@@ -1645,51 +1664,53 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
bool IsShl = ShiftOpcode == Instruction::Shl;
const APInt *C3;
if (match(Shift->getOperand(1), m_APInt(C3))) {
- bool CanFold = false;
+ APInt NewAndCst, NewCmpCst;
+ bool AnyCmpCstBitsShiftedOut;
if (ShiftOpcode == Instruction::Shl) {
// For a left shift, we can fold if the comparison is not signed. We can
// also fold a signed comparison if the mask value and comparison value
// are not negative. These constraints may not be obvious, but we can
// prove that they are correct using an SMT solver.
- if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative()))
- CanFold = true;
- } else {
- bool IsAshr = ShiftOpcode == Instruction::AShr;
+ if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative()))
+ return nullptr;
+
+ NewCmpCst = C1.lshr(*C3);
+ NewAndCst = C2.lshr(*C3);
+ AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1;
+ } else if (ShiftOpcode == Instruction::LShr) {
// For a logical right shift, we can fold if the comparison is not signed.
// We can also fold a signed comparison if the shifted mask value and the
// shifted comparison value are not negative. These constraints may not be
// obvious, but we can prove that they are correct using an SMT solver.
- // For an arithmetic shift right we can do the same, if we ensure
- // the And doesn't use any bits being shifted in. Normally these would
- // be turned into lshr by SimplifyDemandedBits, but not if there is an
- // additional user.
- if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) {
- if (!Cmp.isSigned() ||
- (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative()))
- CanFold = true;
- }
+ NewCmpCst = C1.shl(*C3);
+ NewAndCst = C2.shl(*C3);
+ AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1;
+ if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative()))
+ return nullptr;
+ } else {
+ // For an arithmetic shift, check that both constants don't use (in a
+ // signed sense) the top bits being shifted out.
+ assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode");
+ NewCmpCst = C1.shl(*C3);
+ NewAndCst = C2.shl(*C3);
+ AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1;
+ if (NewAndCst.ashr(*C3) != C2)
+ return nullptr;
}
- if (CanFold) {
- APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3);
- APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3);
- // Check to see if we are shifting out any of the bits being compared.
- if (SameAsC1 != C1) {
- // If we shifted bits out, the fold is not going to work out. As a
- // special case, check to see if this means that the result is always
- // true or false now.
- if (Cmp.getPredicate() == ICmpInst::ICMP_EQ)
- return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType()));
- if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
- return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType()));
- } else {
- Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst));
- APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3);
- And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst));
- And->setOperand(0, Shift->getOperand(0));
- Worklist.Add(Shift); // Shift is dead.
- return &Cmp;
- }
+ if (AnyCmpCstBitsShiftedOut) {
+ // If we shifted bits out, the fold is not going to work out. As a
+ // special case, check to see if this means that the result is always
+ // true or false now.
+ if (Cmp.getPredicate() == ICmpInst::ICMP_EQ)
+ return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType()));
+ if (Cmp.getPredicate() == ICmpInst::ICMP_NE)
+ return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType()));
+ } else {
+ Value *NewAnd = Builder.CreateAnd(
+ Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst));
+ return new ICmpInst(Cmp.getPredicate(),
+ NewAnd, ConstantInt::get(And->getType(), NewCmpCst));
}
}
@@ -1705,8 +1726,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
// Compute X & (C2 << Y).
Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift);
- Cmp.setOperand(0, NewAnd);
- return &Cmp;
+ return replaceOperand(Cmp, 0, NewAnd);
}
return nullptr;
@@ -1812,8 +1832,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
}
if (NewOr) {
Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName());
- Cmp.setOperand(0, NewAnd);
- return &Cmp;
+ return replaceOperand(Cmp, 0, NewAnd);
}
}
}
@@ -1863,8 +1882,8 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp,
int32_t ExactLogBase2 = C2->exactLogBase2();
if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) {
Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1);
- if (And->getType()->isVectorTy())
- NTy = VectorType::get(NTy, And->getType()->getVectorNumElements());
+ if (auto *AndVTy = dyn_cast<VectorType>(And->getType()))
+ NTy = FixedVectorType::get(NTy, AndVTy->getNumElements());
Value *Trunc = Builder.CreateTrunc(X, NTy);
auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE
: CmpInst::ICMP_SLT;
@@ -1888,20 +1907,24 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
}
Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1);
- if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) {
- // X | C == C --> X <=u C
- // X | C != C --> X >u C
- // iff C+1 is a power of 2 (C is a bitmask of the low bits)
- if ((C + 1).isPowerOf2()) {
+ const APInt *MaskC;
+ if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) {
+ if (*MaskC == C && (C + 1).isPowerOf2()) {
+ // X | C == C --> X <=u C
+ // X | C != C --> X >u C
+ // iff C+1 is a power of 2 (C is a bitmask of the low bits)
Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
return new ICmpInst(Pred, OrOp0, OrOp1);
}
- // More general: are all bits outside of a mask constant set or not set?
- // X | C == C --> (X & ~C) == 0
- // X | C != C --> (X & ~C) != 0
+
+ // More general: canonicalize 'equality with set bits mask' to
+ // 'equality with clear bits mask'.
+ // (X | MaskC) == C --> (X & ~MaskC) == C ^ MaskC
+ // (X | MaskC) != C --> (X & ~MaskC) != C ^ MaskC
if (Or->hasOneUse()) {
- Value *A = Builder.CreateAnd(OrOp0, ~C);
- return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType()));
+ Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC));
+ Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC));
+ return new ICmpInst(Pred, And, NewC);
}
}
@@ -2149,8 +2172,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
- if (ShType->isVectorTy())
- TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());
+ if (auto *ShVTy = dyn_cast<VectorType>(ShType))
+ TruncTy = FixedVectorType::get(TruncTy, ShVTy->getNumElements());
Constant *NewC =
ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC);
@@ -2763,6 +2786,37 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
if (match(BCSrcOp, m_UIToFP(m_Value(X))))
if (Cmp.isEquality() && match(Op1, m_Zero()))
return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType()));
+
+ // If this is a sign-bit test of a bitcast of a casted FP value, eliminate
+ // the FP extend/truncate because that cast does not change the sign-bit.
+ // This is true for all standard IEEE-754 types and the X86 80-bit type.
+ // The sign-bit is always the most significant bit in those types.
+ const APInt *C;
+ bool TrueIfSigned;
+ if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() &&
+ isSignBitCheck(Pred, *C, TrueIfSigned)) {
+ if (match(BCSrcOp, m_FPExt(m_Value(X))) ||
+ match(BCSrcOp, m_FPTrunc(m_Value(X)))) {
+ // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0
+ // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1
+ Type *XType = X->getType();
+
+ // We can't currently handle Power style floating point operations here.
+ if (!(XType->isPPC_FP128Ty() || BCSrcOp->getType()->isPPC_FP128Ty())) {
+
+ Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits());
+ if (auto *XVTy = dyn_cast<VectorType>(XType))
+ NewType = FixedVectorType::get(NewType, XVTy->getNumElements());
+ Value *NewBitcast = Builder.CreateBitCast(X, NewType);
+ if (TrueIfSigned)
+ return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast,
+ ConstantInt::getNullValue(NewType));
+ else
+ return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast,
+ ConstantInt::getAllOnesValue(NewType));
+ }
+ }
+ }
}
// Test to see if the operands of the icmp are casted versions of other
@@ -2792,11 +2846,10 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
return nullptr;
Value *Vec;
- Constant *Mask;
- if (match(BCSrcOp,
- m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) {
+ ArrayRef<int> Mask;
+ if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
// Check whether every element of Mask is the same constant
- if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) {
+ if (is_splat(Mask)) {
auto *VecTy = cast<VectorType>(BCSrcOp->getType());
auto *EltTy = cast<IntegerType>(VecTy->getElementType());
if (C->isSplat(EltTy->getBitWidth())) {
@@ -2805,6 +2858,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp,
// then:
// => %E = extractelement <N x iK> %vec, i32 Elem
// icmp <pred> iK %SplatVal, <pattern>
+ Value *Elem = Builder.getInt32(Mask[0]);
Value *Extract = Builder.CreateExtractElement(Vec, Elem);
Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth()));
return new ICmpInst(Pred, Extract, NewC);
@@ -2928,12 +2982,9 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
break;
case Instruction::Add: {
// Replace ((add A, B) != C) with (A != C-B) if B & C are constants.
- const APInt *BOC;
- if (match(BOp1, m_APInt(BOC))) {
- if (BO->hasOneUse()) {
- Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1));
- return new ICmpInst(Pred, BOp0, SubC);
- }
+ if (Constant *BOC = dyn_cast<Constant>(BOp1)) {
+ if (BO->hasOneUse())
+ return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC));
} else if (C.isNullValue()) {
// Replace ((add A, B) != 0) with (A != -B) if A or B is
// efficiently invertible, or if the add has just this one use.
@@ -2963,11 +3014,11 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,
break;
case Instruction::Sub:
if (BO->hasOneUse()) {
- const APInt *BOC;
- if (match(BOp0, m_APInt(BOC))) {
+ // Only check for constant LHS here, as constant RHS will be canonicalized
+ // to add and use the fold above.
+ if (Constant *BOC = dyn_cast<Constant>(BOp0)) {
// Replace ((sub BOC, B) != C) with (B != BOC-C).
- Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS);
- return new ICmpInst(Pred, BOp1, SubC);
+ return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS));
} else if (C.isNullValue()) {
// Replace ((sub A, B) != 0) with (A != B).
return new ICmpInst(Pred, BOp0, BOp1);
@@ -3028,20 +3079,16 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
unsigned BitWidth = C.getBitWidth();
switch (II->getIntrinsicID()) {
case Intrinsic::bswap:
- Worklist.Add(II);
- Cmp.setOperand(0, II->getArgOperand(0));
- Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap()));
- return &Cmp;
+ // bswap(A) == C -> A == bswap(C)
+ return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0),
+ ConstantInt::get(Ty, C.byteSwap()));
case Intrinsic::ctlz:
case Intrinsic::cttz: {
// ctz(A) == bitwidth(A) -> A == 0 and likewise for !=
- if (C == BitWidth) {
- Worklist.Add(II);
- Cmp.setOperand(0, II->getArgOperand(0));
- Cmp.setOperand(1, ConstantInt::getNullValue(Ty));
- return &Cmp;
- }
+ if (C == BitWidth)
+ return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0),
+ ConstantInt::getNullValue(Ty));
// ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set
// and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits.
@@ -3054,10 +3101,9 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
APInt Mask2 = IsTrailing
? APInt::getOneBitSet(BitWidth, Num)
: APInt::getOneBitSet(BitWidth, BitWidth - Num - 1);
- Cmp.setOperand(0, Builder.CreateAnd(II->getArgOperand(0), Mask1));
- Cmp.setOperand(1, ConstantInt::get(Ty, Mask2));
- Worklist.Add(II);
- return &Cmp;
+ return new ICmpInst(Cmp.getPredicate(),
+ Builder.CreateAnd(II->getArgOperand(0), Mask1),
+ ConstantInt::get(Ty, Mask2));
}
break;
}
@@ -3066,14 +3112,10 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
// popcount(A) == 0 -> A == 0 and likewise for !=
// popcount(A) == bitwidth(A) -> A == -1 and likewise for !=
bool IsZero = C.isNullValue();
- if (IsZero || C == BitWidth) {
- Worklist.Add(II);
- Cmp.setOperand(0, II->getArgOperand(0));
- auto *NewOp =
- IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty);
- Cmp.setOperand(1, NewOp);
- return &Cmp;
- }
+ if (IsZero || C == BitWidth)
+ return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0),
+ IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty));
+
break;
}
@@ -3081,9 +3123,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
// uadd.sat(a, b) == 0 -> (a | b) == 0
if (C.isNullValue()) {
Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1));
- return replaceInstUsesWith(Cmp, Builder.CreateICmp(
- Cmp.getPredicate(), Or, Constant::getNullValue(Ty)));
-
+ return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty));
}
break;
}
@@ -3093,8 +3133,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp,
if (C.isNullValue()) {
ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ
? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT;
- return ICmpInst::Create(Instruction::ICmp, NewPred,
- II->getArgOperand(0), II->getArgOperand(1));
+ return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1));
}
break;
}
@@ -3300,30 +3339,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
// x & (-1 >> y) != x -> x u> (-1 >> y)
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
- case ICmpInst::Predicate::ICMP_UGT:
+ case ICmpInst::Predicate::ICMP_ULT:
+ // x & (-1 >> y) u< x -> x u> (-1 >> y)
// x u> x & (-1 >> y) -> x u> (-1 >> y)
- assert(X == I.getOperand(0) && "instsimplify took care of commut. variant");
DstPred = ICmpInst::Predicate::ICMP_UGT;
break;
case ICmpInst::Predicate::ICMP_UGE:
// x & (-1 >> y) u>= x -> x u<= (-1 >> y)
- assert(X == I.getOperand(1) && "instsimplify took care of commut. variant");
- DstPred = ICmpInst::Predicate::ICMP_ULE;
- break;
- case ICmpInst::Predicate::ICMP_ULT:
- // x & (-1 >> y) u< x -> x u> (-1 >> y)
- assert(X == I.getOperand(1) && "instsimplify took care of commut. variant");
- DstPred = ICmpInst::Predicate::ICMP_UGT;
- break;
- case ICmpInst::Predicate::ICMP_ULE:
// x u<= x & (-1 >> y) -> x u<= (-1 >> y)
- assert(X == I.getOperand(0) && "instsimplify took care of commut. variant");
DstPred = ICmpInst::Predicate::ICMP_ULE;
break;
- case ICmpInst::Predicate::ICMP_SGT:
+ case ICmpInst::Predicate::ICMP_SLT:
+ // x & (-1 >> y) s< x -> x s> (-1 >> y)
// x s> x & (-1 >> y) -> x s> (-1 >> y)
- if (X != I.getOperand(0)) // X must be on LHS of comparison!
- return nullptr; // Ignore the other case.
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
return nullptr;
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
@@ -3332,33 +3360,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
break;
case ICmpInst::Predicate::ICMP_SGE:
// x & (-1 >> y) s>= x -> x s<= (-1 >> y)
- if (X != I.getOperand(1)) // X must be on RHS of comparison!
- return nullptr; // Ignore the other case.
+ // x s<= x & (-1 >> y) -> x s<= (-1 >> y)
if (!match(M, m_Constant())) // Can not do this fold with non-constant.
return nullptr;
if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
return nullptr;
DstPred = ICmpInst::Predicate::ICMP_SLE;
break;
- case ICmpInst::Predicate::ICMP_SLT:
- // x & (-1 >> y) s< x -> x s> (-1 >> y)
- if (X != I.getOperand(1)) // X must be on RHS of comparison!
- return nullptr; // Ignore the other case.
- if (!match(M, m_Constant())) // Can not do this fold with non-constant.
- return nullptr;
- if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
- return nullptr;
- DstPred = ICmpInst::Predicate::ICMP_SGT;
- break;
+ case ICmpInst::Predicate::ICMP_SGT:
case ICmpInst::Predicate::ICMP_SLE:
- // x s<= x & (-1 >> y) -> x s<= (-1 >> y)
- if (X != I.getOperand(0)) // X must be on LHS of comparison!
- return nullptr; // Ignore the other case.
- if (!match(M, m_Constant())) // Can not do this fold with non-constant.
- return nullptr;
- if (!match(M, m_NonNegative())) // Must not have any -1 vector elements.
- return nullptr;
- DstPred = ICmpInst::Predicate::ICMP_SLE;
+ return nullptr;
+ case ICmpInst::Predicate::ICMP_UGT:
+ case ICmpInst::Predicate::ICMP_ULE:
+ llvm_unreachable("Instsimplify took care of commut. variant");
break;
default:
llvm_unreachable("All possible folds are handled.");
@@ -3370,8 +3384,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,
Type *OpTy = M->getType();
auto *VecC = dyn_cast<Constant>(M);
if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) {
+ auto *OpVTy = cast<VectorType>(OpTy);
Constant *SafeReplacementConstant = nullptr;
- for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) {
+ for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) {
if (!isa<UndefValue>(VecC->getAggregateElement(i))) {
SafeReplacementConstant = VecC->getAggregateElement(i);
break;
@@ -3494,7 +3509,8 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
Instruction *NarrowestShift = XShift;
Type *WidestTy = WidestShift->getType();
- assert(NarrowestShift->getType() == I.getOperand(0)->getType() &&
+ Type *NarrowestTy = NarrowestShift->getType();
+ assert(NarrowestTy == I.getOperand(0)->getType() &&
"We did not look past any shifts while matching XShift though.");
bool HadTrunc = WidestTy != I.getOperand(0)->getType();
@@ -3533,6 +3549,23 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ,
if (XShAmt->getType() != YShAmt->getType())
return nullptr;
+ // As input, we have the following pattern:
+ // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0
+ // We want to rewrite that as:
+ // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x)
+ // While we know that originally (Q+K) would not overflow
+ // (because 2 * (N-1) u<= iN -1), we have looked past extensions of
+ // shift amounts. so it may now overflow in smaller bitwidth.
+ // To ensure that does not happen, we need to ensure that the total maximal
+ // shift amount is still representable in that smaller bit width.
+ unsigned MaximalPossibleTotalShiftAmount =
+ (WidestTy->getScalarSizeInBits() - 1) +
+ (NarrowestTy->getScalarSizeInBits() - 1);
+ APInt MaximalRepresentableShiftAmount =
+ APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits());
+ if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount))
+ return nullptr;
+
// Can we fold (XShAmt+YShAmt) ?
auto *NewShAmt = dyn_cast_or_null<Constant>(
SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false,
@@ -3627,9 +3660,6 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))),
m_Value(Y)))) {
Mul = nullptr;
- // Canonicalize as-if y was on RHS.
- if (I.getOperand(1) != Y)
- Pred = I.getSwappedPredicate();
// Are we checking that overflow does not happen, or does happen?
switch (Pred) {
@@ -3674,6 +3704,11 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) {
if (NeedNegation) // This technically increases instruction count.
Res = Builder.CreateNot(Res, "umul.not.ov");
+ // If we replaced the mul, erase it. Do this after all uses of Builder,
+ // as the mul is used as insertion point.
+ if (MulHadOtherUses)
+ eraseInstFromFunction(*Mul);
+
return Res;
}
@@ -4202,9 +4237,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
if (X) { // Build (X^Y) & Z
Op1 = Builder.CreateXor(X, Y);
Op1 = Builder.CreateAnd(Op1, Z);
- I.setOperand(0, Op1);
- I.setOperand(1, Constant::getNullValue(Op1->getType()));
- return &I;
+ return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType()));
}
}
@@ -4613,17 +4646,6 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
case ICmpInst::ICMP_NE:
// Recognize pattern:
// mulval = mul(zext A, zext B)
- // cmp eq/neq mulval, zext trunc mulval
- if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal))
- if (Zext->hasOneUse()) {
- Value *ZextArg = Zext->getOperand(0);
- if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg))
- if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth)
- break; //Recognized
- }
-
- // Recognize pattern:
- // mulval = mul(zext A, zext B)
// cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits.
ConstantInt *CI;
Value *ValToMask;
@@ -4701,7 +4723,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
Function *F = Intrinsic::getDeclaration(
I.getModule(), Intrinsic::umul_with_overflow, MulType);
CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul");
- IC.Worklist.Add(MulInstr);
+ IC.Worklist.push(MulInstr);
// If there are uses of mul result other than the comparison, we know that
// they are truncation or binary AND. Change them to use result of
@@ -4723,18 +4745,16 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,
ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
APInt ShortMask = CI->getValue().trunc(MulWidth);
Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask);
- Instruction *Zext =
- cast<Instruction>(Builder.CreateZExt(ShortAnd, BO->getType()));
- IC.Worklist.Add(Zext);
+ Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType());
IC.replaceInstUsesWith(*BO, Zext);
} else {
llvm_unreachable("Unexpected Binary operation");
}
- IC.Worklist.Add(cast<Instruction>(U));
+ IC.Worklist.push(cast<Instruction>(U));
}
}
if (isa<Instruction>(OtherVal))
- IC.Worklist.Add(cast<Instruction>(OtherVal));
+ IC.Worklist.push(cast<Instruction>(OtherVal));
// The original icmp gets replaced with the overflow value, maybe inverted
// depending on predicate.
@@ -5189,8 +5209,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
return llvm::None;
- } else if (Type->isVectorTy()) {
- unsigned NumElts = Type->getVectorNumElements();
+ } else if (auto *VTy = dyn_cast<VectorType>(Type)) {
+ unsigned NumElts = VTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
@@ -5252,6 +5272,47 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) {
return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second);
}
+/// If we have a comparison with a non-canonical predicate, if we can update
+/// all the users, invert the predicate and adjust all the users.
+static CmpInst *canonicalizeICmpPredicate(CmpInst &I) {
+ // Is the predicate already canonical?
+ CmpInst::Predicate Pred = I.getPredicate();
+ if (isCanonicalPredicate(Pred))
+ return nullptr;
+
+ // Can all users be adjusted to predicate inversion?
+ if (!canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
+ return nullptr;
+
+ // Ok, we can canonicalize comparison!
+ // Let's first invert the comparison's predicate.
+ I.setPredicate(CmpInst::getInversePredicate(Pred));
+ I.setName(I.getName() + ".not");
+
+ // And now let's adjust every user.
+ for (User *U : I.users()) {
+ switch (cast<Instruction>(U)->getOpcode()) {
+ case Instruction::Select: {
+ auto *SI = cast<SelectInst>(U);
+ SI->swapValues();
+ SI->swapProfMetadata();
+ break;
+ }
+ case Instruction::Br:
+ cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too
+ break;
+ case Instruction::Xor:
+ U->replaceAllUsesWith(&I);
+ break;
+ default:
+ llvm_unreachable("Got unexpected user - out of sync with "
+ "canFreelyInvertAllUsersOf() ?");
+ }
+ }
+
+ return &I;
+}
+
/// Integer compare with boolean values can always be turned into bitwise ops.
static Instruction *canonicalizeICmpBool(ICmpInst &I,
InstCombiner::BuilderTy &Builder) {
@@ -5338,10 +5399,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
Value *X, *Y;
if (match(&Cmp,
m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) {
- // We want X to be the icmp's second operand, so swap predicate if it isn't.
- if (Cmp.getOperand(0) == X)
- Pred = Cmp.getSwappedPredicate();
-
switch (Pred) {
case ICmpInst::ICMP_ULE:
NewPred = ICmpInst::ICMP_NE;
@@ -5361,10 +5418,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
// The variant with 'add' is not canonical, (the variant with 'not' is)
// we only get it because it has extra uses, and can't be canonicalized,
- // We want X to be the icmp's second operand, so swap predicate if it isn't.
- if (Cmp.getOperand(0) == X)
- Pred = Cmp.getSwappedPredicate();
-
switch (Pred) {
case ICmpInst::ICMP_ULT:
NewPred = ICmpInst::ICMP_NE;
@@ -5385,21 +5438,45 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp,
static Instruction *foldVectorCmp(CmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
- // If both arguments of the cmp are shuffles that use the same mask and
- // shuffle within a single vector, move the shuffle after the cmp.
+ const CmpInst::Predicate Pred = Cmp.getPredicate();
Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1);
Value *V1, *V2;
- Constant *M;
- if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(M))) &&
- match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(M))) &&
- V1->getType() == V2->getType() &&
- (LHS->hasOneUse() || RHS->hasOneUse())) {
- // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M
- CmpInst::Predicate P = Cmp.getPredicate();
- Value *NewCmp = isa<ICmpInst>(Cmp) ? Builder.CreateICmp(P, V1, V2)
- : Builder.CreateFCmp(P, V1, V2);
+ ArrayRef<int> M;
+ if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M))))
+ return nullptr;
+
+ // If both arguments of the cmp are shuffles that use the same mask and
+ // shuffle within a single vector, move the shuffle after the cmp:
+ // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M
+ Type *V1Ty = V1->getType();
+ if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) &&
+ V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) {
+ Value *NewCmp = Builder.CreateCmp(Pred, V1, V2);
return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M);
}
+
+ // Try to canonicalize compare with splatted operand and splat constant.
+ // TODO: We could generalize this for more than splats. See/use the code in
+ // InstCombiner::foldVectorBinop().
+ Constant *C;
+ if (!LHS->hasOneUse() || !match(RHS, m_Constant(C)))
+ return nullptr;
+
+ // Length-changing splats are ok, so adjust the constants as needed:
+ // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M
+ Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true);
+ int MaskSplatIndex;
+ if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) {
+ // We allow undefs in matching, but this transform removes those for safety.
+ // Demanded elements analysis should be able to recover some/all of that.
+ C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(),
+ ScalarC);
+ SmallVector<int, 8> NewM(M.size(), MaskSplatIndex);
+ Value *NewCmp = Builder.CreateCmp(Pred, V1, C);
+ return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()),
+ NewM);
+ }
+
return nullptr;
}
@@ -5474,8 +5551,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = canonicalizeICmpBool(I, Builder))
return Res;
- if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I))
- return NewICmp;
+ if (Instruction *Res = canonicalizeCmpWithConstant(I))
+ return Res;
+
+ if (Instruction *Res = canonicalizeICmpPredicate(I))
+ return Res;
if (Instruction *Res = foldICmpWithConstant(I))
return Res;
@@ -5565,6 +5645,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
if (Instruction *Res = foldICmpBitCast(I, Builder))
return Res;
+ // TODO: Hoist this above the min/max bailout.
if (Instruction *R = foldICmpWithCastOp(I))
return R;
@@ -5600,9 +5681,13 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {
isa<IntegerType>(A->getType())) {
Value *Result;
Constant *Overflow;
- if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B,
- *AddI, Result, Overflow)) {
+ // m_UAddWithOverflow can match patterns that do not include an explicit
+ // "add" instruction, so check the opcode of the matched op.
+ if (AddI->getOpcode() == Instruction::Add &&
+ OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI,
+ Result, Overflow)) {
replaceInstUsesWith(*AddI, Result);
+ eraseInstFromFunction(*AddI);
return replaceInstUsesWith(I, Overflow);
}
}
@@ -5689,7 +5774,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
// TODO: Can never be -0.0 and other non-representable values
APFloat RHSRoundInt(RHS);
RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven);
- if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) {
+ if (RHS != RHSRoundInt) {
if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ)
return replaceInstUsesWith(I, Builder.getFalse());
@@ -5777,7 +5862,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
APFloat SMax(RHS.getSemantics());
SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true,
APFloat::rmNearestTiesToEven);
- if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0
+ if (SMax < RHS) { // smax < 13123.0
if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT ||
Pred == ICmpInst::ICMP_SLE)
return replaceInstUsesWith(I, Builder.getTrue());
@@ -5789,7 +5874,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
APFloat UMax(RHS.getSemantics());
UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false,
APFloat::rmNearestTiesToEven);
- if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0
+ if (UMax < RHS) { // umax < 13123.0
if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT ||
Pred == ICmpInst::ICMP_ULE)
return replaceInstUsesWith(I, Builder.getTrue());
@@ -5802,7 +5887,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
APFloat SMin(RHS.getSemantics());
SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true,
APFloat::rmNearestTiesToEven);
- if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0
+ if (SMin > RHS) { // smin > 12312.0
if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT ||
Pred == ICmpInst::ICMP_SGE)
return replaceInstUsesWith(I, Builder.getTrue());
@@ -5810,10 +5895,10 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI,
}
} else {
// See if the RHS value is < UnsignedMin.
- APFloat SMin(RHS.getSemantics());
- SMin.convertFromAPInt(APInt::getMinValue(IntWidth), true,
+ APFloat UMin(RHS.getSemantics());
+ UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false,
APFloat::rmNearestTiesToEven);
- if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0
+ if (UMin > RHS) { // umin > 12312.0
if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT ||
Pred == ICmpInst::ICMP_UGE)
return replaceInstUsesWith(I, Builder.getTrue());
@@ -5949,16 +6034,15 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
}
/// Optimize fabs(X) compared with zero.
-static Instruction *foldFabsWithFcmpZero(FCmpInst &I) {
+static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) {
Value *X;
if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) ||
!match(I.getOperand(1), m_PosZeroFP()))
return nullptr;
- auto replacePredAndOp0 = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
+ auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
I->setPredicate(P);
- I->setOperand(0, X);
- return I;
+ return IC.replaceOperand(*I, 0, X);
};
switch (I.getPredicate()) {
@@ -6058,14 +6142,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
// If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand,
// then canonicalize the operand to 0.0.
if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) {
- if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) {
- I.setOperand(0, ConstantFP::getNullValue(OpType));
- return &I;
- }
- if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) {
- I.setOperand(1, ConstantFP::getNullValue(OpType));
- return &I;
- }
+ if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI))
+ return replaceOperand(I, 0, ConstantFP::getNullValue(OpType));
+
+ if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI))
+ return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
}
// fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y
@@ -6090,10 +6171,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
// The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0:
// fcmp Pred X, -0.0 --> fcmp Pred X, 0.0
- if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) {
- I.setOperand(1, ConstantFP::getNullValue(OpType));
- return &I;
- }
+ if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP()))
+ return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
// Handle fcmp with instruction LHS and constant RHS.
Instruction *LHSI;
@@ -6128,7 +6207,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
}
}
- if (Instruction *R = foldFabsWithFcmpZero(I))
+ if (Instruction *R = foldFabsWithFcmpZero(I, *this))
return R;
if (match(Op0, m_FNeg(m_Value(X)))) {
@@ -6159,8 +6238,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
APFloat Fabs = TruncC;
Fabs.clearSign();
if (!Lossy &&
- ((Fabs.compare(APFloat::getSmallestNormalized(FPSem)) !=
- APFloat::cmpLessThan) || Fabs.isZero())) {
+ (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) {
Constant *NewC = ConstantFP::get(X->getType(), TruncC);
return new FCmpInst(Pred, X, NewC, "", &I);
}