aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Analysis/InstructionSimplify.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:04 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:11 +0000
commite3b557809604d036af6e00c60f012c2025b59a5e (patch)
tree8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Analysis/InstructionSimplify.cpp
parent08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff)
Diffstat (limited to 'llvm/lib/Analysis/InstructionSimplify.cpp')
-rw-r--r--llvm/lib/Analysis/InstructionSimplify.cpp865
1 files changed, 561 insertions, 304 deletions
diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp
index 21fe448218bc..c83eb96bbc69 100644
--- a/llvm/lib/Analysis/InstructionSimplify.cpp
+++ b/llvm/lib/Analysis/InstructionSimplify.cpp
@@ -41,6 +41,7 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include <algorithm>
+#include <optional>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -741,9 +742,45 @@ static Constant *computePointerDifference(const DataLayout &DL, Value *LHS,
return Res;
}
+/// Test if there is a dominating equivalence condition for the
+/// two operands. If there is, try to reduce the binary operation
+/// between the two operands.
+/// Example: Op0 - Op1 --> 0 when Op0 == Op1
+static Value *simplifyByDomEq(unsigned Opcode, Value *Op0, Value *Op1,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ // Recursive run it can not get any benefit
+ if (MaxRecurse != RecursionLimit)
+ return nullptr;
+
+ std::optional<bool> Imp =
+ isImpliedByDomCondition(CmpInst::ICMP_EQ, Op0, Op1, Q.CxtI, Q.DL);
+ if (Imp && *Imp) {
+ Type *Ty = Op0->getType();
+ switch (Opcode) {
+ case Instruction::Sub:
+ case Instruction::Xor:
+ case Instruction::URem:
+ case Instruction::SRem:
+ return Constant::getNullValue(Ty);
+
+ case Instruction::SDiv:
+ case Instruction::UDiv:
+ return ConstantInt::get(Ty, 1);
+
+ case Instruction::And:
+ case Instruction::Or:
+ // Could be either one - choose Op1 since that's more likely a constant.
+ return Op1;
+ default:
+ break;
+ }
+ }
+ return nullptr;
+}
+
/// Given operands for a Sub, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+static Value *simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
const SimplifyQuery &Q, unsigned MaxRecurse) {
if (Constant *C = foldOrCommuteConstant(Instruction::Sub, Op0, Op1, Q))
return C;
@@ -769,14 +806,14 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
// Is this a negation?
if (match(Op0, m_Zero())) {
// 0 - X -> 0 if the sub is NUW.
- if (isNUW)
+ if (IsNUW)
return Constant::getNullValue(Op0->getType());
KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (Known.Zero.isMaxSignedValue()) {
// Op1 is either 0 or the minimum signed value. If the sub is NSW, then
// Op1 must be 0 because negating the minimum signed value is undefined.
- if (isNSW)
+ if (IsNSW)
return Constant::getNullValue(Op0->getType());
// 0 - X -> X if X is 0 or the minimum signed value.
@@ -872,18 +909,21 @@ static Value *simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
// "A-B" and "A-C" thus gains nothing, but costs compile time. Similarly
// for threading over phi nodes.
+ if (Value *V = simplifyByDomEq(Instruction::Sub, Op0, Op1, Q, MaxRecurse))
+ return V;
+
return nullptr;
}
-Value *llvm::simplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+Value *llvm::simplifySubInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
const SimplifyQuery &Q) {
- return ::simplifySubInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit);
+ return ::simplifySubInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit);
}
/// Given operands for a Mul, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
- unsigned MaxRecurse) {
+static Value *simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
if (Constant *C = foldOrCommuteConstant(Instruction::Mul, Op0, Op1, Q))
return C;
@@ -908,10 +948,17 @@ static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
match(Op1, m_Exact(m_IDiv(m_Value(X), m_Specific(Op0)))))) // Y * (X / Y)
return X;
- // i1 mul -> and.
- if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))
- if (Value *V = simplifyAndInst(Op0, Op1, Q, MaxRecurse - 1))
- return V;
+ if (Op0->getType()->isIntOrIntVectorTy(1)) {
+ // mul i1 nsw is a special-case because -1 * -1 is poison (+1 is not
+ // representable). All other cases reduce to 0, so just return 0.
+ if (IsNSW)
+ return ConstantInt::getNullValue(Op0->getType());
+
+ // Treat "mul i1" as "and i1".
+ if (MaxRecurse)
+ if (Value *V = simplifyAndInst(Op0, Op1, Q, MaxRecurse - 1))
+ return V;
+ }
// Try some generic simplifications for associative operations.
if (Value *V =
@@ -940,14 +987,16 @@ static Value *simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
return nullptr;
}
-Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
- return ::simplifyMulInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifyMulInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
+ const SimplifyQuery &Q) {
+ return ::simplifyMulInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit);
}
/// Check for common or similar folds of integer division or integer remainder.
/// This applies to all 4 opcodes (sdiv/udiv/srem/urem).
static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
- Value *Op1, const SimplifyQuery &Q) {
+ Value *Op1, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
bool IsDiv = (Opcode == Instruction::SDiv || Opcode == Instruction::UDiv);
bool IsSigned = (Opcode == Instruction::SDiv || Opcode == Instruction::SRem);
@@ -1022,6 +1071,9 @@ static Value *simplifyDivRem(Instruction::BinaryOps Opcode, Value *Op0,
}
}
+ if (Value *V = simplifyByDomEq(Opcode, Op0, Op1, Q, MaxRecurse))
+ return V;
+
return nullptr;
}
@@ -1099,13 +1151,24 @@ static bool isDivZero(Value *X, Value *Y, const SimplifyQuery &Q,
/// These are simplifications common to SDiv and UDiv.
static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
- const SimplifyQuery &Q, unsigned MaxRecurse) {
+ bool IsExact, const SimplifyQuery &Q,
+ unsigned MaxRecurse) {
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
- if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
+ if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse))
return V;
+ // If this is an exact divide by a constant, then the dividend (Op0) must have
+ // at least as many trailing zeros as the divisor to divide evenly. If it has
+ // less trailing zeros, then the result must be poison.
+ const APInt *DivC;
+ if (IsExact && match(Op1, m_APInt(DivC)) && DivC->countTrailingZeros()) {
+ KnownBits KnownOp0 = computeKnownBits(Op0, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
+ if (KnownOp0.countMaxTrailingZeros() < DivC->countTrailingZeros())
+ return PoisonValue::get(Op0->getType());
+ }
+
bool IsSigned = Opcode == Instruction::SDiv;
// (X rem Y) / Y -> 0
@@ -1147,7 +1210,7 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
if (Constant *C = foldOrCommuteConstant(Opcode, Op0, Op1, Q))
return C;
- if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q))
+ if (Value *V = simplifyDivRem(Opcode, Op0, Op1, Q, MaxRecurse))
return V;
// (X % Y) % Y -> X % Y
@@ -1186,28 +1249,30 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,
/// Given operands for an SDiv, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
- unsigned MaxRecurse) {
+static Value *simplifySDivInst(Value *Op0, Value *Op1, bool IsExact,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
// If two operands are negated and no signed overflow, return -1.
if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true))
return Constant::getAllOnesValue(Op0->getType());
- return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse);
+ return simplifyDiv(Instruction::SDiv, Op0, Op1, IsExact, Q, MaxRecurse);
}
-Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
- return ::simplifySDivInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifySDivInst(Value *Op0, Value *Op1, bool IsExact,
+ const SimplifyQuery &Q) {
+ return ::simplifySDivInst(Op0, Op1, IsExact, Q, RecursionLimit);
}
/// Given operands for a UDiv, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
- unsigned MaxRecurse) {
- return simplifyDiv(Instruction::UDiv, Op0, Op1, Q, MaxRecurse);
+static Value *simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact,
+ const SimplifyQuery &Q, unsigned MaxRecurse) {
+ return simplifyDiv(Instruction::UDiv, Op0, Op1, IsExact, Q, MaxRecurse);
}
-Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
- return ::simplifyUDivInst(Op0, Op1, Q, RecursionLimit);
+Value *llvm::simplifyUDivInst(Value *Op0, Value *Op1, bool IsExact,
+ const SimplifyQuery &Q) {
+ return ::simplifyUDivInst(Op0, Op1, IsExact, Q, RecursionLimit);
}
/// Given operands for an SRem, see if we can fold the result.
@@ -1252,12 +1317,14 @@ static bool isPoisonShift(Value *Amount, const SimplifyQuery &Q) {
if (Q.isUndefValue(C))
return true;
- // Shifting by the bitwidth or more is undefined.
- if (ConstantInt *CI = dyn_cast<ConstantInt>(C))
- if (CI->getValue().uge(CI->getType()->getScalarSizeInBits()))
- return true;
+ // Shifting by the bitwidth or more is poison. This covers scalars and
+ // fixed/scalable vectors with splat constants.
+ const APInt *AmountC;
+ if (match(C, m_APInt(AmountC)) && AmountC->uge(AmountC->getBitWidth()))
+ return true;
- // If all lanes of a vector shift are undefined the whole shift is.
+ // Try harder for fixed-length vectors:
+ // If all lanes of a vector shift are poison, the whole shift is poison.
if (isa<ConstantVector>(C) || isa<ConstantDataVector>(C)) {
for (unsigned I = 0,
E = cast<FixedVectorType>(C->getType())->getNumElements();
@@ -1343,7 +1410,7 @@ static Value *simplifyShift(Instruction::BinaryOps Opcode, Value *Op0,
/// Given operands for an Shl, LShr or AShr, see if we can
/// fold the result. If not, this returns null.
static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
- Value *Op1, bool isExact,
+ Value *Op1, bool IsExact,
const SimplifyQuery &Q, unsigned MaxRecurse) {
if (Value *V =
simplifyShift(Opcode, Op0, Op1, /*IsNSW*/ false, Q, MaxRecurse))
@@ -1356,10 +1423,11 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
// undef >> X -> 0
// undef >> X -> undef (if it's exact)
if (Q.isUndefValue(Op0))
- return isExact ? Op0 : Constant::getNullValue(Op0->getType());
+ return IsExact ? Op0 : Constant::getNullValue(Op0->getType());
// The low bit cannot be shifted out of an exact shift if it is set.
- if (isExact) {
+ // TODO: Generalize by counting trailing zeros (see fold for exact division).
+ if (IsExact) {
KnownBits Op0Known =
computeKnownBits(Op0, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT);
if (Op0Known.One[0])
@@ -1371,16 +1439,16 @@ static Value *simplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,
/// Given operands for an Shl, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+static Value *simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
const SimplifyQuery &Q, unsigned MaxRecurse) {
if (Value *V =
- simplifyShift(Instruction::Shl, Op0, Op1, isNSW, Q, MaxRecurse))
+ simplifyShift(Instruction::Shl, Op0, Op1, IsNSW, Q, MaxRecurse))
return V;
// undef << X -> 0
// undef << X -> undef if (if it's NSW/NUW)
if (Q.isUndefValue(Op0))
- return isNSW || isNUW ? Op0 : Constant::getNullValue(Op0->getType());
+ return IsNSW || IsNUW ? Op0 : Constant::getNullValue(Op0->getType());
// (X >> A) << A -> X
Value *X;
@@ -1389,7 +1457,7 @@ static Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
return X;
// shl nuw i8 C, %x -> C iff C has sign bit set.
- if (isNUW && match(Op0, m_Negative()))
+ if (IsNUW && match(Op0, m_Negative()))
return Op0;
// NOTE: could use computeKnownBits() / LazyValueInfo,
// but the cost-benefit analysis suggests it isn't worth it.
@@ -1397,16 +1465,16 @@ static Value *simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
return nullptr;
}
-Value *llvm::simplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,
+Value *llvm::simplifyShlInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,
const SimplifyQuery &Q) {
- return ::simplifyShlInst(Op0, Op1, isNSW, isNUW, Q, RecursionLimit);
+ return ::simplifyShlInst(Op0, Op1, IsNSW, IsNUW, Q, RecursionLimit);
}
/// Given operands for an LShr, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
+static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Value *V = simplifyRightShift(Instruction::LShr, Op0, Op1, isExact, Q,
+ if (Value *V = simplifyRightShift(Instruction::LShr, Op0, Op1, IsExact, Q,
MaxRecurse))
return V;
@@ -1434,16 +1502,16 @@ static Value *simplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
return nullptr;
}
-Value *llvm::simplifyLShrInst(Value *Op0, Value *Op1, bool isExact,
+Value *llvm::simplifyLShrInst(Value *Op0, Value *Op1, bool IsExact,
const SimplifyQuery &Q) {
- return ::simplifyLShrInst(Op0, Op1, isExact, Q, RecursionLimit);
+ return ::simplifyLShrInst(Op0, Op1, IsExact, Q, RecursionLimit);
}
/// Given operands for an AShr, see if we can fold the result.
/// If not, this returns null.
-static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
+static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
const SimplifyQuery &Q, unsigned MaxRecurse) {
- if (Value *V = simplifyRightShift(Instruction::AShr, Op0, Op1, isExact, Q,
+ if (Value *V = simplifyRightShift(Instruction::AShr, Op0, Op1, IsExact, Q,
MaxRecurse))
return V;
@@ -1467,9 +1535,9 @@ static Value *simplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
return nullptr;
}
-Value *llvm::simplifyAShrInst(Value *Op0, Value *Op1, bool isExact,
+Value *llvm::simplifyAShrInst(Value *Op0, Value *Op1, bool IsExact,
const SimplifyQuery &Q) {
- return ::simplifyAShrInst(Op0, Op1, isExact, Q, RecursionLimit);
+ return ::simplifyAShrInst(Op0, Op1, IsExact, Q, RecursionLimit);
}
/// Commuted variants are assumed to be handled by calling this function again
@@ -1582,45 +1650,6 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,
return nullptr;
}
-/// Commuted variants are assumed to be handled by calling this function again
-/// with the parameters swapped.
-static Value *simplifyAndOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) {
- ICmpInst::Predicate Pred0, Pred1;
- Value *A, *B;
- if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) ||
- !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B))))
- return nullptr;
-
- // Check for any combination of predicates that are guaranteed to be disjoint.
- if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) ||
- (Pred0 == ICmpInst::ICMP_EQ && ICmpInst::isFalseWhenEqual(Pred1)) ||
- (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT) ||
- (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT))
- return getFalse(Op0->getType());
-
- return nullptr;
-}
-
-/// Commuted variants are assumed to be handled by calling this function again
-/// with the parameters swapped.
-static Value *simplifyOrOfICmpsWithSameOperands(ICmpInst *Op0, ICmpInst *Op1) {
- ICmpInst::Predicate Pred0, Pred1;
- Value *A, *B;
- if (!match(Op0, m_ICmp(Pred0, m_Value(A), m_Value(B))) ||
- !match(Op1, m_ICmp(Pred1, m_Specific(A), m_Specific(B))))
- return nullptr;
-
- // Check for any combination of predicates that cover the entire range of
- // possibilities.
- if ((Pred0 == ICmpInst::getInversePredicate(Pred1)) ||
- (Pred0 == ICmpInst::ICMP_NE && ICmpInst::isTrueWhenEqual(Pred1)) ||
- (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGE) ||
- (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_UGE))
- return getTrue(Op0->getType());
-
- return nullptr;
-}
-
/// Test if a pair of compares with a shared operand and 2 constants has an
/// empty set intersection, full set union, or if one compare is a superset of
/// the other.
@@ -1715,25 +1744,25 @@ static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
return nullptr;
Type *ITy = Op0->getType();
- bool isNSW = IIQ.hasNoSignedWrap(AddInst);
- bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
+ bool IsNSW = IIQ.hasNoSignedWrap(AddInst);
+ bool IsNUW = IIQ.hasNoUnsignedWrap(AddInst);
const APInt Delta = *C1 - *C0;
if (C0->isStrictlyPositive()) {
if (Delta == 2) {
if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_SGT)
return getFalse(ITy);
- if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && isNSW)
+ if (Pred0 == ICmpInst::ICMP_SLT && Pred1 == ICmpInst::ICMP_SGT && IsNSW)
return getFalse(ITy);
}
if (Delta == 1) {
if (Pred0 == ICmpInst::ICMP_ULE && Pred1 == ICmpInst::ICMP_SGT)
return getFalse(ITy);
- if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && isNSW)
+ if (Pred0 == ICmpInst::ICMP_SLE && Pred1 == ICmpInst::ICMP_SGT && IsNSW)
return getFalse(ITy);
}
}
- if (C0->getBoolValue() && isNUW) {
+ if (C0->getBoolValue() && IsNUW) {
if (Delta == 2)
if (Pred0 == ICmpInst::ICMP_ULT && Pred1 == ICmpInst::ICMP_UGT)
return getFalse(ITy);
@@ -1833,11 +1862,6 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1,
if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/true, Q))
return X;
- if (Value *X = simplifyAndOfICmpsWithSameOperands(Op0, Op1))
- return X;
- if (Value *X = simplifyAndOfICmpsWithSameOperands(Op1, Op0))
- return X;
-
if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true))
return X;
@@ -1877,25 +1901,25 @@ static Value *simplifyOrOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1,
return nullptr;
Type *ITy = Op0->getType();
- bool isNSW = IIQ.hasNoSignedWrap(AddInst);
- bool isNUW = IIQ.hasNoUnsignedWrap(AddInst);
+ bool IsNSW = IIQ.hasNoSignedWrap(AddInst);
+ bool IsNUW = IIQ.hasNoUnsignedWrap(AddInst);
const APInt Delta = *C1 - *C0;
if (C0->isStrictlyPositive()) {
if (Delta == 2) {
if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_SLE)
return getTrue(ITy);
- if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && isNSW)
+ if (Pred0 == ICmpInst::ICMP_SGE && Pred1 == ICmpInst::ICMP_SLE && IsNSW)
return getTrue(ITy);
}
if (Delta == 1) {
if (Pred0 == ICmpInst::ICMP_UGT && Pred1 == ICmpInst::ICMP_SLE)
return getTrue(ITy);
- if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && isNSW)
+ if (Pred0 == ICmpInst::ICMP_SGT && Pred1 == ICmpInst::ICMP_SLE && IsNSW)
return getTrue(ITy);
}
}
- if (C0->getBoolValue() && isNUW) {
+ if (C0->getBoolValue() && IsNUW) {
if (Delta == 2)
if (Pred0 == ICmpInst::ICMP_UGE && Pred1 == ICmpInst::ICMP_ULE)
return getTrue(ITy);
@@ -1914,11 +1938,6 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1,
if (Value *X = simplifyUnsignedRangeCheck(Op1, Op0, /*IsAnd=*/false, Q))
return X;
- if (Value *X = simplifyOrOfICmpsWithSameOperands(Op0, Op1))
- return X;
- if (Value *X = simplifyOrOfICmpsWithSameOperands(Op1, Op0))
- return X;
-
if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false))
return X;
@@ -2220,14 +2239,27 @@ static Value *simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
return Constant::getNullValue(Op0->getType());
if (Op0->getType()->isIntOrIntVectorTy(1)) {
- // Op0&Op1 -> Op0 where Op0 implies Op1
- if (isImpliedCondition(Op0, Op1, Q.DL).value_or(false))
- return Op0;
- // Op0&Op1 -> Op1 where Op1 implies Op0
- if (isImpliedCondition(Op1, Op0, Q.DL).value_or(false))
- return Op1;
+ if (std::optional<bool> Implied = isImpliedCondition(Op0, Op1, Q.DL)) {
+ // If Op0 is true implies Op1 is true, then Op0 is a subset of Op1.
+ if (*Implied == true)
+ return Op0;
+ // If Op0 is true implies Op1 is false, then they are not true together.
+ if (*Implied == false)
+ return ConstantInt::getFalse(Op0->getType());
+ }
+ if (std::optional<bool> Implied = isImpliedCondition(Op1, Op0, Q.DL)) {
+ // If Op1 is true implies Op0 is true, then Op1 is a subset of Op0.
+ if (*Implied)
+ return Op1;
+ // If Op1 is true implies Op0 is false, then they are not true together.
+ if (!*Implied)
+ return ConstantInt::getFalse(Op1->getType());
+ }
}
+ if (Value *V = simplifyByDomEq(Instruction::And, Op0, Op1, Q, MaxRecurse))
+ return V;
+
return nullptr;
}
@@ -2235,6 +2267,7 @@ Value *llvm::simplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {
return ::simplifyAndInst(Op0, Op1, Q, RecursionLimit);
}
+// TODO: Many of these folds could use LogicalAnd/LogicalOr.
static Value *simplifyOrLogic(Value *X, Value *Y) {
assert(X->getType() == Y->getType() && "Expected same type for 'or' ops");
Type *Ty = X->getType();
@@ -2277,7 +2310,7 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
// (B ^ ~A) | (A & B) --> B ^ ~A
// (~A ^ B) | (B & A) --> ~A ^ B
// (B ^ ~A) | (B & A) --> B ^ ~A
- if (match(X, m_c_Xor(m_Not(m_Value(A)), m_Value(B))) &&
+ if (match(X, m_c_Xor(m_NotForbidUndef(m_Value(A)), m_Value(B))) &&
match(Y, m_c_And(m_Specific(A), m_Specific(B))))
return X;
@@ -2299,6 +2332,14 @@ static Value *simplifyOrLogic(Value *X, Value *Y) {
m_Value(B))) &&
match(Y, m_Not(m_c_Or(m_Specific(A), m_Specific(B)))))
return NotA;
+ // The same is true of Logical And
+ // TODO: This could share the logic of the version above if there was a
+ // version of LogicalAnd that allowed more than just i1 types.
+ if (match(X, m_c_LogicalAnd(
+ m_CombineAnd(m_Value(NotA), m_NotForbidUndef(m_Value(A))),
+ m_Value(B))) &&
+ match(Y, m_Not(m_c_LogicalOr(m_Specific(A), m_Specific(B)))))
+ return NotA;
// ~(A ^ B) | (A & B) --> ~(A ^ B)
// ~(A ^ B) | (B & A) --> ~(A ^ B)
@@ -2460,14 +2501,29 @@ static Value *simplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
return V;
if (Op0->getType()->isIntOrIntVectorTy(1)) {
- // Op0|Op1 -> Op1 where Op0 implies Op1
- if (isImpliedCondition(Op0, Op1, Q.DL).value_or(false))
- return Op1;
- // Op0|Op1 -> Op0 where Op1 implies Op0
- if (isImpliedCondition(Op1, Op0, Q.DL).value_or(false))
- return Op0;
+ if (std::optional<bool> Implied =
+ isImpliedCondition(Op0, Op1, Q.DL, false)) {
+ // If Op0 is false implies Op1 is false, then Op1 is a subset of Op0.
+ if (*Implied == false)
+ return Op0;
+ // If Op0 is false implies Op1 is true, then at least one is always true.
+ if (*Implied == true)
+ return ConstantInt::getTrue(Op0->getType());
+ }
+ if (std::optional<bool> Implied =
+ isImpliedCondition(Op1, Op0, Q.DL, false)) {
+ // If Op1 is false implies Op0 is false, then Op0 is a subset of Op1.
+ if (*Implied == false)
+ return Op1;
+ // If Op1 is false implies Op0 is true, then at least one is always true.
+ if (*Implied == true)
+ return ConstantInt::getTrue(Op1->getType());
+ }
}
+ if (Value *V = simplifyByDomEq(Instruction::Or, Op0, Op1, Q, MaxRecurse))
+ return V;
+
return nullptr;
}
@@ -2543,6 +2599,9 @@ static Value *simplifyXorInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,
// "A^B" and "A^C" thus gains nothing, but costs compile time. Similarly
// for threading over phi nodes.
+ if (Value *V = simplifyByDomEq(Instruction::Xor, Op0, Op1, Q, MaxRecurse))
+ return V;
+
return nullptr;
}
@@ -2689,7 +2748,7 @@ static Constant *computePointerICmp(CmpInst::Predicate Pred, Value *LHS,
default:
return nullptr;
- // Equality comaprisons are easy to fold.
+ // Equality comparisons are easy to fold.
case CmpInst::ICMP_EQ:
case CmpInst::ICMP_NE:
break;
@@ -2895,6 +2954,11 @@ static Value *simplifyICmpOfBools(CmpInst::Predicate Pred, Value *LHS,
if (isImpliedCondition(LHS, RHS, Q.DL).value_or(false))
return getTrue(ITy);
break;
+ case ICmpInst::ICMP_SLE:
+ /// SLE follows the same logic as SGE with the LHS and RHS swapped.
+ if (isImpliedCondition(RHS, LHS, Q.DL).value_or(false))
+ return getTrue(ITy);
+ break;
}
return nullptr;
@@ -3054,7 +3118,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case ICmpInst::ICMP_EQ:
case ICmpInst::ICMP_UGT:
@@ -3065,7 +3129,7 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);
if (!Known.isNonNegative())
break;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case ICmpInst::ICMP_NE:
case ICmpInst::ICMP_ULT:
@@ -3148,6 +3212,12 @@ static Value *simplifyICmpWithBinOpOnLHS(CmpInst::Predicate Pred,
return getTrue(ITy);
}
+ // (sub C, X) == X, C is odd --> false
+ // (sub C, X) != X, C is odd --> true
+ if (match(LBO, m_Sub(m_APIntAllowUndef(C), m_Specific(RHS))) &&
+ (*C & 1) == 1 && ICmpInst::isEquality(Pred))
+ return (Pred == ICmpInst::ICMP_EQ) ? getFalse(ITy) : getTrue(ITy);
+
return nullptr;
}
@@ -3570,8 +3640,8 @@ static Value *simplifyICmpWithDominatingAssume(CmpInst::Predicate Predicate,
continue;
CallInst *Assume = cast<CallInst>(AssumeVH);
- if (Optional<bool> Imp = isImpliedCondition(Assume->getArgOperand(0),
- Predicate, LHS, RHS, Q.DL))
+ if (std::optional<bool> Imp = isImpliedCondition(
+ Assume->getArgOperand(0), Predicate, LHS, RHS, Q.DL))
if (isValidAssumeForContext(Assume, Q.CxtI, Q.DT))
return ConstantInt::get(getCompareTy(LHS), *Imp);
}
@@ -4098,8 +4168,6 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
const SimplifyQuery &Q,
bool AllowRefinement,
unsigned MaxRecurse) {
- assert(!Op->getType()->isVectorTy() && "This is not safe for vectors");
-
// Trivial replacement.
if (V == Op)
return RepOp;
@@ -4112,6 +4180,14 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (!I || !is_contained(I->operands(), Op))
return nullptr;
+ if (Op->getType()->isVectorTy()) {
+ // For vector types, the simplification must hold per-lane, so forbid
+ // potentially cross-lane operations like shufflevector.
+ assert(I->getType()->isVectorTy() && "Vector type mismatch");
+ if (isa<ShuffleVectorInst>(I) || isa<CallBase>(I))
+ return nullptr;
+ }
+
// Replace Op with RepOp in instruction operands.
SmallVector<Value *, 8> NewOps(I->getNumOperands());
transform(I->operands(), NewOps.begin(),
@@ -4167,7 +4243,7 @@ static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,
if (auto *GEP = dyn_cast<GetElementPtrInst>(I))
return PreventSelfSimplify(simplifyGEPInst(
- GEP->getSourceElementType(), NewOps[0], makeArrayRef(NewOps).slice(1),
+ GEP->getSourceElementType(), NewOps[0], ArrayRef(NewOps).slice(1),
GEP->isInBounds(), Q, MaxRecurse - 1));
if (isa<SelectInst>(I))
@@ -4243,6 +4319,78 @@ static Value *simplifySelectBitTest(Value *TrueVal, Value *FalseVal, Value *X,
return nullptr;
}
+static Value *simplifyCmpSelOfMaxMin(Value *CmpLHS, Value *CmpRHS,
+ ICmpInst::Predicate Pred, Value *TVal,
+ Value *FVal) {
+ // Canonicalize common cmp+sel operand as CmpLHS.
+ if (CmpRHS == TVal || CmpRHS == FVal) {
+ std::swap(CmpLHS, CmpRHS);
+ Pred = ICmpInst::getSwappedPredicate(Pred);
+ }
+
+ // Canonicalize common cmp+sel operand as TVal.
+ if (CmpLHS == FVal) {
+ std::swap(TVal, FVal);
+ Pred = ICmpInst::getInversePredicate(Pred);
+ }
+
+ // A vector select may be shuffling together elements that are equivalent
+ // based on the max/min/select relationship.
+ Value *X = CmpLHS, *Y = CmpRHS;
+ bool PeekedThroughSelectShuffle = false;
+ auto *Shuf = dyn_cast<ShuffleVectorInst>(FVal);
+ if (Shuf && Shuf->isSelect()) {
+ if (Shuf->getOperand(0) == Y)
+ FVal = Shuf->getOperand(1);
+ else if (Shuf->getOperand(1) == Y)
+ FVal = Shuf->getOperand(0);
+ else
+ return nullptr;
+ PeekedThroughSelectShuffle = true;
+ }
+
+ // (X pred Y) ? X : max/min(X, Y)
+ auto *MMI = dyn_cast<MinMaxIntrinsic>(FVal);
+ if (!MMI || TVal != X ||
+ !match(FVal, m_c_MaxOrMin(m_Specific(X), m_Specific(Y))))
+ return nullptr;
+
+ // (X > Y) ? X : max(X, Y) --> max(X, Y)
+ // (X >= Y) ? X : max(X, Y) --> max(X, Y)
+ // (X < Y) ? X : min(X, Y) --> min(X, Y)
+ // (X <= Y) ? X : min(X, Y) --> min(X, Y)
+ //
+ // The equivalence allows a vector select (shuffle) of max/min and Y. Ex:
+ // (X > Y) ? X : (Z ? max(X, Y) : Y)
+ // If Z is true, this reduces as above, and if Z is false:
+ // (X > Y) ? X : Y --> max(X, Y)
+ ICmpInst::Predicate MMPred = MMI->getPredicate();
+ if (MMPred == CmpInst::getStrictPredicate(Pred))
+ return MMI;
+
+ // Other transforms are not valid with a shuffle.
+ if (PeekedThroughSelectShuffle)
+ return nullptr;
+
+ // (X == Y) ? X : max/min(X, Y) --> max/min(X, Y)
+ if (Pred == CmpInst::ICMP_EQ)
+ return MMI;
+
+ // (X != Y) ? X : max/min(X, Y) --> X
+ if (Pred == CmpInst::ICMP_NE)
+ return X;
+
+ // (X < Y) ? X : max(X, Y) --> X
+ // (X <= Y) ? X : max(X, Y) --> X
+ // (X > Y) ? X : min(X, Y) --> X
+ // (X >= Y) ? X : min(X, Y) --> X
+ ICmpInst::Predicate InvPred = CmpInst::getInversePredicate(Pred);
+ if (MMPred == CmpInst::getStrictPredicate(InvPred))
+ return X;
+
+ return nullptr;
+}
+
/// An alternative way to test if a bit is set or not uses sgt/slt instead of
/// eq/ne.
static Value *simplifySelectWithFakeICmpEq(Value *CmpLHS, Value *CmpRHS,
@@ -4268,6 +4416,9 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
if (!match(CondVal, m_ICmp(Pred, m_Value(CmpLHS), m_Value(CmpRHS))))
return nullptr;
+ if (Value *V = simplifyCmpSelOfMaxMin(CmpLHS, CmpRHS, Pred, TrueVal, FalseVal))
+ return V;
+
// Canonicalize ne to eq predicate.
if (Pred == ICmpInst::ICMP_NE) {
Pred = ICmpInst::ICMP_EQ;
@@ -4341,9 +4492,7 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,
// If we have a scalar equality comparison, then we know the value in one of
// the arms of the select. See if substituting this value into the arm and
// simplifying the result yields the same value as the other arm.
- // Note that the equivalence/replacement opportunity does not hold for vectors
- // because each element of a vector select is chosen independently.
- if (Pred == ICmpInst::ICMP_EQ && !CondVal->getType()->isVectorTy()) {
+ if (Pred == ICmpInst::ICMP_EQ) {
if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q,
/* AllowRefinement */ false,
MaxRecurse) == TrueVal ||
@@ -4431,9 +4580,34 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
if (match(TrueVal, m_One()) && match(FalseVal, m_ZeroInt()))
return Cond;
- // (X || Y) && (X || !Y) --> X (commuted 8 ways)
- Value *X, *Y;
+ // (X && Y) ? X : Y --> Y (commuted 2 ways)
+ if (match(Cond, m_c_LogicalAnd(m_Specific(TrueVal), m_Specific(FalseVal))))
+ return FalseVal;
+
+ // (X || Y) ? X : Y --> X (commuted 2 ways)
+ if (match(Cond, m_c_LogicalOr(m_Specific(TrueVal), m_Specific(FalseVal))))
+ return TrueVal;
+
+ // (X || Y) ? false : X --> false (commuted 2 ways)
+ if (match(Cond, m_c_LogicalOr(m_Specific(FalseVal), m_Value())) &&
+ match(TrueVal, m_ZeroInt()))
+ return ConstantInt::getFalse(Cond->getType());
+
+ // Match patterns that end in logical-and.
if (match(FalseVal, m_ZeroInt())) {
+ // !(X || Y) && X --> false (commuted 2 ways)
+ if (match(Cond, m_Not(m_c_LogicalOr(m_Specific(TrueVal), m_Value()))))
+ return ConstantInt::getFalse(Cond->getType());
+
+ // (X || Y) && Y --> Y (commuted 2 ways)
+ if (match(Cond, m_c_LogicalOr(m_Specific(TrueVal), m_Value())))
+ return TrueVal;
+ // Y && (X || Y) --> Y (commuted 2 ways)
+ if (match(TrueVal, m_c_LogicalOr(m_Specific(Cond), m_Value())))
+ return Cond;
+
+ // (X || Y) && (X || !Y) --> X (commuted 8 ways)
+ Value *X, *Y;
if (match(Cond, m_c_LogicalOr(m_Value(X), m_Not(m_Value(Y)))) &&
match(TrueVal, m_c_LogicalOr(m_Specific(X), m_Specific(Y))))
return X;
@@ -4441,12 +4615,39 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
match(Cond, m_c_LogicalOr(m_Specific(X), m_Specific(Y))))
return X;
}
+
+ // Match patterns that end in logical-or.
+ if (match(TrueVal, m_One())) {
+ // (X && Y) || Y --> Y (commuted 2 ways)
+ if (match(Cond, m_c_LogicalAnd(m_Specific(FalseVal), m_Value())))
+ return FalseVal;
+ // Y || (X && Y) --> Y (commuted 2 ways)
+ if (match(FalseVal, m_c_LogicalAnd(m_Specific(Cond), m_Value())))
+ return Cond;
+ }
}
// select ?, X, X -> X
if (TrueVal == FalseVal)
return TrueVal;
+ if (Cond == TrueVal) {
+ // select i1 X, i1 X, i1 false --> X (logical-and)
+ if (match(FalseVal, m_ZeroInt()))
+ return Cond;
+ // select i1 X, i1 X, i1 true --> true
+ if (match(FalseVal, m_One()))
+ return ConstantInt::getTrue(Cond->getType());
+ }
+ if (Cond == FalseVal) {
+ // select i1 X, i1 true, i1 X --> X (logical-or)
+ if (match(TrueVal, m_One()))
+ return Cond;
+ // select i1 X, i1 false, i1 X --> false
+ if (match(TrueVal, m_ZeroInt()))
+ return ConstantInt::getFalse(Cond->getType());
+ }
+
// If the true or false value is poison, we can fold to the other value.
// If the true or false value is undef, we can fold to the other value as
// long as the other value isn't poison.
@@ -4505,7 +4706,7 @@ static Value *simplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal,
if (Value *V = foldSelectWithBinaryOp(Cond, TrueVal, FalseVal))
return V;
- Optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL);
+ std::optional<bool> Imp = isImpliedByDomCondition(Cond, Q.CxtI, Q.DL);
if (Imp)
return *Imp ? TrueVal : FalseVal;
@@ -4671,8 +4872,10 @@ static Value *simplifyInsertValueInst(Value *Agg, Value *Val,
if (Constant *CVal = dyn_cast<Constant>(Val))
return ConstantFoldInsertValueInstruction(CAgg, CVal, Idxs);
- // insertvalue x, undef, n -> x
- if (Q.isUndefValue(Val))
+ // insertvalue x, poison, n -> x
+ // insertvalue x, undef, n -> x if x cannot be poison
+ if (isa<PoisonValue>(Val) ||
+ (Q.isUndefValue(Val) && isGuaranteedNotToBePoison(Agg)))
return Agg;
// insertvalue x, (extractvalue y, n), n
@@ -4794,6 +4997,14 @@ static Value *simplifyExtractElementInst(Value *Vec, Value *Idx,
if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue()))
return Elt;
} else {
+ // extractelt x, (insertelt y, elt, n), n -> elt
+ // If the possibly-variable indices are trivially known to be equal
+ // (because they are the same operand) then use the value that was
+ // inserted directly.
+ auto *IE = dyn_cast<InsertElementInst>(Vec);
+ if (IE && IE->getOperand(2) == Idx)
+ return IE->getOperand(1);
+
// The index is not relevant if our vector is a splat.
if (Value *Splat = getSplatValue(Vec))
return Splat;
@@ -5019,7 +5230,7 @@ static Value *simplifyShuffleVectorInst(Value *Op0, Value *Op1,
// value type is same as the input vectors' type.
if (auto *OpShuf = dyn_cast<ShuffleVectorInst>(Op0))
if (Q.isUndefValue(Op1) && RetTy == InVecTy &&
- is_splat(OpShuf->getShuffleMask()))
+ all_equal(OpShuf->getShuffleMask()))
return Op0;
// All remaining transformation depend on the value of the mask, which is
@@ -5085,8 +5296,25 @@ Value *llvm::simplifyFNegInst(Value *Op, FastMathFlags FMF,
return ::simplifyFNegInst(Op, FMF, Q, RecursionLimit);
}
+/// Try to propagate existing NaN values when possible. If not, replace the
+/// constant or elements in the constant with a canonical NaN.
static Constant *propagateNaN(Constant *In) {
- // If the input is a vector with undef elements, just return a default NaN.
+ if (auto *VecTy = dyn_cast<FixedVectorType>(In->getType())) {
+ unsigned NumElts = VecTy->getNumElements();
+ SmallVector<Constant *, 32> NewC(NumElts);
+ for (unsigned i = 0; i != NumElts; ++i) {
+ Constant *EltC = In->getAggregateElement(i);
+ // Poison and existing NaN elements propagate.
+ // Replace unknown or undef elements with canonical NaN.
+ if (EltC && (isa<PoisonValue>(EltC) || EltC->isNaN()))
+ NewC[i] = EltC;
+ else
+ NewC[i] = (ConstantFP::getNaN(VecTy->getElementType()));
+ }
+ return ConstantVector::get(NewC);
+ }
+
+ // It is not a fixed vector, but not a simple NaN either?
if (!In->isNaN())
return ConstantFP::getNaN(In->getType());
@@ -5121,7 +5349,13 @@ static Constant *simplifyFPOp(ArrayRef<Value *> Ops, FastMathFlags FMF,
return PoisonValue::get(V->getType());
if (isDefaultFPEnvironment(ExBehavior, Rounding)) {
- if (IsUndef || IsNan)
+ // Undef does not propagate because undef means that all bits can take on
+ // any value. If this is undef * NaN for example, then the result values
+ // (at least the exponent bits) are limited. Assume the undef is a
+ // canonical NaN and propagate that.
+ if (IsUndef)
+ return ConstantFP::getNaN(V->getType());
+ if (IsNan)
return propagateNaN(cast<Constant>(V));
} else if (ExBehavior != fp::ebStrict) {
if (IsNan)
@@ -5165,14 +5399,18 @@ simplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,
if (!isDefaultFPEnvironment(ExBehavior, Rounding))
return nullptr;
- // With nnan: -X + X --> 0.0 (and commuted variant)
- // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN.
- // Negative zeros are allowed because we always end up with positive zero:
- // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
- // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
- // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0
- // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0
if (FMF.noNaNs()) {
+ // With nnan: X + {+/-}Inf --> {+/-}Inf
+ if (match(Op1, m_Inf()))
+ return Op1;
+
+ // With nnan: -X + X --> 0.0 (and commuted variant)
+ // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN.
+ // Negative zeros are allowed because we always end up with positive zero:
+ // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
+ // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0
+ // X = 0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0
+ // X = 0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0
if (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) ||
match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))
return ConstantFP::getNullValue(Op0->getType());
@@ -5227,19 +5465,30 @@ simplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,
if (match(Op0, m_NegZeroFP()) && match(Op1, m_FNeg(m_Value(X))))
return X;
+ // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored.
+ // fsub 0.0, (fneg X) ==> X if signed zeros are ignored.
+ if (canIgnoreSNaN(ExBehavior, FMF))
+ if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) &&
+ (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) ||
+ match(Op1, m_FNeg(m_Value(X)))))
+ return X;
+
if (!isDefaultFPEnvironment(ExBehavior, Rounding))
return nullptr;
- // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored.
- // fsub 0.0, (fneg X) ==> X if signed zeros are ignored.
- if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) &&
- (match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))) ||
- match(Op1, m_FNeg(m_Value(X)))))
- return X;
+ if (FMF.noNaNs()) {
+ // fsub nnan x, x ==> 0.0
+ if (Op0 == Op1)
+ return Constant::getNullValue(Op0->getType());
- // fsub nnan x, x ==> 0.0
- if (FMF.noNaNs() && Op0 == Op1)
- return Constant::getNullValue(Op0->getType());
+ // With nnan: {+/-}Inf - X --> {+/-}Inf
+ if (match(Op0, m_Inf()))
+ return Op0;
+
+ // With nnan: X - {+/-}Inf --> {-/+}Inf
+ if (match(Op1, m_Inf()))
+ return foldConstant(Instruction::FNeg, Op1, Q);
+ }
// Y - (Y - X) --> X
// (X + Y) - Y --> X
@@ -5261,21 +5510,24 @@ static Value *simplifyFMAFMul(Value *Op0, Value *Op1, FastMathFlags FMF,
if (!isDefaultFPEnvironment(ExBehavior, Rounding))
return nullptr;
- // fmul X, 1.0 ==> X
+ // Canonicalize special constants as operand 1.
+ if (match(Op0, m_FPOne()) || match(Op0, m_AnyZeroFP()))
+ std::swap(Op0, Op1);
+
+ // X * 1.0 --> X
if (match(Op1, m_FPOne()))
return Op0;
- // fmul 1.0, X ==> X
- if (match(Op0, m_FPOne()))
- return Op1;
-
- // fmul nnan nsz X, 0 ==> 0
- if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP()))
- return ConstantFP::getNullValue(Op0->getType());
+ if (match(Op1, m_AnyZeroFP())) {
+ // X * 0.0 --> 0.0 (with nnan and nsz)
+ if (FMF.noNaNs() && FMF.noSignedZeros())
+ return ConstantFP::getNullValue(Op0->getType());
- // fmul nnan nsz 0, X ==> 0
- if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()))
- return ConstantFP::getNullValue(Op1->getType());
+ // +normal number * (-)0.0 --> (-)0.0
+ if (isKnownNeverInfinity(Op0, Q.TLI) && isKnownNeverNaN(Op0, Q.TLI) &&
+ SignBitMustBeZero(Op0, Q.TLI))
+ return Op1;
+ }
// sqrt(X) * sqrt(X) --> X, if we can:
// 1. Remove the intermediate rounding (reassociate).
@@ -5377,6 +5629,10 @@ simplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,
if (match(Op0, m_FNegNSZ(m_Specific(Op1))) ||
match(Op1, m_FNegNSZ(m_Specific(Op0))))
return ConstantFP::get(Op0->getType(), -1.0);
+
+ // nnan ninf X / [-]0.0 -> poison
+ if (FMF.noInfs() && match(Op1, m_AnyZeroFP()))
+ return PoisonValue::get(Op1->getType());
}
return nullptr;
@@ -5471,25 +5727,29 @@ static Value *simplifyBinOp(unsigned Opcode, Value *LHS, Value *RHS,
const SimplifyQuery &Q, unsigned MaxRecurse) {
switch (Opcode) {
case Instruction::Add:
- return simplifyAddInst(LHS, RHS, false, false, Q, MaxRecurse);
+ return simplifyAddInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q,
+ MaxRecurse);
case Instruction::Sub:
- return simplifySubInst(LHS, RHS, false, false, Q, MaxRecurse);
+ return simplifySubInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q,
+ MaxRecurse);
case Instruction::Mul:
- return simplifyMulInst(LHS, RHS, Q, MaxRecurse);
+ return simplifyMulInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q,
+ MaxRecurse);
case Instruction::SDiv:
- return simplifySDivInst(LHS, RHS, Q, MaxRecurse);
+ return simplifySDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
case Instruction::UDiv:
- return simplifyUDivInst(LHS, RHS, Q, MaxRecurse);
+ return simplifyUDivInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
case Instruction::SRem:
return simplifySRemInst(LHS, RHS, Q, MaxRecurse);
case Instruction::URem:
return simplifyURemInst(LHS, RHS, Q, MaxRecurse);
case Instruction::Shl:
- return simplifyShlInst(LHS, RHS, false, false, Q, MaxRecurse);
+ return simplifyShlInst(LHS, RHS, /* IsNSW */ false, /* IsNUW */ false, Q,
+ MaxRecurse);
case Instruction::LShr:
- return simplifyLShrInst(LHS, RHS, false, Q, MaxRecurse);
+ return simplifyLShrInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
case Instruction::AShr:
- return simplifyAShrInst(LHS, RHS, false, Q, MaxRecurse);
+ return simplifyAShrInst(LHS, RHS, /* IsExact */ false, Q, MaxRecurse);
case Instruction::And:
return simplifyAndInst(LHS, RHS, Q, MaxRecurse);
case Instruction::Or:
@@ -5569,6 +5829,25 @@ static bool isIdempotent(Intrinsic::ID ID) {
case Intrinsic::round:
case Intrinsic::roundeven:
case Intrinsic::canonicalize:
+ case Intrinsic::arithmetic_fence:
+ return true;
+ }
+}
+
+/// Return true if the intrinsic rounds a floating-point value to an integral
+/// floating-point value (not an integer type).
+static bool removesFPFraction(Intrinsic::ID ID) {
+ switch (ID) {
+ default:
+ return false;
+
+ case Intrinsic::floor:
+ case Intrinsic::ceil:
+ case Intrinsic::trunc:
+ case Intrinsic::rint:
+ case Intrinsic::nearbyint:
+ case Intrinsic::round:
+ case Intrinsic::roundeven:
return true;
}
}
@@ -5638,6 +5917,18 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
if (II->getIntrinsicID() == IID)
return II;
+ if (removesFPFraction(IID)) {
+ // Converting from int or calling a rounding function always results in a
+ // finite integral number or infinity. For those inputs, rounding functions
+ // always return the same value, so the (2nd) rounding is eliminated. Ex:
+ // floor (sitofp x) -> sitofp x
+ // round (ceil x) -> ceil x
+ auto *II = dyn_cast<IntrinsicInst>(Op0);
+ if ((II && removesFPFraction(II->getIntrinsicID())) ||
+ match(Op0, m_SIToFP(m_Value())) || match(Op0, m_UIToFP(m_Value())))
+ return Op0;
+ }
+
Value *X;
switch (IID) {
case Intrinsic::fabs:
@@ -5655,6 +5946,10 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
return X;
break;
case Intrinsic::ctpop: {
+ // ctpop(X) -> 1 iff X is non-zero power of 2.
+ if (isKnownToBeAPowerOfTwo(Op0, Q.DL, /*OrZero*/ false, 0, Q.AC, Q.CxtI,
+ Q.DT))
+ return ConstantInt::get(Op0->getType(), 1);
// If everything but the lowest bit is zero, that bit is the pop-count. Ex:
// ctpop(and X, 1) --> and X, 1
unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
@@ -5695,27 +5990,9 @@ static Value *simplifyUnaryIntrinsic(Function *F, Value *Op0,
match(Op0, m_Intrinsic<Intrinsic::pow>(m_SpecificFP(10.0), m_Value(X))))
return X;
break;
- case Intrinsic::floor:
- case Intrinsic::trunc:
- case Intrinsic::ceil:
- case Intrinsic::round:
- case Intrinsic::roundeven:
- case Intrinsic::nearbyint:
- case Intrinsic::rint: {
- // floor (sitofp x) -> sitofp x
- // floor (uitofp x) -> uitofp x
- //
- // Converting from int always results in a finite integral number or
- // infinity. For either of those inputs, these rounding functions always
- // return the same value, so the rounding can be eliminated.
- if (match(Op0, m_SIToFP(m_Value())) || match(Op0, m_UIToFP(m_Value())))
- return Op0;
- break;
- }
case Intrinsic::experimental_vector_reverse:
// experimental.vector.reverse(experimental.vector.reverse(x)) -> x
- if (match(Op0,
- m_Intrinsic<Intrinsic::experimental_vector_reverse>(m_Value(X))))
+ if (match(Op0, m_VecReverse(m_Value(X))))
return X;
// experimental.vector.reverse(splat(X)) -> splat(X)
if (isSplatValue(Op0))
@@ -5789,8 +6066,8 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
if (Op0 == Op1)
return Op0;
- // Canonicalize constant operand as Op1.
- if (isa<Constant>(Op0))
+ // Canonicalize immediate constant operand as Op1.
+ if (match(Op0, m_ImmConstant()))
std::swap(Op0, Op1);
// Assume undef is the limit value.
@@ -5839,10 +6116,10 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
if (isICmpTrue(Pred, Op1, Op0, Q.getWithoutUndef(), RecursionLimit))
return Op1;
- if (Optional<bool> Imp =
+ if (std::optional<bool> Imp =
isImpliedByDomCondition(Pred, Op0, Op1, Q.CxtI, Q.DL))
return *Imp ? Op0 : Op1;
- if (Optional<bool> Imp =
+ if (std::optional<bool> Imp =
isImpliedByDomCondition(Pred, Op1, Op0, Q.CxtI, Q.DL))
return *Imp ? Op1 : Op0;
@@ -5883,7 +6160,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
// sat(X + MAX) -> MAX
if (match(Op0, m_AllOnes()) || match(Op1, m_AllOnes()))
return Constant::getAllOnesValue(ReturnType);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Intrinsic::sadd_sat:
// sat(X + undef) -> -1
// sat(undef + X) -> -1
@@ -5903,7 +6180,7 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
// sat(0 - X) -> 0, sat(X - MAX) -> 0
if (match(Op0, m_Zero()) || match(Op1, m_AllOnes()))
return Constant::getNullValue(ReturnType);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Intrinsic::ssub_sat:
// X - X -> 0, X - undef -> 0, undef - X -> 0
if (Op0 == Op1 || Q.isUndefValue(Op0) || Q.isUndefValue(Op1))
@@ -5937,6 +6214,20 @@ static Value *simplifyBinaryIntrinsic(Function *F, Value *Op0, Value *Op1,
match(Op1, m_FNeg(m_Specific(Op0))))
return Op1;
break;
+ case Intrinsic::is_fpclass: {
+ if (isa<PoisonValue>(Op0))
+ return PoisonValue::get(ReturnType);
+
+ uint64_t Mask = cast<ConstantInt>(Op1)->getZExtValue();
+ // If all tests are made, it doesn't matter what the value is.
+ if ((Mask & fcAllFlags) == fcAllFlags)
+ return ConstantInt::get(ReturnType, true);
+ if ((Mask & fcAllFlags) == 0)
+ return ConstantInt::get(ReturnType, false);
+ if (Q.isUndefValue(Op0))
+ return UndefValue::get(ReturnType);
+ break;
+ }
case Intrinsic::maxnum:
case Intrinsic::minnum:
case Intrinsic::maximum:
@@ -6034,7 +6325,7 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
if (!Attr.isValid())
return nullptr;
unsigned VScaleMin = Attr.getVScaleRangeMin();
- Optional<unsigned> VScaleMax = Attr.getVScaleRangeMax();
+ std::optional<unsigned> VScaleMax = Attr.getVScaleRangeMax();
if (VScaleMax && VScaleMin == VScaleMax)
return ConstantInt::get(F->getReturnType(), VScaleMin);
return nullptr;
@@ -6098,9 +6389,9 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
Value *Op1 = Call->getArgOperand(1);
Value *Op2 = Call->getArgOperand(2);
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
- if (Value *V = simplifyFPOp({Op0, Op1, Op2}, {}, Q,
- FPI->getExceptionBehavior().value(),
- FPI->getRoundingMode().value()))
+ if (Value *V =
+ simplifyFPOp({Op0, Op1, Op2}, {}, Q, *FPI->getExceptionBehavior(),
+ *FPI->getRoundingMode()))
return V;
return nullptr;
}
@@ -6166,31 +6457,31 @@ static Value *simplifyIntrinsic(CallBase *Call, const SimplifyQuery &Q) {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFAddInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
- Q, FPI->getExceptionBehavior().value(), FPI->getRoundingMode().value());
+ Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fsub: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFSubInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
- Q, FPI->getExceptionBehavior().value(), FPI->getRoundingMode().value());
+ Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fmul: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFMulInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
- Q, FPI->getExceptionBehavior().value(), FPI->getRoundingMode().value());
+ Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_fdiv: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFDivInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
- Q, FPI->getExceptionBehavior().value(), FPI->getRoundingMode().value());
+ Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
}
case Intrinsic::experimental_constrained_frem: {
auto *FPI = cast<ConstrainedFPIntrinsic>(Call);
return simplifyFRemInst(
FPI->getArgOperand(0), FPI->getArgOperand(1), FPI->getFastMathFlags(),
- Q, FPI->getExceptionBehavior().value(), FPI->getRoundingMode().value());
+ Q, *FPI->getExceptionBehavior(), *FPI->getRoundingMode());
}
default:
return nullptr;
@@ -6295,7 +6586,6 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
const SimplifyQuery &SQ,
OptimizationRemarkEmitter *ORE) {
const SimplifyQuery Q = SQ.CxtI ? SQ : SQ.getWithInstruction(I);
- Value *Result = nullptr;
switch (I->getOpcode()) {
default:
@@ -6303,145 +6593,107 @@ static Value *simplifyInstructionWithOperands(Instruction *I,
SmallVector<Constant *, 8> NewConstOps(NewOps.size());
transform(NewOps, NewConstOps.begin(),
[](Value *V) { return cast<Constant>(V); });
- Result = ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI);
+ return ConstantFoldInstOperands(I, NewConstOps, Q.DL, Q.TLI);
}
- break;
+ return nullptr;
case Instruction::FNeg:
- Result = simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q);
- break;
+ return simplifyFNegInst(NewOps[0], I->getFastMathFlags(), Q);
case Instruction::FAdd:
- Result = simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFAddInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
case Instruction::Add:
- Result = simplifyAddInst(
- NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
- break;
+ return simplifyAddInst(NewOps[0], NewOps[1],
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
case Instruction::FSub:
- Result = simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFSubInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
case Instruction::Sub:
- Result = simplifySubInst(
- NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
- break;
+ return simplifySubInst(NewOps[0], NewOps[1],
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
case Instruction::FMul:
- Result = simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFMulInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
case Instruction::Mul:
- Result = simplifyMulInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyMulInst(NewOps[0], NewOps[1],
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
case Instruction::SDiv:
- Result = simplifySDivInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifySDivInst(NewOps[0], NewOps[1],
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
case Instruction::UDiv:
- Result = simplifyUDivInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyUDivInst(NewOps[0], NewOps[1],
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
case Instruction::FDiv:
- Result = simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFDivInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
case Instruction::SRem:
- Result = simplifySRemInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifySRemInst(NewOps[0], NewOps[1], Q);
case Instruction::URem:
- Result = simplifyURemInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyURemInst(NewOps[0], NewOps[1], Q);
case Instruction::FRem:
- Result = simplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFRemInst(NewOps[0], NewOps[1], I->getFastMathFlags(), Q);
case Instruction::Shl:
- Result = simplifyShlInst(
- NewOps[0], NewOps[1], Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
- Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
- break;
+ return simplifyShlInst(NewOps[0], NewOps[1],
+ Q.IIQ.hasNoSignedWrap(cast<BinaryOperator>(I)),
+ Q.IIQ.hasNoUnsignedWrap(cast<BinaryOperator>(I)), Q);
case Instruction::LShr:
- Result = simplifyLShrInst(NewOps[0], NewOps[1],
- Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
- break;
+ return simplifyLShrInst(NewOps[0], NewOps[1],
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
case Instruction::AShr:
- Result = simplifyAShrInst(NewOps[0], NewOps[1],
- Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
- break;
+ return simplifyAShrInst(NewOps[0], NewOps[1],
+ Q.IIQ.isExact(cast<BinaryOperator>(I)), Q);
case Instruction::And:
- Result = simplifyAndInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyAndInst(NewOps[0], NewOps[1], Q);
case Instruction::Or:
- Result = simplifyOrInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyOrInst(NewOps[0], NewOps[1], Q);
case Instruction::Xor:
- Result = simplifyXorInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyXorInst(NewOps[0], NewOps[1], Q);
case Instruction::ICmp:
- Result = simplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0],
- NewOps[1], Q);
- break;
+ return simplifyICmpInst(cast<ICmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1], Q);
case Instruction::FCmp:
- Result = simplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0],
- NewOps[1], I->getFastMathFlags(), Q);
- break;
+ return simplifyFCmpInst(cast<FCmpInst>(I)->getPredicate(), NewOps[0],
+ NewOps[1], I->getFastMathFlags(), Q);
case Instruction::Select:
- Result = simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q);
+ return simplifySelectInst(NewOps[0], NewOps[1], NewOps[2], Q);
break;
case Instruction::GetElementPtr: {
auto *GEPI = cast<GetElementPtrInst>(I);
- Result =
- simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0],
- makeArrayRef(NewOps).slice(1), GEPI->isInBounds(), Q);
- break;
+ return simplifyGEPInst(GEPI->getSourceElementType(), NewOps[0],
+ ArrayRef(NewOps).slice(1), GEPI->isInBounds(), Q);
}
case Instruction::InsertValue: {
InsertValueInst *IV = cast<InsertValueInst>(I);
- Result = simplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q);
- break;
- }
- case Instruction::InsertElement: {
- Result = simplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q);
- break;
+ return simplifyInsertValueInst(NewOps[0], NewOps[1], IV->getIndices(), Q);
}
+ case Instruction::InsertElement:
+ return simplifyInsertElementInst(NewOps[0], NewOps[1], NewOps[2], Q);
case Instruction::ExtractValue: {
auto *EVI = cast<ExtractValueInst>(I);
- Result = simplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q);
- break;
- }
- case Instruction::ExtractElement: {
- Result = simplifyExtractElementInst(NewOps[0], NewOps[1], Q);
- break;
+ return simplifyExtractValueInst(NewOps[0], EVI->getIndices(), Q);
}
+ case Instruction::ExtractElement:
+ return simplifyExtractElementInst(NewOps[0], NewOps[1], Q);
case Instruction::ShuffleVector: {
auto *SVI = cast<ShuffleVectorInst>(I);
- Result = simplifyShuffleVectorInst(
- NewOps[0], NewOps[1], SVI->getShuffleMask(), SVI->getType(), Q);
- break;
+ return simplifyShuffleVectorInst(NewOps[0], NewOps[1],
+ SVI->getShuffleMask(), SVI->getType(), Q);
}
case Instruction::PHI:
- Result = simplifyPHINode(cast<PHINode>(I), NewOps, Q);
- break;
- case Instruction::Call: {
+ return simplifyPHINode(cast<PHINode>(I), NewOps, Q);
+ case Instruction::Call:
// TODO: Use NewOps
- Result = simplifyCall(cast<CallInst>(I), Q);
- break;
- }
+ return simplifyCall(cast<CallInst>(I), Q);
case Instruction::Freeze:
- Result = llvm::simplifyFreezeInst(NewOps[0], Q);
- break;
+ return llvm::simplifyFreezeInst(NewOps[0], Q);
#define HANDLE_CAST_INST(num, opc, clas) case Instruction::opc:
#include "llvm/IR/Instruction.def"
#undef HANDLE_CAST_INST
- Result = simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q);
- break;
+ return simplifyCastInst(I->getOpcode(), NewOps[0], I->getType(), Q);
case Instruction::Alloca:
// No simplifications for Alloca and it can't be constant folded.
- Result = nullptr;
- break;
+ return nullptr;
case Instruction::Load:
- Result = simplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q);
- break;
+ return simplifyLoadInst(cast<LoadInst>(I), NewOps[0], Q);
}
-
- /// If called on unreachable code, the above logic may report that the
- /// instruction simplified to itself. Make life easier for users by
- /// detecting that case here, returning a safe value instead.
- return Result == I ? UndefValue::get(I->getType()) : Result;
}
Value *llvm::simplifyInstructionWithOperands(Instruction *I,
@@ -6456,7 +6708,12 @@ Value *llvm::simplifyInstructionWithOperands(Instruction *I,
Value *llvm::simplifyInstruction(Instruction *I, const SimplifyQuery &SQ,
OptimizationRemarkEmitter *ORE) {
SmallVector<Value *, 8> Ops(I->operands());
- return ::simplifyInstructionWithOperands(I, Ops, SQ, ORE);
+ Value *Result = ::simplifyInstructionWithOperands(I, Ops, SQ, ORE);
+
+ /// If called on unreachable code, the instruction may simplify to itself.
+ /// Make life easier for users by detecting that case here, and returning a
+ /// safe value instead.
+ return Result == I ? UndefValue::get(I->getType()) : Result;
}
/// Implementation of recursive simplification through an instruction's