diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
| -rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 643 | 
1 files changed, 416 insertions, 227 deletions
| diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index b5bbb09935e2..3a4283ae5406 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -1,9 +1,8 @@  //===- InstCombineCompares.cpp --------------------------------------------===//  // -//                     The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception  //  //===----------------------------------------------------------------------===//  // @@ -704,7 +703,10 @@ static Value *rewriteGEPAsOffset(Value *Start, Value *Base,        continue;      if (auto *CI = dyn_cast<CastInst>(Val)) { -      NewInsts[CI] = NewInsts[CI->getOperand(0)]; +      // Don't get rid of the intermediate variable here; the store can grow +      // the map which will invalidate the reference to the input value. +      Value *V = NewInsts[CI->getOperand(0)]; +      NewInsts[CI] = V;        continue;      }      if (auto *GEP = dyn_cast<GEPOperator>(Val)) { @@ -1292,8 +1294,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,    // use the sadd_with_overflow intrinsic to efficiently compute both the    // result and the overflow bit.    Type *NewType = IntegerType::get(OrigAdd->getContext(), NewWidth); -  Value *F = Intrinsic::getDeclaration(I.getModule(), -                                       Intrinsic::sadd_with_overflow, NewType); +  Function *F = Intrinsic::getDeclaration( +      I.getModule(), Intrinsic::sadd_with_overflow, NewType);    InstCombiner::BuilderTy &Builder = IC.Builder; @@ -1315,14 +1317,16 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B,    return ExtractValueInst::Create(Call, 1, "sadd.overflow");  } -// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +// Handle  icmp pred X, 0  Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {    CmpInst::Predicate Pred = Cmp.getPredicate(); -  Value *X = Cmp.getOperand(0); +  if (!match(Cmp.getOperand(1), m_Zero())) +    return nullptr; -  if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { +  // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +  if (Pred == ICmpInst::ICMP_SGT) {      Value *A, *B; -    SelectPatternResult SPR = matchSelectPattern(X, A, B); +    SelectPatternResult SPR = matchSelectPattern(Cmp.getOperand(0), A, B);      if (SPR.Flavor == SPF_SMIN) {        if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT))          return new ICmpInst(Pred, B, Cmp.getOperand(1)); @@ -1330,6 +1334,20 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) {          return new ICmpInst(Pred, A, Cmp.getOperand(1));      }    } + +  // Given: +  //   icmp eq/ne (urem %x, %y), 0 +  // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': +  //   icmp eq/ne %x, 0 +  Value *X, *Y; +  if (match(Cmp.getOperand(0), m_URem(m_Value(X), m_Value(Y))) && +      ICmpInst::isEquality(Pred)) { +    KnownBits XKnown = computeKnownBits(X, 0, &Cmp); +    KnownBits YKnown = computeKnownBits(Y, 0, &Cmp); +    if (XKnown.countMaxPopulation() == 1 && YKnown.countMinPopulation() >= 2) +      return new ICmpInst(Pred, X, Cmp.getOperand(1)); +  } +    return nullptr;  } @@ -1624,20 +1642,43 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,  Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,                                                   BinaryOperator *And,                                                   const APInt &C1) { +  bool isICMP_NE = Cmp.getPredicate() == ICmpInst::ICMP_NE; +    // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1    // TODO: We canonicalize to the longer form for scalars because we have    // better analysis/folds for icmp, and codegen may be better with icmp. -  if (Cmp.getPredicate() == CmpInst::ICMP_NE && Cmp.getType()->isVectorTy() && -      C1.isNullValue() && match(And->getOperand(1), m_One())) +  if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && +      match(And->getOperand(1), m_One()))      return new TruncInst(And->getOperand(0), Cmp.getType());    const APInt *C2; -  if (!match(And->getOperand(1), m_APInt(C2))) +  Value *X; +  if (!match(And, m_And(m_Value(X), m_APInt(C2))))      return nullptr; +  // Don't perform the following transforms if the AND has multiple uses    if (!And->hasOneUse())      return nullptr; +  if (Cmp.isEquality() && C1.isNullValue()) { +    // Restrict this fold to single-use 'and' (PR10267). +    // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 +    if (C2->isSignMask()) { +      Constant *Zero = Constant::getNullValue(X->getType()); +      auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; +      return new ICmpInst(NewPred, X, Zero); +    } + +    // Restrict this fold only for single-use 'and' (PR10267). +    // ((%x & C) == 0) --> %x u< (-C)  iff (-C) is power of two. +    if ((~(*C2) + 1).isPowerOf2()) { +      Constant *NegBOC = +          ConstantExpr::getNeg(cast<Constant>(And->getOperand(1))); +      auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; +      return new ICmpInst(NewPred, X, NegBOC); +    } +  } +    // If the LHS is an 'and' of a truncate and we can widen the and/compare to    // the input width without changing the value produced, eliminate the cast:    // @@ -1772,13 +1813,22 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,                            ConstantInt::get(V->getType(), 1));    } -  // X | C == C --> X <=u C -  // X | C != C --> X  >u C -  //   iff C+1 is a power of 2 (C is a bitmask of the low bits) -  if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && -      (C + 1).isPowerOf2()) { -    Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; -    return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); +  Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); +  if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) { +    // X | C == C --> X <=u C +    // X | C != C --> X  >u C +    //   iff C+1 is a power of 2 (C is a bitmask of the low bits) +    if ((C + 1).isPowerOf2()) { +      Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; +      return new ICmpInst(Pred, OrOp0, OrOp1); +    } +    // More general: are all bits outside of a mask constant set or not set? +    // X | C == C --> (X & ~C) == 0 +    // X | C != C --> (X & ~C) != 0 +    if (Or->hasOneUse()) { +      Value *A = Builder.CreateAnd(OrOp0, ~C); +      return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType())); +    }    }    if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) @@ -1799,8 +1849,8 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,    // Are we using xors to bitwise check for a pair of (in)equalities? Convert to    // a shorter form that has more potential to be folded even further.    Value *X1, *X2, *X3, *X4; -  if (match(Or->getOperand(0), m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && -      match(Or->getOperand(1), m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) { +  if (match(OrOp0, m_OneUse(m_Xor(m_Value(X1), m_Value(X2)))) && +      match(OrOp1, m_OneUse(m_Xor(m_Value(X3), m_Value(X4))))) {      // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4)      // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4)      Value *Cmp12 = Builder.CreateICmp(Pred, X1, X2); @@ -1994,6 +2044,27 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,                          And, Constant::getNullValue(ShType));    } +  // Simplify 'shl' inequality test into 'and' equality test. +  if (Cmp.isUnsigned() && Shl->hasOneUse()) { +    // (X l<< C2) u<=/u> C1 iff C1+1 is power of two -> X & (~C1 l>> C2) ==/!= 0 +    if ((C + 1).isPowerOf2() && +        (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)) { +      Value *And = Builder.CreateAnd(X, (~C).lshr(ShiftAmt->getZExtValue())); +      return new ICmpInst(Pred == ICmpInst::ICMP_ULE ? ICmpInst::ICMP_EQ +                                                     : ICmpInst::ICMP_NE, +                          And, Constant::getNullValue(ShType)); +    } +    // (X l<< C2) u</u>= C1 iff C1 is power of two -> X & (-C1 l>> C2) ==/!= 0 +    if (C.isPowerOf2() && +        (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) { +      Value *And = +          Builder.CreateAnd(X, (~(C - 1)).lshr(ShiftAmt->getZExtValue())); +      return new ICmpInst(Pred == ICmpInst::ICMP_ULT ? ICmpInst::ICMP_EQ +                                                     : ICmpInst::ICMP_NE, +                          And, Constant::getNullValue(ShType)); +    } +  } +    // Transform (icmp pred iM (shl iM %v, N), C)    // -> (icmp pred i(M-N) (trunc %v iM to i(M-N)), (trunc (C>>N))    // Transform the shl to a trunc if (trunc (C>>N)) has no loss and M-N. @@ -2313,6 +2384,16 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,                                                 const APInt &C) {    Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1);    ICmpInst::Predicate Pred = Cmp.getPredicate(); +  const APInt *C2; +  APInt SubResult; + +  // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) +  if (match(X, m_APInt(C2)) && +      ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || +       (Cmp.isSigned() && Sub->hasNoSignedWrap())) && +      !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) +    return new ICmpInst(Cmp.getSwappedPredicate(), Y, +                        ConstantInt::get(Y->getType(), SubResult));    // The following transforms are only worth it if the only user of the subtract    // is the icmp. @@ -2337,7 +2418,6 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp,        return new ICmpInst(ICmpInst::ICMP_SLE, X, Y);    } -  const APInt *C2;    if (!match(X, m_APInt(C2)))      return nullptr; @@ -2482,20 +2562,76 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp,      // the entire original Cmp can be simplified to a false.      Value *Cond = Builder.getFalse();      if (TrueWhenLessThan) -      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, OrigLHS, OrigRHS)); +      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SLT, +                                                       OrigLHS, OrigRHS));      if (TrueWhenEqual) -      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, OrigLHS, OrigRHS)); +      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_EQ, +                                                       OrigLHS, OrigRHS));      if (TrueWhenGreaterThan) -      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, OrigLHS, OrigRHS)); +      Cond = Builder.CreateOr(Cond, Builder.CreateICmp(ICmpInst::ICMP_SGT, +                                                       OrigLHS, OrigRHS));      return replaceInstUsesWith(Cmp, Cond);    }    return nullptr;  } -Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp, -                                                   BitCastInst *Bitcast, -                                                   const APInt &C) { +static Instruction *foldICmpBitCast(ICmpInst &Cmp, +                                    InstCombiner::BuilderTy &Builder) { +  auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); +  if (!Bitcast) +    return nullptr; + +  ICmpInst::Predicate Pred = Cmp.getPredicate(); +  Value *Op1 = Cmp.getOperand(1); +  Value *BCSrcOp = Bitcast->getOperand(0); + +  // Make sure the bitcast doesn't change the number of vector elements. +  if (Bitcast->getSrcTy()->getScalarSizeInBits() == +          Bitcast->getDestTy()->getScalarSizeInBits()) { +    // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. +    Value *X; +    if (match(BCSrcOp, m_SIToFP(m_Value(X)))) { +      // icmp  eq (bitcast (sitofp X)), 0 --> icmp  eq X, 0 +      // icmp  ne (bitcast (sitofp X)), 0 --> icmp  ne X, 0 +      // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 +      // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 +      if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || +           Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && +          match(Op1, m_Zero())) +        return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + +      // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 +      if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) +        return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); + +      // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 +      if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) +        return new ICmpInst(Pred, X, +                            ConstantInt::getAllOnesValue(X->getType())); +    } + +    // Zero-equality checks are preserved through unsigned floating-point casts: +    // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 +    // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 +    if (match(BCSrcOp, m_UIToFP(m_Value(X)))) +      if (Cmp.isEquality() && match(Op1, m_Zero())) +        return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); +  } + +  // Test to see if the operands of the icmp are casted versions of other +  // values. If the ptr->ptr cast can be stripped off both arguments, do so. +  if (Bitcast->getType()->isPointerTy() && +      (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { +    // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast +    // so eliminate it as well. +    if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) +      Op1 = BC2->getOperand(0); + +    Op1 = Builder.CreateBitCast(Op1, BCSrcOp->getType()); +    return new ICmpInst(Pred, BCSrcOp, Op1); +  } +    // Folding: icmp <pred> iN X, C    //  where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN    //    and C is a splat of a K-bit pattern @@ -2503,28 +2639,28 @@ Instruction *InstCombiner::foldICmpBitCastConstant(ICmpInst &Cmp,    // Into:    //   %E = extractelement <M x iK> %vec, i32 C'    //   icmp <pred> iK %E, trunc(C) -  if (!Bitcast->getType()->isIntegerTy() || +  const APInt *C; +  if (!match(Cmp.getOperand(1), m_APInt(C)) || +      !Bitcast->getType()->isIntegerTy() ||        !Bitcast->getSrcTy()->isIntOrIntVectorTy())      return nullptr; -  Value *BCIOp = Bitcast->getOperand(0); -  Value *Vec = nullptr;     // 1st vector arg of the shufflevector -  Constant *Mask = nullptr; // Mask arg of the shufflevector -  if (match(BCIOp, +  Value *Vec; +  Constant *Mask; +  if (match(BCSrcOp,              m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) {      // Check whether every element of Mask is the same constant      if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { -      auto *VecTy = cast<VectorType>(BCIOp->getType()); +      auto *VecTy = cast<VectorType>(BCSrcOp->getType());        auto *EltTy = cast<IntegerType>(VecTy->getElementType()); -      auto Pred = Cmp.getPredicate(); -      if (C.isSplat(EltTy->getBitWidth())) { +      if (C->isSplat(EltTy->getBitWidth())) {          // Fold the icmp based on the value of C          // If C is M copies of an iK sized bit pattern,          // then:          //   =>  %E = extractelement <N x iK> %vec, i32 Elem          //       icmp <pred> iK %SplatVal, <pattern>          Value *Extract = Builder.CreateExtractElement(Vec, Elem); -        Value *NewC = ConstantInt::get(EltTy, C.trunc(EltTy->getBitWidth())); +        Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth()));          return new ICmpInst(Pred, Extract, NewC);        }      } @@ -2606,13 +2742,9 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) {        return I;    } -  if (auto *BCI = dyn_cast<BitCastInst>(Cmp.getOperand(0))) { -    if (Instruction *I = foldICmpBitCastConstant(Cmp, BCI, *C)) +  if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) +    if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C))        return I; -  } - -  if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) -    return I;    return nullptr;  } @@ -2711,24 +2843,6 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,        if (C == *BOC && C.isPowerOf2())          return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE,                              BO, Constant::getNullValue(RHS->getType())); - -      // Don't perform the following transforms if the AND has multiple uses -      if (!BO->hasOneUse()) -        break; - -      // Replace (and X, (1 << size(X)-1) != 0) with x s< 0 -      if (BOC->isSignMask()) { -        Constant *Zero = Constant::getNullValue(BOp0->getType()); -        auto NewPred = isICMP_NE ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGE; -        return new ICmpInst(NewPred, BOp0, Zero); -      } - -      // ((X & ~7) == 0) --> X < 8 -      if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { -        Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); -        auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; -        return new ICmpInst(NewPred, BOp0, NegBOC); -      }      }      break;    } @@ -2756,14 +2870,10 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp,    return nullptr;  } -/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. -Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, -                                                         const APInt &C) { -  IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); -  if (!II || !Cmp.isEquality()) -    return nullptr; - -  // Handle icmp {eq|ne} <intrinsic>, Constant. +/// Fold an equality icmp with LLVM intrinsic and constant operand. +Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, +                                                           IntrinsicInst *II, +                                                           const APInt &C) {    Type *Ty = II->getType();    unsigned BitWidth = C.getBitWidth();    switch (II->getIntrinsicID()) { @@ -2823,6 +2933,65 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp,    return nullptr;  } +/// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. +Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, +                                                         IntrinsicInst *II, +                                                         const APInt &C) { +  if (Cmp.isEquality()) +    return foldICmpEqIntrinsicWithConstant(Cmp, II, C); + +  Type *Ty = II->getType(); +  unsigned BitWidth = C.getBitWidth(); +  switch (II->getIntrinsicID()) { +  case Intrinsic::ctlz: { +    // ctlz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX < 0b00010000 +    if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { +      unsigned Num = C.getLimitedValue(); +      APInt Limit = APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); +      return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_ULT, +                             II->getArgOperand(0), ConstantInt::get(Ty, Limit)); +    } + +    // ctlz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX > 0b00011111 +    if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && +        C.uge(1) && C.ule(BitWidth)) { +      unsigned Num = C.getLimitedValue(); +      APInt Limit = APInt::getLowBitsSet(BitWidth, BitWidth - Num); +      return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_UGT, +                             II->getArgOperand(0), ConstantInt::get(Ty, Limit)); +    } +    break; +  } +  case Intrinsic::cttz: { +    // Limit to one use to ensure we don't increase instruction count. +    if (!II->hasOneUse()) +      return nullptr; + +    // cttz(0bXXXXXXXX) > 3 -> 0bXXXXXXXX & 0b00001111 == 0 +    if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && C.ult(BitWidth)) { +      APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue() + 1); +      return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_EQ, +                             Builder.CreateAnd(II->getArgOperand(0), Mask), +                             ConstantInt::getNullValue(Ty)); +    } + +    // cttz(0bXXXXXXXX) < 3 -> 0bXXXXXXXX & 0b00000111 != 0 +    if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && +        C.uge(1) && C.ule(BitWidth)) { +      APInt Mask = APInt::getLowBitsSet(BitWidth, C.getLimitedValue()); +      return CmpInst::Create(Instruction::ICmp, ICmpInst::ICMP_NE, +                             Builder.CreateAnd(II->getArgOperand(0), Mask), +                             ConstantInt::getNullValue(Ty)); +    } +    break; +  } +  default: +    break; +  } + +  return nullptr; +} +  /// Handle icmp with constant (but not simple integer constant) RHS.  Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) {    Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -2983,6 +3152,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,      //  x s> x & (-1 >> y)    ->    x s> (-1 >> y)      if (X != I.getOperand(0)) // X must be on LHS of comparison!        return nullptr;         // Ignore the other case. +    if (!match(M, m_Constant())) // Can not do this fold with non-constant. +      return nullptr; +    if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. +      return nullptr;      DstPred = ICmpInst::Predicate::ICMP_SGT;      break;    case ICmpInst::Predicate::ICMP_SGE: @@ -3009,6 +3182,10 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I,      //  x s<= x & (-1 >> y)    ->    x s<= (-1 >> y)      if (X != I.getOperand(0)) // X must be on LHS of comparison!        return nullptr;         // Ignore the other case. +    if (!match(M, m_Constant())) // Can not do this fold with non-constant. +      return nullptr; +    if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. +      return nullptr;      DstPred = ICmpInst::Predicate::ICMP_SLE;      break;    default: @@ -3093,6 +3270,64 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I,    return T1;  } +// Given pattern: +//   icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 +// we should move shifts to the same hand of 'and', i.e. rewrite as +//   icmp eq/ne (and (x shift (Q+K)), y), 0  iff (Q+K) u< bitwidth(x) +// We are only interested in opposite logical shifts here. +// If we can, we want to end up creating 'lshr' shift. +static Value * +foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, +                                           InstCombiner::BuilderTy &Builder) { +  if (!I.isEquality() || !match(I.getOperand(1), m_Zero()) || +      !I.getOperand(0)->hasOneUse()) +    return nullptr; + +  auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); +  auto m_AnyLShr = m_LShr(m_Value(), m_Value()); + +  // Look for an 'and' of two (opposite) logical shifts. +  // Pick the single-use shift as XShift. +  Value *XShift, *YShift; +  if (!match(I.getOperand(0), +             m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), +                     m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) +    return nullptr; + +  // If YShift is a single-use 'lshr', swap the shifts around. +  if (match(YShift, m_OneUse(m_AnyLShr))) +    std::swap(XShift, YShift); + +  // The shifts must be in opposite directions. +  Instruction::BinaryOps XShiftOpcode = +      cast<BinaryOperator>(XShift)->getOpcode(); +  if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) +    return nullptr; // Do not care about same-direction shifts here. + +  Value *X, *XShAmt, *Y, *YShAmt; +  match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); +  match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + +  // Can we fold (XShAmt+YShAmt) ? +  Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, +                                  SQ.getWithInstruction(&I)); +  if (!NewShAmt) +    return nullptr; +  // Is the new shift amount smaller than the bit width? +  // FIXME: could also rely on ConstantRange. +  unsigned BitWidth = X->getType()->getScalarSizeInBits(); +  if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, +                                          APInt(BitWidth, BitWidth)))) +    return nullptr; +  // All good, we can do this fold. The shift is the same that was for X. +  Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr +                  ? Builder.CreateLShr(X, NewShAmt) +                  : Builder.CreateShl(X, NewShAmt); +  Value *T1 = Builder.CreateAnd(T0, Y); +  return Builder.CreateICmp(I.getPredicate(), T1, +                            Constant::getNullValue(X->getType())); +} +  /// Try to fold icmp (binop), X or icmp X, (binop).  /// TODO: A large part of this logic is duplicated in InstSimplify's  /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code @@ -3448,6 +3683,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {    if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder))      return replaceInstUsesWith(I, V); +  if (Value *V = foldShiftIntoShiftInAnotherHandOfAndInICmp(I, SQ, Builder)) +    return replaceInstUsesWith(I, V); +    return nullptr;  } @@ -3688,6 +3926,30 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {         match(Op1, m_BitReverse(m_Value(B)))))      return new ICmpInst(Pred, A, B); +  // Canonicalize checking for a power-of-2-or-zero value: +  // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) +  // ((A-1) & A) != 0 --> ctpop(A) > 1 (two commuted variants) +  if (!match(Op0, m_OneUse(m_c_And(m_Add(m_Value(A), m_AllOnes()), +                                   m_Deferred(A)))) || +      !match(Op1, m_ZeroInt())) +    A = nullptr; + +  // (A & -A) == A --> ctpop(A) < 2 (four commuted variants) +  // (-A & A) != A --> ctpop(A) > 1 (four commuted variants) +  if (match(Op0, m_OneUse(m_c_And(m_Neg(m_Specific(Op1)), m_Specific(Op1))))) +    A = Op1; +  else if (match(Op1, +                 m_OneUse(m_c_And(m_Neg(m_Specific(Op0)), m_Specific(Op0))))) +    A = Op0; + +  if (A) { +    Type *Ty = A->getType(); +    CallInst *CtPop = Builder.CreateUnaryIntrinsic(Intrinsic::ctpop, A); +    return Pred == ICmpInst::ICMP_EQ +        ? new ICmpInst(ICmpInst::ICMP_ULT, CtPop, ConstantInt::get(Ty, 2)) +        : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); +  } +    return nullptr;  } @@ -3698,7 +3960,6 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) {    Value *LHSCIOp        = LHSCI->getOperand(0);    Type *SrcTy     = LHSCIOp->getType();    Type *DestTy    = LHSCI->getType(); -  Value *RHSCIOp;    // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the    // integer type is the same size as the pointer type. @@ -3740,7 +4001,7 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) {    if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) {      // Not an extension from the same type? -    RHSCIOp = CI->getOperand(0); +    Value *RHSCIOp = CI->getOperand(0);      if (RHSCIOp->getType() != LHSCIOp->getType())        return nullptr; @@ -3813,104 +4074,83 @@ Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) {    return BinaryOperator::CreateNot(Result);  } -bool InstCombiner::OptimizeOverflowCheck(OverflowCheckFlavor OCF, Value *LHS, -                                         Value *RHS, Instruction &OrigI, -                                         Value *&Result, Constant *&Overflow) { +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { +  switch (BinaryOp) { +    default: +      llvm_unreachable("Unsupported binary op"); +    case Instruction::Add: +    case Instruction::Sub: +      return match(RHS, m_Zero()); +    case Instruction::Mul: +      return match(RHS, m_One()); +  } +} + +OverflowResult InstCombiner::computeOverflow( +    Instruction::BinaryOps BinaryOp, bool IsSigned, +    Value *LHS, Value *RHS, Instruction *CxtI) const { +  switch (BinaryOp) { +    default: +      llvm_unreachable("Unsupported binary op"); +    case Instruction::Add: +      if (IsSigned) +        return computeOverflowForSignedAdd(LHS, RHS, CxtI); +      else +        return computeOverflowForUnsignedAdd(LHS, RHS, CxtI); +    case Instruction::Sub: +      if (IsSigned) +        return computeOverflowForSignedSub(LHS, RHS, CxtI); +      else +        return computeOverflowForUnsignedSub(LHS, RHS, CxtI); +    case Instruction::Mul: +      if (IsSigned) +        return computeOverflowForSignedMul(LHS, RHS, CxtI); +      else +        return computeOverflowForUnsignedMul(LHS, RHS, CxtI); +  } +} + +bool InstCombiner::OptimizeOverflowCheck( +    Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, +    Instruction &OrigI, Value *&Result, Constant *&Overflow) {    if (OrigI.isCommutative() && isa<Constant>(LHS) && !isa<Constant>(RHS))      std::swap(LHS, RHS); -  auto SetResult = [&](Value *OpResult, Constant *OverflowVal, bool ReuseName) { -    Result = OpResult; -    Overflow = OverflowVal; -    if (ReuseName) -      Result->takeName(&OrigI); -    return true; -  }; -    // If the overflow check was an add followed by a compare, the insertion point    // may be pointing to the compare.  We want to insert the new instructions    // before the add in case there are uses of the add between the add and the    // compare.    Builder.SetInsertPoint(&OrigI); -  switch (OCF) { -  case OCF_INVALID: -    llvm_unreachable("bad overflow check kind!"); - -  case OCF_UNSIGNED_ADD: { -    OverflowResult OR = computeOverflowForUnsignedAdd(LHS, RHS, &OrigI); -    if (OR == OverflowResult::NeverOverflows) -      return SetResult(Builder.CreateNUWAdd(LHS, RHS), Builder.getFalse(), -                       true); - -    if (OR == OverflowResult::AlwaysOverflows) -      return SetResult(Builder.CreateAdd(LHS, RHS), Builder.getTrue(), true); - -    // Fall through uadd into sadd -    LLVM_FALLTHROUGH; -  } -  case OCF_SIGNED_ADD: { -    // X + 0 -> {X, false} -    if (match(RHS, m_Zero())) -      return SetResult(LHS, Builder.getFalse(), false); - -    // We can strength reduce this signed add into a regular add if we can prove -    // that it will never overflow. -    if (OCF == OCF_SIGNED_ADD) -      if (willNotOverflowSignedAdd(LHS, RHS, OrigI)) -        return SetResult(Builder.CreateNSWAdd(LHS, RHS), Builder.getFalse(), -                         true); -    break; -  } - -  case OCF_UNSIGNED_SUB: -  case OCF_SIGNED_SUB: { -    // X - 0 -> {X, false} -    if (match(RHS, m_Zero())) -      return SetResult(LHS, Builder.getFalse(), false); - -    if (OCF == OCF_SIGNED_SUB) { -      if (willNotOverflowSignedSub(LHS, RHS, OrigI)) -        return SetResult(Builder.CreateNSWSub(LHS, RHS), Builder.getFalse(), -                         true); -    } else { -      if (willNotOverflowUnsignedSub(LHS, RHS, OrigI)) -        return SetResult(Builder.CreateNUWSub(LHS, RHS), Builder.getFalse(), -                         true); -    } -    break; -  } - -  case OCF_UNSIGNED_MUL: { -    OverflowResult OR = computeOverflowForUnsignedMul(LHS, RHS, &OrigI); -    if (OR == OverflowResult::NeverOverflows) -      return SetResult(Builder.CreateNUWMul(LHS, RHS), Builder.getFalse(), -                       true); -    if (OR == OverflowResult::AlwaysOverflows) -      return SetResult(Builder.CreateMul(LHS, RHS), Builder.getTrue(), true); -    LLVM_FALLTHROUGH; +  if (isNeutralValue(BinaryOp, RHS)) { +    Result = LHS; +    Overflow = Builder.getFalse(); +    return true;    } -  case OCF_SIGNED_MUL: -    // X * undef -> undef -    if (isa<UndefValue>(RHS)) -      return SetResult(RHS, UndefValue::get(Builder.getInt1Ty()), false); - -    // X * 0 -> {0, false} -    if (match(RHS, m_Zero())) -      return SetResult(RHS, Builder.getFalse(), false); - -    // X * 1 -> {X, false} -    if (match(RHS, m_One())) -      return SetResult(LHS, Builder.getFalse(), false); -    if (OCF == OCF_SIGNED_MUL) -      if (willNotOverflowSignedMul(LHS, RHS, OrigI)) -        return SetResult(Builder.CreateNSWMul(LHS, RHS), Builder.getFalse(), -                         true); -    break; +  switch (computeOverflow(BinaryOp, IsSigned, LHS, RHS, &OrigI)) { +    case OverflowResult::MayOverflow: +      return false; +    case OverflowResult::AlwaysOverflowsLow: +    case OverflowResult::AlwaysOverflowsHigh: +      Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); +      Result->takeName(&OrigI); +      Overflow = Builder.getTrue(); +      return true; +    case OverflowResult::NeverOverflows: +      Result = Builder.CreateBinOp(BinaryOp, LHS, RHS); +      Result->takeName(&OrigI); +      Overflow = Builder.getFalse(); +      if (auto *Inst = dyn_cast<Instruction>(Result)) { +        if (IsSigned) +          Inst->setHasNoSignedWrap(); +        else +          Inst->setHasNoUnsignedWrap(); +      } +      return true;    } -  return false; +  llvm_unreachable("Unexpected overflow result");  }  /// Recognize and process idiom involving test for multiplication @@ -4084,8 +4324,8 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal,      MulA = Builder.CreateZExt(A, MulType);    if (WidthB < MulWidth)      MulB = Builder.CreateZExt(B, MulType); -  Value *F = Intrinsic::getDeclaration(I.getModule(), -                                       Intrinsic::umul_with_overflow, MulType); +  Function *F = Intrinsic::getDeclaration( +      I.getModule(), Intrinsic::umul_with_overflow, MulType);    CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul");    IC.Worklist.Add(MulInstr); @@ -4881,61 +5121,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {          return New;    } -  // Zero-equality and sign-bit checks are preserved through sitofp + bitcast. -  Value *X; -  if (match(Op0, m_BitCast(m_SIToFP(m_Value(X))))) { -    // icmp  eq (bitcast (sitofp X)), 0 --> icmp  eq X, 0 -    // icmp  ne (bitcast (sitofp X)), 0 --> icmp  ne X, 0 -    // icmp slt (bitcast (sitofp X)), 0 --> icmp slt X, 0 -    // icmp sgt (bitcast (sitofp X)), 0 --> icmp sgt X, 0 -    if ((Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_SLT || -         Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT) && -        match(Op1, m_Zero())) -      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - -    // icmp slt (bitcast (sitofp X)), 1 --> icmp slt X, 1 -    if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_One())) -      return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), 1)); - -    // icmp sgt (bitcast (sitofp X)), -1 --> icmp sgt X, -1 -    if (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes())) -      return new ICmpInst(Pred, X, ConstantInt::getAllOnesValue(X->getType())); -  } - -  // Zero-equality checks are preserved through unsigned floating-point casts: -  // icmp eq (bitcast (uitofp X)), 0 --> icmp eq X, 0 -  // icmp ne (bitcast (uitofp X)), 0 --> icmp ne X, 0 -  if (match(Op0, m_BitCast(m_UIToFP(m_Value(X))))) -    if (I.isEquality() && match(Op1, m_Zero())) -      return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); - -  // Test to see if the operands of the icmp are casted versions of other -  // values.  If the ptr->ptr cast can be stripped off both arguments, we do so -  // now. -  if (BitCastInst *CI = dyn_cast<BitCastInst>(Op0)) { -    if (Op0->getType()->isPointerTy() && -        (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { -      // We keep moving the cast from the left operand over to the right -      // operand, where it can often be eliminated completely. -      Op0 = CI->getOperand(0); - -      // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast -      // so eliminate it as well. -      if (BitCastInst *CI2 = dyn_cast<BitCastInst>(Op1)) -        Op1 = CI2->getOperand(0); - -      // If Op1 is a constant, we can fold the cast into the constant. -      if (Op0->getType() != Op1->getType()) { -        if (Constant *Op1C = dyn_cast<Constant>(Op1)) { -          Op1 = ConstantExpr::getBitCast(Op1C, Op0->getType()); -        } else { -          // Otherwise, cast the RHS right before the icmp -          Op1 = Builder.CreateBitCast(Op1, Op0->getType()); -        } -      } -      return new ICmpInst(I.getPredicate(), Op0, Op1); -    } -  } +  if (Instruction *Res = foldICmpBitCast(I, Builder)) +    return Res;    if (isa<CastInst>(Op0)) {      // Handle the special case of: icmp (cast bool to X), <cst> @@ -4984,8 +5171,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) {          isa<IntegerType>(A->getType())) {        Value *Result;        Constant *Overflow; -      if (OptimizeOverflowCheck(OCF_UNSIGNED_ADD, A, B, *AddI, Result, -                                Overflow)) { +      if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B, +                                *AddI, Result, Overflow)) {          replaceInstUsesWith(*AddI, Result);          return replaceInstUsesWith(I, Overflow);        } @@ -5411,6 +5598,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {      return replaceInstUsesWith(I, V);    // Simplify 'fcmp pred X, X' +  Type *OpType = Op0->getType(); +  assert(OpType == Op1->getType() && "fcmp with different-typed operands?");    if (Op0 == Op1) {      switch (Pred) {        default: break; @@ -5420,7 +5609,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {      case FCmpInst::FCMP_UNE:    // True if unordered or not equal        // Canonicalize these to be 'fcmp uno %X, 0.0'.        I.setPredicate(FCmpInst::FCMP_UNO); -      I.setOperand(1, Constant::getNullValue(Op0->getType())); +      I.setOperand(1, Constant::getNullValue(OpType));        return &I;      case FCmpInst::FCMP_ORD:    // True if ordered (no nans) @@ -5429,7 +5618,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {      case FCmpInst::FCMP_OLE:    // True if ordered and less than or equal        // Canonicalize these to be 'fcmp ord %X, 0.0'.        I.setPredicate(FCmpInst::FCMP_ORD); -      I.setOperand(1, Constant::getNullValue(Op0->getType())); +      I.setOperand(1, Constant::getNullValue(OpType));        return &I;      }    } @@ -5438,15 +5627,20 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {    // then canonicalize the operand to 0.0.    if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) {      if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { -      I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); +      I.setOperand(0, ConstantFP::getNullValue(OpType));        return &I;      }      if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { -      I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); +      I.setOperand(1, ConstantFP::getNullValue(OpType));        return &I;      }    } +  // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y +  Value *X, *Y; +  if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) +    return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); +    // Test if the FCmpInst instruction is used exclusively by a select as    // part of a minimum or maximum operation. If so, refrain from doing    // any other folding. This helps out other analyses which understand @@ -5465,7 +5659,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {    // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0:    // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0    if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { -    I.setOperand(1, ConstantFP::getNullValue(Op1->getType())); +    I.setOperand(1, ConstantFP::getNullValue(OpType));      return &I;    } @@ -5505,12 +5699,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {    if (Instruction *R = foldFabsWithFcmpZero(I))      return R; -  Value *X, *Y;    if (match(Op0, m_FNeg(m_Value(X)))) { -    // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y -    if (match(Op1, m_FNeg(m_Value(Y)))) -      return new FCmpInst(I.getSwappedPredicate(), X, Y, "", &I); -      // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C      Constant *C;      if (match(Op1, m_Constant(C))) { | 
