diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-06-13 19:31:46 +0000 | 
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-06-13 19:37:19 +0000 | 
| commit | e8d8bef961a50d4dc22501cde4fb9fb0be1b2532 (patch) | |
| tree | 94f04805f47bb7c59ae29690d8952b6074fff602 /contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp | |
| parent | bb130ff39747b94592cb26d71b7cb097b9a4ea6b (diff) | |
| parent | b60736ec1405bb0a8dd40989f67ef4c93da068ab (diff) | |
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp | 1230 | 
1 files changed, 757 insertions, 473 deletions
diff --git a/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp b/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp index e744a966a104..c40e5c36cdc7 100644 --- a/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/contrib/llvm-project/llvm/lib/Analysis/InstructionSimplify.cpp @@ -228,64 +228,56 @@ static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {    return false;  } -/// Simplify "A op (B op' C)" by distributing op over op', turning it into -/// "(A op B) op' (A op C)".  Here "op" is given by Opcode and "op'" is -/// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. -/// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". -/// Returns the simplified value, or null if no simplification was performed. -static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, -                          Instruction::BinaryOps OpcodeToExpand, +/// Try to simplify a binary operator of form "V op OtherOp" where V is +/// "(B0 opex B1)" by distributing 'op' across 'opex' as +/// "(B0 op OtherOp) opex (B1 op OtherOp)". +static Value *expandBinOp(Instruction::BinaryOps Opcode, Value *V, +                          Value *OtherOp, Instruction::BinaryOps OpcodeToExpand,                            const SimplifyQuery &Q, unsigned MaxRecurse) { -  // Recursion is always used, so bail out at once if we already hit the limit. -  if (!MaxRecurse--) +  auto *B = dyn_cast<BinaryOperator>(V); +  if (!B || B->getOpcode() != OpcodeToExpand) +    return nullptr; +  Value *B0 = B->getOperand(0), *B1 = B->getOperand(1); +  Value *L = SimplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), +                           MaxRecurse); +  if (!L) +    return nullptr; +  Value *R = SimplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), +                           MaxRecurse); +  if (!R)      return nullptr; -  // Check whether the expression has the form "(A op' B) op C". -  if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) -    if (Op0->getOpcode() == OpcodeToExpand) { -      // It does!  Try turning it into "(A op C) op' (B op C)". -      Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; -      // Do "A op C" and "B op C" both simplify? -      if (Value *L = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) -        if (Value *R = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { -          // They do! Return "L op' R" if it simplifies or is already available. -          // If "L op' R" equals "A op' B" then "L op' R" is just the LHS. -          if ((L == A && R == B) || (Instruction::isCommutative(OpcodeToExpand) -                                     && L == B && R == A)) { -            ++NumExpand; -            return LHS; -          } -          // Otherwise return "L op' R" if it simplifies. -          if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { -            ++NumExpand; -            return V; -          } -        } -    } +  // Does the expanded pair of binops simplify to the existing binop? +  if ((L == B0 && R == B1) || +      (Instruction::isCommutative(OpcodeToExpand) && L == B1 && R == B0)) { +    ++NumExpand; +    return B; +  } -  // Check whether the expression has the form "A op (B op' C)". -  if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) -    if (Op1->getOpcode() == OpcodeToExpand) { -      // It does!  Try turning it into "(A op B) op' (A op C)". -      Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); -      // Do "A op B" and "A op C" both simplify? -      if (Value *L = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) -        if (Value *R = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) { -          // They do! Return "L op' R" if it simplifies or is already available. -          // If "L op' R" equals "B op' C" then "L op' R" is just the RHS. -          if ((L == B && R == C) || (Instruction::isCommutative(OpcodeToExpand) -                                     && L == C && R == B)) { -            ++NumExpand; -            return RHS; -          } -          // Otherwise return "L op' R" if it simplifies. -          if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { -            ++NumExpand; -            return V; -          } -        } -    } +  // Otherwise, return "L op' R" if it simplifies. +  Value *S = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse); +  if (!S) +    return nullptr; +  ++NumExpand; +  return S; +} + +/// Try to simplify binops of form "A op (B op' C)" or the commuted variant by +/// distributing op over op'. +static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, +                                     Value *L, Value *R, +                                     Instruction::BinaryOps OpcodeToExpand, +                                     const SimplifyQuery &Q, +                                     unsigned MaxRecurse) { +  // Recursion is always used, so bail out at once if we already hit the limit. +  if (!MaxRecurse--) +    return nullptr; + +  if (Value *V = expandBinOp(Opcode, L, R, OpcodeToExpand, Q, MaxRecurse)) +    return V; +  if (Value *V = expandBinOp(Opcode, R, L, OpcodeToExpand, Q, MaxRecurse)) +    return V;    return nullptr;  } @@ -423,9 +415,9 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS,      return TV;    // If one branch simplified to undef, return the other one. -  if (TV && isa<UndefValue>(TV)) +  if (TV && Q.isUndefValue(TV))      return FV; -  if (FV && isa<UndefValue>(FV)) +  if (FV && Q.isUndefValue(FV))      return TV;    // If applying the operation did not change the true and false select values, @@ -620,7 +612,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,      return C;    // X + undef -> undef -  if (match(Op1, m_Undef())) +  if (Q.isUndefValue(Op1))      return Op1;    // X + 0 -> X @@ -740,7 +732,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    // X - undef -> undef    // undef - X -> undef -  if (match(Op0, m_Undef()) || match(Op1, m_Undef())) +  if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1))      return UndefValue::get(Op0->getType());    // X - 0 -> X @@ -875,7 +867,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,    // X * undef -> 0    // X * 0 -> 0 -  if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) +  if (Q.isUndefValue(Op1) || match(Op1, m_Zero()))      return Constant::getNullValue(Op0->getType());    // X * 1 -> X @@ -901,8 +893,8 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return V;    // Mul distributes over Add. Try some generic simplifications based on this. -  if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, -                             Q, MaxRecurse)) +  if (Value *V = expandCommutativeBinOp(Instruction::Mul, Op0, Op1, +                                        Instruction::Add, Q, MaxRecurse))      return V;    // If the operation is with the result of a select instruction, check whether @@ -928,36 +920,37 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {  /// Check for common or similar folds of integer division or integer remainder.  /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). -static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, +                             const SimplifyQuery &Q) {    Type *Ty = Op0->getType(); -  // X / undef -> undef -  // X % undef -> undef -  if (match(Op1, m_Undef())) -    return Op1; +  // X / undef -> poison +  // X % undef -> poison +  if (Q.isUndefValue(Op1)) +    return PoisonValue::get(Ty); -  // X / 0 -> undef -  // X % 0 -> undef +  // X / 0 -> poison +  // X % 0 -> poison    // We don't need to preserve faults!    if (match(Op1, m_Zero())) -    return UndefValue::get(Ty); +    return PoisonValue::get(Ty); -  // If any element of a constant divisor fixed width vector is zero or undef, -  // the whole op is undef. +  // If any element of a constant divisor fixed width vector is zero or undef +  // the behavior is undefined and we can fold the whole op to poison.    auto *Op1C = dyn_cast<Constant>(Op1);    auto *VTy = dyn_cast<FixedVectorType>(Ty);    if (Op1C && VTy) {      unsigned NumElts = VTy->getNumElements();      for (unsigned i = 0; i != NumElts; ++i) {        Constant *Elt = Op1C->getAggregateElement(i); -      if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt))) -        return UndefValue::get(Ty); +      if (Elt && (Elt->isNullValue() || Q.isUndefValue(Elt))) +        return PoisonValue::get(Ty);      }    }    // undef / X -> 0    // undef % X -> 0 -  if (match(Op0, m_Undef())) +  if (Q.isUndefValue(Op0))      return Constant::getNullValue(Ty);    // 0 / X -> 0 @@ -1051,7 +1044,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,    if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))      return C; -  if (Value *V = simplifyDivRem(Op0, Op1, true)) +  if (Value *V = simplifyDivRem(Op0, Op1, true, Q))      return V;    bool IsSigned = Opcode == Instruction::SDiv; @@ -1109,7 +1102,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,    if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))      return C; -  if (Value *V = simplifyDivRem(Op0, Op1, false)) +  if (Value *V = simplifyDivRem(Op0, Op1, false, Q))      return V;    // (X % Y) % Y -> X % Y @@ -1204,14 +1197,14 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {    return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit);  } -/// Returns true if a shift by \c Amount always yields undef. -static bool isUndefShift(Value *Amount) { +/// Returns true if a shift by \c Amount always yields poison. +static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) {    Constant *C = dyn_cast<Constant>(Amount);    if (!C)      return false; -  // X shift by undef -> undef because it may shift by the bitwidth. -  if (isa<UndefValue>(C)) +  // X shift by undef -> poison because it may shift by the bitwidth. +  if (Q.isUndefValue(C))      return true;    // Shifting by the bitwidth or more is undefined. @@ -1222,9 +1215,10 @@ static bool isUndefShift(Value *Amount) {    // If all lanes of a vector shift are undefined the whole shift is.    if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) { -    for (unsigned I = 0, E = cast<VectorType>(C->getType())->getNumElements(); +    for (unsigned I = 0, +                  E = cast<FixedVectorType>(C->getType())->getNumElements();           I != E; ++I) -      if (!isUndefShift(C->getAggregateElement(I))) +      if (!isPoisonShift(C->getAggregateElement(I), Q))          return false;      return true;    } @@ -1252,8 +1246,8 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,      return Op0;    // Fold undefined shifts. -  if (isUndefShift(Op1)) -    return UndefValue::get(Op0->getType()); +  if (isPoisonShift(Op1, Q)) +    return PoisonValue::get(Op0->getType());    // If the operation is with the result of a select instruction, check whether    // operating on either branch of the select always yields the same value. @@ -1271,7 +1265,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,    // the number of bits in the type, the shift is undefined.    KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);    if (Known.One.getLimitedValue() >= Known.getBitWidth()) -    return UndefValue::get(Op0->getType()); +    return PoisonValue::get(Op0->getType());    // If all valid bits in the shift amount are known zero, the first operand is    // unchanged. @@ -1296,7 +1290,7 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,    // undef >> X -> 0    // undef >> X -> undef (if it's exact) -  if (match(Op0, m_Undef())) +  if (Q.isUndefValue(Op0))      return isExact ? Op0 : Constant::getNullValue(Op0->getType());    // The low bit cannot be shifted out of an exact shift if it is set. @@ -1318,7 +1312,7 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    // undef << X -> 0    // undef << X -> undef if (if it's NSW/NUW) -  if (match(Op0, m_Undef())) +  if (Q.isUndefValue(Op0))      return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType());    // (X >> A) << A -> X @@ -1704,25 +1698,27 @@ static Value *simplifyAndOrOfICmpsWithLimitConst(ICmpInst *Cmp0, ICmpInst *Cmp1,    if (!Cmp0->isEquality())      return nullptr; -  // The equality compare must be against a constant. Convert the 'null' pointer -  // constant to an integer zero value. -  APInt MinMaxC; -  const APInt *C; -  if (match(Cmp0->getOperand(1), m_APInt(C))) -    MinMaxC = *C; -  else if (isa<ConstantPointerNull>(Cmp0->getOperand(1))) -    MinMaxC = APInt::getNullValue(8); -  else -    return nullptr; -    // The non-equality compare must include a common operand (X). Canonicalize    // the common operand as operand 0 (the predicate is swapped if the common    // operand was operand 1).    ICmpInst::Predicate Pred0 = Cmp0->getPredicate();    Value *X = Cmp0->getOperand(0);    ICmpInst::Predicate Pred1; -  if (!match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value())) || -      ICmpInst::isEquality(Pred1)) +  bool HasNotOp = match(Cmp1, m_c_ICmp(Pred1, m_Not(m_Specific(X)), m_Value())); +  if (!HasNotOp && !match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value()))) +    return nullptr; +  if (ICmpInst::isEquality(Pred1)) +    return nullptr; + +  // The equality compare must be against a constant. Flip bits if we matched +  // a bitwise not. Convert a null pointer constant to an integer zero value. +  APInt MinMaxC; +  const APInt *C; +  if (match(Cmp0->getOperand(1), m_APInt(C))) +    MinMaxC = HasNotOp ? ~*C : *C; +  else if (isa<ConstantPointerNull>(Cmp0->getOperand(1))) +    MinMaxC = APInt::getNullValue(8); +  else      return nullptr;    // DeMorganize if this is 'or': P0 || P1 --> !P0 && !P1. @@ -2003,6 +1999,30 @@ static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0,    return NotOp1;  } +/// Given a bitwise logic op, check if the operands are add/sub with a common +/// source value and inverted constant (identity: C - X -> ~(X + ~C)). +static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, +                                    Instruction::BinaryOps Opcode) { +  assert(Op0->getType() == Op1->getType() && "Mismatched binop types"); +  assert(BinaryOperator::isBitwiseLogicOp(Opcode) && "Expected logic op"); +  Value *X; +  Constant *C1, *C2; +  if ((match(Op0, m_Add(m_Value(X), m_Constant(C1))) && +       match(Op1, m_Sub(m_Constant(C2), m_Specific(X)))) || +      (match(Op1, m_Add(m_Value(X), m_Constant(C1))) && +       match(Op0, m_Sub(m_Constant(C2), m_Specific(X))))) { +    if (ConstantExpr::getNot(C1) == C2) { +      // (X + C) & (~C - X) --> (X + C) & ~(X + C) --> 0 +      // (X + C) | (~C - X) --> (X + C) | ~(X + C) --> -1 +      // (X + C) ^ (~C - X) --> (X + C) ^ ~(X + C) --> -1 +      Type *Ty = Op0->getType(); +      return Opcode == Instruction::And ? ConstantInt::getNullValue(Ty) +                                        : ConstantInt::getAllOnesValue(Ty); +    } +  } +  return nullptr; +} +  /// Given operands for an And, see if we can fold the result.  /// If not, this returns null.  static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, @@ -2011,7 +2031,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return C;    // X & undef -> 0 -  if (match(Op1, m_Undef())) +  if (Q.isUndefValue(Op1))      return Constant::getNullValue(Op0->getType());    // X & X = X @@ -2039,6 +2059,9 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,    if (match(Op1, m_c_Or(m_Specific(Op0), m_Value())))      return Op0; +  if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::And)) +    return V; +    // A mask that only clears known zeros of a shifted value is a no-op.    Value *X;    const APInt *Mask; @@ -2095,21 +2118,30 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return V;    // And distributes over Or.  Try some generic simplifications based on this. -  if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Or, -                             Q, MaxRecurse)) +  if (Value *V = expandCommutativeBinOp(Instruction::And, Op0, Op1, +                                        Instruction::Or, Q, MaxRecurse))      return V;    // And distributes over Xor.  Try some generic simplifications based on this. -  if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Xor, -                             Q, MaxRecurse)) +  if (Value *V = expandCommutativeBinOp(Instruction::And, Op0, Op1, +                                        Instruction::Xor, Q, MaxRecurse))      return V; -  // If the operation is with the result of a select instruction, check whether -  // operating on either branch of the select always yields the same value. -  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) { +    if (Op0->getType()->isIntOrIntVectorTy(1)) { +      // A & (A && B) -> A && B +      if (match(Op1, m_Select(m_Specific(Op0), m_Value(), m_Zero()))) +        return Op1; +      else if (match(Op0, m_Select(m_Specific(Op1), m_Value(), m_Zero()))) +        return Op0; +    } +    // If the operation is with the result of a select instruction, check +    // whether operating on either branch of the select always yields the same +    // value.      if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q,                                           MaxRecurse))        return V; +  }    // If the operation is with the result of a phi instruction, check whether    // operating on all incoming values of the phi always yields the same value. @@ -2169,7 +2201,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,    // X | undef -> -1    // X | -1 = -1    // Do not return Op1 because it may contain undef elements if it's a vector. -  if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) +  if (Q.isUndefValue(Op1) || match(Op1, m_AllOnes()))      return Constant::getAllOnesValue(Op0->getType());    // X | X = X @@ -2198,7 +2230,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,    if (match(Op1, m_Not(m_c_And(m_Specific(Op0), m_Value()))))      return Constant::getAllOnesValue(Op0->getType()); -  Value *A, *B; +  if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::Or)) +    return V; + +  Value *A, *B, *NotA;    // (A & ~B) | (A ^ B) -> (A ^ B)    // (~B & A) | (A ^ B) -> (A ^ B)    // (A & ~B) | (B ^ A) -> (B ^ A) @@ -2227,6 +2262,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,         match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B)))))      return Op1; +  // Commute the 'or' operands.    // (~A ^ B) | (A & B) -> (~A ^ B)    // (~A ^ B) | (B & A) -> (~A ^ B)    // (B ^ ~A) | (A & B) -> (B ^ ~A) @@ -2236,6 +2272,25 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,         match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B)))))      return Op0; +  // (~A & B) | ~(A | B) --> ~A +  // (~A & B) | ~(B | A) --> ~A +  // (B & ~A) | ~(A | B) --> ~A +  // (B & ~A) | ~(B | A) --> ~A +  if (match(Op0, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))), +                         m_Value(B))) && +      match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) +    return NotA; + +  // Commute the 'or' operands. +  // ~(A | B) | (~A & B) --> ~A +  // ~(B | A) | (~A & B) --> ~A +  // ~(A | B) | (B & ~A) --> ~A +  // ~(B | A) | (B & ~A) --> ~A +  if (match(Op1, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))), +                         m_Value(B))) && +      match(Op0, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) +    return NotA; +    if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false))      return V; @@ -2253,16 +2308,25 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return V;    // Or distributes over And.  Try some generic simplifications based on this. -  if (Value *V = ExpandBinOp(Instruction::Or, Op0, Op1, Instruction::And, Q, -                             MaxRecurse)) +  if (Value *V = expandCommutativeBinOp(Instruction::Or, Op0, Op1, +                                        Instruction::And, Q, MaxRecurse))      return V; -  // If the operation is with the result of a select instruction, check whether -  // operating on either branch of the select always yields the same value. -  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) +  if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) { +    if (Op0->getType()->isIntOrIntVectorTy(1)) { +      // A | (A || B) -> A || B +      if (match(Op1, m_Select(m_Specific(Op0), m_One(), m_Value()))) +        return Op1; +      else if (match(Op0, m_Select(m_Specific(Op1), m_One(), m_Value()))) +        return Op0; +    } +    // If the operation is with the result of a select instruction, check +    // whether operating on either branch of the select always yields the same +    // value.      if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q,                                           MaxRecurse))        return V; +  }    // (A & C1)|(B & C2)    const APInt *C1, *C2; @@ -2311,7 +2375,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return C;    // A ^ undef -> undef -  if (match(Op1, m_Undef())) +  if (Q.isUndefValue(Op1))      return Op1;    // A ^ 0 = A @@ -2327,6 +2391,9 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,        match(Op1, m_Not(m_Specific(Op0))))      return Constant::getAllOnesValue(Op0->getType()); +  if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::Xor)) +    return V; +    // Try some generic simplifications for associative operations.    if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q,                                            MaxRecurse)) @@ -2533,8 +2600,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI,      // memory within the lifetime of the current function (allocas, byval      // arguments, globals), then determine the comparison result here.      SmallVector<const Value *, 8> LHSUObjs, RHSUObjs; -    GetUnderlyingObjects(LHS, LHSUObjs, DL); -    GetUnderlyingObjects(RHS, RHSUObjs, DL); +    getUnderlyingObjects(LHS, LHSUObjs); +    getUnderlyingObjects(RHS, RHSUObjs);      // Is the set of underlying objects all noalias calls?      auto IsNAC = [](ArrayRef<const Value *> Objects) { @@ -2741,7 +2808,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,    }    const APInt *C; -  if (!match(RHS, m_APInt(C))) +  if (!match(RHS, m_APIntAllowUndef(C)))      return nullptr;    // Rule out tautological comparisons (eg., ult 0 or uge 0). @@ -2759,17 +2826,166 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,        return ConstantInt::getFalse(ITy);    } +  // (mul nuw/nsw X, MulC) != C --> true  (if C is not a multiple of MulC) +  // (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC) +  const APInt *MulC; +  if (ICmpInst::isEquality(Pred) && +      ((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) && +        *MulC != 0 && C->urem(*MulC) != 0) || +       (match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) && +        *MulC != 0 && C->srem(*MulC) != 0))) +    return ConstantInt::get(ITy, Pred == ICmpInst::ICMP_NE); +    return nullptr;  } +static Value *simplifyICmpWithBinOpOnLHS( +    CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS, +    const SimplifyQuery &Q, unsigned MaxRecurse) { +  Type *ITy = GetCompareTy(RHS); // The return type. + +  Value *Y = nullptr; +  // icmp pred (or X, Y), X +  if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { +    if (Pred == ICmpInst::ICMP_ULT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_UGE) +      return getTrue(ITy); + +    if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { +      KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (RHSKnown.isNonNegative() && YKnown.isNegative()) +        return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); +      if (RHSKnown.isNegative() || YKnown.isNonNegative()) +        return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); +    } +  } + +  // icmp pred (and X, Y), X +  if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +  } + +  // icmp pred (urem X, Y), Y +  if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { +    switch (Pred) { +    default: +      break; +    case ICmpInst::ICMP_SGT: +    case ICmpInst::ICMP_SGE: { +      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_EQ: +    case ICmpInst::ICMP_UGT: +    case ICmpInst::ICMP_UGE: +      return getFalse(ITy); +    case ICmpInst::ICMP_SLT: +    case ICmpInst::ICMP_SLE: { +      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); +      if (!Known.isNonNegative()) +        break; +      LLVM_FALLTHROUGH; +    } +    case ICmpInst::ICMP_NE: +    case ICmpInst::ICMP_ULT: +    case ICmpInst::ICMP_ULE: +      return getTrue(ITy); +    } +  } + +  // icmp pred (urem X, Y), X +  if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) { +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +  } + +  // x >> y <=u x +  // x udiv y <=u x. +  if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || +      match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) { +    // icmp pred (X op Y), X +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +  } + +  // (x*C1)/C2 <= x for C1 <= C2. +  // This holds even if the multiplication overflows: Assume that x != 0 and +  // arithmetic is modulo M. For overflow to occur we must have C1 >= M/x and +  // thus C2 >= M/x. It follows that (x*C1)/C2 <= (M-1)/C2 <= ((M-1)*x)/M < x. +  // +  // Additionally, either the multiplication and division might be represented +  // as shifts: +  // (x*C1)>>C2 <= x for C1 < 2**C2. +  // (x<<C1)/C2 <= x for 2**C1 < C2. +  const APInt *C1, *C2; +  if ((match(LBO, m_UDiv(m_Mul(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && +       C1->ule(*C2)) || +      (match(LBO, m_LShr(m_Mul(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && +       C1->ule(APInt(C2->getBitWidth(), 1) << *C2)) || +      (match(LBO, m_UDiv(m_Shl(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && +       (APInt(C1->getBitWidth(), 1) << *C1).ule(*C2))) { +    if (Pred == ICmpInst::ICMP_UGT) +      return getFalse(ITy); +    if (Pred == ICmpInst::ICMP_ULE) +      return getTrue(ITy); +  } + +  return nullptr; +} + + +// If only one of the icmp's operands has NSW flags, try to prove that: +// +//   icmp slt (x + C1), (x +nsw C2) +// +// is equivalent to: +// +//   icmp slt C1, C2 +// +// which is true if x + C2 has the NSW flags set and: +// *) C1 < C2 && C1 >= 0, or +// *) C2 < C1 && C1 <= 0. +// +static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, +                                    Value *RHS) { +  // TODO: only support icmp slt for now. +  if (Pred != CmpInst::ICMP_SLT) +    return false; + +  // Canonicalize nsw add as RHS. +  if (!match(RHS, m_NSWAdd(m_Value(), m_Value()))) +    std::swap(LHS, RHS); +  if (!match(RHS, m_NSWAdd(m_Value(), m_Value()))) +    return false; + +  Value *X; +  const APInt *C1, *C2; +  if (!match(LHS, m_c_Add(m_Value(X), m_APInt(C1))) || +      !match(RHS, m_c_Add(m_Specific(X), m_APInt(C2)))) +    return false; + +  return (C1->slt(*C2) && C1->isNonNegative()) || +         (C2->slt(*C1) && C1->isNonPositive()); +} + +  /// TODO: A large part of this logic is duplicated in InstCombine's  /// foldICmpBinOp(). We should be able to share that and avoid the code  /// duplication.  static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,                                      Value *RHS, const SimplifyQuery &Q,                                      unsigned MaxRecurse) { -  Type *ITy = GetCompareTy(LHS); // The return type. -    BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS);    BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS);    if (MaxRecurse && (LBO || RBO)) { @@ -2813,8 +3029,9 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,          return V;      // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. -    if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && -        NoRHSWrapProblem) { +    bool CanSimplify = (NoLHSWrapProblem && NoRHSWrapProblem) || +                       trySimplifyICmpWithAdds(Pred, LHS, RHS); +    if (A && C && (A == C || A == D || B == C || B == D) && CanSimplify) {        // Determine Y and Z in the form icmp (X+Y), (X+Z).        Value *Y, *Z;        if (A == C) { @@ -2840,195 +3057,66 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,      }    } -  { -    Value *Y = nullptr; -    // icmp pred (or X, Y), X -    if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { -      if (Pred == ICmpInst::ICMP_ULT) -        return getFalse(ITy); -      if (Pred == ICmpInst::ICMP_UGE) -        return getTrue(ITy); - -      if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { -        KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -        KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -        if (RHSKnown.isNonNegative() && YKnown.isNegative()) -          return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); -        if (RHSKnown.isNegative() || YKnown.isNonNegative()) -          return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); -      } -    } -    // icmp pred X, (or X, Y) -    if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) { -      if (Pred == ICmpInst::ICMP_ULE) -        return getTrue(ITy); -      if (Pred == ICmpInst::ICMP_UGT) -        return getFalse(ITy); - -      if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { -        KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -        KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -        if (LHSKnown.isNonNegative() && YKnown.isNegative()) -          return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); -        if (LHSKnown.isNegative() || YKnown.isNonNegative()) -          return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); -      } -    } -  } +  if (LBO) +    if (Value *V = simplifyICmpWithBinOpOnLHS(Pred, LBO, RHS, Q, MaxRecurse)) +      return V; -  // icmp pred (and X, Y), X -  if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { -    if (Pred == ICmpInst::ICMP_UGT) -      return getFalse(ITy); -    if (Pred == ICmpInst::ICMP_ULE) -      return getTrue(ITy); -  } -  // icmp pred X, (and X, Y) -  if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) { -    if (Pred == ICmpInst::ICMP_UGE) -      return getTrue(ITy); -    if (Pred == ICmpInst::ICMP_ULT) -      return getFalse(ITy); -  } +  if (RBO) +    if (Value *V = simplifyICmpWithBinOpOnLHS( +            ICmpInst::getSwappedPredicate(Pred), RBO, LHS, Q, MaxRecurse)) +      return V;    // 0 - (zext X) pred C    if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) { -    if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { -      if (RHSC->getValue().isStrictlyPositive()) { -        if (Pred == ICmpInst::ICMP_SLT) -          return ConstantInt::getTrue(RHSC->getContext()); -        if (Pred == ICmpInst::ICMP_SGE) -          return ConstantInt::getFalse(RHSC->getContext()); -        if (Pred == ICmpInst::ICMP_EQ) -          return ConstantInt::getFalse(RHSC->getContext()); -        if (Pred == ICmpInst::ICMP_NE) -          return ConstantInt::getTrue(RHSC->getContext()); +    const APInt *C; +    if (match(RHS, m_APInt(C))) { +      if (C->isStrictlyPositive()) { +        if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_NE) +          return ConstantInt::getTrue(GetCompareTy(RHS)); +        if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_EQ) +          return ConstantInt::getFalse(GetCompareTy(RHS));        } -      if (RHSC->getValue().isNonNegative()) { +      if (C->isNonNegative()) {          if (Pred == ICmpInst::ICMP_SLE) -          return ConstantInt::getTrue(RHSC->getContext()); +          return ConstantInt::getTrue(GetCompareTy(RHS));          if (Pred == ICmpInst::ICMP_SGT) -          return ConstantInt::getFalse(RHSC->getContext()); +          return ConstantInt::getFalse(GetCompareTy(RHS));        }      }    } -  // icmp pred (urem X, Y), Y -  if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { -    switch (Pred) { -    default: -      break; -    case ICmpInst::ICMP_SGT: -    case ICmpInst::ICMP_SGE: { -      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -      if (!Known.isNonNegative()) -        break; -      LLVM_FALLTHROUGH; -    } -    case ICmpInst::ICMP_EQ: -    case ICmpInst::ICMP_UGT: -    case ICmpInst::ICMP_UGE: -      return getFalse(ITy); -    case ICmpInst::ICMP_SLT: -    case ICmpInst::ICMP_SLE: { -      KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -      if (!Known.isNonNegative()) -        break; -      LLVM_FALLTHROUGH; -    } -    case ICmpInst::ICMP_NE: -    case ICmpInst::ICMP_ULT: -    case ICmpInst::ICMP_ULE: -      return getTrue(ITy); -    } -  } - -  // icmp pred X, (urem Y, X) -  if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { -    switch (Pred) { -    default: -      break; -    case ICmpInst::ICMP_SGT: -    case ICmpInst::ICMP_SGE: { -      KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -      if (!Known.isNonNegative()) -        break; -      LLVM_FALLTHROUGH; -    } -    case ICmpInst::ICMP_NE: -    case ICmpInst::ICMP_UGT: -    case ICmpInst::ICMP_UGE: -      return getTrue(ITy); -    case ICmpInst::ICMP_SLT: -    case ICmpInst::ICMP_SLE: { -      KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); -      if (!Known.isNonNegative()) -        break; -      LLVM_FALLTHROUGH; -    } -    case ICmpInst::ICMP_EQ: -    case ICmpInst::ICMP_ULT: -    case ICmpInst::ICMP_ULE: -      return getFalse(ITy); +  //   If C2 is a power-of-2 and C is not: +  //   (C2 << X) == C --> false +  //   (C2 << X) != C --> true +  const APInt *C; +  if (match(LHS, m_Shl(m_Power2(), m_Value())) && +      match(RHS, m_APIntAllowUndef(C)) && !C->isPowerOf2()) { +    // C2 << X can equal zero in some circumstances. +    // This simplification might be unsafe if C is zero. +    // +    // We know it is safe if: +    // - The shift is nsw. We can't shift out the one bit. +    // - The shift is nuw. We can't shift out the one bit. +    // - C2 is one. +    // - C isn't zero. +    if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || +        Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || +        match(LHS, m_Shl(m_One(), m_Value())) || !C->isNullValue()) { +      if (Pred == ICmpInst::ICMP_EQ) +        return ConstantInt::getFalse(GetCompareTy(RHS)); +      if (Pred == ICmpInst::ICMP_NE) +        return ConstantInt::getTrue(GetCompareTy(RHS));      }    } -  // x >> y <=u x -  // x udiv y <=u x. -  if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || -              match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) { -    // icmp pred (X op Y), X +  // TODO: This is overly constrained. LHS can be any power-of-2. +  // (1 << X)  >u 0x8000 --> false +  // (1 << X) <=u 0x8000 --> true +  if (match(LHS, m_Shl(m_One(), m_Value())) && match(RHS, m_SignMask())) {      if (Pred == ICmpInst::ICMP_UGT) -      return getFalse(ITy); +      return ConstantInt::getFalse(GetCompareTy(RHS));      if (Pred == ICmpInst::ICMP_ULE) -      return getTrue(ITy); -  } - -  // x >=u x >> y -  // x >=u x udiv y. -  if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || -              match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { -    // icmp pred X, (X op Y) -    if (Pred == ICmpInst::ICMP_ULT) -      return getFalse(ITy); -    if (Pred == ICmpInst::ICMP_UGE) -      return getTrue(ITy); -  } - -  // handle: -  //   CI2 << X == CI -  //   CI2 << X != CI -  // -  //   where CI2 is a power of 2 and CI isn't -  if (auto *CI = dyn_cast<ConstantInt>(RHS)) { -    const APInt *CI2Val, *CIVal = &CI->getValue(); -    if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && -        CI2Val->isPowerOf2()) { -      if (!CIVal->isPowerOf2()) { -        // CI2 << X can equal zero in some circumstances, -        // this simplification is unsafe if CI is zero. -        // -        // We know it is safe if: -        // - The shift is nsw, we can't shift out the one bit. -        // - The shift is nuw, we can't shift out the one bit. -        // - CI2 is one -        // - CI isn't zero -        if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || -            Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || -            CI2Val->isOneValue() || !CI->isZero()) { -          if (Pred == ICmpInst::ICMP_EQ) -            return ConstantInt::getFalse(RHS->getContext()); -          if (Pred == ICmpInst::ICMP_NE) -            return ConstantInt::getTrue(RHS->getContext()); -        } -      } -      if (CIVal->isSignMask() && CI2Val->isOneValue()) { -        if (Pred == ICmpInst::ICMP_UGT) -          return ConstantInt::getFalse(RHS->getContext()); -        if (Pred == ICmpInst::ICMP_ULE) -          return ConstantInt::getTrue(RHS->getContext()); -      } -    } +      return ConstantInt::getTrue(GetCompareTy(RHS));    }    if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && @@ -3226,55 +3314,38 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS,        break;      }      case CmpInst::ICMP_UGE: -      // Always true.        return getTrue(ITy);      case CmpInst::ICMP_ULT: -      // Always false.        return getFalse(ITy);      }    } -  // Variants on "max(x,y) >= min(x,z)". +  // Comparing 1 each of min/max with a common operand? +  // Canonicalize min operand to RHS. +  if (match(LHS, m_UMin(m_Value(), m_Value())) || +      match(LHS, m_SMin(m_Value(), m_Value()))) { +    std::swap(LHS, RHS); +    Pred = ICmpInst::getSwappedPredicate(Pred); +  } +    Value *C, *D;    if (match(LHS, m_SMax(m_Value(A), m_Value(B))) &&        match(RHS, m_SMin(m_Value(C), m_Value(D))) &&        (A == C || A == D || B == C || B == D)) { -    // max(x, ?) pred min(x, ?). +    // smax(A, B) >=s smin(A, D) --> true      if (Pred == CmpInst::ICMP_SGE) -      // Always true.        return getTrue(ITy); +    // smax(A, B) <s smin(A, D) --> false      if (Pred == CmpInst::ICMP_SLT) -      // Always false. -      return getFalse(ITy); -  } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && -             match(RHS, m_SMax(m_Value(C), m_Value(D))) && -             (A == C || A == D || B == C || B == D)) { -    // min(x, ?) pred max(x, ?). -    if (Pred == CmpInst::ICMP_SLE) -      // Always true. -      return getTrue(ITy); -    if (Pred == CmpInst::ICMP_SGT) -      // Always false.        return getFalse(ITy);    } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) &&               match(RHS, m_UMin(m_Value(C), m_Value(D))) &&               (A == C || A == D || B == C || B == D)) { -    // max(x, ?) pred min(x, ?). +    // umax(A, B) >=u umin(A, D) --> true      if (Pred == CmpInst::ICMP_UGE) -      // Always true.        return getTrue(ITy); +    // umax(A, B) <u umin(A, D) --> false      if (Pred == CmpInst::ICMP_ULT) -      // Always false. -      return getFalse(ITy); -  } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && -             match(RHS, m_UMax(m_Value(C), m_Value(D))) && -             (A == C || A == D || B == C || B == D)) { -    // min(x, ?) pred max(x, ?). -    if (Pred == CmpInst::ICMP_ULE) -      // Always true. -      return getTrue(ITy); -    if (Pred == CmpInst::ICMP_UGT) -      // Always false.        return getFalse(ITy);    } @@ -3327,12 +3398,12 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,    // For EQ and NE, we can always pick a value for the undef to make the    // predicate pass or fail, so we can return undef.    // Matches behavior in llvm::ConstantFoldCompareInstruction. -  if (isa<UndefValue>(RHS) && ICmpInst::isEquality(Pred)) +  if (Q.isUndefValue(RHS) && ICmpInst::isEquality(Pred))      return UndefValue::get(ITy);    // icmp X, X -> true/false    // icmp X, undef -> true/false because undef could be X. -  if (LHS == RHS || isa<UndefValue>(RHS)) +  if (LHS == RHS || Q.isUndefValue(RHS))      return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred));    if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) @@ -3588,7 +3659,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,          // expression GEP with the same indices and a null base pointer to see          // what constant folding can make out of it.          Constant *Null = Constant::getNullValue(GLHS->getPointerOperandType()); -        SmallVector<Value *, 4> IndicesLHS(GLHS->idx_begin(), GLHS->idx_end()); +        SmallVector<Value *, 4> IndicesLHS(GLHS->indices());          Constant *NewLHS = ConstantExpr::getGetElementPtr(              GLHS->getSourceElementType(), Null, IndicesLHS); @@ -3659,7 +3730,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,    // fcmp pred x, undef  and  fcmp pred undef, x    // fold to true if unordered, false if ordered -  if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { +  if (Q.isUndefValue(LHS) || Q.isUndefValue(RHS)) {      // Choosing NaN for the undef will always make unordered comparison succeed      // and ordered comparison fail.      return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); @@ -3703,6 +3774,21 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,            break;          }        } + +      // LHS == Inf +      if (Pred == FCmpInst::FCMP_OEQ && isKnownNeverInfinity(LHS, Q.TLI)) +        return getFalse(RetTy); +      // LHS != Inf +      if (Pred == FCmpInst::FCMP_UNE && isKnownNeverInfinity(LHS, Q.TLI)) +        return getTrue(RetTy); +      // LHS == Inf || LHS == NaN +      if (Pred == FCmpInst::FCMP_UEQ && isKnownNeverInfinity(LHS, Q.TLI) && +          isKnownNeverNaN(LHS, Q.TLI)) +        return getFalse(RetTy); +      // LHS != Inf && LHS != NaN +      if (Pred == FCmpInst::FCMP_ONE && isKnownNeverInfinity(LHS, Q.TLI) && +          isKnownNeverNaN(LHS, Q.TLI)) +        return getTrue(RetTy);      }      if (C->isNegative() && !C->isNegZero()) {        assert(!C->isNaN() && "Unexpected NaN constant!"); @@ -3834,18 +3920,33 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,    // We can't replace %sel with %add unless we strip away the flags (which will    // be done in InstCombine).    // TODO: This is unsound, because it only catches some forms of refinement. -  if (!AllowRefinement && canCreatePoison(I)) +  if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))      return nullptr; +  // The simplification queries below may return the original value. Consider: +  //   %div = udiv i32 %arg, %arg2 +  //   %mul = mul nsw i32 %div, %arg2 +  //   %cmp = icmp eq i32 %mul, %arg +  //   %sel = select i1 %cmp, i32 %div, i32 undef +  // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which +  // simplifies back to %arg. This can only happen because %mul does not +  // dominate %div. To ensure a consistent return value contract, we make sure +  // that this case returns nullptr as well. +  auto PreventSelfSimplify = [V](Value *Simplified) { +    return Simplified != V ? Simplified : nullptr; +  }; +    // If this is a binary operator, try to simplify it with the replaced op.    if (auto *B = dyn_cast<BinaryOperator>(I)) {      if (MaxRecurse) {        if (B->getOperand(0) == Op) -        return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q, -                             MaxRecurse - 1); +        return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), RepOp, +                                                 B->getOperand(1), Q, +                                                 MaxRecurse - 1));        if (B->getOperand(1) == Op) -        return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, Q, -                             MaxRecurse - 1); +        return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), +                                                 B->getOperand(0), RepOp, Q, +                                                 MaxRecurse - 1));      }    } @@ -3853,11 +3954,13 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,    if (CmpInst *C = dyn_cast<CmpInst>(I)) {      if (MaxRecurse) {        if (C->getOperand(0) == Op) -        return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), Q, -                               MaxRecurse - 1); +        return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), RepOp, +                                                   C->getOperand(1), Q, +                                                   MaxRecurse - 1));        if (C->getOperand(1) == Op) -        return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, Q, -                               MaxRecurse - 1); +        return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), +                                                   C->getOperand(0), RepOp, Q, +                                                   MaxRecurse - 1));      }    } @@ -3867,8 +3970,8 @@ static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,        SmallVector<Value *, 8> NewOps(GEP->getNumOperands());        transform(GEP->operands(), NewOps.begin(),                  [&](Value *V) { return V == Op ? RepOp : V; }); -      return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, -                             MaxRecurse - 1); +      return PreventSelfSimplify(SimplifyGEPInst(GEP->getSourceElementType(), +                                                 NewOps, Q, MaxRecurse - 1));      }    } @@ -3987,10 +4090,8 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,      // Test for a bogus zero-shift-guard-op around funnel-shift or rotate.      Value *ShAmt; -    auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), -                                                          m_Value(ShAmt)), -                             m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), -                                                          m_Value(ShAmt))); +    auto isFsh = m_CombineOr(m_FShl(m_Value(X), m_Value(), m_Value(ShAmt)), +                             m_FShr(m_Value(), m_Value(X), m_Value(ShAmt)));      // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X      // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X      if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt) @@ -4001,17 +4102,24 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,      // intrinsics do not have that problem.      // We do not allow this transform for the general funnel shift case because      // that would not preserve the poison safety of the original code. -    auto isRotate = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), -                                                             m_Deferred(X), -                                                             m_Value(ShAmt)), -                                m_Intrinsic<Intrinsic::fshr>(m_Value(X), -                                                             m_Deferred(X), -                                                             m_Value(ShAmt))); +    auto isRotate = +        m_CombineOr(m_FShl(m_Value(X), m_Deferred(X), m_Value(ShAmt)), +                    m_FShr(m_Value(X), m_Deferred(X), m_Value(ShAmt)));      // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt)      // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt)      if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt &&          Pred == ICmpInst::ICMP_EQ)        return FalseVal; + +    // X == 0 ? abs(X) : -abs(X) --> -abs(X) +    // X == 0 ? -abs(X) : abs(X) --> abs(X) +    if (match(TrueVal, m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS))) && +        match(FalseVal, m_Neg(m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS))))) +      return FalseVal; +    if (match(TrueVal, +              m_Neg(m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS)))) && +        match(FalseVal, m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS)))) +      return FalseVal;    }    // Check for other compares that behave like bit test. @@ -4083,7 +4191,7 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,          return ConstantFoldSelectInstruction(CondC, TrueC, FalseC);      // select undef, X, Y -> X or Y -    if (isa<UndefValue>(CondC)) +    if (Q.isUndefValue(CondC))        return isa<Constant>(FalseVal) ? FalseVal : TrueVal;      // TODO: Vector constants with undef elements don't simplify. @@ -4109,16 +4217,24 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,    if (TrueVal == FalseVal)      return TrueVal; -  if (isa<UndefValue>(TrueVal))   // select ?, undef, X -> X +  // If the true or false value is undef, we can fold to the other value as +  // long as the other value isn't poison. +  // select ?, undef, X -> X +  if (Q.isUndefValue(TrueVal) && +      isGuaranteedNotToBeUndefOrPoison(FalseVal, Q.AC, Q.CxtI, Q.DT))      return FalseVal; -  if (isa<UndefValue>(FalseVal))   // select ?, X, undef -> X +  // select ?, X, undef -> X +  if (Q.isUndefValue(FalseVal) && +      isGuaranteedNotToBeUndefOrPoison(TrueVal, Q.AC, Q.CxtI, Q.DT))      return TrueVal;    // Deal with partial undef vector constants: select ?, VecC, VecC' --> VecC''    Constant *TrueC, *FalseC; -  if (TrueVal->getType()->isVectorTy() && match(TrueVal, m_Constant(TrueC)) && +  if (isa<FixedVectorType>(TrueVal->getType()) && +      match(TrueVal, m_Constant(TrueC)) &&        match(FalseVal, m_Constant(FalseC))) { -    unsigned NumElts = cast<VectorType>(TrueC->getType())->getNumElements(); +    unsigned NumElts = +        cast<FixedVectorType>(TrueC->getType())->getNumElements();      SmallVector<Constant *, 16> NewC;      for (unsigned i = 0; i != NumElts; ++i) {        // Bail out on incomplete vector constants. @@ -4131,9 +4247,11 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,        // one element is undef, choose the defined element as the safe result.        if (TEltC == FEltC)          NewC.push_back(TEltC); -      else if (isa<UndefValue>(TEltC)) +      else if (Q.isUndefValue(TEltC) && +               isGuaranteedNotToBeUndefOrPoison(FEltC))          NewC.push_back(FEltC); -      else if (isa<UndefValue>(FEltC)) +      else if (Q.isUndefValue(FEltC) && +               isGuaranteedNotToBeUndefOrPoison(TEltC))          NewC.push_back(TEltC);        else          break; @@ -4184,7 +4302,12 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,    else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType()))      GEPTy = VectorType::get(GEPTy, VT->getElementCount()); -  if (isa<UndefValue>(Ops[0])) +  // getelementptr poison, idx -> poison +  // getelementptr baseptr, poison -> poison +  if (any_of(Ops, [](const auto *V) { return isa<PoisonValue>(V); })) +    return PoisonValue::get(GEPTy); + +  if (Q.isUndefValue(Ops[0]))      return UndefValue::get(GEPTy);    bool IsScalableVec = isa<ScalableVectorType>(SrcTy); @@ -4207,9 +4330,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,        // doesn't truncate the pointers.        if (Ops[1]->getType()->getScalarSizeInBits() ==            Q.DL.getPointerSizeInBits(AS)) { -        auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { -          if (match(P, m_Zero())) -            return Constant::getNullValue(GEPTy); +        auto PtrToInt = [GEPTy](Value *P) -> Value * {            Value *Temp;            if (match(P, m_PtrToInt(m_Value(Temp))))              if (Temp->getType() == GEPTy) @@ -4217,10 +4338,14 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,            return nullptr;          }; +        // FIXME: The following transforms are only legal if P and V have the +        // same provenance (PR44403). Check whether getUnderlyingObject() is +        // the same? +          // getelementptr V, (sub P, V) -> P if P points to a type of size 1.          if (TyAllocSize == 1 &&              match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) -          if (Value *R = PtrToIntOrZero(P)) +          if (Value *R = PtrToInt(P))              return R;          // getelementptr V, (ashr (sub P, V), C) -> Q @@ -4229,7 +4354,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,                    m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),                           m_ConstantInt(C))) &&              TyAllocSize == 1ULL << C) -          if (Value *R = PtrToIntOrZero(P)) +          if (Value *R = PtrToInt(P))              return R;          // getelementptr V, (sdiv (sub P, V), C) -> Q @@ -4237,7 +4362,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,          if (match(Ops[1],                    m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))),                           m_SpecificInt(TyAllocSize)))) -          if (Value *R = PtrToIntOrZero(P)) +          if (Value *R = PtrToInt(P))              return R;        }      } @@ -4254,15 +4379,21 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,            Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL,                                                              BasePtrOffset); +      // Avoid creating inttoptr of zero here: While LLVMs treatment of +      // inttoptr is generally conservative, this particular case is folded to +      // a null pointer, which will have incorrect provenance. +        // gep (gep V, C), (sub 0, V) -> C        if (match(Ops.back(), -                m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { +                m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr)))) && +          !BasePtrOffset.isNullValue()) {          auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset);          return ConstantExpr::getIntToPtr(CI, GEPTy);        }        // gep (gep V, C), (xor V, -1) -> C-1        if (match(Ops.back(), -                m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { +                m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes())) && +          !BasePtrOffset.isOneValue()) {          auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1);          return ConstantExpr::getIntToPtr(CI, GEPTy);        } @@ -4293,7 +4424,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val,        return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs);    // insertvalue x, undef, n -> x -  if (match(Val, m_Undef())) +  if (Q.isUndefValue(Val))      return Agg;    // insertvalue x, (extractvalue y, n), n @@ -4301,7 +4432,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val,      if (EV->getAggregateOperand()->getType() == Agg->getType() &&          EV->getIndices() == Idxs) {        // insertvalue undef, (extractvalue y, n), n -> y -      if (match(Agg, m_Undef())) +      if (Q.isUndefValue(Agg))          return EV->getAggregateOperand();        // insertvalue y, (extractvalue y, n), n -> y @@ -4325,22 +4456,23 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx,    auto *ValC = dyn_cast<Constant>(Val);    auto *IdxC = dyn_cast<Constant>(Idx);    if (VecC && ValC && IdxC) -    return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC); +    return ConstantExpr::getInsertElement(VecC, ValC, IdxC); -  // For fixed-length vector, fold into undef if index is out of bounds. +  // For fixed-length vector, fold into poison if index is out of bounds.    if (auto *CI = dyn_cast<ConstantInt>(Idx)) {      if (isa<FixedVectorType>(Vec->getType()) &&          CI->uge(cast<FixedVectorType>(Vec->getType())->getNumElements())) -      return UndefValue::get(Vec->getType()); +      return PoisonValue::get(Vec->getType());    }    // If index is undef, it might be out of bounds (see above case) -  if (isa<UndefValue>(Idx)) -    return UndefValue::get(Vec->getType()); +  if (Q.isUndefValue(Idx)) +    return PoisonValue::get(Vec->getType()); -  // If the scalar is undef, and there is no risk of propagating poison from the -  // vector value, simplify to the vector value. -  if (isa<UndefValue>(Val) && isGuaranteedNotToBeUndefOrPoison(Vec)) +  // If the scalar is poison, or it is undef and there is no risk of +  // propagating poison from the vector value, simplify to the vector value. +  if (isa<PoisonValue>(Val) || +      (Q.isUndefValue(Val) && isGuaranteedNotToBePoison(Vec)))      return Vec;    // If we are extracting a value from a vector, then inserting it into the same @@ -4384,18 +4516,18 @@ Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs,  /// Given operands for an ExtractElementInst, see if we can fold the result.  /// If not, this returns null. -static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, -                                         unsigned) { +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, +                                         const SimplifyQuery &Q, unsigned) {    auto *VecVTy = cast<VectorType>(Vec->getType());    if (auto *CVec = dyn_cast<Constant>(Vec)) {      if (auto *CIdx = dyn_cast<Constant>(Idx)) -      return ConstantFoldExtractElementInstruction(CVec, CIdx); +      return ConstantExpr::getExtractElement(CVec, CIdx);      // The index is not relevant if our vector is a splat.      if (auto *Splat = CVec->getSplatValue())        return Splat; -    if (isa<UndefValue>(Vec)) +    if (Q.isUndefValue(Vec))        return UndefValue::get(VecVTy->getElementType());    } @@ -4404,16 +4536,16 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQ    if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) {      // For fixed-length vector, fold into undef if index is out of bounds.      if (isa<FixedVectorType>(VecVTy) && -        IdxC->getValue().uge(VecVTy->getNumElements())) -      return UndefValue::get(VecVTy->getElementType()); +        IdxC->getValue().uge(cast<FixedVectorType>(VecVTy)->getNumElements())) +      return PoisonValue::get(VecVTy->getElementType());      if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue()))        return Elt;    }    // An undef extract index can be arbitrarily chosen to be an out-of-range -  // index value, which would result in the instruction being undef. -  if (isa<UndefValue>(Idx)) -    return UndefValue::get(VecVTy->getElementType()); +  // index value, which would result in the instruction being poison. +  if (Q.isUndefValue(Idx)) +    return PoisonValue::get(VecVTy->getElementType());    return nullptr;  } @@ -4425,6 +4557,10 @@ Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx,  /// See if we can fold the given phi. If not, returns null.  static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { +  // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE +  //          here, because the PHI we may succeed simplifying to was not +  //          def-reachable from the original PHI! +    // If all of the PHI's incoming values are the same then replace the PHI node    // with the common value.    Value *CommonValue = nullptr; @@ -4432,7 +4568,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) {    for (Value *Incoming : PN->incoming_values()) {      // If the incoming value is the phi node itself, it can safely be skipped.      if (Incoming == PN) continue; -    if (isa<UndefValue>(Incoming)) { +    if (Q.isUndefValue(Incoming)) {        // Remember that we saw an undef value, but otherwise ignore them.        HasUndefInput = true;        continue; @@ -4510,7 +4646,7 @@ static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1,      return nullptr;    // The mask value chooses which source operand we need to look at next. -  int InVecNumElts = cast<VectorType>(Op0->getType())->getNumElements(); +  int InVecNumElts = cast<FixedVectorType>(Op0->getType())->getNumElements();    int RootElt = MaskVal;    Value *SourceOp = Op0;    if (MaskVal >= InVecNumElts) { @@ -4557,16 +4693,16 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1,    unsigned MaskNumElts = Mask.size();    ElementCount InVecEltCount = InVecTy->getElementCount(); -  bool Scalable = InVecEltCount.Scalable; +  bool Scalable = InVecEltCount.isScalable();    SmallVector<int, 32> Indices;    Indices.assign(Mask.begin(), Mask.end());    // Canonicalization: If mask does not select elements from an input vector, -  // replace that input vector with undef. +  // replace that input vector with poison.    if (!Scalable) {      bool MaskSelects0 = false, MaskSelects1 = false; -    unsigned InVecNumElts = InVecEltCount.Min; +    unsigned InVecNumElts = InVecEltCount.getKnownMinValue();      for (unsigned i = 0; i != MaskNumElts; ++i) {        if (Indices[i] == -1)          continue; @@ -4576,9 +4712,9 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1,          MaskSelects1 = true;      }      if (!MaskSelects0) -      Op0 = UndefValue::get(InVecTy); +      Op0 = PoisonValue::get(InVecTy);      if (!MaskSelects1) -      Op1 = UndefValue::get(InVecTy); +      Op1 = PoisonValue::get(InVecTy);    }    auto *Op0Const = dyn_cast<Constant>(Op0); @@ -4587,15 +4723,16 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1,    // If all operands are constant, constant fold the shuffle. This    // transformation depends on the value of the mask which is not known at    // compile time for scalable vectors -  if (!Scalable && Op0Const && Op1Const) -    return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); +  if (Op0Const && Op1Const) +    return ConstantExpr::getShuffleVector(Op0Const, Op1Const, Mask);    // Canonicalization: if only one input vector is constant, it shall be the    // second one. This transformation depends on the value of the mask which    // is not known at compile time for scalable vectors    if (!Scalable && Op0Const && !Op1Const) {      std::swap(Op0, Op1); -    ShuffleVectorInst::commuteShuffleMask(Indices, InVecEltCount.Min); +    ShuffleVectorInst::commuteShuffleMask(Indices, +                                          InVecEltCount.getKnownMinValue());    }    // A splat of an inserted scalar constant becomes a vector constant: @@ -4627,7 +4764,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1,    // A shuffle of a splat is always the splat itself. Legal if the shuffle's    // value type is same as the input vectors' type.    if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) -    if (isa<UndefValue>(Op1) && RetTy == InVecTy && +    if (Q.isUndefValue(Op1) && RetTy == InVecTy &&          is_splat(OpShuf->getShuffleMask()))        return Op0; @@ -4639,7 +4776,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1,    // Don't fold a shuffle with undef mask elements. This may get folded in a    // better way using demanded bits or other analysis.    // TODO: Should we allow this? -  if (find(Indices, -1) != Indices.end()) +  if (is_contained(Indices, -1))      return nullptr;    // Check if every element of this shuffle can be mapped back to the @@ -4708,19 +4845,20 @@ static Constant *propagateNaN(Constant *In) {  /// transforms based on undef/NaN because the operation itself makes no  /// difference to the result.  static Constant *simplifyFPOp(ArrayRef<Value *> Ops, -                              FastMathFlags FMF = FastMathFlags()) { +                              FastMathFlags FMF, +                              const SimplifyQuery &Q) {    for (Value *V : Ops) {      bool IsNan = match(V, m_NaN());      bool IsInf = match(V, m_Inf()); -    bool IsUndef = match(V, m_Undef()); +    bool IsUndef = Q.isUndefValue(V);      // If this operation has 'nnan' or 'ninf' and at least 1 disallowed operand      // (an undef operand can be chosen to be Nan/Inf), then the result of -    // this operation is poison. That result can be relaxed to undef. +    // this operation is poison.      if (FMF.noNaNs() && (IsNan || IsUndef)) -      return UndefValue::get(V->getType()); +      return PoisonValue::get(V->getType());      if (FMF.noInfs() && (IsInf || IsUndef)) -      return UndefValue::get(V->getType()); +      return PoisonValue::get(V->getType());      if (IsUndef || IsNan)        return propagateNaN(cast<Constant>(V)); @@ -4735,7 +4873,7 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))      return C; -  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) +  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))      return C;    // fadd X, -0 ==> X @@ -4782,7 +4920,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))      return C; -  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) +  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))      return C;    // fsub X, +0 ==> X @@ -4824,7 +4962,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,  static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,                                const SimplifyQuery &Q, unsigned MaxRecurse) { -  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) +  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))      return C;    // fmul X, 1.0 ==> X @@ -4891,7 +5029,7 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))      return C; -  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) +  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))      return C;    // X / 1.0 -> X @@ -4936,7 +5074,7 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))      return C; -  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) +  if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q))      return C;    // Unlike fdiv, the result of frem always matches the sign of the dividend. @@ -5181,6 +5319,15 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,      // bitreverse(bitreverse(x)) -> x      if (match(Op0, m_BitReverse(m_Value(X)))) return X;      break; +  case Intrinsic::ctpop: { +    // If everything but the lowest bit is zero, that bit is the pop-count. Ex: +    // ctpop(and X, 1) --> and X, 1 +    unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); +    if (MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, BitWidth - 1), +                          Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) +      return Op0; +    break; +  }    case Intrinsic::exp:      // exp(log(x)) -> x      if (Q.CxtI->hasAllowReassoc() && @@ -5233,27 +5380,156 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,    return nullptr;  } +static Intrinsic::ID getMaxMinOpposite(Intrinsic::ID IID) { +  switch (IID) { +  case Intrinsic::smax: return Intrinsic::smin; +  case Intrinsic::smin: return Intrinsic::smax; +  case Intrinsic::umax: return Intrinsic::umin; +  case Intrinsic::umin: return Intrinsic::umax; +  default: llvm_unreachable("Unexpected intrinsic"); +  } +} + +static APInt getMaxMinLimit(Intrinsic::ID IID, unsigned BitWidth) { +  switch (IID) { +  case Intrinsic::smax: return APInt::getSignedMaxValue(BitWidth); +  case Intrinsic::smin: return APInt::getSignedMinValue(BitWidth); +  case Intrinsic::umax: return APInt::getMaxValue(BitWidth); +  case Intrinsic::umin: return APInt::getMinValue(BitWidth); +  default: llvm_unreachable("Unexpected intrinsic"); +  } +} + +static ICmpInst::Predicate getMaxMinPredicate(Intrinsic::ID IID) { +  switch (IID) { +  case Intrinsic::smax: return ICmpInst::ICMP_SGE; +  case Intrinsic::smin: return ICmpInst::ICMP_SLE; +  case Intrinsic::umax: return ICmpInst::ICMP_UGE; +  case Intrinsic::umin: return ICmpInst::ICMP_ULE; +  default: llvm_unreachable("Unexpected intrinsic"); +  } +} + +/// Given a min/max intrinsic, see if it can be removed based on having an +/// operand that is another min/max intrinsic with shared operand(s). The caller +/// is expected to swap the operand arguments to handle commutation. +static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { +  Value *X, *Y; +  if (!match(Op0, m_MaxOrMin(m_Value(X), m_Value(Y)))) +    return nullptr; + +  auto *MM0 = dyn_cast<IntrinsicInst>(Op0); +  if (!MM0) +    return nullptr; +  Intrinsic::ID IID0 = MM0->getIntrinsicID(); + +  if (Op1 == X || Op1 == Y || +      match(Op1, m_c_MaxOrMin(m_Specific(X), m_Specific(Y)))) { +    // max (max X, Y), X --> max X, Y +    if (IID0 == IID) +      return MM0; +    // max (min X, Y), X --> X +    if (IID0 == getMaxMinOpposite(IID)) +      return Op1; +  } +  return nullptr; +} +  static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,                                        const SimplifyQuery &Q) {    Intrinsic::ID IID = F->getIntrinsicID();    Type *ReturnType = F->getReturnType(); +  unsigned BitWidth = ReturnType->getScalarSizeInBits();    switch (IID) { +  case Intrinsic::abs: +    // abs(abs(x)) -> abs(x). We don't need to worry about the nsw arg here. +    // It is always ok to pick the earlier abs. We'll just lose nsw if its only +    // on the outer abs. +    if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(), m_Value()))) +      return Op0; +    break; + +  case Intrinsic::smax: +  case Intrinsic::smin: +  case Intrinsic::umax: +  case Intrinsic::umin: { +    // If the arguments are the same, this is a no-op. +    if (Op0 == Op1) +      return Op0; + +    // Canonicalize constant operand as Op1. +    if (isa<Constant>(Op0)) +      std::swap(Op0, Op1); + +    // Assume undef is the limit value. +    if (Q.isUndefValue(Op1)) +      return ConstantInt::get(ReturnType, getMaxMinLimit(IID, BitWidth)); + +    const APInt *C; +    if (match(Op1, m_APIntAllowUndef(C))) { +      // Clamp to limit value. For example: +      // umax(i8 %x, i8 255) --> 255 +      if (*C == getMaxMinLimit(IID, BitWidth)) +        return ConstantInt::get(ReturnType, *C); + +      // If the constant op is the opposite of the limit value, the other must +      // be larger/smaller or equal. For example: +      // umin(i8 %x, i8 255) --> %x +      if (*C == getMaxMinLimit(getMaxMinOpposite(IID), BitWidth)) +        return Op0; + +      // Remove nested call if constant operands allow it. Example: +      // max (max X, 7), 5 -> max X, 7 +      auto *MinMax0 = dyn_cast<IntrinsicInst>(Op0); +      if (MinMax0 && MinMax0->getIntrinsicID() == IID) { +        // TODO: loosen undef/splat restrictions for vector constants. +        Value *M00 = MinMax0->getOperand(0), *M01 = MinMax0->getOperand(1); +        const APInt *InnerC; +        if ((match(M00, m_APInt(InnerC)) || match(M01, m_APInt(InnerC))) && +            ((IID == Intrinsic::smax && InnerC->sge(*C)) || +             (IID == Intrinsic::smin && InnerC->sle(*C)) || +             (IID == Intrinsic::umax && InnerC->uge(*C)) || +             (IID == Intrinsic::umin && InnerC->ule(*C)))) +          return Op0; +      } +    } + +    if (Value *V = foldMinMaxSharedOp(IID, Op0, Op1)) +      return V; +    if (Value *V = foldMinMaxSharedOp(IID, Op1, Op0)) +      return V; + +    ICmpInst::Predicate Pred = getMaxMinPredicate(IID); +    if (isICmpTrue(Pred, Op0, Op1, Q.getWithoutUndef(), RecursionLimit)) +      return Op0; +    if (isICmpTrue(Pred, Op1, Op0, Q.getWithoutUndef(), RecursionLimit)) +      return Op1; + +    if (Optional<bool> Imp = +            isImpliedByDomCondition(Pred, Op0, Op1, Q.CxtI, Q.DL)) +      return *Imp ? Op0 : Op1; +    if (Optional<bool> Imp = +            isImpliedByDomCondition(Pred, Op1, Op0, Q.CxtI, Q.DL)) +      return *Imp ? Op1 : Op0; + +    break; +  }    case Intrinsic::usub_with_overflow:    case Intrinsic::ssub_with_overflow:      // X - X -> { 0, false } -    if (Op0 == Op1) +    // X - undef -> { 0, false } +    // undef - X -> { 0, false } +    if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1))        return Constant::getNullValue(ReturnType); -    LLVM_FALLTHROUGH; +    break;    case Intrinsic::uadd_with_overflow:    case Intrinsic::sadd_with_overflow: -    // X - undef -> { undef, false } -    // undef - X -> { undef, false } -    // X + undef -> { undef, false } -    // undef + x -> { undef, false } -    if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) { +    // X + undef -> { -1, false } +    // undef + x -> { -1, false } +    if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) {        return ConstantStruct::get(            cast<StructType>(ReturnType), -          {UndefValue::get(ReturnType->getStructElementType(0)), +          {Constant::getAllOnesValue(ReturnType->getStructElementType(0)),             Constant::getNullValue(ReturnType->getStructElementType(1))});      }      break; @@ -5265,7 +5541,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,        return Constant::getNullValue(ReturnType);      // undef * X -> { 0, false }      // X * undef -> { 0, false } -    if (match(Op0, m_Undef()) || match(Op1, m_Undef())) +    if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1))        return Constant::getNullValue(ReturnType);      break;    case Intrinsic::uadd_sat: @@ -5279,7 +5555,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,      // sat(undef + X) -> -1      // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1).      // For signed: Assume undef is ~X, in which case X + ~X = -1. -    if (match(Op0, m_Undef()) || match(Op1, m_Undef())) +    if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1))        return Constant::getAllOnesValue(ReturnType);      // X + 0 -> X @@ -5296,7 +5572,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,      LLVM_FALLTHROUGH;    case Intrinsic::ssub_sat:      // X - X -> 0, X - undef -> 0, undef - X -> 0 -    if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef())) +    if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1))        return Constant::getNullValue(ReturnType);      // X - 0 -> X      if (match(Op1, m_Zero())) @@ -5334,18 +5610,43 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,      // If the arguments are the same, this is a no-op.      if (Op0 == Op1) return Op0; -    // If one argument is undef, return the other argument. -    if (match(Op0, m_Undef())) -      return Op1; -    if (match(Op1, m_Undef())) +    // Canonicalize constant operand as Op1. +    if (isa<Constant>(Op0)) +      std::swap(Op0, Op1); + +    // If an argument is undef, return the other argument. +    if (Q.isUndefValue(Op1))        return Op0; -    // If one argument is NaN, return other or NaN appropriately.      bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; -    if (match(Op0, m_NaN())) -      return PropagateNaN ? Op0 : Op1; +    bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum; + +    // minnum(X, nan) -> X +    // maxnum(X, nan) -> X +    // minimum(X, nan) -> nan +    // maximum(X, nan) -> nan      if (match(Op1, m_NaN())) -      return PropagateNaN ? Op1 : Op0; +      return PropagateNaN ? propagateNaN(cast<Constant>(Op1)) : Op0; + +    // In the following folds, inf can be replaced with the largest finite +    // float, if the ninf flag is set. +    const APFloat *C; +    if (match(Op1, m_APFloat(C)) && +        (C->isInfinity() || (Q.CxtI->hasNoInfs() && C->isLargest()))) { +      // minnum(X, -inf) -> -inf +      // maxnum(X, +inf) -> +inf +      // minimum(X, -inf) -> -inf if nnan +      // maximum(X, +inf) -> +inf if nnan +      if (C->isNegative() == IsMin && (!PropagateNaN || Q.CxtI->hasNoNaNs())) +        return ConstantFP::get(ReturnType, *C); + +      // minnum(X, +inf) -> X if nnan +      // maxnum(X, -inf) -> X if nnan +      // minimum(X, +inf) -> X +      // maximum(X, -inf) -> X +      if (C->isNegative() != IsMin && (PropagateNaN || Q.CxtI->hasNoNaNs())) +        return Op0; +    }      // Min/max of the same operation with common operand:      // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) @@ -5358,20 +5659,6 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,            (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0))          return Op1; -    // min(X, -Inf) --> -Inf (and commuted variant) -    // max(X, +Inf) --> +Inf (and commuted variant) -    bool UseNegInf = IID == Intrinsic::minnum || IID == Intrinsic::minimum; -    const APFloat *C; -    if ((match(Op0, m_APFloat(C)) && C->isInfinity() && -         C->isNegative() == UseNegInf) || -        (match(Op1, m_APFloat(C)) && C->isInfinity() && -         C->isNegative() == UseNegInf)) -      return ConstantFP::getInfinity(ReturnType, UseNegInf); - -    // TODO: minnum(nnan x, inf) -> x -    // TODO: minnum(nnan ninf x, flt_max) -> x -    // TODO: maxnum(nnan x, -inf) -> x -    // TODO: maxnum(nnan ninf x, -flt_max) -> x      break;    }    default: @@ -5414,11 +5701,11 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {            *ShAmtArg = Call->getArgOperand(2);      // If both operands are undef, the result is undef. -    if (match(Op0, m_Undef()) && match(Op1, m_Undef())) +    if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1))        return UndefValue::get(F->getReturnType());      // If shift amount is undef, assume it is zero. -    if (match(ShAmtArg, m_Undef())) +    if (Q.isUndefValue(ShAmtArg))        return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1);      const APInt *ShAmtC; @@ -5435,7 +5722,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {      Value *Op0 = Call->getArgOperand(0);      Value *Op1 = Call->getArgOperand(1);      Value *Op2 = Call->getArgOperand(2); -    if (Value *V = simplifyFPOp({ Op0, Op1, Op2 })) +    if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q))        return V;      return nullptr;    } @@ -5444,28 +5731,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {    }  } -Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { -  Value *Callee = Call->getCalledOperand(); - -  // musttail calls can only be simplified if they are also DCEd. -  // As we can't guarantee this here, don't simplify them. -  if (Call->isMustTailCall()) -    return nullptr; - -  // call undef -> undef -  // call null -> undef -  if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) -    return UndefValue::get(Call->getType()); - -  Function *F = dyn_cast<Function>(Callee); -  if (!F) -    return nullptr; - -  if (F->isIntrinsic()) -    if (Value *Ret = simplifyIntrinsic(Call, Q)) -      return Ret; - -  if (!canConstantFoldCallTo(Call, F)) +static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { +  auto *F = dyn_cast<Function>(Call->getCalledOperand()); +  if (!F || !canConstantFoldCallTo(Call, F))      return nullptr;    SmallVector<Constant *, 4> ConstantArgs; @@ -5484,10 +5752,33 @@ Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) {    return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI);  } +Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { +  // musttail calls can only be simplified if they are also DCEd. +  // As we can't guarantee this here, don't simplify them. +  if (Call->isMustTailCall()) +    return nullptr; + +  // call undef -> poison +  // call null -> poison +  Value *Callee = Call->getCalledOperand(); +  if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) +    return PoisonValue::get(Call->getType()); + +  if (Value *V = tryConstantFoldCall(Call, Q)) +    return V; + +  auto *F = dyn_cast<Function>(Callee); +  if (F && F->isIntrinsic()) +    if (Value *Ret = simplifyIntrinsic(Call, Q)) +      return Ret; + +  return nullptr; +} +  /// Given operands for a Freeze, see if we can fold the result.  static Value *SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) {    // Use a utility function defined in ValueTracking. -  if (llvm::isGuaranteedNotToBeUndefOrPoison(Op0, Q.CxtI, Q.DT)) +  if (llvm::isGuaranteedNotToBeUndefOrPoison(Op0, Q.AC, Q.CxtI, Q.DT))      return Op0;    // We have room for improvement.    return nullptr; @@ -5596,7 +5887,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,                                  I->getOperand(2), Q);      break;    case Instruction::GetElementPtr: { -    SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); +    SmallVector<Value *, 8> Ops(I->operands());      Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(),                               Ops, Q);      break; @@ -5733,13 +6024,6 @@ static bool replaceAndRecursivelySimplifyImpl(    return Simplified;  } -bool llvm::recursivelySimplifyInstruction(Instruction *I, -                                          const TargetLibraryInfo *TLI, -                                          const DominatorTree *DT, -                                          AssumptionCache *AC) { -  return replaceAndRecursivelySimplifyImpl(I, nullptr, TLI, DT, AC, nullptr); -} -  bool llvm::replaceAndRecursivelySimplify(      Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI,      const DominatorTree *DT, AssumptionCache *AC,  | 
