summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r--lib/Transforms/InstCombine/InstCombineAddSub.cpp135
-rw-r--r--lib/Transforms/InstCombine/InstCombineAndOrXor.cpp1192
-rw-r--r--lib/Transforms/InstCombine/InstCombineCalls.cpp1148
-rw-r--r--lib/Transforms/InstCombine/InstCombineCasts.cpp169
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp212
-rw-r--r--lib/Transforms/InstCombine/InstCombineInternal.h61
-rw-r--r--lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp145
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp58
-rw-r--r--lib/Transforms/InstCombine/InstCombinePHI.cpp6
-rw-r--r--lib/Transforms/InstCombine/InstCombineSelect.cpp86
-rw-r--r--lib/Transforms/InstCombine/InstCombineShifts.cpp703
-rw-r--r--lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp562
-rw-r--r--lib/Transforms/InstCombine/InstCombineVectorOps.cpp44
-rw-r--r--lib/Transforms/InstCombine/InstructionCombining.cpp285
14 files changed, 3029 insertions, 1777 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 2d34c1cc74bd..174ec8036274 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -902,7 +902,7 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS,
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
- // Addition of two 2's compliment numbers having opposite signs will never
+ // Addition of two 2's complement numbers having opposite signs will never
// overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1]))
@@ -939,7 +939,7 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS,
APInt RHSKnownOne(BitWidth, 0);
computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI);
- // Subtraction of two 2's compliment numbers having identical signs will
+ // Subtraction of two 2's complement numbers having identical signs will
// never overflow.
if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) ||
(LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1]))
@@ -1042,43 +1042,42 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
- const APInt *Val;
- if (match(RHS, m_APInt(Val))) {
- // X + (signbit) --> X ^ signbit
- if (Val->isSignBit())
+ const APInt *RHSC;
+ if (match(RHS, m_APInt(RHSC))) {
+ if (RHSC->isSignBit()) {
+ // If wrapping is not allowed, then the addition must set the sign bit:
+ // X + (signbit) --> X | signbit
+ if (I.hasNoSignedWrap() || I.hasNoUnsignedWrap())
+ return BinaryOperator::CreateOr(LHS, RHS);
+
+ // If wrapping is allowed, then the addition flips the sign bit of LHS:
+ // X + (signbit) --> X ^ signbit
return BinaryOperator::CreateXor(LHS, RHS);
+ }
// Is this add the last step in a convoluted sext?
Value *X;
const APInt *C;
if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) &&
C->isMinSignedValue() &&
- C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) {
+ C->sext(LHS->getType()->getScalarSizeInBits()) == *RHSC) {
// add(zext(xor i16 X, -32768), -32768) --> sext X
return CastInst::Create(Instruction::SExt, X, LHS->getType());
}
- if (Val->isNegative() &&
+ if (RHSC->isNegative() &&
match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) &&
- Val->sge(-C->sext(Val->getBitWidth()))) {
+ RHSC->sge(-C->sext(RHSC->getBitWidth()))) {
// (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val))
- return CastInst::Create(
- Instruction::ZExt,
- Builder->CreateNUWAdd(
- X, Constant::getIntegerValue(X->getType(),
- *C + Val->trunc(C->getBitWidth()))),
- I.getType());
+ Constant *NewC =
+ ConstantInt::get(X->getType(), *C + RHSC->trunc(C->getBitWidth()));
+ return new ZExtInst(Builder->CreateNUWAdd(X, NewC), I.getType());
}
}
// FIXME: Use the match above instead of dyn_cast to allow these transforms
// for splat vectors.
if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) {
- // See if SimplifyDemandedBits can simplify this. This handles stuff like
- // (X & 254)+1 -> (X&254)|1
- if (SimplifyDemandedInstructionBits(I))
- return &I;
-
// zext(bool) + C -> bool ? C + 1 : C
if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS))
if (ZI->getSrcTy()->isIntegerTy(1))
@@ -1129,8 +1128,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
}
}
- if (isa<Constant>(RHS) && isa<PHINode>(LHS))
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (isa<Constant>(RHS))
+ if (Instruction *NV = foldOpWithConstantIntoOperand(I))
return NV;
if (I.getType()->getScalarType()->isIntegerTy(1))
@@ -1201,11 +1200,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(NewAdd, C2);
}
}
-
- // Try to fold constant add into select arguments.
- if (SelectInst *SI = dyn_cast<SelectInst>(LHS))
- if (Instruction *R = FoldOpIntoSelect(I, SI))
- return R;
}
// add (select X 0 (sub n A)) A --> select X A n
@@ -1253,7 +1247,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// (add (sext x), (sext y)) --> (sext (add int x, y))
if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of sexts), and if the
// integer add will not overflow.
if (LHSConv->getOperand(0)->getType() ==
@@ -1290,7 +1284,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
// (add (zext x), (zext y)) --> (zext (add int x, y))
if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of zexts), and if the
// integer add will not overflow.
if (LHSConv->getOperand(0)->getType() ==
@@ -1311,13 +1305,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
{
Value *A = nullptr, *B = nullptr;
if (match(RHS, m_Xor(m_Value(A), m_Value(B))) &&
- (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(LHS, m_And(m_Specific(B), m_Specific(A)))))
+ match(LHS, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
if (match(LHS, m_Xor(m_Value(A), m_Value(B))) &&
- (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(RHS, m_And(m_Specific(B), m_Specific(A)))))
+ match(RHS, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
}
@@ -1325,8 +1317,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
{
Value *A = nullptr, *B = nullptr;
if (match(RHS, m_Or(m_Value(A), m_Value(B))) &&
- (match(LHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(LHS, m_And(m_Specific(B), m_Specific(A))))) {
+ match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) {
auto *New = BinaryOperator::CreateAdd(A, B);
New->setHasNoSignedWrap(I.hasNoSignedWrap());
New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
@@ -1334,8 +1325,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
}
if (match(LHS, m_Or(m_Value(A), m_Value(B))) &&
- (match(RHS, m_And(m_Specific(A), m_Specific(B))) ||
- match(RHS, m_And(m_Specific(B), m_Specific(A))))) {
+ match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) {
auto *New = BinaryOperator::CreateAdd(A, B);
New->setHasNoSignedWrap(I.hasNoSignedWrap());
New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
@@ -1394,6 +1384,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
+ Value *LHSIntVal = LHSConv->getOperand(0);
+
// (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst))
// ... if the constant fits in the integer value. This is useful for things
// like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer
@@ -1401,12 +1393,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) {
Constant *CI =
- ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType());
+ ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
if (LHSConv->hasOneUse() &&
ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
- WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) {
+ WillNotOverflowSignedAdd(LHSIntVal, CI, I)) {
// Insert the new integer add.
- Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
+ Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal,
CI, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
@@ -1414,17 +1406,17 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
// (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y))
if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) {
- // Only do this if x/y have the same type, if at last one of them has a
+ Value *RHSIntVal = RHSConv->getOperand(0);
+
+ // Only do this if x/y have the same type, if at least one of them has a
// single use (so we don't increase the number of int->fp conversions),
// and if the integer add will not overflow.
- if (LHSConv->getOperand(0)->getType() ==
- RHSConv->getOperand(0)->getType() &&
+ if (LHSIntVal->getType() == RHSIntVal->getType() &&
(LHSConv->hasOneUse() || RHSConv->hasOneUse()) &&
- WillNotOverflowSignedAdd(LHSConv->getOperand(0),
- RHSConv->getOperand(0), I)) {
+ WillNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) {
// Insert the new integer add.
- Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0),
- RHSConv->getOperand(0),"addconv");
+ Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal,
+ RHSIntVal, "addconv");
return new SIToFPInst(NewAdd, I.getType());
}
}
@@ -1562,7 +1554,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return Res;
}
- if (I.getType()->isIntegerTy(1))
+ if (I.getType()->getScalarType()->isIntegerTy(1))
return BinaryOperator::CreateXor(Op0, Op1);
// Replace (-1 - A) with (~A).
@@ -1580,14 +1572,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
+ // Try to fold constant sub into PHI values.
+ if (PHINode *PN = dyn_cast<PHINode>(Op1))
+ if (Instruction *R = foldOpIntoPhi(I, PN))
+ return R;
+
// C-(X+C2) --> (C-C2)-X
Constant *C2;
if (match(Op1, m_Add(m_Value(X), m_Constant(C2))))
return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
- if (SimplifyDemandedInstructionBits(I))
- return &I;
-
// Fold (sub 0, (zext bool to B)) --> (sext bool to B)
if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X))))
if (X->getType()->getScalarType()->isIntegerTy(1))
@@ -1622,11 +1616,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
// zero.
- if ((*Op0C + 1).isPowerOf2()) {
- APInt KnownZero(BitWidth, 0);
- APInt KnownOne(BitWidth, 0);
- computeKnownBits(&I, KnownZero, KnownOne, 0, &I);
- if ((*Op0C | KnownZero).isAllOnesValue())
+ if (Op0C->isMask()) {
+ APInt RHSKnownZero(BitWidth, 0);
+ APInt RHSKnownOne(BitWidth, 0);
+ computeKnownBits(Op1, RHSKnownZero, RHSKnownOne, 0, &I);
+ if ((*Op0C | RHSKnownZero).isAllOnesValue())
return BinaryOperator::CreateXor(Op1, Op0);
}
}
@@ -1634,8 +1628,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
{
Value *Y;
// X-(X+Y) == -Y X-(Y+X) == -Y
- if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) ||
- match(Op1, m_Add(m_Value(Y), m_Specific(Op0))))
+ if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y))))
return BinaryOperator::CreateNeg(Y);
// (X-Y)-X == -Y
@@ -1645,18 +1638,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// (sub (or A, B) (xor A, B)) --> (and A, B)
{
- Value *A = nullptr, *B = nullptr;
+ Value *A, *B;
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) &&
- (match(Op0, m_Or(m_Specific(A), m_Specific(B))) ||
- match(Op0, m_Or(m_Specific(B), m_Specific(A)))))
+ match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateAnd(A, B);
}
- if (Op0->hasOneUse()) {
- Value *Y = nullptr;
+ {
+ Value *Y;
// ((X | Y) - X) --> (~X & Y)
- if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) ||
- match(Op0, m_Or(m_Specific(Op1), m_Value(Y))))
+ if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1)))))
return BinaryOperator::CreateAnd(
Y, Builder->CreateNot(Op1, Op1->getName() + ".not"));
}
@@ -1664,7 +1655,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
if (Op1->hasOneUse()) {
Value *X = nullptr, *Y = nullptr, *Z = nullptr;
Constant *C = nullptr;
- Constant *CI = nullptr;
// (X - (Y - Z)) --> (X + (Z - Y)).
if (match(Op1, m_Sub(m_Value(Y), m_Value(Z))))
@@ -1673,8 +1663,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// (X - (X & Y)) --> (X & ~Y)
//
- if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) ||
- match(Op1, m_And(m_Specific(Op0), m_Value(Y))))
+ if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0))))
return BinaryOperator::CreateAnd(Op0,
Builder->CreateNot(Y, Y->getName() + ".not"));
@@ -1702,14 +1691,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
// X - A*-B -> X + A*B
// X - -A*B -> X + A*B
Value *A, *B;
- if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) ||
- match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B))))
+ Constant *CI;
+ if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B)))))
return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B));
// X - A*CI -> X + A*-CI
- // X - CI*A -> X + A*-CI
- if (match(Op1, m_Mul(m_Value(A), m_Constant(CI))) ||
- match(Op1, m_Mul(m_Constant(CI), m_Value(A)))) {
+ // No need to handle commuted multiply because multiply handling will
+ // ensure constant will be move to the right hand side.
+ if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) {
Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI));
return BinaryOperator::CreateAdd(Op0, NewMul);
}
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index da5384a86aac..b2a41c699202 100644
--- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -137,9 +137,8 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) {
}
/// This handles expressions of the form ((val OP C1) & C2). Where
-/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is
-/// guaranteed to be a binary operator.
-Instruction *InstCombiner::OptAndOp(Instruction *Op,
+/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'.
+Instruction *InstCombiner::OptAndOp(BinaryOperator *Op,
ConstantInt *OpRHS,
ConstantInt *AndRHS,
BinaryOperator &TheAnd) {
@@ -149,6 +148,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,
Together = ConstantExpr::getAnd(AndRHS, OpRHS);
switch (Op->getOpcode()) {
+ default: break;
case Instruction::Xor:
if (Op->hasOneUse()) {
// (X ^ C1) & C2 --> (X & C2) ^ (C1&C2)
@@ -159,13 +159,6 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op,
break;
case Instruction::Or:
if (Op->hasOneUse()){
- if (Together != OpRHS) {
- // (X | C1) & C2 --> (X | (C1&C2)) & C2
- Value *Or = Builder->CreateOr(X, Together);
- Or->takeName(Op);
- return BinaryOperator::CreateAnd(Or, AndRHS);
- }
-
ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together);
if (TogetherCI && !TogetherCI->isZero()){
// (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1
@@ -302,178 +295,91 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo);
}
-/// Returns true iff Val consists of one contiguous run of 1s with any number
-/// of 0s on either side. The 1s are allowed to wrap from LSB to MSB,
-/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is
-/// not, since all 1s are not contiguous.
-static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) {
- const APInt& V = Val->getValue();
- uint32_t BitWidth = Val->getType()->getBitWidth();
- if (!APIntOps::isShiftedMask(BitWidth, V)) return false;
-
- // look for the first zero bit after the run of ones
- MB = BitWidth - ((V - 1) ^ V).countLeadingZeros();
- // look for the first non-zero bit
- ME = V.getActiveBits();
- return true;
-}
-
-/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines
-/// whether the operator is a sub. If we can fold one of the following xforms:
+/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns
+/// that can be simplified.
+/// One of A and B is considered the mask. The other is the value. This is
+/// described as the "AMask" or "BMask" part of the enum. If the enum contains
+/// only "Mask", then both A and B can be considered masks. If A is the mask,
+/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0.
+/// If both A and C are constants, this proof is also easy.
+/// For the following explanations, we assume that A is the mask.
///
-/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask
-/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0
-/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0
+/// "AllOnes" declares that the comparison is true only if (A & B) == A or all
+/// bits of A are set in B.
+/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes
///
-/// return (A +/- B).
+/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all
+/// bits of A are cleared in B.
+/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes
+///
+/// "Mixed" declares that (A & B) == C and C might or might not contain any
+/// number of one bits and zero bits.
+/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed
+///
+/// "Not" means that in above descriptions "==" should be replaced by "!=".
+/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes
///
-Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS,
- ConstantInt *Mask, bool isSub,
- Instruction &I) {
- Instruction *LHSI = dyn_cast<Instruction>(LHS);
- if (!LHSI || LHSI->getNumOperands() != 2 ||
- !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr;
-
- ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1));
-
- switch (LHSI->getOpcode()) {
- default: return nullptr;
- case Instruction::And:
- if (ConstantExpr::getAnd(N, Mask) == Mask) {
- // If the AndRHS is a power of two minus one (0+1+), this is simple.
- if ((Mask->getValue().countLeadingZeros() +
- Mask->getValue().countPopulation()) ==
- Mask->getValue().getBitWidth())
- break;
-
- // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+
- // part, we don't need any explicit masks to take them out of A. If that
- // is all N is, ignore it.
- uint32_t MB = 0, ME = 0;
- if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive
- uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth();
- APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1));
- if (MaskedValueIsZero(RHS, Mask, 0, &I))
- break;
- }
- }
- return nullptr;
- case Instruction::Or:
- case Instruction::Xor:
- // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0
- if ((Mask->getValue().countLeadingZeros() +
- Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth()
- && ConstantExpr::getAnd(N, Mask)->isNullValue())
- break;
- return nullptr;
- }
-
- if (isSub)
- return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold");
- return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold");
-}
-
-/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C)
-/// One of A and B is considered the mask, the other the value. This is
-/// described as the "AMask" or "BMask" part of the enum. If the enum
-/// contains only "Mask", then both A and B can be considered masks.
-/// If A is the mask, then it was proven, that (A & C) == C. This
-/// is trivial if C == A, or C == 0. If both A and C are constants, this
-/// proof is also easy.
-/// For the following explanations we assume that A is the mask.
-/// The part "AllOnes" declares, that the comparison is true only
-/// if (A & B) == A, or all bits of A are set in B.
-/// Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes
-/// The part "AllZeroes" declares, that the comparison is true only
-/// if (A & B) == 0, or all bits of A are cleared in B.
-/// Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes
-/// The part "Mixed" declares, that (A & B) == C and C might or might not
-/// contain any number of one bits and zero bits.
-/// Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed
-/// The Part "Not" means, that in above descriptions "==" should be replaced
-/// by "!=".
-/// Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes
/// If the mask A contains a single bit, then the following is equivalent:
/// (icmp eq (A & B), A) equals (icmp ne (A & B), 0)
/// (icmp ne (A & B), A) equals (icmp eq (A & B), 0)
enum MaskedICmpType {
- FoldMskICmp_AMask_AllOnes = 1,
- FoldMskICmp_AMask_NotAllOnes = 2,
- FoldMskICmp_BMask_AllOnes = 4,
- FoldMskICmp_BMask_NotAllOnes = 8,
- FoldMskICmp_Mask_AllZeroes = 16,
- FoldMskICmp_Mask_NotAllZeroes = 32,
- FoldMskICmp_AMask_Mixed = 64,
- FoldMskICmp_AMask_NotMixed = 128,
- FoldMskICmp_BMask_Mixed = 256,
- FoldMskICmp_BMask_NotMixed = 512
+ AMask_AllOnes = 1,
+ AMask_NotAllOnes = 2,
+ BMask_AllOnes = 4,
+ BMask_NotAllOnes = 8,
+ Mask_AllZeros = 16,
+ Mask_NotAllZeros = 32,
+ AMask_Mixed = 64,
+ AMask_NotMixed = 128,
+ BMask_Mixed = 256,
+ BMask_NotMixed = 512
};
-/// Return the set of pattern classes (from MaskedICmpType)
-/// that (icmp SCC (A & B), C) satisfies.
-static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
- ICmpInst::Predicate SCC)
-{
+/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C)
+/// satisfies.
+static unsigned getMaskedICmpType(Value *A, Value *B, Value *C,
+ ICmpInst::Predicate Pred) {
ConstantInt *ACst = dyn_cast<ConstantInt>(A);
ConstantInt *BCst = dyn_cast<ConstantInt>(B);
ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- bool icmp_eq = (SCC == ICmpInst::ICMP_EQ);
- bool icmp_abit = (ACst && !ACst->isZero() &&
- ACst->getValue().isPowerOf2());
- bool icmp_bbit = (BCst && !BCst->isZero() &&
- BCst->getValue().isPowerOf2());
- unsigned result = 0;
+ bool IsEq = (Pred == ICmpInst::ICMP_EQ);
+ bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2());
+ bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2());
+ unsigned MaskVal = 0;
if (CCst && CCst->isZero()) {
// if C is zero, then both A and B qualify as mask
- result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_AMask_Mixed |
- FoldMskICmp_BMask_Mixed)
- : (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_AMask_NotMixed |
- FoldMskICmp_BMask_NotMixed));
- if (icmp_abit)
- result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes |
- FoldMskICmp_AMask_NotMixed)
- : (FoldMskICmp_AMask_AllOnes |
- FoldMskICmp_AMask_Mixed));
- if (icmp_bbit)
- result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_BMask_NotMixed)
- : (FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_BMask_Mixed));
- return result;
+ MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed)
+ : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed));
+ if (IsAPow2)
+ MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed)
+ : (AMask_AllOnes | AMask_Mixed));
+ if (IsBPow2)
+ MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed)
+ : (BMask_AllOnes | BMask_Mixed));
+ return MaskVal;
}
+
if (A == C) {
- result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes |
- FoldMskICmp_AMask_Mixed)
- : (FoldMskICmp_AMask_NotAllOnes |
- FoldMskICmp_AMask_NotMixed));
- if (icmp_abit)
- result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_AMask_NotMixed)
- : (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_AMask_Mixed));
- } else if (ACst && CCst &&
- ConstantExpr::getAnd(ACst, CCst) == CCst) {
- result |= (icmp_eq ? FoldMskICmp_AMask_Mixed
- : FoldMskICmp_AMask_NotMixed);
+ MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed)
+ : (AMask_NotAllOnes | AMask_NotMixed));
+ if (IsAPow2)
+ MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed)
+ : (Mask_AllZeros | AMask_Mixed));
+ } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) {
+ MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed);
}
+
if (B == C) {
- result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_BMask_Mixed)
- : (FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_BMask_NotMixed));
- if (icmp_bbit)
- result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes |
- FoldMskICmp_BMask_NotMixed)
- : (FoldMskICmp_Mask_AllZeroes |
- FoldMskICmp_BMask_Mixed));
- } else if (BCst && CCst &&
- ConstantExpr::getAnd(BCst, CCst) == CCst) {
- result |= (icmp_eq ? FoldMskICmp_BMask_Mixed
- : FoldMskICmp_BMask_NotMixed);
- }
- return result;
+ MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed)
+ : (BMask_NotAllOnes | BMask_NotMixed));
+ if (IsBPow2)
+ MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed)
+ : (Mask_AllZeros | BMask_Mixed));
+ } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) {
+ MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed);
+ }
+
+ return MaskVal;
}
/// Convert an analysis of a masked ICmp into its equivalent if all boolean
@@ -482,32 +388,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C,
/// involves swapping those bits over.
static unsigned conjugateICmpMask(unsigned Mask) {
unsigned NewMask;
- NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes |
- FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed |
- FoldMskICmp_BMask_Mixed))
+ NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros |
+ AMask_Mixed | BMask_Mixed))
<< 1;
- NewMask |=
- (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes |
- FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed |
- FoldMskICmp_BMask_NotMixed))
- >> 1;
+ NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros |
+ AMask_NotMixed | BMask_NotMixed))
+ >> 1;
return NewMask;
}
-/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E)
-/// Return the set of pattern classes (from MaskedICmpType)
-/// that both LHS and RHS satisfy.
-static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
- Value*& B, Value*& C,
- Value*& D, Value*& E,
- ICmpInst *LHS, ICmpInst *RHS,
- ICmpInst::Predicate &LHSCC,
- ICmpInst::Predicate &RHSCC) {
- if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0;
+/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E).
+/// Return the set of pattern classes (from MaskedICmpType) that both LHS and
+/// RHS satisfy.
+static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
+ Value *&D, Value *&E, ICmpInst *LHS,
+ ICmpInst *RHS,
+ ICmpInst::Predicate &PredL,
+ ICmpInst::Predicate &PredR) {
+ if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType())
+ return 0;
// vectors are not (yet?) supported
- if (LHS->getOperand(0)->getType()->isVectorTy()) return 0;
+ if (LHS->getOperand(0)->getType()->isVectorTy())
+ return 0;
// Here comes the tricky part:
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -517,9 +421,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
// above.
Value *L1 = LHS->getOperand(0);
Value *L2 = LHS->getOperand(1);
- Value *L11,*L12,*L21,*L22;
+ Value *L11, *L12, *L21, *L22;
// Check whether the icmp can be decomposed into a bit test.
- if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) {
+ if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) {
L21 = L22 = L1 = nullptr;
} else {
// Look for ANDs in the LHS icmp.
@@ -543,22 +447,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
}
// Bail if LHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(LHSCC))
+ if (!ICmpInst::isEquality(PredL))
return 0;
Value *R1 = RHS->getOperand(0);
Value *R2 = RHS->getOperand(1);
- Value *R11,*R12;
- bool ok = false;
- if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) {
+ Value *R11, *R12;
+ bool Ok = false;
+ if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) {
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12;
+ A = R11;
+ D = R12;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11;
+ A = R12;
+ D = R11;
} else {
return 0;
}
- E = R2; R1 = nullptr; ok = true;
+ E = R2;
+ R1 = nullptr;
+ Ok = true;
} else if (R1->getType()->isIntegerTy()) {
if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) {
// As before, model no mask as a trivial mask if it'll let us do an
@@ -568,46 +476,62 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
}
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12; E = R2; ok = true;
+ A = R11;
+ D = R12;
+ E = R2;
+ Ok = true;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11; E = R2; ok = true;
+ A = R12;
+ D = R11;
+ E = R2;
+ Ok = true;
}
}
// Bail if RHS was a icmp that can't be decomposed into an equality.
- if (!ICmpInst::isEquality(RHSCC))
+ if (!ICmpInst::isEquality(PredR))
return 0;
// Look for ANDs on the right side of the RHS icmp.
- if (!ok && R2->getType()->isIntegerTy()) {
+ if (!Ok && R2->getType()->isIntegerTy()) {
if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) {
R11 = R2;
R12 = Constant::getAllOnesValue(R2->getType());
}
if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) {
- A = R11; D = R12; E = R1; ok = true;
+ A = R11;
+ D = R12;
+ E = R1;
+ Ok = true;
} else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) {
- A = R12; D = R11; E = R1; ok = true;
+ A = R12;
+ D = R11;
+ E = R1;
+ Ok = true;
} else {
return 0;
}
}
- if (!ok)
+ if (!Ok)
return 0;
if (L11 == A) {
- B = L12; C = L2;
+ B = L12;
+ C = L2;
} else if (L12 == A) {
- B = L11; C = L2;
+ B = L11;
+ C = L2;
} else if (L21 == A) {
- B = L22; C = L1;
+ B = L22;
+ C = L1;
} else if (L22 == A) {
- B = L21; C = L1;
+ B = L21;
+ C = L1;
}
- unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC);
- unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC);
+ unsigned LeftType = getMaskedICmpType(A, B, C, PredL);
+ unsigned RightType = getMaskedICmpType(A, D, E, PredR);
return LeftType & RightType;
}
@@ -616,12 +540,14 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A,
static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
llvm::InstCombiner::BuilderTy *Builder) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
- unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS,
- LHSCC, RHSCC);
- if (Mask == 0) return nullptr;
- assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) &&
- "foldLogOpOfMaskedICmpsHelper must return an equality predicate.");
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+ unsigned Mask =
+ getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
+ if (Mask == 0)
+ return nullptr;
+
+ assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) &&
+ "Expected equality predicates for masked type of icmps.");
// In full generality:
// (icmp (A & B) Op C) | (icmp (A & D) Op E)
@@ -642,7 +568,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
Mask = conjugateICmpMask(Mask);
}
- if (Mask & FoldMskICmp_Mask_AllZeroes) {
+ if (Mask & Mask_AllZeros) {
// (icmp eq (A & B), 0) & (icmp eq (A & D), 0)
// -> (icmp eq (A & (B|D)), 0)
Value *NewOr = Builder->CreateOr(B, D);
@@ -653,14 +579,14 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
Value *Zero = Constant::getNullValue(A->getType());
return Builder->CreateICmp(NewCC, NewAnd, Zero);
}
- if (Mask & FoldMskICmp_BMask_AllOnes) {
+ if (Mask & BMask_AllOnes) {
// (icmp eq (A & B), B) & (icmp eq (A & D), D)
// -> (icmp eq (A & (B|D)), (B|D))
Value *NewOr = Builder->CreateOr(B, D);
Value *NewAnd = Builder->CreateAnd(A, NewOr);
return Builder->CreateICmp(NewCC, NewAnd, NewOr);
}
- if (Mask & FoldMskICmp_AMask_AllOnes) {
+ if (Mask & AMask_AllOnes) {
// (icmp eq (A & B), A) & (icmp eq (A & D), A)
// -> (icmp eq (A & (B&D)), A)
Value *NewAnd1 = Builder->CreateAnd(B, D);
@@ -672,11 +598,13 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// their actual values. This isn't strictly necessary, just a "handle the
// easy cases for now" decision.
ConstantInt *BCst = dyn_cast<ConstantInt>(B);
- if (!BCst) return nullptr;
+ if (!BCst)
+ return nullptr;
ConstantInt *DCst = dyn_cast<ConstantInt>(D);
- if (!DCst) return nullptr;
+ if (!DCst)
+ return nullptr;
- if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) {
+ if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) {
// (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and
// (icmp ne (A & B), B) & (icmp ne (A & D), D)
// -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0)
@@ -689,7 +617,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
else if (NewMask == DCst->getValue())
return RHS;
}
- if (Mask & FoldMskICmp_AMask_NotAllOnes) {
+
+ if (Mask & AMask_NotAllOnes) {
// (icmp ne (A & B), B) & (icmp ne (A & D), D)
// -> (icmp ne (A & B), A) or (icmp ne (A & D), A)
// Only valid if one of the masks is a superset of the other (check "B|D" is
@@ -701,7 +630,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
else if (NewMask == DCst->getValue())
return RHS;
}
- if (Mask & FoldMskICmp_BMask_Mixed) {
+
+ if (Mask & BMask_Mixed) {
// (icmp eq (A & B), C) & (icmp eq (A & D), E)
// We already know that B & C == C && D & E == E.
// If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of
@@ -713,23 +643,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
// (icmp ne (A & B), B) & (icmp eq (A & D), D)
// with B and D, having a single bit set.
ConstantInt *CCst = dyn_cast<ConstantInt>(C);
- if (!CCst) return nullptr;
+ if (!CCst)
+ return nullptr;
ConstantInt *ECst = dyn_cast<ConstantInt>(E);
- if (!ECst) return nullptr;
- if (LHSCC != NewCC)
+ if (!ECst)
+ return nullptr;
+ if (PredL != NewCC)
CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst));
- if (RHSCC != NewCC)
+ if (PredR != NewCC)
ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst));
+
// If there is a conflict, we should actually return a false for the
// whole construct.
if (((BCst->getValue() & DCst->getValue()) &
(CCst->getValue() ^ ECst->getValue())) != 0)
return ConstantInt::get(LHS->getType(), !IsAnd);
+
Value *NewOr1 = Builder->CreateOr(B, D);
Value *NewOr2 = ConstantExpr::getOr(CCst, ECst);
Value *NewAnd = Builder->CreateAnd(A, NewOr1);
return Builder->CreateICmp(NewCC, NewAnd, NewOr2);
}
+
return nullptr;
}
@@ -789,12 +724,67 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1,
return Builder->CreateICmp(NewPred, Input, RangeEnd);
}
+static Value *
+foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS,
+ bool JoinedByAnd,
+ InstCombiner::BuilderTy *Builder) {
+ Value *X = LHS->getOperand(0);
+ if (X != RHS->getOperand(0))
+ return nullptr;
+
+ const APInt *C1, *C2;
+ if (!match(LHS->getOperand(1), m_APInt(C1)) ||
+ !match(RHS->getOperand(1), m_APInt(C2)))
+ return nullptr;
+
+ // We only handle (X != C1 && X != C2) and (X == C1 || X == C2).
+ ICmpInst::Predicate Pred = LHS->getPredicate();
+ if (Pred != RHS->getPredicate())
+ return nullptr;
+ if (JoinedByAnd && Pred != ICmpInst::ICMP_NE)
+ return nullptr;
+ if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ)
+ return nullptr;
+
+ // The larger unsigned constant goes on the right.
+ if (C1->ugt(*C2))
+ std::swap(C1, C2);
+
+ APInt Xor = *C1 ^ *C2;
+ if (Xor.isPowerOf2()) {
+ // If LHSC and RHSC differ by only one bit, then set that bit in X and
+ // compare against the larger constant:
+ // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2
+ // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2
+ // We choose an 'or' with a Pow2 constant rather than the inverse mask with
+ // 'and' because that may lead to smaller codegen from a smaller constant.
+ Value *Or = Builder->CreateOr(X, ConstantInt::get(X->getType(), Xor));
+ return Builder->CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2));
+ }
+
+ // Special case: get the ordering right when the values wrap around zero.
+ // Ie, we assumed the constants were unsigned when swapping earlier.
+ if (*C1 == 0 && C2->isAllOnesValue())
+ std::swap(C1, C2);
+
+ if (*C1 == *C2 - 1) {
+ // (X == 13 || X == 14) --> X - 13 <=u 1
+ // (X != 13 && X != 14) --> X - 13 >u 1
+ // An 'add' is the canonical IR form, so favor that over a 'sub'.
+ Value *Add = Builder->CreateAdd(X, ConstantInt::get(X->getType(), -(*C1)));
+ auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE;
+ return Builder->CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1));
+ }
+
+ return nullptr;
+}
+
/// Fold (icmp)&(icmp) if possible.
Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
// (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B)
- if (PredicatesFoldable(LHSCC, RHSCC)) {
+ if (PredicatesFoldable(PredL, PredR)) {
if (LHS->getOperand(0) == RHS->getOperand(1) &&
LHS->getOperand(1) == RHS->getOperand(0))
LHS->swapOperands();
@@ -819,86 +809,90 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false))
return V;
+ if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder))
+ return V;
+
// This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2).
- Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
- ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (!LHSCst || !RHSCst) return nullptr;
+ Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ if (!LHSC || !RHSC)
+ return nullptr;
- if (LHSCst == RHSCst && LHSCC == RHSCC) {
+ if (LHSC == RHSC && PredL == PredR) {
// (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C)
// where C is a power of 2 or
// (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0)
- if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) ||
- (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) {
- Value *NewOr = Builder->CreateOr(Val, Val2);
- return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
+ if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) ||
+ (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) {
+ Value *NewOr = Builder->CreateOr(LHS0, RHS0);
+ return Builder->CreateICmp(PredL, NewOr, LHSC);
}
}
// (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2
// where CMAX is the all ones value for the truncated type,
// iff the lower bits of C2 and CA are zero.
- if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC &&
- LHS->hasOneUse() && RHS->hasOneUse()) {
+ if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() &&
+ RHS->hasOneUse()) {
Value *V;
- ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr;
+ ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr;
// (trunc x) == C1 & (and x, CA) == C2
// (and x, CA) == C2 & (trunc x) == C1
- if (match(Val2, m_Trunc(m_Value(V))) &&
- match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) {
- SmallCst = RHSCst;
- BigCst = LHSCst;
- } else if (match(Val, m_Trunc(m_Value(V))) &&
- match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) {
- SmallCst = LHSCst;
- BigCst = RHSCst;
+ if (match(RHS0, m_Trunc(m_Value(V))) &&
+ match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) {
+ SmallC = RHSC;
+ BigC = LHSC;
+ } else if (match(LHS0, m_Trunc(m_Value(V))) &&
+ match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) {
+ SmallC = LHSC;
+ BigC = RHSC;
}
- if (SmallCst && BigCst) {
- unsigned BigBitSize = BigCst->getType()->getBitWidth();
- unsigned SmallBitSize = SmallCst->getType()->getBitWidth();
+ if (SmallC && BigC) {
+ unsigned BigBitSize = BigC->getType()->getBitWidth();
+ unsigned SmallBitSize = SmallC->getType()->getBitWidth();
// Check that the low bits are zero.
APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize);
- if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) {
- Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue());
- APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue();
- Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N);
- return Builder->CreateICmp(LHSCC, NewAnd, NewVal);
+ if ((Low & AndC->getValue()) == 0 && (Low & BigC->getValue()) == 0) {
+ Value *NewAnd = Builder->CreateAnd(V, Low | AndC->getValue());
+ APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue();
+ Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N);
+ return Builder->CreateICmp(PredL, NewAnd, NewVal);
}
}
}
// From here on, we only handle:
// (icmp1 A, C1) & (icmp2 A, C2) --> something simpler.
- if (Val != Val2) return nullptr;
+ if (LHS0 != RHS0)
+ return nullptr;
- // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere.
- if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE ||
- RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE ||
- LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE ||
- RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE)
+ // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere.
+ if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE ||
+ PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE ||
+ PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE ||
+ PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)
return nullptr;
// We can't fold (ugt x, C) & (sgt x, C2).
- if (!PredicatesFoldable(LHSCC, RHSCC))
+ if (!PredicatesFoldable(PredL, PredR))
return nullptr;
// Ensure that the larger constant is on the RHS.
bool ShouldSwap;
- if (CmpInst::isSigned(LHSCC) ||
- (ICmpInst::isEquality(LHSCC) &&
- CmpInst::isSigned(RHSCC)))
- ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue());
+ if (CmpInst::isSigned(PredL) ||
+ (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR)))
+ ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());
else
- ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue());
+ ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());
if (ShouldSwap) {
std::swap(LHS, RHS);
- std::swap(LHSCst, RHSCst);
- std::swap(LHSCC, RHSCC);
+ std::swap(LHSC, RHSC);
+ std::swap(PredL, PredR);
}
// At this point, we know we have two icmp instructions
@@ -907,113 +901,95 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) {
// icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know
// (from the icmp folding check above), that the two constants
// are not equal and that the larger constant is on the RHS
- assert(LHSCst != RHSCst && "Compares not folded above?");
+ assert(LHSC != RHSC && "Compares not folded above?");
- switch (LHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredL) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13
- case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13
- case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13
+ case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13
+ case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13
return LHS;
}
case ICmpInst::ICMP_NE:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_ULT:
- if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13
- return Builder->CreateICmpULT(Val, LHSCst);
- if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
+ if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13
+ return Builder->CreateICmpULT(LHS0, LHSC);
+ if (LHSC->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
- break; // (X != 13 & X u< 15) -> no change
+ break; // (X != 13 & X u< 15) -> no change
case ICmpInst::ICMP_SLT:
- if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13
- return Builder->CreateICmpSLT(Val, LHSCst);
- break; // (X != 13 & X s< 15) -> no change
- case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15
- case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15
+ if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13
+ return Builder->CreateICmpSLT(LHS0, LHSC);
+ break; // (X != 13 & X s< 15) -> no change
+ case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15
+ case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15
return RHS;
case ICmpInst::ICMP_NE:
- // Special case to get the ordering right when the values wrap around
- // zero.
- if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue())
- std::swap(LHSCst, RHSCst);
- if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1
- Constant *AddCST = ConstantExpr::getNeg(LHSCst);
- Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off");
- return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1),
- Val->getName()+".cmp");
- }
- break; // (X != 13 & X != 15) -> no change
+ // Potential folds for this case should already be handled.
+ break;
}
break;
case ICmpInst::ICMP_ULT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false
- case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false
+ case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false
return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0);
- case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13
- case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13
+ case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13
+ case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13
return LHS;
- case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SLT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13
- case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13
+ case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13
return LHS;
- case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_UGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15
return RHS;
- case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change
- break;
case ICmpInst::ICMP_NE:
- if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14
- return Builder->CreateICmp(LHSCC, Val, RHSCst);
- break; // (X u> 13 & X != 15) -> no change
- case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
+ if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14
+ return Builder->CreateICmp(PredL, LHS0, RHSC);
+ break; // (X u> 13 & X != 15) -> no change
+ case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(),
false, true);
- case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15
- case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15
+ case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15
return RHS;
- case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change
- break;
case ICmpInst::ICMP_NE:
- if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14
- return Builder->CreateICmp(LHSCC, Val, RHSCst);
- break; // (X s> 13 & X != 15) -> no change
- case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1
- return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(),
- true, true);
- case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change
- break;
+ if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14
+ return Builder->CreateICmp(PredL, LHS0, RHSC);
+ break; // (X s> 13 & X != 15) -> no change
+ case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1
+ return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true,
+ true);
}
break;
}
@@ -1314,39 +1290,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
break;
}
- case Instruction::Add:
- // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS.
- // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0
- // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0
- if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I))
- return BinaryOperator::CreateAnd(V, AndRHS);
- if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I))
- return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes
- break;
-
case Instruction::Sub:
- // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS.
- // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0
- // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0
- if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I))
- return BinaryOperator::CreateAnd(V, AndRHS);
-
// -x & 1 -> x & 1
if (AndRHSMask == 1 && match(Op0LHS, m_Zero()))
return BinaryOperator::CreateAnd(Op0RHS, AndRHS);
- // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS
- // has 1's for all bits that the subtraction with A might affect.
- if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) {
- uint32_t BitWidth = AndRHSMask.getBitWidth();
- uint32_t Zeros = AndRHSMask.countLeadingZeros();
- APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros);
-
- if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) {
- Value *NewNeg = Builder->CreateNeg(Op0RHS);
- return BinaryOperator::CreateAnd(NewNeg, AndRHS);
- }
- }
break;
case Instruction::Shl:
@@ -1361,6 +1309,33 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
break;
}
+ // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth
+ // of X and OP behaves well when given trunc(C1) and X.
+ switch (Op0I->getOpcode()) {
+ default:
+ break;
+ case Instruction::Xor:
+ case Instruction::Or:
+ case Instruction::Mul:
+ case Instruction::Add:
+ case Instruction::Sub:
+ Value *X;
+ ConstantInt *C1;
+ if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) {
+ if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) {
+ auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType());
+ Value *BinOp;
+ if (isa<ZExtInst>(Op0LHS))
+ BinOp = Builder->CreateBinOp(Op0I->getOpcode(), X, TruncC1);
+ else
+ BinOp = Builder->CreateBinOp(Op0I->getOpcode(), TruncC1, X);
+ auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType());
+ auto *And = Builder->CreateAnd(BinOp, TruncC2);
+ return new ZExtInst(And, I.getType());
+ }
+ }
+ }
+
if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1)))
if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I))
return Res;
@@ -1381,10 +1356,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(NewCast, C3);
}
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder))
return DeMorgan;
@@ -1630,15 +1606,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D,
/// Fold (icmp)|(icmp) if possible.
Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
Instruction *CxtI) {
- ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate();
+ ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
// Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2)
// if K1 and K2 are a one-bit mask.
- ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1));
- ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1));
+ ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1));
+ ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1));
- if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() &&
- RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) {
+ if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() &&
+ RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0));
BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0));
@@ -1680,52 +1656,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask.
// This implies all values in the two ranges differ by exactly one bit.
- if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) &&
- LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() &&
- RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() &&
- LHSCst->getValue() == (RHSCst->getValue())) {
+ if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) &&
+ PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() &&
+ LHSC->getType() == RHSC->getType() &&
+ LHSC->getValue() == (RHSC->getValue())) {
Value *LAdd = LHS->getOperand(0);
Value *RAdd = RHS->getOperand(0);
Value *LAddOpnd, *RAddOpnd;
- ConstantInt *LAddCst, *RAddCst;
- if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) &&
- match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) &&
- LAddCst->getValue().ugt(LHSCst->getValue()) &&
- RAddCst->getValue().ugt(LHSCst->getValue())) {
-
- APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue();
- if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) {
- ConstantInt *MaxAddCst = nullptr;
- if (LAddCst->getValue().ult(RAddCst->getValue()))
- MaxAddCst = RAddCst;
+ ConstantInt *LAddC, *RAddC;
+ if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) &&
+ match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) &&
+ LAddC->getValue().ugt(LHSC->getValue()) &&
+ RAddC->getValue().ugt(LHSC->getValue())) {
+
+ APInt DiffC = LAddC->getValue() ^ RAddC->getValue();
+ if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) {
+ ConstantInt *MaxAddC = nullptr;
+ if (LAddC->getValue().ult(RAddC->getValue()))
+ MaxAddC = RAddC;
else
- MaxAddCst = LAddCst;
+ MaxAddC = LAddC;
- APInt RRangeLow = -RAddCst->getValue();
- APInt RRangeHigh = RRangeLow + LHSCst->getValue();
- APInt LRangeLow = -LAddCst->getValue();
- APInt LRangeHigh = LRangeLow + LHSCst->getValue();
+ APInt RRangeLow = -RAddC->getValue();
+ APInt RRangeHigh = RRangeLow + LHSC->getValue();
+ APInt LRangeLow = -LAddC->getValue();
+ APInt LRangeHigh = LRangeLow + LHSC->getValue();
APInt LowRangeDiff = RRangeLow ^ LRangeLow;
APInt HighRangeDiff = RRangeHigh ^ LRangeHigh;
APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow
: RRangeLow - LRangeLow;
if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff &&
- RangeDiff.ugt(LHSCst->getValue())) {
- Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst);
+ RangeDiff.ugt(LHSC->getValue())) {
+ Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC);
- Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst);
- Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst);
- return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst));
+ Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskC);
+ Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddC);
+ return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSC));
}
}
}
}
// (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B)
- if (PredicatesFoldable(LHSCC, RHSCC)) {
+ if (PredicatesFoldable(PredL, PredR)) {
if (LHS->getOperand(0) == RHS->getOperand(1) &&
LHS->getOperand(1) == RHS->getOperand(0))
LHS->swapOperands();
@@ -1743,25 +1719,25 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder))
return V;
- Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0);
+ Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0);
if (LHS->hasOneUse() || RHS->hasOneUse()) {
// (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1)
// (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1)
Value *A = nullptr, *B = nullptr;
- if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) {
- B = Val;
- if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1))
- A = Val2;
- else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2)
+ if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) {
+ B = LHS0;
+ if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1))
+ A = RHS0;
+ else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0)
A = RHS->getOperand(1);
}
// (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1)
// (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1)
- else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) {
- B = Val2;
- if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1))
- A = Val;
- else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val)
+ else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) {
+ B = RHS0;
+ if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1))
+ A = LHS0;
+ else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0)
A = LHS->getOperand(1);
}
if (A && B)
@@ -1778,54 +1754,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true))
return V;
+ if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder))
+ return V;
+
// This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2).
- if (!LHSCst || !RHSCst) return nullptr;
+ if (!LHSC || !RHSC)
+ return nullptr;
- if (LHSCst == RHSCst && LHSCC == RHSCC) {
+ if (LHSC == RHSC && PredL == PredR) {
// (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0)
- if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) {
- Value *NewOr = Builder->CreateOr(Val, Val2);
- return Builder->CreateICmp(LHSCC, NewOr, LHSCst);
+ if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) {
+ Value *NewOr = Builder->CreateOr(LHS0, RHS0);
+ return Builder->CreateICmp(PredL, NewOr, LHSC);
}
}
// (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1)
// iff C2 + CA == C1.
- if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) {
- ConstantInt *AddCst;
- if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst))))
- if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue())
- return Builder->CreateICmpULE(Val, LHSCst);
+ if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) {
+ ConstantInt *AddC;
+ if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC))))
+ if (RHSC->getValue() + AddC->getValue() == LHSC->getValue())
+ return Builder->CreateICmpULE(LHS0, LHSC);
}
// From here on, we only handle:
// (icmp1 A, C1) | (icmp2 A, C2) --> something simpler.
- if (Val != Val2) return nullptr;
+ if (LHS0 != RHS0)
+ return nullptr;
- // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere.
- if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE ||
- RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE ||
- LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE ||
- RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE)
+ // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere.
+ if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE ||
+ PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE ||
+ PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE ||
+ PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE)
return nullptr;
// We can't fold (ugt x, C) | (sgt x, C2).
- if (!PredicatesFoldable(LHSCC, RHSCC))
+ if (!PredicatesFoldable(PredL, PredR))
return nullptr;
// Ensure that the larger constant is on the RHS.
bool ShouldSwap;
- if (CmpInst::isSigned(LHSCC) ||
- (ICmpInst::isEquality(LHSCC) &&
- CmpInst::isSigned(RHSCC)))
- ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue());
+ if (CmpInst::isSigned(PredL) ||
+ (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR)))
+ ShouldSwap = LHSC->getValue().sgt(RHSC->getValue());
else
- ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue());
+ ShouldSwap = LHSC->getValue().ugt(RHSC->getValue());
if (ShouldSwap) {
std::swap(LHS, RHS);
- std::swap(LHSCst, RHSCst);
- std::swap(LHSCC, RHSCC);
+ std::swap(LHSC, RHSC);
+ std::swap(PredL, PredR);
}
// At this point, we know we have two icmp instructions
@@ -1834,127 +1814,98 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
// ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the
// icmp folding check above), that the two constants are not
// equal.
- assert(LHSCst != RHSCst && "Compares not folded above?");
+ assert(LHSC != RHSC && "Compares not folded above?");
- switch (LHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredL) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
case ICmpInst::ICMP_EQ:
- if (LHS->getOperand(0) == RHS->getOperand(0)) {
- // if LHSCst and RHSCst differ only by one bit:
- // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2
- assert(LHSCst->getValue().ule(LHSCst->getValue()));
-
- APInt Xor = LHSCst->getValue() ^ RHSCst->getValue();
- if (Xor.isPowerOf2()) {
- Value *Cst = Builder->getInt(Xor);
- Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst);
- return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst);
- }
- }
-
- if (LHSCst == SubOne(RHSCst)) {
- // (X == 13 | X == 14) -> X-13 <u 2
- Constant *AddCST = ConstantExpr::getNeg(LHSCst);
- Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off");
- AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst);
- return Builder->CreateICmpULT(Add, AddCST);
- }
-
- break; // (X == 13 | X == 15) -> no change
- case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change
- case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change
+ // Potential folds for this case should already be handled.
+ break;
+ case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change
+ case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change
break;
- case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15
- case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15
+ case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15
+ case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15
return RHS;
}
break;
case ICmpInst::ICMP_NE:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13
- case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13
- case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13
+ case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13
+ case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13
return LHS;
- case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true
- case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true
- case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true
+ case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true
+ case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true
+ case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true
return Builder->getTrue();
}
case ICmpInst::ICMP_ULT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change
break;
- case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2
- // If RHSCst is [us]MAXINT, it is always false. Not handling
+ case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2
+ // If RHSC is [us]MAXINT, it is always false. Not handling
// this can cause overflow.
- if (RHSCst->isMaxValue(false))
+ if (RHSC->isMaxValue(false))
return LHS;
- return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1,
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1,
false, false);
- case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15
+ case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15
return RHS;
- case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SLT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change
break;
- case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2
- // If RHSCst is [us]MAXINT, it is always false. Not handling
+ case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2
+ // If RHSC is [us]MAXINT, it is always false. Not handling
// this can cause overflow.
- if (RHSCst->isMaxValue(true))
+ if (RHSC->isMaxValue(true))
return LHS;
- return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1,
- true, false);
- case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15
- case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15
+ return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true,
+ false);
+ case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15
+ case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15
return RHS;
- case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_UGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13
- case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13
+ case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13
return LHS;
- case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true
- case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true
+ case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true
+ case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true
return Builder->getTrue();
- case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change
- break;
}
break;
case ICmpInst::ICMP_SGT:
- switch (RHSCC) {
- default: llvm_unreachable("Unknown integer condition code!");
- case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13
- case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13
+ switch (PredR) {
+ default:
+ llvm_unreachable("Unknown integer condition code!");
+ case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13
+ case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13
return LHS;
- case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change
- break;
- case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true
- case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true
+ case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true
+ case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true
return Builder->getTrue();
- case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change
- break;
}
break;
}
@@ -2100,17 +2051,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
ConstantInt *C1 = nullptr; Value *X = nullptr;
- // (X & C1) | C2 --> (X | C2) & (C1|C2)
- // iff (C1 & C2) == 0.
- if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) &&
- (RHS->getValue() & C1->getValue()) != 0 &&
- Op0->hasOneUse()) {
- Value *Or = Builder->CreateOr(X, RHS);
- Or->takeName(Op0);
- return BinaryOperator::CreateAnd(Or,
- Builder->getInt(RHS->getValue() | C1->getValue()));
- }
-
// (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2)
if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) &&
Op0->hasOneUse()) {
@@ -2119,45 +2059,51 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
return BinaryOperator::CreateXor(Or,
Builder->getInt(C1->getValue() & ~RHS->getValue()));
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
// Given an OR instruction, check to see if this is a bswap.
if (Instruction *BSwap = MatchBSwap(I))
return BSwap;
- Value *A = nullptr, *B = nullptr;
- ConstantInt *C1 = nullptr, *C2 = nullptr;
+ {
+ Value *A;
+ const APInt *C;
+ // (X^C)|Y -> (X|Y)^C iff Y&C == 0
+ if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) &&
+ MaskedValueIsZero(Op1, *C, 0, &I)) {
+ Value *NOr = Builder->CreateOr(A, Op1);
+ NOr->takeName(Op0);
+ return BinaryOperator::CreateXor(NOr,
+ cast<Instruction>(Op0)->getOperand(1));
+ }
- // (X^C)|Y -> (X|Y)^C iff Y&C == 0
- if (Op0->hasOneUse() &&
- match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) &&
- MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) {
- Value *NOr = Builder->CreateOr(A, Op1);
- NOr->takeName(Op0);
- return BinaryOperator::CreateXor(NOr, C1);
+ // Y|(X^C) -> (X|Y)^C iff Y&C == 0
+ if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) &&
+ MaskedValueIsZero(Op0, *C, 0, &I)) {
+ Value *NOr = Builder->CreateOr(A, Op0);
+ NOr->takeName(Op0);
+ return BinaryOperator::CreateXor(NOr,
+ cast<Instruction>(Op1)->getOperand(1));
+ }
}
- // Y|(X^C) -> (X|Y)^C iff Y&C == 0
- if (Op1->hasOneUse() &&
- match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) &&
- MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) {
- Value *NOr = Builder->CreateOr(A, Op0);
- NOr->takeName(Op0);
- return BinaryOperator::CreateXor(NOr, C1);
- }
+ Value *A, *B;
// ((~A & B) | A) -> (A | B)
- if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1, m_Specific(A)))
- return BinaryOperator::CreateOr(A, B);
+ if (match(Op0, m_c_And(m_Not(m_Specific(Op1)), m_Value(A))))
+ return BinaryOperator::CreateOr(A, Op1);
+ if (match(Op1, m_c_And(m_Not(m_Specific(Op0)), m_Value(A))))
+ return BinaryOperator::CreateOr(Op0, A);
// ((A & B) | ~A) -> (~A | B)
- if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
- match(Op1, m_Not(m_Specific(A))))
- return BinaryOperator::CreateOr(Builder->CreateNot(A), B);
+ // The NOT is guaranteed to be in the RHS by complexity ordering.
+ if (match(Op1, m_Not(m_Value(A))) &&
+ match(Op0, m_c_And(m_Specific(A), m_Value(B))))
+ return BinaryOperator::CreateOr(Op1, B);
// (A & ~B) | (A ^ B) -> (A ^ B)
// (~B & A) | (A ^ B) -> (A ^ B)
@@ -2177,8 +2123,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
match(Op1, m_And(m_Value(B), m_Value(D)))) {
Value *V1 = nullptr, *V2 = nullptr;
- C1 = dyn_cast<ConstantInt>(C);
- C2 = dyn_cast<ConstantInt>(D);
+ ConstantInt *C1 = dyn_cast<ConstantInt>(C);
+ ConstantInt *C2 = dyn_cast<ConstantInt>(D);
if (C1 && C2) { // (A & C1)|(B & C2)
if ((C1->getValue() & C2->getValue()) == 0) {
// ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2)
@@ -2403,6 +2349,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) {
// be simplified by a later pass either, so we try swapping the inner/outer
// ORs in the hopes that we'll be able to simplify it this way.
// (X|C) | V --> (X|V) | C
+ ConstantInt *C1;
if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) &&
match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) {
Value *Inner = Builder->CreateOr(A, Op1);
@@ -2493,23 +2440,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
- if (Constant *RHS = dyn_cast<Constant>(Op1)) {
- if (RHS->isAllOnesValue() && Op0->hasOneUse())
- // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B
- if (CmpInst *CI = dyn_cast<CmpInst>(Op0))
- return CmpInst::Create(CI->getOpcode(),
- CI->getInversePredicate(),
- CI->getOperand(0), CI->getOperand(1));
+ // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B
+ ICmpInst::Predicate Pred;
+ if (match(Op0, m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))) &&
+ match(Op1, m_AllOnes())) {
+ cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred));
+ return replaceInstUsesWith(I, Op0);
}
- if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
+ if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) {
// fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp).
if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) {
if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) {
if (CI->hasOneUse() && Op0C->hasOneUse()) {
Instruction::CastOps Opcode = Op0C->getOpcode();
if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) &&
- (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(),
+ (RHSC == ConstantExpr::getCast(Opcode, Builder->getTrue(),
Op0C->getDestTy()))) {
CI->setPredicate(CI->getInversePredicate());
return CastInst::Create(Opcode, CI, Op0C->getType());
@@ -2520,26 +2466,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) {
// ~(c-X) == X-c-1 == X+(-c-1)
- if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue())
+ if (Op0I->getOpcode() == Instruction::Sub && RHSC->isAllOnesValue())
if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) {
Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C);
- Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C,
- ConstantInt::get(I.getType(), 1));
- return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS);
+ return BinaryOperator::CreateAdd(Op0I->getOperand(1),
+ SubOne(NegOp0I0C));
}
if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) {
if (Op0I->getOpcode() == Instruction::Add) {
// ~(X-c) --> (-c-1)-X
- if (RHS->isAllOnesValue()) {
+ if (RHSC->isAllOnesValue()) {
Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI);
- return BinaryOperator::CreateSub(
- ConstantExpr::getSub(NegOp0CI,
- ConstantInt::get(I.getType(), 1)),
- Op0I->getOperand(0));
- } else if (RHS->getValue().isSignBit()) {
+ return BinaryOperator::CreateSub(SubOne(NegOp0CI),
+ Op0I->getOperand(0));
+ } else if (RHSC->getValue().isSignBit()) {
// (X + C) ^ signbit -> (X + C + signbit)
- Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue());
+ Constant *C = Builder->getInt(RHSC->getValue() + Op0CI->getValue());
return BinaryOperator::CreateAdd(Op0I->getOperand(0), C);
}
@@ -2547,10 +2490,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
// (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0
if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(),
0, &I)) {
- Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS);
+ Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC);
// Anything in both C1 and C2 is known to be zero, remove it from
// NewRHS.
- Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS);
+ Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC);
NewRHS = ConstantExpr::getAnd(NewRHS,
ConstantExpr::getNot(CommonBits));
Worklist.Add(Op0I);
@@ -2568,7 +2511,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
E1->getOpcode() == Instruction::Xor &&
(C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) {
// fold (C1 >> C2) ^ C3
- ConstantInt *C2 = Op0CI, *C3 = RHS;
+ ConstantInt *C2 = Op0CI, *C3 = RHSC;
APInt FoldConst = C1->getValue().lshr(C2->getValue());
FoldConst ^= C3->getValue();
// Prepare the two operands.
@@ -2582,27 +2525,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
}
+ }
+ if (isa<Constant>(Op1))
if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I))
return FoldedLogic;
- }
- BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1);
- if (Op1I) {
+ {
Value *A, *B;
- if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) {
- if (A == Op0) { // B^(B|A) == (A|B)^B
- Op1I->swapOperands();
- I.swapOperands();
- std::swap(Op0, Op1);
- } else if (B == Op0) { // B^(A|B) == (A|B)^B
+ if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) {
+ if (A == Op0) { // A^(A|B) == A^(B|A)
+ cast<BinaryOperator>(Op1)->swapOperands();
+ std::swap(A, B);
+ }
+ if (B == Op0) { // A^(B|A) == (B|A)^A
I.swapOperands(); // Simplified below.
std::swap(Op0, Op1);
}
- } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) &&
- Op1I->hasOneUse()){
+ } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) {
if (A == Op0) { // A^(A&B) -> A^(B&A)
- Op1I->swapOperands();
+ cast<BinaryOperator>(Op1)->swapOperands();
std::swap(A, B);
}
if (B == Op0) { // A^(B&A) -> (B&A)^A
@@ -2612,65 +2554,63 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
}
}
- BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0);
- if (Op0I) {
+ {
Value *A, *B;
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- Op0I->hasOneUse()) {
+ if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) {
if (A == Op1) // (B|A)^B == (A|B)^B
std::swap(A, B);
if (B == Op1) // (A|B)^B == A & ~B
return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1));
- } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- Op0I->hasOneUse()){
+ } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) {
if (A == Op1) // (A&B)^A -> (B&A)^A
std::swap(A, B);
+ const APInt *C;
if (B == Op1 && // (B&A)^A == ~B & A
- !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C
+ !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C
return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1);
}
}
}
- if (Op0I && Op1I) {
+ {
Value *A, *B, *C, *D;
// (A & B)^(A | B) -> A ^ B
- if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Or(m_Value(C), m_Value(D)))) {
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Or(m_Value(C), m_Value(D)))) {
if ((A == C && B == D) || (A == D && B == C))
return BinaryOperator::CreateXor(A, B);
}
// (A | B)^(A & B) -> A ^ B
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- match(Op1I, m_And(m_Value(C), m_Value(D)))) {
+ if (match(Op0, m_Or(m_Value(A), m_Value(B))) &&
+ match(Op1, m_And(m_Value(C), m_Value(D)))) {
if ((A == C && B == D) || (A == D && B == C))
return BinaryOperator::CreateXor(A, B);
}
// (A | ~B) ^ (~A | B) -> A ^ B
// (~B | A) ^ (~A | B) -> A ^ B
- if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B))))
+ if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) &&
+ match(Op1, m_Or(m_Not(m_Specific(A)), m_Specific(B))))
return BinaryOperator::CreateXor(A, B);
// (~A | B) ^ (A | ~B) -> A ^ B
- if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) {
+ if (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) &&
+ match(Op1, m_Or(m_Specific(A), m_Not(m_Specific(B))))) {
return BinaryOperator::CreateXor(A, B);
}
// (A & ~B) ^ (~A & B) -> A ^ B
// (~B & A) ^ (~A & B) -> A ^ B
- if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) &&
- match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B))))
+ if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) &&
+ match(Op1, m_And(m_Not(m_Specific(A)), m_Specific(B))))
return BinaryOperator::CreateXor(A, B);
// (~A & B) ^ (A & ~B) -> A ^ B
- if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) &&
- match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) {
+ if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) &&
+ match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) {
return BinaryOperator::CreateXor(A, B);
}
// (A ^ C)^(A | B) -> ((~A) & B) ^ C
- if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) &&
- match(Op1I, m_Or(m_Value(A), m_Value(B)))) {
+ if (match(Op0, m_Xor(m_Value(D), m_Value(C))) &&
+ match(Op1, m_Or(m_Value(A), m_Value(B)))) {
if (D == A)
return BinaryOperator::CreateXor(
Builder->CreateAnd(Builder->CreateNot(A), B), C);
@@ -2679,8 +2619,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
Builder->CreateAnd(Builder->CreateNot(B), A), C);
}
// (A | B)^(A ^ C) -> ((~A) & B) ^ C
- if (match(Op0I, m_Or(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Xor(m_Value(D), m_Value(C)))) {
+ if (match(Op0, m_Or(m_Value(A), m_Value(B))) &&
+ match(Op1, m_Xor(m_Value(D), m_Value(C)))) {
if (D == A)
return BinaryOperator::CreateXor(
Builder->CreateAnd(Builder->CreateNot(A), B), C);
@@ -2689,12 +2629,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) {
Builder->CreateAnd(Builder->CreateNot(B), A), C);
}
// (A & B) ^ (A ^ B) -> (A | B)
- if (match(Op0I, m_And(m_Value(A), m_Value(B))) &&
- match(Op1I, m_Xor(m_Specific(A), m_Specific(B))))
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Xor(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
// (A ^ B) ^ (A & B) -> (A | B)
- if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) &&
- match(Op1I, m_And(m_Specific(A), m_Specific(B))))
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_And(m_Specific(A), m_Specific(B))))
return BinaryOperator::CreateOr(A, B);
}
diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 2ef82ba3ed8c..69484f47223f 100644
--- a/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -60,6 +60,12 @@ using namespace PatternMatch;
STATISTIC(NumSimplified, "Number of library calls simplified");
+static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements(
+ "unfold-element-atomic-memcpy-max-elements",
+ cl::init(16),
+ cl::desc("Maximum number of elements in atomic memcpy the optimizer is "
+ "allowed to unfold"));
+
/// Return the specified type promoted as it would be to pass though a va_arg
/// area.
static Type *getPromotedType(Type *Ty) {
@@ -70,27 +76,6 @@ static Type *getPromotedType(Type *Ty) {
return Ty;
}
-/// Given an aggregate type which ultimately holds a single scalar element,
-/// like {{{type}}} or [1 x type], return type.
-static Type *reduceToSingleValueType(Type *T) {
- while (!T->isSingleValueType()) {
- if (StructType *STy = dyn_cast<StructType>(T)) {
- if (STy->getNumElements() == 1)
- T = STy->getElementType(0);
- else
- break;
- } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) {
- if (ATy->getNumElements() == 1)
- T = ATy->getElementType();
- else
- break;
- } else
- break;
- }
-
- return T;
-}
-
/// Return a constant boolean vector that has true elements in all positions
/// where the input constant data vector has an element with the sign bit set.
static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
@@ -108,6 +93,78 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) {
return ConstantVector::get(BoolVec);
}
+Instruction *
+InstCombiner::SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI) {
+ // Try to unfold this intrinsic into sequence of explicit atomic loads and
+ // stores.
+ // First check that number of elements is compile time constant.
+ auto *NumElementsCI = dyn_cast<ConstantInt>(AMI->getNumElements());
+ if (!NumElementsCI)
+ return nullptr;
+
+ // Check that there are not too many elements.
+ uint64_t NumElements = NumElementsCI->getZExtValue();
+ if (NumElements >= UnfoldElementAtomicMemcpyMaxElements)
+ return nullptr;
+
+ // Don't unfold into illegal integers
+ uint64_t ElementSizeInBytes = AMI->getElementSizeInBytes() * 8;
+ if (!getDataLayout().isLegalInteger(ElementSizeInBytes))
+ return nullptr;
+
+ // Cast source and destination to the correct type. Intrinsic input arguments
+ // are usually represented as i8*.
+ // Often operands will be explicitly casted to i8* and we can just strip
+ // those casts instead of inserting new ones. However it's easier to rely on
+ // other InstCombine rules which will cover trivial cases anyway.
+ Value *Src = AMI->getRawSource();
+ Value *Dst = AMI->getRawDest();
+ Type *ElementPointerType = Type::getIntNPtrTy(
+ AMI->getContext(), ElementSizeInBytes, Src->getType()->getPointerAddressSpace());
+
+ Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType,
+ "memcpy_unfold.src_casted");
+ Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType,
+ "memcpy_unfold.dst_casted");
+
+ for (uint64_t i = 0; i < NumElements; ++i) {
+ // Get current element addresses
+ ConstantInt *ElementIdxCI =
+ ConstantInt::get(AMI->getContext(), APInt(64, i));
+ Value *SrcElementAddr =
+ Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr");
+ Value *DstElementAddr =
+ Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr");
+
+ // Load from the source. Transfer alignment information and mark load as
+ // unordered atomic.
+ LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val");
+ Load->setOrdering(AtomicOrdering::Unordered);
+ // We know alignment of the first element. It is also guaranteed by the
+ // verifier that element size is less or equal than first element alignment
+ // and both of this values are powers of two.
+ // This means that all subsequent accesses are at least element size
+ // aligned.
+ // TODO: We can infer better alignment but there is no evidence that this
+ // will matter.
+ Load->setAlignment(i == 0 ? AMI->getSrcAlignment()
+ : AMI->getElementSizeInBytes());
+ Load->setDebugLoc(AMI->getDebugLoc());
+
+ // Store loaded value via unordered atomic store.
+ StoreInst *Store = Builder->CreateStore(Load, DstElementAddr);
+ Store->setOrdering(AtomicOrdering::Unordered);
+ Store->setAlignment(i == 0 ? AMI->getDstAlignment()
+ : AMI->getElementSizeInBytes());
+ Store->setDebugLoc(AMI->getDebugLoc());
+ }
+
+ // Set the number of elements of the copy to 0, it will be deleted on the
+ // next iteration.
+ AMI->setNumElements(Constant::getNullValue(NumElementsCI->getType()));
+ return AMI;
+}
+
Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT);
unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT);
@@ -144,41 +201,19 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) {
Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp);
Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp);
- // Memcpy forces the use of i8* for the source and destination. That means
- // that if you're using memcpy to move one double around, you'll get a cast
- // from double* to i8*. We'd much rather use a double load+store rather than
- // an i64 load+store, here because this improves the odds that the source or
- // dest address will be promotable. See if we can find a better type than the
- // integer datatype.
- Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts();
+ // If the memcpy has metadata describing the members, see if we can get the
+ // TBAA tag describing our copy.
MDNode *CopyMD = nullptr;
- if (StrippedDest != MI->getArgOperand(0)) {
- Type *SrcETy = cast<PointerType>(StrippedDest->getType())
- ->getElementType();
- if (SrcETy->isSized() && DL.getTypeStoreSize(SrcETy) == Size) {
- // The SrcETy might be something like {{{double}}} or [1 x double]. Rip
- // down through these levels if so.
- SrcETy = reduceToSingleValueType(SrcETy);
-
- if (SrcETy->isSingleValueType()) {
- NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp);
- NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp);
-
- // If the memcpy has metadata describing the members, see if we can
- // get the TBAA tag describing our copy.
- if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) {
- if (M->getNumOperands() == 3 && M->getOperand(0) &&
- mdconst::hasa<ConstantInt>(M->getOperand(0)) &&
- mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() &&
- M->getOperand(1) &&
- mdconst::hasa<ConstantInt>(M->getOperand(1)) &&
- mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() ==
- Size &&
- M->getOperand(2) && isa<MDNode>(M->getOperand(2)))
- CopyMD = cast<MDNode>(M->getOperand(2));
- }
- }
- }
+ if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) {
+ if (M->getNumOperands() == 3 && M->getOperand(0) &&
+ mdconst::hasa<ConstantInt>(M->getOperand(0)) &&
+ mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() &&
+ M->getOperand(1) &&
+ mdconst::hasa<ConstantInt>(M->getOperand(1)) &&
+ mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() ==
+ Size &&
+ M->getOperand(2) && isa<MDNode>(M->getOperand(2)))
+ CopyMD = cast<MDNode>(M->getOperand(2));
}
// If the memcpy/memmove provides better alignment info than we can
@@ -510,6 +545,131 @@ static Value *simplifyX86varShift(const IntrinsicInst &II,
return Builder.CreateAShr(Vec, ShiftVec);
}
+static Value *simplifyX86muldq(const IntrinsicInst &II,
+ InstCombiner::BuilderTy &Builder) {
+ Value *Arg0 = II.getArgOperand(0);
+ Value *Arg1 = II.getArgOperand(1);
+ Type *ResTy = II.getType();
+ assert(Arg0->getType()->getScalarSizeInBits() == 32 &&
+ Arg1->getType()->getScalarSizeInBits() == 32 &&
+ ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types");
+
+ // muldq/muludq(undef, undef) -> zero (matches generic mul behavior)
+ if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1))
+ return ConstantAggregateZero::get(ResTy);
+
+ // Constant folding.
+ // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)),
+ // vXi64 sext(shuffle<0,2,..>(Arg1))))
+ // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)),
+ // vXi64 zext(shuffle<0,2,..>(Arg1))))
+ if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1))
+ return nullptr;
+
+ unsigned NumElts = ResTy->getVectorNumElements();
+ assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) &&
+ Arg1->getType()->getVectorNumElements() == (2 * NumElts) &&
+ "Unexpected muldq/muludq types");
+
+ unsigned IntrinsicID = II.getIntrinsicID();
+ bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID ||
+ Intrinsic::x86_avx2_pmul_dq == IntrinsicID ||
+ Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID);
+
+ SmallVector<unsigned, 16> ShuffleMask;
+ for (unsigned i = 0; i != NumElts; ++i)
+ ShuffleMask.push_back(i * 2);
+
+ auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask);
+ auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask);
+
+ if (IsSigned) {
+ LHS = Builder.CreateSExt(LHS, ResTy);
+ RHS = Builder.CreateSExt(RHS, ResTy);
+ } else {
+ LHS = Builder.CreateZExt(LHS, ResTy);
+ RHS = Builder.CreateZExt(RHS, ResTy);
+ }
+
+ return Builder.CreateMul(LHS, RHS);
+}
+
+static Value *simplifyX86pack(IntrinsicInst &II, InstCombiner &IC,
+ InstCombiner::BuilderTy &Builder, bool IsSigned) {
+ Value *Arg0 = II.getArgOperand(0);
+ Value *Arg1 = II.getArgOperand(1);
+ Type *ResTy = II.getType();
+
+ // Fast all undef handling.
+ if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1))
+ return UndefValue::get(ResTy);
+
+ Type *ArgTy = Arg0->getType();
+ unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128;
+ unsigned NumDstElts = ResTy->getVectorNumElements();
+ unsigned NumSrcElts = ArgTy->getVectorNumElements();
+ assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types");
+
+ unsigned NumDstEltsPerLane = NumDstElts / NumLanes;
+ unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes;
+ unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits();
+ assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) &&
+ "Unexpected packing types");
+
+ // Constant folding.
+ auto *Cst0 = dyn_cast<Constant>(Arg0);
+ auto *Cst1 = dyn_cast<Constant>(Arg1);
+ if (!Cst0 || !Cst1)
+ return nullptr;
+
+ SmallVector<Constant *, 32> Vals;
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) {
+ unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane;
+ auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0;
+ auto *COp = Cst->getAggregateElement(SrcIdx);
+ if (COp && isa<UndefValue>(COp)) {
+ Vals.push_back(UndefValue::get(ResTy->getScalarType()));
+ continue;
+ }
+
+ auto *CInt = dyn_cast_or_null<ConstantInt>(COp);
+ if (!CInt)
+ return nullptr;
+
+ APInt Val = CInt->getValue();
+ assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() &&
+ "Unexpected constant bitwidth");
+
+ if (IsSigned) {
+ // PACKSS: Truncate signed value with signed saturation.
+ // Source values less than dst minint are saturated to minint.
+ // Source values greater than dst maxint are saturated to maxint.
+ if (Val.isSignedIntN(DstScalarSizeInBits))
+ Val = Val.trunc(DstScalarSizeInBits);
+ else if (Val.isNegative())
+ Val = APInt::getSignedMinValue(DstScalarSizeInBits);
+ else
+ Val = APInt::getSignedMaxValue(DstScalarSizeInBits);
+ } else {
+ // PACKUS: Truncate signed value with unsigned saturation.
+ // Source values less than zero are saturated to zero.
+ // Source values greater than dst maxuint are saturated to maxuint.
+ if (Val.isIntN(DstScalarSizeInBits))
+ Val = Val.trunc(DstScalarSizeInBits);
+ else if (Val.isNegative())
+ Val = APInt::getNullValue(DstScalarSizeInBits);
+ else
+ Val = APInt::getAllOnesValue(DstScalarSizeInBits);
+ }
+
+ Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val));
+ }
+ }
+
+ return ConstantVector::get(Vals);
+}
+
static Value *simplifyX86movmsk(const IntrinsicInst &II,
InstCombiner::BuilderTy &Builder) {
Value *Arg = II.getArgOperand(0);
@@ -1330,6 +1490,27 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) {
return true;
}
+// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs.
+//
+// A single NaN input is folded to minnum, so we rely on that folding for
+// handling NaNs.
+static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1,
+ const APFloat &Src2) {
+ APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2);
+
+ APFloat::cmpResult Cmp0 = Max3.compare(Src0);
+ assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately");
+ if (Cmp0 == APFloat::cmpEqual)
+ return maxnum(Src1, Src2);
+
+ APFloat::cmpResult Cmp1 = Max3.compare(Src1);
+ assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately");
+ if (Cmp1 == APFloat::cmpEqual)
+ return maxnum(Src0, Src2);
+
+ return maxnum(Src0, Src1);
+}
+
// Returns true iff the 2 intrinsics have the same operands, limiting the
// comparison to the first NumOperands.
static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E,
@@ -1373,6 +1554,254 @@ static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID,
return false;
}
+// Convert NVVM intrinsics to target-generic LLVM code where possible.
+static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) {
+ // Each NVVM intrinsic we can simplify can be replaced with one of:
+ //
+ // * an LLVM intrinsic,
+ // * an LLVM cast operation,
+ // * an LLVM binary operation, or
+ // * ad-hoc LLVM IR for the particular operation.
+
+ // Some transformations are only valid when the module's
+ // flush-denormals-to-zero (ftz) setting is true/false, whereas other
+ // transformations are valid regardless of the module's ftz setting.
+ enum FtzRequirementTy {
+ FTZ_Any, // Any ftz setting is ok.
+ FTZ_MustBeOn, // Transformation is valid only if ftz is on.
+ FTZ_MustBeOff, // Transformation is valid only if ftz is off.
+ };
+ // Classes of NVVM intrinsics that can't be replaced one-to-one with a
+ // target-generic intrinsic, cast op, or binary op but that we can nonetheless
+ // simplify.
+ enum SpecialCase {
+ SPC_Reciprocal,
+ };
+
+ // SimplifyAction is a poor-man's variant (plus an additional flag) that
+ // represents how to replace an NVVM intrinsic with target-generic LLVM IR.
+ struct SimplifyAction {
+ // Invariant: At most one of these Optionals has a value.
+ Optional<Intrinsic::ID> IID;
+ Optional<Instruction::CastOps> CastOp;
+ Optional<Instruction::BinaryOps> BinaryOp;
+ Optional<SpecialCase> Special;
+
+ FtzRequirementTy FtzRequirement = FTZ_Any;
+
+ SimplifyAction() = default;
+
+ SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq)
+ : IID(IID), FtzRequirement(FtzReq) {}
+
+ // Cast operations don't have anything to do with FTZ, so we skip that
+ // argument.
+ SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {}
+
+ SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq)
+ : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {}
+
+ SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq)
+ : Special(Special), FtzRequirement(FtzReq) {}
+ };
+
+ // Try to generate a SimplifyAction describing how to replace our
+ // IntrinsicInstr with target-generic LLVM IR.
+ const SimplifyAction Action = [II]() -> SimplifyAction {
+ switch (II->getIntrinsicID()) {
+
+ // NVVM intrinsics that map directly to LLVM intrinsics.
+ case Intrinsic::nvvm_ceil_d:
+ return {Intrinsic::ceil, FTZ_Any};
+ case Intrinsic::nvvm_ceil_f:
+ return {Intrinsic::ceil, FTZ_MustBeOff};
+ case Intrinsic::nvvm_ceil_ftz_f:
+ return {Intrinsic::ceil, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fabs_d:
+ return {Intrinsic::fabs, FTZ_Any};
+ case Intrinsic::nvvm_fabs_f:
+ return {Intrinsic::fabs, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fabs_ftz_f:
+ return {Intrinsic::fabs, FTZ_MustBeOn};
+ case Intrinsic::nvvm_floor_d:
+ return {Intrinsic::floor, FTZ_Any};
+ case Intrinsic::nvvm_floor_f:
+ return {Intrinsic::floor, FTZ_MustBeOff};
+ case Intrinsic::nvvm_floor_ftz_f:
+ return {Intrinsic::floor, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fma_rn_d:
+ return {Intrinsic::fma, FTZ_Any};
+ case Intrinsic::nvvm_fma_rn_f:
+ return {Intrinsic::fma, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ return {Intrinsic::fma, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fmax_d:
+ return {Intrinsic::maxnum, FTZ_Any};
+ case Intrinsic::nvvm_fmax_f:
+ return {Intrinsic::maxnum, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fmax_ftz_f:
+ return {Intrinsic::maxnum, FTZ_MustBeOn};
+ case Intrinsic::nvvm_fmin_d:
+ return {Intrinsic::minnum, FTZ_Any};
+ case Intrinsic::nvvm_fmin_f:
+ return {Intrinsic::minnum, FTZ_MustBeOff};
+ case Intrinsic::nvvm_fmin_ftz_f:
+ return {Intrinsic::minnum, FTZ_MustBeOn};
+ case Intrinsic::nvvm_round_d:
+ return {Intrinsic::round, FTZ_Any};
+ case Intrinsic::nvvm_round_f:
+ return {Intrinsic::round, FTZ_MustBeOff};
+ case Intrinsic::nvvm_round_ftz_f:
+ return {Intrinsic::round, FTZ_MustBeOn};
+ case Intrinsic::nvvm_sqrt_rn_d:
+ return {Intrinsic::sqrt, FTZ_Any};
+ case Intrinsic::nvvm_sqrt_f:
+ // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the
+ // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts
+ // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are
+ // the versions with explicit ftz-ness.
+ return {Intrinsic::sqrt, FTZ_Any};
+ case Intrinsic::nvvm_sqrt_rn_f:
+ return {Intrinsic::sqrt, FTZ_MustBeOff};
+ case Intrinsic::nvvm_sqrt_rn_ftz_f:
+ return {Intrinsic::sqrt, FTZ_MustBeOn};
+ case Intrinsic::nvvm_trunc_d:
+ return {Intrinsic::trunc, FTZ_Any};
+ case Intrinsic::nvvm_trunc_f:
+ return {Intrinsic::trunc, FTZ_MustBeOff};
+ case Intrinsic::nvvm_trunc_ftz_f:
+ return {Intrinsic::trunc, FTZ_MustBeOn};
+
+ // NVVM intrinsics that map to LLVM cast operations.
+ //
+ // Note that llvm's target-generic conversion operators correspond to the rz
+ // (round to zero) versions of the nvvm conversion intrinsics, even though
+ // most everything else here uses the rn (round to nearest even) nvvm ops.
+ case Intrinsic::nvvm_d2i_rz:
+ case Intrinsic::nvvm_f2i_rz:
+ case Intrinsic::nvvm_d2ll_rz:
+ case Intrinsic::nvvm_f2ll_rz:
+ return {Instruction::FPToSI};
+ case Intrinsic::nvvm_d2ui_rz:
+ case Intrinsic::nvvm_f2ui_rz:
+ case Intrinsic::nvvm_d2ull_rz:
+ case Intrinsic::nvvm_f2ull_rz:
+ return {Instruction::FPToUI};
+ case Intrinsic::nvvm_i2d_rz:
+ case Intrinsic::nvvm_i2f_rz:
+ case Intrinsic::nvvm_ll2d_rz:
+ case Intrinsic::nvvm_ll2f_rz:
+ return {Instruction::SIToFP};
+ case Intrinsic::nvvm_ui2d_rz:
+ case Intrinsic::nvvm_ui2f_rz:
+ case Intrinsic::nvvm_ull2d_rz:
+ case Intrinsic::nvvm_ull2f_rz:
+ return {Instruction::UIToFP};
+
+ // NVVM intrinsics that map to LLVM binary ops.
+ case Intrinsic::nvvm_add_rn_d:
+ return {Instruction::FAdd, FTZ_Any};
+ case Intrinsic::nvvm_add_rn_f:
+ return {Instruction::FAdd, FTZ_MustBeOff};
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return {Instruction::FAdd, FTZ_MustBeOn};
+ case Intrinsic::nvvm_mul_rn_d:
+ return {Instruction::FMul, FTZ_Any};
+ case Intrinsic::nvvm_mul_rn_f:
+ return {Instruction::FMul, FTZ_MustBeOff};
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ return {Instruction::FMul, FTZ_MustBeOn};
+ case Intrinsic::nvvm_div_rn_d:
+ return {Instruction::FDiv, FTZ_Any};
+ case Intrinsic::nvvm_div_rn_f:
+ return {Instruction::FDiv, FTZ_MustBeOff};
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ return {Instruction::FDiv, FTZ_MustBeOn};
+
+ // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but
+ // need special handling.
+ //
+ // We seem to be mising intrinsics for rcp.approx.{ftz.}f32, which is just
+ // as well.
+ case Intrinsic::nvvm_rcp_rn_d:
+ return {SPC_Reciprocal, FTZ_Any};
+ case Intrinsic::nvvm_rcp_rn_f:
+ return {SPC_Reciprocal, FTZ_MustBeOff};
+ case Intrinsic::nvvm_rcp_rn_ftz_f:
+ return {SPC_Reciprocal, FTZ_MustBeOn};
+
+ // We do not currently simplify intrinsics that give an approximate answer.
+ // These include:
+ //
+ // - nvvm_cos_approx_{f,ftz_f}
+ // - nvvm_ex2_approx_{d,f,ftz_f}
+ // - nvvm_lg2_approx_{d,f,ftz_f}
+ // - nvvm_sin_approx_{f,ftz_f}
+ // - nvvm_sqrt_approx_{f,ftz_f}
+ // - nvvm_rsqrt_approx_{d,f,ftz_f}
+ // - nvvm_div_approx_{ftz_d,ftz_f,f}
+ // - nvvm_rcp_approx_ftz_d
+ //
+ // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast"
+ // means that fastmath is enabled in the intrinsic. Unfortunately only
+ // binary operators (currently) have a fastmath bit in SelectionDAG, so this
+ // information gets lost and we can't select on it.
+ //
+ // TODO: div and rcp are lowered to a binary op, so these we could in theory
+ // lower them to "fast fdiv".
+
+ default:
+ return {};
+ }
+ }();
+
+ // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we
+ // can bail out now. (Notice that in the case that IID is not an NVVM
+ // intrinsic, we don't have to look up any module metadata, as
+ // FtzRequirementTy will be FTZ_Any.)
+ if (Action.FtzRequirement != FTZ_Any) {
+ bool FtzEnabled =
+ II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() ==
+ "true";
+
+ if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn))
+ return nullptr;
+ }
+
+ // Simplify to target-generic intrinsic.
+ if (Action.IID) {
+ SmallVector<Value *, 4> Args(II->arg_operands());
+ // All the target-generic intrinsics currently of interest to us have one
+ // type argument, equal to that of the nvvm intrinsic's argument.
+ Type *Tys[] = {II->getArgOperand(0)->getType()};
+ return CallInst::Create(
+ Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args);
+ }
+
+ // Simplify to target-generic binary op.
+ if (Action.BinaryOp)
+ return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0),
+ II->getArgOperand(1), II->getName());
+
+ // Simplify to target-generic cast op.
+ if (Action.CastOp)
+ return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(),
+ II->getName());
+
+ // All that's left are the special cases.
+ if (!Action.Special)
+ return nullptr;
+
+ switch (*Action.Special) {
+ case SPC_Reciprocal:
+ // Simplify reciprocal.
+ return BinaryOperator::Create(
+ Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1),
+ II->getArgOperand(0), II->getName());
+ }
+ llvm_unreachable("All SpecialCase enumerators should be handled in switch.");
+}
+
Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) {
removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this);
return nullptr;
@@ -1462,6 +1891,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
if (Changed) return II;
}
+ if (auto *AMI = dyn_cast<ElementAtomicMemCpyInst>(II)) {
+ if (Constant *C = dyn_cast<Constant>(AMI->getNumElements()))
+ if (C->isNullValue())
+ return eraseInstFromFunction(*AMI);
+
+ if (Instruction *I = SimplifyElementAtomicMemCpy(AMI))
+ return I;
+ }
+
+ if (Instruction *I = SimplifyNVVMIntrinsic(II, *this))
+ return I;
+
auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width,
unsigned DemandedWidth) {
APInt UndefElts(Width, 0);
@@ -1581,8 +2022,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(*II, V);
break;
}
- case Intrinsic::fma:
case Intrinsic::fmuladd: {
+ // Canonicalize fast fmuladd to the separate fmul + fadd.
+ if (II->hasUnsafeAlgebra()) {
+ BuilderTy::FastMathFlagGuard Guard(*Builder);
+ Builder->setFastMathFlags(II->getFastMathFlags());
+ Value *Mul = Builder->CreateFMul(II->getArgOperand(0),
+ II->getArgOperand(1));
+ Value *Add = Builder->CreateFAdd(Mul, II->getArgOperand(2));
+ Add->takeName(II);
+ return replaceInstUsesWith(*II, Add);
+ }
+
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::fma: {
Value *Src0 = II->getArgOperand(0);
Value *Src1 = II->getArgOperand(1);
@@ -1631,6 +2085,26 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return SelectInst::Create(Cond, Call0, Call1);
}
+ LLVM_FALLTHROUGH;
+ }
+ case Intrinsic::ceil:
+ case Intrinsic::floor:
+ case Intrinsic::round:
+ case Intrinsic::nearbyint:
+ case Intrinsic::rint:
+ case Intrinsic::trunc: {
+ Value *ExtSrc;
+ if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) &&
+ II->getArgOperand(0)->hasOneUse()) {
+ // fabs (fpext x) -> fpext (fabs x)
+ Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(),
+ { ExtSrc->getType() });
+ CallInst *NewFabs = Builder->CreateCall(F, ExtSrc);
+ NewFabs->copyFastMathFlags(II);
+ NewFabs->takeName(II);
+ return new FPExtInst(NewFabs, II->getType());
+ }
+
break;
}
case Intrinsic::cos:
@@ -1863,6 +2337,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return II;
break;
}
+ case Intrinsic::x86_avx512_mask_cmp_pd_128:
+ case Intrinsic::x86_avx512_mask_cmp_pd_256:
+ case Intrinsic::x86_avx512_mask_cmp_pd_512:
+ case Intrinsic::x86_avx512_mask_cmp_ps_128:
+ case Intrinsic::x86_avx512_mask_cmp_ps_256:
+ case Intrinsic::x86_avx512_mask_cmp_ps_512: {
+ // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a)
+ Value *Arg0 = II->getArgOperand(0);
+ Value *Arg1 = II->getArgOperand(1);
+ bool Arg0IsZero = match(Arg0, m_Zero());
+ if (Arg0IsZero)
+ std::swap(Arg0, Arg1);
+ Value *A, *B;
+ // This fold requires only the NINF(not +/- inf) since inf minus
+ // inf is nan.
+ // NSZ(No Signed Zeros) is not needed because zeros of any sign are
+ // equal for both compares.
+ // NNAN is not needed because nans compare the same for both compares.
+ // The compare intrinsic uses the above assumptions and therefore
+ // doesn't require additional flags.
+ if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) &&
+ match(Arg1, m_Zero()) &&
+ cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) {
+ if (Arg0IsZero)
+ std::swap(A, B);
+ II->setArgOperand(0, A);
+ II->setArgOperand(1, B);
+ return II;
+ }
+ break;
+ }
case Intrinsic::x86_avx512_mask_add_ps_512:
case Intrinsic::x86_avx512_mask_div_ps_512:
@@ -2130,6 +2635,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
case Intrinsic::x86_avx2_pmulu_dq:
case Intrinsic::x86_avx512_pmul_dq_512:
case Intrinsic::x86_avx512_pmulu_dq_512: {
+ if (Value *V = simplifyX86muldq(*II, *Builder))
+ return replaceInstUsesWith(*II, V);
+
unsigned VWidth = II->getType()->getVectorNumElements();
APInt UndefElts(VWidth, 0);
APInt DemandedElts = APInt::getAllOnesValue(VWidth);
@@ -2141,6 +2649,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::x86_sse2_packssdw_128:
+ case Intrinsic::x86_sse2_packsswb_128:
+ case Intrinsic::x86_avx2_packssdw:
+ case Intrinsic::x86_avx2_packsswb:
+ case Intrinsic::x86_avx512_packssdw_512:
+ case Intrinsic::x86_avx512_packsswb_512:
+ if (Value *V = simplifyX86pack(*II, *this, *Builder, true))
+ return replaceInstUsesWith(*II, V);
+ break;
+
+ case Intrinsic::x86_sse2_packuswb_128:
+ case Intrinsic::x86_sse41_packusdw:
+ case Intrinsic::x86_avx2_packusdw:
+ case Intrinsic::x86_avx2_packuswb:
+ case Intrinsic::x86_avx512_packusdw_512:
+ case Intrinsic::x86_avx512_packuswb_512:
+ if (Value *V = simplifyX86pack(*II, *this, *Builder, false))
+ return replaceInstUsesWith(*II, V);
+ break;
+
+ case Intrinsic::x86_pclmulqdq: {
+ if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) {
+ unsigned Imm = C->getZExtValue();
+
+ bool MadeChange = false;
+ Value *Arg0 = II->getArgOperand(0);
+ Value *Arg1 = II->getArgOperand(1);
+ unsigned VWidth = Arg0->getType()->getVectorNumElements();
+ APInt DemandedElts(VWidth, 0);
+
+ APInt UndefElts1(VWidth, 0);
+ DemandedElts = (Imm & 0x01) ? 2 : 1;
+ if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts,
+ UndefElts1)) {
+ II->setArgOperand(0, V);
+ MadeChange = true;
+ }
+
+ APInt UndefElts2(VWidth, 0);
+ DemandedElts = (Imm & 0x10) ? 2 : 1;
+ if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts,
+ UndefElts2)) {
+ II->setArgOperand(1, V);
+ MadeChange = true;
+ }
+
+ // If both input elements are undef, the result is undef.
+ if (UndefElts1[(Imm & 0x01) ? 1 : 0] ||
+ UndefElts2[(Imm & 0x10) ? 1 : 0])
+ return replaceInstUsesWith(*II,
+ ConstantAggregateZero::get(II->getType()));
+
+ if (MadeChange)
+ return II;
+ }
+ break;
+ }
+
case Intrinsic::x86_sse41_insertps:
if (Value *V = simplifyX86insertps(*II, *Builder))
return replaceInstUsesWith(*II, V);
@@ -2531,9 +3097,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
-
case Intrinsic::amdgcn_rcp: {
- if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) {
+ Value *Src = II->getArgOperand(0);
+
+ // TODO: Move to ConstantFolding/InstSimplify?
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(CI, Src);
+
+ if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) {
const APFloat &ArgVal = C->getValueAPF();
APFloat Val(ArgVal.getSemantics(), 1.0);
APFloat::opStatus Status = Val.divide(ArgVal,
@@ -2546,6 +3117,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::amdgcn_rsq: {
+ Value *Src = II->getArgOperand(0);
+
+ // TODO: Move to ConstantFolding/InstSimplify?
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(CI, Src);
+ break;
+ }
case Intrinsic::amdgcn_frexp_mant:
case Intrinsic::amdgcn_frexp_exp: {
Value *Src = II->getArgOperand(0);
@@ -2650,6 +3229,274 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result));
}
+ case Intrinsic::amdgcn_cvt_pkrtz: {
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+ if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
+ if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
+ const fltSemantics &HalfSem
+ = II->getType()->getScalarType()->getFltSemantics();
+ bool LosesInfo;
+ APFloat Val0 = C0->getValueAPF();
+ APFloat Val1 = C1->getValueAPF();
+ Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
+ Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo);
+
+ Constant *Folded = ConstantVector::get({
+ ConstantFP::get(II->getContext(), Val0),
+ ConstantFP::get(II->getContext(), Val1) });
+ return replaceInstUsesWith(*II, Folded);
+ }
+ }
+
+ if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1))
+ return replaceInstUsesWith(*II, UndefValue::get(II->getType()));
+
+ break;
+ }
+ case Intrinsic::amdgcn_ubfe:
+ case Intrinsic::amdgcn_sbfe: {
+ // Decompose simple cases into standard shifts.
+ Value *Src = II->getArgOperand(0);
+ if (isa<UndefValue>(Src))
+ return replaceInstUsesWith(*II, Src);
+
+ unsigned Width;
+ Type *Ty = II->getType();
+ unsigned IntSize = Ty->getIntegerBitWidth();
+
+ ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2));
+ if (CWidth) {
+ Width = CWidth->getZExtValue();
+ if ((Width & (IntSize - 1)) == 0)
+ return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty));
+
+ if (Width >= IntSize) {
+ // Hardware ignores high bits, so remove those.
+ II->setArgOperand(2, ConstantInt::get(CWidth->getType(),
+ Width & (IntSize - 1)));
+ return II;
+ }
+ }
+
+ unsigned Offset;
+ ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1));
+ if (COffset) {
+ Offset = COffset->getZExtValue();
+ if (Offset >= IntSize) {
+ II->setArgOperand(1, ConstantInt::get(COffset->getType(),
+ Offset & (IntSize - 1)));
+ return II;
+ }
+ }
+
+ bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe;
+
+ // TODO: Also emit sub if only width is constant.
+ if (!CWidth && COffset && Offset == 0) {
+ Constant *KSize = ConstantInt::get(COffset->getType(), IntSize);
+ Value *ShiftVal = Builder->CreateSub(KSize, II->getArgOperand(2));
+ ShiftVal = Builder->CreateZExt(ShiftVal, II->getType());
+
+ Value *Shl = Builder->CreateShl(Src, ShiftVal);
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Shl, ShiftVal) :
+ Builder->CreateLShr(Shl, ShiftVal);
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+
+ if (!CWidth || !COffset)
+ break;
+
+ // TODO: This allows folding to undef when the hardware has specific
+ // behavior?
+ if (Offset + Width < IntSize) {
+ Value *Shl = Builder->CreateShl(Src, IntSize - Offset - Width);
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Shl, IntSize - Width) :
+ Builder->CreateLShr(Shl, IntSize - Width);
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+
+ Value *RightShift = Signed ?
+ Builder->CreateAShr(Src, Offset) :
+ Builder->CreateLShr(Src, Offset);
+
+ RightShift->takeName(II);
+ return replaceInstUsesWith(*II, RightShift);
+ }
+ case Intrinsic::amdgcn_exp:
+ case Intrinsic::amdgcn_exp_compr: {
+ ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1));
+ if (!En) // Illegal.
+ break;
+
+ unsigned EnBits = En->getZExtValue();
+ if (EnBits == 0xf)
+ break; // All inputs enabled.
+
+ bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr;
+ bool Changed = false;
+ for (int I = 0; I < (IsCompr ? 2 : 4); ++I) {
+ if ((!IsCompr && (EnBits & (1 << I)) == 0) ||
+ (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) {
+ Value *Src = II->getArgOperand(I + 2);
+ if (!isa<UndefValue>(Src)) {
+ II->setArgOperand(I + 2, UndefValue::get(Src->getType()));
+ Changed = true;
+ }
+ }
+ }
+
+ if (Changed)
+ return II;
+
+ break;
+
+ }
+ case Intrinsic::amdgcn_fmed3: {
+ // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled
+ // for the shader.
+
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+ Value *Src2 = II->getArgOperand(2);
+
+ bool Swap = false;
+ // Canonicalize constants to RHS operands.
+ //
+ // fmed3(c0, x, c1) -> fmed3(x, c0, c1)
+ if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
+ std::swap(Src0, Src1);
+ Swap = true;
+ }
+
+ if (isa<Constant>(Src1) && !isa<Constant>(Src2)) {
+ std::swap(Src1, Src2);
+ Swap = true;
+ }
+
+ if (isa<Constant>(Src0) && !isa<Constant>(Src1)) {
+ std::swap(Src0, Src1);
+ Swap = true;
+ }
+
+ if (Swap) {
+ II->setArgOperand(0, Src0);
+ II->setArgOperand(1, Src1);
+ II->setArgOperand(2, Src2);
+ return II;
+ }
+
+ if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) {
+ CallInst *NewCall = Builder->CreateMinNum(Src0, Src1);
+ NewCall->copyFastMathFlags(II);
+ NewCall->takeName(II);
+ return replaceInstUsesWith(*II, NewCall);
+ }
+
+ if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) {
+ if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) {
+ if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) {
+ APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(),
+ C2->getValueAPF());
+ return replaceInstUsesWith(*II,
+ ConstantFP::get(Builder->getContext(), Result));
+ }
+ }
+ }
+
+ break;
+ }
+ case Intrinsic::amdgcn_icmp:
+ case Intrinsic::amdgcn_fcmp: {
+ const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2));
+ if (!CC)
+ break;
+
+ // Guard against invalid arguments.
+ int64_t CCVal = CC->getZExtValue();
+ bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp;
+ if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE ||
+ CCVal > CmpInst::LAST_ICMP_PREDICATE)) ||
+ (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE ||
+ CCVal > CmpInst::LAST_FCMP_PREDICATE)))
+ break;
+
+ Value *Src0 = II->getArgOperand(0);
+ Value *Src1 = II->getArgOperand(1);
+
+ if (auto *CSrc0 = dyn_cast<Constant>(Src0)) {
+ if (auto *CSrc1 = dyn_cast<Constant>(Src1)) {
+ Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1);
+ return replaceInstUsesWith(*II,
+ ConstantExpr::getSExt(CCmp, II->getType()));
+ }
+
+ // Canonicalize constants to RHS.
+ CmpInst::Predicate SwapPred
+ = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal));
+ II->setArgOperand(0, Src1);
+ II->setArgOperand(1, Src0);
+ II->setArgOperand(2, ConstantInt::get(CC->getType(),
+ static_cast<int>(SwapPred)));
+ return II;
+ }
+
+ if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE)
+ break;
+
+ // Canonicalize compare eq with true value to compare != 0
+ // llvm.amdgcn.icmp(zext (i1 x), 1, eq)
+ // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne)
+ // llvm.amdgcn.icmp(sext (i1 x), -1, eq)
+ // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne)
+ Value *ExtSrc;
+ if (CCVal == CmpInst::ICMP_EQ &&
+ ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) ||
+ (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) &&
+ ExtSrc->getType()->isIntegerTy(1)) {
+ II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType()));
+ II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE));
+ return II;
+ }
+
+ CmpInst::Predicate SrcPred;
+ Value *SrcLHS;
+ Value *SrcRHS;
+
+ // Fold compare eq/ne with 0 from a compare result as the predicate to the
+ // intrinsic. The typical use is a wave vote function in the library, which
+ // will be fed from a user code condition compared with 0. Fold in the
+ // redundant compare.
+
+ // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne)
+ // -> llvm.amdgcn.[if]cmp(a, b, pred)
+ //
+ // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq)
+ // -> llvm.amdgcn.[if]cmp(a, b, inv pred)
+ if (match(Src1, m_Zero()) &&
+ match(Src0,
+ m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) {
+ if (CCVal == CmpInst::ICMP_EQ)
+ SrcPred = CmpInst::getInversePredicate(SrcPred);
+
+ Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ?
+ Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp;
+
+ Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID,
+ SrcLHS->getType());
+ Value *Args[] = { SrcLHS, SrcRHS,
+ ConstantInt::get(CC->getType(), SrcPred) };
+ CallInst *NewCall = Builder->CreateCall(NewF, Args);
+ NewCall->takeName(II);
+ return replaceInstUsesWith(*II, NewCall);
+ }
+
+ break;
+ }
case Intrinsic::stackrestore: {
// If the save is right next to the restore, remove the restore. This can
// happen when variable allocas are DCE'd.
@@ -2790,7 +3637,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// isKnownNonNull -> nonnull attribute
if (isKnownNonNullAt(DerivedPtr, II, &DT))
- II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull);
+ II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull);
}
// TODO: bitcast(relocate(p)) -> relocate(bitcast(p))
@@ -2799,11 +3646,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) {
// TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...)
break;
}
- }
+ case Intrinsic::experimental_guard: {
+ // Is this guard followed by another guard?
+ Instruction *NextInst = II->getNextNode();
+ Value *NextCond = nullptr;
+ if (match(NextInst,
+ m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) {
+ Value *CurrCond = II->getArgOperand(0);
+
+ // Remove a guard that it is immediately preceded by an identical guard.
+ if (CurrCond == NextCond)
+ return eraseInstFromFunction(*NextInst);
+
+ // Otherwise canonicalize guard(a); guard(b) -> guard(a & b).
+ II->setArgOperand(0, Builder->CreateAnd(CurrCond, NextCond));
+ return eraseInstFromFunction(*NextInst);
+ }
+ break;
+ }
+ }
return visitCallSite(II);
}
+// Fence instruction simplification
+Instruction *InstCombiner::visitFenceInst(FenceInst &FI) {
+ // Remove identical consecutive fences.
+ if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode()))
+ if (FI.isIdenticalTo(NFI))
+ return eraseInstFromFunction(FI);
+ return nullptr;
+}
+
// InvokeInst simplification
//
Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) {
@@ -2950,7 +3824,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
for (Value *V : CS.args()) {
if (V->getType()->isPointerTy() &&
- !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) &&
+ !CS.paramHasAttr(ArgNo, Attribute::NonNull) &&
isKnownNonNullAt(V, CS.getInstruction(), &DT))
Indices.push_back(ArgNo + 1);
ArgNo++;
@@ -2959,7 +3833,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) {
assert(ArgNo == CS.arg_size() && "sanity check");
if (!Indices.empty()) {
- AttributeSet AS = CS.getAttributes();
+ AttributeList AS = CS.getAttributes();
LLVMContext &Ctx = CS.getInstruction()->getContext();
AS = AS.addAttribute(Ctx, Indices,
Attribute::get(Ctx, Attribute::NonNull));
@@ -3081,7 +3955,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
return false;
Instruction *Caller = CS.getInstruction();
- const AttributeSet &CallerPAL = CS.getAttributes();
+ const AttributeList &CallerPAL = CS.getAttributes();
// Okay, this is a cast from a function to a different type. Unless doing so
// would cause a type conversion of one of our arguments, change this call to
@@ -3108,7 +3982,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
}
if (!CallerPAL.isEmpty() && !Caller->use_empty()) {
- AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex);
+ AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);
if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy)))
return false; // Attribute not compatible with transformed value.
}
@@ -3149,8 +4023,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL))
return false; // Cannot transform this parameter value.
- if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1).
- overlaps(AttributeFuncs::typeIncompatible(ParamTy)))
+ if (AttrBuilder(CallerPAL.getParamAttributes(i))
+ .overlaps(AttributeFuncs::typeIncompatible(ParamTy)))
return false; // Attribute not compatible with transformed value.
if (CS.isInAllocaArgument(i))
@@ -3158,9 +4032,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// If the parameter is passed as a byval argument, then we have to have a
// sized type and the sized type has to have the same size as the old type.
- if (ParamTy != ActTy &&
- CallerPAL.getParamAttributes(i + 1).hasAttribute(i + 1,
- Attribute::ByVal)) {
+ if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) {
PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy);
if (!ParamPTy || !ParamPTy->getElementType()->isSized())
return false;
@@ -3205,7 +4077,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
break;
// Check if it has an attribute that's incompatible with varargs.
- AttributeSet PAttrs = CallerPAL.getSlotAttributes(i - 1);
+ AttributeList PAttrs = CallerPAL.getSlotAttributes(i - 1);
if (PAttrs.hasAttribute(Index, Attribute::StructRet))
return false;
}
@@ -3213,44 +4085,37 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// Okay, we decided that this is a safe thing to do: go ahead and start
// inserting cast instructions as necessary.
- std::vector<Value*> Args;
+ SmallVector<Value *, 8> Args;
+ SmallVector<AttributeSet, 8> ArgAttrs;
Args.reserve(NumActualArgs);
- SmallVector<AttributeSet, 8> attrVec;
- attrVec.reserve(NumCommonArgs);
+ ArgAttrs.reserve(NumActualArgs);
// Get any return attributes.
- AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex);
+ AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex);
// If the return value is not being used, the type may not be compatible
// with the existing attributes. Wipe out any problematic attributes.
RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy));
- // Add the new return attributes.
- if (RAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(Caller->getContext(),
- AttributeSet::ReturnIndex, RAttrs));
-
AI = CS.arg_begin();
for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) {
Type *ParamTy = FT->getParamType(i);
- if ((*AI)->getType() == ParamTy) {
- Args.push_back(*AI);
- } else {
- Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy));
- }
+ Value *NewArg = *AI;
+ if ((*AI)->getType() != ParamTy)
+ NewArg = Builder->CreateBitOrPointerCast(*AI, ParamTy);
+ Args.push_back(NewArg);
// Add any parameter attributes.
- AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1);
- if (PAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(Caller->getContext(), i + 1,
- PAttrs));
+ ArgAttrs.push_back(CallerPAL.getParamAttributes(i));
}
// If the function takes more arguments than the call was taking, add them
// now.
- for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i)
+ for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) {
Args.push_back(Constant::getNullValue(FT->getParamType(i)));
+ ArgAttrs.push_back(AttributeSet());
+ }
// If we are removing arguments to the function, emit an obnoxious warning.
if (FT->getNumParams() < NumActualArgs) {
@@ -3259,54 +4124,56 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) {
// Add all of the arguments in their promoted form to the arg list.
for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) {
Type *PTy = getPromotedType((*AI)->getType());
+ Value *NewArg = *AI;
if (PTy != (*AI)->getType()) {
// Must promote to pass through va_arg area!
Instruction::CastOps opcode =
CastInst::getCastOpcode(*AI, false, PTy, false);
- Args.push_back(Builder->CreateCast(opcode, *AI, PTy));
- } else {
- Args.push_back(*AI);
+ NewArg = Builder->CreateCast(opcode, *AI, PTy);
}
+ Args.push_back(NewArg);
// Add any parameter attributes.
- AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1);
- if (PAttrs.hasAttributes())
- attrVec.push_back(AttributeSet::get(FT->getContext(), i + 1,
- PAttrs));
+ ArgAttrs.push_back(CallerPAL.getParamAttributes(i));
}
}
}
AttributeSet FnAttrs = CallerPAL.getFnAttributes();
- if (CallerPAL.hasAttributes(AttributeSet::FunctionIndex))
- attrVec.push_back(AttributeSet::get(Callee->getContext(), FnAttrs));
if (NewRetTy->isVoidTy())
Caller->setName(""); // Void type should not have a name.
- const AttributeSet &NewCallerPAL = AttributeSet::get(Callee->getContext(),
- attrVec);
+ assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) &&
+ "missing argument attributes");
+ LLVMContext &Ctx = Callee->getContext();
+ AttributeList NewCallerPAL = AttributeList::get(
+ Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs);
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
- Instruction *NC;
+ CallSite NewCS;
if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) {
- NC = Builder->CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(),
- Args, OpBundles);
- NC->takeName(II);
- cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv());
- cast<InvokeInst>(NC)->setAttributes(NewCallerPAL);
+ NewCS = Builder->CreateInvoke(Callee, II->getNormalDest(),
+ II->getUnwindDest(), Args, OpBundles);
} else {
- CallInst *CI = cast<CallInst>(Caller);
- NC = Builder->CreateCall(Callee, Args, OpBundles);
- NC->takeName(CI);
- cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind());
- cast<CallInst>(NC)->setCallingConv(CI->getCallingConv());
- cast<CallInst>(NC)->setAttributes(NewCallerPAL);
+ NewCS = Builder->CreateCall(Callee, Args, OpBundles);
+ cast<CallInst>(NewCS.getInstruction())
+ ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind());
}
+ NewCS->takeName(Caller);
+ NewCS.setCallingConv(CS.getCallingConv());
+ NewCS.setAttributes(NewCallerPAL);
+
+ // Preserve the weight metadata for the new call instruction. The metadata
+ // is used by SamplePGO to check callsite's hotness.
+ uint64_t W;
+ if (Caller->extractProfTotalWeight(W))
+ NewCS->setProfWeight(W);
// Insert a cast of the return type as necessary.
+ Instruction *NC = NewCS.getInstruction();
Value *NV = NC;
if (OldRetTy != NV->getType() && !Caller->use_empty()) {
if (!NV->getType()->isVoidTy()) {
@@ -3351,7 +4218,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
Value *Callee = CS.getCalledValue();
PointerType *PTy = cast<PointerType>(Callee->getType());
FunctionType *FTy = cast<FunctionType>(PTy->getElementType());
- const AttributeSet &Attrs = CS.getAttributes();
+ AttributeList Attrs = CS.getAttributes();
// If the call already has the 'nest' attribute somewhere then give up -
// otherwise 'nest' would occur twice after splicing in the chain.
@@ -3364,50 +4231,46 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts());
FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType());
- const AttributeSet &NestAttrs = NestF->getAttributes();
+ AttributeList NestAttrs = NestF->getAttributes();
if (!NestAttrs.isEmpty()) {
- unsigned NestIdx = 1;
+ unsigned NestArgNo = 0;
Type *NestTy = nullptr;
AttributeSet NestAttr;
// Look for a parameter marked with the 'nest' attribute.
for (FunctionType::param_iterator I = NestFTy->param_begin(),
- E = NestFTy->param_end(); I != E; ++NestIdx, ++I)
- if (NestAttrs.hasAttribute(NestIdx, Attribute::Nest)) {
+ E = NestFTy->param_end();
+ I != E; ++NestArgNo, ++I) {
+ AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo);
+ if (AS.hasAttribute(Attribute::Nest)) {
// Record the parameter type and any other attributes.
NestTy = *I;
- NestAttr = NestAttrs.getParamAttributes(NestIdx);
+ NestAttr = AS;
break;
}
+ }
if (NestTy) {
Instruction *Caller = CS.getInstruction();
std::vector<Value*> NewArgs;
+ std::vector<AttributeSet> NewArgAttrs;
NewArgs.reserve(CS.arg_size() + 1);
-
- SmallVector<AttributeSet, 8> NewAttrs;
- NewAttrs.reserve(Attrs.getNumSlots() + 1);
+ NewArgAttrs.reserve(CS.arg_size());
// Insert the nest argument into the call argument list, which may
// mean appending it. Likewise for attributes.
- // Add any result attributes.
- if (Attrs.hasAttributes(AttributeSet::ReturnIndex))
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- Attrs.getRetAttributes()));
-
{
- unsigned Idx = 1;
+ unsigned ArgNo = 0;
CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end();
do {
- if (Idx == NestIdx) {
+ if (ArgNo == NestArgNo) {
// Add the chain argument and attributes.
Value *NestVal = Tramp->getArgOperand(2);
if (NestVal->getType() != NestTy)
NestVal = Builder->CreateBitCast(NestVal, NestTy, "nest");
NewArgs.push_back(NestVal);
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- NestAttr));
+ NewArgAttrs.push_back(NestAttr);
}
if (I == E)
@@ -3415,23 +4278,13 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Add the original argument and attributes.
NewArgs.push_back(*I);
- AttributeSet Attr = Attrs.getParamAttributes(Idx);
- if (Attr.hasAttributes(Idx)) {
- AttrBuilder B(Attr, Idx);
- NewAttrs.push_back(AttributeSet::get(Caller->getContext(),
- Idx + (Idx >= NestIdx), B));
- }
+ NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo));
- ++Idx;
+ ++ArgNo;
++I;
} while (true);
}
- // Add any function attributes.
- if (Attrs.hasAttributes(AttributeSet::FunctionIndex))
- NewAttrs.push_back(AttributeSet::get(FTy->getContext(),
- Attrs.getFnAttributes()));
-
// The trampoline may have been bitcast to a bogus type (FTy).
// Handle this by synthesizing a new function type, equal to FTy
// with the chain parameter inserted.
@@ -3442,12 +4295,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Insert the chain's type into the list of parameter types, which may
// mean appending it.
{
- unsigned Idx = 1;
+ unsigned ArgNo = 0;
FunctionType::param_iterator I = FTy->param_begin(),
E = FTy->param_end();
do {
- if (Idx == NestIdx)
+ if (ArgNo == NestArgNo)
// Add the chain's type.
NewTypes.push_back(NestTy);
@@ -3457,7 +4310,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
// Add the original type.
NewTypes.push_back(*I);
- ++Idx;
+ ++ArgNo;
++I;
} while (true);
}
@@ -3470,8 +4323,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS,
NestF->getType() == PointerType::getUnqual(NewFTy) ?
NestF : ConstantExpr::getBitCast(NestF,
PointerType::getUnqual(NewFTy));
- const AttributeSet &NewPAL =
- AttributeSet::get(FTy->getContext(), NewAttrs);
+ AttributeList NewPAL =
+ AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(),
+ Attrs.getRetAttributes(), NewArgAttrs);
SmallVector<OperandBundleDef, 1> OpBundles;
CS.getOperandBundlesAsDefs(OpBundles);
diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp
index e74b590e2b7c..25683132c786 100644
--- a/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -274,12 +274,12 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) {
return NV;
// If we are casting a PHI, then fold the cast into the PHI.
- if (isa<PHINode>(Src)) {
+ if (auto *PN = dyn_cast<PHINode>(Src)) {
// Don't do this if it would create a PHI node with an illegal type from a
// legal type.
if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() ||
- ShouldChangeType(CI.getType(), Src->getType()))
- if (Instruction *NV = FoldOpIntoPhi(CI))
+ shouldChangeType(CI.getType(), Src->getType()))
+ if (Instruction *NV = foldOpIntoPhi(CI, PN))
return NV;
}
@@ -447,7 +447,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC,
Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {
Type *SrcTy = Trunc.getSrcTy();
Type *DestTy = Trunc.getType();
- if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy))
+ if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy))
return nullptr;
BinaryOperator *LogicOp;
@@ -463,6 +463,56 @@ Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) {
return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC);
}
+/// Try to narrow the width of a splat shuffle. This could be generalized to any
+/// shuffle with a constant operand, but we limit the transform to avoid
+/// creating a shuffle type that targets may not be able to lower effectively.
+static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
+ InstCombiner::BuilderTy &Builder) {
+ auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
+ if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) &&
+ Shuf->getMask()->getSplatValue() &&
+ Shuf->getType() == Shuf->getOperand(0)->getType()) {
+ // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask
+ Constant *NarrowUndef = UndefValue::get(Trunc.getType());
+ Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType());
+ return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask());
+ }
+
+ return nullptr;
+}
+
+/// Try to narrow the width of an insert element. This could be generalized for
+/// any vector constant, but we limit the transform to insertion into undef to
+/// avoid potential backend problems from unsupported insertion widths. This
+/// could also be extended to handle the case of inserting a scalar constant
+/// into a vector variable.
+static Instruction *shrinkInsertElt(CastInst &Trunc,
+ InstCombiner::BuilderTy &Builder) {
+ Instruction::CastOps Opcode = Trunc.getOpcode();
+ assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) &&
+ "Unexpected instruction for shrinking");
+
+ auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0));
+ if (!InsElt || !InsElt->hasOneUse())
+ return nullptr;
+
+ Type *DestTy = Trunc.getType();
+ Type *DestScalarTy = DestTy->getScalarType();
+ Value *VecOp = InsElt->getOperand(0);
+ Value *ScalarOp = InsElt->getOperand(1);
+ Value *Index = InsElt->getOperand(2);
+
+ if (isa<UndefValue>(VecOp)) {
+ // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index
+ // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index
+ UndefValue *NarrowUndef = UndefValue::get(DestTy);
+ Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy);
+ return InsertElementInst::Create(NarrowUndef, NarrowOp, Index);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
if (Instruction *Result = commonCastTransforms(CI))
return Result;
@@ -488,7 +538,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
// type. Only do this if the dest type is a simple type, don't convert the
// expression tree to something weird like i93 unless the source is also
// strange.
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateTruncated(Src, DestTy, *this, &CI)) {
// If this cast is a truncate, evaluting in a different type always
@@ -554,8 +604,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) {
if (Instruction *I = shrinkBitwiseLogic(CI))
return I;
+ if (Instruction *I = shrinkSplatShuffle(CI, *Builder))
+ return I;
+
+ if (Instruction *I = shrinkInsertElt(CI, *Builder))
+ return I;
+
if (Src->hasOneUse() && isa<IntegerType>(SrcTy) &&
- ShouldChangeType(SrcTy, DestTy)) {
+ shouldChangeType(SrcTy, DestTy)) {
// Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the
// dest type is native and cst < dest size.
if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) &&
@@ -838,11 +894,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
if (Instruction *Result = commonCastTransforms(CI))
return Result;
- // See if we can simplify any instructions used by the input whose sole
- // purpose is to compute bits we don't care about.
- if (SimplifyDemandedInstructionBits(CI))
- return &CI;
-
Value *Src = CI.getOperand(0);
Type *SrcTy = Src->getType(), *DestTy = CI.getType();
@@ -851,10 +902,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) {
// expression tree to something weird like i93 unless the source is also
// strange.
unsigned BitsToClear;
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) {
- assert(BitsToClear < SrcTy->getScalarSizeInBits() &&
- "Unreasonable BitsToClear");
+ assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
+ "Can't clear more bits than in SrcTy");
// Okay, we can transform this! Insert the new expression now.
DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -1124,11 +1175,6 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
if (Instruction *I = commonCastTransforms(CI))
return I;
- // See if we can simplify any instructions used by the input whose sole
- // purpose is to compute bits we don't care about.
- if (SimplifyDemandedInstructionBits(CI))
- return &CI;
-
Value *Src = CI.getOperand(0);
Type *SrcTy = Src->getType(), *DestTy = CI.getType();
@@ -1145,7 +1191,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
// type. Only do this if the dest type is a simple type, don't convert the
// expression tree to something weird like i93 unless the source is also
// strange.
- if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) &&
+ if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) &&
canEvaluateSExtd(Src, DestTy)) {
// Okay, we can transform this! Insert the new expression now.
DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type"
@@ -1167,18 +1213,16 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) {
ShAmt);
}
- // If this input is a trunc from our destination, then turn sext(trunc(x))
+ // If the input is a trunc from the destination type, then turn sext(trunc(x))
// into shifts.
- if (TruncInst *TI = dyn_cast<TruncInst>(Src))
- if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) {
- uint32_t SrcBitSize = SrcTy->getScalarSizeInBits();
- uint32_t DestBitSize = DestTy->getScalarSizeInBits();
-
- // We need to emit a shl + ashr to do the sign extend.
- Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
- Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext");
- return BinaryOperator::CreateAShr(Res, ShAmt);
- }
+ Value *X;
+ if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) {
+ // sext(trunc(X)) --> ashr(shl(X, C), C)
+ unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
+ unsigned DestBitSize = DestTy->getScalarSizeInBits();
+ Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize);
+ return BinaryOperator::CreateAShr(Builder->CreateShl(X, ShAmt), ShAmt);
+ }
if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
return transformSExtICmp(ICI, CI);
@@ -1225,17 +1269,15 @@ static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
return nullptr;
}
-/// If this is a floating-point extension instruction, look
-/// through it until we get the source value.
+/// Look through floating-point extensions until we get the source value.
static Value *lookThroughFPExtensions(Value *V) {
- if (Instruction *I = dyn_cast<Instruction>(V))
- if (I->getOpcode() == Instruction::FPExt)
- return lookThroughFPExtensions(I->getOperand(0));
+ while (auto *FPExt = dyn_cast<FPExtInst>(V))
+ V = FPExt->getOperand(0);
// If this value is a constant, return the constant in the smallest FP type
// that can accurately represent it. This allows us to turn
// (float)((double)X+2.0) into x+2.0f.
- if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) {
+ if (auto *CFP = dyn_cast<ConstantFP>(V)) {
if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext()))
return V; // No constant folding of this.
// See if the value can be truncated to half and then reextended.
@@ -1392,24 +1434,49 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) {
IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0));
if (II) {
switch (II->getIntrinsicID()) {
- default: break;
- case Intrinsic::fabs: {
- // (fptrunc (fabs x)) -> (fabs (fptrunc x))
- Value *InnerTrunc = Builder->CreateFPTrunc(II->getArgOperand(0),
- CI.getType());
- Type *IntrinsicType[] = { CI.getType() };
- Function *Overload = Intrinsic::getDeclaration(
- CI.getModule(), II->getIntrinsicID(), IntrinsicType);
-
- SmallVector<OperandBundleDef, 1> OpBundles;
- II->getOperandBundlesAsDefs(OpBundles);
-
- Value *Args[] = { InnerTrunc };
- return CallInst::Create(Overload, Args, OpBundles, II->getName());
+ default: break;
+ case Intrinsic::fabs:
+ case Intrinsic::ceil:
+ case Intrinsic::floor:
+ case Intrinsic::rint:
+ case Intrinsic::round:
+ case Intrinsic::nearbyint:
+ case Intrinsic::trunc: {
+ Value *Src = II->getArgOperand(0);
+ if (!Src->hasOneUse())
+ break;
+
+ // Except for fabs, this transformation requires the input of the unary FP
+ // operation to be itself an fpext from the type to which we're
+ // truncating.
+ if (II->getIntrinsicID() != Intrinsic::fabs) {
+ FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src);
+ if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType())
+ break;
}
+
+ // Do unary FP operation on smaller type.
+ // (fptrunc (fabs x)) -> (fabs (fptrunc x))
+ Value *InnerTrunc = Builder->CreateFPTrunc(Src, CI.getType());
+ Type *IntrinsicType[] = { CI.getType() };
+ Function *Overload = Intrinsic::getDeclaration(
+ CI.getModule(), II->getIntrinsicID(), IntrinsicType);
+
+ SmallVector<OperandBundleDef, 1> OpBundles;
+ II->getOperandBundlesAsDefs(OpBundles);
+
+ Value *Args[] = { InnerTrunc };
+ CallInst *NewCI = CallInst::Create(Overload, Args,
+ OpBundles, II->getName());
+ NewCI->copyFastMathFlags(II);
+ return NewCI;
+ }
}
}
+ if (Instruction *I = shrinkInsertElt(CI, *Builder))
+ return I;
+
return nullptr;
}
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 428f94bb5e93..bbafa9e9f468 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -230,7 +230,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
return nullptr;
uint64_t ArrayElementCount = Init->getType()->getArrayNumElements();
- if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays.
+ // Don't blow up on huge arrays.
+ if (ArrayElementCount > MaxArraySizeForCombine)
+ return nullptr;
// There are many forms of this optimization we can handle, for now, just do
// the simple index into a single-dimensional array.
@@ -1663,7 +1665,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp,
(Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) {
// TODO: Is this a good transform for vectors? Wider types may reduce
// throughput. Should this transform be limited (even for scalars) by using
- // ShouldChangeType()?
+ // shouldChangeType()?
if (!Cmp.getType()->isVectorTy()) {
Type *WideType = W->getType();
unsigned WideScalarBits = WideType->getScalarSizeInBits();
@@ -1792,6 +1794,15 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or,
ConstantInt::get(V->getType(), 1));
}
+ // X | C == C --> X <=u C
+ // X | C != C --> X >u C
+ // iff C+1 is a power of 2 (C is a bitmask of the low bits)
+ if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) &&
+ (*C + 1).isPowerOf2()) {
+ Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT;
+ return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1));
+ }
+
if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse())
return nullptr;
@@ -1914,61 +1925,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
ICmpInst::Predicate Pred = Cmp.getPredicate();
Value *X = Shl->getOperand(0);
- if (Cmp.isEquality()) {
- // If the shift is NUW, then it is just shifting out zeros, no need for an
- // AND.
- Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt));
- if (Shl->hasNoUnsignedWrap())
- return new ICmpInst(Pred, X, LShrC);
-
- // If the shift is NSW and we compare to 0, then it is just shifting out
- // sign bits, no need for an AND either.
- if (Shl->hasNoSignedWrap() && *C == 0)
- return new ICmpInst(Pred, X, LShrC);
-
- if (Shl->hasOneUse()) {
- // Otherwise, strength reduce the shift into an and.
- Constant *Mask = ConstantInt::get(Shl->getType(),
- APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
-
- Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
- return new ICmpInst(Pred, And, LShrC);
+ Type *ShType = Shl->getType();
+
+ // NSW guarantees that we are only shifting out sign bits from the high bits,
+ // so we can ASHR the compare constant without needing a mask and eliminate
+ // the shift.
+ if (Shl->hasNoSignedWrap()) {
+ if (Pred == ICmpInst::ICMP_SGT) {
+ // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt)
+ APInt ShiftedC = C->ashr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
+ // This is the same code as the SGT case, but assert the pre-condition
+ // that is needed for this to work with equality predicates.
+ assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ "Compare known true or false was not folded");
+ APInt ShiftedC = C->ashr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_SLT) {
+ // SLE is the same as above, but SLE is canonicalized to SLT, so convert:
+ // (X << S) <=s C is equiv to X <=s (C >> S) for all C
+ // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX
+ // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN
+ assert(!C->isMinSignedValue() && "Unexpected icmp slt");
+ APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1;
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ // If this is a signed comparison to 0 and the shift is sign preserving,
+ // use the shift LHS operand instead; isSignTest may change 'Pred', so only
+ // do that if we're sure to not continue on in this function.
+ if (isSignTest(Pred, *C))
+ return new ICmpInst(Pred, X, Constant::getNullValue(ShType));
+ }
+
+ // NUW guarantees that we are only shifting out zero bits from the high bits,
+ // so we can LSHR the compare constant without needing a mask and eliminate
+ // the shift.
+ if (Shl->hasNoUnsignedWrap()) {
+ if (Pred == ICmpInst::ICMP_UGT) {
+ // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt)
+ APInt ShiftedC = C->lshr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) {
+ // This is the same code as the UGT case, but assert the pre-condition
+ // that is needed for this to work with equality predicates.
+ assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C &&
+ "Compare known true or false was not folded");
+ APInt ShiftedC = C->lshr(*ShiftAmt);
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
+ }
+ if (Pred == ICmpInst::ICMP_ULT) {
+ // ULE is the same as above, but ULE is canonicalized to ULT, so convert:
+ // (X << S) <=u C is equiv to X <=u (C >> S) for all C
+ // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u
+ // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0
+ assert(C->ugt(0) && "ult 0 should have been eliminated");
+ APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1;
+ return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC));
}
}
- // If this is a signed comparison to 0 and the shift is sign preserving,
- // use the shift LHS operand instead; isSignTest may change 'Pred', so only
- // do that if we're sure to not continue on in this function.
- if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C))
- return new ICmpInst(Pred, X, Constant::getNullValue(X->getType()));
+ if (Cmp.isEquality() && Shl->hasOneUse()) {
+ // Strength-reduce the shift into an 'and'.
+ Constant *Mask = ConstantInt::get(
+ ShType,
+ APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue()));
+ Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
+ Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt));
+ return new ICmpInst(Pred, And, LShrC);
+ }
// Otherwise, if this is a comparison of the sign bit, simplify to and/test.
bool TrueIfSigned = false;
if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) {
// (X << 31) <s 0 --> (X & 1) != 0
Constant *Mask = ConstantInt::get(
- X->getType(),
+ ShType,
APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1));
Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask");
return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ,
- And, Constant::getNullValue(And->getType()));
- }
-
- // When the shift is nuw and pred is >u or <=u, comparison only really happens
- // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the
- // <=u case can be further converted to match <u (see below).
- if (Shl->hasNoUnsignedWrap() &&
- (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) {
- // Derivation for the ult case:
- // (X << S) <=u C is equiv to X <=u (C >> S) for all C
- // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u
- // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0
- assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) &&
- "Encountered `ult 0` that should have been eliminated by "
- "InstSimplify.");
- APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1
- : C->lshr(*ShiftAmt);
- return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC));
+ And, Constant::getNullValue(ShType));
}
// Transform (icmp pred iM (shl iM %v, N), C)
@@ -1981,8 +2020,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp,
if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt &&
DL.isLegalInteger(TypeBits - Amt)) {
Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt);
- if (X->getType()->isVectorTy())
- TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements());
+ if (ShType->isVectorTy())
+ TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements());
Constant *NewC =
ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt));
return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC);
@@ -2342,8 +2381,24 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
// Fold icmp pred (add X, C2), C.
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
- auto CR =
- ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2);
+ CmpInst::Predicate Pred = Cmp.getPredicate();
+
+ // If the add does not wrap, we can always adjust the compare by subtracting
+ // the constants. Equality comparisons are handled elsewhere. SGE/SLE are
+ // canonicalized to SGT/SLT.
+ if (Add->hasNoSignedWrap() &&
+ (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) {
+ bool Overflow;
+ APInt NewC = C->ssub_ov(*C2, Overflow);
+ // If there is overflow, the result must be true or false.
+ // TODO: Can we assert there is no overflow because InstSimplify always
+ // handles those cases?
+ if (!Overflow)
+ // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2)
+ return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC));
+ }
+
+ auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2);
const APInt &Upper = CR.getUpper();
const APInt &Lower = CR.getLower();
if (Cmp.isSigned()) {
@@ -2364,16 +2419,14 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp,
// X+C <u C2 -> (X & -C2) == C
// iff C & (C2-1) == 0
// C2 is a power of 2
- if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() &&
- (*C2 & (*C - 1)) == 0)
+ if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0)
return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)),
ConstantExpr::getNeg(cast<Constant>(Y)));
// X+C >u C2 -> (X & ~C2) != C
// iff C & C2 == 0
// C2+1 is a power of 2
- if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() &&
- (*C2 & *C) == 0)
+ if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0)
return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)),
ConstantExpr::getNeg(cast<Constant>(Y)));
@@ -2656,7 +2709,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) {
// block. If in the same block, we're encouraging jump threading. If
// not, we are just pessimizing the code by making an i1 phi.
if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::Select: {
@@ -2767,12 +2820,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
D = BO1->getOperand(1);
}
- // icmp (X+cst) < 0 --> X < -cst
- if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero()))
- if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B))
- if (!RHSC->isMinValue(/*isSigned=*/true))
- return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC));
-
// icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow.
if ((A == Op1 || B == Op1) && NoOp0WrapProblem)
return new ICmpInst(Pred, A == Op1 ? B : A,
@@ -2847,6 +2894,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) {
if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One()))
return new ICmpInst(CmpInst::ICMP_SLE, Op0, C);
+ // TODO: The subtraction-related identities shown below also hold, but
+ // canonicalization from (X -nuw 1) to (X + -1) means that the combinations
+ // wouldn't happen even if they were implemented.
+ //
+ // icmp ult (X - 1), Y -> icmp ule X, Y
+ // icmp uge (X - 1), Y -> icmp ugt X, Y
+ // icmp ugt X, (Y - 1) -> icmp uge X, Y
+ // icmp ule X, (Y - 1) -> icmp ult X, Y
+
+ // icmp ule (X + 1), Y -> icmp ult X, Y
+ if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One()))
+ return new ICmpInst(CmpInst::ICMP_ULT, A, Op1);
+
+ // icmp ugt (X + 1), Y -> icmp uge X, Y
+ if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One()))
+ return new ICmpInst(CmpInst::ICMP_UGE, A, Op1);
+
+ // icmp uge X, (Y + 1) -> icmp ugt X, Y
+ if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One()))
+ return new ICmpInst(CmpInst::ICMP_UGT, Op0, C);
+
+ // icmp ult X, (Y + 1) -> icmp ule X, Y
+ if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One()))
+ return new ICmpInst(CmpInst::ICMP_ULE, Op0, C);
+
// if C1 has greater magnitude than C2:
// icmp (X + C1), (Y + C2) -> icmp (X + C3), Y
// s.t. C3 = C1 - C2
@@ -3738,16 +3810,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth,
// greater than the RHS must differ in a bit higher than these due to carry.
case ICmpInst::ICMP_UGT: {
unsigned trailingOnes = RHS.countTrailingOnes();
- APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes);
- return ~lowBitsSet;
+ return APInt::getBitsSetFrom(BitWidth, trailingOnes);
}
// Similarly, for a ULT comparison, we don't care about the trailing zeros.
// Any value less than the RHS must differ in a higher bit because of carries.
case ICmpInst::ICMP_ULT: {
unsigned trailingZeros = RHS.countTrailingZeros();
- APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros);
- return ~lowBitsSet;
+ return APInt::getBitsSetFrom(BitWidth, trailingZeros);
}
default:
@@ -3887,7 +3957,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!");
if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) {
BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1);
- // The check for the unique predecessor is not the best that can be
+ // The check for the single predecessor is not the best that can be
// done. But it protects efficiently against cases like when SI's
// home block has two successors, Succ and Succ1, and Succ1 predecessor
// of Succ. Then SI can't be replaced by SIOpd because the use that gets
@@ -3895,8 +3965,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI,
// guarantees that the path all uses of SI (outside SI's parent) are on
// is disjoint from all other paths out of SI. But that information
// is more expensive to compute, and the trade-off here is in favor
- // of compile-time.
- if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
+ // of compile-time. It should also be noticed that we check for a single
+ // predecessor and not only uniqueness. This to handle the situation when
+ // Succ and Succ1 points to the same basic block.
+ if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) {
NumSel++;
SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent());
return true;
@@ -3932,12 +4004,12 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) {
APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0);
APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0);
- if (SimplifyDemandedBits(I.getOperandUse(0),
+ if (SimplifyDemandedBits(&I, 0,
getDemandedBitsLHSMask(I, BitWidth, IsSignBit),
Op0KnownZero, Op0KnownOne, 0))
return &I;
- if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth),
+ if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth),
Op1KnownZero, Op1KnownOne, 0))
return &I;
@@ -4801,7 +4873,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) {
// block. If in the same block, we're encouraging jump threading. If
// not, we are just pessimizing the code by making an i1 phi.
if (LHSI->getParent() == I.getParent())
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI)))
return NV;
break;
case Instruction::SIToFP:
diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h
index 2847ce858e79..71000063ab3c 100644
--- a/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -28,6 +28,9 @@
#include "llvm/IR/PatternMatch.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/InstCombine/InstCombineWorklist.h"
+#include "llvm/Transforms/Utils/Local.h"
+#include "llvm/Support/Dwarf.h"
+#include "llvm/IR/DIBuilder.h"
#define DEBUG_TYPE "instcombine"
@@ -40,21 +43,29 @@ class DbgDeclareInst;
class MemIntrinsic;
class MemSetInst;
-/// \brief Assign a complexity or rank value to LLVM Values.
+/// Assign a complexity or rank value to LLVM Values. This is used to reduce
+/// the amount of pattern matching needed for compares and commutative
+/// instructions. For example, if we have:
+/// icmp ugt X, Constant
+/// or
+/// xor (add X, Constant), cast Z
+///
+/// We do not have to consider the commuted variants of these patterns because
+/// canonicalization based on complexity guarantees the above ordering.
///
/// This routine maps IR values to various complexity ranks:
/// 0 -> undef
/// 1 -> Constants
/// 2 -> Other non-instructions
/// 3 -> Arguments
-/// 3 -> Unary operations
-/// 4 -> Other instructions
+/// 4 -> Cast and (f)neg/not instructions
+/// 5 -> Other instructions
static inline unsigned getComplexity(Value *V) {
if (isa<Instruction>(V)) {
- if (BinaryOperator::isNeg(V) || BinaryOperator::isFNeg(V) ||
- BinaryOperator::isNot(V))
- return 3;
- return 4;
+ if (isa<CastInst>(V) || BinaryOperator::isNeg(V) ||
+ BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V))
+ return 4;
+ return 5;
}
if (isa<Argument>(V))
return 3;
@@ -289,6 +300,7 @@ public:
Instruction *visitLoadInst(LoadInst &LI);
Instruction *visitStoreInst(StoreInst &SI);
Instruction *visitBranchInst(BranchInst &BI);
+ Instruction *visitFenceInst(FenceInst &FI);
Instruction *visitSwitchInst(SwitchInst &SI);
Instruction *visitReturnInst(ReturnInst &RI);
Instruction *visitInsertValueInst(InsertValueInst &IV);
@@ -313,9 +325,14 @@ public:
bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp,
const unsigned SIOpd);
+ /// Try to replace instruction \p I with value \p V which are pointers
+ /// in different address space.
+ /// \return true if successful.
+ bool replacePointer(Instruction &I, Value *V);
+
private:
- bool ShouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
- bool ShouldChangeType(Type *From, Type *To) const;
+ bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const;
+ bool shouldChangeType(Type *From, Type *To) const;
Value *dyn_castNegVal(Value *V) const;
Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const;
Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset,
@@ -456,8 +473,9 @@ public:
/// methods should return the value returned by this function.
Instruction *eraseInstFromFunction(Instruction &I) {
DEBUG(dbgs() << "IC: ERASE " << I << '\n');
-
assert(I.use_empty() && "Cannot erase instruction that is used!");
+ salvageDebugInfo(I);
+
// Make sure that we reprocess all operands now that we reduced their
// use counts.
if (I.getNumOperands() < 8) {
@@ -499,6 +517,9 @@ public:
return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
}
+ /// Maximum size of array considered when transforming.
+ uint64_t MaxArraySizeForCombine;
+
private:
/// \brief Performs a few simplifications for operators which are associative
/// or commutative.
@@ -518,8 +539,16 @@ private:
Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero,
APInt &KnownOne, unsigned Depth,
Instruction *CxtI);
- bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero,
+ bool SimplifyDemandedBits(Instruction *I, unsigned Op,
+ const APInt &DemandedMask, APInt &KnownZero,
APInt &KnownOne, unsigned Depth = 0);
+ /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
+ /// bits. It also tries to handle simplifications that can be done based on
+ /// DemandedMask, but without modifying the Instruction.
+ Value *SimplifyMultipleUseDemandedBits(Instruction *I,
+ const APInt &DemandedMask,
+ APInt &KnownZero, APInt &KnownOne,
+ unsigned Depth, Instruction *CxtI);
/// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded
/// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence.
Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl,
@@ -540,7 +569,7 @@ private:
/// Given a binary operator, cast instruction, or select which has a PHI node
/// as operand #0, see if we can fold the instruction into the PHI (which is
/// only possible if all operands to the PHI are constants).
- Instruction *FoldOpIntoPhi(Instruction &I);
+ Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN);
/// Given an instruction with a select as one operand and a constant as the
/// other operand, try to fold the binary operator into the select arguments.
@@ -549,7 +578,7 @@ private:
Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI);
/// This is a convenience wrapper function for the above two functions.
- Instruction *foldOpWithConstantIntoOperand(Instruction &I);
+ Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I);
/// \brief Try to rotate an operation below a PHI node, using PHI nodes for
/// its operands.
@@ -628,16 +657,16 @@ private:
SelectPatternFlavor SPF2, Value *C);
Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI);
- Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS,
+ Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS,
ConstantInt *AndRHS, BinaryOperator &TheAnd);
- Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask,
- bool isSub, Instruction &I);
Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi,
bool isSigned, bool Inside);
Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI);
Instruction *MatchBSwap(BinaryOperator &I);
bool SimplifyStoreAtEndOfBlock(StoreInst &SI);
+
+ Instruction *SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI);
Instruction *SimplifyMemTransfer(MemIntrinsic *MI);
Instruction *SimplifyMemSet(MemSetInst *MI);
diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index 49e516e9c176..6288e054f1bc 100644
--- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -12,13 +12,15 @@
//===----------------------------------------------------------------------===//
#include "InstCombineInternal.h"
+#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/Analysis/Loads.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
-#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/MDBuilder.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
@@ -223,6 +225,107 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) {
return nullptr;
}
+namespace {
+// If I and V are pointers in different address space, it is not allowed to
+// use replaceAllUsesWith since I and V have different types. A
+// non-target-specific transformation should not use addrspacecast on V since
+// the two address space may be disjoint depending on target.
+//
+// This class chases down uses of the old pointer until reaching the load
+// instructions, then replaces the old pointer in the load instructions with
+// the new pointer. If during the chasing it sees bitcast or GEP, it will
+// create new bitcast or GEP with the new pointer and use them in the load
+// instruction.
+class PointerReplacer {
+public:
+ PointerReplacer(InstCombiner &IC) : IC(IC) {}
+ void replacePointer(Instruction &I, Value *V);
+
+private:
+ void findLoadAndReplace(Instruction &I);
+ void replace(Instruction *I);
+ Value *getReplacement(Value *I);
+
+ SmallVector<Instruction *, 4> Path;
+ MapVector<Value *, Value *> WorkMap;
+ InstCombiner &IC;
+};
+} // end anonymous namespace
+
+void PointerReplacer::findLoadAndReplace(Instruction &I) {
+ for (auto U : I.users()) {
+ auto *Inst = dyn_cast<Instruction>(&*U);
+ if (!Inst)
+ return;
+ DEBUG(dbgs() << "Found pointer user: " << *U << '\n');
+ if (isa<LoadInst>(Inst)) {
+ for (auto P : Path)
+ replace(P);
+ replace(Inst);
+ } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) {
+ Path.push_back(Inst);
+ findLoadAndReplace(*Inst);
+ Path.pop_back();
+ } else {
+ return;
+ }
+ }
+}
+
+Value *PointerReplacer::getReplacement(Value *V) {
+ auto Loc = WorkMap.find(V);
+ if (Loc != WorkMap.end())
+ return Loc->second;
+ return nullptr;
+}
+
+void PointerReplacer::replace(Instruction *I) {
+ if (getReplacement(I))
+ return;
+
+ if (auto *LT = dyn_cast<LoadInst>(I)) {
+ auto *V = getReplacement(LT->getPointerOperand());
+ assert(V && "Operand not replaced");
+ auto *NewI = new LoadInst(V);
+ NewI->takeName(LT);
+ IC.InsertNewInstWith(NewI, *LT);
+ IC.replaceInstUsesWith(*LT, NewI);
+ WorkMap[LT] = NewI;
+ } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
+ auto *V = getReplacement(GEP->getPointerOperand());
+ assert(V && "Operand not replaced");
+ SmallVector<Value *, 8> Indices;
+ Indices.append(GEP->idx_begin(), GEP->idx_end());
+ auto *NewI = GetElementPtrInst::Create(
+ V->getType()->getPointerElementType(), V, Indices);
+ IC.InsertNewInstWith(NewI, *GEP);
+ NewI->takeName(GEP);
+ WorkMap[GEP] = NewI;
+ } else if (auto *BC = dyn_cast<BitCastInst>(I)) {
+ auto *V = getReplacement(BC->getOperand(0));
+ assert(V && "Operand not replaced");
+ auto *NewT = PointerType::get(BC->getType()->getPointerElementType(),
+ V->getType()->getPointerAddressSpace());
+ auto *NewI = new BitCastInst(V, NewT);
+ IC.InsertNewInstWith(NewI, *BC);
+ NewI->takeName(BC);
+ WorkMap[BC] = NewI;
+ } else {
+ llvm_unreachable("should never reach here");
+ }
+}
+
+void PointerReplacer::replacePointer(Instruction &I, Value *V) {
+#ifndef NDEBUG
+ auto *PT = cast<PointerType>(I.getType());
+ auto *NT = cast<PointerType>(V->getType());
+ assert(PT != NT && PT->getElementType() == NT->getElementType() &&
+ "Invalid usage");
+#endif
+ WorkMap[&I] = V;
+ findLoadAndReplace(I);
+}
+
Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
if (auto *I = simplifyAllocaArraySize(*this, AI))
return I;
@@ -293,12 +396,22 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) {
for (unsigned i = 0, e = ToDelete.size(); i != e; ++i)
eraseInstFromFunction(*ToDelete[i]);
Constant *TheSrc = cast<Constant>(Copy->getSource());
- Constant *Cast
- = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType());
- Instruction *NewI = replaceInstUsesWith(AI, Cast);
- eraseInstFromFunction(*Copy);
- ++NumGlobalCopies;
- return NewI;
+ auto *SrcTy = TheSrc->getType();
+ auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(),
+ SrcTy->getPointerAddressSpace());
+ Constant *Cast =
+ ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy);
+ if (AI.getType()->getPointerAddressSpace() ==
+ SrcTy->getPointerAddressSpace()) {
+ Instruction *NewI = replaceInstUsesWith(AI, Cast);
+ eraseInstFromFunction(*Copy);
+ ++NumGlobalCopies;
+ return NewI;
+ } else {
+ PointerReplacer PtrReplacer(*this);
+ PtrReplacer.replacePointer(AI, Cast);
+ ++NumGlobalCopies;
+ }
}
}
}
@@ -608,7 +721,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) {
// arrays of arbitrary size but this has a terrible impact on compile time.
// The threshold here is chosen arbitrarily, maybe needs a little bit of
// tuning.
- if (NumElements > 1024)
+ if (NumElements > IC.MaxArraySizeForCombine)
return nullptr;
const DataLayout &DL = IC.getDataLayout();
@@ -1113,7 +1226,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) {
// arrays of arbitrary size but this has a terrible impact on compile time.
// The threshold here is chosen arbitrarily, maybe needs a little bit of
// tuning.
- if (NumElements > 1024)
+ if (NumElements > IC.MaxArraySizeForCombine)
return false;
const DataLayout &DL = IC.getDataLayout();
@@ -1268,8 +1381,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) {
break;
}
- // Don't skip over loads or things that can modify memory.
- if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory())
+ // Don't skip over loads, throws or things that can modify memory.
+ if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow())
break;
}
@@ -1392,8 +1505,8 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
}
// If we find something that may be using or overwriting the stored
// value, or if we run out of instructions, we can't do the xform.
- if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() ||
- BBI == OtherBB->begin())
+ if (BBI->mayReadFromMemory() || BBI->mayThrow() ||
+ BBI->mayWriteToMemory() || BBI == OtherBB->begin())
return false;
}
@@ -1402,7 +1515,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
// StoreBB.
for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) {
// FIXME: This should really be AA driven.
- if (I->mayReadFromMemory() || I->mayWriteToMemory())
+ if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory())
return false;
}
}
@@ -1425,7 +1538,9 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) {
SI.getOrdering(),
SI.getSynchScope());
InsertNewInstBefore(NewSI, *BBI);
- NewSI->setDebugLoc(OtherStore->getDebugLoc());
+ // The debug locations of the original instructions might differ; merge them.
+ NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(),
+ OtherStore->getDebugLoc()));
// If the two stores had AA tags, merge them.
AAMDNodes AATags;
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 45a19fb0f1f2..f1ac82057e6c 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -298,39 +298,33 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
// (X / Y) * Y = X - (X % Y)
// (X / Y) * -Y = (X % Y) - X
{
- Value *Op1C = Op1;
- BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0);
- if (!BO ||
- (BO->getOpcode() != Instruction::UDiv &&
- BO->getOpcode() != Instruction::SDiv)) {
- Op1C = Op0;
- BO = dyn_cast<BinaryOperator>(Op1);
+ Value *Y = Op1;
+ BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0);
+ if (!Div || (Div->getOpcode() != Instruction::UDiv &&
+ Div->getOpcode() != Instruction::SDiv)) {
+ Y = Op0;
+ Div = dyn_cast<BinaryOperator>(Op1);
}
- Value *Neg = dyn_castNegVal(Op1C);
- if (BO && BO->hasOneUse() &&
- (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) &&
- (BO->getOpcode() == Instruction::UDiv ||
- BO->getOpcode() == Instruction::SDiv)) {
- Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1);
+ Value *Neg = dyn_castNegVal(Y);
+ if (Div && Div->hasOneUse() &&
+ (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) &&
+ (Div->getOpcode() == Instruction::UDiv ||
+ Div->getOpcode() == Instruction::SDiv)) {
+ Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1);
// If the division is exact, X % Y is zero, so we end up with X or -X.
- if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO))
- if (SDiv->isExact()) {
- if (Op1BO == Op1C)
- return replaceInstUsesWith(I, Op0BO);
- return BinaryOperator::CreateNeg(Op0BO);
- }
-
- Value *Rem;
- if (BO->getOpcode() == Instruction::UDiv)
- Rem = Builder->CreateURem(Op0BO, Op1BO);
- else
- Rem = Builder->CreateSRem(Op0BO, Op1BO);
- Rem->takeName(BO);
+ if (Div->isExact()) {
+ if (DivOp1 == Y)
+ return replaceInstUsesWith(I, X);
+ return BinaryOperator::CreateNeg(X);
+ }
- if (Op1BO == Op1C)
- return BinaryOperator::CreateSub(Op0BO, Rem);
- return BinaryOperator::CreateSub(Rem, Op0BO);
+ auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem
+ : Instruction::SRem;
+ Value *Rem = Builder->CreateBinOp(RemOpc, X, DivOp1);
+ if (DivOp1 == Y)
+ return BinaryOperator::CreateSub(X, Rem);
+ return BinaryOperator::CreateSub(Rem, X);
}
}
@@ -1461,16 +1455,16 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
- } else if (isa<PHINode>(Op0I)) {
+ } else if (auto *PN = dyn_cast<PHINode>(Op0I)) {
using namespace llvm::PatternMatch;
const APInt *Op1Int;
if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
(I.getOpcode() == Instruction::URem ||
!Op1Int->isMinSignedValue())) {
- // FoldOpIntoPhi will speculate instructions to the end of the PHI's
+ // foldOpIntoPhi will speculate instructions to the end of the PHI's
// predecessor blocks, so do this only if we know the srem or urem
// will not fault.
- if (Instruction *NV = FoldOpIntoPhi(I))
+ if (Instruction *NV = foldOpIntoPhi(I, PN))
return NV;
}
}
diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 4cbffe9533b7..85e5b6ba2dc2 100644
--- a/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -457,8 +457,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) {
}
// The more common cases of a phi with no constant operands or just one
- // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi()
- // respectively. FoldOpIntoPhi() wants to do the opposite transform that is
+ // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi()
+ // respectively. foldOpIntoPhi() wants to do the opposite transform that is
// performed here. It tries to replicate a cast in the phi operand's basic
// block to expose other folding opportunities. Thus, InstCombine will
// infinite loop without this check.
@@ -507,7 +507,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) {
// Be careful about transforming integer PHIs. We don't want to pessimize
// the code by turning an i32 into an i1293.
if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) {
- if (!ShouldChangeType(PN.getType(), CastSrcTy))
+ if (!shouldChangeType(PN.getType(), CastSrcTy))
return nullptr;
}
} else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) {
diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 36644845352e..693b6c95c169 100644
--- a/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -120,6 +120,16 @@ static Constant *getSelectFoldableConstant(Instruction *I) {
/// We have (select c, TI, FI), and we know that TI and FI have the same opcode.
Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI,
Instruction *FI) {
+ // Don't break up min/max patterns. The hasOneUse checks below prevent that
+ // for most cases, but vector min/max with bitcasts can be transformed. If the
+ // one-use restrictions are eased for other patterns, we still don't want to
+ // obfuscate min/max.
+ if ((match(&SI, m_SMin(m_Value(), m_Value())) ||
+ match(&SI, m_SMax(m_Value(), m_Value())) ||
+ match(&SI, m_UMin(m_Value(), m_Value())) ||
+ match(&SI, m_UMax(m_Value(), m_Value()))))
+ return nullptr;
+
// If this is a cast from the same type, merge.
if (TI->getNumOperands() == 1 && TI->isCast()) {
Type *FIOpndTy = FI->getOperand(0)->getType();
@@ -364,7 +374,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal,
/// into:
/// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false)
static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
- InstCombiner::BuilderTy *Builder) {
+ InstCombiner::BuilderTy *Builder) {
ICmpInst::Predicate Pred = ICI->getPredicate();
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
@@ -395,13 +405,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) ||
match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) {
IntrinsicInst *II = cast<IntrinsicInst>(Count);
- IRBuilder<> Builder(II);
// Explicitly clear the 'undef_on_zero' flag.
IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone());
Type *Ty = NewI->getArgOperand(1)->getType();
NewI->setArgOperand(1, Constant::getNullValue(Ty));
- Builder.Insert(NewI);
- return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType());
+ Builder->Insert(NewI);
+ return Builder->CreateZExtOrTrunc(NewI, ValueOnZero->getType());
}
return nullptr;
@@ -500,18 +509,16 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) {
return true;
}
-/// If this is an integer min/max where the select's 'true' operand is a
-/// constant, canonicalize that constant to the 'false' operand:
-/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C
+/// If this is an integer min/max (icmp + select) with a constant operand,
+/// create the canonical icmp for the min/max operation and canonicalize the
+/// constant to the 'false' operand of the select:
+/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2
+/// Note: if C1 != C2, this will change the icmp constant to the existing
+/// constant operand of the select.
static Instruction *
canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
InstCombiner::BuilderTy &Builder) {
- // TODO: We should also canonicalize min/max when the select has a different
- // constant value than the cmp constant, but we need to fix the backend first.
- if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) ||
- !isa<Constant>(Sel.getTrueValue()) ||
- isa<Constant>(Sel.getFalseValue()) ||
- Cmp.getOperand(1) != Sel.getTrueValue())
+ if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)))
return nullptr;
// Canonicalize the compare predicate based on whether we have min or max.
@@ -526,16 +533,25 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp,
default: return nullptr;
}
- // Canonicalize the constant to the right side.
- if (isa<Constant>(LHS))
- std::swap(LHS, RHS);
+ // Is this already canonical?
+ if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS &&
+ Cmp.getPredicate() == NewPred)
+ return nullptr;
+
+ // Create the canonical compare and plug it into the select.
+ Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS));
- Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS);
- SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel);
+ // If the select operands did not change, we're done.
+ if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS)
+ return &Sel;
- // We swapped the select operands, so swap the metadata too.
- NewSel->swapProfMetadata();
- return NewSel;
+ // If we are swapping the select operands, swap the metadata too.
+ assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS &&
+ "Unexpected results from matchSelectPattern");
+ Sel.setTrueValue(LHS);
+ Sel.setFalseValue(RHS);
+ Sel.swapProfMetadata();
+ return &Sel;
}
/// Visit a SelectInst that has an ICmpInst as its first operand.
@@ -786,7 +802,9 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner,
// This transform is performance neutral if we can elide at least one xor from
// the set of three operands, since we'll be tacking on an xor at the very
// end.
- if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
+ if (SelectPatternResult::isMinOrMax(SPF1) &&
+ SelectPatternResult::isMinOrMax(SPF2) &&
+ IsFreeOrProfitableToInvert(A, NotA, ElidesXor) &&
IsFreeOrProfitableToInvert(B, NotB, ElidesXor) &&
IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) {
if (!NotA)
@@ -1035,8 +1053,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) {
// If the select condition element is false, choose from the 2nd vector.
Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts));
} else if (isa<UndefValue>(Elt)) {
- // If the select condition element is undef, the shuffle mask is undef.
- Mask.push_back(UndefValue::get(Int32Ty));
+ // Undef in a select condition (choose one of the operands) does not mean
+ // the same thing as undef in a shuffle mask (any value is acceptable), so
+ // give up.
+ return nullptr;
} else {
// Bail out on a constant expression.
return nullptr;
@@ -1364,11 +1384,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
// See if we can fold the select into a phi node if the condition is a select.
- if (isa<PHINode>(SI.getCondition()))
+ if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
// The true/false values have to be live in the PHI predecessor's blocks.
if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) &&
canSelectOperandBeMappingIntoPredBlock(FalseVal, SI))
- if (Instruction *NV = FoldOpIntoPhi(SI))
+ if (Instruction *NV = foldOpIntoPhi(SI, PN))
return NV;
if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) {
@@ -1450,6 +1470,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) {
}
}
+ // If we can compute the condition, there's no need for a select.
+ // Like the above fold, we are attempting to reduce compile-time cost by
+ // putting this fold here with limitations rather than in InstSimplify.
+ // The motivation for this call into value tracking is to take advantage of
+ // the assumption cache, so make sure that is populated.
+ if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) {
+ APInt KnownOne(1, 0), KnownZero(1, 0);
+ computeKnownBits(CondVal, KnownZero, KnownOne, 0, &SI);
+ if (KnownOne == 1)
+ return replaceInstUsesWith(SI, TrueVal);
+ if (KnownZero == 1)
+ return replaceInstUsesWith(SI, FalseVal);
+ }
+
if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder))
return BitCastSel;
diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 4ff9b64ac57c..9aa679c60e47 100644
--- a/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -22,8 +22,8 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
- assert(I.getOperand(1)->getType() == I.getOperand(0)->getType());
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ assert(Op0->getType() == Op1->getType());
// See if we can fold away this shift.
if (SimplifyDemandedInstructionBits(I))
@@ -65,63 +65,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) {
}
/// Return true if we can simplify two logical (either left or right) shifts
-/// that have constant shift amounts.
-static bool canEvaluateShiftedShift(unsigned FirstShiftAmt,
- bool IsFirstShiftLeft,
- Instruction *SecondShift, InstCombiner &IC,
+/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2.
+static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl,
+ Instruction *InnerShift, InstCombiner &IC,
Instruction *CxtI) {
- assert(SecondShift->isLogicalShift() && "Unexpected instruction type");
+ assert(InnerShift->isLogicalShift() && "Unexpected instruction type");
- // We need constant shifts.
- auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1));
- if (!SecondShiftConst)
+ // We need constant scalar or constant splat shifts.
+ const APInt *InnerShiftConst;
+ if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst)))
return false;
- unsigned SecondShiftAmt = SecondShiftConst->getZExtValue();
- bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl;
-
- // We can always fold shl(c1) + shl(c2) -> shl(c1+c2).
- // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2).
- if (IsFirstShiftLeft == IsSecondShiftLeft)
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ if (IsInnerShl == IsOuterShl)
return true;
- // We can always fold lshr(c) + shl(c) -> and(c2).
- // We can always fold shl(c) + lshr(c) -> and(c2).
- if (FirstShiftAmt == SecondShiftAmt)
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ unsigned InnerShAmt = InnerShiftConst->getZExtValue();
+ if (InnerShAmt == OuterShAmt)
return true;
- unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits();
-
// If the 2nd shift is bigger than the 1st, we can fold:
- // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or
- // shl(c1) + lshr(c2) -> lshr(c3) + and(c4),
+ // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3
+ // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3
// but it isn't profitable unless we know the and'd out bits are already zero.
- // Also check that the 2nd shift is valid (less than the type width) or we'll
- // crash trying to produce the bit mask for the 'and'.
- if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) {
- unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt
- : SecondShiftAmt - FirstShiftAmt;
- APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift;
- if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI))
+ // Also, check that the inner shift is valid (less than the type width) or
+ // we'll crash trying to produce the bit mask for the 'and'.
+ unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits();
+ if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) {
+ unsigned MaskShift =
+ IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt;
+ APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift;
+ if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI))
return true;
}
return false;
}
-/// See if we can compute the specified value, but shifted
-/// logically to the left or right by some number of bits. This should return
-/// true if the expression can be computed for the same cost as the current
-/// expression tree. This is used to eliminate extraneous shifting from things
-/// like:
+/// See if we can compute the specified value, but shifted logically to the left
+/// or right by some number of bits. This should return true if the expression
+/// can be computed for the same cost as the current expression tree. This is
+/// used to eliminate extraneous shifting from things like:
/// %C = shl i128 %A, 64
/// %D = shl i128 %B, 96
/// %E = or i128 %C, %D
/// %F = lshr i128 %E, 64
-/// where the client will ask if E can be computed shifted right by 64-bits. If
-/// this succeeds, the GetShiftedValue function will be called to produce the
-/// value.
-static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
+/// where the client will ask if E can be computed shifted right by 64-bits. If
+/// this succeeds, getShiftedValue() will be called to produce the value.
+static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
InstCombiner &IC, Instruction *CxtI) {
// We can always evaluate constants shifted.
if (isa<Constant>(V))
@@ -165,8 +162,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
case Instruction::Or:
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
- return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
- CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
+ return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) &&
+ canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I);
case Instruction::Shl:
case Instruction::LShr:
@@ -176,8 +173,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
SelectInst *SI = cast<SelectInst>(I);
Value *TrueVal = SI->getTrueValue();
Value *FalseVal = SI->getFalseValue();
- return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
- CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
+ return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) &&
+ canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI);
}
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -185,16 +182,79 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (Value *IncValue : PN->incoming_values())
- if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
+ if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN))
return false;
return true;
}
}
}
-/// When CanEvaluateShifted returned true for an expression,
-/// this value inserts the new computation that produces the shifted value.
-static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
+/// Fold OuterShift (InnerShift X, C1), C2.
+/// See canEvaluateShiftedShift() for the constraints on these instructions.
+static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt,
+ bool IsOuterShl,
+ InstCombiner::BuilderTy &Builder) {
+ bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl;
+ Type *ShType = InnerShift->getType();
+ unsigned TypeWidth = ShType->getScalarSizeInBits();
+
+ // We only accept shifts-by-a-constant in canEvaluateShifted().
+ const APInt *C1;
+ match(InnerShift->getOperand(1), m_APInt(C1));
+ unsigned InnerShAmt = C1->getZExtValue();
+
+ // Change the shift amount and clear the appropriate IR flags.
+ auto NewInnerShift = [&](unsigned ShAmt) {
+ InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt));
+ if (IsInnerShl) {
+ InnerShift->setHasNoUnsignedWrap(false);
+ InnerShift->setHasNoSignedWrap(false);
+ } else {
+ InnerShift->setIsExact(false);
+ }
+ return InnerShift;
+ };
+
+ // Two logical shifts in the same direction:
+ // shl (shl X, C1), C2 --> shl X, C1 + C2
+ // lshr (lshr X, C1), C2 --> lshr X, C1 + C2
+ if (IsInnerShl == IsOuterShl) {
+ // If this is an oversized composite shift, then unsigned shifts get 0.
+ if (InnerShAmt + OuterShAmt >= TypeWidth)
+ return Constant::getNullValue(ShType);
+
+ return NewInnerShift(InnerShAmt + OuterShAmt);
+ }
+
+ // Equal shift amounts in opposite directions become bitwise 'and':
+ // lshr (shl X, C), C --> and X, C'
+ // shl (lshr X, C), C --> and X, C'
+ if (InnerShAmt == OuterShAmt) {
+ APInt Mask = IsInnerShl
+ ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt)
+ : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt);
+ Value *And = Builder.CreateAnd(InnerShift->getOperand(0),
+ ConstantInt::get(ShType, Mask));
+ if (auto *AndI = dyn_cast<Instruction>(And)) {
+ AndI->moveBefore(InnerShift);
+ AndI->takeName(InnerShift);
+ }
+ return And;
+ }
+
+ assert(InnerShAmt > OuterShAmt &&
+ "Unexpected opposite direction logical shift pair");
+
+ // In general, we would need an 'and' for this transform, but
+ // canEvaluateShiftedShift() guarantees that the masked-off bits are not used.
+ // lshr (shl X, C1), C2 --> shl X, C1 - C2
+ // shl (lshr X, C1), C2 --> lshr X, C1 - C2
+ return NewInnerShift(InnerShAmt - OuterShAmt);
+}
+
+/// When canEvaluateShifted() returns true for an expression, this function
+/// inserts the new computation that produces the shifted value.
+static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
InstCombiner &IC, const DataLayout &DL) {
// We can always evaluate constants shifted.
if (Constant *C = dyn_cast<Constant>(V)) {
@@ -220,100 +280,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
case Instruction::Xor:
// Bitwise operators can all arbitrarily be arbitrarily evaluated shifted.
I->setOperand(
- 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
+ 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
return I;
- case Instruction::Shl: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
-
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold shl(c1)+shl(c2) -> shl(c1+c2).
- if (isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(I->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return I;
- }
-
- // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(BO->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(BO);
- VI->takeName(BO);
- }
- return V;
- }
-
- // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setHasNoUnsignedWrap(false);
- BO->setHasNoSignedWrap(false);
- return BO;
- }
- // FIXME: This is almost identical to the SHL case. Refactor both cases into
- // a helper function.
- case Instruction::LShr: {
- BinaryOperator *BO = cast<BinaryOperator>(I);
- unsigned TypeWidth = BO->getType()->getScalarSizeInBits();
- // We only accept shifts-by-a-constant in CanEvaluateShifted.
- ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1));
-
- // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2).
- if (!isLeftShift) {
- // If this is oversized composite shift, then unsigned shifts get 0.
- unsigned NewShAmt = NumBits+CI->getZExtValue();
- if (NewShAmt >= TypeWidth)
- return Constant::getNullValue(BO->getType());
-
- BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt));
- BO->setIsExact(false);
- return I;
- }
-
- // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have
- // zeros.
- if (CI->getValue() == NumBits) {
- APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits));
- V = IC.Builder->CreateAnd(I->getOperand(0),
- ConstantInt::get(BO->getContext(), Mask));
- if (Instruction *VI = dyn_cast<Instruction>(V)) {
- VI->moveBefore(I);
- VI->takeName(I);
- }
- return V;
- }
-
- // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that
- // the and won't be needed.
- assert(CI->getZExtValue() > NumBits);
- BO->setOperand(1, ConstantInt::get(BO->getType(),
- CI->getZExtValue() - NumBits));
- BO->setIsExact(false);
- return BO;
- }
+ case Instruction::Shl:
+ case Instruction::LShr:
+ return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift,
+ *(IC.Builder));
case Instruction::Select:
I->setOperand(
- 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
+ 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL));
I->setOperand(
- 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
+ 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL));
return I;
case Instruction::PHI: {
// We can change a phi if we can change all operands. Note that we never
@@ -321,215 +302,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift,
// instructions with a single use.
PHINode *PN = cast<PHINode>(I);
for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i)
- PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits,
+ PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits,
isLeftShift, IC, DL));
return PN;
}
}
}
-/// Try to fold (X << C1) << C2, where the shifts are some combination of
-/// shl/ashr/lshr.
-static Instruction *
-foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1,
- InstCombiner::BuilderTy *Builder) {
- Value *Op0 = I.getOperand(0);
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
-
- // Find out if this is a shift of a shift by a constant.
- BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0);
- if (ShiftOp && !ShiftOp->isShift())
- ShiftOp = nullptr;
-
- if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) {
-
- // This is a constant shift of a constant shift. Be careful about hiding
- // shl instructions behind bit masks. They are used to represent multiplies
- // by a constant, and it is important that simple arithmetic expressions
- // are still recognizable by scalar evolution.
- //
- // The transforms applied to shl are very similar to the transforms applied
- // to mul by constant. We can be more aggressive about optimizing right
- // shifts.
- //
- // Combinations of right and left shifts will still be optimized in
- // DAGCombine where scalar evolution no longer applies.
-
- ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1));
- uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits);
- uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits);
- assert(ShiftAmt2 != 0 && "Should have been simplified earlier");
- if (ShiftAmt1 == 0)
- return nullptr; // Will be simplified in the future.
- Value *X = ShiftOp->getOperand(0);
-
- IntegerType *Ty = cast<IntegerType>(I.getType());
-
- // Check for (X << c1) << c2 and (X >> c1) >> c2
- if (I.getOpcode() == ShiftOp->getOpcode()) {
- uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift.
- // If this is an oversized composite shift, then unsigned shifts become
- // zero (handled in InstSimplify) and ashr saturates.
- if (AmtSum >= TypeBits) {
- if (I.getOpcode() != Instruction::AShr)
- return nullptr;
- AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr.
- }
-
- return BinaryOperator::Create(I.getOpcode(), X,
- ConstantInt::get(Ty, AmtSum));
- }
-
- if (ShiftAmt1 == ShiftAmt2) {
- // If we have ((X << C) >>u C), turn this into X & (-1 >>u C).
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1));
- return BinaryOperator::CreateAnd(
- X, ConstantInt::get(I.getContext(), Mask));
- }
- } else if (ShiftAmt1 < ShiftAmt2) {
- uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1;
-
- // (X >>?,exact C1) << C2 --> X << (C2-C1)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- assert(ShiftOp->getOpcode() == Instruction::LShr ||
- ShiftOp->getOpcode() == Instruction::AShr);
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
- NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
- return NewShl;
- }
-
- // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- // (X <<nuw C1) >>u C2 --> X >>u (C2-C1)
- if (ShiftOp->hasNoUnsignedWrap()) {
- BinaryOperator *NewLShr =
- BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst);
- NewLShr->setIsExact(I.isExact());
- return NewLShr;
- }
- Value *Shift = Builder->CreateLShr(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X >>s (C2-C1)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewAShr =
- BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst);
- NewAShr->setIsExact(I.isExact());
- return NewAShr;
- }
- }
- } else {
- assert(ShiftAmt2 < ShiftAmt1);
- uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2;
-
- // (X >>?exact C1) << C2 --> X >>?exact (C1-C2)
- // The inexact version is deferred to DAGCombine so we don't hide shl
- // behind a bit mask.
- if (I.getOpcode() == Instruction::Shl &&
- ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShr =
- BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst);
- NewShr->setIsExact(true);
- return NewShr;
- }
-
- // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2)
- if (I.getOpcode() == Instruction::LShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- if (ShiftOp->hasNoUnsignedWrap()) {
- // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2)
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoUnsignedWrap(true);
- return NewShl;
- }
- Value *Shift = Builder->CreateShl(X, ShiftDiffCst);
-
- APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2));
- return BinaryOperator::CreateAnd(
- Shift, ConstantInt::get(I.getContext(), Mask));
- }
-
- // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However,
- // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
- if (I.getOpcode() == Instruction::AShr &&
- ShiftOp->getOpcode() == Instruction::Shl) {
- if (ShiftOp->hasNoSignedWrap()) {
- // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2)
- ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff);
- BinaryOperator *NewShl =
- BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst);
- NewShl->setHasNoSignedWrap(true);
- return NewShl;
- }
- }
- }
- }
-
- return nullptr;
-}
-
Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
BinaryOperator &I) {
bool isLeftShift = I.getOpcode() == Instruction::Shl;
- ConstantInt *COp1 = nullptr;
- if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1))
- COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue());
- else
- COp1 = dyn_cast<ConstantInt>(Op1);
-
- if (!COp1)
+ const APInt *Op1C;
+ if (!match(Op1, m_APInt(Op1C)))
return nullptr;
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
if (I.getOpcode() != Instruction::AShr &&
- CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) {
+ canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) {
DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression"
" to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n");
return replaceInstUsesWith(
- I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL));
+ I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL));
}
// See if we can simplify any instructions used by the instruction whose sole
// purpose is to compute bits we don't care about.
- uint32_t TypeBits = Op0->getType()->getScalarSizeInBits();
+ unsigned TypeBits = Op0->getType()->getScalarSizeInBits();
- assert(!COp1->uge(TypeBits) &&
+ assert(!Op1C->uge(TypeBits) &&
"Shift over the type width should have been removed already");
- // ((X*C1) << C2) == (X * (C1 << C2))
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0))
- if (BO->getOpcode() == Instruction::Mul && isLeftShift)
- if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1)))
- return BinaryOperator::CreateMul(BO->getOperand(0),
- ConstantExpr::getShl(BOOp, Op1));
-
if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I))
return FoldedShift;
@@ -544,7 +349,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
if (TrOp && I.isLogicalShift() && TrOp->isShift() &&
isa<ConstantInt>(TrOp->getOperand(1))) {
// Okay, we'll do this xform. Make the shift of shift.
- Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType());
+ Constant *ShAmt =
+ ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType());
// (shift2 (shift1 & 0x00FF), c2)
Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName());
@@ -561,10 +367,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// shift. We know that it is a logical shift by a constant, so adjust the
// mask as appropriate.
if (I.getOpcode() == Instruction::Shl)
- MaskV <<= COp1->getZExtValue();
+ MaskV <<= Op1C->getZExtValue();
else {
assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift");
- MaskV = MaskV.lshr(COp1->getZExtValue());
+ MaskV = MaskV.lshr(Op1C->getZExtValue());
}
// shift1 & 0x00FF
@@ -598,7 +404,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1,
Op0BO->getOperand(1)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -634,7 +440,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
// (X + (Y << C))
Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS,
Op0BO->getOperand(0)->getName());
- uint32_t Op1Val = COp1->getLimitedValue(TypeBits);
+ unsigned Op1Val = Op1C->getLimitedValue(TypeBits);
APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val);
Constant *Mask = ConstantInt::get(I.getContext(), Bits);
@@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1,
}
}
- if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder))
- return Folded;
-
return nullptr;
}
@@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V =
- SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(),
- I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(),
+ I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *V = commonShiftTransforms(I))
return V;
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- // Turn:
- // %zext = zext i32 %V to i64
- // %res = shl i64 %V, 8
- //
- // Into:
- // %shl = shl i32 %V, 8
- // %res = zext i32 %shl to i64
- //
- // This is only valid if %V would have zeros shifted out.
- if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) {
- unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits();
- if (ShAmt < SrcBitWidth &&
- MaskedValueIsZero(ZI->getOperand(0),
- APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) {
- auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt);
- return new ZExtInst(Shl, I.getType());
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+
+ // shl (zext X), ShAmt --> zext (shl X, ShAmt)
+ // This is only valid if X would have zeros shifted out.
+ Value *X;
+ if (match(Op0, m_ZExt(m_Value(X)))) {
+ unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+ if (ShAmt < SrcWidth &&
+ MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I))
+ return new ZExtInst(Builder->CreateShl(X, ShAmt), Ty);
+ }
+
+ // (X >>u C) << C --> X & (-1 << C)
+ if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) {
+ APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ // Be careful about hiding shl instructions behind bit masks. They are used
+ // to represent multiplies by a constant, and it is important that simple
+ // arithmetic expressions are still recognizable by scalar evolution.
+ // The inexact versions are deferred to DAGCombine, so we don't hide shl
+ // behind a bit mask.
+ const APInt *ShOp1;
+ if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShOp1))),
+ m_Exact(m_AShr(m_Value(X), m_APInt(ShOp1)))))) {
+ unsigned ShrAmt = ShOp1->getZExtValue();
+ if (ShrAmt < ShAmt) {
+ // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt);
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap());
+ NewShl->setHasNoSignedWrap(I.hasNoSignedWrap());
+ return NewShl;
}
+ if (ShrAmt > ShAmt) {
+ // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt);
+ auto *NewShr = BinaryOperator::Create(
+ cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff);
+ NewShr->setIsExact(true);
+ return NewShr;
+ }
+ }
+
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X << C1) << C2 --> X << (C1 + C2)
+ return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is a NUW shift.
if (!I.hasNoUnsignedWrap() &&
- MaskedValueIsZero(I.getOperand(0),
- APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0,
- &I)) {
+ MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setHasNoUnsignedWrap();
return &I;
}
- // If the shifted out value is all signbits, this is a NSW shift.
- if (!I.hasNoSignedWrap() &&
- ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) {
+ // If the shifted-out value is all signbits, then this is a NSW shift.
+ if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) {
I.setHasNoSignedWrap();
return &I;
}
}
- // (C1 << A) << C2 -> (C1 << C2) << A
- Constant *C1, *C2;
- Value *A;
- if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) &&
- match(I.getOperand(1), m_Constant(C2)))
- return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A);
+ Constant *C1;
+ if (match(Op1, m_Constant(C1))) {
+ Constant *C2;
+ Value *X;
+ // (C2 << X) << C1 --> (C2 << C1) << X
+ if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X)))))
+ return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X);
+
+ // (X * C2) << C1 --> X * (C2 << C1)
+ if (match(Op0, m_Mul(m_Value(X), m_Constant(C2))))
+ return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1));
+ }
return nullptr;
}
@@ -776,43 +617,83 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
-
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
- unsigned BitWidth = Op0->getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ auto *II = dyn_cast<IntrinsicInst>(Op0);
+ if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt &&
+ (II->getIntrinsicID() == Intrinsic::ctlz ||
+ II->getIntrinsicID() == Intrinsic::cttz ||
+ II->getIntrinsicID() == Intrinsic::ctpop)) {
// ctlz.i32(x)>>5 --> zext(x == 0)
// cttz.i32(x)>>5 --> zext(x == 0)
// ctpop.i32(x)>>5 --> zext(x == -1)
- if ((II->getIntrinsicID() == Intrinsic::ctlz ||
- II->getIntrinsicID() == Intrinsic::cttz ||
- II->getIntrinsicID() == Intrinsic::ctpop) &&
- isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) {
- bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop;
- Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0);
- Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
- return new ZExtInst(Cmp, II->getType());
+ bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop;
+ Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0);
+ Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS);
+ return new ZExtInst(Cmp, Ty);
+ }
+
+ Value *X;
+ const APInt *ShOp1;
+ if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1)
+ auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff);
+ NewLShr->setIsExact(I.isExact());
+ return NewLShr;
+ }
+ // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2)
+ Value *NewLShr = Builder->CreateLShr(X, ShiftDiff, "", I.isExact());
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask));
}
+ if (ShlAmt > ShAmt) {
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) {
+ // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2)
+ auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff);
+ NewShl->setHasNoUnsignedWrap(true);
+ return NewShl;
+ }
+ // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2)
+ Value *NewShl = Builder->CreateShl(X, ShiftDiff);
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask));
+ }
+ assert(ShlAmt == ShAmt);
+ // (X << C) >>u C --> X & (-1 >>u C)
+ APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt));
+ return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask));
+ }
+
+ if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized shifts are simplified to zero in InstSimplify.
+ if (AmtSum < BitWidth)
+ // (X >>u C1) >>u C2 --> X >>u (C1 + C2)
+ return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)){
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
-
return nullptr;
}
@@ -820,48 +701,66 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) {
if (Value *V = SimplifyVectorOp(I))
return replaceInstUsesWith(I, V);
- if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(),
- DL, &TLI, &DT, &AC))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC))
return replaceInstUsesWith(I, V);
if (Instruction *R = commonShiftTransforms(I))
return R;
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-
- if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
- unsigned ShAmt = Op1C->getZExtValue();
+ Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ const APInt *ShAmtAPInt;
+ if (match(Op1, m_APInt(ShAmtAPInt))) {
+ unsigned ShAmt = ShAmtAPInt->getZExtValue();
- // If the input is a SHL by the same constant (ashr (shl X, C), C), then we
- // have a sign-extend idiom.
+ // If the shift amount equals the difference in width of the destination
+ // and source scalar types:
+ // ashr (shl (zext X), C), C --> sext X
Value *X;
- if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) {
- // If the input is an extension from the shifted amount value, e.g.
- // %x = zext i8 %A to i32
- // %y = shl i32 %x, 24
- // %z = ashr %y, 24
- // then turn this into "z = sext i8 A to i32".
- if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) {
- uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits();
- uint32_t DestBits = ZI->getType()->getScalarSizeInBits();
- if (Op1C->getZExtValue() == DestBits-SrcBits)
- return new SExtInst(ZI->getOperand(0), ZI->getType());
+ if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) &&
+ ShAmt == BitWidth - X->getType()->getScalarSizeInBits())
+ return new SExtInst(X, Ty);
+
+ // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However,
+ // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits.
+ const APInt *ShOp1;
+ if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned ShlAmt = ShOp1->getZExtValue();
+ if (ShlAmt < ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt);
+ auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff);
+ NewAShr->setIsExact(I.isExact());
+ return NewAShr;
}
+ if (ShlAmt > ShAmt) {
+ // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2)
+ Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt);
+ auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff);
+ NewShl->setHasNoSignedWrap(true);
+ return NewShl;
+ }
+ }
+
+ if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) {
+ unsigned AmtSum = ShAmt + ShOp1->getZExtValue();
+ // Oversized arithmetic shifts replicate the sign bit.
+ AmtSum = std::min(AmtSum, BitWidth - 1);
+ // (X >>s C1) >>s C2 --> X >>s (C1 + C2)
+ return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum));
}
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
- MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt),
- 0, &I)) {
+ MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) {
I.setIsExact();
return &I;
}
}
// See if we can turn a signed shr into an unsigned shr.
- if (MaskedValueIsZero(Op0,
- APInt::getSignBit(I.getType()->getScalarSizeInBits()),
- 0, &I))
+ if (MaskedValueIsZero(Op0, APInt::getSignBit(BitWidth), 0, &I))
return BinaryOperator::CreateLShr(Op0, Op1);
return nullptr;
diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index 8b930bd95dfe..4e6f02058d83 100644
--- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -30,18 +30,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo,
assert(I && "No instruction?");
assert(OpNo < I->getNumOperands() && "Operand index too large");
- // If the operand is not a constant integer, nothing to do.
- ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo));
- if (!OpC) return false;
+ // The operand must be a constant integer or splat integer.
+ Value *Op = I->getOperand(OpNo);
+ const APInt *C;
+ if (!match(Op, m_APInt(C)))
+ return false;
// If there are no bits set that aren't demanded, nothing to do.
- Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth());
- if ((~Demanded & OpC->getValue()) == 0)
+ Demanded = Demanded.zextOrTrunc(C->getBitWidth());
+ if ((~Demanded & *C) == 0)
return false;
// This instruction is producing bits that are not demanded. Shrink the RHS.
- Demanded &= OpC->getValue();
- I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded));
+ Demanded &= *C;
+ I->setOperand(OpNo, ConstantInt::get(Op->getType(), Demanded));
return true;
}
@@ -66,12 +68,13 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) {
/// This form of SimplifyDemandedBits simplifies the specified instruction
/// operand if possible, updating it in place. It returns true if it made any
/// change and false otherwise.
-bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask,
+bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo,
+ const APInt &DemandedMask,
APInt &KnownZero, APInt &KnownOne,
unsigned Depth) {
- auto *UserI = dyn_cast<Instruction>(U.getUser());
+ Use &U = I->getOperandUse(OpNo);
Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero,
- KnownOne, Depth, UserI);
+ KnownOne, Depth, I);
if (!NewVal) return false;
U = NewVal;
return true;
@@ -114,9 +117,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
KnownOne.getBitWidth() == BitWidth &&
"Value *V, DemandedMask, KnownZero and KnownOne "
"must have same BitWidth");
- if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
- // We know all of the bits for a constant!
- KnownOne = CI->getValue() & DemandedMask;
+ const APInt *C;
+ if (match(V, m_APInt(C))) {
+ // We know all of the bits for a scalar constant or a splat vector constant!
+ KnownOne = *C & DemandedMask;
KnownZero = ~KnownOne & DemandedMask;
return nullptr;
}
@@ -138,9 +142,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 6) // Limit search depth.
return nullptr;
- APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
- APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
-
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI);
@@ -151,107 +152,43 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// we can't do any simplifications of the operands, because DemandedMask
// only reflects the bits demanded by *one* of the users.
if (Depth != 0 && !I->hasOneUse()) {
- // Despite the fact that we can't simplify this instruction in all User's
- // context, we can at least compute the knownzero/knownone bits, and we can
- // do simplifications that apply to *just* the one user if we know that
- // this instruction has a simpler value in that context.
- if (I->getOpcode() == Instruction::And) {
- // If either the LHS or the RHS are Zero, the result is zero.
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known 1 on one side, return the other.
- // These bits cannot contribute to the result of the 'and' in this
- // context.
- if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
- (DemandedMask & ~LHSKnownZero))
- return I->getOperand(0);
- if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
- (DemandedMask & ~RHSKnownZero))
- return I->getOperand(1);
-
- // If all of the demanded bits in the inputs are known zeros, return zero.
- if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
- return Constant::getNullValue(VTy);
-
- } else if (I->getOpcode() == Instruction::Or) {
- // We can simplify (X|Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
- // If either the LHS or the RHS are One, the result is One.
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known zero on one side, return the
- // other. These bits cannot contribute to the result of the 'or' in this
- // context.
- if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
- (DemandedMask & ~LHSKnownOne))
- return I->getOperand(0);
- if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
- (DemandedMask & ~RHSKnownOne))
- return I->getOperand(1);
-
- // If all of the potentially set bits on one side are known to be set on
- // the other side, just use the 'other' side.
- if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
- (DemandedMask & (~RHSKnownZero)))
- return I->getOperand(0);
- if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
- (DemandedMask & (~LHSKnownZero)))
- return I->getOperand(1);
- } else if (I->getOpcode() == Instruction::Xor) {
- // We can simplify (X^Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
- computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
- CxtI);
- computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
- CxtI);
-
- // If all of the demanded bits are known zero on one side, return the
- // other.
- if ((DemandedMask & RHSKnownZero) == DemandedMask)
- return I->getOperand(0);
- if ((DemandedMask & LHSKnownZero) == DemandedMask)
- return I->getOperand(1);
- }
-
- // Compute the KnownZero/KnownOne bits to simplify things downstream.
- computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
- return nullptr;
+ return SimplifyMultipleUseDemandedBits(I, DemandedMask, KnownZero, KnownOne,
+ Depth, CxtI);
}
+ APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
+ APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
+
// If this is the root being simplified, allow it to have multiple uses,
// just set the DemandedMask to all bits so that we can try to simplify the
// operands. This allows visitTruncInst (for example) to simplify the
// operand of a trunc without duplicating all the logic below.
if (Depth == 0 && !V->hasOneUse())
- DemandedMask = APInt::getAllOnesValue(BitWidth);
+ DemandedMask.setAllBits();
switch (I->getOpcode()) {
default:
computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
break;
- case Instruction::And:
+ case Instruction::And: {
// If either the LHS or the RHS are Zero, the result is zero.
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero,
- LHSKnownZero, LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownZero, LHSKnownZero,
+ LHSKnownOne, Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
+ // Output known-0 are known to be clear if zero in either the LHS | RHS.
+ APInt IKnownZero = RHSKnownZero | LHSKnownZero;
+ // Output known-1 bits are only known if set in both the LHS & RHS.
+ APInt IKnownOne = RHSKnownOne & LHSKnownOne;
+
// If the client is only demanding bits that we know, return the known
// constant.
- if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)|
- (RHSKnownOne & LHSKnownOne))) == DemandedMask)
- return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne);
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(VTy, IKnownOne);
// If all of the demanded bits are known 1 on one side, return the other.
// These bits cannot contribute to the result of the 'and'.
@@ -262,34 +199,33 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
(DemandedMask & ~RHSKnownZero))
return I->getOperand(1);
- // If all of the demanded bits in the inputs are known zeros, return zero.
- if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask)
- return Constant::getNullValue(VTy);
-
// If the RHS is a constant, see if we can simplify it.
if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero))
return I;
- // Output known-1 bits are only known if set in both the LHS & RHS.
- KnownOne = RHSKnownOne & LHSKnownOne;
- // Output known-0 are known to be clear if zero in either the LHS | RHS.
- KnownZero = RHSKnownZero | LHSKnownZero;
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
- case Instruction::Or:
+ }
+ case Instruction::Or: {
// If either the LHS or the RHS are One, the result is One.
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne,
- LHSKnownZero, LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownOne, LHSKnownZero,
+ LHSKnownOne, Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
+ // Output known-0 bits are only known if clear in both the LHS & RHS.
+ APInt IKnownZero = RHSKnownZero & LHSKnownZero;
+ // Output known-1 are known to be set if set in either the LHS | RHS.
+ APInt IKnownOne = RHSKnownOne | LHSKnownOne;
+
// If the client is only demanding bits that we know, return the known
// constant.
- if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)|
- (RHSKnownOne | LHSKnownOne))) == DemandedMask)
- return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne);
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(VTy, IKnownOne);
// If all of the demanded bits are known zero on one side, return the other.
// These bits cannot contribute to the result of the 'or'.
@@ -313,16 +249,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (ShrinkDemandedConstant(I, 1, DemandedMask))
return I;
- // Output known-0 bits are only known if clear in both the LHS & RHS.
- KnownZero = RHSKnownZero & LHSKnownZero;
- // Output known-1 are known to be set if set in either the LHS | RHS.
- KnownOne = RHSKnownOne | LHSKnownOne;
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
+ }
case Instruction::Xor: {
- if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 0, DemandedMask, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
@@ -400,9 +335,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
// Output known-0 bits are known if clear or set in both the LHS & RHS.
- KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne);
+ KnownZero = std::move(IKnownZero);
// Output known-1 are known to be set if set in only one of the LHS, RHS.
- KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero);
+ KnownOne = std::move(IKnownOne);
break;
}
case Instruction::Select:
@@ -412,10 +347,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN)
return nullptr;
- if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero,
- RHSKnownOne, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnownZero, RHSKnownOne,
+ Depth + 1) ||
+ SimplifyDemandedBits(I, 1, DemandedMask, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?");
assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?");
@@ -434,8 +369,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask = DemandedMask.zext(truncBf);
KnownZero = KnownZero.zext(truncBf);
KnownOne = KnownOne.zext(truncBf);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
DemandedMask = DemandedMask.trunc(BitWidth);
KnownZero = KnownZero.trunc(BitWidth);
@@ -460,8 +395,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Don't touch a vector-to-scalar bitcast.
return nullptr;
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
break;
@@ -472,15 +407,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
DemandedMask = DemandedMask.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne,
+ Depth + 1))
return I;
DemandedMask = DemandedMask.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
KnownOne = KnownOne.zext(BitWidth);
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
// The top bits are known to be zero.
- KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth);
+ KnownZero.setBitsFrom(SrcBitWidth);
break;
}
case Instruction::SExt: {
@@ -490,7 +425,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt InputDemandedBits = DemandedMask &
APInt::getLowBitsSet(BitWidth, SrcBitWidth);
- APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth));
+ APInt NewBits(APInt::getBitsSetFrom(BitWidth, SrcBitWidth));
// If any of the sign extended bits are demanded, we know that the sign
// bit is demanded.
if ((NewBits & DemandedMask) != 0)
@@ -499,8 +434,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth);
KnownZero = KnownZero.trunc(SrcBitWidth);
KnownOne = KnownOne.trunc(SrcBitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, InputDemandedBits, KnownZero, KnownOne,
+ Depth + 1))
return I;
InputDemandedBits = InputDemandedBits.zext(BitWidth);
KnownZero = KnownZero.zext(BitWidth);
@@ -530,11 +465,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// Right fill the mask of bits for this ADD/SUB to demand the most
// significant bit and all those below it.
APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ));
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps,
- LHSKnownZero, LHSKnownOne, Depth + 1) ||
+ if (ShrinkDemandedConstant(I, 0, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnownZero, LHSKnownOne,
+ Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
- SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps,
- LHSKnownZero, LHSKnownOne, Depth + 1)) {
+ SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnownZero, RHSKnownOne,
+ Depth + 1)) {
// Disable the nsw and nuw flags here: We can no longer guarantee that
// we won't wrap after simplification. Removing the nsw/nuw flags is
// legal here because the top bit is not demanded.
@@ -543,6 +479,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
BinOP.setHasNoUnsignedWrap(false);
return I;
}
+
+ // If we are known to be adding/subtracting zeros to every bit below
+ // the highest demanded bit, we just return the other side.
+ if ((DemandedFromOps & RHSKnownZero) == DemandedFromOps)
+ return I->getOperand(0);
+ // We can't do this with the LHS for subtraction.
+ if (I->getOpcode() == Instruction::Add &&
+ (DemandedFromOps & LHSKnownZero) == DemandedFromOps)
+ return I->getOperand(1);
}
// Otherwise just hand the add/sub off to computeKnownBits to fill in
@@ -569,19 +514,19 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the shift is NUW/NSW, then it does demand the high bits.
ShlOperator *IOp = cast<ShlOperator>(I);
if (IOp->hasNoSignedWrap())
- DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1);
+ DemandedMaskIn.setHighBits(ShiftAmt+1);
else if (IOp->hasNoUnsignedWrap())
- DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setHighBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
KnownZero <<= ShiftAmt;
KnownOne <<= ShiftAmt;
// low bits known zero.
if (ShiftAmt)
- KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ KnownZero.setLowBits(ShiftAmt);
}
break;
case Instruction::LShr:
@@ -595,19 +540,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If the shift is exact, then it does demand the low bits (and knows that
// they are zero).
if (cast<LShrOperator>(I)->isExact())
- DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setLowBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
- KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
- KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
- if (ShiftAmt) {
- // Compute the new bits that are at the top now.
- APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
- KnownZero |= HighBits; // high bits known zero.
- }
+ KnownZero = KnownZero.lshr(ShiftAmt);
+ KnownOne = KnownOne.lshr(ShiftAmt);
+ if (ShiftAmt)
+ KnownZero.setHighBits(ShiftAmt); // high bits known zero.
}
break;
case Instruction::AShr:
@@ -635,26 +577,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If any of the "high bits" are demanded, we should set the sign bit as
// demanded.
if (DemandedMask.countLeadingZeros() <= ShiftAmt)
- DemandedMaskIn.setBit(BitWidth-1);
+ DemandedMaskIn.setSignBit();
// If the shift is exact, then it does demand the low bits (and knows that
// they are zero).
if (cast<AShrOperator>(I)->isExact())
- DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt);
+ DemandedMaskIn.setLowBits(ShiftAmt);
- if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero,
- KnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne,
+ Depth + 1))
return I;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
// Compute the new bits that are at the top now.
APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt));
- KnownZero = APIntOps::lshr(KnownZero, ShiftAmt);
- KnownOne = APIntOps::lshr(KnownOne, ShiftAmt);
+ KnownZero = KnownZero.lshr(ShiftAmt);
+ KnownOne = KnownOne.lshr(ShiftAmt);
// Handle the sign bits.
APInt SignBit(APInt::getSignBit(BitWidth));
// Adjust to where it is now in the mask.
- SignBit = APIntOps::lshr(SignBit, ShiftAmt);
+ SignBit = SignBit.lshr(ShiftAmt);
// If the input sign bit is known to be zero, or if none of the top bits
// are demanded, turn this into an unsigned shift right.
@@ -683,8 +625,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
APInt LowBits = RA - 1;
APInt Mask2 = LowBits | APInt::getSignBit(BitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero,
- LHSKnownOne, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, Mask2, LHSKnownZero, LHSKnownOne,
+ Depth + 1))
return I;
// The low bits of LHS are unchanged by the srem.
@@ -693,12 +635,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// If LHS is non-negative or has all low bits zero, then the upper bits
// are all zero.
- if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits))
+ if (LHSKnownZero.isNegative() || ((LHSKnownZero & LowBits) == LowBits))
KnownZero |= ~LowBits;
// If LHS is negative and not all low bits are zero, then the upper bits
// are all one.
- if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0))
+ if (LHSKnownOne.isNegative() && ((LHSKnownOne & LowBits) != 0))
KnownOne |= ~LowBits;
assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?");
@@ -713,21 +655,17 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
CxtI);
// If it's known zero, our sign bit is also zero.
if (LHSKnownZero.isNegative())
- KnownZero.setBit(KnownZero.getBitWidth() - 1);
+ KnownZero.setSignBit();
}
break;
case Instruction::URem: {
APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0);
APInt AllOnes = APInt::getAllOnesValue(BitWidth);
- if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2,
- KnownOne2, Depth + 1) ||
- SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2,
- KnownOne2, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, AllOnes, KnownZero2, KnownOne2, Depth + 1) ||
+ SimplifyDemandedBits(I, 1, AllOnes, KnownZero2, KnownOne2, Depth + 1))
return I;
unsigned Leaders = KnownZero2.countLeadingOnes();
- Leaders = std::max(Leaders,
- KnownZero2.countLeadingOnes());
KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask;
break;
}
@@ -792,11 +730,11 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return ConstantInt::getNullValue(VTy);
// We know that the upper bits are set to zero.
- KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth);
+ KnownZero.setBitsFrom(ArgWidth);
return nullptr;
}
case Intrinsic::x86_sse42_crc32_64_64:
- KnownZero = APInt::getHighBitsSet(64, 32);
+ KnownZero.setBitsFrom(32);
return nullptr;
}
}
@@ -811,6 +749,150 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return nullptr;
}
+/// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne
+/// bits. It also tries to handle simplifications that can be done based on
+/// DemandedMask, but without modifying the Instruction.
+Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I,
+ const APInt &DemandedMask,
+ APInt &KnownZero,
+ APInt &KnownOne,
+ unsigned Depth,
+ Instruction *CxtI) {
+ unsigned BitWidth = DemandedMask.getBitWidth();
+ Type *ITy = I->getType();
+
+ APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0);
+ APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0);
+
+ // Despite the fact that we can't simplify this instruction in all User's
+ // context, we can at least compute the knownzero/knownone bits, and we can
+ // do simplifications that apply to *just* the one user if we know that
+ // this instruction has a simpler value in that context.
+ switch (I->getOpcode()) {
+ case Instruction::And: {
+ // If either the LHS or the RHS are Zero, the result is zero.
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 are known to be clear if zero in either the LHS | RHS.
+ APInt IKnownZero = RHSKnownZero | LHSKnownZero;
+ // Output known-1 bits are only known if set in both the LHS & RHS.
+ APInt IKnownOne = RHSKnownOne & LHSKnownOne;
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known 1 on one side, return the other.
+ // These bits cannot contribute to the result of the 'and' in this
+ // context.
+ if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) ==
+ (DemandedMask & ~LHSKnownZero))
+ return I->getOperand(0);
+ if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) ==
+ (DemandedMask & ~RHSKnownZero))
+ return I->getOperand(1);
+
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ case Instruction::Or: {
+ // We can simplify (X|Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+
+ // If either the LHS or the RHS are One, the result is One.
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 bits are only known if clear in both the LHS & RHS.
+ APInt IKnownZero = RHSKnownZero & LHSKnownZero;
+ // Output known-1 are known to be set if set in either the LHS | RHS.
+ APInt IKnownOne = RHSKnownOne | LHSKnownOne;
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known zero on one side, return the
+ // other. These bits cannot contribute to the result of the 'or' in this
+ // context.
+ if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) ==
+ (DemandedMask & ~LHSKnownOne))
+ return I->getOperand(0);
+ if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) ==
+ (DemandedMask & ~RHSKnownOne))
+ return I->getOperand(1);
+
+ // If all of the potentially set bits on one side are known to be set on
+ // the other side, just use the 'other' side.
+ if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) ==
+ (DemandedMask & (~RHSKnownZero)))
+ return I->getOperand(0);
+ if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) ==
+ (DemandedMask & (~LHSKnownZero)))
+ return I->getOperand(1);
+
+ KnownZero = std::move(IKnownZero);
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ case Instruction::Xor: {
+ // We can simplify (X^Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+
+ computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1,
+ CxtI);
+ computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1,
+ CxtI);
+
+ // Output known-0 bits are known if clear or set in both the LHS & RHS.
+ APInt IKnownZero = (RHSKnownZero & LHSKnownZero) |
+ (RHSKnownOne & LHSKnownOne);
+ // Output known-1 are known to be set if set in only one of the LHS, RHS.
+ APInt IKnownOne = (RHSKnownZero & LHSKnownOne) |
+ (RHSKnownOne & LHSKnownZero);
+
+ // If the client is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, IKnownOne);
+
+ // If all of the demanded bits are known zero on one side, return the
+ // other.
+ if ((DemandedMask & RHSKnownZero) == DemandedMask)
+ return I->getOperand(0);
+ if ((DemandedMask & LHSKnownZero) == DemandedMask)
+ return I->getOperand(1);
+
+ // Output known-0 bits are known if clear or set in both the LHS & RHS.
+ KnownZero = std::move(IKnownZero);
+ // Output known-1 are known to be set if set in only one of the LHS, RHS.
+ KnownOne = std::move(IKnownOne);
+ break;
+ }
+ default:
+ // Compute the KnownZero/KnownOne bits to simplify things downstream.
+ computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI);
+
+ // If this user is only demanding bits that we know, return the known
+ // constant.
+ if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask)
+ return Constant::getIntegerValue(ITy, KnownOne);
+
+ break;
+ }
+
+ return nullptr;
+}
+
+
/// Helper routine of SimplifyDemandedUseBits. It tries to simplify
/// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into
/// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign
@@ -849,7 +931,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr,
unsigned ShrAmt = ShrOp1.getZExtValue();
KnownOne.clearAllBits();
- KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1);
+ KnownZero.setLowBits(ShlAmt - 1);
KnownZero &= DemandedMask;
APInt BitMask1(APInt::getAllOnesValue(BitWidth));
@@ -1472,14 +1554,136 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts,
break;
}
+ case Intrinsic::x86_sse2_packssdw_128:
+ case Intrinsic::x86_sse2_packsswb_128:
+ case Intrinsic::x86_sse2_packuswb_128:
+ case Intrinsic::x86_sse41_packusdw:
+ case Intrinsic::x86_avx2_packssdw:
+ case Intrinsic::x86_avx2_packsswb:
+ case Intrinsic::x86_avx2_packusdw:
+ case Intrinsic::x86_avx2_packuswb:
+ case Intrinsic::x86_avx512_packssdw_512:
+ case Intrinsic::x86_avx512_packsswb_512:
+ case Intrinsic::x86_avx512_packusdw_512:
+ case Intrinsic::x86_avx512_packuswb_512: {
+ auto *Ty0 = II->getArgOperand(0)->getType();
+ unsigned InnerVWidth = Ty0->getVectorNumElements();
+ assert(VWidth == (InnerVWidth * 2) && "Unexpected input size");
+
+ unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128;
+ unsigned VWidthPerLane = VWidth / NumLanes;
+ unsigned InnerVWidthPerLane = InnerVWidth / NumLanes;
+
+ // Per lane, pack the elements of the first input and then the second.
+ // e.g.
+ // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3])
+ // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15])
+ for (int OpNum = 0; OpNum != 2; ++OpNum) {
+ APInt OpDemandedElts(InnerVWidth, 0);
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ unsigned LaneIdx = Lane * VWidthPerLane;
+ for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) {
+ unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum;
+ if (DemandedElts[Idx])
+ OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt);
+ }
+ }
+
+ // Demand elements from the operand.
+ auto *Op = II->getArgOperand(OpNum);
+ APInt OpUndefElts(InnerVWidth, 0);
+ TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts,
+ Depth + 1);
+ if (TmpV) {
+ II->setArgOperand(OpNum, TmpV);
+ MadeChange = true;
+ }
+
+ // Pack the operand's UNDEF elements, one lane at a time.
+ OpUndefElts = OpUndefElts.zext(VWidth);
+ for (unsigned Lane = 0; Lane != NumLanes; ++Lane) {
+ APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane);
+ LaneElts = LaneElts.getLoBits(InnerVWidthPerLane);
+ LaneElts = LaneElts.shl(InnerVWidthPerLane * (2 * Lane + OpNum));
+ UndefElts |= LaneElts;
+ }
+ }
+ break;
+ }
+
+ // PSHUFB
+ case Intrinsic::x86_ssse3_pshuf_b_128:
+ case Intrinsic::x86_avx2_pshuf_b:
+ case Intrinsic::x86_avx512_pshuf_b_512:
+ // PERMILVAR
+ case Intrinsic::x86_avx_vpermilvar_ps:
+ case Intrinsic::x86_avx_vpermilvar_ps_256:
+ case Intrinsic::x86_avx512_vpermilvar_ps_512:
+ case Intrinsic::x86_avx_vpermilvar_pd:
+ case Intrinsic::x86_avx_vpermilvar_pd_256:
+ case Intrinsic::x86_avx512_vpermilvar_pd_512:
+ // PERMV
+ case Intrinsic::x86_avx2_permd:
+ case Intrinsic::x86_avx2_permps: {
+ Value *Op1 = II->getArgOperand(1);
+ TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts,
+ Depth + 1);
+ if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; }
+ break;
+ }
+
// SSE4A instructions leave the upper 64-bits of the 128-bit result
// in an undefined state.
case Intrinsic::x86_sse4a_extrq:
case Intrinsic::x86_sse4a_extrqi:
case Intrinsic::x86_sse4a_insertq:
case Intrinsic::x86_sse4a_insertqi:
- UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2);
+ UndefElts.setHighBits(VWidth / 2);
break;
+ case Intrinsic::amdgcn_buffer_load:
+ case Intrinsic::amdgcn_buffer_load_format: {
+ if (VWidth == 1 || !DemandedElts.isMask())
+ return nullptr;
+
+ // TODO: Handle 3 vectors when supported in code gen.
+ unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes());
+ if (NewNumElts == VWidth)
+ return nullptr;
+
+ Module *M = II->getParent()->getParent()->getParent();
+ Type *EltTy = V->getType()->getVectorElementType();
+
+ Type *NewTy = (NewNumElts == 1) ? EltTy :
+ VectorType::get(EltTy, NewNumElts);
+
+ Function *NewIntrin = Intrinsic::getDeclaration(M, II->getIntrinsicID(),
+ NewTy);
+
+ SmallVector<Value *, 5> Args;
+ for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I)
+ Args.push_back(II->getArgOperand(I));
+
+ IRBuilderBase::InsertPointGuard Guard(*Builder);
+ Builder->SetInsertPoint(II);
+
+ CallInst *NewCall = Builder->CreateCall(NewIntrin, Args);
+ NewCall->takeName(II);
+ NewCall->copyMetadata(*II);
+ if (NewNumElts == 1) {
+ return Builder->CreateInsertElement(UndefValue::get(V->getType()),
+ NewCall, static_cast<uint64_t>(0));
+ }
+
+ SmallVector<uint32_t, 8> EltMask;
+ for (unsigned I = 0; I < VWidth; ++I)
+ EltMask.push_back(I);
+
+ Value *Shuffle = Builder->CreateShuffleVector(
+ NewCall, UndefValue::get(NewTy), EltMask);
+
+ MadeChange = true;
+ return Shuffle;
+ }
}
break;
}
diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index b2477f6c8633..e89b400a4afc 100644
--- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -645,6 +645,36 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) {
return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask);
}
+/// If we have an insertelement instruction feeding into another insertelement
+/// and the 2nd is inserting a constant into the vector, canonicalize that
+/// constant insertion before the insertion of a variable:
+///
+/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 -->
+/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1
+///
+/// This has the potential of eliminating the 2nd insertelement instruction
+/// via constant folding of the scalar constant into a vector constant.
+static Instruction *hoistInsEltConst(InsertElementInst &InsElt2,
+ InstCombiner::BuilderTy &Builder) {
+ auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0));
+ if (!InsElt1 || !InsElt1->hasOneUse())
+ return nullptr;
+
+ Value *X, *Y;
+ Constant *ScalarC;
+ ConstantInt *IdxC1, *IdxC2;
+ if (match(InsElt1->getOperand(0), m_Value(X)) &&
+ match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) &&
+ match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) &&
+ match(InsElt2.getOperand(1), m_Constant(ScalarC)) &&
+ match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) {
+ Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2);
+ return InsertElementInst::Create(NewInsElt1, Y, IdxC1);
+ }
+
+ return nullptr;
+}
+
/// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex
/// --> shufflevector X, CVec', Mask'
static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) {
@@ -806,6 +836,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) {
if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE))
return Shuf;
+ if (Instruction *NewInsElt = hoistInsEltConst(IE, *Builder))
+ return NewInsElt;
+
// Turn a sequence of inserts that broadcasts a scalar into a single
// insert + shufflevector.
if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE))
@@ -1107,12 +1140,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SmallVector<int, 16> Mask = SVI.getShuffleMask();
Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
- bool MadeChange = false;
-
- // Undefined shuffle mask -> undefined value.
- if (isa<UndefValue>(SVI.getOperand(2)))
- return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType()));
+ if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getMask(),
+ SVI.getType(), DL, &TLI, &DT, &AC))
+ return replaceInstUsesWith(SVI, V);
+ bool MadeChange = false;
unsigned VWidth = SVI.getType()->getVectorNumElements();
APInt UndefElts(VWidth, 0);
@@ -1209,7 +1241,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
if (isShuffleExtractingFromLHS(SVI, Mask)) {
Value *V = LHS;
unsigned MaskElems = Mask.size();
- unsigned BegIdx = Mask.front();
VectorType *SrcTy = cast<VectorType>(V->getType());
unsigned VecBitWidth = SrcTy->getBitWidth();
unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType());
@@ -1223,6 +1254,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
// Only visit bitcasts that weren't previously handled.
BCs.push_back(BC);
for (BitCastInst *BC : BCs) {
+ unsigned BegIdx = Mask.front();
Type *TgtTy = BC->getDestTy();
unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy);
if (!TgtElemBitWidth)
diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp
index 27fc34d23175..88ef17bbc8fa 100644
--- a/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -82,18 +82,24 @@ static cl::opt<bool>
EnableExpensiveCombines("expensive-combines",
cl::desc("Enable expensive instruction combines"));
+static cl::opt<unsigned>
+MaxArraySize("instcombine-maxarray-size", cl::init(1024),
+ cl::desc("Maximum array size considered when doing a combine"));
+
Value *InstCombiner::EmitGEPOffset(User *GEP) {
return llvm::EmitGEPOffset(Builder, DL, GEP);
}
/// Return true if it is desirable to convert an integer computation from a
/// given bit width to a new bit width.
-/// We don't want to convert from a legal to an illegal type for example or from
-/// a smaller to a larger illegal type.
-bool InstCombiner::ShouldChangeType(unsigned FromWidth,
+/// We don't want to convert from a legal to an illegal type or from a smaller
+/// to a larger illegal type. A width of '1' is always treated as a legal type
+/// because i1 is a fundamental type in IR, and there are many specialized
+/// optimizations for i1 types.
+bool InstCombiner::shouldChangeType(unsigned FromWidth,
unsigned ToWidth) const {
- bool FromLegal = DL.isLegalInteger(FromWidth);
- bool ToLegal = DL.isLegalInteger(ToWidth);
+ bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth);
+ bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth);
// If this is a legal integer from type, and the result would be an illegal
// type, don't do the transformation.
@@ -109,14 +115,16 @@ bool InstCombiner::ShouldChangeType(unsigned FromWidth,
}
/// Return true if it is desirable to convert a computation from 'From' to 'To'.
-/// We don't want to convert from a legal to an illegal type for example or from
-/// a smaller to a larger illegal type.
-bool InstCombiner::ShouldChangeType(Type *From, Type *To) const {
+/// We don't want to convert from a legal to an illegal type or from a smaller
+/// to a larger illegal type. i1 is always treated as a legal type because it is
+/// a fundamental type in IR, and there are many specialized optimizations for
+/// i1 types.
+bool InstCombiner::shouldChangeType(Type *From, Type *To) const {
assert(From->isIntegerTy() && To->isIntegerTy());
unsigned FromWidth = From->getPrimitiveSizeInBits();
unsigned ToWidth = To->getPrimitiveSizeInBits();
- return ShouldChangeType(FromWidth, ToWidth);
+ return shouldChangeType(FromWidth, ToWidth);
}
// Return true, if No Signed Wrap should be maintained for I.
@@ -447,16 +455,11 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp,
/// This function returns identity value for given opcode, which can be used to
/// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1).
-static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
+static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) {
if (isa<Constant>(V))
return nullptr;
- if (OpCode == Instruction::Mul)
- return ConstantInt::get(V->getType(), 1);
-
- // TODO: We can handle other cases e.g. Instruction::And, Instruction::Or etc.
-
- return nullptr;
+ return ConstantExpr::getBinOpIdentity(Opcode, V->getType());
}
/// This function factors binary ops which can be combined using distributive
@@ -468,8 +471,7 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) {
static Instruction::BinaryOps
getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode,
BinaryOperator *Op, Value *&LHS, Value *&RHS) {
- if (!Op)
- return Instruction::BinaryOpsEnd;
+ assert(Op && "Expected a binary operator");
LHS = Op->getOperand(0);
RHS = Op->getOperand(1);
@@ -499,11 +501,7 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,
const DataLayout &DL, BinaryOperator &I,
Instruction::BinaryOps InnerOpcode, Value *A,
Value *B, Value *C, Value *D) {
-
- // If any of A, B, C, D are null, we can not factor I, return early.
- // Checking A and C should be enough.
- if (!A || !C || !B || !D)
- return nullptr;
+ assert(A && B && C && D && "All values must be provided");
Value *V = nullptr;
Value *SimplifiedInst = nullptr;
@@ -564,13 +562,11 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder,
if (isa<OverflowingBinaryOperator>(&I))
HasNSW = I.hasNoSignedWrap();
- if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS))
- if (isa<OverflowingBinaryOperator>(Op0))
- HasNSW &= Op0->hasNoSignedWrap();
+ if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS))
+ HasNSW &= LOBO->hasNoSignedWrap();
- if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS))
- if (isa<OverflowingBinaryOperator>(Op1))
- HasNSW &= Op1->hasNoSignedWrap();
+ if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS))
+ HasNSW &= ROBO->hasNoSignedWrap();
// We can propagate 'nsw' if we know that
// %Y = mul nsw i16 %X, C
@@ -599,31 +595,39 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
+ Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
- // Factorization.
- Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr;
- auto TopLevelOpcode = I.getOpcode();
- auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
- auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
-
- // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
- // a common term.
- if (LHSOpcode == RHSOpcode) {
- if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D))
- return V;
- }
-
- // The instruction has the form "(A op' B) op (C)". Try to factorize common
- // term.
- if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS,
- getIdentityValue(LHSOpcode, RHS)))
- return V;
+ {
+ // Factorization.
+ Value *A, *B, *C, *D;
+ Instruction::BinaryOps LHSOpcode, RHSOpcode;
+ if (Op0)
+ LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+ if (Op1)
+ RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
+
+ // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
+ // a common term.
+ if (Op0 && Op1 && LHSOpcode == RHSOpcode)
+ if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D))
+ return V;
+
+ // The instruction has the form "(A op' B) op (C)". Try to factorize common
+ // term.
+ if (Op0)
+ if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
+ if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS,
+ Ident))
+ return V;
- // The instruction has the form "(B) op (C op' D)". Try to factorize common
- // term.
- if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS,
- getIdentityValue(RHSOpcode, LHS), C, D))
- return V;
+ // The instruction has the form "(B) op (C op' D)". Try to factorize common
+ // term.
+ if (Op1)
+ if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
+ if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, Ident,
+ C, D))
+ return V;
+ }
// Expansion.
if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
@@ -720,6 +724,21 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const {
if (C->getType()->getElementType()->isIntegerTy())
return ConstantExpr::getNeg(C);
+ if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) {
+ for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) {
+ Constant *Elt = CV->getAggregateElement(i);
+ if (!Elt)
+ return nullptr;
+
+ if (isa<UndefValue>(Elt))
+ continue;
+
+ if (!isa<ConstantInt>(Elt))
+ return nullptr;
+ }
+ return ConstantExpr::getNeg(CV);
+ }
+
return nullptr;
}
@@ -820,8 +839,29 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) {
return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
}
-Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
- PHINode *PN = cast<PHINode>(I.getOperand(0));
+static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV,
+ InstCombiner *IC) {
+ bool ConstIsRHS = isa<Constant>(I->getOperand(1));
+ Constant *C = cast<Constant>(I->getOperand(ConstIsRHS));
+
+ if (auto *InC = dyn_cast<Constant>(InV)) {
+ if (ConstIsRHS)
+ return ConstantExpr::get(I->getOpcode(), InC, C);
+ return ConstantExpr::get(I->getOpcode(), C, InC);
+ }
+
+ Value *Op0 = InV, *Op1 = C;
+ if (!ConstIsRHS)
+ std::swap(Op0, Op1);
+
+ Value *RI = IC->Builder->CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp");
+ auto *FPInst = dyn_cast<Instruction>(RI);
+ if (FPInst && isa<FPMathOperator>(FPInst))
+ FPInst->copyFastMathFlags(I);
+ return RI;
+}
+
+Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) {
unsigned NumPHIValues = PN->getNumIncomingValues();
if (NumPHIValues == 0)
return nullptr;
@@ -902,7 +942,11 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
// Beware of ConstantExpr: it may eventually evaluate to getNullValue,
// even if currently isNullValue gives false.
Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i));
- if (InC && !isa<ConstantExpr>(InC))
+ // For vector constants, we cannot use isNullValue to fold into
+ // FalseVInPred versus TrueVInPred. When we have individual nonzero
+ // elements in the vector, we will incorrectly fold InC to
+ // `TrueVInPred`.
+ if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC))
InV = InC->isNullValue() ? FalseVInPred : TrueVInPred;
else
InV = Builder->CreateSelect(PN->getIncomingValue(i),
@@ -923,15 +967,9 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
C, "phitmp");
NewPN->addIncoming(InV, PN->getIncomingBlock(i));
}
- } else if (I.getNumOperands() == 2) {
- Constant *C = cast<Constant>(I.getOperand(1));
+ } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV = nullptr;
- if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)))
- InV = ConstantExpr::get(I.getOpcode(), InC, C);
- else
- InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(),
- PN->getIncomingValue(i), C, "phitmp");
+ Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), this);
NewPN->addIncoming(InV, PN->getIncomingBlock(i));
}
} else {
@@ -957,14 +995,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) {
return replaceInstUsesWith(I, NewPN);
}
-Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) {
+Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) {
assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type");
if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) {
if (Instruction *NewSel = FoldOpIntoSelect(I, Sel))
return NewSel;
- } else if (isa<PHINode>(I.getOperand(0))) {
- if (Instruction *NewPhi = FoldOpIntoPhi(I))
+ } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) {
+ if (Instruction *NewPhi = foldOpIntoPhi(I, PN))
return NewPhi;
}
return nullptr;
@@ -1315,22 +1353,19 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) {
assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth);
assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth);
- // If both arguments of binary operation are shuffles, which use the same
- // mask and shuffle within a single vector, it is worthwhile to move the
- // shuffle after binary operation:
+ // If both arguments of the binary operation are shuffles that use the same
+ // mask and shuffle within a single vector, move the shuffle after the binop:
// Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m)
- if (isa<ShuffleVectorInst>(LHS) && isa<ShuffleVectorInst>(RHS)) {
- ShuffleVectorInst *LShuf = cast<ShuffleVectorInst>(LHS);
- ShuffleVectorInst *RShuf = cast<ShuffleVectorInst>(RHS);
- if (isa<UndefValue>(LShuf->getOperand(1)) &&
- isa<UndefValue>(RShuf->getOperand(1)) &&
- LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType() &&
- LShuf->getMask() == RShuf->getMask()) {
- Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0),
- RShuf->getOperand(0), Builder);
- return Builder->CreateShuffleVector(NewBO,
- UndefValue::get(NewBO->getType()), LShuf->getMask());
- }
+ auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS);
+ auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS);
+ if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() &&
+ isa<UndefValue>(LShuf->getOperand(1)) &&
+ isa<UndefValue>(RShuf->getOperand(1)) &&
+ LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) {
+ Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0),
+ RShuf->getOperand(0), Builder);
+ return Builder->CreateShuffleVector(
+ NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask());
}
// If one argument is a shuffle within one vector, the other is a constant,
@@ -1559,27 +1594,21 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
// Replace: gep (gep %P, long B), long A, ...
// With: T = long A+B; gep %P, T, ...
//
- Value *Sum;
Value *SO1 = Src->getOperand(Src->getNumOperands()-1);
Value *GO1 = GEP.getOperand(1);
- if (SO1 == Constant::getNullValue(SO1->getType())) {
- Sum = GO1;
- } else if (GO1 == Constant::getNullValue(GO1->getType())) {
- Sum = SO1;
- } else {
- // If they aren't the same type, then the input hasn't been processed
- // by the loop above yet (which canonicalizes sequential index types to
- // intptr_t). Just avoid transforming this until the input has been
- // normalized.
- if (SO1->getType() != GO1->getType())
- return nullptr;
- // Only do the combine when GO1 and SO1 are both constants. Only in
- // this case, we are sure the cost after the merge is never more than
- // that before the merge.
- if (!isa<Constant>(GO1) || !isa<Constant>(SO1))
- return nullptr;
- Sum = Builder->CreateAdd(SO1, GO1, PtrOp->getName()+".sum");
- }
+
+ // If they aren't the same type, then the input hasn't been processed
+ // by the loop above yet (which canonicalizes sequential index types to
+ // intptr_t). Just avoid transforming this until the input has been
+ // normalized.
+ if (SO1->getType() != GO1->getType())
+ return nullptr;
+
+ Value* Sum = SimplifyAddInst(GO1, SO1, false, false, DL, &TLI, &DT, &AC);
+ // Only do the combine when we are sure the cost after the
+ // merge is never more than that before the merge.
+ if (Sum == nullptr)
+ return nullptr;
// Update the GEP in place if possible.
if (Src->getNumOperands() == 2) {
@@ -1654,14 +1683,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
}
}
- // Handle gep(bitcast x) and gep(gep x, 0, 0, 0).
- Value *StrippedPtr = PtrOp->stripPointerCasts();
- PointerType *StrippedPtrTy = dyn_cast<PointerType>(StrippedPtr->getType());
-
// We do not handle pointer-vector geps here.
- if (!StrippedPtrTy)
+ if (GEP.getType()->isVectorTy())
return nullptr;
+ // Handle gep(bitcast x) and gep(gep x, 0, 0, 0).
+ Value *StrippedPtr = PtrOp->stripPointerCasts();
+ PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType());
+
if (StrippedPtr != PtrOp) {
bool HasZeroPointerIndex = false;
if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1)))
@@ -2239,11 +2268,11 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
ConstantInt *AddRHS;
if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) {
// Change 'switch (X+4) case 1:' into 'switch (X) case -3'.
- for (SwitchInst::CaseIt CaseIter : SI.cases()) {
- Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS);
+ for (auto Case : SI.cases()) {
+ Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS);
assert(isa<ConstantInt>(NewCase) &&
"Result of expression should be constant");
- CaseIter.setValue(cast<ConstantInt>(NewCase));
+ Case.setValue(cast<ConstantInt>(NewCase));
}
SI.setCondition(Op0);
return &SI;
@@ -2275,9 +2304,9 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) {
Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc");
SI.setCondition(NewCond);
- for (SwitchInst::CaseIt CaseIter : SI.cases()) {
- APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth);
- CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
+ for (auto Case : SI.cases()) {
+ APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth);
+ Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase));
}
return &SI;
}
@@ -2934,8 +2963,8 @@ bool InstCombiner::run() {
Result->takeName(I);
// Push the new instruction and any users onto the worklist.
- Worklist.Add(Result);
Worklist.AddUsersToWorkList(*Result);
+ Worklist.Add(Result);
// Insert the new instruction into the basic block...
BasicBlock *InstParent = I->getParent();
@@ -2958,8 +2987,8 @@ bool InstCombiner::run() {
if (isInstructionTriviallyDead(I, &TLI)) {
eraseInstFromFunction(*I);
} else {
- Worklist.Add(I);
Worklist.AddUsersToWorkList(*I);
+ Worklist.Add(I);
}
}
MadeIRChange = true;
@@ -3022,12 +3051,11 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
}
// See if we can constant fold its operands.
- for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e;
- ++i) {
- if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i))
+ for (Use &U : Inst->operands()) {
+ if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U))
continue;
- auto *C = cast<Constant>(i);
+ auto *C = cast<Constant>(U);
Constant *&FoldRes = FoldedConstants[C];
if (!FoldRes)
FoldRes = ConstantFoldConstant(C, DL, TLI);
@@ -3035,7 +3063,10 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
FoldRes = C;
if (FoldRes != C) {
- *i = FoldRes;
+ DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst
+ << "\n Old = " << *C
+ << "\n New = " << *FoldRes << '\n');
+ U = FoldRes;
MadeIRChange = true;
}
}
@@ -3055,17 +3086,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL,
}
} else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) {
if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) {
- // See if this is an explicit destination.
- for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end();
- i != e; ++i)
- if (i.getCaseValue() == Cond) {
- BasicBlock *ReachableBB = i.getCaseSuccessor();
- Worklist.push_back(ReachableBB);
- continue;
- }
-
- // Otherwise it is the default destination.
- Worklist.push_back(SI->getDefaultDest());
+ Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor());
continue;
}
}
@@ -3152,6 +3173,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist,
InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines,
AA, AC, TLI, DT, DL, LI);
+ IC.MaxArraySizeForCombine = MaxArraySize;
Changed |= IC.run();
if (!Changed)
@@ -3176,9 +3198,10 @@ PreservedAnalyses InstCombinePass::run(Function &F,
return PreservedAnalyses::all();
// Mark all the analyses that instcombine updates as preserved.
- // FIXME: This should also 'preserve the CFG'.
PreservedAnalyses PA;
- PA.preserve<DominatorTreeAnalysis>();
+ PA.preserveSet<CFGAnalyses>();
+ PA.preserve<AAManager>();
+ PA.preserve<GlobalsAA>();
return PA;
}