aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/InstructionSimplify.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp802
1 files changed, 471 insertions, 331 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 0bfea6140ab5..2a45acf63aa2 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -811,7 +811,7 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
if (IsNUW)
return Constant::getNullValue(Op0->getType());
- KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
if (Known.Zero.isMaxSignedValue()) {
// Op1 is either 0 or the minimum signed value. If the sub is NSW, then
// Op1 must be 0 because negating the minimum signed value is undefined.
@@ -895,7 +895,8 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
// Variations on GEP(base, I, ...) - GEP(base, i, ...) -> GEP(null, I-i, ...).
if (match(Op0, m_PtrToInt(m_Value(X))) && match(Op1, m_PtrToInt(m_Value(Y))))
if (Constant *Result = computePointerDifference(Q.DL, X, Y))
- return ConstantExpr::getIntegerCast(Result, Op0->getType(), true);
+ return ConstantFoldIntegerCast(Result, Op0->getType(), /*IsSigned*/ true,
+ Q.DL);
// i1 sub -> xor.
if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))
@@ -1062,7 +1063,7 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
// ("computeConstantRangeIncludingKnownBits")?
const APInt *C;
if (match(Y, m_APInt(C)) &&
- computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT).getMaxValue().ult(*C))
+ computeKnownBits(X, /* Depth */ 0, Q).getMaxValue().ult(*C))
return true;
// Try again for any divisor:
@@ -1124,8 +1125,7 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
if (Op0 == Op1)
return IsDiv ? ConstantInt::get(Ty, 1) : Constant::getNullValue(Ty);
-
- KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(Op1, /* Depth */ 0, Q);
// X / 0 -> poison
// X % 0 -> poison
// If the divisor is known to be zero, just return poison. This can happen in
@@ -1194,7 +1194,7 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
// less trailing zeros, then the result must be poison.
const APInt *DivC;
if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countr_zero()) {
- KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownOp0 = computeKnownBits(Op0, /* Depth */ 0, Q);
if (KnownOp0.countMaxTrailingZeros() < DivC->countr_zero())
return PoisonValue::get(Op0->getType());
}
@@ -1354,7 +1354,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// If any bits in the shift amount make that value greater than or equal to
// the number of bits in the type, the shift is undefined.
- KnownBits KnownAmt = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownAmt = computeKnownBits(Op1, /* Depth */ 0, Q);
if (KnownAmt.getMinValue().uge(KnownAmt.getBitWidth()))
return PoisonValue::get(Op0->getType());
@@ -1367,7 +1367,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
// Check for nsw shl leading to a poison value.
if (IsNSW) {
assert(Opcode == Instruction::Shl && "Expected shl for nsw instruction");
- KnownBits KnownVal = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits KnownVal = computeKnownBits(Op0, /* Depth */ 0, Q);
KnownBits KnownShl = KnownBits::shl(KnownVal, KnownAmt);
if (KnownVal.Zero.isSignBitSet())
@@ -1403,8 +1403,7 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
// The low bit cannot be shifted out of an exact shift if it is set.
// TODO: Generalize by counting trailing zeros (see fold for exact division).
if (IsExact) {
- KnownBits Op0Known =
- computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Op0Known = computeKnownBits(Op0, /* Depth */ 0, Q);
if (Op0Known.One[0])
return Op0;
}
@@ -1463,7 +1462,7 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
// (X << A) >> A -> X
Value *X;
- if (match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
+ if (Q.IIQ.UseInstrInfo && match(Op0, m_NUWShl(m_Value(X), m_Specific(Op1))))
return X;
// ((X << A) | Y) >> A -> X if effective width of Y is not larger than A.
@@ -1473,10 +1472,10 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
// optimizers by supporting a simple but common case in InstSimplify.
Value *Y;
const APInt *ShRAmt, *ShLAmt;
- if (match(Op1, m_APInt(ShRAmt)) &&
+ if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(ShRAmt)) &&
match(Op0, m_c_Or(m_NUWShl(m_Value(X), m_APInt(ShLAmt)), m_Value(Y))) &&
*ShRAmt == *ShLAmt) {
- const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (ShRAmt->uge(EffWidthY))
return X;
@@ -1673,43 +1672,6 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,
return nullptr;
}
-static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1,
- bool IsAnd) {
- ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate();
- if (!match(Cmp0->getOperand(1), m_Zero()) ||
- !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1)
- return nullptr;
-
- if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ))
- return nullptr;
-
- // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)".
- Value *X = Cmp0->getOperand(0);
- Value *Y = Cmp1->getOperand(0);
-
- // If one of the compares is a masked version of a (not) null check, then
- // that compare implies the other, so we eliminate the other. Optionally, look
- // through a pointer-to-int cast to match a null check of a pointer type.
-
- // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0
- // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0
- // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0
- // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0
- if (match(Y, m_c_And(m_Specific(X), m_Value())) ||
- match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value())))
- return Cmp1;
-
- // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0
- // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0
- // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0
- // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0
- if (match(X, m_c_And(m_Specific(Y), m_Value())) ||
- match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value())))
- return Cmp0;
-
- return nullptr;
-}
-
static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
const InstrInfoQuery &IIQ) {
// (icmp (add V, C0), C1) & (icmp V, C0)
@@ -1757,66 +1719,6 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
return nullptr;
}
-/// Try to eliminate compares with signed or unsigned min/max constants.
-static Value *simplifyAndOrOfICmpsWithLimitConst(ICmpInst *Cmp0, ICmpInst *Cmp1,
- bool IsAnd) {
- // Canonicalize an equality compare as Cmp0.
- if (Cmp1->isEquality())
- std::swap(Cmp0, Cmp1);
- if (!Cmp0->isEquality())
- return nullptr;
-
- // The non-equality compare must include a common operand (X). Canonicalize
- // the common operand as operand 0 (the predicate is swapped if the common
- // operand was operand 1).
- ICmpInst::Predicate Pred0 = Cmp0->getPredicate();
- Value *X = Cmp0->getOperand(0);
- ICmpInst::Predicate Pred1;
- bool HasNotOp = match(Cmp1, m_c_ICmp(Pred1, m_Not(m_Specific(X)), m_Value()));
- if (!HasNotOp && !match(Cmp1, m_c_ICmp(Pred1, m_Specific(X), m_Value())))
- return nullptr;
- if (ICmpInst::isEquality(Pred1))
- return nullptr;
-
- // The equality compare must be against a constant. Flip bits if we matched
- // a bitwise not. Convert a null pointer constant to an integer zero value.
- APInt MinMaxC;
- const APInt *C;
- if (match(Cmp0->getOperand(1), m_APInt(C)))
- MinMaxC = HasNotOp ? ~*C : *C;
- else if (isa<ConstantPointerNull>(Cmp0->getOperand(1)))
- MinMaxC = APInt::getZero(8);
- else
- return nullptr;
-
- // DeMorganize if this is 'or': P0 || P1 --> !P0 && !P1.
- if (!IsAnd) {
- Pred0 = ICmpInst::getInversePredicate(Pred0);
- Pred1 = ICmpInst::getInversePredicate(Pred1);
- }
-
- // Normalize to unsigned compare and unsigned min/max value.
- // Example for 8-bit: -128 + 128 -> 0; 127 + 128 -> 255
- if (ICmpInst::isSigned(Pred1)) {
- Pred1 = ICmpInst::getUnsignedPredicate(Pred1);
- MinMaxC += APInt::getSignedMinValue(MinMaxC.getBitWidth());
- }
-
- // (X != MAX) && (X < Y) --> X < Y
- // (X == MAX) || (X >= Y) --> X >= Y
- if (MinMaxC.isMaxValue())
- if (Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_ULT)
- return Cmp1;
-
- // (X != MIN) && (X > Y) --> X > Y
- // (X == MIN) || (X <= Y) --> X <= Y
- if (MinMaxC.isMinValue())
- if (Pred0 == ICmpInst::ICMP_NE && Pred1 == ICmpInst::ICMP_UGT)
- return Cmp1;
-
- return nullptr;
-}
-
/// Try to simplify and/or of icmp with ctpop intrinsic.
static Value *simplifyAndOrOfICmpsWithCtpop(ICmpInst *Cmp0, ICmpInst *Cmp1,
bool IsAnd) {
@@ -1848,12 +1750,6 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true))
return X;
- if (Value *X = simplifyAndOrOfICmpsWithLimitConst(Op0, Op1, true))
- return X;
-
- if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true))
- return X;
-
if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, true))
return X;
if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, true))
@@ -1924,12 +1820,6 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1,
if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false))
return X;
- if (Value *X = simplifyAndOrOfICmpsWithLimitConst(Op0, Op1, false))
- return X;
-
- if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false))
- return X;
-
if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op0, Op1, false))
return X;
if (Value *X = simplifyAndOrOfICmpsWithCtpop(Op1, Op0, false))
@@ -2019,7 +1909,60 @@ static Value *simplifyAndOrOfCmps(const SimplifyQuery &Q, Value *Op0,
// If we looked through casts, we can only handle a constant simplification
// because we are not allowed to create a cast instruction here.
if (auto *C = dyn_cast<Constant>(V))
- return ConstantExpr::getCast(Cast0->getOpcode(), C, Cast0->getType());
+ return ConstantFoldCastOperand(Cast0->getOpcode(), C, Cast0->getType(),
+ Q.DL);
+
+ return nullptr;
+}
+
+static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
+ const SimplifyQuery &Q,
+ bool AllowRefinement,
+ SmallVectorImpl<Instruction *> *DropFlags,
+ unsigned MaxRecurse);
+
+static Value *simplifyAndOrWithICmpEq(unsigned Opcode, Value *Op0, Value *Op1,
+ const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ assert((Opcode == Instruction::And || Opcode == Instruction::Or) &&
+ "Must be and/or");
+ ICmpInst::Predicate Pred;
+ Value *A, *B;
+ if (!match(Op0, m_ICmp(Pred, m_Value(A), m_Value(B))) ||
+ !ICmpInst::isEquality(Pred))
+ return nullptr;
+
+ auto Simplify = [&](Value *Res) -> Value * {
+ Constant *Absorber = ConstantExpr::getBinOpAbsorber(Opcode, Res->getType());
+
+ // and (icmp eq a, b), x implies (a==b) inside x.
+ // or (icmp ne a, b), x implies (a==b) inside x.
+ // If x simplifies to true/false, we can simplify the and/or.
+ if (Pred ==
+ (Opcode == Instruction::And ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) {
+ if (Res == Absorber)
+ return Absorber;
+ if (Res == ConstantExpr::getBinOpIdentity(Opcode, Res->getType()))
+ return Op0;
+ return nullptr;
+ }
+
+ // If we have and (icmp ne a, b), x and for a==b we can simplify x to false,
+ // then we can drop the icmp, as x will already be false in the case where
+ // the icmp is false. Similar for or and true.
+ if (Res == Absorber)
+ return Op1;
+ return nullptr;
+ };
+
+ if (Value *Res =
+ simplifyWithOpReplaced(Op1, A, B, Q, /* AllowRefinement */ true,
+ /* DropFlags */ nullptr, MaxRecurse))
+ return Simplify(Res);
+ if (Value *Res =
+ simplifyWithOpReplaced(Op1, B, A, Q, /* AllowRefinement */ true,
+ /* DropFlags */ nullptr, MaxRecurse))
+ return Simplify(Res);
return nullptr;
}
@@ -2048,6 +1991,58 @@ static Value *simplifyLogicOfAddSub(Value *Op0, Value *Op1,
return nullptr;
}
+// Commutative patterns for and that will be tried with both operand orders.
+static Value *simplifyAndCommutative(Value *Op0, Value *Op1,
+ const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
+ // ~A & A = 0
+ if (match(Op0, m_Not(m_Specific(Op1))))
+ return Constant::getNullValue(Op0->getType());
+
+ // (A | ?) & A = A
+ if (match(Op0, m_c_Or(m_Specific(Op1), m_Value())))
+ return Op1;
+
+ // (X | ~Y) & (X | Y) --> X
+ Value *X, *Y;
+ if (match(Op0, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) &&
+ match(Op1, m_c_Or(m_Deferred(X), m_Deferred(Y))))
+ return X;
+
+ // If we have a multiplication overflow check that is being 'and'ed with a
+ // check that one of the multipliers is not zero, we can omit the 'and', and
+ // only keep the overflow check.
+ if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true))
+ return Op1;
+
+ // -A & A = A if A is a power of two or zero.
+ if (match(Op0, m_Neg(m_Specific(Op1))) &&
+ isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
+ return Op1;
+
+ // This is a similar pattern used for checking if a value is a power-of-2:
+ // (A - 1) & A --> 0 (if A is a power-of-2 or 0)
+ if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) &&
+ isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
+ return Constant::getNullValue(Op1->getType());
+
+ // (x << N) & ((x << M) - 1) --> 0, where x is known to be a power of 2 and
+ // M <= N.
+ const APInt *Shift1, *Shift2;
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(Shift1))) &&
+ match(Op1, m_Add(m_Shl(m_Specific(X), m_APInt(Shift2)), m_AllOnes())) &&
+ isKnownToBeAPowerOfTwo(X, Q.DL, /*OrZero*/ true, /*Depth*/ 0, Q.AC,
+ Q.CxtI) &&
+ Shift1->uge(*Shift2))
+ return Constant::getNullValue(Op0->getType());
+
+ if (Value *V =
+ simplifyAndOrWithICmpEq(Instruction::And, Op0, Op1, Q, MaxRecurse))
+ return V;
+
+ return nullptr;
+}
+
/// Given operands for an And, see if we can fold the result.
/// If not, this returns null.
static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
@@ -2075,26 +2070,10 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
if (match(Op1, m_AllOnes()))
return Op0;
- // A & ~A = ~A & A = 0
- if (match(Op0, m_Not(m_Specific(Op1))) || match(Op1, m_Not(m_Specific(Op0))))
- return Constant::getNullValue(Op0->getType());
-
- // (A | ?) & A = A
- if (match(Op0, m_c_Or(m_Specific(Op1), m_Value())))
- return Op1;
-
- // A & (A | ?) = A
- if (match(Op1, m_c_Or(m_Specific(Op0), m_Value())))
- return Op0;
-
- // (X | Y) & (X | ~Y) --> X (commuted 8 ways)
- Value *X, *Y;
- if (match(Op0, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) &&
- match(Op1, m_c_Or(m_Deferred(X), m_Deferred(Y))))
- return X;
- if (match(Op1, m_c_Or(m_Value(X), m_Not(m_Value(Y)))) &&
- match(Op0, m_c_Or(m_Deferred(X), m_Deferred(Y))))
- return X;
+ if (Value *Res = simplifyAndCommutative(Op0, Op1, Q, MaxRecurse))
+ return Res;
+ if (Value *Res = simplifyAndCommutative(Op1, Op0, Q, MaxRecurse))
+ return Res;
if (Value *V = simplifyLogicOfAddSub(Op0, Op1, Instruction::And))
return V;
@@ -2102,6 +2081,7 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// A mask that only clears known zeros of a shifted value is a no-op.
const APInt *Mask;
const APInt *ShAmt;
+ Value *X, *Y;
if (match(Op1, m_APInt(Mask))) {
// If all bits in the inverted and shifted mask are clear:
// and (shl X, ShAmt), Mask --> shl X, ShAmt
@@ -2116,35 +2096,19 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
return Op0;
}
- // If we have a multiplication overflow check that is being 'and'ed with a
- // check that one of the multipliers is not zero, we can omit the 'and', and
- // only keep the overflow check.
- if (isCheckForZeroAndMulWithOverflow(Op0, Op1, true))
- return Op1;
- if (isCheckForZeroAndMulWithOverflow(Op1, Op0, true))
- return Op0;
-
- // A & (-A) = A if A is a power of two or zero.
- if (match(Op0, m_Neg(m_Specific(Op1))) ||
- match(Op1, m_Neg(m_Specific(Op0)))) {
- if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI,
- Q.DT))
- return Op0;
- if (isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI,
- Q.DT))
- return Op1;
+ // and 2^x-1, 2^C --> 0 where x <= C.
+ const APInt *PowerC;
+ Value *Shift;
+ if (match(Op1, m_Power2(PowerC)) &&
+ match(Op0, m_Add(m_Value(Shift), m_AllOnes())) &&
+ isKnownToBeAPowerOfTwo(Shift, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
+ Q.DT)) {
+ KnownBits Known = computeKnownBits(Shift, /* Depth */ 0, Q);
+ // Use getActiveBits() to make use of the additional power of two knowledge
+ if (PowerC->getActiveBits() >= Known.getMaxValue().getActiveBits())
+ return ConstantInt::getNullValue(Op1->getType());
}
- // This is a similar pattern used for checking if a value is a power-of-2:
- // (A - 1) & A --> 0 (if A is a power-of-2 or 0)
- // A & (A - 1) --> 0 (if A is a power-of-2 or 0)
- if (match(Op0, m_Add(m_Specific(Op1), m_AllOnes())) &&
- isKnownToBeAPowerOfTwo(Op1, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
- return Constant::getNullValue(Op1->getType());
- if (match(Op1, m_Add(m_Specific(Op0), m_AllOnes())) &&
- isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ true, 0, Q.AC, Q.CxtI, Q.DT))
- return Constant::getNullValue(Op0->getType());
-
if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, true))
return V;
@@ -2197,16 +2161,16 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// SimplifyDemandedBits in InstCombine can optimize the general case.
// This pattern aims to help other passes for a common case.
Value *XShifted;
- if (match(Op1, m_APInt(Mask)) &&
+ if (Q.IIQ.UseInstrInfo && match(Op1, m_APInt(Mask)) &&
match(Op0, m_c_Or(m_CombineAnd(m_NUWShl(m_Value(X), m_APInt(ShAmt)),
m_Value(XShifted)),
m_Value(Y)))) {
const unsigned Width = Op0->getType()->getScalarSizeInBits();
const unsigned ShftCnt = ShAmt->getLimitedValue(Width);
- const KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
const unsigned EffWidthY = YKnown.countMaxActiveBits();
if (EffWidthY <= ShftCnt) {
- const KnownBits XKnown = computeKnownBits(X, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ const KnownBits XKnown = computeKnownBits(X, /* Depth */ 0, Q);
const unsigned EffWidthX = XKnown.countMaxActiveBits();
const APInt EffBitsY = APInt::getLowBitsSet(Width, EffWidthY);
const APInt EffBitsX = APInt::getLowBitsSet(Width, EffWidthX) << ShftCnt;
@@ -2421,6 +2385,13 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
match(Op0, m_LShr(m_Specific(X), m_Specific(Y))))
return Op1;
+ if (Value *V =
+ simplifyAndOrWithICmpEq(Instruction::Or, Op0, Op1, Q, MaxRecurse))
+ return V;
+ if (Value *V =
+ simplifyAndOrWithICmpEq(Instruction::Or, Op1, Op0, Q, MaxRecurse))
+ return V;
+
if (Value *V = simplifyAndOrOfCmps(Q, Op0, Op1, false))
return V;
@@ -2472,13 +2443,13 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
if (C2->isMask() && // C2 == 0+1+
match(A, m_c_Add(m_Specific(B), m_Value(N)))) {
// Add commutes, try both ways.
- if (MaskedValueIsZero(N, *C2, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ if (MaskedValueIsZero(N, *C2, Q))
return A;
}
// Or commutes, try both ways.
if (C1->isMask() && match(B, m_c_Add(m_Specific(A), m_Value(N)))) {
// Add commutes, try both ways.
- if (MaskedValueIsZero(N, *C1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ if (MaskedValueIsZero(N, *C1, Q))
return B;
}
}
@@ -2722,13 +2693,6 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
const TargetLibraryInfo *TLI = Q.TLI;
const DominatorTree *DT = Q.DT;
const Instruction *CxtI = Q.CxtI;
- const InstrInfoQuery &IIQ = Q.IIQ;
-
- // A non-null pointer is not equal to a null pointer.
- if (isa<ConstantPointerNull>(RHS) && ICmpInst::isEquality(Pred) &&
- llvm::isKnownNonZero(LHS, DL, 0, nullptr, nullptr, nullptr,
- IIQ.UseInstrInfo))
- return ConstantInt::get(getCompareTy(LHS), !CmpInst::isTrueWhenEqual(Pred));
// We can only fold certain predicates on pointer comparisons.
switch (Pred) {
@@ -3002,7 +2966,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
return getTrue(ITy);
break;
case ICmpInst::ICMP_SLT: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative())
@@ -3010,7 +2974,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SLE: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getTrue(ITy);
if (LHSKnown.isNonNegative() &&
@@ -3019,7 +2983,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SGE: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative())
@@ -3027,7 +2991,7 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS,
break;
}
case ICmpInst::ICMP_SGT: {
- KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits LHSKnown = computeKnownBits(LHS, /* Depth */ 0, Q);
if (LHSKnown.isNegative())
return getFalse(ITy);
if (LHSKnown.isNonNegative() &&
@@ -3079,7 +3043,7 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,
// (mul nuw/nsw X, MulC) != C --> true (if C is not a multiple of MulC)
// (mul nuw/nsw X, MulC) == C --> false (if C is not a multiple of MulC)
const APInt *MulC;
- if (ICmpInst::isEquality(Pred) &&
+ if (IIQ.UseInstrInfo && ICmpInst::isEquality(Pred) &&
((match(LHS, m_NUWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
*MulC != 0 && C->urem(*MulC) != 0) ||
(match(LHS, m_NSWMul(m_Value(), m_APIntAllowUndef(MulC))) &&
@@ -3104,8 +3068,8 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getTrue(ITy);
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) {
- KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
- KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits RHSKnown = computeKnownBits(RHS, /* Depth */ 0, Q);
+ KnownBits YKnown = computeKnownBits(Y, /* Depth */ 0, Q);
if (RHSKnown.isNonNegative() && YKnown.isNegative())
return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy);
if (RHSKnown.isNegative() || YKnown.isNonNegative())
@@ -3128,7 +3092,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
break;
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE: {
- KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
@@ -3139,7 +3103,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getFalse(ITy);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE: {
- KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ KnownBits Known = computeKnownBits(RHS, /* Depth */ 0, Q);
if (!Known.isNonNegative())
break;
[[fallthrough]];
@@ -3247,9 +3211,9 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
// *) C2 < C1 && C1 <= 0.
//
static bool trySimplifyICmpWithAdds(CmpInst::Predicate Pred, Value *LHS,
- Value *RHS) {
+ Value *RHS, const InstrInfoQuery &IIQ) {
// TODO: only support icmp slt for now.
- if (Pred != CmpInst::ICMP_SLT)
+ if (Pred != CmpInst::ICMP_SLT || !IIQ.UseInstrInfo)
return false;
// Canonicalize nsw add as RHS.
@@ -3318,7 +3282,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
// icmp (X+Y), (X+Z) -> icmp Y,Z for equalities or if there is no overflow.
bool CanSimplify = (NoLHSWrapProblem && NoRHSWrapProblem) ||
- trySimplifyICmpWithAdds(Pred, LHS, RHS);
+ trySimplifyICmpWithAdds(Pred, LHS, RHS, Q.IIQ);
if (A && C && (A == C || A == D || B == C || B == D) && CanSimplify) {
// Determine Y and Z in the form icmp (X+Y), (X+Z).
Value *Y, *Z;
@@ -3397,10 +3361,10 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
}
}
- // TODO: This is overly constrained. LHS can be any power-of-2.
- // (1 << X) >u 0x8000 --> false
- // (1 << X) <=u 0x8000 --> true
- if (match(LHS, m_Shl(m_One(), m_Value())) && match(RHS, m_SignMask())) {
+ // If C is a power-of-2:
+ // (C << X) >u 0x8000 --> false
+ // (C << X) <=u 0x8000 --> true
+ if (match(LHS, m_Shl(m_Power2(), m_Value())) && match(RHS, m_SignMask())) {
if (Pred == ICmpInst::ICMP_UGT)
return ConstantInt::getFalse(getCompareTy(RHS));
if (Pred == ICmpInst::ICMP_ULE)
@@ -3414,7 +3378,7 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
switch (LBO->getOpcode()) {
default:
break;
- case Instruction::Shl:
+ case Instruction::Shl: {
bool NUW = Q.IIQ.hasNoUnsignedWrap(LBO) && Q.IIQ.hasNoUnsignedWrap(RBO);
bool NSW = Q.IIQ.hasNoSignedWrap(LBO) && Q.IIQ.hasNoSignedWrap(RBO);
if (!NUW || (ICmpInst::isSigned(Pred) && !NSW) ||
@@ -3423,6 +3387,38 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS,
if (Value *V = simplifyICmpInst(Pred, LBO->getOperand(1),
RBO->getOperand(1), Q, MaxRecurse - 1))
return V;
+ break;
+ }
+ // If C1 & C2 == C1, A = X and/or C1, B = X and/or C2:
+ // icmp ule A, B -> true
+ // icmp ugt A, B -> false
+ // icmp sle A, B -> true (C1 and C2 are the same sign)
+ // icmp sgt A, B -> false (C1 and C2 are the same sign)
+ case Instruction::And:
+ case Instruction::Or: {
+ const APInt *C1, *C2;
+ if (ICmpInst::isRelational(Pred) &&
+ match(LBO->getOperand(1), m_APInt(C1)) &&
+ match(RBO->getOperand(1), m_APInt(C2))) {
+ if (!C1->isSubsetOf(*C2)) {
+ std::swap(C1, C2);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+ if (C1->isSubsetOf(*C2)) {
+ if (Pred == ICmpInst::ICMP_ULE)
+ return ConstantInt::getTrue(getCompareTy(LHS));
+ if (Pred == ICmpInst::ICMP_UGT)
+ return ConstantInt::getFalse(getCompareTy(LHS));
+ if (C1->isNonNegative() == C2->isNonNegative()) {
+ if (Pred == ICmpInst::ICMP_SLE)
+ return ConstantInt::getTrue(getCompareTy(LHS));
+ if (Pred == ICmpInst::ICMP_SGT)
+ return ConstantInt::getFalse(getCompareTy(LHS));
+ }
+ }
+ }
+ break;
+ }
}
}
@@ -3831,9 +3827,15 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
// Compute the constant that would happen if we truncated to SrcTy then
// reextended to DstTy.
- Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
- Constant *RExt = ConstantExpr::getCast(CastInst::ZExt, Trunc, DstTy);
- Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
+ Constant *Trunc =
+ ConstantFoldCastOperand(Instruction::Trunc, C, SrcTy, Q.DL);
+ assert(Trunc && "Constant-fold of ImmConstant should not fail");
+ Constant *RExt =
+ ConstantFoldCastOperand(CastInst::ZExt, Trunc, DstTy, Q.DL);
+ assert(RExt && "Constant-fold of ImmConstant should not fail");
+ Constant *AnyEq =
+ ConstantFoldCompareInstOperands(ICmpInst::ICMP_EQ, RExt, C, Q.DL);
+ assert(AnyEq && "Constant-fold of ImmConstant should not fail");
// If the re-extended constant didn't change any of the elements then
// this is effectively also a case of comparing two zero-extended
@@ -3864,12 +3866,14 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
// is non-negative then LHS <s RHS.
case ICmpInst::ICMP_SGT:
case ICmpInst::ICMP_SGE:
- return ConstantExpr::getICmp(ICmpInst::ICMP_SLT, C,
- Constant::getNullValue(C->getType()));
+ return ConstantFoldCompareInstOperands(
+ ICmpInst::ICMP_SLT, C, Constant::getNullValue(C->getType()),
+ Q.DL);
case ICmpInst::ICMP_SLT:
case ICmpInst::ICMP_SLE:
- return ConstantExpr::getICmp(ICmpInst::ICMP_SGE, C,
- Constant::getNullValue(C->getType()));
+ return ConstantFoldCompareInstOperands(
+ ICmpInst::ICMP_SGE, C, Constant::getNullValue(C->getType()),
+ Q.DL);
}
}
}
@@ -3897,14 +3901,19 @@ static Value *simplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,
// Turn icmp (sext X), Cst into a compare of X and Cst if Cst is extended
// too. If not, then try to deduce the result of the comparison.
else if (match(RHS, m_ImmConstant())) {
- Constant *C = dyn_cast<Constant>(RHS);
- assert(C != nullptr);
+ Constant *C = cast<Constant>(RHS);
// Compute the constant that would happen if we truncated to SrcTy then
// reextended to DstTy.
- Constant *Trunc = ConstantExpr::getTrunc(C, SrcTy);
- Constant *RExt = ConstantExpr::getCast(CastInst::SExt, Trunc, DstTy);
- Constant *AnyEq = ConstantExpr::getICmp(ICmpInst::ICMP_EQ, RExt, C);
+ Constant *Trunc =
+ ConstantFoldCastOperand(Instruction::Trunc, C, SrcTy, Q.DL);
+ assert(Trunc && "Constant-fold of ImmConstant should not fail");
+ Constant *RExt =
+ ConstantFoldCastOperand(CastInst::SExt, Trunc, DstTy, Q.DL);
+ assert(RExt && "Constant-fold of ImmConstant should not fail");
+ Constant *AnyEq =
+ ConstantFoldCompareInstOperands(ICmpInst::ICMP_EQ, RExt, C, Q.DL);
+ assert(AnyEq && "Constant-fold of ImmConstant should not fail");
// If the re-extended constant didn't change then this is effectively
// also a case of comparing two sign-extended values.
@@ -4047,19 +4056,6 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
if (Pred == FCmpInst::FCMP_TRUE)
return getTrue(RetTy);
- // Fold (un)ordered comparison if we can determine there are no NaNs.
- if (Pred == FCmpInst::FCMP_UNO || Pred == FCmpInst::FCMP_ORD)
- if (FMF.noNaNs() ||
- (isKnownNeverNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) &&
- isKnownNeverNaN(RHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)))
- return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD);
-
- // NaN is unordered; NaN is not ordered.
- assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) &&
- "Comparison must be either ordered or unordered");
- if (match(RHS, m_NaN()))
- return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred));
-
// fcmp pred x, poison and fcmp pred poison, x
// fold to poison
if (isa<PoisonValue>(LHS) || isa<PoisonValue>(RHS))
@@ -4081,80 +4077,88 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
return getFalse(RetTy);
}
- // Handle fcmp with constant RHS.
- // TODO: Use match with a specific FP value, so these work with vectors with
- // undef lanes.
- const APFloat *C;
- if (match(RHS, m_APFloat(C))) {
- // Check whether the constant is an infinity.
- if (C->isInfinity()) {
- if (C->isNegative()) {
- switch (Pred) {
- case FCmpInst::FCMP_OLT:
- // No value is ordered and less than negative infinity.
- return getFalse(RetTy);
- case FCmpInst::FCMP_UGE:
- // All values are unordered with or at least negative infinity.
- return getTrue(RetTy);
- default:
- break;
- }
- } else {
- switch (Pred) {
- case FCmpInst::FCMP_OGT:
- // No value is ordered and greater than infinity.
- return getFalse(RetTy);
- case FCmpInst::FCMP_ULE:
- // All values are unordered with and at most infinity.
- return getTrue(RetTy);
- default:
- break;
- }
- }
+ // Fold (un)ordered comparison if we can determine there are no NaNs.
+ //
+ // This catches the 2 variable input case, constants are handled below as a
+ // class-like compare.
+ if (Pred == FCmpInst::FCMP_ORD || Pred == FCmpInst::FCMP_UNO) {
+ if (FMF.noNaNs() ||
+ (isKnownNeverNaN(RHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT) &&
+ isKnownNeverNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT)))
+ return ConstantInt::get(RetTy, Pred == FCmpInst::FCMP_ORD);
+ }
- // LHS == Inf
- if (Pred == FCmpInst::FCMP_OEQ &&
- isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))
- return getFalse(RetTy);
- // LHS != Inf
- if (Pred == FCmpInst::FCMP_UNE &&
- isKnownNeverInfinity(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))
- return getTrue(RetTy);
- // LHS == Inf || LHS == NaN
- if (Pred == FCmpInst::FCMP_UEQ &&
- isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))
+ const APFloat *C = nullptr;
+ match(RHS, m_APFloatAllowUndef(C));
+ std::optional<KnownFPClass> FullKnownClassLHS;
+
+ // Lazily compute the possible classes for LHS. Avoid computing it twice if
+ // RHS is a 0.
+ auto computeLHSClass = [=, &FullKnownClassLHS](FPClassTest InterestedFlags =
+ fcAllFlags) {
+ if (FullKnownClassLHS)
+ return *FullKnownClassLHS;
+ return computeKnownFPClass(LHS, FMF, Q.DL, InterestedFlags, 0, Q.TLI, Q.AC,
+ Q.CxtI, Q.DT, Q.IIQ.UseInstrInfo);
+ };
+
+ if (C && Q.CxtI) {
+ // Fold out compares that express a class test.
+ //
+ // FIXME: Should be able to perform folds without context
+ // instruction. Always pass in the context function?
+
+ const Function *ParentF = Q.CxtI->getFunction();
+ auto [ClassVal, ClassTest] = fcmpToClassTest(Pred, *ParentF, LHS, C);
+ if (ClassVal) {
+ FullKnownClassLHS = computeLHSClass();
+ if ((FullKnownClassLHS->KnownFPClasses & ClassTest) == fcNone)
return getFalse(RetTy);
- // LHS != Inf && LHS != NaN
- if (Pred == FCmpInst::FCMP_ONE &&
- isKnownNeverInfOrNaN(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))
+ if ((FullKnownClassLHS->KnownFPClasses & ~ClassTest) == fcNone)
return getTrue(RetTy);
}
+ }
+
+ // Handle fcmp with constant RHS.
+ if (C) {
+ // TODO: If we always required a context function, we wouldn't need to
+ // special case nans.
+ if (C->isNaN())
+ return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred));
+
+ // TODO: Need version fcmpToClassTest which returns implied class when the
+ // compare isn't a complete class test. e.g. > 1.0 implies fcPositive, but
+ // isn't implementable as a class call.
if (C->isNegative() && !C->isNegZero()) {
- assert(!C->isNaN() && "Unexpected NaN constant!");
+ FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask;
+
// TODO: We can catch more cases by using a range check rather than
// relying on CannotBeOrderedLessThanZero.
switch (Pred) {
case FCmpInst::FCMP_UGE:
case FCmpInst::FCMP_UGT:
- case FCmpInst::FCMP_UNE:
+ case FCmpInst::FCMP_UNE: {
+ KnownFPClass KnownClass = computeLHSClass(Interested);
+
// (X >= 0) implies (X > C) when (C < 0)
- if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0,
- Q.AC, Q.CxtI, Q.DT))
+ if (KnownClass.cannotBeOrderedLessThanZero())
return getTrue(RetTy);
break;
+ }
case FCmpInst::FCMP_OEQ:
case FCmpInst::FCMP_OLE:
- case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_OLT: {
+ KnownFPClass KnownClass = computeLHSClass(Interested);
+
// (X >= 0) implies !(X < C) when (C < 0)
- if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI,
- Q.DT))
+ if (KnownClass.cannotBeOrderedLessThanZero())
return getFalse(RetTy);
break;
+ }
default:
break;
}
}
-
// Check comparison of [minnum/maxnum with constant] with other constant.
const APFloat *C2;
if ((match(LHS, m_Intrinsic<Intrinsic::minnum>(m_Value(), m_APFloat(C2))) &&
@@ -4201,13 +4205,17 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
}
}
+ // TODO: Could fold this with above if there were a matcher which returned all
+ // classes in a non-splat vector.
if (match(RHS, m_AnyZeroFP())) {
switch (Pred) {
case FCmpInst::FCMP_OGE:
case FCmpInst::FCMP_ULT: {
- FPClassTest Interested = FMF.noNaNs() ? fcNegative : fcNegative | fcNan;
- KnownFPClass Known = computeKnownFPClass(LHS, Q.DL, Interested, 0,
- Q.TLI, Q.AC, Q.CxtI, Q.DT);
+ FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask;
+ if (!FMF.noNaNs())
+ Interested |= fcNan;
+
+ KnownFPClass Known = computeLHSClass(Interested);
// Positive or zero X >= 0.0 --> true
// Positive or zero X < 0.0 --> false
@@ -4217,12 +4225,16 @@ static Value *simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
break;
}
case FCmpInst::FCMP_UGE:
- case FCmpInst::FCMP_OLT:
+ case FCmpInst::FCMP_OLT: {
+ FPClassTest Interested = KnownFPClass::OrderedLessThanZeroMask;
+ KnownFPClass Known = computeLHSClass(Interested);
+
// Positive or zero or nan X >= 0.0 --> true
// Positive or zero or nan X < 0.0 --> false
- if (cannotBeOrderedLessThanZero(LHS, Q.DL, Q.TLI, 0, Q.AC, Q.CxtI, Q.DT))
+ if (Known.cannotBeOrderedLessThanZero())
return Pred == FCmpInst::FCMP_UGE ? getTrue(RetTy) : getFalse(RetTy);
break;
+ }
default:
break;
}
@@ -4251,6 +4263,7 @@ Value *llvm::simplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,
static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
+ SmallVectorImpl<Instruction *> *DropFlags,
unsigned MaxRecurse) {
// Trivial replacement.
if (V == Op)
@@ -4280,12 +4293,16 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
return nullptr;
}
+ // Don't fold away llvm.is.constant checks based on assumptions.
+ if (match(I, m_Intrinsic<Intrinsic::is_constant>()))
+ return nullptr;
+
// Replace Op with RepOp in instruction operands.
SmallVector<Value *, 8> NewOps;
bool AnyReplaced = false;
for (Value *InstOp : I->operands()) {
if (Value *NewInstOp = simplifyWithOpReplaced(
- InstOp, Op, RepOp, Q, AllowRefinement, MaxRecurse)) {
+ InstOp, Op, RepOp, Q, AllowRefinement, DropFlags, MaxRecurse)) {
NewOps.push_back(NewInstOp);
AnyReplaced = InstOp != NewInstOp;
} else {
@@ -4312,8 +4329,17 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// x & x -> x, x | x -> x
if ((Opcode == Instruction::And || Opcode == Instruction::Or) &&
- NewOps[0] == NewOps[1])
+ NewOps[0] == NewOps[1]) {
+ // or disjoint x, x results in poison.
+ if (auto *PDI = dyn_cast<PossiblyDisjointInst>(BO)) {
+ if (PDI->isDisjoint()) {
+ if (!DropFlags)
+ return nullptr;
+ DropFlags->push_back(BO);
+ }
+ }
return NewOps[0];
+ }
// x - x -> 0, x ^ x -> 0. This is non-refining, because x is non-poison
// by assumption and this case never wraps, so nowrap flags can be
@@ -4379,16 +4405,30 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
// will be done in InstCombine).
// TODO: This may be unsound, because it only catches some forms of
// refinement.
- if (!AllowRefinement && canCreatePoison(cast<Operator>(I)))
- return nullptr;
+ if (!AllowRefinement) {
+ if (canCreatePoison(cast<Operator>(I), !DropFlags)) {
+ // abs cannot create poison if the value is known to never be int_min.
+ if (auto *II = dyn_cast<IntrinsicInst>(I);
+ II && II->getIntrinsicID() == Intrinsic::abs) {
+ if (!ConstOps[0]->isNotMinSignedValue())
+ return nullptr;
+ } else
+ return nullptr;
+ }
+ Constant *Res = ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
+ if (DropFlags && Res && I->hasPoisonGeneratingFlagsOrMetadata())
+ DropFlags->push_back(I);
+ return Res;
+ }
return ConstantFoldInstOperands(I, ConstOps, Q.DL, Q.TLI);
}
Value *llvm::simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
- bool AllowRefinement) {
- return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement,
+ bool AllowRefinement,
+ SmallVectorImpl<Instruction *> *DropFlags) {
+ return ::simplifyWithOpReplaced(V, Op, RepOp, Q, AllowRefinement, DropFlags,
RecursionLimit);
}
@@ -4414,14 +4454,22 @@ static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
// (X & Y) == 0 ? X | Y : X --> X | Y
// (X & Y) != 0 ? X | Y : X --> X
if (FalseVal == X && match(TrueVal, m_Or(m_Specific(X), m_APInt(C))) &&
- *Y == *C)
+ *Y == *C) {
+ // We can't return the or if it has the disjoint flag.
+ if (TrueWhenUnset && cast<PossiblyDisjointInst>(TrueVal)->isDisjoint())
+ return nullptr;
return TrueWhenUnset ? TrueVal : FalseVal;
+ }
// (X & Y) == 0 ? X : X | Y --> X
// (X & Y) != 0 ? X : X | Y --> X | Y
if (TrueVal == X && match(FalseVal, m_Or(m_Specific(X), m_APInt(C))) &&
- *Y == *C)
+ *Y == *C) {
+ // We can't return the or if it has the disjoint flag.
+ if (!TrueWhenUnset && cast<PossiblyDisjointInst>(FalseVal)->isDisjoint())
+ return nullptr;
return TrueWhenUnset ? TrueVal : FalseVal;
+ }
}
return nullptr;
@@ -4521,11 +4569,11 @@ static Value *simplifySelectWithICmpEq(Value *CmpLHS, Value *CmpRHS,
unsigned MaxRecurse) {
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ false,
- MaxRecurse) == TrueVal)
+ /* DropFlags */ nullptr, MaxRecurse) == TrueVal)
return FalseVal;
if (simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ true,
- MaxRecurse) == FalseVal)
+ /* DropFlags */ nullptr, MaxRecurse) == FalseVal)
return FalseVal;
return nullptr;
@@ -4888,10 +4936,8 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
// Compute the (pointer) type returned by the GEP instruction.
Type *LastType = GetElementPtrInst::getIndexedType(SrcTy, Indices);
- Type *GEPTy = PointerType::get(LastType, AS);
- if (VectorType *VT = dyn_cast<VectorType>(Ptr->getType()))
- GEPTy = VectorType::get(GEPTy, VT->getElementCount());
- else {
+ Type *GEPTy = Ptr->getType();
+ if (!GEPTy->isVectorTy()) {
for (Value *Op : Indices) {
// If one of the operands is a vector, the result type is a vector of
// pointers. All vector operands must have the same number of elements.
@@ -4918,15 +4964,11 @@ static Value *simplifyGEPInst(Type *SrcTy, Value *Ptr,
return UndefValue::get(GEPTy);
bool IsScalableVec =
- isa<ScalableVectorType>(SrcTy) || any_of(Indices, [](const Value *V) {
+ SrcTy->isScalableTy() || any_of(Indices, [](const Value *V) {
return isa<ScalableVectorType>(V->getType());
});
if (Indices.size() == 1) {
- // getelementptr P, 0 -> P.
- if (match(Indices[0], m_Zero()) && Ptr->getType() == GEPTy)
- return Ptr;
-
Type *Ty = SrcTy;
if (!IsScalableVec && Ty->isSized()) {
Value *P;
@@ -6034,23 +6076,18 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
if (!IsConstantOffsetFromGlobal(Ptr, PtrSym, PtrOffset, DL))
return nullptr;
- Type *Int8PtrTy = Type::getInt8PtrTy(Ptr->getContext());
Type *Int32Ty = Type::getInt32Ty(Ptr->getContext());
- Type *Int32PtrTy = Int32Ty->getPointerTo();
- Type *Int64Ty = Type::getInt64Ty(Ptr->getContext());
auto *OffsetConstInt = dyn_cast<ConstantInt>(Offset);
if (!OffsetConstInt || OffsetConstInt->getType()->getBitWidth() > 64)
return nullptr;
- uint64_t OffsetInt = OffsetConstInt->getSExtValue();
- if (OffsetInt % 4 != 0)
+ APInt OffsetInt = OffsetConstInt->getValue().sextOrTrunc(
+ DL.getIndexTypeSizeInBits(Ptr->getType()));
+ if (OffsetInt.srem(4) != 0)
return nullptr;
- Constant *C = ConstantExpr::getGetElementPtr(
- Int32Ty, ConstantExpr::getBitCast(Ptr, Int32PtrTy),
- ConstantInt::get(Int64Ty, OffsetInt / 4));
- Constant *Loaded = ConstantFoldLoadFromConstPtr(C, Int32Ty, DL);
+ Constant *Loaded = ConstantFoldLoadFromConstPtr(Ptr, Int32Ty, OffsetInt, DL);
if (!Loaded)
return nullptr;
@@ -6080,11 +6117,62 @@ static Value *simplifyRelativeLoad(Constant *Ptr, Constant *Offset,
PtrSym != LoadedRHSSym || PtrOffset != LoadedRHSOffset)
return nullptr;
- return ConstantExpr::getBitCast(LoadedLHSPtr, Int8PtrTy);
+ return LoadedLHSPtr;
+}
+
+// TODO: Need to pass in FastMathFlags
+static Value *simplifyLdexp(Value *Op0, Value *Op1, const SimplifyQuery &Q,
+ bool IsStrict) {
+ // ldexp(poison, x) -> poison
+ // ldexp(x, poison) -> poison
+ if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1))
+ return Op0;
+
+ // ldexp(undef, x) -> nan
+ if (Q.isUndefValue(Op0))
+ return ConstantFP::getNaN(Op0->getType());
+
+ if (!IsStrict) {
+ // TODO: Could insert a canonicalize for strict
+
+ // ldexp(x, undef) -> x
+ if (Q.isUndefValue(Op1))
+ return Op0;
+ }
+
+ const APFloat *C = nullptr;
+ match(Op0, PatternMatch::m_APFloat(C));
+
+ // These cases should be safe, even with strictfp.
+ // ldexp(0.0, x) -> 0.0
+ // ldexp(-0.0, x) -> -0.0
+ // ldexp(inf, x) -> inf
+ // ldexp(-inf, x) -> -inf
+ if (C && (C->isZero() || C->isInfinity()))
+ return Op0;
+
+ // These are canonicalization dropping, could do it if we knew how we could
+ // ignore denormal flushes and target handling of nan payload bits.
+ if (IsStrict)
+ return nullptr;
+
+ // TODO: Could quiet this with strictfp if the exception mode isn't strict.
+ if (C && C->isNaN())
+ return ConstantFP::get(Op0->getType(), C->makeQuiet());
+
+ // ldexp(x, 0) -> x
+
+ // TODO: Could fold this if we know the exception mode isn't
+ // strict, we know the denormal mode and other target modes.
+ if (match(Op1, PatternMatch::m_ZeroInt()))
+ return Op0;
+
+ return nullptr;
}
static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
- const SimplifyQuery &Q) {
+ const SimplifyQuery &Q,
+ const CallBase *Call) {
// Idempotent functions return the same result when called repeatedly.
Intrinsic::ID IID = F->getIntrinsicID();
if (isIdempotent(IID))
@@ -6129,31 +6217,37 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
// ctpop(and X, 1) --> and X, 1
unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
if (MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, BitWidth - 1),
- Q.DL, 0, Q.AC, Q.CxtI, Q.DT))
+ Q))
return Op0;
break;
}
case Intrinsic::exp:
// exp(log(x)) -> x
- if (Q.CxtI->hasAllowReassoc() &&
+ if (Call->hasAllowReassoc() &&
match(Op0, m_Intrinsic<Intrinsic::log>(m_Value(X))))
return X;
break;
case Intrinsic::exp2:
// exp2(log2(x)) -> x
- if (Q.CxtI->hasAllowReassoc() &&
+ if (Call->hasAllowReassoc() &&
match(Op0, m_Intrinsic<Intrinsic::log2>(m_Value(X))))
return X;
break;
+ case Intrinsic::exp10:
+ // exp10(log10(x)) -> x
+ if (Call->hasAllowReassoc() &&
+ match(Op0, m_Intrinsic<Intrinsic::log10>(m_Value(X))))
+ return X;
+ break;
case Intrinsic::log:
// log(exp(x)) -> x
- if (Q.CxtI->hasAllowReassoc() &&
+ if (Call->hasAllowReassoc() &&
match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))))
return X;
break;
case Intrinsic::log2:
// log2(exp2(x)) -> x
- if (Q.CxtI->hasAllowReassoc() &&
+ if (Call->hasAllowReassoc() &&
(match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) ||
match(Op0,
m_Intrinsic<Intrinsic::pow>(m_SpecificFP(2.0), m_Value(X)))))
@@ -6161,8 +6255,11 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
break;
case Intrinsic::log10:
// log10(pow(10.0, x)) -> x
- if (Q.CxtI->hasAllowReassoc() &&
- match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X))))
+ // log10(exp10(x)) -> x
+ if (Call->hasAllowReassoc() &&
+ (match(Op0, m_Intrinsic<Intrinsic::exp10>(m_Value(X))) ||
+ match(Op0,
+ m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X)))))
return X;
break;
case Intrinsic::experimental_vector_reverse:
@@ -6260,7 +6357,8 @@ static Value *foldMinimumMaximumSharedOp(Intrinsic::ID IID, Value *Op0,
}
static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
- const SimplifyQuery &Q) {
+ const SimplifyQuery &Q,
+ const CallBase *Call) {
Intrinsic::ID IID = F->getIntrinsicID();
Type *ReturnType = F->getReturnType();
unsigned BitWidth = ReturnType->getScalarSizeInBits();
@@ -6287,6 +6385,44 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
return Constant::getNullValue(ReturnType);
break;
}
+ case Intrinsic::ptrmask: {
+ if (isa<PoisonValue>(Op0) || isa<PoisonValue>(Op1))
+ return PoisonValue::get(Op0->getType());
+
+ // NOTE: We can't apply this simplifications based on the value of Op1
+ // because we need to preserve provenance.
+ if (Q.isUndefValue(Op0) || match(Op0, m_Zero()))
+ return Constant::getNullValue(Op0->getType());
+
+ assert(Op1->getType()->getScalarSizeInBits() ==
+ Q.DL.getIndexTypeSizeInBits(Op0->getType()) &&
+ "Invalid mask width");
+ // If index-width (mask size) is less than pointer-size then mask is
+ // 1-extended.
+ if (match(Op1, m_PtrToInt(m_Specific(Op0))))
+ return Op0;
+
+ // NOTE: We may have attributes associated with the return value of the
+ // llvm.ptrmask intrinsic that will be lost when we just return the
+ // operand. We should try to preserve them.
+ if (match(Op1, m_AllOnes()) || Q.isUndefValue(Op1))
+ return Op0;
+
+ Constant *C;
+ if (match(Op1, m_ImmConstant(C))) {
+ KnownBits PtrKnown = computeKnownBits(Op0, /*Depth=*/0, Q);
+ // See if we only masking off bits we know are already zero due to
+ // alignment.
+ APInt IrrelevantPtrBits =
+ PtrKnown.Zero.zextOrTrunc(C->getType()->getScalarSizeInBits());
+ C = ConstantFoldBinaryOpOperands(
+ Instruction::Or, C, ConstantInt::get(C->getType(), IrrelevantPtrBits),
+ Q.DL);
+ if (C != nullptr && C->isAllOnesValue())
+ return Op0;
+ }
+ break;
+ }
case Intrinsic::smax:
case Intrinsic::smin:
case Intrinsic::umax:
@@ -6426,6 +6562,8 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
return Op0;
}
break;
+ case Intrinsic::ldexp:
+ return simplifyLdexp(Op0, Op1, Q, false);
case Intrinsic::copysign:
// copysign X, X --> X
if (Op0 == Op1)
@@ -6480,19 +6618,19 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
// float, if the ninf flag is set.
const APFloat *C;
if (match(Op1, m_APFloat(C)) &&
- (C->isInfinity() || (Q.CxtI->hasNoInfs() && C->isLargest()))) {
+ (C->isInfinity() || (Call->hasNoInfs() && C->isLargest()))) {
// minnum(X, -inf) -> -inf
// maxnum(X, +inf) -> +inf
// minimum(X, -inf) -> -inf if nnan
// maximum(X, +inf) -> +inf if nnan
- if (C->isNegative() == IsMin && (!PropagateNaN || Q.CxtI->hasNoNaNs()))
+ if (C->isNegative() == IsMin && (!PropagateNaN || Call->hasNoNaNs()))
return ConstantFP::get(ReturnType, *C);
// minnum(X, +inf) -> X if nnan
// maxnum(X, -inf) -> X if nnan
// minimum(X, +inf) -> X
// maximum(X, -inf) -> X
- if (C->isNegative() != IsMin && (PropagateNaN || Q.CxtI->hasNoNaNs()))
+ if (C->isNegative() != IsMin && (PropagateNaN || Call->hasNoNaNs()))
return Op0;
}
@@ -6539,13 +6677,10 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
if (!NumOperands) {
switch (IID) {
case Intrinsic::vscale: {
- auto Attr = Call->getFunction()->getFnAttribute(Attribute::VScaleRange);
- if (!Attr.isValid())
- return nullptr;
- unsigned VScaleMin = Attr.getVScaleRangeMin();
- std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax();
- if (VScaleMax && VScaleMin == VScaleMax)
- return ConstantInt::get(F->getReturnType(), VScaleMin);
+ Type *RetTy = F->getReturnType();
+ ConstantRange CR = getVScaleRange(Call->getFunction(), 64);
+ if (const APInt *C = CR.getSingleElement())
+ return ConstantInt::get(RetTy, C->getZExtValue());
return nullptr;
}
default:
@@ -6554,10 +6689,10 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
}
if (NumOperands == 1)
- return simplifyUnaryIntrinsic(F, Args[0], Q);
+ return simplifyUnaryIntrinsic(F, Args[0], Q, Call);
if (NumOperands == 2)
- return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q);
+ return simplifyBinaryIntrinsic(F, Args[0], Args[1], Q, Call);
// Handle intrinsics with 3 or more arguments.
switch (IID) {
@@ -6692,6 +6827,8 @@ static Value *simplifyIntrinsic(CallBase *Call, Value *Callee,
*FPI->getExceptionBehavior(),
*FPI->getRoundingMode());
}
+ case Intrinsic::experimental_constrained_ldexp:
+ return simplifyLdexp(Args[0], Args[1], Q, true);
default:
return nullptr;
}
@@ -6811,6 +6948,9 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
const SimplifyQuery &SQ,
unsigned MaxRecurse) {
assert(I->getFunction() && "instruction should be inserted in a function");
+ assert((!SQ.CxtI || SQ.CxtI->getFunction() == I->getFunction()) &&
+ "context instruction should be in the same function");
+
const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I);
switch (I->getOpcode()) {