diff options
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
| -rw-r--r-- | llvm/lib/Analysis/InstructionSimplify.cpp | 802 |
1 files changed, 471 insertions, 331 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp index 0bfea6140ab5..2a45acf63aa2 100644 --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, if (IsNUW) return Constant::getNullValue(Op0->getType()); - KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q); if (Known.Zero.isMaxSignedValue()) { // Op1 is either 0 or the minimum signed value. If the sub is NSW, then // Op1 must be 0 because negating the minimum signed value is undefined. @@ -895,7 +895,8 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW, // Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...). if (match(Op0, m_PtrToInt(m_Value(X))) && match(Op1, m_PtrToInt(m_Value(Y)))) if (Constant *Result = computePointerDifference(Q.DL, X, Y)) - return ConstantExpr::getIntegerCast(Result, Op0->getType(), true); + return ConstantFoldIntegerCast(Result, Op0->getType(), /*IsSigned*/ true, + Q.DL); // i1 sub -> xor. if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1)) @@ -1062,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q, // ("computeConstantRangeIncludingKnownBits")? const APInt *C; if (match(Y, m_APInt(C)) && - computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C)) + computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C)) return true; // Try again for any divisor: @@ -1124,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0, if (Op0 == Op1) return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty); - - KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q); // X / 0 -> poison // X % 0 -> poison // If the divisor is known to be zero, just return poison. This can happen in @@ -1194,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1, // less trailing zeros, then the result must be poison. const APInt *DivC; if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) { - KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q); if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero()) return PoisonValue::get(Op0->getType()); } @@ -1354,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // If any bits in the shift amount make that value greater than or equal to // the number of bits in the type, the shift is undefined. - KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q); if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth())) return PoisonValue::get(Op0->getType()); @@ -1367,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // Check for nsw shl leading to a poison value. if (IsNSW) { assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction"); - KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q); KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt); if (KnownVal.Zero.isSignBitSet()) @@ -1403,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0, // The low bit cannot be shifted out of an exact shift if it is set. // TODO: Generalize by counting trailing zeros (see fold for exact division). if (IsExact) { - KnownBits Op0Known = - computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q); if (Op0Known.One[0]) return Op0; } @@ -1463,7 +1462,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact, // (X << A) >> A -> X Value *X; - if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) + if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1)))) return X; // ((X << A) | Y) >> A -> X if effective width of Y is not larger than A. @@ -1473,10 +1472,10 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact, // optimizers by supporting a simple but common case in InstSimplify. Value *Y; const APInt *ShRAmt, *ShLAmt; - if (match(Op1, m_APInt(ShRAmt)) && + if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) && match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) && *ShRAmt == *ShLAmt) { - const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q); const unsigned EffWidthY = YKnown.countMaxActiveBits(); if (ShRAmt->uge(EffWidthY)) return X; @@ -1673,43 +1672,6 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; } -static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool IsAnd) { - ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); - if (!match(Cmp0->getOperand(1), m_Zero()) || - !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) - return nullptr; - - if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) - return nullptr; - - // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". - Value *X = Cmp0->getOperand(0); - Value *Y = Cmp1->getOperand(0); - - // If one of the compares is a masked version of a (not) null check, then - // that compare implies the other, so we eliminate the other. Optionally, look - // through a pointer-to-int cast to match a null check of a pointer type. - - // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 - // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 - // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 - // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 - if (match(Y, m_c_And(m_Specific(X), m_Value())) || - match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) - return Cmp1; - - // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 - // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 - // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 - // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 - if (match(X, m_c_And(m_Specific(Y), m_Value())) || - match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) - return Cmp0; - - return nullptr; -} - static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, const InstrInfoQuery &IIQ) { // (icmp (add V, C0), C1) & (icmp V, C0) @@ -1757,66 +1719,6 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1, return nullptr; } -/// Try to eliminate compares with signed or unsigned min/max constants. -static Value *simplifyAndOrOfICmpsWithLimitConst(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool IsAnd) { - // Canonicalize an equality compare as Cmp0. - if (Cmp1->isEquality()) - std::swap(Cmp0, Cmp1); - if (!Cmp0->isEquality()) - return nullptr; - - // The non-equality compare must include a common operand (X). Canonicalize - // the common operand as operand 0 (the predicate is swapped if the common - // operand was operand 1). - ICmpInst::Predicate Pred0 = Cmp0->getPredicate(); - Value *X = Cmp0->getOperand(0); - ICmpInst::Predicate Pred1; - bool HasNotOp = match(Cmp1, m_c_ICmp(Pred1, m_Not(m_Specific(X)), m_Value())); - if (!HasNotOp && !match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value()))) - return nullptr; - if (ICmpInst::isEquality(Pred1)) - return nullptr; - - // The equality compare must be against a constant. Flip bits if we matched - // a bitwise not. Convert a null pointer constant to an integer zero value. - APInt MinMaxC; - const APInt *C; - if (match(Cmp0->getOperand(1), m_APInt(C))) - MinMaxC = HasNotOp ? ~*C : *C; - else if (isa<ConstantPointerNull>(Cmp0->getOperand(1))) - MinMaxC = APInt::getZero(8); - else - return nullptr; - - // DeMorganize if this is 'or': P0 || P1 --> !P0 && !P1. - if (!IsAnd) { - Pred0 = ICmpInst::getInversePredicate(Pred0); - Pred1 = ICmpInst::getInversePredicate(Pred1); - } - - // Normalize to unsigned compare and unsigned min/max value. - // Example for 8-bit: -128 + 128 -> 0; 127 + 128 -> 255 - if (ICmpInst::isSigned(Pred1)) { - Pred1 = ICmpInst::getUnsignedPredicate(Pred1); - MinMaxC += APInt::getSignedMinValue(MinMaxC.getBitWidth()); - } - - // (X != MAX) && (X < Y) --> X < Y - // (X == MAX) || (X >= Y) --> X >= Y - if (MinMaxC.isMaxValue()) - if (Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT) - return Cmp1; - - // (X != MIN) && (X > Y) --> X > Y - // (X == MIN) || (X <= Y) --> X <= Y - if (MinMaxC.isMinValue()) - if (Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_UGT) - return Cmp1; - - return nullptr; -} - /// Try to simplify and/or of icmp with ctpop intrinsic. static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd) { @@ -1848,12 +1750,6 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1, if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true)) return X; - if (Value *X = simplifyAndOrOfICmpsWithLimitConst(Op0, Op1, true)) - return X; - - if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) - return X; - if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, true)) return X; if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, true)) @@ -1924,12 +1820,6 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1, if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false)) return X; - if (Value *X = simplifyAndOrOfICmpsWithLimitConst(Op0, Op1, false)) - return X; - - if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) - return X; - if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, false)) return X; if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, false)) @@ -2019,7 +1909,60 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, Value *Op0, // If we looked through casts, we can only handle a constant simplification // because we are not allowed to create a cast instruction here. if (auto *C = dyn_cast<Constant>(V)) - return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType()); + return ConstantFoldCastOperand(Cast0->getOpcode(), C, Cast0->getType(), + Q.DL); + + return nullptr; +} + +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, + const SimplifyQuery &Q, + bool AllowRefinement, + SmallVectorImpl<Instruction *> *DropFlags, + unsigned MaxRecurse); + +static Value *simplifyAndOrWithICmpEq(unsigned Opcode, Value *Op0, Value *Op1, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + assert((Opcode == Instruction::And || Opcode == Instruction::Or) && + "Must be and/or"); + ICmpInst::Predicate Pred; + Value *A, *B; + if (!match(Op0, m_ICmp(Pred, m_Value(A), m_Value(B))) || + !ICmpInst::isEquality(Pred)) + return nullptr; + + auto Simplify = [&](Value *Res) -> Value * { + Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, Res->getType()); + + // and (icmp eq a, b), x implies (a==b) inside x. + // or (icmp ne a, b), x implies (a==b) inside x. + // If x simplifies to true/false, we can simplify the and/or. + if (Pred == + (Opcode == Instruction::And ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) { + if (Res == Absorber) + return Absorber; + if (Res == ConstantExpr::getBinOpIdentity(Opcode, Res->getType())) + return Op0; + return nullptr; + } + + // If we have and (icmp ne a, b), x and for a==b we can simplify x to false, + // then we can drop the icmp, as x will already be false in the case where + // the icmp is false. Similar for or and true. + if (Res == Absorber) + return Op1; + return nullptr; + }; + + if (Value *Res = + simplifyWithOpReplaced(Op1, A, B, Q, /* AllowRefinement */ true, + /* DropFlags */ nullptr, MaxRecurse)) + return Simplify(Res); + if (Value *Res = + simplifyWithOpReplaced(Op1, B, A, Q, /* AllowRefinement */ true, + /* DropFlags */ nullptr, MaxRecurse)) + return Simplify(Res); return nullptr; } @@ -2048,6 +1991,58 @@ static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1, return nullptr; } +// Commutative patterns for and that will be tried with both operand orders. +static Value *simplifyAndCommutative(Value *Op0, Value *Op1, + const SimplifyQuery &Q, + unsigned MaxRecurse) { + // ~A & A = 0 + if (match(Op0, m_Not(m_Specific(Op1)))) + return Constant::getNullValue(Op0->getType()); + + // (A | ?) & A = A + if (match(Op0, m_c_Or(m_Specific(Op1), m_Value()))) + return Op1; + + // (X | ~Y) & (X | Y) --> X + Value *X, *Y; + if (match(Op0, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) && + match(Op1, m_c_Or(m_Deferred(X), m_Deferred(Y)))) + return X; + + // If we have a multiplication overflow check that is being 'and'ed with a + // check that one of the multipliers is not zero, we can omit the 'and', and + // only keep the overflow check. + if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true)) + return Op1; + + // -A & A = A if A is a power of two or zero. + if (match(Op0, m_Neg(m_Specific(Op1))) && + isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Op1; + + // This is a similar pattern used for checking if a value is a power-of-2: + // (A - 1) & A --> 0 (if A is a power-of-2 or 0) + if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) + return Constant::getNullValue(Op1->getType()); + + // (x << N) & ((x << M) - 1) --> 0, where x is known to be a power of 2 and + // M <= N. + const APInt *Shift1, *Shift2; + if (match(Op0, m_Shl(m_Value(X), m_APInt(Shift1))) && + match(Op1, m_Add(m_Shl(m_Specific(X), m_APInt(Shift2)), m_AllOnes())) && + isKnownToBeAPowerOfTwo(X, Q.DL, /*OrZero*/ true, /*Depth*/ 0, Q.AC, + Q.CxtI) && + Shift1->uge(*Shift2)) + return Constant::getNullValue(Op0->getType()); + + if (Value *V = + simplifyAndOrWithICmpEq(Instruction::And, Op0, Op1, Q, MaxRecurse)) + return V; + + return nullptr; +} + /// Given operands for an And, see if we can fold the result. /// If not, this returns null. static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, @@ -2075,26 +2070,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (match(Op1, m_AllOnes())) return Op0; - // A & ~A = ~A & A = 0 - if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0)))) - return Constant::getNullValue(Op0->getType()); - - // (A | ?) & A = A - if (match(Op0, m_c_Or(m_Specific(Op1), m_Value()))) - return Op1; - - // A & (A | ?) = A - if (match(Op1, m_c_Or(m_Specific(Op0), m_Value()))) - return Op0; - - // (X | Y) & (X | ~Y) --> X (commuted 8 ways) - Value *X, *Y; - if (match(Op0, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) && - match(Op1, m_c_Or(m_Deferred(X), m_Deferred(Y)))) - return X; - if (match(Op1, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) && - match(Op0, m_c_Or(m_Deferred(X), m_Deferred(Y)))) - return X; + if (Value *Res = simplifyAndCommutative(Op0, Op1, Q, MaxRecurse)) + return Res; + if (Value *Res = simplifyAndCommutative(Op1, Op0, Q, MaxRecurse)) + return Res; if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::And)) return V; @@ -2102,6 +2081,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // A mask that only clears known zeros of a shifted value is a no-op. const APInt *Mask; const APInt *ShAmt; + Value *X, *Y; if (match(Op1, m_APInt(Mask))) { // If all bits in the inverted and shifted mask are clear: // and (shl X, ShAmt), Mask --> shl X, ShAmt @@ -2116,35 +2096,19 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, return Op0; } - // If we have a multiplication overflow check that is being 'and'ed with a - // check that one of the multipliers is not zero, we can omit the 'and', and - // only keep the overflow check. - if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true)) - return Op1; - if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true)) - return Op0; - - // A & (-A) = A if A is a power of two or zero. - if (match(Op0, m_Neg(m_Specific(Op1))) || - match(Op1, m_Neg(m_Specific(Op0)))) { - if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, - Q.DT)) - return Op0; - if (isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, - Q.DT)) - return Op1; + // and 2^x-1, 2^C --> 0 where x <= C. + const APInt *PowerC; + Value *Shift; + if (match(Op1, m_Power2(PowerC)) && + match(Op0, m_Add(m_Value(Shift), m_AllOnes())) && + isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI, + Q.DT)) { + KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q); + // Use getActiveBits() to make use of the additional power of two knowledge + if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits()) + return ConstantInt::getNullValue(Op1->getType()); } - // This is a similar pattern used for checking if a value is a power-of-2: - // (A - 1) & A --> 0 (if A is a power-of-2 or 0) - // A & (A - 1) --> 0 (if A is a power-of-2 or 0) - if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) && - isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) - return Constant::getNullValue(Op1->getType()); - if (match(Op1, m_Add(m_Specific(Op0), m_AllOnes())) && - isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT)) - return Constant::getNullValue(Op0->getType()); - if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true)) return V; @@ -2197,16 +2161,16 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, // SimplifyDemandedBits in InstCombine can optimize the general case. // This pattern aims to help other passes for a common case. Value *XShifted; - if (match(Op1, m_APInt(Mask)) && + if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(Mask)) && match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)), m_Value(XShifted)), m_Value(Y)))) { const unsigned Width = Op0->getType()->getScalarSizeInBits(); const unsigned ShftCnt = ShAmt->getLimitedValue(Width); - const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q); const unsigned EffWidthY = YKnown.countMaxActiveBits(); if (EffWidthY <= ShftCnt) { - const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q); const unsigned EffWidthX = XKnown.countMaxActiveBits(); const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY); const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt; @@ -2421,6 +2385,13 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op0, m_LShr(m_Specific(X), m_Specific(Y)))) return Op1; + if (Value *V = + simplifyAndOrWithICmpEq(Instruction::Or, Op0, Op1, Q, MaxRecurse)) + return V; + if (Value *V = + simplifyAndOrWithICmpEq(Instruction::Or, Op1, Op0, Q, MaxRecurse)) + return V; + if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false)) return V; @@ -2472,13 +2443,13 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, if (C2->isMask() && // C2 == 0+1+ match(A, m_c_Add(m_Specific(B), m_Value(N)))) { // Add commutes, try both ways. - if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C2, Q)) return A; } // Or commutes, try both ways. if (C1->isMask() && match(B, m_c_Add(m_Specific(A), m_Value(N)))) { // Add commutes, try both ways. - if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (MaskedValueIsZero(N, *C1, Q)) return B; } } @@ -2722,13 +2693,6 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS, const TargetLibraryInfo *TLI = Q.TLI; const DominatorTree *DT = Q.DT; const Instruction *CxtI = Q.CxtI; - const InstrInfoQuery &IIQ = Q.IIQ; - - // A non-null pointer is not equal to a null pointer. - if (isa<ConstantPointerNull>(RHS) && ICmpInst::isEquality(Pred) && - llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr, - IIQ.UseInstrInfo)) - return ConstantInt::get(getCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred)); // We can only fold certain predicates on pointer comparisons. switch (Pred) { @@ -3002,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return getTrue(ITy); break; case ICmpInst::ICMP_SLT: { - KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q); if (LHSKnown.isNegative()) return getTrue(ITy); if (LHSKnown.isNonNegative()) @@ -3010,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, break; } case ICmpInst::ICMP_SLE: { - KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q); if (LHSKnown.isNegative()) return getTrue(ITy); if (LHSKnown.isNonNegative() && @@ -3019,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, break; } case ICmpInst::ICMP_SGE: { - KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q); if (LHSKnown.isNegative()) return getFalse(ITy); if (LHSKnown.isNonNegative()) @@ -3027,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, break; } case ICmpInst::ICMP_SGT: { - KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q); if (LHSKnown.isNegative()) return getFalse(ITy); if (LHSKnown.isNonNegative() && @@ -3079,7 +3043,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, // (mul nuw/nsw X, MulC) != C --> true (if C is not a multiple of MulC) // (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC) const APInt *MulC; - if (ICmpInst::isEquality(Pred) && + if (IIQ.UseInstrInfo && ICmpInst::isEquality(Pred) && ((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) && *MulC != 0 && C->urem(*MulC) != 0) || (match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) && @@ -3104,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, return getTrue(ITy); if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { - KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); - KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q); + KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q); if (RHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); if (RHSKnown.isNegative() || YKnown.isNonNegative()) @@ -3128,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, break; case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: { - KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q); if (!Known.isNonNegative()) break; [[fallthrough]]; @@ -3139,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, return getFalse(ITy); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: { - KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q); if (!Known.isNonNegative()) break; [[fallthrough]]; @@ -3247,9 +3211,9 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred, // *) C2 < C1 && C1 <= 0. // static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS, - Value *RHS) { + Value *RHS, const InstrInfoQuery &IIQ) { // TODO: only support icmp slt for now. - if (Pred != CmpInst::ICMP_SLT) + if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo) return false; // Canonicalize nsw add as RHS. @@ -3318,7 +3282,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow. bool CanSimplify = (NoLHSWrapProblem && NoRHSWrapProblem) || - trySimplifyICmpWithAdds(Pred, LHS, RHS); + trySimplifyICmpWithAdds(Pred, LHS, RHS, Q.IIQ); if (A && C && (A == C || A == D || B == C || B == D) && CanSimplify) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; @@ -3397,10 +3361,10 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, } } - // TODO: This is overly constrained. LHS can be any power-of-2. - // (1 << X) >u 0x8000 --> false - // (1 << X) <=u 0x8000 --> true - if (match(LHS, m_Shl(m_One(), m_Value())) && match(RHS, m_SignMask())) { + // If C is a power-of-2: + // (C << X) >u 0x8000 --> false + // (C << X) <=u 0x8000 --> true + if (match(LHS, m_Shl(m_Power2(), m_Value())) && match(RHS, m_SignMask())) { if (Pred == ICmpInst::ICMP_UGT) return ConstantInt::getFalse(getCompareTy(RHS)); if (Pred == ICmpInst::ICMP_ULE) @@ -3414,7 +3378,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, switch (LBO->getOpcode()) { default: break; - case Instruction::Shl: + case Instruction::Shl: { bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO); bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO); if (!NUW || (ICmpInst::isSigned(Pred) && !NSW) || @@ -3423,6 +3387,38 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(1), RBO->getOperand(1), Q, MaxRecurse - 1)) return V; + break; + } + // If C1 & C2 == C1, A = X and/or C1, B = X and/or C2: + // icmp ule A, B -> true + // icmp ugt A, B -> false + // icmp sle A, B -> true (C1 and C2 are the same sign) + // icmp sgt A, B -> false (C1 and C2 are the same sign) + case Instruction::And: + case Instruction::Or: { + const APInt *C1, *C2; + if (ICmpInst::isRelational(Pred) && + match(LBO->getOperand(1), m_APInt(C1)) && + match(RBO->getOperand(1), m_APInt(C2))) { + if (!C1->isSubsetOf(*C2)) { + std::swap(C1, C2); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + if (C1->isSubsetOf(*C2)) { + if (Pred == ICmpInst::ICMP_ULE) + return ConstantInt::getTrue(getCompareTy(LHS)); + if (Pred == ICmpInst::ICMP_UGT) + return ConstantInt::getFalse(getCompareTy(LHS)); + if (C1->isNonNegative() == C2->isNonNegative()) { + if (Pred == ICmpInst::ICMP_SLE) + return ConstantInt::getTrue(getCompareTy(LHS)); + if (Pred == ICmpInst::ICMP_SGT) + return ConstantInt::getFalse(getCompareTy(LHS)); + } + } + } + break; + } } } @@ -3831,9 +3827,15 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Compute the constant that would happen if we truncated to SrcTy then // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy); - Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C); + Constant *Trunc = + ConstantFoldCastOperand(Instruction::Trunc, C, SrcTy, Q.DL); + assert(Trunc && "Constant-fold of ImmConstant should not fail"); + Constant *RExt = + ConstantFoldCastOperand(CastInst::ZExt, Trunc, DstTy, Q.DL); + assert(RExt && "Constant-fold of ImmConstant should not fail"); + Constant *AnyEq = + ConstantFoldCompareInstOperands(ICmpInst::ICMP_EQ, RExt, C, Q.DL); + assert(AnyEq && "Constant-fold of ImmConstant should not fail"); // If the re-extended constant didn't change any of the elements then // this is effectively also a case of comparing two zero-extended @@ -3864,12 +3866,14 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // is non-negative then LHS <s RHS. case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_SGE: - return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C, - Constant::getNullValue(C->getType())); + return ConstantFoldCompareInstOperands( + ICmpInst::ICMP_SLT, C, Constant::getNullValue(C->getType()), + Q.DL); case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_SLE: - return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C, - Constant::getNullValue(C->getType())); + return ConstantFoldCompareInstOperands( + ICmpInst::ICMP_SGE, C, Constant::getNullValue(C->getType()), + Q.DL); } } } @@ -3897,14 +3901,19 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS, // Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended // too. If not, then try to deduce the result of the comparison. else if (match(RHS, m_ImmConstant())) { - Constant *C = dyn_cast<Constant>(RHS); - assert(C != nullptr); + Constant *C = cast<Constant>(RHS); // Compute the constant that would happen if we truncated to SrcTy then // reextended to DstTy. - Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy); - Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy); - Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C); + Constant *Trunc = + ConstantFoldCastOperand(Instruction::Trunc, C, SrcTy, Q.DL); + assert(Trunc && "Constant-fold of ImmConstant should not fail"); + Constant *RExt = + ConstantFoldCastOperand(CastInst::SExt, Trunc, DstTy, Q.DL); + assert(RExt && "Constant-fold of ImmConstant should not fail"); + Constant *AnyEq = + ConstantFoldCompareInstOperands(ICmpInst::ICMP_EQ, RExt, C, Q.DL); + assert(AnyEq && "Constant-fold of ImmConstant should not fail"); // If the re-extended constant didn't change then this is effectively // also a case of comparing two sign-extended values. @@ -4047,19 +4056,6 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, if (Pred == FCmpInst::FCMP_TRUE) return getTrue(RetTy); - // Fold (un)ordered comparison if we can determine there are no NaNs. - if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD) - if (FMF.noNaNs() || - (isKnownNeverNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) && - isKnownNeverNaN(RHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))) - return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD); - - // NaN is unordered; NaN is not ordered. - assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && - "Comparison must be either ordered or unordered"); - if (match(RHS, m_NaN())) - return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); - // fcmp pred x, poison and fcmp pred poison, x // fold to poison if (isa<PoisonValue>(LHS) || isa<PoisonValue>(RHS)) @@ -4081,80 +4077,88 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, return getFalse(RetTy); } - // Handle fcmp with constant RHS. - // TODO: Use match with a specific FP value, so these work with vectors with - // undef lanes. - const APFloat *C; - if (match(RHS, m_APFloat(C))) { - // Check whether the constant is an infinity. - if (C->isInfinity()) { - if (C->isNegative()) { - switch (Pred) { - case FCmpInst::FCMP_OLT: - // No value is ordered and less than negative infinity. - return getFalse(RetTy); - case FCmpInst::FCMP_UGE: - // All values are unordered with or at least negative infinity. - return getTrue(RetTy); - default: - break; - } - } else { - switch (Pred) { - case FCmpInst::FCMP_OGT: - // No value is ordered and greater than infinity. - return getFalse(RetTy); - case FCmpInst::FCMP_ULE: - // All values are unordered with and at most infinity. - return getTrue(RetTy); - default: - break; - } - } + // Fold (un)ordered comparison if we can determine there are no NaNs. + // + // This catches the 2 variable input case, constants are handled below as a + // class-like compare. + if (Pred == FCmpInst::FCMP_ORD || Pred == FCmpInst::FCMP_UNO) { + if (FMF.noNaNs() || + (isKnownNeverNaN(RHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) && + isKnownNeverNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))) + return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD); + } - // LHS == Inf - if (Pred == FCmpInst::FCMP_OEQ && - isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) - return getFalse(RetTy); - // LHS != Inf - if (Pred == FCmpInst::FCMP_UNE && - isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) - return getTrue(RetTy); - // LHS == Inf || LHS == NaN - if (Pred == FCmpInst::FCMP_UEQ && - isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) + const APFloat *C = nullptr; + match(RHS, m_APFloatAllowUndef(C)); + std::optional<KnownFPClass> FullKnownClassLHS; + + // Lazily compute the possible classes for LHS. Avoid computing it twice if + // RHS is a 0. + auto computeLHSClass = [=, &FullKnownClassLHS](FPClassTest InterestedFlags = + fcAllFlags) { + if (FullKnownClassLHS) + return *FullKnownClassLHS; + return computeKnownFPClass(LHS, FMF, Q.DL, InterestedFlags, 0, Q.TLI, Q.AC, + Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo); + }; + + if (C && Q.CxtI) { + // Fold out compares that express a class test. + // + // FIXME: Should be able to perform folds without context + // instruction. Always pass in the context function? + + const Function *ParentF = Q.CxtI->getFunction(); + auto [ClassVal, ClassTest] = fcmpToClassTest(Pred, *ParentF, LHS, C); + if (ClassVal) { + FullKnownClassLHS = computeLHSClass(); + if ((FullKnownClassLHS->KnownFPClasses & ClassTest) == fcNone) return getFalse(RetTy); - // LHS != Inf && LHS != NaN - if (Pred == FCmpInst::FCMP_ONE && - isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) + if ((FullKnownClassLHS->KnownFPClasses & ~ClassTest) == fcNone) return getTrue(RetTy); } + } + + // Handle fcmp with constant RHS. + if (C) { + // TODO: If we always required a context function, we wouldn't need to + // special case nans. + if (C->isNaN()) + return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); + + // TODO: Need version fcmpToClassTest which returns implied class when the + // compare isn't a complete class test. e.g. > 1.0 implies fcPositive, but + // isn't implementable as a class call. if (C->isNegative() && !C->isNegZero()) { - assert(!C->isNaN() && "Unexpected NaN constant!"); + FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask; + // TODO: We can catch more cases by using a range check rather than // relying on CannotBeOrderedLessThanZero. switch (Pred) { case FCmpInst::FCMP_UGE: case FCmpInst::FCMP_UGT: - case FCmpInst::FCMP_UNE: + case FCmpInst::FCMP_UNE: { + KnownFPClass KnownClass = computeLHSClass(Interested); + // (X >= 0) implies (X > C) when (C < 0) - if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, - Q.AC, Q.CxtI, Q.DT)) + if (KnownClass.cannotBeOrderedLessThanZero()) return getTrue(RetTy); break; + } case FCmpInst::FCMP_OEQ: case FCmpInst::FCMP_OLE: - case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLT: { + KnownFPClass KnownClass = computeLHSClass(Interested); + // (X >= 0) implies !(X < C) when (C < 0) - if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, - Q.DT)) + if (KnownClass.cannotBeOrderedLessThanZero()) return getFalse(RetTy); break; + } default: break; } } - // Check comparison of [minnum/maxnum with constant] with other constant. const APFloat *C2; if ((match(LHS, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_APFloat(C2))) && @@ -4201,13 +4205,17 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, } } + // TODO: Could fold this with above if there were a matcher which returned all + // classes in a non-splat vector. if (match(RHS, m_AnyZeroFP())) { switch (Pred) { case FCmpInst::FCMP_OGE: case FCmpInst::FCMP_ULT: { - FPClassTest Interested = FMF.noNaNs() ? fcNegative : fcNegative | fcNan; - KnownFPClass Known = computeKnownFPClass(LHS, Q.DL, Interested, 0, - Q.TLI, Q.AC, Q.CxtI, Q.DT); + FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask; + if (!FMF.noNaNs()) + Interested |= fcNan; + + KnownFPClass Known = computeLHSClass(Interested); // Positive or zero X >= 0.0 --> true // Positive or zero X < 0.0 --> false @@ -4217,12 +4225,16 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, break; } case FCmpInst::FCMP_UGE: - case FCmpInst::FCMP_OLT: + case FCmpInst::FCMP_OLT: { + FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask; + KnownFPClass Known = computeLHSClass(Interested); + // Positive or zero or nan X >= 0.0 --> true // Positive or zero or nan X < 0.0 --> false - if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)) + if (Known.cannotBeOrderedLessThanZero()) return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy); break; + } default: break; } @@ -4251,6 +4263,7 @@ Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS, static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, bool AllowRefinement, + SmallVectorImpl<Instruction *> *DropFlags, unsigned MaxRecurse) { // Trivial replacement. if (V == Op) @@ -4280,12 +4293,16 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, return nullptr; } + // Don't fold away llvm.is.constant checks based on assumptions. + if (match(I, m_Intrinsic<Intrinsic::is_constant>())) + return nullptr; + // Replace Op with RepOp in instruction operands. SmallVector<Value *, 8> NewOps; bool AnyReplaced = false; for (Value *InstOp : I->operands()) { if (Value *NewInstOp = simplifyWithOpReplaced( - InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) { + InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) { NewOps.push_back(NewInstOp); AnyReplaced = InstOp != NewInstOp; } else { @@ -4312,8 +4329,17 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, // x & x -> x, x | x -> x if ((Opcode == Instruction::And || Opcode == Instruction::Or) && - NewOps[0] == NewOps[1]) + NewOps[0] == NewOps[1]) { + // or disjoint x, x results in poison. + if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BO)) { + if (PDI->isDisjoint()) { + if (!DropFlags) + return nullptr; + DropFlags->push_back(BO); + } + } return NewOps[0]; + } // x - x -> 0, x ^ x -> 0. This is non-refining, because x is non-poison // by assumption and this case never wraps, so nowrap flags can be @@ -4379,16 +4405,30 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, // will be done in InstCombine). // TODO: This may be unsound, because it only catches some forms of // refinement. - if (!AllowRefinement && canCreatePoison(cast<Operator>(I))) - return nullptr; + if (!AllowRefinement) { + if (canCreatePoison(cast<Operator>(I), !DropFlags)) { + // abs cannot create poison if the value is known to never be int_min. + if (auto *II = dyn_cast<IntrinsicInst>(I); + II && II->getIntrinsicID() == Intrinsic::abs) { + if (!ConstOps[0]->isNotMinSignedValue()) + return nullptr; + } else + return nullptr; + } + Constant *Res = ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); + if (DropFlags && Res && I->hasPoisonGeneratingFlagsOrMetadata()) + DropFlags->push_back(I); + return Res; + } return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI); } Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp, const SimplifyQuery &Q, - bool AllowRefinement) { - return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, + bool AllowRefinement, + SmallVectorImpl<Instruction *> *DropFlags) { + return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, DropFlags, RecursionLimit); } @@ -4414,14 +4454,22 @@ static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X, // (X & Y) == 0 ? X | Y : X --> X | Y // (X & Y) != 0 ? X | Y : X --> X if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) + *Y == *C) { + // We can't return the or if it has the disjoint flag. + if (TrueWhenUnset && cast<PossiblyDisjointInst>(TrueVal)->isDisjoint()) + return nullptr; return TrueWhenUnset ? TrueVal : FalseVal; + } // (X & Y) == 0 ? X : X | Y --> X // (X & Y) != 0 ? X : X | Y --> X | Y if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) && - *Y == *C) + *Y == *C) { + // We can't return the or if it has the disjoint flag. + if (!TrueWhenUnset && cast<PossiblyDisjointInst>(FalseVal)->isDisjoint()) + return nullptr; return TrueWhenUnset ? TrueVal : FalseVal; + } } return nullptr; @@ -4521,11 +4569,11 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS, unsigned MaxRecurse) { if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ false, - MaxRecurse) == TrueVal) + /* DropFlags */ nullptr, MaxRecurse) == TrueVal) return FalseVal; if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q, /* AllowRefinement */ true, - MaxRecurse) == FalseVal) + /* DropFlags */ nullptr, MaxRecurse) == FalseVal) return FalseVal; return nullptr; @@ -4888,10 +4936,8 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, // Compute the (pointer) type returned by the GEP instruction. Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Indices); - Type *GEPTy = PointerType::get(LastType, AS); - if (VectorType *VT = dyn_cast<VectorType>(Ptr->getType())) - GEPTy = VectorType::get(GEPTy, VT->getElementCount()); - else { + Type *GEPTy = Ptr->getType(); + if (!GEPTy->isVectorTy()) { for (Value *Op : Indices) { // If one of the operands is a vector, the result type is a vector of // pointers. All vector operands must have the same number of elements. @@ -4918,15 +4964,11 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr, return UndefValue::get(GEPTy); bool IsScalableVec = - isa<ScalableVectorType>(SrcTy) || any_of(Indices, [](const Value *V) { + SrcTy->isScalableTy() || any_of(Indices, [](const Value *V) { return isa<ScalableVectorType>(V->getType()); }); if (Indices.size() == 1) { - // getelementptr P, 0 -> P. - if (match(Indices[0], m_Zero()) && Ptr->getType() == GEPTy) - return Ptr; - Type *Ty = SrcTy; if (!IsScalableVec && Ty->isSized()) { Value *P; @@ -6034,23 +6076,18 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset, if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL)) return nullptr; - Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext()); Type *Int32Ty = Type::getInt32Ty(Ptr->getContext()); - Type *Int32PtrTy = Int32Ty->getPointerTo(); - Type *Int64Ty = Type::getInt64Ty(Ptr->getContext()); auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset); if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64) return nullptr; - uint64_t OffsetInt = OffsetConstInt->getSExtValue(); - if (OffsetInt % 4 != 0) + APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc( + DL.getIndexTypeSizeInBits(Ptr->getType())); + if (OffsetInt.srem(4) != 0) return nullptr; - Constant *C = ConstantExpr::getGetElementPtr( - Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy), - ConstantInt::get(Int64Ty, OffsetInt / 4)); - Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL); + Constant *Loaded = ConstantFoldLoadFromConstPtr(Ptr, Int32Ty, OffsetInt, DL); if (!Loaded) return nullptr; @@ -6080,11 +6117,62 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset, PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset) return nullptr; - return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy); + return LoadedLHSPtr; +} + +// TODO: Need to pass in FastMathFlags +static Value *simplifyLdexp(Value *Op0, Value *Op1, const SimplifyQuery &Q, + bool IsStrict) { + // ldexp(poison, x) -> poison + // ldexp(x, poison) -> poison + if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1)) + return Op0; + + // ldexp(undef, x) -> nan + if (Q.isUndefValue(Op0)) + return ConstantFP::getNaN(Op0->getType()); + + if (!IsStrict) { + // TODO: Could insert a canonicalize for strict + + // ldexp(x, undef) -> x + if (Q.isUndefValue(Op1)) + return Op0; + } + + const APFloat *C = nullptr; + match(Op0, PatternMatch::m_APFloat(C)); + + // These cases should be safe, even with strictfp. + // ldexp(0.0, x) -> 0.0 + // ldexp(-0.0, x) -> -0.0 + // ldexp(inf, x) -> inf + // ldexp(-inf, x) -> -inf + if (C && (C->isZero() || C->isInfinity())) + return Op0; + + // These are canonicalization dropping, could do it if we knew how we could + // ignore denormal flushes and target handling of nan payload bits. + if (IsStrict) + return nullptr; + + // TODO: Could quiet this with strictfp if the exception mode isn't strict. + if (C && C->isNaN()) + return ConstantFP::get(Op0->getType(), C->makeQuiet()); + + // ldexp(x, 0) -> x + + // TODO: Could fold this if we know the exception mode isn't + // strict, we know the denormal mode and other target modes. + if (match(Op1, PatternMatch::m_ZeroInt())) + return Op0; + + return nullptr; } static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, - const SimplifyQuery &Q) { + const SimplifyQuery &Q, + const CallBase *Call) { // Idempotent functions return the same result when called repeatedly. Intrinsic::ID IID = F->getIntrinsicID(); if (isIdempotent(IID)) @@ -6129,31 +6217,37 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, // ctpop(and X, 1) --> and X, 1 unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); if (MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, BitWidth - 1), - Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + Q)) return Op0; break; } case Intrinsic::exp: // exp(log(x)) -> x - if (Q.CxtI->hasAllowReassoc() && + if (Call->hasAllowReassoc() && match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X)))) return X; break; case Intrinsic::exp2: // exp2(log2(x)) -> x - if (Q.CxtI->hasAllowReassoc() && + if (Call->hasAllowReassoc() && match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) return X; break; + case Intrinsic::exp10: + // exp10(log10(x)) -> x + if (Call->hasAllowReassoc() && + match(Op0, m_Intrinsic<Intrinsic::log10>(m_Value(X)))) + return X; + break; case Intrinsic::log: // log(exp(x)) -> x - if (Q.CxtI->hasAllowReassoc() && + if (Call->hasAllowReassoc() && match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) return X; break; case Intrinsic::log2: // log2(exp2(x)) -> x - if (Q.CxtI->hasAllowReassoc() && + if (Call->hasAllowReassoc() && (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) || match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), m_Value(X))))) @@ -6161,8 +6255,11 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0, break; case Intrinsic::log10: // log10(pow(10.0, x)) -> x - if (Q.CxtI->hasAllowReassoc() && - match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X)))) + // log10(exp10(x)) -> x + if (Call->hasAllowReassoc() && + (match(Op0, m_Intrinsic<Intrinsic::exp10>(m_Value(X))) || + match(Op0, + m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X))))) return X; break; case Intrinsic::experimental_vector_reverse: @@ -6260,7 +6357,8 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0, } static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, - const SimplifyQuery &Q) { + const SimplifyQuery &Q, + const CallBase *Call) { Intrinsic::ID IID = F->getIntrinsicID(); Type *ReturnType = F->getReturnType(); unsigned BitWidth = ReturnType->getScalarSizeInBits(); @@ -6287,6 +6385,44 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return Constant::getNullValue(ReturnType); break; } + case Intrinsic::ptrmask: { + if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1)) + return PoisonValue::get(Op0->getType()); + + // NOTE: We can't apply this simplifications based on the value of Op1 + // because we need to preserve provenance. + if (Q.isUndefValue(Op0) || match(Op0, m_Zero())) + return Constant::getNullValue(Op0->getType()); + + assert(Op1->getType()->getScalarSizeInBits() == + Q.DL.getIndexTypeSizeInBits(Op0->getType()) && + "Invalid mask width"); + // If index-width (mask size) is less than pointer-size then mask is + // 1-extended. + if (match(Op1, m_PtrToInt(m_Specific(Op0)))) + return Op0; + + // NOTE: We may have attributes associated with the return value of the + // llvm.ptrmask intrinsic that will be lost when we just return the + // operand. We should try to preserve them. + if (match(Op1, m_AllOnes()) || Q.isUndefValue(Op1)) + return Op0; + + Constant *C; + if (match(Op1, m_ImmConstant(C))) { + KnownBits PtrKnown = computeKnownBits(Op0, /*Depth=*/0, Q); + // See if we only masking off bits we know are already zero due to + // alignment. + APInt IrrelevantPtrBits = + PtrKnown.Zero.zextOrTrunc(C->getType()->getScalarSizeInBits()); + C = ConstantFoldBinaryOpOperands( + Instruction::Or, C, ConstantInt::get(C->getType(), IrrelevantPtrBits), + Q.DL); + if (C != nullptr && C->isAllOnesValue()) + return Op0; + } + break; + } case Intrinsic::smax: case Intrinsic::smin: case Intrinsic::umax: @@ -6426,6 +6562,8 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, return Op0; } break; + case Intrinsic::ldexp: + return simplifyLdexp(Op0, Op1, Q, false); case Intrinsic::copysign: // copysign X, X --> X if (Op0 == Op1) @@ -6480,19 +6618,19 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1, // float, if the ninf flag is set. const APFloat *C; if (match(Op1, m_APFloat(C)) && - (C->isInfinity() || (Q.CxtI->hasNoInfs() && C->isLargest()))) { + (C->isInfinity() || (Call->hasNoInfs() && C->isLargest()))) { // minnum(X, -inf) -> -inf // maxnum(X, +inf) -> +inf // minimum(X, -inf) -> -inf if nnan // maximum(X, +inf) -> +inf if nnan - if (C->isNegative() == IsMin && (!PropagateNaN || Q.CxtI->hasNoNaNs())) + if (C->isNegative() == IsMin && (!PropagateNaN || Call->hasNoNaNs())) return ConstantFP::get(ReturnType, *C); // minnum(X, +inf) -> X if nnan // maxnum(X, -inf) -> X if nnan // minimum(X, +inf) -> X // maximum(X, -inf) -> X - if (C->isNegative() != IsMin && (PropagateNaN || Q.CxtI->hasNoNaNs())) + if (C->isNegative() != IsMin && (PropagateNaN || Call->hasNoNaNs())) return Op0; } @@ -6539,13 +6677,10 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, if (!NumOperands) { switch (IID) { case Intrinsic::vscale: { - auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange); - if (!Attr.isValid()) - return nullptr; - unsigned VScaleMin = Attr.getVScaleRangeMin(); - std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax(); - if (VScaleMax && VScaleMin == VScaleMax) - return ConstantInt::get(F->getReturnType(), VScaleMin); + Type *RetTy = F->getReturnType(); + ConstantRange CR = getVScaleRange(Call->getFunction(), 64); + if (const APInt *C = CR.getSingleElement()) + return ConstantInt::get(RetTy, C->getZExtValue()); return nullptr; } default: @@ -6554,10 +6689,10 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, } if (NumOperands == 1) - return simplifyUnaryIntrinsic(F, Args[0], Q); + return simplifyUnaryIntrinsic(F, Args[0], Q, Call); if (NumOperands == 2) - return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q); + return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q, Call); // Handle intrinsics with 3 or more arguments. switch (IID) { @@ -6692,6 +6827,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee, *FPI->getExceptionBehavior(), *FPI->getRoundingMode()); } + case Intrinsic::experimental_constrained_ldexp: + return simplifyLdexp(Args[0], Args[1], Q, true); default: return nullptr; } @@ -6811,6 +6948,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I, const SimplifyQuery &SQ, unsigned MaxRecurse) { assert(I->getFunction() && "instruction should be inserted in a function"); + assert((!SQ.CxtI || SQ.CxtI->getFunction() == I->getFunction()) && + "context instruction should be in the same function"); + const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I); switch (I->getOpcode()) { |
