diff options
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
| -rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 1311 |
1 files changed, 789 insertions, 522 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 0975a65d183e..c40e5c36cdc7 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -228,64 +228,56 @@ static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { return false; } -/// Simplify "A op (B op' C)" by distributing op over op', turning it into -/// "(A op B) op' (A op C)". Here "op" is given by Opcode and "op'" is -/// given by OpcodeToExpand, while "A" corresponds to LHS and "B op' C" to RHS. -/// Also performs the transform "(A op' B) op C" -> "(A op C) op' (B op C)". -/// Returns the simplified value, or null if no simplification was performed. -static Value *ExpandBinOp(Instruction::BinaryOps Opcode, Value *LHS, Value *RHS, - Instruction::BinaryOps OpcodeToExpand, +/// Try to simplify a binary operator of form "V op OtherOp" where V is +/// "(B0 opex B1)" by distributing 'op' across 'opex' as +/// "(B0 op OtherOp) opex (B1 op OtherOp)". +static Value *expandBinOp(Instruction::BinaryOps Opcode, Value *V, + Value *OtherOp, Instruction::BinaryOps OpcodeToExpand, const SimplifyQuery &Q, unsigned MaxRecurse) { - // Recursion is always used, so bail out at once if we already hit the limit. - if (!MaxRecurse--) + auto *B = dyn_cast<BinaryOperator>(V); + if (!B || B->getOpcode() != OpcodeToExpand) + return nullptr; + Value *B0 = B->getOperand(0), *B1 = B->getOperand(1); + Value *L = SimplifyBinOp(Opcode, B0, OtherOp, Q.getWithoutUndef(), + MaxRecurse); + if (!L) + return nullptr; + Value *R = SimplifyBinOp(Opcode, B1, OtherOp, Q.getWithoutUndef(), + MaxRecurse); + if (!R) return nullptr; - // Check whether the expression has the form "(A op' B) op C". - if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) - if (Op0->getOpcode() == OpcodeToExpand) { - // It does! Try turning it into "(A op C) op' (B op C)". - Value *A = Op0->getOperand(0), *B = Op0->getOperand(1), *C = RHS; - // Do "A op C" and "B op C" both simplify? - if (Value *L = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) - if (Value *R = SimplifyBinOp(Opcode, B, C, Q, MaxRecurse)) { - // They do! Return "L op' R" if it simplifies or is already available. - // If "L op' R" equals "A op' B" then "L op' R" is just the LHS. - if ((L == A && R == B) || (Instruction::isCommutative(OpcodeToExpand) - && L == B && R == A)) { - ++NumExpand; - return LHS; - } - // Otherwise return "L op' R" if it simplifies. - if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { - ++NumExpand; - return V; - } - } - } + // Does the expanded pair of binops simplify to the existing binop? + if ((L == B0 && R == B1) || + (Instruction::isCommutative(OpcodeToExpand) && L == B1 && R == B0)) { + ++NumExpand; + return B; + } - // Check whether the expression has the form "A op (B op' C)". - if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) - if (Op1->getOpcode() == OpcodeToExpand) { - // It does! Try turning it into "(A op B) op' (A op C)". - Value *A = LHS, *B = Op1->getOperand(0), *C = Op1->getOperand(1); - // Do "A op B" and "A op C" both simplify? - if (Value *L = SimplifyBinOp(Opcode, A, B, Q, MaxRecurse)) - if (Value *R = SimplifyBinOp(Opcode, A, C, Q, MaxRecurse)) { - // They do! Return "L op' R" if it simplifies or is already available. - // If "L op' R" equals "B op' C" then "L op' R" is just the RHS. - if ((L == B && R == C) || (Instruction::isCommutative(OpcodeToExpand) - && L == C && R == B)) { - ++NumExpand; - return RHS; - } - // Otherwise return "L op' R" if it simplifies. - if (Value *V = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse)) { - ++NumExpand; - return V; - } - } - } + // Otherwise, return "L op' R" if it simplifies. + Value *S = SimplifyBinOp(OpcodeToExpand, L, R, Q, MaxRecurse); + if (!S) + return nullptr; + + ++NumExpand; + return S; +} +/// Try to simplify binops of form "A op (B op' C)" or the commuted variant by +/// distributing op over op'. +static Value *expandCommutativeBinOp(Instruction::BinaryOps Opcode, + Value *L, Value *R, + Instruction::BinaryOps OpcodeToExpand, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + if (Value *V = expandBinOp(Opcode, L, R, OpcodeToExpand, Q, MaxRecurse)) + return V; + if (Value *V = expandBinOp(Opcode, R, L, OpcodeToExpand, Q, MaxRecurse)) + return V; return nullptr; } @@ -423,9 +415,9 @@ static Value *ThreadBinOpOverSelect(Instruction::BinaryOps Opcode, Value *LHS, return TV; // If one branch simplified to undef, return the other one. - if (TV && isa<UndefValue>(TV)) + if (TV && Q.isUndefValue(TV)) return FV; - if (FV && isa<UndefValue>(FV)) + if (FV && Q.isUndefValue(FV)) return TV; // If applying the operation did not change the true and false select values, @@ -620,7 +612,7 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, return C; // X + undef -> undef - if (match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op1; // X + 0 -> X @@ -740,7 +732,7 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // X - undef -> undef // undef - X -> undef - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return UndefValue::get(Op0->getType()); // X - 0 -> X @@ -875,7 +867,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X * undef -> 0 // X * 0 -> 0 - if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) + if (Q.isUndefValue(Op1) || match(Op1, m_Zero())) return Constant::getNullValue(Op0->getType()); // X * 1 -> X @@ -901,8 +893,8 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return V; // Mul distributes over Add. Try some generic simplifications based on this. - if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add, - Q, MaxRecurse)) + if (Value *V = expandCommutativeBinOp(Instruction::Mul, Op0, Op1, + Instruction::Add, Q, MaxRecurse)) return V; // If the operation is with the result of a select instruction, check whether @@ -928,36 +920,37 @@ Value *llvm::SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { /// Check for common or similar folds of integer division or integer remainder. /// This applies to all 4 opcodes (sdiv/udiv/srem/urem). -static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) { +static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv, + const SimplifyQuery &Q) { Type *Ty = Op0->getType(); - // X / undef -> undef - // X % undef -> undef - if (match(Op1, m_Undef())) - return Op1; + // X / undef -> poison + // X % undef -> poison + if (Q.isUndefValue(Op1)) + return PoisonValue::get(Ty); - // X / 0 -> undef - // X % 0 -> undef + // X / 0 -> poison + // X % 0 -> poison // We don't need to preserve faults! if (match(Op1, m_Zero())) - return UndefValue::get(Ty); + return PoisonValue::get(Ty); - // If any element of a constant divisor fixed width vector is zero or undef, - // the whole op is undef. + // If any element of a constant divisor fixed width vector is zero or undef + // the behavior is undefined and we can fold the whole op to poison. auto *Op1C = dyn_cast<Constant>(Op1); auto *VTy = dyn_cast<FixedVectorType>(Ty); if (Op1C && VTy) { unsigned NumElts = VTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = Op1C->getAggregateElement(i); - if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt))) - return UndefValue::get(Ty); + if (Elt && (Elt->isNullValue() || Q.isUndefValue(Elt))) + return PoisonValue::get(Ty); } } // undef / X -> 0 // undef % X -> 0 - if (match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return Constant::getNullValue(Ty); // 0 / X -> 0 @@ -1051,7 +1044,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, true)) + if (Value *V = simplifyDivRem(Op0, Op1, true, Q)) return V; bool IsSigned = Opcode == Instruction::SDiv; @@ -1109,7 +1102,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q)) return C; - if (Value *V = simplifyDivRem(Op0, Op1, false)) + if (Value *V = simplifyDivRem(Op0, Op1, false, Q)) return V; // (X % Y) % Y -> X % Y @@ -1204,14 +1197,14 @@ Value *llvm::SimplifyURemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) { return ::SimplifyURemInst(Op0, Op1, Q, RecursionLimit); } -/// Returns true if a shift by \c Amount always yields undef. -static bool isUndefShift(Value *Amount) { +/// Returns true if a shift by \c Amount always yields poison. +static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) { Constant *C = dyn_cast<Constant>(Amount); if (!C) return false; - // X shift by undef -> undef because it may shift by the bitwidth. - if (isa<UndefValue>(C)) + // X shift by undef -> poison because it may shift by the bitwidth. + if (Q.isUndefValue(C)) return true; // Shifting by the bitwidth or more is undefined. @@ -1222,9 +1215,10 @@ static bool isUndefShift(Value *Amount) { // If all lanes of a vector shift are undefined the whole shift is. if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) { - for (unsigned I = 0, E = cast<VectorType>(C->getType())->getNumElements(); + for (unsigned I = 0, + E = cast<FixedVectorType>(C->getType())->getNumElements(); I != E; ++I) - if (!isUndefShift(C->getAggregateElement(I))) + if (!isPoisonShift(C->getAggregateElement(I), Q)) return false; return true; } @@ -1252,8 +1246,8 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, return Op0; // Fold undefined shifts. - if (isUndefShift(Op1)) - return UndefValue::get(Op0->getType()); + if (isPoisonShift(Op1, Q)) + return PoisonValue::get(Op0->getType()); // If the operation is with the result of a select instruction, check whether // operating on either branch of the select always yields the same value. @@ -1271,7 +1265,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // the number of bits in the type, the shift is undefined. KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); if (Known.One.getLimitedValue() >= Known.getBitWidth()) - return UndefValue::get(Op0->getType()); + return PoisonValue::get(Op0->getType()); // If all valid bits in the shift amount are known zero, the first operand is // unchanged. @@ -1296,7 +1290,7 @@ static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, // undef >> X -> 0 // undef >> X -> undef (if it's exact) - if (match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return isExact ? Op0 : Constant::getNullValue(Op0->getType()); // The low bit cannot be shifted out of an exact shift if it is set. @@ -1318,7 +1312,7 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, // undef << X -> 0 // undef << X -> undef if (if it's NSW/NUW) - if (match(Op0, m_Undef())) + if (Q.isUndefValue(Op0)) return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType()); // (X >> A) << A -> X @@ -1704,25 +1698,27 @@ static Value *simplifyAndOrOfICmpsWithLimitConst(ICmpInst *Cmp0, ICmpInst *Cmp1, if (!Cmp0->isEquality()) return nullptr; - // The equality compare must be against a constant. Convert the 'null' pointer - // constant to an integer zero value. - APInt MinMaxC; - const APInt *C; - if (match(Cmp0->getOperand(1), m_APInt(C))) - MinMaxC = *C; - else if (isa<ConstantPointerNull>(Cmp0->getOperand(1))) - MinMaxC = APInt::getNullValue(8); - else - return nullptr; - // The non-equality compare must include a common operand (X). Canonicalize // the common operand as operand 0 (the predicate is swapped if the common // operand was operand 1). ICmpInst::Predicate Pred0 = Cmp0->getPredicate(); Value *X = Cmp0->getOperand(0); ICmpInst::Predicate Pred1; - if (!match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value())) || - ICmpInst::isEquality(Pred1)) + bool HasNotOp = match(Cmp1, m_c_ICmp(Pred1, m_Not(m_Specific(X)), m_Value())); + if (!HasNotOp && !match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value()))) + return nullptr; + if (ICmpInst::isEquality(Pred1)) + return nullptr; + + // The equality compare must be against a constant. Flip bits if we matched + // a bitwise not. Convert a null pointer constant to an integer zero value. + APInt MinMaxC; + const APInt *C; + if (match(Cmp0->getOperand(1), m_APInt(C))) + MinMaxC = HasNotOp ? ~*C : *C; + else if (isa<ConstantPointerNull>(Cmp0->getOperand(1))) + MinMaxC = APInt::getNullValue(8); + else return nullptr; // DeMorganize if this is 'or': P0 || P1 --> !P0 && !P1. @@ -2003,6 +1999,30 @@ static Value *omitCheckForZeroBeforeInvertedMulWithOverflow(Value *Op0, return NotOp1; } +/// Given a bitwise logic op, check if the operands are add/sub with a common +/// source value and inverted constant (identity: C - X -> ~(X + ~C)). +static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, + Instruction::BinaryOps Opcode) { + assert(Op0->getType() == Op1->getType() && "Mismatched binop types"); + assert(BinaryOperator::isBitwiseLogicOp(Opcode) && "Expected logic op"); + Value *X; + Constant *C1, *C2; + if ((match(Op0, m_Add(m_Value(X), m_Constant(C1))) && + match(Op1, m_Sub(m_Constant(C2), m_Specific(X)))) || + (match(Op1, m_Add(m_Value(X), m_Constant(C1))) && + match(Op0, m_Sub(m_Constant(C2), m_Specific(X))))) { + if (ConstantExpr::getNot(C1) == C2) { + // (X + C) & (~C - X) --> (X + C) & ~(X + C) --> 0 + // (X + C) | (~C - X) --> (X + C) | ~(X + C) --> -1 + // (X + C) ^ (~C - X) --> (X + C) ^ ~(X + C) --> -1 + Type *Ty = Op0->getType(); + return Opcode == Instruction::And ? ConstantInt::getNullValue(Ty) + : ConstantInt::getAllOnesValue(Ty); + } + } + return nullptr; +} + /// Given operands for an And, see if we can fold the result. /// If not, this returns null. static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, @@ -2011,7 +2031,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // X & undef -> 0 - if (match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Constant::getNullValue(Op0->getType()); // X & X = X @@ -2039,6 +2059,9 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (match(Op1, m_c_Or(m_Specific(Op0), m_Value()))) return Op0; + if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::And)) + return V; + // A mask that only clears known zeros of a shifted value is a no-op. Value *X; const APInt *Mask; @@ -2095,21 +2118,30 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return V; // And distributes over Or. Try some generic simplifications based on this. - if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Or, - Q, MaxRecurse)) + if (Value *V = expandCommutativeBinOp(Instruction::And, Op0, Op1, + Instruction::Or, Q, MaxRecurse)) return V; // And distributes over Xor. Try some generic simplifications based on this. - if (Value *V = ExpandBinOp(Instruction::And, Op0, Op1, Instruction::Xor, - Q, MaxRecurse)) + if (Value *V = expandCommutativeBinOp(Instruction::And, Op0, Op1, + Instruction::Xor, Q, MaxRecurse)) return V; - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) { + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // A & (A && B) -> A && B + if (match(Op1, m_Select(m_Specific(Op0), m_Value(), m_Zero()))) + return Op1; + else if (match(Op0, m_Select(m_Specific(Op1), m_Value(), m_Zero()))) + return Op0; + } + // If the operation is with the result of a select instruction, check + // whether operating on either branch of the select always yields the same + // value. if (Value *V = ThreadBinOpOverSelect(Instruction::And, Op0, Op1, Q, MaxRecurse)) return V; + } // If the operation is with the result of a phi instruction, check whether // operating on all incoming values of the phi always yields the same value. @@ -2169,7 +2201,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // X | undef -> -1 // X | -1 = -1 // Do not return Op1 because it may contain undef elements if it's a vector. - if (match(Op1, m_Undef()) || match(Op1, m_AllOnes())) + if (Q.isUndefValue(Op1) || match(Op1, m_AllOnes())) return Constant::getAllOnesValue(Op0->getType()); // X | X = X @@ -2198,7 +2230,10 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (match(Op1, m_Not(m_c_And(m_Specific(Op0), m_Value())))) return Constant::getAllOnesValue(Op0->getType()); - Value *A, *B; + if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::Or)) + return V; + + Value *A, *B, *NotA; // (A & ~B) | (A ^ B) -> (A ^ B) // (~B & A) | (A ^ B) -> (A ^ B) // (A & ~B) | (B ^ A) -> (B ^ A) @@ -2227,6 +2262,7 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) return Op1; + // Commute the 'or' operands. // (~A ^ B) | (A & B) -> (~A ^ B) // (~A ^ B) | (B & A) -> (~A ^ B) // (B ^ ~A) | (A & B) -> (B ^ ~A) @@ -2236,6 +2272,25 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) return Op0; + // (~A & B) | ~(A | B) --> ~A + // (~A & B) | ~(B | A) --> ~A + // (B & ~A) | ~(A | B) --> ~A + // (B & ~A) | ~(B | A) --> ~A + if (match(Op0, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))), + m_Value(B))) && + match(Op1, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return NotA; + + // Commute the 'or' operands. + // ~(A | B) | (~A & B) --> ~A + // ~(B | A) | (~A & B) --> ~A + // ~(A | B) | (B & ~A) --> ~A + // ~(B | A) | (B & ~A) --> ~A + if (match(Op1, m_c_And(m_CombineAnd(m_Value(NotA), m_Not(m_Value(A))), + m_Value(B))) && + match(Op0, m_Not(m_c_Or(m_Specific(A), m_Specific(B))))) + return NotA; + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) return V; @@ -2253,16 +2308,25 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return V; // Or distributes over And. Try some generic simplifications based on this. - if (Value *V = ExpandBinOp(Instruction::Or, Op0, Op1, Instruction::And, Q, - MaxRecurse)) + if (Value *V = expandCommutativeBinOp(Instruction::Or, Op0, Op1, + Instruction::And, Q, MaxRecurse)) return V; - // If the operation is with the result of a select instruction, check whether - // operating on either branch of the select always yields the same value. - if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) + if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) { + if (Op0->getType()->isIntOrIntVectorTy(1)) { + // A | (A || B) -> A || B + if (match(Op1, m_Select(m_Specific(Op0), m_One(), m_Value()))) + return Op1; + else if (match(Op0, m_Select(m_Specific(Op1), m_One(), m_Value()))) + return Op0; + } + // If the operation is with the result of a select instruction, check + // whether operating on either branch of the select always yields the same + // value. if (Value *V = ThreadBinOpOverSelect(Instruction::Or, Op0, Op1, Q, MaxRecurse)) return V; + } // (A & C1)|(B & C2) const APInt *C1, *C2; @@ -2311,7 +2375,7 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return C; // A ^ undef -> undef - if (match(Op1, m_Undef())) + if (Q.isUndefValue(Op1)) return Op1; // A ^ 0 = A @@ -2327,6 +2391,9 @@ static Value *SimplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op1, m_Not(m_Specific(Op0)))) return Constant::getAllOnesValue(Op0->getType()); + if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::Xor)) + return V; + // Try some generic simplifications for associative operations. if (Value *V = SimplifyAssociativeBinOp(Instruction::Xor, Op0, Op1, Q, MaxRecurse)) @@ -2533,8 +2600,8 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI, // memory within the lifetime of the current function (allocas, byval // arguments, globals), then determine the comparison result here. SmallVector<const Value *, 8> LHSUObjs, RHSUObjs; - GetUnderlyingObjects(LHS, LHSUObjs, DL); - GetUnderlyingObjects(RHS, RHSUObjs, DL); + getUnderlyingObjects(LHS, LHSUObjs); + getUnderlyingObjects(RHS, RHSUObjs); // Is the set of underlying objects all noalias calls? auto IsNAC = [](ArrayRef<const Value *> Objects) { @@ -2741,7 +2808,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, } const APInt *C; - if (!match(RHS, m_APInt(C))) + if (!match(RHS, m_APIntAllowUndef(C))) return nullptr; // Rule out tautological comparisons (eg., ult 0 or uge 0). @@ -2759,17 +2826,166 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return ConstantInt::getFalse(ITy); } + // (mul nuw/nsw X, MulC) != C --> true (if C is not a multiple of MulC) + // (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC) + const APInt *MulC; + if (ICmpInst::isEquality(Pred) && + ((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) && + *MulC != 0 && C->urem(*MulC) != 0) || + (match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) && + *MulC != 0 && C->srem(*MulC) != 0))) + return ConstantInt::get(ITy, Pred == ICmpInst::ICMP_NE); + return nullptr; } +static Value *simplifyICmpWithBinOpOnLHS( + CmpInst::Predicate Pred, BinaryOperator *LBO, Value *RHS, + const SimplifyQuery &Q, unsigned MaxRecurse) { + Type *ITy = GetCompareTy(RHS); // The return type. + + Value *Y = nullptr; + // icmp pred (or X, Y), X + if (match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { + if (Pred == ICmpInst::ICMP_ULT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_UGE) + return getTrue(ITy); + + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { + KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (RHSKnown.isNonNegative() && YKnown.isNegative()) + return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); + if (RHSKnown.isNegative() || YKnown.isNonNegative()) + return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); + } + } + + // icmp pred (and X, Y), X + if (match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + // icmp pred (urem X, Y), Y + if (match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGE: + return getFalse(ITy); + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) + break; + LLVM_FALLTHROUGH; + } + case ICmpInst::ICMP_NE: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULE: + return getTrue(ITy); + } + } + + // icmp pred (urem X, Y), X + if (match(LBO, m_URem(m_Specific(RHS), m_Value()))) { + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + } + + // x >> y <=u x + // x udiv y <=u x. + if (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || + match(LBO, m_UDiv(m_Specific(RHS), m_Value()))) { + // icmp pred (X op Y), X + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + // (x*C1)/C2 <= x for C1 <= C2. + // This holds even if the multiplication overflows: Assume that x != 0 and + // arithmetic is modulo M. For overflow to occur we must have C1 >= M/x and + // thus C2 >= M/x. It follows that (x*C1)/C2 <= (M-1)/C2 <= ((M-1)*x)/M < x. + // + // Additionally, either the multiplication and division might be represented + // as shifts: + // (x*C1)>>C2 <= x for C1 < 2**C2. + // (x<<C1)/C2 <= x for 2**C1 < C2. + const APInt *C1, *C2; + if ((match(LBO, m_UDiv(m_Mul(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && + C1->ule(*C2)) || + (match(LBO, m_LShr(m_Mul(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && + C1->ule(APInt(C2->getBitWidth(), 1) << *C2)) || + (match(LBO, m_UDiv(m_Shl(m_Specific(RHS), m_APInt(C1)), m_APInt(C2))) && + (APInt(C1->getBitWidth(), 1) << *C1).ule(*C2))) { + if (Pred == ICmpInst::ICMP_UGT) + return getFalse(ITy); + if (Pred == ICmpInst::ICMP_ULE) + return getTrue(ITy); + } + + return nullptr; +} + + +// If only one of the icmp's operands has NSW flags, try to prove that: +// +// icmp slt (x + C1), (x +nsw C2) +// +// is equivalent to: +// +// icmp slt C1, C2 +// +// which is true if x + C2 has the NSW flags set and: +// *) C1 < C2 && C1 >= 0, or +// *) C2 < C1 && C1 <= 0. +// +static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + // TODO: only support icmp slt for now. + if (Pred != CmpInst::ICMP_SLT) + return false; + + // Canonicalize nsw add as RHS. + if (!match(RHS, m_NSWAdd(m_Value(), m_Value()))) + std::swap(LHS, RHS); + if (!match(RHS, m_NSWAdd(m_Value(), m_Value()))) + return false; + + Value *X; + const APInt *C1, *C2; + if (!match(LHS, m_c_Add(m_Value(X), m_APInt(C1))) || + !match(RHS, m_c_Add(m_Specific(X), m_APInt(C2)))) + return false; + + return (C1->slt(*C2) && C1->isNonNegative()) || + (C2->slt(*C1) && C1->isNonPositive()); +} + + /// TODO: A large part of this logic is duplicated in InstCombine's /// foldICmpBinOp(). We should be able to share that and avoid the code /// duplication. static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { - Type *ITy = GetCompareTy(LHS); // The return type. - BinaryOperator *LBO = dyn_cast<BinaryOperator>(LHS); BinaryOperator *RBO = dyn_cast<BinaryOperator>(RHS); if (MaxRecurse && (LBO || RBO)) { @@ -2813,8 +3029,9 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return V; // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. - if (A && C && (A == C || A == D || B == C || B == D) && NoLHSWrapProblem && - NoRHSWrapProblem) { + bool CanSimplify = (NoLHSWrapProblem && NoRHSWrapProblem) || + trySimplifyICmpWithAdds(Pred, LHS, RHS); + if (A && C && (A == C || A == D || B == C || B == D) && CanSimplify) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -2840,195 +3057,66 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, } } - { - Value *Y = nullptr; - // icmp pred (or X, Y), X - if (LBO && match(LBO, m_c_Or(m_Value(Y), m_Specific(RHS)))) { - if (Pred == ICmpInst::ICMP_ULT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_UGE) - return getTrue(ITy); - - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { - KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (RHSKnown.isNonNegative() && YKnown.isNegative()) - return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); - if (RHSKnown.isNegative() || YKnown.isNonNegative()) - return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); - } - } - // icmp pred X, (or X, Y) - if (RBO && match(RBO, m_c_Or(m_Value(Y), m_Specific(LHS)))) { - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); - - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { - KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (LHSKnown.isNonNegative() && YKnown.isNegative()) - return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); - if (LHSKnown.isNegative() || YKnown.isNonNegative()) - return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); - } - } - } + if (LBO) + if (Value *V = simplifyICmpWithBinOpOnLHS(Pred, LBO, RHS, Q, MaxRecurse)) + return V; - // icmp pred (and X, Y), X - if (LBO && match(LBO, m_c_And(m_Value(), m_Specific(RHS)))) { - if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - // icmp pred X, (and X, Y) - if (RBO && match(RBO, m_c_And(m_Value(), m_Specific(LHS)))) { - if (Pred == ICmpInst::ICMP_UGE) - return getTrue(ITy); - if (Pred == ICmpInst::ICMP_ULT) - return getFalse(ITy); - } + if (RBO) + if (Value *V = simplifyICmpWithBinOpOnLHS( + ICmpInst::getSwappedPredicate(Pred), RBO, LHS, Q, MaxRecurse)) + return V; // 0 - (zext X) pred C if (!CmpInst::isUnsigned(Pred) && match(LHS, m_Neg(m_ZExt(m_Value())))) { - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS)) { - if (RHSC->getValue().isStrictlyPositive()) { - if (Pred == ICmpInst::ICMP_SLT) - return ConstantInt::getTrue(RHSC->getContext()); - if (Pred == ICmpInst::ICMP_SGE) - return ConstantInt::getFalse(RHSC->getContext()); - if (Pred == ICmpInst::ICMP_EQ) - return ConstantInt::getFalse(RHSC->getContext()); - if (Pred == ICmpInst::ICMP_NE) - return ConstantInt::getTrue(RHSC->getContext()); + const APInt *C; + if (match(RHS, m_APInt(C))) { + if (C->isStrictlyPositive()) { + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(GetCompareTy(RHS)); + if (Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(GetCompareTy(RHS)); } - if (RHSC->getValue().isNonNegative()) { + if (C->isNonNegative()) { if (Pred == ICmpInst::ICMP_SLE) - return ConstantInt::getTrue(RHSC->getContext()); + return ConstantInt::getTrue(GetCompareTy(RHS)); if (Pred == ICmpInst::ICMP_SGT) - return ConstantInt::getFalse(RHSC->getContext()); + return ConstantInt::getFalse(GetCompareTy(RHS)); } } } - // icmp pred (urem X, Y), Y - if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { - switch (Pred) { - default: - break; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: { - KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (!Known.isNonNegative()) - break; - LLVM_FALLTHROUGH; - } - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - return getFalse(ITy); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: { - KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (!Known.isNonNegative()) - break; - LLVM_FALLTHROUGH; - } - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - return getTrue(ITy); - } - } - - // icmp pred X, (urem Y, X) - if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { - switch (Pred) { - default: - break; - case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: { - KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (!Known.isNonNegative()) - break; - LLVM_FALLTHROUGH; - } - case ICmpInst::ICMP_NE: - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - return getTrue(ITy); - case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: { - KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - if (!Known.isNonNegative()) - break; - LLVM_FALLTHROUGH; - } - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - return getFalse(ITy); + // If C2 is a power-of-2 and C is not: + // (C2 << X) == C --> false + // (C2 << X) != C --> true + const APInt *C; + if (match(LHS, m_Shl(m_Power2(), m_Value())) && + match(RHS, m_APIntAllowUndef(C)) && !C->isPowerOf2()) { + // C2 << X can equal zero in some circumstances. + // This simplification might be unsafe if C is zero. + // + // We know it is safe if: + // - The shift is nsw. We can't shift out the one bit. + // - The shift is nuw. We can't shift out the one bit. + // - C2 is one. + // - C isn't zero. + if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || + Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || + match(LHS, m_Shl(m_One(), m_Value())) || !C->isNullValue()) { + if (Pred == ICmpInst::ICMP_EQ) + return ConstantInt::getFalse(GetCompareTy(RHS)); + if (Pred == ICmpInst::ICMP_NE) + return ConstantInt::getTrue(GetCompareTy(RHS)); } } - // x >> y <=u x - // x udiv y <=u x. - if (LBO && (match(LBO, m_LShr(m_Specific(RHS), m_Value())) || - match(LBO, m_UDiv(m_Specific(RHS), m_Value())))) { - // icmp pred (X op Y), X + // TODO: This is overly constrained. LHS can be any power-of-2. + // (1 << X) >u 0x8000 --> false + // (1 << X) <=u 0x8000 --> true + if (match(LHS, m_Shl(m_One(), m_Value())) && match(RHS, m_SignMask())) { if (Pred == ICmpInst::ICMP_UGT) - return getFalse(ITy); + return ConstantInt::getFalse(GetCompareTy(RHS)); if (Pred == ICmpInst::ICMP_ULE) - return getTrue(ITy); - } - - // x >=u x >> y - // x >=u x udiv y. - if (RBO && (match(RBO, m_LShr(m_Specific(LHS), m_Value())) || - match(RBO, m_UDiv(m_Specific(LHS), m_Value())))) { - // icmp pred X, (X op Y) - if (Pred == ICmpInst::ICMP_ULT) - return getFalse(ITy); - if (Pred == ICmpInst::ICMP_UGE) - return getTrue(ITy); - } - - // handle: - // CI2 << X == CI - // CI2 << X != CI - // - // where CI2 is a power of 2 and CI isn't - if (auto *CI = dyn_cast<ConstantInt>(RHS)) { - const APInt *CI2Val, *CIVal = &CI->getValue(); - if (LBO && match(LBO, m_Shl(m_APInt(CI2Val), m_Value())) && - CI2Val->isPowerOf2()) { - if (!CIVal->isPowerOf2()) { - // CI2 << X can equal zero in some circumstances, - // this simplification is unsafe if CI is zero. - // - // We know it is safe if: - // - The shift is nsw, we can't shift out the one bit. - // - The shift is nuw, we can't shift out the one bit. - // - CI2 is one - // - CI isn't zero - if (Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(LBO)) || - Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(LBO)) || - CI2Val->isOneValue() || !CI->isZero()) { - if (Pred == ICmpInst::ICMP_EQ) - return ConstantInt::getFalse(RHS->getContext()); - if (Pred == ICmpInst::ICMP_NE) - return ConstantInt::getTrue(RHS->getContext()); - } - } - if (CIVal->isSignMask() && CI2Val->isOneValue()) { - if (Pred == ICmpInst::ICMP_UGT) - return ConstantInt::getFalse(RHS->getContext()); - if (Pred == ICmpInst::ICMP_ULE) - return ConstantInt::getTrue(RHS->getContext()); - } - } + return ConstantInt::getTrue(GetCompareTy(RHS)); } if (MaxRecurse && LBO && RBO && LBO->getOpcode() == RBO->getOpcode() && @@ -3226,55 +3314,38 @@ static Value *simplifyICmpWithMinMax(CmpInst::Predicate Pred, Value *LHS, break; } case CmpInst::ICMP_UGE: - // Always true. return getTrue(ITy); case CmpInst::ICMP_ULT: - // Always false. return getFalse(ITy); } } - // Variants on "max(x,y) >= min(x,z)". + // Comparing 1 each of min/max with a common operand? + // Canonicalize min operand to RHS. + if (match(LHS, m_UMin(m_Value(), m_Value())) || + match(LHS, m_SMin(m_Value(), m_Value()))) { + std::swap(LHS, RHS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + Value *C, *D; if (match(LHS, m_SMax(m_Value(A), m_Value(B))) && match(RHS, m_SMin(m_Value(C), m_Value(D))) && (A == C || A == D || B == C || B == D)) { - // max(x, ?) pred min(x, ?). + // smax(A, B) >=s smin(A, D) --> true if (Pred == CmpInst::ICMP_SGE) - // Always true. return getTrue(ITy); + // smax(A, B) <s smin(A, D) --> false if (Pred == CmpInst::ICMP_SLT) - // Always false. - return getFalse(ITy); - } else if (match(LHS, m_SMin(m_Value(A), m_Value(B))) && - match(RHS, m_SMax(m_Value(C), m_Value(D))) && - (A == C || A == D || B == C || B == D)) { - // min(x, ?) pred max(x, ?). - if (Pred == CmpInst::ICMP_SLE) - // Always true. - return getTrue(ITy); - if (Pred == CmpInst::ICMP_SGT) - // Always false. return getFalse(ITy); } else if (match(LHS, m_UMax(m_Value(A), m_Value(B))) && match(RHS, m_UMin(m_Value(C), m_Value(D))) && (A == C || A == D || B == C || B == D)) { - // max(x, ?) pred min(x, ?). + // umax(A, B) >=u umin(A, D) --> true if (Pred == CmpInst::ICMP_UGE) - // Always true. return getTrue(ITy); + // umax(A, B) <u umin(A, D) --> false if (Pred == CmpInst::ICMP_ULT) - // Always false. - return getFalse(ITy); - } else if (match(LHS, m_UMin(m_Value(A), m_Value(B))) && - match(RHS, m_UMax(m_Value(C), m_Value(D))) && - (A == C || A == D || B == C || B == D)) { - // min(x, ?) pred max(x, ?). - if (Pred == CmpInst::ICMP_ULE) - // Always true. - return getTrue(ITy); - if (Pred == CmpInst::ICMP_UGT) - // Always false. return getFalse(ITy); } @@ -3327,12 +3398,12 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // For EQ and NE, we can always pick a value for the undef to make the // predicate pass or fail, so we can return undef. // Matches behavior in llvm::ConstantFoldCompareInstruction. - if (isa<UndefValue>(RHS) && ICmpInst::isEquality(Pred)) + if (Q.isUndefValue(RHS) && ICmpInst::isEquality(Pred)) return UndefValue::get(ITy); // icmp X, X -> true/false // icmp X, undef -> true/false because undef could be X. - if (LHS == RHS || isa<UndefValue>(RHS)) + if (LHS == RHS || Q.isUndefValue(RHS)) return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); if (Value *V = simplifyICmpOfBools(Pred, LHS, RHS, Q)) @@ -3588,7 +3659,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // expression GEP with the same indices and a null base pointer to see // what constant folding can make out of it. Constant *Null = Constant::getNullValue(GLHS->getPointerOperandType()); - SmallVector<Value *, 4> IndicesLHS(GLHS->idx_begin(), GLHS->idx_end()); + SmallVector<Value *, 4> IndicesLHS(GLHS->indices()); Constant *NewLHS = ConstantExpr::getGetElementPtr( GLHS->getSourceElementType(), Null, IndicesLHS); @@ -3659,7 +3730,7 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, // fcmp pred x, undef and fcmp pred undef, x // fold to true if unordered, false if ordered - if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { + if (Q.isUndefValue(LHS) || Q.isUndefValue(RHS)) { // Choosing NaN for the undef will always make unordered comparison succeed // and ordered comparison fail. return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); @@ -3703,6 +3774,21 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; } } + + // LHS == Inf + if (Pred == FCmpInst::FCMP_OEQ && isKnownNeverInfinity(LHS, Q.TLI)) + return getFalse(RetTy); + // LHS != Inf + if (Pred == FCmpInst::FCMP_UNE && isKnownNeverInfinity(LHS, Q.TLI)) + return getTrue(RetTy); + // LHS == Inf || LHS == NaN + if (Pred == FCmpInst::FCMP_UEQ && isKnownNeverInfinity(LHS, Q.TLI) && + isKnownNeverNaN(LHS, Q.TLI)) + return getFalse(RetTy); + // LHS != Inf && LHS != NaN + if (Pred == FCmpInst::FCMP_ONE && isKnownNeverInfinity(LHS, Q.TLI) && + isKnownNeverNaN(LHS, Q.TLI)) + return getTrue(RetTy); } if (C->isNegative() && !C->isNegZero()) { assert(!C->isNaN() && "Unexpected NaN constant!"); @@ -3810,10 +3896,10 @@ Value *llvm::SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return ::SimplifyFCmpInst(Predicate, LHS, RHS, FMF, Q, RecursionLimit); } -/// See if V simplifies when its operand Op is replaced with RepOp. -static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, - const SimplifyQuery &Q, - unsigned MaxRecurse) { +static Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement, + unsigned MaxRecurse) { // Trivial replacement. if (V == Op) return RepOp; @@ -3826,30 +3912,41 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (!I) return nullptr; + // Consider: + // %cmp = icmp eq i32 %x, 2147483647 + // %add = add nsw i32 %x, 1 + // %sel = select i1 %cmp, i32 -2147483648, i32 %add + // + // We can't replace %sel with %add unless we strip away the flags (which will + // be done in InstCombine). + // TODO: This is unsound, because it only catches some forms of refinement. + if (!AllowRefinement && canCreatePoison(cast<Operator>(I))) + return nullptr; + + // The simplification queries below may return the original value. Consider: + // %div = udiv i32 %arg, %arg2 + // %mul = mul nsw i32 %div, %arg2 + // %cmp = icmp eq i32 %mul, %arg + // %sel = select i1 %cmp, i32 %div, i32 undef + // Replacing %arg by %mul, %div becomes "udiv i32 %mul, %arg2", which + // simplifies back to %arg. This can only happen because %mul does not + // dominate %div. To ensure a consistent return value contract, we make sure + // that this case returns nullptr as well. + auto PreventSelfSimplify = [V](Value *Simplified) { + return Simplified != V ? Simplified : nullptr; + }; + // If this is a binary operator, try to simplify it with the replaced op. if (auto *B = dyn_cast<BinaryOperator>(I)) { - // Consider: - // %cmp = icmp eq i32 %x, 2147483647 - // %add = add nsw i32 %x, 1 - // %sel = select i1 %cmp, i32 -2147483648, i32 %add - // - // We can't replace %sel with %add unless we strip away the flags. - // TODO: This is an unusual limitation because better analysis results in - // worse simplification. InstCombine can do this fold more generally - // by dropping the flags. Remove this fold to save compile-time? - if (isa<OverflowingBinaryOperator>(B)) - if (Q.IIQ.hasNoSignedWrap(B) || Q.IIQ.hasNoUnsignedWrap(B)) - return nullptr; - if (isa<PossiblyExactOperator>(B) && Q.IIQ.isExact(B)) - return nullptr; - if (MaxRecurse) { if (B->getOperand(0) == Op) - return SimplifyBinOp(B->getOpcode(), RepOp, B->getOperand(1), Q, - MaxRecurse - 1); + return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), RepOp, + B->getOperand(1), Q, + MaxRecurse - 1)); if (B->getOperand(1) == Op) - return SimplifyBinOp(B->getOpcode(), B->getOperand(0), RepOp, Q, - MaxRecurse - 1); + return PreventSelfSimplify(SimplifyBinOp(B->getOpcode(), + B->getOperand(0), RepOp, Q, + MaxRecurse - 1)); } } @@ -3857,11 +3954,13 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, if (CmpInst *C = dyn_cast<CmpInst>(I)) { if (MaxRecurse) { if (C->getOperand(0) == Op) - return SimplifyCmpInst(C->getPredicate(), RepOp, C->getOperand(1), Q, - MaxRecurse - 1); + return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), RepOp, + C->getOperand(1), Q, + MaxRecurse - 1)); if (C->getOperand(1) == Op) - return SimplifyCmpInst(C->getPredicate(), C->getOperand(0), RepOp, Q, - MaxRecurse - 1); + return PreventSelfSimplify(SimplifyCmpInst(C->getPredicate(), + C->getOperand(0), RepOp, Q, + MaxRecurse - 1)); } } @@ -3871,8 +3970,8 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); transform(GEP->operands(), NewOps.begin(), [&](Value *V) { return V == Op ? RepOp : V; }); - return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, - MaxRecurse - 1); + return PreventSelfSimplify(SimplifyGEPInst(GEP->getSourceElementType(), + NewOps, Q, MaxRecurse - 1)); } } @@ -3909,6 +4008,13 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; } +Value *llvm::SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement) { + return ::SimplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, + RecursionLimit); +} + /// Try to simplify a select instruction when its condition operand is an /// integer comparison where one operand of the compare is a constant. static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, @@ -3968,29 +4074,27 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS)))) return nullptr; - if (ICmpInst::isEquality(Pred) && match(CmpRHS, m_Zero())) { + // Canonicalize ne to eq predicate. + if (Pred == ICmpInst::ICMP_NE) { + Pred = ICmpInst::ICMP_EQ; + std::swap(TrueVal, FalseVal); + } + + if (Pred == ICmpInst::ICMP_EQ && match(CmpRHS, m_Zero())) { Value *X; const APInt *Y; if (match(CmpLHS, m_And(m_Value(X), m_APInt(Y)))) if (Value *V = simplifySelectBitTest(TrueVal, FalseVal, X, Y, - Pred == ICmpInst::ICMP_EQ)) + /*TrueWhenUnset=*/true)) return V; // Test for a bogus zero-shift-guard-op around funnel-shift or rotate. Value *ShAmt; - auto isFsh = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), m_Value(), - m_Value(ShAmt)), - m_Intrinsic<Intrinsic::fshr>(m_Value(), m_Value(X), - m_Value(ShAmt))); + auto isFsh = m_CombineOr(m_FShl(m_Value(X), m_Value(), m_Value(ShAmt)), + m_FShr(m_Value(), m_Value(X), m_Value(ShAmt))); // (ShAmt == 0) ? fshl(X, *, ShAmt) : X --> X // (ShAmt == 0) ? fshr(*, X, ShAmt) : X --> X - if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt && - Pred == ICmpInst::ICMP_EQ) - return X; - // (ShAmt != 0) ? X : fshl(X, *, ShAmt) --> X - // (ShAmt != 0) ? X : fshr(*, X, ShAmt) --> X - if (match(FalseVal, isFsh) && TrueVal == X && CmpLHS == ShAmt && - Pred == ICmpInst::ICMP_NE) + if (match(TrueVal, isFsh) && FalseVal == X && CmpLHS == ShAmt) return X; // Test for a zero-shift-guard-op around rotates. These are used to @@ -3998,22 +4102,24 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // intrinsics do not have that problem. // We do not allow this transform for the general funnel shift case because // that would not preserve the poison safety of the original code. - auto isRotate = m_CombineOr(m_Intrinsic<Intrinsic::fshl>(m_Value(X), - m_Deferred(X), - m_Value(ShAmt)), - m_Intrinsic<Intrinsic::fshr>(m_Value(X), - m_Deferred(X), - m_Value(ShAmt))); - // (ShAmt != 0) ? fshl(X, X, ShAmt) : X --> fshl(X, X, ShAmt) - // (ShAmt != 0) ? fshr(X, X, ShAmt) : X --> fshr(X, X, ShAmt) - if (match(TrueVal, isRotate) && FalseVal == X && CmpLHS == ShAmt && - Pred == ICmpInst::ICMP_NE) - return TrueVal; + auto isRotate = + m_CombineOr(m_FShl(m_Value(X), m_Deferred(X), m_Value(ShAmt)), + m_FShr(m_Value(X), m_Deferred(X), m_Value(ShAmt))); // (ShAmt == 0) ? X : fshl(X, X, ShAmt) --> fshl(X, X, ShAmt) // (ShAmt == 0) ? X : fshr(X, X, ShAmt) --> fshr(X, X, ShAmt) if (match(FalseVal, isRotate) && TrueVal == X && CmpLHS == ShAmt && Pred == ICmpInst::ICMP_EQ) return FalseVal; + + // X == 0 ? abs(X) : -abs(X) --> -abs(X) + // X == 0 ? -abs(X) : abs(X) --> abs(X) + if (match(TrueVal, m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS))) && + match(FalseVal, m_Neg(m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS))))) + return FalseVal; + if (match(TrueVal, + m_Neg(m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS)))) && + match(FalseVal, m_Intrinsic<Intrinsic::abs>(m_Specific(CmpLHS)))) + return FalseVal; } // Check for other compares that behave like bit test. @@ -4025,27 +4131,20 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal, // arms of the select. See if substituting this value into the arm and // simplifying the result yields the same value as the other arm. if (Pred == ICmpInst::ICMP_EQ) { - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ false, MaxRecurse) == TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ false, MaxRecurse) == TrueVal) return FalseVal; - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == + if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, + /* AllowRefinement */ true, MaxRecurse) == FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == + SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, + /* AllowRefinement */ true, MaxRecurse) == FalseVal) return FalseVal; - } else if (Pred == ICmpInst::ICMP_NE) { - if (SimplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - FalseVal || - SimplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - FalseVal) - return TrueVal; - if (SimplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, MaxRecurse) == - TrueVal || - SimplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q, MaxRecurse) == - TrueVal) - return TrueVal; } return nullptr; @@ -4092,7 +4191,7 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); // select undef, X, Y -> X or Y - if (isa<UndefValue>(CondC)) + if (Q.isUndefValue(CondC)) return isa<Constant>(FalseVal) ? FalseVal : TrueVal; // TODO: Vector constants with undef elements don't simplify. @@ -4121,19 +4220,21 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // If the true or false value is undef, we can fold to the other value as // long as the other value isn't poison. // select ?, undef, X -> X - if (isa<UndefValue>(TrueVal) && - isGuaranteedNotToBeUndefOrPoison(FalseVal, Q.CxtI, Q.DT)) + if (Q.isUndefValue(TrueVal) && + isGuaranteedNotToBeUndefOrPoison(FalseVal, Q.AC, Q.CxtI, Q.DT)) return FalseVal; // select ?, X, undef -> X - if (isa<UndefValue>(FalseVal) && - isGuaranteedNotToBeUndefOrPoison(TrueVal, Q.CxtI, Q.DT)) + if (Q.isUndefValue(FalseVal) && + isGuaranteedNotToBeUndefOrPoison(TrueVal, Q.AC, Q.CxtI, Q.DT)) return TrueVal; // Deal with partial undef vector constants: select ?, VecC, VecC' --> VecC'' Constant *TrueC, *FalseC; - if (TrueVal->getType()->isVectorTy() && match(TrueVal, m_Constant(TrueC)) && + if (isa<FixedVectorType>(TrueVal->getType()) && + match(TrueVal, m_Constant(TrueC)) && match(FalseVal, m_Constant(FalseC))) { - unsigned NumElts = cast<VectorType>(TrueC->getType())->getNumElements(); + unsigned NumElts = + cast<FixedVectorType>(TrueC->getType())->getNumElements(); SmallVector<Constant *, 16> NewC; for (unsigned i = 0; i != NumElts; ++i) { // Bail out on incomplete vector constants. @@ -4146,10 +4247,10 @@ static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, // one element is undef, choose the defined element as the safe result. if (TEltC == FEltC) NewC.push_back(TEltC); - else if (isa<UndefValue>(TEltC) && + else if (Q.isUndefValue(TEltC) && isGuaranteedNotToBeUndefOrPoison(FEltC)) NewC.push_back(FEltC); - else if (isa<UndefValue>(FEltC) && + else if (Q.isUndefValue(FEltC) && isGuaranteedNotToBeUndefOrPoison(TEltC)) NewC.push_back(TEltC); else @@ -4201,7 +4302,12 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, else if (VectorType *VT = dyn_cast<VectorType>(Ops[1]->getType())) GEPTy = VectorType::get(GEPTy, VT->getElementCount()); - if (isa<UndefValue>(Ops[0])) + // getelementptr poison, idx -> poison + // getelementptr baseptr, poison -> poison + if (any_of(Ops, [](const auto *V) { return isa<PoisonValue>(V); })) + return PoisonValue::get(GEPTy); + + if (Q.isUndefValue(Ops[0])) return UndefValue::get(GEPTy); bool IsScalableVec = isa<ScalableVectorType>(SrcTy); @@ -4224,9 +4330,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, // doesn't truncate the pointers. if (Ops[1]->getType()->getScalarSizeInBits() == Q.DL.getPointerSizeInBits(AS)) { - auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * { - if (match(P, m_Zero())) - return Constant::getNullValue(GEPTy); + auto PtrToInt = [GEPTy](Value *P) -> Value * { Value *Temp; if (match(P, m_PtrToInt(m_Value(Temp)))) if (Temp->getType() == GEPTy) @@ -4234,10 +4338,14 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, return nullptr; }; + // FIXME: The following transforms are only legal if P and V have the + // same provenance (PR44403). Check whether getUnderlyingObject() is + // the same? + // getelementptr V, (sub P, V) -> P if P points to a type of size 1. if (TyAllocSize == 1 && match(Ops[1], m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))))) - if (Value *R = PtrToIntOrZero(P)) + if (Value *R = PtrToInt(P)) return R; // getelementptr V, (ashr (sub P, V), C) -> Q @@ -4246,7 +4354,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, m_AShr(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), m_ConstantInt(C))) && TyAllocSize == 1ULL << C) - if (Value *R = PtrToIntOrZero(P)) + if (Value *R = PtrToInt(P)) return R; // getelementptr V, (sdiv (sub P, V), C) -> Q @@ -4254,7 +4362,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, if (match(Ops[1], m_SDiv(m_Sub(m_Value(P), m_PtrToInt(m_Specific(Ops[0]))), m_SpecificInt(TyAllocSize)))) - if (Value *R = PtrToIntOrZero(P)) + if (Value *R = PtrToInt(P)) return R; } } @@ -4271,15 +4379,21 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops, Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL, BasePtrOffset); + // Avoid creating inttoptr of zero here: While LLVMs treatment of + // inttoptr is generally conservative, this particular case is folded to + // a null pointer, which will have incorrect provenance. + // gep (gep V, C), (sub 0, V) -> C if (match(Ops.back(), - m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr))))) { + m_Sub(m_Zero(), m_PtrToInt(m_Specific(StrippedBasePtr)))) && + !BasePtrOffset.isNullValue()) { auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset); return ConstantExpr::getIntToPtr(CI, GEPTy); } // gep (gep V, C), (xor V, -1) -> C-1 if (match(Ops.back(), - m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes()))) { + m_Xor(m_PtrToInt(m_Specific(StrippedBasePtr)), m_AllOnes())) && + !BasePtrOffset.isOneValue()) { auto *CI = ConstantInt::get(GEPTy->getContext(), BasePtrOffset - 1); return ConstantExpr::getIntToPtr(CI, GEPTy); } @@ -4310,7 +4424,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs); // insertvalue x, undef, n -> x - if (match(Val, m_Undef())) + if (Q.isUndefValue(Val)) return Agg; // insertvalue x, (extractvalue y, n), n @@ -4318,7 +4432,7 @@ static Value *SimplifyInsertValueInst(Value *Agg, Value *Val, if (EV->getAggregateOperand()->getType() == Agg->getType() && EV->getIndices() == Idxs) { // insertvalue undef, (extractvalue y, n), n -> y - if (match(Agg, m_Undef())) + if (Q.isUndefValue(Agg)) return EV->getAggregateOperand(); // insertvalue y, (extractvalue y, n), n -> y @@ -4342,22 +4456,23 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx, auto *ValC = dyn_cast<Constant>(Val); auto *IdxC = dyn_cast<Constant>(Idx); if (VecC && ValC && IdxC) - return ConstantFoldInsertElementInstruction(VecC, ValC, IdxC); + return ConstantExpr::getInsertElement(VecC, ValC, IdxC); - // For fixed-length vector, fold into undef if index is out of bounds. + // For fixed-length vector, fold into poison if index is out of bounds. if (auto *CI = dyn_cast<ConstantInt>(Idx)) { if (isa<FixedVectorType>(Vec->getType()) && CI->uge(cast<FixedVectorType>(Vec->getType())->getNumElements())) - return UndefValue::get(Vec->getType()); + return PoisonValue::get(Vec->getType()); } // If index is undef, it might be out of bounds (see above case) - if (isa<UndefValue>(Idx)) - return UndefValue::get(Vec->getType()); + if (Q.isUndefValue(Idx)) + return PoisonValue::get(Vec->getType()); - // If the scalar is undef, and there is no risk of propagating poison from the - // vector value, simplify to the vector value. - if (isa<UndefValue>(Val) && isGuaranteedNotToBeUndefOrPoison(Vec)) + // If the scalar is poison, or it is undef and there is no risk of + // propagating poison from the vector value, simplify to the vector value. + if (isa<PoisonValue>(Val) || + (Q.isUndefValue(Val) && isGuaranteedNotToBePoison(Vec))) return Vec; // If we are extracting a value from a vector, then inserting it into the same @@ -4401,18 +4516,18 @@ Value *llvm::SimplifyExtractValueInst(Value *Agg, ArrayRef<unsigned> Idxs, /// Given operands for an ExtractElementInst, see if we can fold the result. /// If not, this returns null. -static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQuery &, - unsigned) { +static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, + const SimplifyQuery &Q, unsigned) { auto *VecVTy = cast<VectorType>(Vec->getType()); if (auto *CVec = dyn_cast<Constant>(Vec)) { if (auto *CIdx = dyn_cast<Constant>(Idx)) - return ConstantFoldExtractElementInstruction(CVec, CIdx); + return ConstantExpr::getExtractElement(CVec, CIdx); // The index is not relevant if our vector is a splat. if (auto *Splat = CVec->getSplatValue()) return Splat; - if (isa<UndefValue>(Vec)) + if (Q.isUndefValue(Vec)) return UndefValue::get(VecVTy->getElementType()); } @@ -4421,16 +4536,16 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQ if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { // For fixed-length vector, fold into undef if index is out of bounds. if (isa<FixedVectorType>(VecVTy) && - IdxC->getValue().uge(VecVTy->getNumElements())) - return UndefValue::get(VecVTy->getElementType()); + IdxC->getValue().uge(cast<FixedVectorType>(VecVTy)->getNumElements())) + return PoisonValue::get(VecVTy->getElementType()); if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) return Elt; } // An undef extract index can be arbitrarily chosen to be an out-of-range - // index value, which would result in the instruction being undef. - if (isa<UndefValue>(Idx)) - return UndefValue::get(VecVTy->getElementType()); + // index value, which would result in the instruction being poison. + if (Q.isUndefValue(Idx)) + return PoisonValue::get(VecVTy->getElementType()); return nullptr; } @@ -4442,6 +4557,10 @@ Value *llvm::SimplifyExtractElementInst(Value *Vec, Value *Idx, /// See if we can fold the given phi. If not, returns null. static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { + // WARNING: no matter how worthwhile it may seem, we can not perform PHI CSE + // here, because the PHI we may succeed simplifying to was not + // def-reachable from the original PHI! + // If all of the PHI's incoming values are the same then replace the PHI node // with the common value. Value *CommonValue = nullptr; @@ -4449,7 +4568,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) { for (Value *Incoming : PN->incoming_values()) { // If the incoming value is the phi node itself, it can safely be skipped. if (Incoming == PN) continue; - if (isa<UndefValue>(Incoming)) { + if (Q.isUndefValue(Incoming)) { // Remember that we saw an undef value, but otherwise ignore them. HasUndefInput = true; continue; @@ -4527,7 +4646,7 @@ static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, return nullptr; // The mask value chooses which source operand we need to look at next. - int InVecNumElts = cast<VectorType>(Op0->getType())->getNumElements(); + int InVecNumElts = cast<FixedVectorType>(Op0->getType())->getNumElements(); int RootElt = MaskVal; Value *SourceOp = Op0; if (MaskVal >= InVecNumElts) { @@ -4574,16 +4693,16 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, unsigned MaskNumElts = Mask.size(); ElementCount InVecEltCount = InVecTy->getElementCount(); - bool Scalable = InVecEltCount.Scalable; + bool Scalable = InVecEltCount.isScalable(); SmallVector<int, 32> Indices; Indices.assign(Mask.begin(), Mask.end()); // Canonicalization: If mask does not select elements from an input vector, - // replace that input vector with undef. + // replace that input vector with poison. if (!Scalable) { bool MaskSelects0 = false, MaskSelects1 = false; - unsigned InVecNumElts = InVecEltCount.Min; + unsigned InVecNumElts = InVecEltCount.getKnownMinValue(); for (unsigned i = 0; i != MaskNumElts; ++i) { if (Indices[i] == -1) continue; @@ -4593,9 +4712,9 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, MaskSelects1 = true; } if (!MaskSelects0) - Op0 = UndefValue::get(InVecTy); + Op0 = PoisonValue::get(InVecTy); if (!MaskSelects1) - Op1 = UndefValue::get(InVecTy); + Op1 = PoisonValue::get(InVecTy); } auto *Op0Const = dyn_cast<Constant>(Op0); @@ -4604,15 +4723,16 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, // If all operands are constant, constant fold the shuffle. This // transformation depends on the value of the mask which is not known at // compile time for scalable vectors - if (!Scalable && Op0Const && Op1Const) - return ConstantFoldShuffleVectorInstruction(Op0Const, Op1Const, Mask); + if (Op0Const && Op1Const) + return ConstantExpr::getShuffleVector(Op0Const, Op1Const, Mask); // Canonicalization: if only one input vector is constant, it shall be the // second one. This transformation depends on the value of the mask which // is not known at compile time for scalable vectors if (!Scalable && Op0Const && !Op1Const) { std::swap(Op0, Op1); - ShuffleVectorInst::commuteShuffleMask(Indices, InVecEltCount.Min); + ShuffleVectorInst::commuteShuffleMask(Indices, + InVecEltCount.getKnownMinValue()); } // A splat of an inserted scalar constant becomes a vector constant: @@ -4644,7 +4764,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, // A shuffle of a splat is always the splat itself. Legal if the shuffle's // value type is same as the input vectors' type. if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0)) - if (isa<UndefValue>(Op1) && RetTy == InVecTy && + if (Q.isUndefValue(Op1) && RetTy == InVecTy && is_splat(OpShuf->getShuffleMask())) return Op0; @@ -4656,7 +4776,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, // Don't fold a shuffle with undef mask elements. This may get folded in a // better way using demanded bits or other analysis. // TODO: Should we allow this? - if (find(Indices, -1) != Indices.end()) + if (is_contained(Indices, -1)) return nullptr; // Check if every element of this shuffle can be mapped back to the @@ -4725,19 +4845,20 @@ static Constant *propagateNaN(Constant *In) { /// transforms based on undef/NaN because the operation itself makes no /// difference to the result. static Constant *simplifyFPOp(ArrayRef<Value *> Ops, - FastMathFlags FMF = FastMathFlags()) { + FastMathFlags FMF, + const SimplifyQuery &Q) { for (Value *V : Ops) { bool IsNan = match(V, m_NaN()); bool IsInf = match(V, m_Inf()); - bool IsUndef = match(V, m_Undef()); + bool IsUndef = Q.isUndefValue(V); // If this operation has 'nnan' or 'ninf' and at least 1 disallowed operand // (an undef operand can be chosen to be Nan/Inf), then the result of - // this operation is poison. That result can be relaxed to undef. + // this operation is poison. if (FMF.noNaNs() && (IsNan || IsUndef)) - return UndefValue::get(V->getType()); + return PoisonValue::get(V->getType()); if (FMF.noInfs() && (IsInf || IsUndef)) - return UndefValue::get(V->getType()); + return PoisonValue::get(V->getType()); if (IsUndef || IsNan) return propagateNaN(cast<Constant>(V)); @@ -4752,7 +4873,7 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fadd X, -0 ==> X @@ -4799,7 +4920,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fsub X, +0 ==> X @@ -4841,7 +4962,7 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF, static Value *SimplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF, const SimplifyQuery &Q, unsigned MaxRecurse) { - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // fmul X, 1.0 ==> X @@ -4908,7 +5029,7 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // X / 1.0 -> X @@ -4953,7 +5074,7 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF, if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q)) return C; - if (Constant *C = simplifyFPOp({Op0, Op1}, FMF)) + if (Constant *C = simplifyFPOp({Op0, Op1}, FMF, Q)) return C; // Unlike fdiv, the result of frem always matches the sign of the dividend. @@ -5198,6 +5319,15 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, // bitreverse(bitreverse(x)) -> x if (match(Op0, m_BitReverse(m_Value(X)))) return X; break; + case Intrinsic::ctpop: { + // If everything but the lowest bit is zero, that bit is the pop-count. Ex: + // ctpop(and X, 1) --> and X, 1 + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, BitWidth - 1), + Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + return Op0; + break; + } case Intrinsic::exp: // exp(log(x)) -> x if (Q.CxtI->hasAllowReassoc() && @@ -5250,27 +5380,156 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, return nullptr; } +static Intrinsic::ID getMaxMinOpposite(Intrinsic::ID IID) { + switch (IID) { + case Intrinsic::smax: return Intrinsic::smin; + case Intrinsic::smin: return Intrinsic::smax; + case Intrinsic::umax: return Intrinsic::umin; + case Intrinsic::umin: return Intrinsic::umax; + default: llvm_unreachable("Unexpected intrinsic"); + } +} + +static APInt getMaxMinLimit(Intrinsic::ID IID, unsigned BitWidth) { + switch (IID) { + case Intrinsic::smax: return APInt::getSignedMaxValue(BitWidth); + case Intrinsic::smin: return APInt::getSignedMinValue(BitWidth); + case Intrinsic::umax: return APInt::getMaxValue(BitWidth); + case Intrinsic::umin: return APInt::getMinValue(BitWidth); + default: llvm_unreachable("Unexpected intrinsic"); + } +} + +static ICmpInst::Predicate getMaxMinPredicate(Intrinsic::ID IID) { + switch (IID) { + case Intrinsic::smax: return ICmpInst::ICMP_SGE; + case Intrinsic::smin: return ICmpInst::ICMP_SLE; + case Intrinsic::umax: return ICmpInst::ICMP_UGE; + case Intrinsic::umin: return ICmpInst::ICMP_ULE; + default: llvm_unreachable("Unexpected intrinsic"); + } +} + +/// Given a min/max intrinsic, see if it can be removed based on having an +/// operand that is another min/max intrinsic with shared operand(s). The caller +/// is expected to swap the operand arguments to handle commutation. +static Value *foldMinMaxSharedOp(Intrinsic::ID IID, Value *Op0, Value *Op1) { + Value *X, *Y; + if (!match(Op0, m_MaxOrMin(m_Value(X), m_Value(Y)))) + return nullptr; + + auto *MM0 = dyn_cast<IntrinsicInst>(Op0); + if (!MM0) + return nullptr; + Intrinsic::ID IID0 = MM0->getIntrinsicID(); + + if (Op1 == X || Op1 == Y || + match(Op1, m_c_MaxOrMin(m_Specific(X), m_Specific(Y)))) { + // max (max X, Y), X --> max X, Y + if (IID0 == IID) + return MM0; + // max (min X, Y), X --> X + if (IID0 == getMaxMinOpposite(IID)) + return Op1; + } + return nullptr; +} + static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, const SimplifyQuery &Q) { Intrinsic::ID IID = F->getIntrinsicID(); Type *ReturnType = F->getReturnType(); + unsigned BitWidth = ReturnType->getScalarSizeInBits(); switch (IID) { + case Intrinsic::abs: + // abs(abs(x)) -> abs(x). We don't need to worry about the nsw arg here. + // It is always ok to pick the earlier abs. We'll just lose nsw if its only + // on the outer abs. + if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(), m_Value()))) + return Op0; + break; + + case Intrinsic::smax: + case Intrinsic::smin: + case Intrinsic::umax: + case Intrinsic::umin: { + // If the arguments are the same, this is a no-op. + if (Op0 == Op1) + return Op0; + + // Canonicalize constant operand as Op1. + if (isa<Constant>(Op0)) + std::swap(Op0, Op1); + + // Assume undef is the limit value. + if (Q.isUndefValue(Op1)) + return ConstantInt::get(ReturnType, getMaxMinLimit(IID, BitWidth)); + + const APInt *C; + if (match(Op1, m_APIntAllowUndef(C))) { + // Clamp to limit value. For example: + // umax(i8 %x, i8 255) --> 255 + if (*C == getMaxMinLimit(IID, BitWidth)) + return ConstantInt::get(ReturnType, *C); + + // If the constant op is the opposite of the limit value, the other must + // be larger/smaller or equal. For example: + // umin(i8 %x, i8 255) --> %x + if (*C == getMaxMinLimit(getMaxMinOpposite(IID), BitWidth)) + return Op0; + + // Remove nested call if constant operands allow it. Example: + // max (max X, 7), 5 -> max X, 7 + auto *MinMax0 = dyn_cast<IntrinsicInst>(Op0); + if (MinMax0 && MinMax0->getIntrinsicID() == IID) { + // TODO: loosen undef/splat restrictions for vector constants. + Value *M00 = MinMax0->getOperand(0), *M01 = MinMax0->getOperand(1); + const APInt *InnerC; + if ((match(M00, m_APInt(InnerC)) || match(M01, m_APInt(InnerC))) && + ((IID == Intrinsic::smax && InnerC->sge(*C)) || + (IID == Intrinsic::smin && InnerC->sle(*C)) || + (IID == Intrinsic::umax && InnerC->uge(*C)) || + (IID == Intrinsic::umin && InnerC->ule(*C)))) + return Op0; + } + } + + if (Value *V = foldMinMaxSharedOp(IID, Op0, Op1)) + return V; + if (Value *V = foldMinMaxSharedOp(IID, Op1, Op0)) + return V; + + ICmpInst::Predicate Pred = getMaxMinPredicate(IID); + if (isICmpTrue(Pred, Op0, Op1, Q.getWithoutUndef(), RecursionLimit)) + return Op0; + if (isICmpTrue(Pred, Op1, Op0, Q.getWithoutUndef(), RecursionLimit)) + return Op1; + + if (Optional<bool> Imp = + isImpliedByDomCondition(Pred, Op0, Op1, Q.CxtI, Q.DL)) + return *Imp ? Op0 : Op1; + if (Optional<bool> Imp = + isImpliedByDomCondition(Pred, Op1, Op0, Q.CxtI, Q.DL)) + return *Imp ? Op1 : Op0; + + break; + } case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: // X - X -> { 0, false } - if (Op0 == Op1) + // X - undef -> { 0, false } + // undef - X -> { 0, false } + if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getNullValue(ReturnType); - LLVM_FALLTHROUGH; + break; case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: - // X - undef -> { undef, false } - // undef - X -> { undef, false } - // X + undef -> { undef, false } - // undef + x -> { undef, false } - if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) { + // X + undef -> { -1, false } + // undef + x -> { -1, false } + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) { return ConstantStruct::get( cast<StructType>(ReturnType), - {UndefValue::get(ReturnType->getStructElementType(0)), + {Constant::getAllOnesValue(ReturnType->getStructElementType(0)), Constant::getNullValue(ReturnType->getStructElementType(1))}); } break; @@ -5282,7 +5541,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return Constant::getNullValue(ReturnType); // undef * X -> { 0, false } // X * undef -> { 0, false } - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getNullValue(ReturnType); break; case Intrinsic::uadd_sat: @@ -5296,7 +5555,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // sat(undef + X) -> -1 // For unsigned: Assume undef is MAX, thus we saturate to MAX (-1). // For signed: Assume undef is ~X, in which case X + ~X = -1. - if (match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getAllOnesValue(ReturnType); // X + 0 -> X @@ -5313,7 +5572,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, LLVM_FALLTHROUGH; case Intrinsic::ssub_sat: // X - X -> 0, X - undef -> 0, undef - X -> 0 - if (Op0 == Op1 || match(Op0, m_Undef()) || match(Op1, m_Undef())) + if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1)) return Constant::getNullValue(ReturnType); // X - 0 -> X if (match(Op1, m_Zero())) @@ -5351,18 +5610,43 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // If the arguments are the same, this is a no-op. if (Op0 == Op1) return Op0; - // If one argument is undef, return the other argument. - if (match(Op0, m_Undef())) - return Op1; - if (match(Op1, m_Undef())) + // Canonicalize constant operand as Op1. + if (isa<Constant>(Op0)) + std::swap(Op0, Op1); + + // If an argument is undef, return the other argument. + if (Q.isUndefValue(Op1)) return Op0; - // If one argument is NaN, return other or NaN appropriately. bool PropagateNaN = IID == Intrinsic::minimum || IID == Intrinsic::maximum; - if (match(Op0, m_NaN())) - return PropagateNaN ? Op0 : Op1; + bool IsMin = IID == Intrinsic::minimum || IID == Intrinsic::minnum; + + // minnum(X, nan) -> X + // maxnum(X, nan) -> X + // minimum(X, nan) -> nan + // maximum(X, nan) -> nan if (match(Op1, m_NaN())) - return PropagateNaN ? Op1 : Op0; + return PropagateNaN ? propagateNaN(cast<Constant>(Op1)) : Op0; + + // In the following folds, inf can be replaced with the largest finite + // float, if the ninf flag is set. + const APFloat *C; + if (match(Op1, m_APFloat(C)) && + (C->isInfinity() || (Q.CxtI->hasNoInfs() && C->isLargest()))) { + // minnum(X, -inf) -> -inf + // maxnum(X, +inf) -> +inf + // minimum(X, -inf) -> -inf if nnan + // maximum(X, +inf) -> +inf if nnan + if (C->isNegative() == IsMin && (!PropagateNaN || Q.CxtI->hasNoNaNs())) + return ConstantFP::get(ReturnType, *C); + + // minnum(X, +inf) -> X if nnan + // maxnum(X, -inf) -> X if nnan + // minimum(X, +inf) -> X + // maximum(X, -inf) -> X + if (C->isNegative() != IsMin && (PropagateNaN || Q.CxtI->hasNoNaNs())) + return Op0; + } // Min/max of the same operation with common operand: // m(m(X, Y)), X --> m(X, Y) (4 commuted variants) @@ -5375,20 +5659,6 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, (M1->getOperand(0) == Op0 || M1->getOperand(1) == Op0)) return Op1; - // min(X, -Inf) --> -Inf (and commuted variant) - // max(X, +Inf) --> +Inf (and commuted variant) - bool UseNegInf = IID == Intrinsic::minnum || IID == Intrinsic::minimum; - const APFloat *C; - if ((match(Op0, m_APFloat(C)) && C->isInfinity() && - C->isNegative() == UseNegInf) || - (match(Op1, m_APFloat(C)) && C->isInfinity() && - C->isNegative() == UseNegInf)) - return ConstantFP::getInfinity(ReturnType, UseNegInf); - - // TODO: minnum(nnan x, inf) -> x - // TODO: minnum(nnan ninf x, flt_max) -> x - // TODO: maxnum(nnan x, -inf) -> x - // TODO: maxnum(nnan ninf x, -flt_max) -> x break; } default: @@ -5431,11 +5701,11 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { *ShAmtArg = Call->getArgOperand(2); // If both operands are undef, the result is undef. - if (match(Op0, m_Undef()) && match(Op1, m_Undef())) + if (Q.isUndefValue(Op0) && Q.isUndefValue(Op1)) return UndefValue::get(F->getReturnType()); // If shift amount is undef, assume it is zero. - if (match(ShAmtArg, m_Undef())) + if (Q.isUndefValue(ShAmtArg)) return Call->getArgOperand(IID == Intrinsic::fshl ? 0 : 1); const APInt *ShAmtC; @@ -5452,7 +5722,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { Value *Op0 = Call->getArgOperand(0); Value *Op1 = Call->getArgOperand(1); Value *Op2 = Call->getArgOperand(2); - if (Value *V = simplifyFPOp({ Op0, Op1, Op2 })) + if (Value *V = simplifyFPOp({ Op0, Op1, Op2 }, {}, Q)) return V; return nullptr; } @@ -5461,28 +5731,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) { } } -Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { - Value *Callee = Call->getCalledOperand(); - - // musttail calls can only be simplified if they are also DCEd. - // As we can't guarantee this here, don't simplify them. - if (Call->isMustTailCall()) - return nullptr; - - // call undef -> undef - // call null -> undef - if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) - return UndefValue::get(Call->getType()); - - Function *F = dyn_cast<Function>(Callee); - if (!F) - return nullptr; - - if (F->isIntrinsic()) - if (Value *Ret = simplifyIntrinsic(Call, Q)) - return Ret; - - if (!canConstantFoldCallTo(Call, F)) +static Value *tryConstantFoldCall(CallBase *Call, const SimplifyQuery &Q) { + auto *F = dyn_cast<Function>(Call->getCalledOperand()); + if (!F || !canConstantFoldCallTo(Call, F)) return nullptr; SmallVector<Constant *, 4> ConstantArgs; @@ -5501,10 +5752,33 @@ Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { return ConstantFoldCall(Call, F, ConstantArgs, Q.TLI); } +Value *llvm::SimplifyCall(CallBase *Call, const SimplifyQuery &Q) { + // musttail calls can only be simplified if they are also DCEd. + // As we can't guarantee this here, don't simplify them. + if (Call->isMustTailCall()) + return nullptr; + + // call undef -> poison + // call null -> poison + Value *Callee = Call->getCalledOperand(); + if (isa<UndefValue>(Callee) || isa<ConstantPointerNull>(Callee)) + return PoisonValue::get(Call->getType()); + + if (Value *V = tryConstantFoldCall(Call, Q)) + return V; + + auto *F = dyn_cast<Function>(Callee); + if (F && F->isIntrinsic()) + if (Value *Ret = simplifyIntrinsic(Call, Q)) + return Ret; + + return nullptr; +} + /// Given operands for a Freeze, see if we can fold the result. static Value *SimplifyFreezeInst(Value *Op0, const SimplifyQuery &Q) { // Use a utility function defined in ValueTracking. - if (llvm::isGuaranteedNotToBeUndefOrPoison(Op0, Q.CxtI, Q.DT)) + if (llvm::isGuaranteedNotToBeUndefOrPoison(Op0, Q.AC, Q.CxtI, Q.DT)) return Op0; // We have room for improvement. return nullptr; @@ -5613,7 +5887,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ, I->getOperand(2), Q); break; case Instruction::GetElementPtr: { - SmallVector<Value *, 8> Ops(I->op_begin(), I->op_end()); + SmallVector<Value *, 8> Ops(I->operands()); Result = SimplifyGEPInst(cast<GetElementPtrInst>(I)->getSourceElementType(), Ops, Q); break; @@ -5750,13 +6024,6 @@ static bool replaceAndRecursivelySimplifyImpl( return Simplified; } -bool llvm::recursivelySimplifyInstruction(Instruction *I, - const TargetLibraryInfo *TLI, - const DominatorTree *DT, - AssumptionCache *AC) { - return replaceAndRecursivelySimplifyImpl(I, nullptr, TLI, DT, AC, nullptr); -} - bool llvm::replaceAndRecursivelySimplify( Instruction *I, Value *SimpleV, const TargetLibraryInfo *TLI, const DominatorTree *DT, AssumptionCache *AC, |
