diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 498 |
1 files changed, 288 insertions, 210 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index f38dc436722dc..f1233b62445d0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -897,7 +897,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // For vectors, we apply the same reasoning on a per-lane basis. auto *Base = GEPLHS->getPointerOperand(); if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { - int NumElts = GEPLHS->getType()->getVectorNumElements(); + int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements(); Base = Builder.CreateVectorSplat(NumElts, Base); } return new ICmpInst(Cond, Base, @@ -1330,6 +1330,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // The inner add was the result of the narrow add, zero extended to the // wider type. Replace it with the result computed by the intrinsic. IC.replaceInstUsesWith(*OrigAdd, ZExt); + IC.eraseInstFromFunction(*OrigAdd); // The original icmp gets replaced with the overflow value. return ExtractValueInst::Create(Call, 1, "sadd.overflow"); @@ -1451,6 +1452,27 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this)) return Res; + // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...). + Constant *C = dyn_cast<Constant>(Op1); + if (!C) + return nullptr; + + if (auto *Phi = dyn_cast<PHINode>(Op0)) + if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { + Type *Ty = Cmp.getType(); + Builder.SetInsertPoint(Phi); + PHINode *NewPhi = + Builder.CreatePHI(Ty, Phi->getNumOperands()); + for (BasicBlock *Predecessor : predecessors(Phi->getParent())) { + auto *Input = + cast<Constant>(Phi->getIncomingValueForBlock(Predecessor)); + auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C); + NewPhi->addIncoming(BoolInput, Predecessor); + } + NewPhi->takeName(&Cmp); + return replaceInstUsesWith(Cmp, NewPhi); + } + return nullptr; } @@ -1575,11 +1597,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. - if (!XorC->isNegative()) { - Cmp.setOperand(0, X); - Worklist.Add(Xor); - return &Cmp; - } + if (!XorC->isNegative()) + return replaceOperand(Cmp, 0, X); // Emit the opposite comparison. if (TrueIfSigned) @@ -1645,51 +1664,53 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, bool IsShl = ShiftOpcode == Instruction::Shl; const APInt *C3; if (match(Shift->getOperand(1), m_APInt(C3))) { - bool CanFold = false; + APInt NewAndCst, NewCmpCst; + bool AnyCmpCstBitsShiftedOut; if (ShiftOpcode == Instruction::Shl) { // For a left shift, we can fold if the comparison is not signed. We can // also fold a signed comparison if the mask value and comparison value // are not negative. These constraints may not be obvious, but we can // prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative())) - CanFold = true; - } else { - bool IsAshr = ShiftOpcode == Instruction::AShr; + if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative())) + return nullptr; + + NewCmpCst = C1.lshr(*C3); + NewAndCst = C2.lshr(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1; + } else if (ShiftOpcode == Instruction::LShr) { // For a logical right shift, we can fold if the comparison is not signed. // We can also fold a signed comparison if the shifted mask value and the // shifted comparison value are not negative. These constraints may not be // obvious, but we can prove that they are correct using an SMT solver. - // For an arithmetic shift right we can do the same, if we ensure - // the And doesn't use any bits being shifted in. Normally these would - // be turned into lshr by SimplifyDemandedBits, but not if there is an - // additional user. - if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) { - if (!Cmp.isSigned() || - (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative())) - CanFold = true; - } + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1; + if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative())) + return nullptr; + } else { + // For an arithmetic shift, check that both constants don't use (in a + // signed sense) the top bits being shifted out. + assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode"); + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1; + if (NewAndCst.ashr(*C3) != C2) + return nullptr; } - if (CanFold) { - APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3); - APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); - // Check to see if we are shifting out any of the bits being compared. - if (SameAsC1 != C1) { - // If we shifted bits out, the fold is not going to work out. As a - // special case, check to see if this means that the result is always - // true or false now. - if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); - if (Cmp.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); - } else { - Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); - APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3); - And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); - And->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &Cmp; - } + if (AnyCmpCstBitsShiftedOut) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Value *NewAnd = Builder.CreateAnd( + Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst)); + return new ICmpInst(Cmp.getPredicate(), + NewAnd, ConstantInt::get(And->getType(), NewCmpCst)); } } @@ -1705,8 +1726,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Compute X & (C2 << Y). Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift); - Cmp.setOperand(0, NewAnd); - return &Cmp; + return replaceOperand(Cmp, 0, NewAnd); } return nullptr; @@ -1812,8 +1832,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } if (NewOr) { Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); - Cmp.setOperand(0, NewAnd); - return &Cmp; + return replaceOperand(Cmp, 0, NewAnd); } } } @@ -1863,8 +1882,8 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); - if (And->getType()->isVectorTy()) - NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); + if (auto *AndVTy = dyn_cast<VectorType>(And->getType())) + NTy = FixedVectorType::get(NTy, AndVTy->getNumElements()); Value *Trunc = Builder.CreateTrunc(X, NTy); auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT; @@ -1888,20 +1907,24 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, } Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); - if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) { - // X | C == C --> X <=u C - // X | C != C --> X >u C - // iff C+1 is a power of 2 (C is a bitmask of the low bits) - if ((C + 1).isPowerOf2()) { + const APInt *MaskC; + if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) { + if (*MaskC == C && (C + 1).isPowerOf2()) { + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; return new ICmpInst(Pred, OrOp0, OrOp1); } - // More general: are all bits outside of a mask constant set or not set? - // X | C == C --> (X & ~C) == 0 - // X | C != C --> (X & ~C) != 0 + + // More general: canonicalize 'equality with set bits mask' to + // 'equality with clear bits mask'. + // (X | MaskC) == C --> (X & ~MaskC) == C ^ MaskC + // (X | MaskC) != C --> (X & ~MaskC) != C ^ MaskC if (Or->hasOneUse()) { - Value *A = Builder.CreateAnd(OrOp0, ~C); - return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType())); + Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC)); + Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC)); + return new ICmpInst(Pred, And, NewC); } } @@ -2149,8 +2172,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (ShType->isVectorTy()) - TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); + if (auto *ShVTy = dyn_cast<VectorType>(ShType)) + TruncTy = FixedVectorType::get(TruncTy, ShVTy->getNumElements()); Constant *NewC = ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); @@ -2763,6 +2786,37 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, if (match(BCSrcOp, m_UIToFP(m_Value(X)))) if (Cmp.isEquality() && match(Op1, m_Zero())) return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // If this is a sign-bit test of a bitcast of a casted FP value, eliminate + // the FP extend/truncate because that cast does not change the sign-bit. + // This is true for all standard IEEE-754 types and the X86 80-bit type. + // The sign-bit is always the most significant bit in those types. + const APInt *C; + bool TrueIfSigned; + if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && + isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (match(BCSrcOp, m_FPExt(m_Value(X))) || + match(BCSrcOp, m_FPTrunc(m_Value(X)))) { + // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 + // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1 + Type *XType = X->getType(); + + // We can't currently handle Power style floating point operations here. + if (!(XType->isPPC_FP128Ty() || BCSrcOp->getType()->isPPC_FP128Ty())) { + + Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); + if (auto *XVTy = dyn_cast<VectorType>(XType)) + NewType = FixedVectorType::get(NewType, XVTy->getNumElements()); + Value *NewBitcast = Builder.CreateBitCast(X, NewType); + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast, + ConstantInt::getNullValue(NewType)); + else + return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast, + ConstantInt::getAllOnesValue(NewType)); + } + } + } } // Test to see if the operands of the icmp are casted versions of other @@ -2792,11 +2846,10 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, return nullptr; Value *Vec; - Constant *Mask; - if (match(BCSrcOp, - m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { + ArrayRef<int> Mask; + if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant - if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { + if (is_splat(Mask)) { auto *VecTy = cast<VectorType>(BCSrcOp->getType()); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { @@ -2805,6 +2858,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, // then: // => %E = extractelement <N x iK> %vec, i32 Elem // icmp <pred> iK %SplatVal, <pattern> + Value *Elem = Builder.getInt32(Mask[0]); Value *Extract = Builder.CreateExtractElement(Vec, Elem); Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); return new ICmpInst(Pred, Extract, NewC); @@ -2928,12 +2982,9 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; case Instruction::Add: { // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - const APInt *BOC; - if (match(BOp1, m_APInt(BOC))) { - if (BO->hasOneUse()) { - Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); - return new ICmpInst(Pred, BOp0, SubC); - } + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + if (BO->hasOneUse()) + return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); } else if (C.isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. @@ -2963,11 +3014,11 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; case Instruction::Sub: if (BO->hasOneUse()) { - const APInt *BOC; - if (match(BOp0, m_APInt(BOC))) { + // Only check for constant LHS here, as constant RHS will be canonicalized + // to add and use the fold above. + if (Constant *BOC = dyn_cast<Constant>(BOp0)) { // Replace ((sub BOC, B) != C) with (B != BOC-C). - Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); - return new ICmpInst(Pred, BOp1, SubC); + return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS)); } else if (C.isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); @@ -3028,20 +3079,16 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { case Intrinsic::bswap: - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap())); - return &Cmp; + // bswap(A) == C -> A == bswap(C) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + ConstantInt::get(Ty, C.byteSwap())); case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (C == BitWidth) { - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); - return &Cmp; - } + if (C == BitWidth) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + ConstantInt::getNullValue(Ty)); // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set // and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits. @@ -3054,10 +3101,9 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, APInt Mask2 = IsTrailing ? APInt::getOneBitSet(BitWidth, Num) : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); - Cmp.setOperand(0, Builder.CreateAnd(II->getArgOperand(0), Mask1)); - Cmp.setOperand(1, ConstantInt::get(Ty, Mask2)); - Worklist.Add(II); - return &Cmp; + return new ICmpInst(Cmp.getPredicate(), + Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); } break; } @@ -3066,14 +3112,10 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != bool IsZero = C.isNullValue(); - if (IsZero || C == BitWidth) { - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - auto *NewOp = - IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty); - Cmp.setOperand(1, NewOp); - return &Cmp; - } + if (IsZero || C == BitWidth) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty)); + break; } @@ -3081,9 +3123,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, // uadd.sat(a, b) == 0 -> (a | b) == 0 if (C.isNullValue()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); - return replaceInstUsesWith(Cmp, Builder.CreateICmp( - Cmp.getPredicate(), Or, Constant::getNullValue(Ty))); - + return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty)); } break; } @@ -3093,8 +3133,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, if (C.isNullValue()) { ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; - return ICmpInst::Create(Instruction::ICmp, NewPred, - II->getArgOperand(0), II->getArgOperand(1)); + return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } break; } @@ -3300,30 +3339,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x & (-1 >> y) != x -> x u> (-1 >> y) DstPred = ICmpInst::Predicate::ICMP_UGT; break; - case ICmpInst::Predicate::ICMP_UGT: + case ICmpInst::Predicate::ICMP_ULT: + // x & (-1 >> y) u< x -> x u> (-1 >> y) // x u> x & (-1 >> y) -> x u> (-1 >> y) - assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_UGE: // x & (-1 >> y) u>= x -> x u<= (-1 >> y) - assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); - DstPred = ICmpInst::Predicate::ICMP_ULE; - break; - case ICmpInst::Predicate::ICMP_ULT: - // x & (-1 >> y) u< x -> x u> (-1 >> y) - assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); - DstPred = ICmpInst::Predicate::ICMP_UGT; - break; - case ICmpInst::Predicate::ICMP_ULE: // x u<= x & (-1 >> y) -> x u<= (-1 >> y) - assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); DstPred = ICmpInst::Predicate::ICMP_ULE; break; - case ICmpInst::Predicate::ICMP_SGT: + case ICmpInst::Predicate::ICMP_SLT: + // x & (-1 >> y) s< x -> x s> (-1 >> y) // x s> x & (-1 >> y) -> x s> (-1 >> y) - if (X != I.getOperand(0)) // X must be on LHS of comparison! - return nullptr; // Ignore the other case. if (!match(M, m_Constant())) // Can not do this fold with non-constant. return nullptr; if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. @@ -3332,33 +3360,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, break; case ICmpInst::Predicate::ICMP_SGE: // x & (-1 >> y) s>= x -> x s<= (-1 >> y) - if (X != I.getOperand(1)) // X must be on RHS of comparison! - return nullptr; // Ignore the other case. + // x s<= x & (-1 >> y) -> x s<= (-1 >> y) if (!match(M, m_Constant())) // Can not do this fold with non-constant. return nullptr; if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. return nullptr; DstPred = ICmpInst::Predicate::ICMP_SLE; break; - case ICmpInst::Predicate::ICMP_SLT: - // x & (-1 >> y) s< x -> x s> (-1 >> y) - if (X != I.getOperand(1)) // X must be on RHS of comparison! - return nullptr; // Ignore the other case. - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; - DstPred = ICmpInst::Predicate::ICMP_SGT; - break; + case ICmpInst::Predicate::ICMP_SGT: case ICmpInst::Predicate::ICMP_SLE: - // x s<= x & (-1 >> y) -> x s<= (-1 >> y) - if (X != I.getOperand(0)) // X must be on LHS of comparison! - return nullptr; // Ignore the other case. - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; - DstPred = ICmpInst::Predicate::ICMP_SLE; + return nullptr; + case ICmpInst::Predicate::ICMP_UGT: + case ICmpInst::Predicate::ICMP_ULE: + llvm_unreachable("Instsimplify took care of commut. variant"); break; default: llvm_unreachable("All possible folds are handled."); @@ -3370,8 +3384,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, Type *OpTy = M->getType(); auto *VecC = dyn_cast<Constant>(M); if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) { + auto *OpVTy = cast<VectorType>(OpTy); Constant *SafeReplacementConstant = nullptr; - for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) { + for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa<UndefValue>(VecC->getAggregateElement(i))) { SafeReplacementConstant = VecC->getAggregateElement(i); break; @@ -3494,7 +3509,8 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, Instruction *NarrowestShift = XShift; Type *WidestTy = WidestShift->getType(); - assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + Type *NarrowestTy = NarrowestShift->getType(); + assert(NarrowestTy == I.getOperand(0)->getType() && "We did not look past any shifts while matching XShift though."); bool HadTrunc = WidestTy != I.getOperand(0)->getType(); @@ -3533,6 +3549,23 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, if (XShAmt->getType() != YShAmt->getType()) return nullptr; + // As input, we have the following pattern: + // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 + // We want to rewrite that as: + // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (WidestTy->getScalarSizeInBits() - 1) + + (NarrowestTy->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // Can we fold (XShAmt+YShAmt) ? auto *NewShAmt = dyn_cast_or_null<Constant>( SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, @@ -3627,9 +3660,6 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), m_Value(Y)))) { Mul = nullptr; - // Canonicalize as-if y was on RHS. - if (I.getOperand(1) != Y) - Pred = I.getSwappedPredicate(); // Are we checking that overflow does not happen, or does happen? switch (Pred) { @@ -3674,6 +3704,11 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (NeedNegation) // This technically increases instruction count. Res = Builder.CreateNot(Res, "umul.not.ov"); + // If we replaced the mul, erase it. Do this after all uses of Builder, + // as the mul is used as insertion point. + if (MulHadOtherUses) + eraseInstFromFunction(*Mul); + return Res; } @@ -4202,9 +4237,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { if (X) { // Build (X^Y) & Z Op1 = Builder.CreateXor(X, Y); Op1 = Builder.CreateAnd(Op1, Z); - I.setOperand(0, Op1); - I.setOperand(1, Constant::getNullValue(Op1->getType())); - return &I; + return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); } } @@ -4613,17 +4646,6 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, case ICmpInst::ICMP_NE: // Recognize pattern: // mulval = mul(zext A, zext B) - // cmp eq/neq mulval, zext trunc mulval - if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal)) - if (Zext->hasOneUse()) { - Value *ZextArg = Zext->getOperand(0); - if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg)) - if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth) - break; //Recognized - } - - // Recognize pattern: - // mulval = mul(zext A, zext B) // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. ConstantInt *CI; Value *ValToMask; @@ -4701,7 +4723,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Function *F = Intrinsic::getDeclaration( I.getModule(), Intrinsic::umul_with_overflow, MulType); CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); - IC.Worklist.Add(MulInstr); + IC.Worklist.push(MulInstr); // If there are uses of mul result other than the comparison, we know that // they are truncation or binary AND. Change them to use result of @@ -4723,18 +4745,16 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); APInt ShortMask = CI->getValue().trunc(MulWidth); Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); - Instruction *Zext = - cast<Instruction>(Builder.CreateZExt(ShortAnd, BO->getType())); - IC.Worklist.Add(Zext); + Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType()); IC.replaceInstUsesWith(*BO, Zext); } else { llvm_unreachable("Unexpected Binary operation"); } - IC.Worklist.Add(cast<Instruction>(U)); + IC.Worklist.push(cast<Instruction>(U)); } } if (isa<Instruction>(OtherVal)) - IC.Worklist.Add(cast<Instruction>(OtherVal)); + IC.Worklist.push(cast<Instruction>(OtherVal)); // The original icmp gets replaced with the overflow value, maybe inverted // depending on predicate. @@ -5189,8 +5209,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) return llvm::None; - } else if (Type->isVectorTy()) { - unsigned NumElts = Type->getVectorNumElements(); + } else if (auto *VTy = dyn_cast<VectorType>(Type)) { + unsigned NumElts = VTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -5252,6 +5272,47 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); } +/// If we have a comparison with a non-canonical predicate, if we can update +/// all the users, invert the predicate and adjust all the users. +static CmpInst *canonicalizeICmpPredicate(CmpInst &I) { + // Is the predicate already canonical? + CmpInst::Predicate Pred = I.getPredicate(); + if (isCanonicalPredicate(Pred)) + return nullptr; + + // Can all users be adjusted to predicate inversion? + if (!canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return nullptr; + + // Ok, we can canonicalize comparison! + // Let's first invert the comparison's predicate. + I.setPredicate(CmpInst::getInversePredicate(Pred)); + I.setName(I.getName() + ".not"); + + // And now let's adjust every user. + for (User *U : I.users()) { + switch (cast<Instruction>(U)->getOpcode()) { + case Instruction::Select: { + auto *SI = cast<SelectInst>(U); + SI->swapValues(); + SI->swapProfMetadata(); + break; + } + case Instruction::Br: + cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too + break; + case Instruction::Xor: + U->replaceAllUsesWith(&I); + break; + default: + llvm_unreachable("Got unexpected user - out of sync with " + "canFreelyInvertAllUsersOf() ?"); + } + } + + return &I; +} + /// Integer compare with boolean values can always be turned into bitwise ops. static Instruction *canonicalizeICmpBool(ICmpInst &I, InstCombiner::BuilderTy &Builder) { @@ -5338,10 +5399,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, Value *X, *Y; if (match(&Cmp, m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) { - // We want X to be the icmp's second operand, so swap predicate if it isn't. - if (Cmp.getOperand(0) == X) - Pred = Cmp.getSwappedPredicate(); - switch (Pred) { case ICmpInst::ICMP_ULE: NewPred = ICmpInst::ICMP_NE; @@ -5361,10 +5418,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, // The variant with 'add' is not canonical, (the variant with 'not' is) // we only get it because it has extra uses, and can't be canonicalized, - // We want X to be the icmp's second operand, so swap predicate if it isn't. - if (Cmp.getOperand(0) == X) - Pred = Cmp.getSwappedPredicate(); - switch (Pred) { case ICmpInst::ICMP_ULT: NewPred = ICmpInst::ICMP_NE; @@ -5385,21 +5438,45 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, static Instruction *foldVectorCmp(CmpInst &Cmp, InstCombiner::BuilderTy &Builder) { - // If both arguments of the cmp are shuffles that use the same mask and - // shuffle within a single vector, move the shuffle after the cmp. + const CmpInst::Predicate Pred = Cmp.getPredicate(); Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); Value *V1, *V2; - Constant *M; - if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(M))) && - match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(M))) && - V1->getType() == V2->getType() && - (LHS->hasOneUse() || RHS->hasOneUse())) { - // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M - CmpInst::Predicate P = Cmp.getPredicate(); - Value *NewCmp = isa<ICmpInst>(Cmp) ? Builder.CreateICmp(P, V1, V2) - : Builder.CreateFCmp(P, V1, V2); + ArrayRef<int> M; + if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) + return nullptr; + + // If both arguments of the cmp are shuffles that use the same mask and + // shuffle within a single vector, move the shuffle after the cmp: + // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M + Type *V1Ty = V1->getType(); + if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && + V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); } + + // Try to canonicalize compare with splatted operand and splat constant. + // TODO: We could generalize this for more than splats. See/use the code in + // InstCombiner::foldVectorBinop(). + Constant *C; + if (!LHS->hasOneUse() || !match(RHS, m_Constant(C))) + return nullptr; + + // Length-changing splats are ok, so adjust the constants as needed: + // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M + Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true); + int MaskSplatIndex; + if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) { + // We allow undefs in matching, but this transform removes those for safety. + // Demanded elements analysis should be able to recover some/all of that. + C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(), + ScalarC); + SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); + Value *NewCmp = Builder.CreateCmp(Pred, V1, C); + return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), + NewM); + } + return nullptr; } @@ -5474,8 +5551,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = canonicalizeICmpBool(I, Builder)) return Res; - if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) - return NewICmp; + if (Instruction *Res = canonicalizeCmpWithConstant(I)) + return Res; + + if (Instruction *Res = canonicalizeICmpPredicate(I)) + return Res; if (Instruction *Res = foldICmpWithConstant(I)) return Res; @@ -5565,6 +5645,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpBitCast(I, Builder)) return Res; + // TODO: Hoist this above the min/max bailout. if (Instruction *R = foldICmpWithCastOp(I)) return R; @@ -5600,9 +5681,13 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { isa<IntegerType>(A->getType())) { Value *Result; Constant *Overflow; - if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B, - *AddI, Result, Overflow)) { + // m_UAddWithOverflow can match patterns that do not include an explicit + // "add" instruction, so check the opcode of the matched op. + if (AddI->getOpcode() == Instruction::Add && + OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI, + Result, Overflow)) { replaceInstUsesWith(*AddI, Result); + eraseInstFromFunction(*AddI); return replaceInstUsesWith(I, Overflow); } } @@ -5689,7 +5774,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, // TODO: Can never be -0.0 and other non-representable values APFloat RHSRoundInt(RHS); RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); - if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { + if (RHS != RHSRoundInt) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) return replaceInstUsesWith(I, Builder.getFalse()); @@ -5777,7 +5862,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat SMax(RHS.getSemantics()); SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 + if (SMax < RHS) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5789,7 +5874,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat UMax(RHS.getSemantics()); UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 + if (UMax < RHS) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5802,7 +5887,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat SMin(RHS.getSemantics()); SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 + if (SMin > RHS) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5810,10 +5895,10 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, } } else { // See if the RHS value is < UnsignedMin. - APFloat SMin(RHS.getSemantics()); - SMin.convertFromAPInt(APInt::getMinValue(IntWidth), true, + APFloat UMin(RHS.getSemantics()); + UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0 + if (UMin > RHS) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5949,16 +6034,15 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, } /// Optimize fabs(X) compared with zero. -static Instruction *foldFabsWithFcmpZero(FCmpInst &I) { +static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) { Value *X; if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) || !match(I.getOperand(1), m_PosZeroFP())) return nullptr; - auto replacePredAndOp0 = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { I->setPredicate(P); - I->setOperand(0, X); - return I; + return IC.replaceOperand(*I, 0, X); }; switch (I.getPredicate()) { @@ -6058,14 +6142,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { - I.setOperand(0, ConstantFP::getNullValue(OpType)); - return &I; - } - if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { - I.setOperand(1, ConstantFP::getNullValue(OpType)); - return &I; - } + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) + return replaceOperand(I, 0, ConstantFP::getNullValue(OpType)); + + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); } // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y @@ -6090,10 +6171,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 - if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { - I.setOperand(1, ConstantFP::getNullValue(OpType)); - return &I; - } + if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); // Handle fcmp with instruction LHS and constant RHS. Instruction *LHSI; @@ -6128,7 +6207,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } - if (Instruction *R = foldFabsWithFcmpZero(I)) + if (Instruction *R = foldFabsWithFcmpZero(I, *this)) return R; if (match(Op0, m_FNeg(m_Value(X)))) { @@ -6159,8 +6238,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { APFloat Fabs = TruncC; Fabs.clearSign(); if (!Lossy && - ((Fabs.compare(APFloat::getSmallestNormalized(FPSem)) != - APFloat::cmpLessThan) || Fabs.isZero())) { + (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) { Constant *NewC = ConstantFP::get(X->getType(), TruncC); return new FCmpInst(Pred, X, NewC, "", &I); } |