diff options
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 135 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 1192 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 1148 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 169 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 212 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineInternal.h | 61 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 145 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 58 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 6 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 86 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 703 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 562 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 44 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 285 |
14 files changed, 3029 insertions, 1777 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 2d34c1cc74bd..174ec8036274 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -902,7 +902,7 @@ bool InstCombiner::WillNotOverflowSignedAdd(Value *LHS, Value *RHS, APInt RHSKnownOne(BitWidth, 0); computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); - // Addition of two 2's compliment numbers having opposite signs will never + // Addition of two 2's complement numbers having opposite signs will never // overflow. if ((LHSKnownOne[BitWidth - 1] && RHSKnownZero[BitWidth - 1]) || (LHSKnownZero[BitWidth - 1] && RHSKnownOne[BitWidth - 1])) @@ -939,7 +939,7 @@ bool InstCombiner::WillNotOverflowSignedSub(Value *LHS, Value *RHS, APInt RHSKnownOne(BitWidth, 0); computeKnownBits(RHS, RHSKnownZero, RHSKnownOne, 0, &CxtI); - // Subtraction of two 2's compliment numbers having identical signs will + // Subtraction of two 2's complement numbers having identical signs will // never overflow. if ((LHSKnownOne[BitWidth - 1] && RHSKnownOne[BitWidth - 1]) || (LHSKnownZero[BitWidth - 1] && RHSKnownZero[BitWidth - 1])) @@ -1042,43 +1042,42 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - const APInt *Val; - if (match(RHS, m_APInt(Val))) { - // X + (signbit) --> X ^ signbit - if (Val->isSignBit()) + const APInt *RHSC; + if (match(RHS, m_APInt(RHSC))) { + if (RHSC->isSignBit()) { + // If wrapping is not allowed, then the addition must set the sign bit: + // X + (signbit) --> X | signbit + if (I.hasNoSignedWrap() || I.hasNoUnsignedWrap()) + return BinaryOperator::CreateOr(LHS, RHS); + + // If wrapping is allowed, then the addition flips the sign bit of LHS: + // X + (signbit) --> X ^ signbit return BinaryOperator::CreateXor(LHS, RHS); + } // Is this add the last step in a convoluted sext? Value *X; const APInt *C; if (match(LHS, m_ZExt(m_Xor(m_Value(X), m_APInt(C)))) && C->isMinSignedValue() && - C->sext(LHS->getType()->getScalarSizeInBits()) == *Val) { + C->sext(LHS->getType()->getScalarSizeInBits()) == *RHSC) { // add(zext(xor i16 X, -32768), -32768) --> sext X return CastInst::Create(Instruction::SExt, X, LHS->getType()); } - if (Val->isNegative() && + if (RHSC->isNegative() && match(LHS, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C)))) && - Val->sge(-C->sext(Val->getBitWidth()))) { + RHSC->sge(-C->sext(RHSC->getBitWidth()))) { // (add (zext (add nuw X, C)), Val) -> (zext (add nuw X, C+Val)) - return CastInst::Create( - Instruction::ZExt, - Builder->CreateNUWAdd( - X, Constant::getIntegerValue(X->getType(), - *C + Val->trunc(C->getBitWidth()))), - I.getType()); + Constant *NewC = + ConstantInt::get(X->getType(), *C + RHSC->trunc(C->getBitWidth())); + return new ZExtInst(Builder->CreateNUWAdd(X, NewC), I.getType()); } } // FIXME: Use the match above instead of dyn_cast to allow these transforms // for splat vectors. if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // See if SimplifyDemandedBits can simplify this. This handles stuff like - // (X & 254)+1 -> (X&254)|1 - if (SimplifyDemandedInstructionBits(I)) - return &I; - // zext(bool) + C -> bool ? C + 1 : C if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) if (ZI->getSrcTy()->isIntegerTy(1)) @@ -1129,8 +1128,8 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - if (isa<Constant>(RHS) && isa<PHINode>(LHS)) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (isa<Constant>(RHS)) + if (Instruction *NV = foldOpWithConstantIntoOperand(I)) return NV; if (I.getType()->getScalarType()->isIntegerTy(1)) @@ -1201,11 +1200,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateAnd(NewAdd, C2); } } - - // Try to fold constant add into select arguments. - if (SelectInst *SI = dyn_cast<SelectInst>(LHS)) - if (Instruction *R = FoldOpIntoSelect(I, SI)) - return R; } // add (select X 0 (sub n A)) A --> select X A n @@ -1253,7 +1247,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (sext x), (sext y)) --> (sext (add int x, y)) if (SExtInst *RHSConv = dyn_cast<SExtInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a + // Only do this if x/y have the same type, if at least one of them has a // single use (so we don't increase the number of sexts), and if the // integer add will not overflow. if (LHSConv->getOperand(0)->getType() == @@ -1290,7 +1284,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (zext x), (zext y)) --> (zext (add int x, y)) if (auto *RHSConv = dyn_cast<ZExtInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a + // Only do this if x/y have the same type, if at least one of them has a // single use (so we don't increase the number of zexts), and if the // integer add will not overflow. if (LHSConv->getOperand(0)->getType() == @@ -1311,13 +1305,11 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - (match(LHS, m_And(m_Specific(A), m_Specific(B))) || - match(LHS, m_And(m_Specific(B), m_Specific(A))))) + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - (match(RHS, m_And(m_Specific(A), m_Specific(B))) || - match(RHS, m_And(m_Specific(B), m_Specific(A))))) + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); } @@ -1325,8 +1317,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { { Value *A = nullptr, *B = nullptr; if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - (match(LHS, m_And(m_Specific(A), m_Specific(B))) || - match(LHS, m_And(m_Specific(B), m_Specific(A))))) { + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { auto *New = BinaryOperator::CreateAdd(A, B); New->setHasNoSignedWrap(I.hasNoSignedWrap()); New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1334,8 +1325,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - (match(RHS, m_And(m_Specific(A), m_Specific(B))) || - match(RHS, m_And(m_Specific(B), m_Specific(A))))) { + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { auto *New = BinaryOperator::CreateAdd(A, B); New->setHasNoSignedWrap(I.hasNoSignedWrap()); New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); @@ -1394,6 +1384,8 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { + Value *LHSIntVal = LHSConv->getOperand(0); + // (fadd double (sitofp x), fpcst) --> (sitofp (add int x, intcst)) // ... if the constant fits in the integer value. This is useful for things // like (double)(x & 1234) + 4.0 -> (double)((X & 1234)+4) which no longer @@ -1401,12 +1393,12 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { // instcombined. if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) { Constant *CI = - ConstantExpr::getFPToSI(CFP, LHSConv->getOperand(0)->getType()); + ConstantExpr::getFPToSI(CFP, LHSIntVal->getType()); if (LHSConv->hasOneUse() && ConstantExpr::getSIToFP(CI, I.getType()) == CFP && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { + WillNotOverflowSignedAdd(LHSIntVal, CI, I)) { // Insert the new integer add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), + Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal, CI, "addconv"); return new SIToFPInst(NewAdd, I.getType()); } @@ -1414,17 +1406,17 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { // (fadd double (sitofp x), (sitofp y)) --> (sitofp (add int x, y)) if (SIToFPInst *RHSConv = dyn_cast<SIToFPInst>(RHS)) { - // Only do this if x/y have the same type, if at last one of them has a + Value *RHSIntVal = RHSConv->getOperand(0); + + // Only do this if x/y have the same type, if at least one of them has a // single use (so we don't increase the number of int->fp conversions), // and if the integer add will not overflow. - if (LHSConv->getOperand(0)->getType() == - RHSConv->getOperand(0)->getType() && + if (LHSIntVal->getType() == RHSIntVal->getType() && (LHSConv->hasOneUse() || RHSConv->hasOneUse()) && - WillNotOverflowSignedAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0), I)) { + WillNotOverflowSignedAdd(LHSIntVal, RHSIntVal, I)) { // Insert the new integer add. - Value *NewAdd = Builder->CreateNSWAdd(LHSConv->getOperand(0), - RHSConv->getOperand(0),"addconv"); + Value *NewAdd = Builder->CreateNSWAdd(LHSIntVal, + RHSIntVal, "addconv"); return new SIToFPInst(NewAdd, I.getType()); } } @@ -1562,7 +1554,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } - if (I.getType()->isIntegerTy(1)) + if (I.getType()->getScalarType()->isIntegerTy(1)) return BinaryOperator::CreateXor(Op0, Op1); // Replace (-1 - A) with (~A). @@ -1580,14 +1572,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; + // Try to fold constant sub into PHI values. + if (PHINode *PN = dyn_cast<PHINode>(Op1)) + if (Instruction *R = foldOpIntoPhi(I, PN)) + return R; + // C-(X+C2) --> (C-C2)-X Constant *C2; if (match(Op1, m_Add(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); - if (SimplifyDemandedInstructionBits(I)) - return &I; - // Fold (sub 0, (zext bool to B)) --> (sext bool to B) if (C->isNullValue() && match(Op1, m_ZExt(m_Value(X)))) if (X->getType()->getScalarType()->isIntegerTy(1)) @@ -1622,11 +1616,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - if ((*Op0C + 1).isPowerOf2()) { - APInt KnownZero(BitWidth, 0); - APInt KnownOne(BitWidth, 0); - computeKnownBits(&I, KnownZero, KnownOne, 0, &I); - if ((*Op0C | KnownZero).isAllOnesValue()) + if (Op0C->isMask()) { + APInt RHSKnownZero(BitWidth, 0); + APInt RHSKnownOne(BitWidth, 0); + computeKnownBits(Op1, RHSKnownZero, RHSKnownOne, 0, &I); + if ((*Op0C | RHSKnownZero).isAllOnesValue()) return BinaryOperator::CreateXor(Op1, Op0); } } @@ -1634,8 +1628,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { { Value *Y; // X-(X+Y) == -Y X-(Y+X) == -Y - if (match(Op1, m_Add(m_Specific(Op0), m_Value(Y))) || - match(Op1, m_Add(m_Value(Y), m_Specific(Op0)))) + if (match(Op1, m_c_Add(m_Specific(Op0), m_Value(Y)))) return BinaryOperator::CreateNeg(Y); // (X-Y)-X == -Y @@ -1645,18 +1638,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // (sub (or A, B) (xor A, B)) --> (and A, B) { - Value *A = nullptr, *B = nullptr; + Value *A, *B; if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && - (match(Op0, m_Or(m_Specific(A), m_Specific(B))) || - match(Op0, m_Or(m_Specific(B), m_Specific(A))))) + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateAnd(A, B); } - if (Op0->hasOneUse()) { - Value *Y = nullptr; + { + Value *Y; // ((X | Y) - X) --> (~X & Y) - if (match(Op0, m_Or(m_Value(Y), m_Specific(Op1))) || - match(Op0, m_Or(m_Specific(Op1), m_Value(Y)))) + if (match(Op0, m_OneUse(m_c_Or(m_Value(Y), m_Specific(Op1))))) return BinaryOperator::CreateAnd( Y, Builder->CreateNot(Op1, Op1->getName() + ".not")); } @@ -1664,7 +1655,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; - Constant *CI = nullptr; // (X - (Y - Z)) --> (X + (Z - Y)). if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) @@ -1673,8 +1663,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // (X - (X & Y)) --> (X & ~Y) // - if (match(Op1, m_And(m_Value(Y), m_Specific(Op0))) || - match(Op1, m_And(m_Specific(Op0), m_Value(Y)))) + if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) return BinaryOperator::CreateAnd(Op0, Builder->CreateNot(Y, Y->getName() + ".not")); @@ -1702,14 +1691,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { // X - A*-B -> X + A*B // X - -A*B -> X + A*B Value *A, *B; - if (match(Op1, m_Mul(m_Value(A), m_Neg(m_Value(B)))) || - match(Op1, m_Mul(m_Neg(m_Value(A)), m_Value(B)))) + Constant *CI; + if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B))))) return BinaryOperator::CreateAdd(Op0, Builder->CreateMul(A, B)); // X - A*CI -> X + A*-CI - // X - CI*A -> X + A*-CI - if (match(Op1, m_Mul(m_Value(A), m_Constant(CI))) || - match(Op1, m_Mul(m_Constant(CI), m_Value(A)))) { + // No need to handle commuted multiply because multiply handling will + // ensure constant will be move to the right hand side. + if (match(Op1, m_Mul(m_Value(A), m_Constant(CI)))) { Value *NewMul = Builder->CreateMul(A, ConstantExpr::getNeg(CI)); return BinaryOperator::CreateAdd(Op0, NewMul); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index da5384a86aac..b2a41c699202 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -137,9 +137,8 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { } /// This handles expressions of the form ((val OP C1) & C2). Where -/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is -/// guaranteed to be a binary operator. -Instruction *InstCombiner::OptAndOp(Instruction *Op, +/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. +Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd) { @@ -149,6 +148,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, Together = ConstantExpr::getAnd(AndRHS, OpRHS); switch (Op->getOpcode()) { + default: break; case Instruction::Xor: if (Op->hasOneUse()) { // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) @@ -159,13 +159,6 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, break; case Instruction::Or: if (Op->hasOneUse()){ - if (Together != OpRHS) { - // (X | C1) & C2 --> (X | (C1&C2)) & C2 - Value *Or = Builder->CreateOr(X, Together); - Or->takeName(Op); - return BinaryOperator::CreateAnd(Or, AndRHS); - } - ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together); if (TogetherCI && !TogetherCI->isZero()){ // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 @@ -302,178 +295,91 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo); } -/// Returns true iff Val consists of one contiguous run of 1s with any number -/// of 0s on either side. The 1s are allowed to wrap from LSB to MSB, -/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is -/// not, since all 1s are not contiguous. -static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { - const APInt& V = Val->getValue(); - uint32_t BitWidth = Val->getType()->getBitWidth(); - if (!APIntOps::isShiftedMask(BitWidth, V)) return false; - - // look for the first zero bit after the run of ones - MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); - // look for the first non-zero bit - ME = V.getActiveBits(); - return true; -} - -/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines -/// whether the operator is a sub. If we can fold one of the following xforms: +/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns +/// that can be simplified. +/// One of A and B is considered the mask. The other is the value. This is +/// described as the "AMask" or "BMask" part of the enum. If the enum contains +/// only "Mask", then both A and B can be considered masks. If A is the mask, +/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0. +/// If both A and C are constants, this proof is also easy. +/// For the following explanations, we assume that A is the mask. /// -/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask -/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 -/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// "AllOnes" declares that the comparison is true only if (A & B) == A or all +/// bits of A are set in B. +/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes /// -/// return (A +/- B). +/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all +/// bits of A are cleared in B. +/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes +/// +/// "Mixed" declares that (A & B) == C and C might or might not contain any +/// number of one bits and zero bits. +/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed +/// +/// "Not" means that in above descriptions "==" should be replaced by "!=". +/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes /// -Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, - ConstantInt *Mask, bool isSub, - Instruction &I) { - Instruction *LHSI = dyn_cast<Instruction>(LHS); - if (!LHSI || LHSI->getNumOperands() != 2 || - !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr; - - ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1)); - - switch (LHSI->getOpcode()) { - default: return nullptr; - case Instruction::And: - if (ConstantExpr::getAnd(N, Mask) == Mask) { - // If the AndRHS is a power of two minus one (0+1+), this is simple. - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == - Mask->getValue().getBitWidth()) - break; - - // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ - // part, we don't need any explicit masks to take them out of A. If that - // is all N is, ignore it. - uint32_t MB = 0, ME = 0; - if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive - uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); - APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask, 0, &I)) - break; - } - } - return nullptr; - case Instruction::Or: - case Instruction::Xor: - // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() - && ConstantExpr::getAnd(N, Mask)->isNullValue()) - break; - return nullptr; - } - - if (isSub) - return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold"); - return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold"); -} - -/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C) -/// One of A and B is considered the mask, the other the value. This is -/// described as the "AMask" or "BMask" part of the enum. If the enum -/// contains only "Mask", then both A and B can be considered masks. -/// If A is the mask, then it was proven, that (A & C) == C. This -/// is trivial if C == A, or C == 0. If both A and C are constants, this -/// proof is also easy. -/// For the following explanations we assume that A is the mask. -/// The part "AllOnes" declares, that the comparison is true only -/// if (A & B) == A, or all bits of A are set in B. -/// Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes -/// The part "AllZeroes" declares, that the comparison is true only -/// if (A & B) == 0, or all bits of A are cleared in B. -/// Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes -/// The part "Mixed" declares, that (A & B) == C and C might or might not -/// contain any number of one bits and zero bits. -/// Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed -/// The Part "Not" means, that in above descriptions "==" should be replaced -/// by "!=". -/// Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes /// If the mask A contains a single bit, then the following is equivalent: /// (icmp eq (A & B), A) equals (icmp ne (A & B), 0) /// (icmp ne (A & B), A) equals (icmp eq (A & B), 0) enum MaskedICmpType { - FoldMskICmp_AMask_AllOnes = 1, - FoldMskICmp_AMask_NotAllOnes = 2, - FoldMskICmp_BMask_AllOnes = 4, - FoldMskICmp_BMask_NotAllOnes = 8, - FoldMskICmp_Mask_AllZeroes = 16, - FoldMskICmp_Mask_NotAllZeroes = 32, - FoldMskICmp_AMask_Mixed = 64, - FoldMskICmp_AMask_NotMixed = 128, - FoldMskICmp_BMask_Mixed = 256, - FoldMskICmp_BMask_NotMixed = 512 + AMask_AllOnes = 1, + AMask_NotAllOnes = 2, + BMask_AllOnes = 4, + BMask_NotAllOnes = 8, + Mask_AllZeros = 16, + Mask_NotAllZeros = 32, + AMask_Mixed = 64, + AMask_NotMixed = 128, + BMask_Mixed = 256, + BMask_NotMixed = 512 }; -/// Return the set of pattern classes (from MaskedICmpType) -/// that (icmp SCC (A & B), C) satisfies. -static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, - ICmpInst::Predicate SCC) -{ +/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C) +/// satisfies. +static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, + ICmpInst::Predicate Pred) { ConstantInt *ACst = dyn_cast<ConstantInt>(A); ConstantInt *BCst = dyn_cast<ConstantInt>(B); ConstantInt *CCst = dyn_cast<ConstantInt>(C); - bool icmp_eq = (SCC == ICmpInst::ICMP_EQ); - bool icmp_abit = (ACst && !ACst->isZero() && - ACst->getValue().isPowerOf2()); - bool icmp_bbit = (BCst && !BCst->isZero() && - BCst->getValue().isPowerOf2()); - unsigned result = 0; + bool IsEq = (Pred == ICmpInst::ICMP_EQ); + bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); + bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); + unsigned MaskVal = 0; if (CCst && CCst->isZero()) { // if C is zero, then both A and B qualify as mask - result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed)); - return result; + MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) + : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed) + : (AMask_AllOnes | AMask_Mixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed) + : (BMask_AllOnes | BMask_Mixed)); + return MaskVal; } + if (A == C) { - result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed) - : (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed)); - } else if (ACst && CCst && - ConstantExpr::getAnd(ACst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_AMask_Mixed - : FoldMskICmp_AMask_NotMixed); + MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed) + : (AMask_NotAllOnes | AMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) + : (Mask_AllZeros | AMask_Mixed)); + } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { + MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); } + if (B == C) { - result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_BMask_Mixed)); - } else if (BCst && CCst && - ConstantExpr::getAnd(BCst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_BMask_Mixed - : FoldMskICmp_BMask_NotMixed); - } - return result; + MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed) + : (BMask_NotAllOnes | BMask_NotMixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) + : (Mask_AllZeros | BMask_Mixed)); + } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { + MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); + } + + return MaskVal; } /// Convert an analysis of a masked ICmp into its equivalent if all boolean @@ -482,32 +388,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, /// involves swapping those bits over. static unsigned conjugateICmpMask(unsigned Mask) { unsigned NewMask; - NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes | - FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed)) + NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros | + AMask_Mixed | BMask_Mixed)) << 1; - NewMask |= - (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)) - >> 1; + NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros | + AMask_NotMixed | BMask_NotMixed)) + >> 1; return NewMask; } -/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) -/// Return the set of pattern classes (from MaskedICmpType) -/// that both LHS and RHS satisfy. -static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, - Value*& B, Value*& C, - Value*& D, Value*& E, - ICmpInst *LHS, ICmpInst *RHS, - ICmpInst::Predicate &LHSCC, - ICmpInst::Predicate &RHSCC) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; +/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). +/// Return the set of pattern classes (from MaskedICmpType) that both LHS and +/// RHS satisfy. +static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { + if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + return 0; // vectors are not (yet?) supported - if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; + if (LHS->getOperand(0)->getType()->isVectorTy()) + return 0; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -517,9 +421,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, // above. Value *L1 = LHS->getOperand(0); Value *L2 = LHS->getOperand(1); - Value *L11,*L12,*L21,*L22; + Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { + if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { // Look for ANDs in the LHS icmp. @@ -543,22 +447,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } // Bail if LHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(LHSCC)) + if (!ICmpInst::isEquality(PredL)) return 0; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); - Value *R11,*R12; - bool ok = false; - if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { + Value *R11, *R12; + bool Ok = false; + if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; + A = R11; + D = R12; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; + A = R12; + D = R11; } else { return 0; } - E = R2; R1 = nullptr; ok = true; + E = R2; + R1 = nullptr; + Ok = true; } else if (R1->getType()->isIntegerTy()) { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an @@ -568,46 +476,62 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R2; ok = true; + A = R11; + D = R12; + E = R2; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R2; ok = true; + A = R12; + D = R11; + E = R2; + Ok = true; } } // Bail if RHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(RHSCC)) + if (!ICmpInst::isEquality(PredR)) return 0; // Look for ANDs on the right side of the RHS icmp. - if (!ok && R2->getType()->isIntegerTy()) { + if (!Ok && R2->getType()->isIntegerTy()) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; R12 = Constant::getAllOnesValue(R2->getType()); } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R1; ok = true; + A = R11; + D = R12; + E = R1; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R1; ok = true; + A = R12; + D = R11; + E = R1; + Ok = true; } else { return 0; } } - if (!ok) + if (!Ok) return 0; if (L11 == A) { - B = L12; C = L2; + B = L12; + C = L2; } else if (L12 == A) { - B = L11; C = L2; + B = L11; + C = L2; } else if (L21 == A) { - B = L22; C = L1; + B = L22; + C = L1; } else if (L22 == A) { - B = L21; C = L1; + B = L21; + C = L1; } - unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); - unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); + unsigned LeftType = getMaskedICmpType(A, B, C, PredL); + unsigned RightType = getMaskedICmpType(A, D, E, PredR); return LeftType & RightType; } @@ -616,12 +540,14 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, - LHSCC, RHSCC); - if (Mask == 0) return nullptr; - assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && - "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + unsigned Mask = + getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); + if (Mask == 0) + return nullptr; + + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); // In full generality: // (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -642,7 +568,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Mask = conjugateICmpMask(Mask); } - if (Mask & FoldMskICmp_Mask_AllZeroes) { + if (Mask & Mask_AllZeros) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) Value *NewOr = Builder->CreateOr(B, D); @@ -653,14 +579,14 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *Zero = Constant::getNullValue(A->getType()); return Builder->CreateICmp(NewCC, NewAnd, Zero); } - if (Mask & FoldMskICmp_BMask_AllOnes) { + if (Mask & BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) Value *NewOr = Builder->CreateOr(B, D); Value *NewAnd = Builder->CreateAnd(A, NewOr); return Builder->CreateICmp(NewCC, NewAnd, NewOr); } - if (Mask & FoldMskICmp_AMask_AllOnes) { + if (Mask & AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) Value *NewAnd1 = Builder->CreateAnd(B, D); @@ -672,11 +598,13 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. ConstantInt *BCst = dyn_cast<ConstantInt>(B); - if (!BCst) return nullptr; + if (!BCst) + return nullptr; ConstantInt *DCst = dyn_cast<ConstantInt>(D); - if (!DCst) return nullptr; + if (!DCst) + return nullptr; - if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { + if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -689,7 +617,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_AMask_NotAllOnes) { + + if (Mask & AMask_NotAllOnes) { // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -701,7 +630,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_BMask_Mixed) { + + if (Mask & BMask_Mixed) { // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -713,23 +643,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. ConstantInt *CCst = dyn_cast<ConstantInt>(C); - if (!CCst) return nullptr; + if (!CCst) + return nullptr; ConstantInt *ECst = dyn_cast<ConstantInt>(E); - if (!ECst) return nullptr; - if (LHSCC != NewCC) + if (!ECst) + return nullptr; + if (PredL != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (RHSCC != NewCC) + if (PredR != NewCC) ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + // If there is a conflict, we should actually return a false for the // whole construct. if (((BCst->getValue() & DCst->getValue()) & (CCst->getValue() ^ ECst->getValue())) != 0) return ConstantInt::get(LHS->getType(), !IsAnd); + Value *NewOr1 = Builder->CreateOr(B, D); Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); Value *NewAnd = Builder->CreateAnd(A, NewOr1); return Builder->CreateICmp(NewCC, NewAnd, NewOr2); } + return nullptr; } @@ -789,12 +724,67 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder->CreateICmp(NewPred, Input, RangeEnd); } +static Value * +foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, + InstCombiner::BuilderTy *Builder) { + Value *X = LHS->getOperand(0); + if (X != RHS->getOperand(0)) + return nullptr; + + const APInt *C1, *C2; + if (!match(LHS->getOperand(1), m_APInt(C1)) || + !match(RHS->getOperand(1), m_APInt(C2))) + return nullptr; + + // We only handle (X != C1 && X != C2) and (X == C1 || X == C2). + ICmpInst::Predicate Pred = LHS->getPredicate(); + if (Pred != RHS->getPredicate()) + return nullptr; + if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) + return nullptr; + if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // The larger unsigned constant goes on the right. + if (C1->ugt(*C2)) + std::swap(C1, C2); + + APInt Xor = *C1 ^ *C2; + if (Xor.isPowerOf2()) { + // If LHSC and RHSC differ by only one bit, then set that bit in X and + // compare against the larger constant: + // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2 + // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2 + // We choose an 'or' with a Pow2 constant rather than the inverse mask with + // 'and' because that may lead to smaller codegen from a smaller constant. + Value *Or = Builder->CreateOr(X, ConstantInt::get(X->getType(), Xor)); + return Builder->CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); + } + + // Special case: get the ordering right when the values wrap around zero. + // Ie, we assumed the constants were unsigned when swapping earlier. + if (*C1 == 0 && C2->isAllOnesValue()) + std::swap(C1, C2); + + if (*C1 == *C2 - 1) { + // (X == 13 || X == 14) --> X - 13 <=u 1 + // (X != 13 && X != 14) --> X - 13 >u 1 + // An 'add' is the canonical IR form, so favor that over a 'sub'. + Value *Add = Builder->CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; + return Builder->CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); + } + + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -819,86 +809,90 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (!LHSCst || !RHSCst) return nullptr; + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) // where C is a power of 2 or // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || - (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || + (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { + Value *NewOr = Builder->CreateOr(LHS0, RHS0); + return Builder->CreateICmp(PredL, NewOr, LHSC); } } // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. - if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC && - LHS->hasOneUse() && RHS->hasOneUse()) { + if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && + RHS->hasOneUse()) { Value *V; - ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr; + ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr; // (trunc x) == C1 & (and x, CA) == C2 // (and x, CA) == C2 & (trunc x) == C1 - if (match(Val2, m_Trunc(m_Value(V))) && - match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = RHSCst; - BigCst = LHSCst; - } else if (match(Val, m_Trunc(m_Value(V))) && - match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = LHSCst; - BigCst = RHSCst; + if (match(RHS0, m_Trunc(m_Value(V))) && + match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = RHSC; + BigC = LHSC; + } else if (match(LHS0, m_Trunc(m_Value(V))) && + match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = LHSC; + BigC = RHSC; } - if (SmallCst && BigCst) { - unsigned BigBitSize = BigCst->getType()->getBitWidth(); - unsigned SmallBitSize = SmallCst->getType()->getBitWidth(); + if (SmallC && BigC) { + unsigned BigBitSize = BigC->getType()->getBitWidth(); + unsigned SmallBitSize = SmallC->getType()->getBitWidth(); // Check that the low bits are zero. APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) { - Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue()); - APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue(); - Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N); - return Builder->CreateICmp(LHSCC, NewAnd, NewVal); + if ((Low & AndC->getValue()) == 0 && (Low & BigC->getValue()) == 0) { + Value *NewAnd = Builder->CreateAnd(V, Low | AndC->getValue()); + APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); + Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); + return Builder->CreateICmp(PredL, NewAnd, NewVal); } } } // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) & (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -907,113 +901,95 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know // (from the icmp folding check above), that the two constants // are not equal and that the larger constant is on the RHS - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 - case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 - case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 + case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 + case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 return LHS; } case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_ULT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 - return Builder->CreateICmpULT(Val, LHSCst); - if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 + return Builder->CreateICmpULT(LHS0, LHSC); + if (LHSC->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - break; // (X != 13 & X u< 15) -> no change + break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 - return Builder->CreateICmpSLT(Val, LHSCst); - break; // (X != 13 & X s< 15) -> no change - case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 - case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 + if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 + return Builder->CreateICmpSLT(LHS0, LHSC); + break; // (X != 13 & X s< 15) -> no change + case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 + case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 return RHS; case ICmpInst::ICMP_NE: - // Special case to get the ordering right when the values wrap around - // zero. - if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue()) - std::swap(LHSCst, RHSCst); - if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1), - Val->getName()+".cmp"); - } - break; // (X != 13 & X != 15) -> no change + // Potential folds for this case should already be handled. + break; } break; case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false - case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 - case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 + case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 + case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 return LHS; - case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 - case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 + case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 return LHS; - case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change - break; } break; case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 return RHS; - case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change - break; case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 + return Builder->CreateICmp(PredL, LHS0, RHSC); + break; // (X u> 13 & X != 15) -> no change + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 return RHS; - case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change - break; case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), - true, true); - case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change - break; + if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 + return Builder->CreateICmp(PredL, LHS0, RHSC); + break; // (X s> 13 & X != 15) -> no change + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, + true); } break; } @@ -1314,39 +1290,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { break; } - case Instruction::Add: - // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes - break; - case Instruction::Sub: - // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - // -x & 1 -> x & 1 if (AndRHSMask == 1 && match(Op0LHS, m_Zero())) return BinaryOperator::CreateAnd(Op0RHS, AndRHS); - // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS - // has 1's for all bits that the subtraction with A might affect. - if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) { - uint32_t BitWidth = AndRHSMask.getBitWidth(); - uint32_t Zeros = AndRHSMask.countLeadingZeros(); - APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - - if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { - Value *NewNeg = Builder->CreateNeg(Op0RHS); - return BinaryOperator::CreateAnd(NewNeg, AndRHS); - } - } break; case Instruction::Shl: @@ -1361,6 +1309,33 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { break; } + // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth + // of X and OP behaves well when given trunc(C1) and X. + switch (Op0I->getOpcode()) { + default: + break; + case Instruction::Xor: + case Instruction::Or: + case Instruction::Mul: + case Instruction::Add: + case Instruction::Sub: + Value *X; + ConstantInt *C1; + if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) { + if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { + auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); + Value *BinOp; + if (isa<ZExtInst>(Op0LHS)) + BinOp = Builder->CreateBinOp(Op0I->getOpcode(), X, TruncC1); + else + BinOp = Builder->CreateBinOp(Op0I->getOpcode(), TruncC1, X); + auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); + auto *And = Builder->CreateAnd(BinOp, TruncC2); + return new ZExtInst(And, I.getType()); + } + } + } + if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I)) return Res; @@ -1381,10 +1356,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(NewCast, C3); } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; @@ -1630,15 +1606,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() && - RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { + if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() && + RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0)); BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0)); @@ -1680,52 +1656,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. - if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && - LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && - RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && - LHSCst->getValue() == (RHSCst->getValue())) { + if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && + PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && + LHSC->getType() == RHSC->getType() && + LHSC->getValue() == (RHSC->getValue())) { Value *LAdd = LHS->getOperand(0); Value *RAdd = RHS->getOperand(0); Value *LAddOpnd, *RAddOpnd; - ConstantInt *LAddCst, *RAddCst; - if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && - match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && - LAddCst->getValue().ugt(LHSCst->getValue()) && - RAddCst->getValue().ugt(LHSCst->getValue())) { - - APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); - if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { - ConstantInt *MaxAddCst = nullptr; - if (LAddCst->getValue().ult(RAddCst->getValue())) - MaxAddCst = RAddCst; + ConstantInt *LAddC, *RAddC; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && + LAddC->getValue().ugt(LHSC->getValue()) && + RAddC->getValue().ugt(LHSC->getValue())) { + + APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { + ConstantInt *MaxAddC = nullptr; + if (LAddC->getValue().ult(RAddC->getValue())) + MaxAddC = RAddC; else - MaxAddCst = LAddCst; + MaxAddC = LAddC; - APInt RRangeLow = -RAddCst->getValue(); - APInt RRangeHigh = RRangeLow + LHSCst->getValue(); - APInt LRangeLow = -LAddCst->getValue(); - APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt RRangeLow = -RAddC->getValue(); + APInt RRangeHigh = RRangeLow + LHSC->getValue(); + APInt LRangeLow = -LAddC->getValue(); + APInt LRangeHigh = LRangeLow + LHSC->getValue(); APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSCst->getValue())) { - Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + RangeDiff.ugt(LHSC->getValue())) { + Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); - Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); - Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); - return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskC); + Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddC); + return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSC)); } } } } // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -1743,25 +1719,25 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder)) return V; - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); if (LHS->hasOneUse() || RHS->hasOneUse()) { // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1) // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1) Value *A = nullptr, *B = nullptr; - if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) { - B = Val; - if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1)) - A = Val2; - else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2) + if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) { + B = LHS0; + if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1)) + A = RHS0; + else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = RHS->getOperand(1); } // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1) // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) - else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { - B = Val2; - if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1)) - A = Val; - else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val) + else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { + B = RHS0; + if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1)) + A = LHS0; + else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = LHS->getOperand(1); } if (A && B) @@ -1778,54 +1754,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). - if (!LHSCst || !RHSCst) return nullptr; + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) { + Value *NewOr = Builder->CreateOr(LHS0, RHS0); + return Builder->CreateICmp(PredL, NewOr, LHSC); } } // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) // iff C2 + CA == C1. - if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) { - ConstantInt *AddCst; - if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst)))) - if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue()) - return Builder->CreateICmpULE(Val, LHSCst); + if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { + ConstantInt *AddC; + if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) + if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) + return Builder->CreateICmpULE(LHS0, LHSC); } // From here on, we only handle: // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) | (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -1834,127 +1814,98 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the // icmp folding check above), that the two constants are not // equal. - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - if (LHS->getOperand(0) == RHS->getOperand(0)) { - // if LHSCst and RHSCst differ only by one bit: - // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2 - assert(LHSCst->getValue().ule(LHSCst->getValue())); - - APInt Xor = LHSCst->getValue() ^ RHSCst->getValue(); - if (Xor.isPowerOf2()) { - Value *Cst = Builder->getInt(Xor); - Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst); - return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst); - } - } - - if (LHSCst == SubOne(RHSCst)) { - // (X == 13 | X == 14) -> X-13 <u 2 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst); - return Builder->CreateICmpULT(Add, AddCST); - } - - break; // (X == 13 | X == 15) -> no change - case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change - case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + // Potential folds for this case should already be handled. + break; + case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change + case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change break; - case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 - case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 + case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 return RHS; } break; case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 - case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 - case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 + case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 + case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 return LHS; - case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true - case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true + case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true + case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true return Builder->getTrue(); } case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling + case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 + // If RHSC is [us]MAXINT, it is always false. Not handling // this can cause overflow. - if (RHSCst->isMaxValue(false)) + if (RHSC->isMaxValue(false)) return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, false, false); - case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 return RHS; - case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + // If RHSC is [us]MAXINT, it is always false. Not handling // this can cause overflow. - if (RHSCst->isMaxValue(true)) + if (RHSC->isMaxValue(true)) return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, - true, false); - case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, + false); + case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 return RHS; - case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change - break; } break; case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 - case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 + case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 return LHS; - case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true + case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true return Builder->getTrue(); - case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 - case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 + case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 return LHS; - case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true - case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true + case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true + case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true return Builder->getTrue(); - case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change - break; } break; } @@ -2100,17 +2051,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; - // (X & C1) | C2 --> (X | C2) & (C1|C2) - // iff (C1 & C2) == 0. - if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && - (RHS->getValue() & C1->getValue()) != 0 && - Op0->hasOneUse()) { - Value *Or = Builder->CreateOr(X, RHS); - Or->takeName(Op0); - return BinaryOperator::CreateAnd(Or, - Builder->getInt(RHS->getValue() | C1->getValue())); - } - // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2) if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) && Op0->hasOneUse()) { @@ -2119,45 +2059,51 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateXor(Or, Builder->getInt(C1->getValue() & ~RHS->getValue())); } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } // Given an OR instruction, check to see if this is a bswap. if (Instruction *BSwap = MatchBSwap(I)) return BSwap; - Value *A = nullptr, *B = nullptr; - ConstantInt *C1 = nullptr, *C2 = nullptr; + { + Value *A; + const APInt *C; + // (X^C)|Y -> (X|Y)^C iff Y&C == 0 + if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op1, *C, 0, &I)) { + Value *NOr = Builder->CreateOr(A, Op1); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + cast<Instruction>(Op0)->getOperand(1)); + } - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 - if (Op0->hasOneUse() && - match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op1); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); + // Y|(X^C) -> (X|Y)^C iff Y&C == 0 + if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op0, *C, 0, &I)) { + Value *NOr = Builder->CreateOr(A, Op0); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + cast<Instruction>(Op1)->getOperand(1)); + } } - // Y|(X^C) -> (X|Y)^C iff Y&C == 0 - if (Op1->hasOneUse() && - match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op0); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); - } + Value *A, *B; // ((~A & B) | A) -> (A | B) - if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_Specific(A))) - return BinaryOperator::CreateOr(A, B); + if (match(Op0, m_c_And(m_Not(m_Specific(Op1)), m_Value(A)))) + return BinaryOperator::CreateOr(A, Op1); + if (match(Op1, m_c_And(m_Not(m_Specific(Op0)), m_Value(A)))) + return BinaryOperator::CreateOr(Op0, A); // ((A & B) | ~A) -> (~A | B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_Specific(A)))) - return BinaryOperator::CreateOr(Builder->CreateNot(A), B); + // The NOT is guaranteed to be in the RHS by complexity ordering. + if (match(Op1, m_Not(m_Value(A))) && + match(Op0, m_c_And(m_Specific(A), m_Value(B)))) + return BinaryOperator::CreateOr(Op1, B); // (A & ~B) | (A ^ B) -> (A ^ B) // (~B & A) | (A ^ B) -> (A ^ B) @@ -2177,8 +2123,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { Value *V1 = nullptr, *V2 = nullptr; - C1 = dyn_cast<ConstantInt>(C); - C2 = dyn_cast<ConstantInt>(D); + ConstantInt *C1 = dyn_cast<ConstantInt>(C); + ConstantInt *C2 = dyn_cast<ConstantInt>(D); if (C1 && C2) { // (A & C1)|(B & C2) if ((C1->getValue() & C2->getValue()) == 0) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) @@ -2403,6 +2349,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // be simplified by a later pass either, so we try swapping the inner/outer // ORs in the hopes that we'll be able to simplify it this way. // (X|C) | V --> (X|V) | C + ConstantInt *C1; if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) && match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) { Value *Inner = Builder->CreateOr(A, Op1); @@ -2493,23 +2440,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - if (Constant *RHS = dyn_cast<Constant>(Op1)) { - if (RHS->isAllOnesValue() && Op0->hasOneUse()) - // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B - if (CmpInst *CI = dyn_cast<CmpInst>(Op0)) - return CmpInst::Create(CI->getOpcode(), - CI->getInversePredicate(), - CI->getOperand(0), CI->getOperand(1)); + // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B + ICmpInst::Predicate Pred; + if (match(Op0, m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))) && + match(Op1, m_AllOnes())) { + cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, Op0); } - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { if (CI->hasOneUse() && Op0C->hasOneUse()) { Instruction::CastOps Opcode = Op0C->getOpcode(); if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && - (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(), + (RHSC == ConstantExpr::getCast(Opcode, Builder->getTrue(), Op0C->getDestTy()))) { CI->setPredicate(CI->getInversePredicate()); return CastInst::Create(Opcode, CI, Op0C->getType()); @@ -2520,26 +2466,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { // ~(c-X) == X-c-1 == X+(-c-1) - if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) + if (Op0I->getOpcode() == Instruction::Sub && RHSC->isAllOnesValue()) if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); - Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, - ConstantInt::get(I.getType(), 1)); - return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS); + return BinaryOperator::CreateAdd(Op0I->getOperand(1), + SubOne(NegOp0I0C)); } if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { if (Op0I->getOpcode() == Instruction::Add) { // ~(X-c) --> (-c-1)-X - if (RHS->isAllOnesValue()) { + if (RHSC->isAllOnesValue()) { Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); - return BinaryOperator::CreateSub( - ConstantExpr::getSub(NegOp0CI, - ConstantInt::get(I.getType(), 1)), - Op0I->getOperand(0)); - } else if (RHS->getValue().isSignBit()) { + return BinaryOperator::CreateSub(SubOne(NegOp0CI), + Op0I->getOperand(0)); + } else if (RHSC->getValue().isSignBit()) { // (X + C) ^ signbit -> (X + C + signbit) - Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue()); + Constant *C = Builder->getInt(RHSC->getValue() + Op0CI->getValue()); return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); } @@ -2547,10 +2490,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), 0, &I)) { - Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); + Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. - Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS); + Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC); NewRHS = ConstantExpr::getAnd(NewRHS, ConstantExpr::getNot(CommonBits)); Worklist.Add(Op0I); @@ -2568,7 +2511,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { E1->getOpcode() == Instruction::Xor && (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) { // fold (C1 >> C2) ^ C3 - ConstantInt *C2 = Op0CI, *C3 = RHS; + ConstantInt *C2 = Op0CI, *C3 = RHSC; APInt FoldConst = C1->getValue().lshr(C2->getValue()); FoldConst ^= C3->getValue(); // Prepare the two operands. @@ -2582,27 +2525,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } - BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); - if (Op1I) { + { Value *A, *B; - if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { - if (A == Op0) { // B^(B|A) == (A|B)^B - Op1I->swapOperands(); - I.swapOperands(); - std::swap(Op0, Op1); - } else if (B == Op0) { // B^(A|B) == (A|B)^B + if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { + if (A == Op0) { // A^(A|B) == A^(B|A) + cast<BinaryOperator>(Op1)->swapOperands(); + std::swap(A, B); + } + if (B == Op0) { // A^(B|A) == (B|A)^A I.swapOperands(); // Simplified below. std::swap(Op0, Op1); } - } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && - Op1I->hasOneUse()){ + } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op0) { // A^(A&B) -> A^(B&A) - Op1I->swapOperands(); + cast<BinaryOperator>(Op1)->swapOperands(); std::swap(A, B); } if (B == Op0) { // A^(B&A) -> (B&A)^A @@ -2612,65 +2554,63 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0); - if (Op0I) { + { Value *A, *B; - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()) { + if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { if (A == Op1) // (B|A)^B == (A|B)^B std::swap(A, B); if (B == Op1) // (A|B)^B == A & ~B return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1)); - } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()){ + } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op1) // (A&B)^A -> (B&A)^A std::swap(A, B); + const APInt *C; if (B == Op1 && // (B&A)^A == ~B & A - !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C + !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1); } } } - if (Op0I && Op1I) { + { Value *A, *B, *C, *D; // (A & B)^(A | B) -> A ^ B - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Or(m_Value(C), m_Value(D)))) { + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Or(m_Value(C), m_Value(D)))) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } // (A | B)^(A & B) -> A ^ B - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Value(C), m_Value(D)))) { + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Value(C), m_Value(D)))) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } // (A | ~B) ^ (~A | B) -> A ^ B // (~B | A) ^ (~A | B) -> A ^ B - if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (~A | B) ^ (A | ~B) -> A ^ B - if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { + if (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A & ~B) ^ (~A & B) -> A ^ B // (~B & A) ^ (~A & B) -> A ^ B - if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_And(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (~A & B) ^ (A & ~B) -> A ^ B - if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { + if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A ^ C)^(A | B) -> ((~A) & B) ^ C - if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && - match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (match(Op0, m_Xor(m_Value(D), m_Value(C))) && + match(Op1, m_Or(m_Value(A), m_Value(B)))) { if (D == A) return BinaryOperator::CreateXor( Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2679,8 +2619,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Builder->CreateAnd(Builder->CreateNot(B), A), C); } // (A | B)^(A ^ C) -> ((~A) & B) ^ C - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Value(D), m_Value(C)))) { if (D == A) return BinaryOperator::CreateXor( Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2689,12 +2629,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Builder->CreateAnd(Builder->CreateNot(B), A), C); } // (A & B) ^ (A ^ B) -> (A | B) - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); // (A ^ B) ^ (A & B) -> (A | B) - if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 2ef82ba3ed8c..69484f47223f 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -60,6 +60,12 @@ using namespace PatternMatch; STATISTIC(NumSimplified, "Number of library calls simplified"); +static cl::opt<unsigned> UnfoldElementAtomicMemcpyMaxElements( + "unfold-element-atomic-memcpy-max-elements", + cl::init(16), + cl::desc("Maximum number of elements in atomic memcpy the optimizer is " + "allowed to unfold")); + /// Return the specified type promoted as it would be to pass though a va_arg /// area. static Type *getPromotedType(Type *Ty) { @@ -70,27 +76,6 @@ static Type *getPromotedType(Type *Ty) { return Ty; } -/// Given an aggregate type which ultimately holds a single scalar element, -/// like {{{type}}} or [1 x type], return type. -static Type *reduceToSingleValueType(Type *T) { - while (!T->isSingleValueType()) { - if (StructType *STy = dyn_cast<StructType>(T)) { - if (STy->getNumElements() == 1) - T = STy->getElementType(0); - else - break; - } else if (ArrayType *ATy = dyn_cast<ArrayType>(T)) { - if (ATy->getNumElements() == 1) - T = ATy->getElementType(); - else - break; - } else - break; - } - - return T; -} - /// Return a constant boolean vector that has true elements in all positions /// where the input constant data vector has an element with the sign bit set. static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { @@ -108,6 +93,78 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } +Instruction * +InstCombiner::SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI) { + // Try to unfold this intrinsic into sequence of explicit atomic loads and + // stores. + // First check that number of elements is compile time constant. + auto *NumElementsCI = dyn_cast<ConstantInt>(AMI->getNumElements()); + if (!NumElementsCI) + return nullptr; + + // Check that there are not too many elements. + uint64_t NumElements = NumElementsCI->getZExtValue(); + if (NumElements >= UnfoldElementAtomicMemcpyMaxElements) + return nullptr; + + // Don't unfold into illegal integers + uint64_t ElementSizeInBytes = AMI->getElementSizeInBytes() * 8; + if (!getDataLayout().isLegalInteger(ElementSizeInBytes)) + return nullptr; + + // Cast source and destination to the correct type. Intrinsic input arguments + // are usually represented as i8*. + // Often operands will be explicitly casted to i8* and we can just strip + // those casts instead of inserting new ones. However it's easier to rely on + // other InstCombine rules which will cover trivial cases anyway. + Value *Src = AMI->getRawSource(); + Value *Dst = AMI->getRawDest(); + Type *ElementPointerType = Type::getIntNPtrTy( + AMI->getContext(), ElementSizeInBytes, Src->getType()->getPointerAddressSpace()); + + Value *SrcCasted = Builder->CreatePointerCast(Src, ElementPointerType, + "memcpy_unfold.src_casted"); + Value *DstCasted = Builder->CreatePointerCast(Dst, ElementPointerType, + "memcpy_unfold.dst_casted"); + + for (uint64_t i = 0; i < NumElements; ++i) { + // Get current element addresses + ConstantInt *ElementIdxCI = + ConstantInt::get(AMI->getContext(), APInt(64, i)); + Value *SrcElementAddr = + Builder->CreateGEP(SrcCasted, ElementIdxCI, "memcpy_unfold.src_addr"); + Value *DstElementAddr = + Builder->CreateGEP(DstCasted, ElementIdxCI, "memcpy_unfold.dst_addr"); + + // Load from the source. Transfer alignment information and mark load as + // unordered atomic. + LoadInst *Load = Builder->CreateLoad(SrcElementAddr, "memcpy_unfold.val"); + Load->setOrdering(AtomicOrdering::Unordered); + // We know alignment of the first element. It is also guaranteed by the + // verifier that element size is less or equal than first element alignment + // and both of this values are powers of two. + // This means that all subsequent accesses are at least element size + // aligned. + // TODO: We can infer better alignment but there is no evidence that this + // will matter. + Load->setAlignment(i == 0 ? AMI->getSrcAlignment() + : AMI->getElementSizeInBytes()); + Load->setDebugLoc(AMI->getDebugLoc()); + + // Store loaded value via unordered atomic store. + StoreInst *Store = Builder->CreateStore(Load, DstElementAddr); + Store->setOrdering(AtomicOrdering::Unordered); + Store->setAlignment(i == 0 ? AMI->getDstAlignment() + : AMI->getElementSizeInBytes()); + Store->setDebugLoc(AMI->getDebugLoc()); + } + + // Set the number of elements of the copy to 0, it will be deleted on the + // next iteration. + AMI->setNumElements(Constant::getNullValue(NumElementsCI->getType())); + return AMI; +} + Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { unsigned DstAlign = getKnownAlignment(MI->getArgOperand(0), DL, MI, &AC, &DT); unsigned SrcAlign = getKnownAlignment(MI->getArgOperand(1), DL, MI, &AC, &DT); @@ -144,41 +201,19 @@ Instruction *InstCombiner::SimplifyMemTransfer(MemIntrinsic *MI) { Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); - // Memcpy forces the use of i8* for the source and destination. That means - // that if you're using memcpy to move one double around, you'll get a cast - // from double* to i8*. We'd much rather use a double load+store rather than - // an i64 load+store, here because this improves the odds that the source or - // dest address will be promotable. See if we can find a better type than the - // integer datatype. - Value *StrippedDest = MI->getArgOperand(0)->stripPointerCasts(); + // If the memcpy has metadata describing the members, see if we can get the + // TBAA tag describing our copy. MDNode *CopyMD = nullptr; - if (StrippedDest != MI->getArgOperand(0)) { - Type *SrcETy = cast<PointerType>(StrippedDest->getType()) - ->getElementType(); - if (SrcETy->isSized() && DL.getTypeStoreSize(SrcETy) == Size) { - // The SrcETy might be something like {{{double}}} or [1 x double]. Rip - // down through these levels if so. - SrcETy = reduceToSingleValueType(SrcETy); - - if (SrcETy->isSingleValueType()) { - NewSrcPtrTy = PointerType::get(SrcETy, SrcAddrSp); - NewDstPtrTy = PointerType::get(SrcETy, DstAddrSp); - - // If the memcpy has metadata describing the members, see if we can - // get the TBAA tag describing our copy. - if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { - if (M->getNumOperands() == 3 && M->getOperand(0) && - mdconst::hasa<ConstantInt>(M->getOperand(0)) && - mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && - M->getOperand(1) && - mdconst::hasa<ConstantInt>(M->getOperand(1)) && - mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == - Size && - M->getOperand(2) && isa<MDNode>(M->getOperand(2))) - CopyMD = cast<MDNode>(M->getOperand(2)); - } - } - } + if (MDNode *M = MI->getMetadata(LLVMContext::MD_tbaa_struct)) { + if (M->getNumOperands() == 3 && M->getOperand(0) && + mdconst::hasa<ConstantInt>(M->getOperand(0)) && + mdconst::extract<ConstantInt>(M->getOperand(0))->isNullValue() && + M->getOperand(1) && + mdconst::hasa<ConstantInt>(M->getOperand(1)) && + mdconst::extract<ConstantInt>(M->getOperand(1))->getValue() == + Size && + M->getOperand(2) && isa<MDNode>(M->getOperand(2))) + CopyMD = cast<MDNode>(M->getOperand(2)); } // If the memcpy/memmove provides better alignment info than we can @@ -510,6 +545,131 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, return Builder.CreateAShr(Vec, ShiftVec); } +static Value *simplifyX86muldq(const IntrinsicInst &II, + InstCombiner::BuilderTy &Builder) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + Type *ResTy = II.getType(); + assert(Arg0->getType()->getScalarSizeInBits() == 32 && + Arg1->getType()->getScalarSizeInBits() == 32 && + ResTy->getScalarSizeInBits() == 64 && "Unexpected muldq/muludq types"); + + // muldq/muludq(undef, undef) -> zero (matches generic mul behavior) + if (isa<UndefValue>(Arg0) || isa<UndefValue>(Arg1)) + return ConstantAggregateZero::get(ResTy); + + // Constant folding. + // PMULDQ = (mul(vXi64 sext(shuffle<0,2,..>(Arg0)), + // vXi64 sext(shuffle<0,2,..>(Arg1)))) + // PMULUDQ = (mul(vXi64 zext(shuffle<0,2,..>(Arg0)), + // vXi64 zext(shuffle<0,2,..>(Arg1)))) + if (!isa<Constant>(Arg0) || !isa<Constant>(Arg1)) + return nullptr; + + unsigned NumElts = ResTy->getVectorNumElements(); + assert(Arg0->getType()->getVectorNumElements() == (2 * NumElts) && + Arg1->getType()->getVectorNumElements() == (2 * NumElts) && + "Unexpected muldq/muludq types"); + + unsigned IntrinsicID = II.getIntrinsicID(); + bool IsSigned = (Intrinsic::x86_sse41_pmuldq == IntrinsicID || + Intrinsic::x86_avx2_pmul_dq == IntrinsicID || + Intrinsic::x86_avx512_pmul_dq_512 == IntrinsicID); + + SmallVector<unsigned, 16> ShuffleMask; + for (unsigned i = 0; i != NumElts; ++i) + ShuffleMask.push_back(i * 2); + + auto *LHS = Builder.CreateShuffleVector(Arg0, Arg0, ShuffleMask); + auto *RHS = Builder.CreateShuffleVector(Arg1, Arg1, ShuffleMask); + + if (IsSigned) { + LHS = Builder.CreateSExt(LHS, ResTy); + RHS = Builder.CreateSExt(RHS, ResTy); + } else { + LHS = Builder.CreateZExt(LHS, ResTy); + RHS = Builder.CreateZExt(RHS, ResTy); + } + + return Builder.CreateMul(LHS, RHS); +} + +static Value *simplifyX86pack(IntrinsicInst &II, InstCombiner &IC, + InstCombiner::BuilderTy &Builder, bool IsSigned) { + Value *Arg0 = II.getArgOperand(0); + Value *Arg1 = II.getArgOperand(1); + Type *ResTy = II.getType(); + + // Fast all undef handling. + if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1)) + return UndefValue::get(ResTy); + + Type *ArgTy = Arg0->getType(); + unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; + unsigned NumDstElts = ResTy->getVectorNumElements(); + unsigned NumSrcElts = ArgTy->getVectorNumElements(); + assert(NumDstElts == (2 * NumSrcElts) && "Unexpected packing types"); + + unsigned NumDstEltsPerLane = NumDstElts / NumLanes; + unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; + unsigned DstScalarSizeInBits = ResTy->getScalarSizeInBits(); + assert(ArgTy->getScalarSizeInBits() == (2 * DstScalarSizeInBits) && + "Unexpected packing types"); + + // Constant folding. + auto *Cst0 = dyn_cast<Constant>(Arg0); + auto *Cst1 = dyn_cast<Constant>(Arg1); + if (!Cst0 || !Cst1) + return nullptr; + + SmallVector<Constant *, 32> Vals; + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + for (unsigned Elt = 0; Elt != NumDstEltsPerLane; ++Elt) { + unsigned SrcIdx = Lane * NumSrcEltsPerLane + Elt % NumSrcEltsPerLane; + auto *Cst = (Elt >= NumSrcEltsPerLane) ? Cst1 : Cst0; + auto *COp = Cst->getAggregateElement(SrcIdx); + if (COp && isa<UndefValue>(COp)) { + Vals.push_back(UndefValue::get(ResTy->getScalarType())); + continue; + } + + auto *CInt = dyn_cast_or_null<ConstantInt>(COp); + if (!CInt) + return nullptr; + + APInt Val = CInt->getValue(); + assert(Val.getBitWidth() == ArgTy->getScalarSizeInBits() && + "Unexpected constant bitwidth"); + + if (IsSigned) { + // PACKSS: Truncate signed value with signed saturation. + // Source values less than dst minint are saturated to minint. + // Source values greater than dst maxint are saturated to maxint. + if (Val.isSignedIntN(DstScalarSizeInBits)) + Val = Val.trunc(DstScalarSizeInBits); + else if (Val.isNegative()) + Val = APInt::getSignedMinValue(DstScalarSizeInBits); + else + Val = APInt::getSignedMaxValue(DstScalarSizeInBits); + } else { + // PACKUS: Truncate signed value with unsigned saturation. + // Source values less than zero are saturated to zero. + // Source values greater than dst maxuint are saturated to maxuint. + if (Val.isIntN(DstScalarSizeInBits)) + Val = Val.trunc(DstScalarSizeInBits); + else if (Val.isNegative()) + Val = APInt::getNullValue(DstScalarSizeInBits); + else + Val = APInt::getAllOnesValue(DstScalarSizeInBits); + } + + Vals.push_back(ConstantInt::get(ResTy->getScalarType(), Val)); + } + } + + return ConstantVector::get(Vals); +} + static Value *simplifyX86movmsk(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); @@ -1330,6 +1490,27 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { return true; } +// Constant fold llvm.amdgcn.fmed3 intrinsics for standard inputs. +// +// A single NaN input is folded to minnum, so we rely on that folding for +// handling NaNs. +static APFloat fmed3AMDGCN(const APFloat &Src0, const APFloat &Src1, + const APFloat &Src2) { + APFloat Max3 = maxnum(maxnum(Src0, Src1), Src2); + + APFloat::cmpResult Cmp0 = Max3.compare(Src0); + assert(Cmp0 != APFloat::cmpUnordered && "nans handled separately"); + if (Cmp0 == APFloat::cmpEqual) + return maxnum(Src1, Src2); + + APFloat::cmpResult Cmp1 = Max3.compare(Src1); + assert(Cmp1 != APFloat::cmpUnordered && "nans handled separately"); + if (Cmp1 == APFloat::cmpEqual) + return maxnum(Src0, Src2); + + return maxnum(Src0, Src1); +} + // Returns true iff the 2 intrinsics have the same operands, limiting the // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, @@ -1373,6 +1554,254 @@ static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID, return false; } +// Convert NVVM intrinsics to target-generic LLVM code where possible. +static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { + // Each NVVM intrinsic we can simplify can be replaced with one of: + // + // * an LLVM intrinsic, + // * an LLVM cast operation, + // * an LLVM binary operation, or + // * ad-hoc LLVM IR for the particular operation. + + // Some transformations are only valid when the module's + // flush-denormals-to-zero (ftz) setting is true/false, whereas other + // transformations are valid regardless of the module's ftz setting. + enum FtzRequirementTy { + FTZ_Any, // Any ftz setting is ok. + FTZ_MustBeOn, // Transformation is valid only if ftz is on. + FTZ_MustBeOff, // Transformation is valid only if ftz is off. + }; + // Classes of NVVM intrinsics that can't be replaced one-to-one with a + // target-generic intrinsic, cast op, or binary op but that we can nonetheless + // simplify. + enum SpecialCase { + SPC_Reciprocal, + }; + + // SimplifyAction is a poor-man's variant (plus an additional flag) that + // represents how to replace an NVVM intrinsic with target-generic LLVM IR. + struct SimplifyAction { + // Invariant: At most one of these Optionals has a value. + Optional<Intrinsic::ID> IID; + Optional<Instruction::CastOps> CastOp; + Optional<Instruction::BinaryOps> BinaryOp; + Optional<SpecialCase> Special; + + FtzRequirementTy FtzRequirement = FTZ_Any; + + SimplifyAction() = default; + + SimplifyAction(Intrinsic::ID IID, FtzRequirementTy FtzReq) + : IID(IID), FtzRequirement(FtzReq) {} + + // Cast operations don't have anything to do with FTZ, so we skip that + // argument. + SimplifyAction(Instruction::CastOps CastOp) : CastOp(CastOp) {} + + SimplifyAction(Instruction::BinaryOps BinaryOp, FtzRequirementTy FtzReq) + : BinaryOp(BinaryOp), FtzRequirement(FtzReq) {} + + SimplifyAction(SpecialCase Special, FtzRequirementTy FtzReq) + : Special(Special), FtzRequirement(FtzReq) {} + }; + + // Try to generate a SimplifyAction describing how to replace our + // IntrinsicInstr with target-generic LLVM IR. + const SimplifyAction Action = [II]() -> SimplifyAction { + switch (II->getIntrinsicID()) { + + // NVVM intrinsics that map directly to LLVM intrinsics. + case Intrinsic::nvvm_ceil_d: + return {Intrinsic::ceil, FTZ_Any}; + case Intrinsic::nvvm_ceil_f: + return {Intrinsic::ceil, FTZ_MustBeOff}; + case Intrinsic::nvvm_ceil_ftz_f: + return {Intrinsic::ceil, FTZ_MustBeOn}; + case Intrinsic::nvvm_fabs_d: + return {Intrinsic::fabs, FTZ_Any}; + case Intrinsic::nvvm_fabs_f: + return {Intrinsic::fabs, FTZ_MustBeOff}; + case Intrinsic::nvvm_fabs_ftz_f: + return {Intrinsic::fabs, FTZ_MustBeOn}; + case Intrinsic::nvvm_floor_d: + return {Intrinsic::floor, FTZ_Any}; + case Intrinsic::nvvm_floor_f: + return {Intrinsic::floor, FTZ_MustBeOff}; + case Intrinsic::nvvm_floor_ftz_f: + return {Intrinsic::floor, FTZ_MustBeOn}; + case Intrinsic::nvvm_fma_rn_d: + return {Intrinsic::fma, FTZ_Any}; + case Intrinsic::nvvm_fma_rn_f: + return {Intrinsic::fma, FTZ_MustBeOff}; + case Intrinsic::nvvm_fma_rn_ftz_f: + return {Intrinsic::fma, FTZ_MustBeOn}; + case Intrinsic::nvvm_fmax_d: + return {Intrinsic::maxnum, FTZ_Any}; + case Intrinsic::nvvm_fmax_f: + return {Intrinsic::maxnum, FTZ_MustBeOff}; + case Intrinsic::nvvm_fmax_ftz_f: + return {Intrinsic::maxnum, FTZ_MustBeOn}; + case Intrinsic::nvvm_fmin_d: + return {Intrinsic::minnum, FTZ_Any}; + case Intrinsic::nvvm_fmin_f: + return {Intrinsic::minnum, FTZ_MustBeOff}; + case Intrinsic::nvvm_fmin_ftz_f: + return {Intrinsic::minnum, FTZ_MustBeOn}; + case Intrinsic::nvvm_round_d: + return {Intrinsic::round, FTZ_Any}; + case Intrinsic::nvvm_round_f: + return {Intrinsic::round, FTZ_MustBeOff}; + case Intrinsic::nvvm_round_ftz_f: + return {Intrinsic::round, FTZ_MustBeOn}; + case Intrinsic::nvvm_sqrt_rn_d: + return {Intrinsic::sqrt, FTZ_Any}; + case Intrinsic::nvvm_sqrt_f: + // nvvm_sqrt_f is a special case. For most intrinsics, foo_ftz_f is the + // ftz version, and foo_f is the non-ftz version. But nvvm_sqrt_f adopts + // the ftz-ness of the surrounding code. sqrt_rn_f and sqrt_rn_ftz_f are + // the versions with explicit ftz-ness. + return {Intrinsic::sqrt, FTZ_Any}; + case Intrinsic::nvvm_sqrt_rn_f: + return {Intrinsic::sqrt, FTZ_MustBeOff}; + case Intrinsic::nvvm_sqrt_rn_ftz_f: + return {Intrinsic::sqrt, FTZ_MustBeOn}; + case Intrinsic::nvvm_trunc_d: + return {Intrinsic::trunc, FTZ_Any}; + case Intrinsic::nvvm_trunc_f: + return {Intrinsic::trunc, FTZ_MustBeOff}; + case Intrinsic::nvvm_trunc_ftz_f: + return {Intrinsic::trunc, FTZ_MustBeOn}; + + // NVVM intrinsics that map to LLVM cast operations. + // + // Note that llvm's target-generic conversion operators correspond to the rz + // (round to zero) versions of the nvvm conversion intrinsics, even though + // most everything else here uses the rn (round to nearest even) nvvm ops. + case Intrinsic::nvvm_d2i_rz: + case Intrinsic::nvvm_f2i_rz: + case Intrinsic::nvvm_d2ll_rz: + case Intrinsic::nvvm_f2ll_rz: + return {Instruction::FPToSI}; + case Intrinsic::nvvm_d2ui_rz: + case Intrinsic::nvvm_f2ui_rz: + case Intrinsic::nvvm_d2ull_rz: + case Intrinsic::nvvm_f2ull_rz: + return {Instruction::FPToUI}; + case Intrinsic::nvvm_i2d_rz: + case Intrinsic::nvvm_i2f_rz: + case Intrinsic::nvvm_ll2d_rz: + case Intrinsic::nvvm_ll2f_rz: + return {Instruction::SIToFP}; + case Intrinsic::nvvm_ui2d_rz: + case Intrinsic::nvvm_ui2f_rz: + case Intrinsic::nvvm_ull2d_rz: + case Intrinsic::nvvm_ull2f_rz: + return {Instruction::UIToFP}; + + // NVVM intrinsics that map to LLVM binary ops. + case Intrinsic::nvvm_add_rn_d: + return {Instruction::FAdd, FTZ_Any}; + case Intrinsic::nvvm_add_rn_f: + return {Instruction::FAdd, FTZ_MustBeOff}; + case Intrinsic::nvvm_add_rn_ftz_f: + return {Instruction::FAdd, FTZ_MustBeOn}; + case Intrinsic::nvvm_mul_rn_d: + return {Instruction::FMul, FTZ_Any}; + case Intrinsic::nvvm_mul_rn_f: + return {Instruction::FMul, FTZ_MustBeOff}; + case Intrinsic::nvvm_mul_rn_ftz_f: + return {Instruction::FMul, FTZ_MustBeOn}; + case Intrinsic::nvvm_div_rn_d: + return {Instruction::FDiv, FTZ_Any}; + case Intrinsic::nvvm_div_rn_f: + return {Instruction::FDiv, FTZ_MustBeOff}; + case Intrinsic::nvvm_div_rn_ftz_f: + return {Instruction::FDiv, FTZ_MustBeOn}; + + // The remainder of cases are NVVM intrinsics that map to LLVM idioms, but + // need special handling. + // + // We seem to be mising intrinsics for rcp.approx.{ftz.}f32, which is just + // as well. + case Intrinsic::nvvm_rcp_rn_d: + return {SPC_Reciprocal, FTZ_Any}; + case Intrinsic::nvvm_rcp_rn_f: + return {SPC_Reciprocal, FTZ_MustBeOff}; + case Intrinsic::nvvm_rcp_rn_ftz_f: + return {SPC_Reciprocal, FTZ_MustBeOn}; + + // We do not currently simplify intrinsics that give an approximate answer. + // These include: + // + // - nvvm_cos_approx_{f,ftz_f} + // - nvvm_ex2_approx_{d,f,ftz_f} + // - nvvm_lg2_approx_{d,f,ftz_f} + // - nvvm_sin_approx_{f,ftz_f} + // - nvvm_sqrt_approx_{f,ftz_f} + // - nvvm_rsqrt_approx_{d,f,ftz_f} + // - nvvm_div_approx_{ftz_d,ftz_f,f} + // - nvvm_rcp_approx_ftz_d + // + // Ideally we'd encode them as e.g. "fast call @llvm.cos", where "fast" + // means that fastmath is enabled in the intrinsic. Unfortunately only + // binary operators (currently) have a fastmath bit in SelectionDAG, so this + // information gets lost and we can't select on it. + // + // TODO: div and rcp are lowered to a binary op, so these we could in theory + // lower them to "fast fdiv". + + default: + return {}; + } + }(); + + // If Action.FtzRequirementTy is not satisfied by the module's ftz state, we + // can bail out now. (Notice that in the case that IID is not an NVVM + // intrinsic, we don't have to look up any module metadata, as + // FtzRequirementTy will be FTZ_Any.) + if (Action.FtzRequirement != FTZ_Any) { + bool FtzEnabled = + II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() == + "true"; + + if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) + return nullptr; + } + + // Simplify to target-generic intrinsic. + if (Action.IID) { + SmallVector<Value *, 4> Args(II->arg_operands()); + // All the target-generic intrinsics currently of interest to us have one + // type argument, equal to that of the nvvm intrinsic's argument. + Type *Tys[] = {II->getArgOperand(0)->getType()}; + return CallInst::Create( + Intrinsic::getDeclaration(II->getModule(), *Action.IID, Tys), Args); + } + + // Simplify to target-generic binary op. + if (Action.BinaryOp) + return BinaryOperator::Create(*Action.BinaryOp, II->getArgOperand(0), + II->getArgOperand(1), II->getName()); + + // Simplify to target-generic cast op. + if (Action.CastOp) + return CastInst::Create(*Action.CastOp, II->getArgOperand(0), II->getType(), + II->getName()); + + // All that's left are the special cases. + if (!Action.Special) + return nullptr; + + switch (*Action.Special) { + case SPC_Reciprocal: + // Simplify reciprocal. + return BinaryOperator::Create( + Instruction::FDiv, ConstantFP::get(II->getArgOperand(0)->getType(), 1), + II->getArgOperand(0), II->getName()); + } + llvm_unreachable("All SpecialCase enumerators should be handled in switch."); +} + Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) { removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this); return nullptr; @@ -1462,6 +1891,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } + if (auto *AMI = dyn_cast<ElementAtomicMemCpyInst>(II)) { + if (Constant *C = dyn_cast<Constant>(AMI->getNumElements())) + if (C->isNullValue()) + return eraseInstFromFunction(*AMI); + + if (Instruction *I = SimplifyElementAtomicMemCpy(AMI)) + return I; + } + + if (Instruction *I = SimplifyNVVMIntrinsic(II, *this)) + return I; + auto SimplifyDemandedVectorEltsLow = [this](Value *Op, unsigned Width, unsigned DemandedWidth) { APInt UndefElts(Width, 0); @@ -1581,8 +2022,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, V); break; } - case Intrinsic::fma: case Intrinsic::fmuladd: { + // Canonicalize fast fmuladd to the separate fmul + fadd. + if (II->hasUnsafeAlgebra()) { + BuilderTy::FastMathFlagGuard Guard(*Builder); + Builder->setFastMathFlags(II->getFastMathFlags()); + Value *Mul = Builder->CreateFMul(II->getArgOperand(0), + II->getArgOperand(1)); + Value *Add = Builder->CreateFAdd(Mul, II->getArgOperand(2)); + Add->takeName(II); + return replaceInstUsesWith(*II, Add); + } + + LLVM_FALLTHROUGH; + } + case Intrinsic::fma: { Value *Src0 = II->getArgOperand(0); Value *Src1 = II->getArgOperand(1); @@ -1631,6 +2085,26 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return SelectInst::Create(Cond, Call0, Call1); } + LLVM_FALLTHROUGH; + } + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::rint: + case Intrinsic::trunc: { + Value *ExtSrc; + if (match(II->getArgOperand(0), m_FPExt(m_Value(ExtSrc))) && + II->getArgOperand(0)->hasOneUse()) { + // fabs (fpext x) -> fpext (fabs x) + Value *F = Intrinsic::getDeclaration(II->getModule(), II->getIntrinsicID(), + { ExtSrc->getType() }); + CallInst *NewFabs = Builder->CreateCall(F, ExtSrc); + NewFabs->copyFastMathFlags(II); + NewFabs->takeName(II); + return new FPExtInst(NewFabs, II->getType()); + } + break; } case Intrinsic::cos: @@ -1863,6 +2337,37 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; } + case Intrinsic::x86_avx512_mask_cmp_pd_128: + case Intrinsic::x86_avx512_mask_cmp_pd_256: + case Intrinsic::x86_avx512_mask_cmp_pd_512: + case Intrinsic::x86_avx512_mask_cmp_ps_128: + case Intrinsic::x86_avx512_mask_cmp_ps_256: + case Intrinsic::x86_avx512_mask_cmp_ps_512: { + // Folding cmp(sub(a,b),0) -> cmp(a,b) and cmp(0,sub(a,b)) -> cmp(b,a) + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + bool Arg0IsZero = match(Arg0, m_Zero()); + if (Arg0IsZero) + std::swap(Arg0, Arg1); + Value *A, *B; + // This fold requires only the NINF(not +/- inf) since inf minus + // inf is nan. + // NSZ(No Signed Zeros) is not needed because zeros of any sign are + // equal for both compares. + // NNAN is not needed because nans compare the same for both compares. + // The compare intrinsic uses the above assumptions and therefore + // doesn't require additional flags. + if ((match(Arg0, m_OneUse(m_FSub(m_Value(A), m_Value(B)))) && + match(Arg1, m_Zero()) && + cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { + if (Arg0IsZero) + std::swap(A, B); + II->setArgOperand(0, A); + II->setArgOperand(1, B); + return II; + } + break; + } case Intrinsic::x86_avx512_mask_add_ps_512: case Intrinsic::x86_avx512_mask_div_ps_512: @@ -2130,6 +2635,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_avx2_pmulu_dq: case Intrinsic::x86_avx512_pmul_dq_512: case Intrinsic::x86_avx512_pmulu_dq_512: { + if (Value *V = simplifyX86muldq(*II, *Builder)) + return replaceInstUsesWith(*II, V); + unsigned VWidth = II->getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); APInt DemandedElts = APInt::getAllOnesValue(VWidth); @@ -2141,6 +2649,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx512_packssdw_512: + case Intrinsic::x86_avx512_packsswb_512: + if (Value *V = simplifyX86pack(*II, *this, *Builder, true)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_sse2_packuswb_128: + case Intrinsic::x86_sse41_packusdw: + case Intrinsic::x86_avx2_packusdw: + case Intrinsic::x86_avx2_packuswb: + case Intrinsic::x86_avx512_packusdw_512: + case Intrinsic::x86_avx512_packuswb_512: + if (Value *V = simplifyX86pack(*II, *this, *Builder, false)) + return replaceInstUsesWith(*II, V); + break; + + case Intrinsic::x86_pclmulqdq: { + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(2))) { + unsigned Imm = C->getZExtValue(); + + bool MadeChange = false; + Value *Arg0 = II->getArgOperand(0); + Value *Arg1 = II->getArgOperand(1); + unsigned VWidth = Arg0->getType()->getVectorNumElements(); + APInt DemandedElts(VWidth, 0); + + APInt UndefElts1(VWidth, 0); + DemandedElts = (Imm & 0x01) ? 2 : 1; + if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts, + UndefElts1)) { + II->setArgOperand(0, V); + MadeChange = true; + } + + APInt UndefElts2(VWidth, 0); + DemandedElts = (Imm & 0x10) ? 2 : 1; + if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts, + UndefElts2)) { + II->setArgOperand(1, V); + MadeChange = true; + } + + // If both input elements are undef, the result is undef. + if (UndefElts1[(Imm & 0x01) ? 1 : 0] || + UndefElts2[(Imm & 0x10) ? 1 : 0]) + return replaceInstUsesWith(*II, + ConstantAggregateZero::get(II->getType())); + + if (MadeChange) + return II; + } + break; + } + case Intrinsic::x86_sse41_insertps: if (Value *V = simplifyX86insertps(*II, *Builder)) return replaceInstUsesWith(*II, V); @@ -2531,9 +3097,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } - case Intrinsic::amdgcn_rcp: { - if (const ConstantFP *C = dyn_cast<ConstantFP>(II->getArgOperand(0))) { + Value *Src = II->getArgOperand(0); + + // TODO: Move to ConstantFolding/InstSimplify? + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, Src); + + if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { const APFloat &ArgVal = C->getValueAPF(); APFloat Val(ArgVal.getSemantics(), 1.0); APFloat::opStatus Status = Val.divide(ArgVal, @@ -2546,6 +3117,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_rsq: { + Value *Src = II->getArgOperand(0); + + // TODO: Move to ConstantFolding/InstSimplify? + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(CI, Src); + break; + } case Intrinsic::amdgcn_frexp_mant: case Intrinsic::amdgcn_frexp_exp: { Value *Src = II->getArgOperand(0); @@ -2650,6 +3229,274 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantInt::get(II->getType(), Result)); } + case Intrinsic::amdgcn_cvt_pkrtz: { + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { + if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { + const fltSemantics &HalfSem + = II->getType()->getScalarType()->getFltSemantics(); + bool LosesInfo; + APFloat Val0 = C0->getValueAPF(); + APFloat Val1 = C1->getValueAPF(); + Val0.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); + Val1.convert(HalfSem, APFloat::rmTowardZero, &LosesInfo); + + Constant *Folded = ConstantVector::get({ + ConstantFP::get(II->getContext(), Val0), + ConstantFP::get(II->getContext(), Val1) }); + return replaceInstUsesWith(*II, Folded); + } + } + + if (isa<UndefValue>(Src0) && isa<UndefValue>(Src1)) + return replaceInstUsesWith(*II, UndefValue::get(II->getType())); + + break; + } + case Intrinsic::amdgcn_ubfe: + case Intrinsic::amdgcn_sbfe: { + // Decompose simple cases into standard shifts. + Value *Src = II->getArgOperand(0); + if (isa<UndefValue>(Src)) + return replaceInstUsesWith(*II, Src); + + unsigned Width; + Type *Ty = II->getType(); + unsigned IntSize = Ty->getIntegerBitWidth(); + + ConstantInt *CWidth = dyn_cast<ConstantInt>(II->getArgOperand(2)); + if (CWidth) { + Width = CWidth->getZExtValue(); + if ((Width & (IntSize - 1)) == 0) + return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty)); + + if (Width >= IntSize) { + // Hardware ignores high bits, so remove those. + II->setArgOperand(2, ConstantInt::get(CWidth->getType(), + Width & (IntSize - 1))); + return II; + } + } + + unsigned Offset; + ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1)); + if (COffset) { + Offset = COffset->getZExtValue(); + if (Offset >= IntSize) { + II->setArgOperand(1, ConstantInt::get(COffset->getType(), + Offset & (IntSize - 1))); + return II; + } + } + + bool Signed = II->getIntrinsicID() == Intrinsic::amdgcn_sbfe; + + // TODO: Also emit sub if only width is constant. + if (!CWidth && COffset && Offset == 0) { + Constant *KSize = ConstantInt::get(COffset->getType(), IntSize); + Value *ShiftVal = Builder->CreateSub(KSize, II->getArgOperand(2)); + ShiftVal = Builder->CreateZExt(ShiftVal, II->getType()); + + Value *Shl = Builder->CreateShl(Src, ShiftVal); + Value *RightShift = Signed ? + Builder->CreateAShr(Shl, ShiftVal) : + Builder->CreateLShr(Shl, ShiftVal); + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + + if (!CWidth || !COffset) + break; + + // TODO: This allows folding to undef when the hardware has specific + // behavior? + if (Offset + Width < IntSize) { + Value *Shl = Builder->CreateShl(Src, IntSize - Offset - Width); + Value *RightShift = Signed ? + Builder->CreateAShr(Shl, IntSize - Width) : + Builder->CreateLShr(Shl, IntSize - Width); + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + + Value *RightShift = Signed ? + Builder->CreateAShr(Src, Offset) : + Builder->CreateLShr(Src, Offset); + + RightShift->takeName(II); + return replaceInstUsesWith(*II, RightShift); + } + case Intrinsic::amdgcn_exp: + case Intrinsic::amdgcn_exp_compr: { + ConstantInt *En = dyn_cast<ConstantInt>(II->getArgOperand(1)); + if (!En) // Illegal. + break; + + unsigned EnBits = En->getZExtValue(); + if (EnBits == 0xf) + break; // All inputs enabled. + + bool IsCompr = II->getIntrinsicID() == Intrinsic::amdgcn_exp_compr; + bool Changed = false; + for (int I = 0; I < (IsCompr ? 2 : 4); ++I) { + if ((!IsCompr && (EnBits & (1 << I)) == 0) || + (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) { + Value *Src = II->getArgOperand(I + 2); + if (!isa<UndefValue>(Src)) { + II->setArgOperand(I + 2, UndefValue::get(Src->getType())); + Changed = true; + } + } + } + + if (Changed) + return II; + + break; + + } + case Intrinsic::amdgcn_fmed3: { + // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled + // for the shader. + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + Value *Src2 = II->getArgOperand(2); + + bool Swap = false; + // Canonicalize constants to RHS operands. + // + // fmed3(c0, x, c1) -> fmed3(x, c0, c1) + if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { + std::swap(Src0, Src1); + Swap = true; + } + + if (isa<Constant>(Src1) && !isa<Constant>(Src2)) { + std::swap(Src1, Src2); + Swap = true; + } + + if (isa<Constant>(Src0) && !isa<Constant>(Src1)) { + std::swap(Src0, Src1); + Swap = true; + } + + if (Swap) { + II->setArgOperand(0, Src0); + II->setArgOperand(1, Src1); + II->setArgOperand(2, Src2); + return II; + } + + if (match(Src2, m_NaN()) || isa<UndefValue>(Src2)) { + CallInst *NewCall = Builder->CreateMinNum(Src0, Src1); + NewCall->copyFastMathFlags(II); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + if (const ConstantFP *C0 = dyn_cast<ConstantFP>(Src0)) { + if (const ConstantFP *C1 = dyn_cast<ConstantFP>(Src1)) { + if (const ConstantFP *C2 = dyn_cast<ConstantFP>(Src2)) { + APFloat Result = fmed3AMDGCN(C0->getValueAPF(), C1->getValueAPF(), + C2->getValueAPF()); + return replaceInstUsesWith(*II, + ConstantFP::get(Builder->getContext(), Result)); + } + } + } + + break; + } + case Intrinsic::amdgcn_icmp: + case Intrinsic::amdgcn_fcmp: { + const ConstantInt *CC = dyn_cast<ConstantInt>(II->getArgOperand(2)); + if (!CC) + break; + + // Guard against invalid arguments. + int64_t CCVal = CC->getZExtValue(); + bool IsInteger = II->getIntrinsicID() == Intrinsic::amdgcn_icmp; + if ((IsInteger && (CCVal < CmpInst::FIRST_ICMP_PREDICATE || + CCVal > CmpInst::LAST_ICMP_PREDICATE)) || + (!IsInteger && (CCVal < CmpInst::FIRST_FCMP_PREDICATE || + CCVal > CmpInst::LAST_FCMP_PREDICATE))) + break; + + Value *Src0 = II->getArgOperand(0); + Value *Src1 = II->getArgOperand(1); + + if (auto *CSrc0 = dyn_cast<Constant>(Src0)) { + if (auto *CSrc1 = dyn_cast<Constant>(Src1)) { + Constant *CCmp = ConstantExpr::getCompare(CCVal, CSrc0, CSrc1); + return replaceInstUsesWith(*II, + ConstantExpr::getSExt(CCmp, II->getType())); + } + + // Canonicalize constants to RHS. + CmpInst::Predicate SwapPred + = CmpInst::getSwappedPredicate(static_cast<CmpInst::Predicate>(CCVal)); + II->setArgOperand(0, Src1); + II->setArgOperand(1, Src0); + II->setArgOperand(2, ConstantInt::get(CC->getType(), + static_cast<int>(SwapPred))); + return II; + } + + if (CCVal != CmpInst::ICMP_EQ && CCVal != CmpInst::ICMP_NE) + break; + + // Canonicalize compare eq with true value to compare != 0 + // llvm.amdgcn.icmp(zext (i1 x), 1, eq) + // -> llvm.amdgcn.icmp(zext (i1 x), 0, ne) + // llvm.amdgcn.icmp(sext (i1 x), -1, eq) + // -> llvm.amdgcn.icmp(sext (i1 x), 0, ne) + Value *ExtSrc; + if (CCVal == CmpInst::ICMP_EQ && + ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) || + (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) && + ExtSrc->getType()->isIntegerTy(1)) { + II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType())); + II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); + return II; + } + + CmpInst::Predicate SrcPred; + Value *SrcLHS; + Value *SrcRHS; + + // Fold compare eq/ne with 0 from a compare result as the predicate to the + // intrinsic. The typical use is a wave vote function in the library, which + // will be fed from a user code condition compared with 0. Fold in the + // redundant compare. + + // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, ne) + // -> llvm.amdgcn.[if]cmp(a, b, pred) + // + // llvm.amdgcn.icmp([sz]ext ([if]cmp pred a, b), 0, eq) + // -> llvm.amdgcn.[if]cmp(a, b, inv pred) + if (match(Src1, m_Zero()) && + match(Src0, + m_ZExtOrSExt(m_Cmp(SrcPred, m_Value(SrcLHS), m_Value(SrcRHS))))) { + if (CCVal == CmpInst::ICMP_EQ) + SrcPred = CmpInst::getInversePredicate(SrcPred); + + Intrinsic::ID NewIID = CmpInst::isFPPredicate(SrcPred) ? + Intrinsic::amdgcn_fcmp : Intrinsic::amdgcn_icmp; + + Value *NewF = Intrinsic::getDeclaration(II->getModule(), NewIID, + SrcLHS->getType()); + Value *Args[] = { SrcLHS, SrcRHS, + ConstantInt::get(CC->getType(), SrcPred) }; + CallInst *NewCall = Builder->CreateCall(NewF, Args); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -2790,7 +3637,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // isKnownNonNull -> nonnull attribute if (isKnownNonNullAt(DerivedPtr, II, &DT)) - II->addAttribute(AttributeSet::ReturnIndex, Attribute::NonNull); + II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); } // TODO: bitcast(relocate(p)) -> relocate(bitcast(p)) @@ -2799,11 +3646,38 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // TODO: relocate((gep p, C, C2, ...)) -> gep(relocate(p), C, C2, ...) break; } - } + case Intrinsic::experimental_guard: { + // Is this guard followed by another guard? + Instruction *NextInst = II->getNextNode(); + Value *NextCond = nullptr; + if (match(NextInst, + m_Intrinsic<Intrinsic::experimental_guard>(m_Value(NextCond)))) { + Value *CurrCond = II->getArgOperand(0); + + // Remove a guard that it is immediately preceded by an identical guard. + if (CurrCond == NextCond) + return eraseInstFromFunction(*NextInst); + + // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). + II->setArgOperand(0, Builder->CreateAnd(CurrCond, NextCond)); + return eraseInstFromFunction(*NextInst); + } + break; + } + } return visitCallSite(II); } +// Fence instruction simplification +Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { + // Remove identical consecutive fences. + if (auto *NFI = dyn_cast<FenceInst>(FI.getNextNode())) + if (FI.isIdenticalTo(NFI)) + return eraseInstFromFunction(FI); + return nullptr; +} + // InvokeInst simplification // Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { @@ -2950,7 +3824,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && - !CS.paramHasAttr(ArgNo + 1, Attribute::NonNull) && + !CS.paramHasAttr(ArgNo, Attribute::NonNull) && isKnownNonNullAt(V, CS.getInstruction(), &DT)) Indices.push_back(ArgNo + 1); ArgNo++; @@ -2959,7 +3833,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { assert(ArgNo == CS.arg_size() && "sanity check"); if (!Indices.empty()) { - AttributeSet AS = CS.getAttributes(); + AttributeList AS = CS.getAttributes(); LLVMContext &Ctx = CS.getInstruction()->getContext(); AS = AS.addAttribute(Ctx, Indices, Attribute::get(Ctx, Attribute::NonNull)); @@ -3081,7 +3955,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { return false; Instruction *Caller = CS.getInstruction(); - const AttributeSet &CallerPAL = CS.getAttributes(); + const AttributeList &CallerPAL = CS.getAttributes(); // Okay, this is a cast from a function to a different type. Unless doing so // would cause a type conversion of one of our arguments, change this call to @@ -3108,7 +3982,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { } if (!CallerPAL.isEmpty() && !Caller->use_empty()) { - AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); + AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex); if (RAttrs.overlaps(AttributeFuncs::typeIncompatible(NewRetTy))) return false; // Attribute not compatible with transformed value. } @@ -3149,8 +4023,8 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. - if (AttrBuilder(CallerPAL.getParamAttributes(i + 1), i + 1). - overlaps(AttributeFuncs::typeIncompatible(ParamTy))) + if (AttrBuilder(CallerPAL.getParamAttributes(i)) + .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) return false; // Attribute not compatible with transformed value. if (CS.isInAllocaArgument(i)) @@ -3158,9 +4032,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && - CallerPAL.getParamAttributes(i + 1).hasAttribute(i + 1, - Attribute::ByVal)) { + if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); if (!ParamPTy || !ParamPTy->getElementType()->isSized()) return false; @@ -3205,7 +4077,7 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { break; // Check if it has an attribute that's incompatible with varargs. - AttributeSet PAttrs = CallerPAL.getSlotAttributes(i - 1); + AttributeList PAttrs = CallerPAL.getSlotAttributes(i - 1); if (PAttrs.hasAttribute(Index, Attribute::StructRet)) return false; } @@ -3213,44 +4085,37 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // Okay, we decided that this is a safe thing to do: go ahead and start // inserting cast instructions as necessary. - std::vector<Value*> Args; + SmallVector<Value *, 8> Args; + SmallVector<AttributeSet, 8> ArgAttrs; Args.reserve(NumActualArgs); - SmallVector<AttributeSet, 8> attrVec; - attrVec.reserve(NumCommonArgs); + ArgAttrs.reserve(NumActualArgs); // Get any return attributes. - AttrBuilder RAttrs(CallerPAL, AttributeSet::ReturnIndex); + AttrBuilder RAttrs(CallerPAL, AttributeList::ReturnIndex); // If the return value is not being used, the type may not be compatible // with the existing attributes. Wipe out any problematic attributes. RAttrs.remove(AttributeFuncs::typeIncompatible(NewRetTy)); - // Add the new return attributes. - if (RAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(Caller->getContext(), - AttributeSet::ReturnIndex, RAttrs)); - AI = CS.arg_begin(); for (unsigned i = 0; i != NumCommonArgs; ++i, ++AI) { Type *ParamTy = FT->getParamType(i); - if ((*AI)->getType() == ParamTy) { - Args.push_back(*AI); - } else { - Args.push_back(Builder->CreateBitOrPointerCast(*AI, ParamTy)); - } + Value *NewArg = *AI; + if ((*AI)->getType() != ParamTy) + NewArg = Builder->CreateBitOrPointerCast(*AI, ParamTy); + Args.push_back(NewArg); // Add any parameter attributes. - AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); - if (PAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(Caller->getContext(), i + 1, - PAttrs)); + ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); } // If the function takes more arguments than the call was taking, add them // now. - for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) + for (unsigned i = NumCommonArgs; i != FT->getNumParams(); ++i) { Args.push_back(Constant::getNullValue(FT->getParamType(i))); + ArgAttrs.push_back(AttributeSet()); + } // If we are removing arguments to the function, emit an obnoxious warning. if (FT->getNumParams() < NumActualArgs) { @@ -3259,54 +4124,56 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // Add all of the arguments in their promoted form to the arg list. for (unsigned i = FT->getNumParams(); i != NumActualArgs; ++i, ++AI) { Type *PTy = getPromotedType((*AI)->getType()); + Value *NewArg = *AI; if (PTy != (*AI)->getType()) { // Must promote to pass through va_arg area! Instruction::CastOps opcode = CastInst::getCastOpcode(*AI, false, PTy, false); - Args.push_back(Builder->CreateCast(opcode, *AI, PTy)); - } else { - Args.push_back(*AI); + NewArg = Builder->CreateCast(opcode, *AI, PTy); } + Args.push_back(NewArg); // Add any parameter attributes. - AttrBuilder PAttrs(CallerPAL.getParamAttributes(i + 1), i + 1); - if (PAttrs.hasAttributes()) - attrVec.push_back(AttributeSet::get(FT->getContext(), i + 1, - PAttrs)); + ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); } } } AttributeSet FnAttrs = CallerPAL.getFnAttributes(); - if (CallerPAL.hasAttributes(AttributeSet::FunctionIndex)) - attrVec.push_back(AttributeSet::get(Callee->getContext(), FnAttrs)); if (NewRetTy->isVoidTy()) Caller->setName(""); // Void type should not have a name. - const AttributeSet &NewCallerPAL = AttributeSet::get(Callee->getContext(), - attrVec); + assert((ArgAttrs.size() == FT->getNumParams() || FT->isVarArg()) && + "missing argument attributes"); + LLVMContext &Ctx = Callee->getContext(); + AttributeList NewCallerPAL = AttributeList::get( + Ctx, FnAttrs, AttributeSet::get(Ctx, RAttrs), ArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); - Instruction *NC; + CallSite NewCS; if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - NC = Builder->CreateInvoke(Callee, II->getNormalDest(), II->getUnwindDest(), - Args, OpBundles); - NC->takeName(II); - cast<InvokeInst>(NC)->setCallingConv(II->getCallingConv()); - cast<InvokeInst>(NC)->setAttributes(NewCallerPAL); + NewCS = Builder->CreateInvoke(Callee, II->getNormalDest(), + II->getUnwindDest(), Args, OpBundles); } else { - CallInst *CI = cast<CallInst>(Caller); - NC = Builder->CreateCall(Callee, Args, OpBundles); - NC->takeName(CI); - cast<CallInst>(NC)->setTailCallKind(CI->getTailCallKind()); - cast<CallInst>(NC)->setCallingConv(CI->getCallingConv()); - cast<CallInst>(NC)->setAttributes(NewCallerPAL); + NewCS = Builder->CreateCall(Callee, Args, OpBundles); + cast<CallInst>(NewCS.getInstruction()) + ->setTailCallKind(cast<CallInst>(Caller)->getTailCallKind()); } + NewCS->takeName(Caller); + NewCS.setCallingConv(CS.getCallingConv()); + NewCS.setAttributes(NewCallerPAL); + + // Preserve the weight metadata for the new call instruction. The metadata + // is used by SamplePGO to check callsite's hotness. + uint64_t W; + if (Caller->extractProfTotalWeight(W)) + NewCS->setProfWeight(W); // Insert a cast of the return type as necessary. + Instruction *NC = NewCS.getInstruction(); Value *NV = NC; if (OldRetTy != NV->getType() && !Caller->use_empty()) { if (!NV->getType()->isVoidTy()) { @@ -3351,7 +4218,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Value *Callee = CS.getCalledValue(); PointerType *PTy = cast<PointerType>(Callee->getType()); FunctionType *FTy = cast<FunctionType>(PTy->getElementType()); - const AttributeSet &Attrs = CS.getAttributes(); + AttributeList Attrs = CS.getAttributes(); // If the call already has the 'nest' attribute somewhere then give up - // otherwise 'nest' would occur twice after splicing in the chain. @@ -3364,50 +4231,46 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, Function *NestF =cast<Function>(Tramp->getArgOperand(1)->stripPointerCasts()); FunctionType *NestFTy = cast<FunctionType>(NestF->getValueType()); - const AttributeSet &NestAttrs = NestF->getAttributes(); + AttributeList NestAttrs = NestF->getAttributes(); if (!NestAttrs.isEmpty()) { - unsigned NestIdx = 1; + unsigned NestArgNo = 0; Type *NestTy = nullptr; AttributeSet NestAttr; // Look for a parameter marked with the 'nest' attribute. for (FunctionType::param_iterator I = NestFTy->param_begin(), - E = NestFTy->param_end(); I != E; ++NestIdx, ++I) - if (NestAttrs.hasAttribute(NestIdx, Attribute::Nest)) { + E = NestFTy->param_end(); + I != E; ++NestArgNo, ++I) { + AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo); + if (AS.hasAttribute(Attribute::Nest)) { // Record the parameter type and any other attributes. NestTy = *I; - NestAttr = NestAttrs.getParamAttributes(NestIdx); + NestAttr = AS; break; } + } if (NestTy) { Instruction *Caller = CS.getInstruction(); std::vector<Value*> NewArgs; + std::vector<AttributeSet> NewArgAttrs; NewArgs.reserve(CS.arg_size() + 1); - - SmallVector<AttributeSet, 8> NewAttrs; - NewAttrs.reserve(Attrs.getNumSlots() + 1); + NewArgAttrs.reserve(CS.arg_size()); // Insert the nest argument into the call argument list, which may // mean appending it. Likewise for attributes. - // Add any result attributes. - if (Attrs.hasAttributes(AttributeSet::ReturnIndex)) - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - Attrs.getRetAttributes())); - { - unsigned Idx = 1; + unsigned ArgNo = 0; CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); do { - if (Idx == NestIdx) { + if (ArgNo == NestArgNo) { // Add the chain argument and attributes. Value *NestVal = Tramp->getArgOperand(2); if (NestVal->getType() != NestTy) NestVal = Builder->CreateBitCast(NestVal, NestTy, "nest"); NewArgs.push_back(NestVal); - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - NestAttr)); + NewArgAttrs.push_back(NestAttr); } if (I == E) @@ -3415,23 +4278,13 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original argument and attributes. NewArgs.push_back(*I); - AttributeSet Attr = Attrs.getParamAttributes(Idx); - if (Attr.hasAttributes(Idx)) { - AttrBuilder B(Attr, Idx); - NewAttrs.push_back(AttributeSet::get(Caller->getContext(), - Idx + (Idx >= NestIdx), B)); - } + NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); - ++Idx; + ++ArgNo; ++I; } while (true); } - // Add any function attributes. - if (Attrs.hasAttributes(AttributeSet::FunctionIndex)) - NewAttrs.push_back(AttributeSet::get(FTy->getContext(), - Attrs.getFnAttributes())); - // The trampoline may have been bitcast to a bogus type (FTy). // Handle this by synthesizing a new function type, equal to FTy // with the chain parameter inserted. @@ -3442,12 +4295,12 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Insert the chain's type into the list of parameter types, which may // mean appending it. { - unsigned Idx = 1; + unsigned ArgNo = 0; FunctionType::param_iterator I = FTy->param_begin(), E = FTy->param_end(); do { - if (Idx == NestIdx) + if (ArgNo == NestArgNo) // Add the chain's type. NewTypes.push_back(NestTy); @@ -3457,7 +4310,7 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, // Add the original type. NewTypes.push_back(*I); - ++Idx; + ++ArgNo; ++I; } while (true); } @@ -3470,8 +4323,9 @@ InstCombiner::transformCallThroughTrampoline(CallSite CS, NestF->getType() == PointerType::getUnqual(NewFTy) ? NestF : ConstantExpr::getBitCast(NestF, PointerType::getUnqual(NewFTy)); - const AttributeSet &NewPAL = - AttributeSet::get(FTy->getContext(), NewAttrs); + AttributeList NewPAL = + AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(), + Attrs.getRetAttributes(), NewArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; CS.getOperandBundlesAsDefs(OpBundles); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index e74b590e2b7c..25683132c786 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -274,12 +274,12 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { return NV; // If we are casting a PHI, then fold the cast into the PHI. - if (isa<PHINode>(Src)) { + if (auto *PN = dyn_cast<PHINode>(Src)) { // Don't do this if it would create a PHI node with an illegal type from a // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || - ShouldChangeType(CI.getType(), Src->getType())) - if (Instruction *NV = FoldOpIntoPhi(CI)) + shouldChangeType(CI.getType(), Src->getType())) + if (Instruction *NV = foldOpIntoPhi(CI, PN)) return NV; } @@ -447,7 +447,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC, Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !ShouldChangeType(SrcTy, DestTy)) + if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; BinaryOperator *LogicOp; @@ -463,6 +463,56 @@ Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); } +/// Try to narrow the width of a splat shuffle. This could be generalized to any +/// shuffle with a constant operand, but we limit the transform to avoid +/// creating a shuffle type that targets may not be able to lower effectively. +static Instruction *shrinkSplatShuffle(TruncInst &Trunc, + InstCombiner::BuilderTy &Builder) { + auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); + if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) && + Shuf->getMask()->getSplatValue() && + Shuf->getType() == Shuf->getOperand(0)->getType()) { + // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask + Constant *NarrowUndef = UndefValue::get(Trunc.getType()); + Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); + return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask()); + } + + return nullptr; +} + +/// Try to narrow the width of an insert element. This could be generalized for +/// any vector constant, but we limit the transform to insertion into undef to +/// avoid potential backend problems from unsupported insertion widths. This +/// could also be extended to handle the case of inserting a scalar constant +/// into a vector variable. +static Instruction *shrinkInsertElt(CastInst &Trunc, + InstCombiner::BuilderTy &Builder) { + Instruction::CastOps Opcode = Trunc.getOpcode(); + assert((Opcode == Instruction::Trunc || Opcode == Instruction::FPTrunc) && + "Unexpected instruction for shrinking"); + + auto *InsElt = dyn_cast<InsertElementInst>(Trunc.getOperand(0)); + if (!InsElt || !InsElt->hasOneUse()) + return nullptr; + + Type *DestTy = Trunc.getType(); + Type *DestScalarTy = DestTy->getScalarType(); + Value *VecOp = InsElt->getOperand(0); + Value *ScalarOp = InsElt->getOperand(1); + Value *Index = InsElt->getOperand(2); + + if (isa<UndefValue>(VecOp)) { + // trunc (inselt undef, X, Index) --> inselt undef, (trunc X), Index + // fptrunc (inselt undef, X, Index) --> inselt undef, (fptrunc X), Index + UndefValue *NarrowUndef = UndefValue::get(DestTy); + Value *NarrowOp = Builder.CreateCast(Opcode, ScalarOp, DestScalarTy); + return InsertElementInst::Create(NarrowUndef, NarrowOp, Index); + } + + return nullptr; +} + Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; @@ -488,7 +538,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateTruncated(Src, DestTy, *this, &CI)) { // If this cast is a truncate, evaluting in a different type always @@ -554,8 +604,14 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { if (Instruction *I = shrinkBitwiseLogic(CI)) return I; + if (Instruction *I = shrinkSplatShuffle(CI, *Builder)) + return I; + + if (Instruction *I = shrinkInsertElt(CI, *Builder)) + return I; + if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && - ShouldChangeType(SrcTy, DestTy)) { + shouldChangeType(SrcTy, DestTy)) { // Transform "trunc (shl X, cst)" -> "shl (trunc X), cst" so long as the // dest type is native and cst < dest size. if (match(Src, m_Shl(m_Value(A), m_ConstantInt(Cst))) && @@ -838,11 +894,6 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { if (Instruction *Result = commonCastTransforms(CI)) return Result; - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); @@ -851,10 +902,10 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // expression tree to something weird like i93 unless the source is also // strange. unsigned BitsToClear; - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { - assert(BitsToClear < SrcTy->getScalarSizeInBits() && - "Unreasonable BitsToClear"); + assert(BitsToClear <= SrcTy->getScalarSizeInBits() && + "Can't clear more bits than in SrcTy"); // Okay, we can transform this! Insert the new expression now. DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1124,11 +1175,6 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { if (Instruction *I = commonCastTransforms(CI)) return I; - // See if we can simplify any instructions used by the input whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; - Value *Src = CI.getOperand(0); Type *SrcTy = Src->getType(), *DestTy = CI.getType(); @@ -1145,7 +1191,7 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. - if ((DestTy->isVectorTy() || ShouldChangeType(SrcTy, DestTy)) && + if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && canEvaluateSExtd(Src, DestTy)) { // Okay, we can transform this! Insert the new expression now. DEBUG(dbgs() << "ICE: EvaluateInDifferentType converting expression type" @@ -1167,18 +1213,16 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { ShAmt); } - // If this input is a trunc from our destination, then turn sext(trunc(x)) + // If the input is a trunc from the destination type, then turn sext(trunc(x)) // into shifts. - if (TruncInst *TI = dyn_cast<TruncInst>(Src)) - if (TI->hasOneUse() && TI->getOperand(0)->getType() == DestTy) { - uint32_t SrcBitSize = SrcTy->getScalarSizeInBits(); - uint32_t DestBitSize = DestTy->getScalarSizeInBits(); - - // We need to emit a shl + ashr to do the sign extend. - Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); - Value *Res = Builder->CreateShl(TI->getOperand(0), ShAmt, "sext"); - return BinaryOperator::CreateAShr(Res, ShAmt); - } + Value *X; + if (match(Src, m_OneUse(m_Trunc(m_Value(X)))) && X->getType() == DestTy) { + // sext(trunc(X)) --> ashr(shl(X, C), C) + unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); + unsigned DestBitSize = DestTy->getScalarSizeInBits(); + Constant *ShAmt = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + return BinaryOperator::CreateAShr(Builder->CreateShl(X, ShAmt), ShAmt); + } if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) return transformSExtICmp(ICI, CI); @@ -1225,17 +1269,15 @@ static Constant *fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) { return nullptr; } -/// If this is a floating-point extension instruction, look -/// through it until we get the source value. +/// Look through floating-point extensions until we get the source value. static Value *lookThroughFPExtensions(Value *V) { - if (Instruction *I = dyn_cast<Instruction>(V)) - if (I->getOpcode() == Instruction::FPExt) - return lookThroughFPExtensions(I->getOperand(0)); + while (auto *FPExt = dyn_cast<FPExtInst>(V)) + V = FPExt->getOperand(0); // If this value is a constant, return the constant in the smallest FP type // that can accurately represent it. This allows us to turn // (float)((double)X+2.0) into x+2.0f. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { + if (auto *CFP = dyn_cast<ConstantFP>(V)) { if (CFP->getType() == Type::getPPC_FP128Ty(V->getContext())) return V; // No constant folding of this. // See if the value can be truncated to half and then reextended. @@ -1392,24 +1434,49 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &CI) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI.getOperand(0)); if (II) { switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::fabs: { - // (fptrunc (fabs x)) -> (fabs (fptrunc x)) - Value *InnerTrunc = Builder->CreateFPTrunc(II->getArgOperand(0), - CI.getType()); - Type *IntrinsicType[] = { CI.getType() }; - Function *Overload = Intrinsic::getDeclaration( - CI.getModule(), II->getIntrinsicID(), IntrinsicType); - - SmallVector<OperandBundleDef, 1> OpBundles; - II->getOperandBundlesAsDefs(OpBundles); - - Value *Args[] = { InnerTrunc }; - return CallInst::Create(Overload, Args, OpBundles, II->getName()); + default: break; + case Intrinsic::fabs: + case Intrinsic::ceil: + case Intrinsic::floor: + case Intrinsic::rint: + case Intrinsic::round: + case Intrinsic::nearbyint: + case Intrinsic::trunc: { + Value *Src = II->getArgOperand(0); + if (!Src->hasOneUse()) + break; + + // Except for fabs, this transformation requires the input of the unary FP + // operation to be itself an fpext from the type to which we're + // truncating. + if (II->getIntrinsicID() != Intrinsic::fabs) { + FPExtInst *FPExtSrc = dyn_cast<FPExtInst>(Src); + if (!FPExtSrc || FPExtSrc->getOperand(0)->getType() != CI.getType()) + break; } + + // Do unary FP operation on smaller type. + // (fptrunc (fabs x)) -> (fabs (fptrunc x)) + Value *InnerTrunc = Builder->CreateFPTrunc(Src, CI.getType()); + Type *IntrinsicType[] = { CI.getType() }; + Function *Overload = Intrinsic::getDeclaration( + CI.getModule(), II->getIntrinsicID(), IntrinsicType); + + SmallVector<OperandBundleDef, 1> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + + Value *Args[] = { InnerTrunc }; + CallInst *NewCI = CallInst::Create(Overload, Args, + OpBundles, II->getName()); + NewCI->copyFastMathFlags(II); + return NewCI; + } } } + if (Instruction *I = shrinkInsertElt(CI, *Builder)) + return I; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 428f94bb5e93..bbafa9e9f468 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -230,7 +230,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. + // Don't blow up on huge arrays. + if (ArrayElementCount > MaxArraySizeForCombine) + return nullptr; // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -1663,7 +1665,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using - // ShouldChangeType()? + // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); @@ -1792,6 +1794,15 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, ConstantInt::get(V->getType(), 1)); } + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && + (*C + 1).isPowerOf2()) { + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); + } + if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) return nullptr; @@ -1914,61 +1925,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Shl->getOperand(0); - if (Cmp.isEquality()) { - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); - if (Shl->hasNoUnsignedWrap()) - return new ICmpInst(Pred, X, LShrC); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (Shl->hasNoSignedWrap() && *C == 0) - return new ICmpInst(Pred, X, LShrC); - - if (Shl->hasOneUse()) { - // Otherwise, strength reduce the shift into an and. - Constant *Mask = ConstantInt::get(Shl->getType(), - APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); - - Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); - return new ICmpInst(Pred, And, LShrC); + Type *ShType = Shl->getType(); + + // NSW guarantees that we are only shifting out sign bits from the high bits, + // so we can ASHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoSignedWrap()) { + if (Pred == ICmpInst::ICMP_SGT) { + // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the SGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_SLT) { + // SLE is the same as above, but SLE is canonicalized to SLT, so convert: + // (X << S) <=s C is equiv to X <=s (C >> S) for all C + // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX + // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN + assert(!C->isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); + } + + // NUW guarantees that we are only shifting out zero bits from the high bits, + // so we can LSHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoUnsignedWrap()) { + if (Pred == ICmpInst::ICMP_UGT) { + // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the UGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_ULT) { + // ULE is the same as above, but ULE is canonicalized to ULT, so convert: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert(C->ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead; isSignTest may change 'Pred', so only - // do that if we're sure to not continue on in this function. - if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) - return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + if (Cmp.isEquality() && Shl->hasOneUse()) { + // Strength-reduce the shift into an 'and'. + Constant *Mask = ConstantInt::get( + ShType, + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + return new ICmpInst(Pred, And, LShrC); + } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( - X->getType(), + ShType, APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } - - // When the shift is nuw and pred is >u or <=u, comparison only really happens - // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the - // <=u case can be further converted to match <u (see below). - if (Shl->hasNoUnsignedWrap() && - (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { - // Derivation for the ult case: - // (X << S) <=u C is equiv to X <=u (C >> S) for all C - // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u - // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && - "Encountered `ult 0` that should have been eliminated by " - "InstSimplify."); - APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 - : C->lshr(*ShiftAmt); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + And, Constant::getNullValue(ShType)); } // Transform (icmp pred iM (shl iM %v, N), C) @@ -1981,8 +2020,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (X->getType()->isVectorTy()) - TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + if (ShType->isVectorTy()) + TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); @@ -2342,8 +2381,24 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // Fold icmp pred (add X, C2), C. Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - auto CR = - ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + CmpInst::Predicate Pred = Cmp.getPredicate(); + + // If the add does not wrap, we can always adjust the compare by subtracting + // the constants. Equality comparisons are handled elsewhere. SGE/SLE are + // canonicalized to SGT/SLT. + if (Add->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { + bool Overflow; + APInt NewC = C->ssub_ov(*C2, Overflow); + // If there is overflow, the result must be true or false. + // TODO: Can we assert there is no overflow because InstSimplify always + // handles those cases? + if (!Overflow) + // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) + return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); + } + + auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { @@ -2364,16 +2419,14 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == 0) + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && - (*C2 & *C) == 0) + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); @@ -2656,7 +2709,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::Select: { @@ -2767,12 +2820,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, @@ -2847,6 +2894,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + // TODO: The subtraction-related identities shown below also hold, but + // canonicalization from (X -nuw 1) to (X + -1) means that the combinations + // wouldn't happen even if they were implemented. + // + // icmp ult (X - 1), Y -> icmp ule X, Y + // icmp uge (X - 1), Y -> icmp ugt X, Y + // icmp ugt X, (Y - 1) -> icmp uge X, Y + // icmp ule X, (Y - 1) -> icmp ult X, Y + + // icmp ule (X + 1), Y -> icmp ult X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); + + // icmp ugt (X + 1), Y -> icmp uge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); + + // icmp uge X, (Y + 1) -> icmp ugt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); + + // icmp ult X, (Y + 1) -> icmp ule X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); + // if C1 has greater magnitude than C2: // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y // s.t. C3 = C1 - C2 @@ -3738,16 +3810,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, // greater than the RHS must differ in a bit higher than these due to carry. case ICmpInst::ICMP_UGT: { unsigned trailingOnes = RHS.countTrailingOnes(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingOnes); } // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. case ICmpInst::ICMP_ULT: { unsigned trailingZeros = RHS.countTrailingZeros(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingZeros); } default: @@ -3887,7 +3957,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); - // The check for the unique predecessor is not the best that can be + // The check for the single predecessor is not the best that can be // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets @@ -3895,8 +3965,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, // guarantees that the path all uses of SI (outside SI's parent) are on // is disjoint from all other paths out of SI. But that information // is more expensive to compute, and the trade-off here is in favor - // of compile-time. - if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { + // of compile-time. It should also be noticed that we check for a single + // predecessor and not only uniqueness. This to handle the situation when + // Succ and Succ1 points to the same basic block. + if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { NumSel++; SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); return true; @@ -3932,12 +4004,12 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - if (SimplifyDemandedBits(I.getOperandUse(0), + if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth, IsSignBit), Op0KnownZero, Op0KnownOne, 0)) return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), Op1KnownZero, Op1KnownOne, 0)) return &I; @@ -4801,7 +4873,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::SIToFP: diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 2847ce858e79..71000063ab3c 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -28,6 +28,9 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Pass.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Support/Dwarf.h" +#include "llvm/IR/DIBuilder.h" #define DEBUG_TYPE "instcombine" @@ -40,21 +43,29 @@ class DbgDeclareInst; class MemIntrinsic; class MemSetInst; -/// \brief Assign a complexity or rank value to LLVM Values. +/// Assign a complexity or rank value to LLVM Values. This is used to reduce +/// the amount of pattern matching needed for compares and commutative +/// instructions. For example, if we have: +/// icmp ugt X, Constant +/// or +/// xor (add X, Constant), cast Z +/// +/// We do not have to consider the commuted variants of these patterns because +/// canonicalization based on complexity guarantees the above ordering. /// /// This routine maps IR values to various complexity ranks: /// 0 -> undef /// 1 -> Constants /// 2 -> Other non-instructions /// 3 -> Arguments -/// 3 -> Unary operations -/// 4 -> Other instructions +/// 4 -> Cast and (f)neg/not instructions +/// 5 -> Other instructions static inline unsigned getComplexity(Value *V) { if (isa<Instruction>(V)) { - if (BinaryOperator::isNeg(V) || BinaryOperator::isFNeg(V) || - BinaryOperator::isNot(V)) - return 3; - return 4; + if (isa<CastInst>(V) || BinaryOperator::isNeg(V) || + BinaryOperator::isFNeg(V) || BinaryOperator::isNot(V)) + return 4; + return 5; } if (isa<Argument>(V)) return 3; @@ -289,6 +300,7 @@ public: Instruction *visitLoadInst(LoadInst &LI); Instruction *visitStoreInst(StoreInst &SI); Instruction *visitBranchInst(BranchInst &BI); + Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); Instruction *visitReturnInst(ReturnInst &RI); Instruction *visitInsertValueInst(InsertValueInst &IV); @@ -313,9 +325,14 @@ public: bool replacedSelectWithOperand(SelectInst *SI, const ICmpInst *Icmp, const unsigned SIOpd); + /// Try to replace instruction \p I with value \p V which are pointers + /// in different address space. + /// \return true if successful. + bool replacePointer(Instruction &I, Value *V); + private: - bool ShouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; - bool ShouldChangeType(Type *From, Type *To) const; + bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; + bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; Value *dyn_castFNegVal(Value *V, bool NoSignedZero = false) const; Type *FindElementAtOffset(PointerType *PtrTy, int64_t Offset, @@ -456,8 +473,9 @@ public: /// methods should return the value returned by this function. Instruction *eraseInstFromFunction(Instruction &I) { DEBUG(dbgs() << "IC: ERASE " << I << '\n'); - assert(I.use_empty() && "Cannot erase instruction that is used!"); + salvageDebugInfo(I); + // Make sure that we reprocess all operands now that we reduced their // use counts. if (I.getNumOperands() < 8) { @@ -499,6 +517,9 @@ public: return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + /// Maximum size of array considered when transforming. + uint64_t MaxArraySizeForCombine; + private: /// \brief Performs a few simplifications for operators which are associative /// or commutative. @@ -518,8 +539,16 @@ private: Value *SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth, Instruction *CxtI); - bool SimplifyDemandedBits(Use &U, const APInt &DemandedMask, APInt &KnownZero, + bool SimplifyDemandedBits(Instruction *I, unsigned Op, + const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth = 0); + /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne + /// bits. It also tries to handle simplifications that can be done based on + /// DemandedMask, but without modifying the Instruction. + Value *SimplifyMultipleUseDemandedBits(Instruction *I, + const APInt &DemandedMask, + APInt &KnownZero, APInt &KnownOne, + unsigned Depth, Instruction *CxtI); /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. Value *SimplifyShrShlDemandedBits(Instruction *Lsr, Instruction *Sftl, @@ -540,7 +569,7 @@ private: /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is /// only possible if all operands to the PHI are constants). - Instruction *FoldOpIntoPhi(Instruction &I); + Instruction *foldOpIntoPhi(Instruction &I, PHINode *PN); /// Given an instruction with a select as one operand and a constant as the /// other operand, try to fold the binary operator into the select arguments. @@ -549,7 +578,7 @@ private: Instruction *FoldOpIntoSelect(Instruction &Op, SelectInst *SI); /// This is a convenience wrapper function for the above two functions. - Instruction *foldOpWithConstantIntoOperand(Instruction &I); + Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); /// \brief Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. @@ -628,16 +657,16 @@ private: SelectPatternFlavor SPF2, Value *C); Instruction *foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI); - Instruction *OptAndOp(Instruction *Op, ConstantInt *OpRHS, + Instruction *OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd); - Value *FoldLogicalPlusAnd(Value *LHS, Value *RHS, ConstantInt *Mask, - bool isSub, Instruction &I); Value *insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside); Instruction *PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI); Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); + + Instruction *SimplifyElementAtomicMemCpy(ElementAtomicMemCpyInst *AMI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 49e516e9c176..6288e054f1bc 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -12,13 +12,15 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/Loads.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" @@ -223,6 +225,107 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { return nullptr; } +namespace { +// If I and V are pointers in different address space, it is not allowed to +// use replaceAllUsesWith since I and V have different types. A +// non-target-specific transformation should not use addrspacecast on V since +// the two address space may be disjoint depending on target. +// +// This class chases down uses of the old pointer until reaching the load +// instructions, then replaces the old pointer in the load instructions with +// the new pointer. If during the chasing it sees bitcast or GEP, it will +// create new bitcast or GEP with the new pointer and use them in the load +// instruction. +class PointerReplacer { +public: + PointerReplacer(InstCombiner &IC) : IC(IC) {} + void replacePointer(Instruction &I, Value *V); + +private: + void findLoadAndReplace(Instruction &I); + void replace(Instruction *I); + Value *getReplacement(Value *I); + + SmallVector<Instruction *, 4> Path; + MapVector<Value *, Value *> WorkMap; + InstCombiner &IC; +}; +} // end anonymous namespace + +void PointerReplacer::findLoadAndReplace(Instruction &I) { + for (auto U : I.users()) { + auto *Inst = dyn_cast<Instruction>(&*U); + if (!Inst) + return; + DEBUG(dbgs() << "Found pointer user: " << *U << '\n'); + if (isa<LoadInst>(Inst)) { + for (auto P : Path) + replace(P); + replace(Inst); + } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { + Path.push_back(Inst); + findLoadAndReplace(*Inst); + Path.pop_back(); + } else { + return; + } + } +} + +Value *PointerReplacer::getReplacement(Value *V) { + auto Loc = WorkMap.find(V); + if (Loc != WorkMap.end()) + return Loc->second; + return nullptr; +} + +void PointerReplacer::replace(Instruction *I) { + if (getReplacement(I)) + return; + + if (auto *LT = dyn_cast<LoadInst>(I)) { + auto *V = getReplacement(LT->getPointerOperand()); + assert(V && "Operand not replaced"); + auto *NewI = new LoadInst(V); + NewI->takeName(LT); + IC.InsertNewInstWith(NewI, *LT); + IC.replaceInstUsesWith(*LT, NewI); + WorkMap[LT] = NewI; + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + auto *V = getReplacement(GEP->getPointerOperand()); + assert(V && "Operand not replaced"); + SmallVector<Value *, 8> Indices; + Indices.append(GEP->idx_begin(), GEP->idx_end()); + auto *NewI = GetElementPtrInst::Create( + V->getType()->getPointerElementType(), V, Indices); + IC.InsertNewInstWith(NewI, *GEP); + NewI->takeName(GEP); + WorkMap[GEP] = NewI; + } else if (auto *BC = dyn_cast<BitCastInst>(I)) { + auto *V = getReplacement(BC->getOperand(0)); + assert(V && "Operand not replaced"); + auto *NewT = PointerType::get(BC->getType()->getPointerElementType(), + V->getType()->getPointerAddressSpace()); + auto *NewI = new BitCastInst(V, NewT); + IC.InsertNewInstWith(NewI, *BC); + NewI->takeName(BC); + WorkMap[BC] = NewI; + } else { + llvm_unreachable("should never reach here"); + } +} + +void PointerReplacer::replacePointer(Instruction &I, Value *V) { +#ifndef NDEBUG + auto *PT = cast<PointerType>(I.getType()); + auto *NT = cast<PointerType>(V->getType()); + assert(PT != NT && PT->getElementType() == NT->getElementType() && + "Invalid usage"); +#endif + WorkMap[&I] = V; + findLoadAndReplace(I); +} + Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { if (auto *I = simplifyAllocaArraySize(*this, AI)) return I; @@ -293,12 +396,22 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) eraseInstFromFunction(*ToDelete[i]); Constant *TheSrc = cast<Constant>(Copy->getSource()); - Constant *Cast - = ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, AI.getType()); - Instruction *NewI = replaceInstUsesWith(AI, Cast); - eraseInstFromFunction(*Copy); - ++NumGlobalCopies; - return NewI; + auto *SrcTy = TheSrc->getType(); + auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), + SrcTy->getPointerAddressSpace()); + Constant *Cast = + ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy); + if (AI.getType()->getPointerAddressSpace() == + SrcTy->getPointerAddressSpace()) { + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); + ++NumGlobalCopies; + return NewI; + } else { + PointerReplacer PtrReplacer(*this); + PtrReplacer.replacePointer(AI, Cast); + ++NumGlobalCopies; + } } } } @@ -608,7 +721,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { // arrays of arbitrary size but this has a terrible impact on compile time. // The threshold here is chosen arbitrarily, maybe needs a little bit of // tuning. - if (NumElements > 1024) + if (NumElements > IC.MaxArraySizeForCombine) return nullptr; const DataLayout &DL = IC.getDataLayout(); @@ -1113,7 +1226,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { // arrays of arbitrary size but this has a terrible impact on compile time. // The threshold here is chosen arbitrarily, maybe needs a little bit of // tuning. - if (NumElements > 1024) + if (NumElements > IC.MaxArraySizeForCombine) return false; const DataLayout &DL = IC.getDataLayout(); @@ -1268,8 +1381,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { break; } - // Don't skip over loads or things that can modify memory. - if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory()) + // Don't skip over loads, throws or things that can modify memory. + if (BBI->mayWriteToMemory() || BBI->mayReadFromMemory() || BBI->mayThrow()) break; } @@ -1392,8 +1505,8 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { } // If we find something that may be using or overwriting the stored // value, or if we run out of instructions, we can't do the xform. - if (BBI->mayReadFromMemory() || BBI->mayWriteToMemory() || - BBI == OtherBB->begin()) + if (BBI->mayReadFromMemory() || BBI->mayThrow() || + BBI->mayWriteToMemory() || BBI == OtherBB->begin()) return false; } @@ -1402,7 +1515,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { // StoreBB. for (BasicBlock::iterator I = StoreBB->begin(); &*I != &SI; ++I) { // FIXME: This should really be AA driven. - if (I->mayReadFromMemory() || I->mayWriteToMemory()) + if (I->mayReadFromMemory() || I->mayThrow() || I->mayWriteToMemory()) return false; } } @@ -1425,7 +1538,9 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { SI.getOrdering(), SI.getSynchScope()); InsertNewInstBefore(NewSI, *BBI); - NewSI->setDebugLoc(OtherStore->getDebugLoc()); + // The debug locations of the original instructions might differ; merge them. + NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(), + OtherStore->getDebugLoc())); // If the two stores had AA tags, merge them. AAMDNodes AATags; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 45a19fb0f1f2..f1ac82057e6c 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -298,39 +298,33 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { // (X / Y) * Y = X - (X % Y) // (X / Y) * -Y = (X % Y) - X { - Value *Op1C = Op1; - BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0); - if (!BO || - (BO->getOpcode() != Instruction::UDiv && - BO->getOpcode() != Instruction::SDiv)) { - Op1C = Op0; - BO = dyn_cast<BinaryOperator>(Op1); + Value *Y = Op1; + BinaryOperator *Div = dyn_cast<BinaryOperator>(Op0); + if (!Div || (Div->getOpcode() != Instruction::UDiv && + Div->getOpcode() != Instruction::SDiv)) { + Y = Op0; + Div = dyn_cast<BinaryOperator>(Op1); } - Value *Neg = dyn_castNegVal(Op1C); - if (BO && BO->hasOneUse() && - (BO->getOperand(1) == Op1C || BO->getOperand(1) == Neg) && - (BO->getOpcode() == Instruction::UDiv || - BO->getOpcode() == Instruction::SDiv)) { - Value *Op0BO = BO->getOperand(0), *Op1BO = BO->getOperand(1); + Value *Neg = dyn_castNegVal(Y); + if (Div && Div->hasOneUse() && + (Div->getOperand(1) == Y || Div->getOperand(1) == Neg) && + (Div->getOpcode() == Instruction::UDiv || + Div->getOpcode() == Instruction::SDiv)) { + Value *X = Div->getOperand(0), *DivOp1 = Div->getOperand(1); // If the division is exact, X % Y is zero, so we end up with X or -X. - if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) - if (SDiv->isExact()) { - if (Op1BO == Op1C) - return replaceInstUsesWith(I, Op0BO); - return BinaryOperator::CreateNeg(Op0BO); - } - - Value *Rem; - if (BO->getOpcode() == Instruction::UDiv) - Rem = Builder->CreateURem(Op0BO, Op1BO); - else - Rem = Builder->CreateSRem(Op0BO, Op1BO); - Rem->takeName(BO); + if (Div->isExact()) { + if (DivOp1 == Y) + return replaceInstUsesWith(I, X); + return BinaryOperator::CreateNeg(X); + } - if (Op1BO == Op1C) - return BinaryOperator::CreateSub(Op0BO, Rem); - return BinaryOperator::CreateSub(Rem, Op0BO); + auto RemOpc = Div->getOpcode() == Instruction::UDiv ? Instruction::URem + : Instruction::SRem; + Value *Rem = Builder->CreateBinOp(RemOpc, X, DivOp1); + if (DivOp1 == Y) + return BinaryOperator::CreateSub(X, Rem); + return BinaryOperator::CreateSub(Rem, X); } } @@ -1461,16 +1455,16 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (SelectInst *SI = dyn_cast<SelectInst>(Op0I)) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - } else if (isa<PHINode>(Op0I)) { + } else if (auto *PN = dyn_cast<PHINode>(Op0I)) { using namespace llvm::PatternMatch; const APInt *Op1Int; if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && (I.getOpcode() == Instruction::URem || !Op1Int->isMinSignedValue())) { - // FoldOpIntoPhi will speculate instructions to the end of the PHI's + // foldOpIntoPhi will speculate instructions to the end of the PHI's // predecessor blocks, so do this only if we know the srem or urem // will not fault. - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, PN)) return NV; } } diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 4cbffe9533b7..85e5b6ba2dc2 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -457,8 +457,8 @@ Instruction *InstCombiner::FoldPHIArgZextsIntoPHI(PHINode &Phi) { } // The more common cases of a phi with no constant operands or just one - // variable operand are handled by FoldPHIArgOpIntoPHI() and FoldOpIntoPhi() - // respectively. FoldOpIntoPhi() wants to do the opposite transform that is + // variable operand are handled by FoldPHIArgOpIntoPHI() and foldOpIntoPhi() + // respectively. foldOpIntoPhi() wants to do the opposite transform that is // performed here. It tries to replicate a cast in the phi operand's basic // block to expose other folding opportunities. Thus, InstCombine will // infinite loop without this check. @@ -507,7 +507,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { // Be careful about transforming integer PHIs. We don't want to pessimize // the code by turning an i32 into an i1293. if (PN.getType()->isIntegerTy() && CastSrcTy->isIntegerTy()) { - if (!ShouldChangeType(PN.getType(), CastSrcTy)) + if (!shouldChangeType(PN.getType(), CastSrcTy)) return nullptr; } } else if (isa<BinaryOperator>(FirstInst) || isa<CmpInst>(FirstInst)) { diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 36644845352e..693b6c95c169 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -120,6 +120,16 @@ static Constant *getSelectFoldableConstant(Instruction *I) { /// We have (select c, TI, FI), and we know that TI and FI have the same opcode. Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI) { + // Don't break up min/max patterns. The hasOneUse checks below prevent that + // for most cases, but vector min/max with bitcasts can be transformed. If the + // one-use restrictions are eased for other patterns, we still don't want to + // obfuscate min/max. + if ((match(&SI, m_SMin(m_Value(), m_Value())) || + match(&SI, m_SMax(m_Value(), m_Value())) || + match(&SI, m_UMin(m_Value(), m_Value())) || + match(&SI, m_UMax(m_Value(), m_Value())))) + return nullptr; + // If this is a cast from the same type, merge. if (TI->getNumOperands() == 1 && TI->isCast()) { Type *FIOpndTy = FI->getOperand(0)->getType(); @@ -364,7 +374,7 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, /// into: /// %0 = tail call i32 @llvm.cttz.i32(i32 %x, i1 false) static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, - InstCombiner::BuilderTy *Builder) { + InstCombiner::BuilderTy *Builder) { ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -395,13 +405,12 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, if (match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) || match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) { IntrinsicInst *II = cast<IntrinsicInst>(Count); - IRBuilder<> Builder(II); // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); Type *Ty = NewI->getArgOperand(1)->getType(); NewI->setArgOperand(1, Constant::getNullValue(Ty)); - Builder.Insert(NewI); - return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); + Builder->Insert(NewI); + return Builder->CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } return nullptr; @@ -500,18 +509,16 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { return true; } -/// If this is an integer min/max where the select's 'true' operand is a -/// constant, canonicalize that constant to the 'false' operand: -/// select (icmp Pred X, C), C, X --> select (icmp Pred' X, C), X, C +/// If this is an integer min/max (icmp + select) with a constant operand, +/// create the canonical icmp for the min/max operation and canonicalize the +/// constant to the 'false' operand of the select: +/// select (icmp Pred X, C1), C2, X --> select (icmp Pred' X, C2), X, C2 +/// Note: if C1 != C2, this will change the icmp constant to the existing +/// constant operand of the select. static Instruction * canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, InstCombiner::BuilderTy &Builder) { - // TODO: We should also canonicalize min/max when the select has a different - // constant value than the cmp constant, but we need to fix the backend first. - if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1)) || - !isa<Constant>(Sel.getTrueValue()) || - isa<Constant>(Sel.getFalseValue()) || - Cmp.getOperand(1) != Sel.getTrueValue()) + if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; // Canonicalize the compare predicate based on whether we have min or max. @@ -526,16 +533,25 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, default: return nullptr; } - // Canonicalize the constant to the right side. - if (isa<Constant>(LHS)) - std::swap(LHS, RHS); + // Is this already canonical? + if (Cmp.getOperand(0) == LHS && Cmp.getOperand(1) == RHS && + Cmp.getPredicate() == NewPred) + return nullptr; + + // Create the canonical compare and plug it into the select. + Sel.setCondition(Builder.CreateICmp(NewPred, LHS, RHS)); - Value *NewCmp = Builder.CreateICmp(NewPred, LHS, RHS); - SelectInst *NewSel = SelectInst::Create(NewCmp, LHS, RHS, "", nullptr, &Sel); + // If the select operands did not change, we're done. + if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) + return &Sel; - // We swapped the select operands, so swap the metadata too. - NewSel->swapProfMetadata(); - return NewSel; + // If we are swapping the select operands, swap the metadata too. + assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && + "Unexpected results from matchSelectPattern"); + Sel.setTrueValue(LHS); + Sel.setFalseValue(RHS); + Sel.swapProfMetadata(); + return &Sel; } /// Visit a SelectInst that has an ICmpInst as its first operand. @@ -786,7 +802,9 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, // This transform is performance neutral if we can elide at least one xor from // the set of three operands, since we'll be tacking on an xor at the very // end. - if (IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && + if (SelectPatternResult::isMinOrMax(SPF1) && + SelectPatternResult::isMinOrMax(SPF2) && + IsFreeOrProfitableToInvert(A, NotA, ElidesXor) && IsFreeOrProfitableToInvert(B, NotB, ElidesXor) && IsFreeOrProfitableToInvert(C, NotC, ElidesXor) && ElidesXor) { if (!NotA) @@ -1035,8 +1053,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { // If the select condition element is false, choose from the 2nd vector. Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts)); } else if (isa<UndefValue>(Elt)) { - // If the select condition element is undef, the shuffle mask is undef. - Mask.push_back(UndefValue::get(Int32Ty)); + // Undef in a select condition (choose one of the operands) does not mean + // the same thing as undef in a shuffle mask (any value is acceptable), so + // give up. + return nullptr; } else { // Bail out on a constant expression. return nullptr; @@ -1364,11 +1384,11 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } // See if we can fold the select into a phi node if the condition is a select. - if (isa<PHINode>(SI.getCondition())) + if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. if (canSelectOperandBeMappingIntoPredBlock(TrueVal, SI) && canSelectOperandBeMappingIntoPredBlock(FalseVal, SI)) - if (Instruction *NV = FoldOpIntoPhi(SI)) + if (Instruction *NV = foldOpIntoPhi(SI, PN)) return NV; if (SelectInst *TrueSI = dyn_cast<SelectInst>(TrueVal)) { @@ -1450,6 +1470,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + // If we can compute the condition, there's no need for a select. + // Like the above fold, we are attempting to reduce compile-time cost by + // putting this fold here with limitations rather than in InstSimplify. + // The motivation for this call into value tracking is to take advantage of + // the assumption cache, so make sure that is populated. + if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { + APInt KnownOne(1, 0), KnownZero(1, 0); + computeKnownBits(CondVal, KnownZero, KnownOne, 0, &SI); + if (KnownOne == 1) + return replaceInstUsesWith(SI, TrueVal); + if (KnownZero == 1) + return replaceInstUsesWith(SI, FalseVal); + } + if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, *Builder)) return BitCastSel; diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 4ff9b64ac57c..9aa679c60e47 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -22,8 +22,8 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { - assert(I.getOperand(1)->getType() == I.getOperand(0)->getType()); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + assert(Op0->getType() == Op1->getType()); // See if we can fold away this shift. if (SimplifyDemandedInstructionBits(I)) @@ -65,63 +65,60 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { } /// Return true if we can simplify two logical (either left or right) shifts -/// that have constant shift amounts. -static bool canEvaluateShiftedShift(unsigned FirstShiftAmt, - bool IsFirstShiftLeft, - Instruction *SecondShift, InstCombiner &IC, +/// that have constant shift amounts: OuterShift (InnerShift X, C1), C2. +static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, + Instruction *InnerShift, InstCombiner &IC, Instruction *CxtI) { - assert(SecondShift->isLogicalShift() && "Unexpected instruction type"); + assert(InnerShift->isLogicalShift() && "Unexpected instruction type"); - // We need constant shifts. - auto *SecondShiftConst = dyn_cast<ConstantInt>(SecondShift->getOperand(1)); - if (!SecondShiftConst) + // We need constant scalar or constant splat shifts. + const APInt *InnerShiftConst; + if (!match(InnerShift->getOperand(1), m_APInt(InnerShiftConst))) return false; - unsigned SecondShiftAmt = SecondShiftConst->getZExtValue(); - bool IsSecondShiftLeft = SecondShift->getOpcode() == Instruction::Shl; - - // We can always fold shl(c1) + shl(c2) -> shl(c1+c2). - // We can always fold lshr(c1) + lshr(c2) -> lshr(c1+c2). - if (IsFirstShiftLeft == IsSecondShiftLeft) + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + if (IsInnerShl == IsOuterShl) return true; - // We can always fold lshr(c) + shl(c) -> and(c2). - // We can always fold shl(c) + lshr(c) -> and(c2). - if (FirstShiftAmt == SecondShiftAmt) + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + unsigned InnerShAmt = InnerShiftConst->getZExtValue(); + if (InnerShAmt == OuterShAmt) return true; - unsigned TypeWidth = SecondShift->getType()->getScalarSizeInBits(); - // If the 2nd shift is bigger than the 1st, we can fold: - // lshr(c1) + shl(c2) -> shl(c3) + and(c4) or - // shl(c1) + lshr(c2) -> lshr(c3) + and(c4), + // lshr (shl X, C1), C2 --> and (shl X, C1 - C2), C3 + // shl (lshr X, C1), C2 --> and (lshr X, C1 - C2), C3 // but it isn't profitable unless we know the and'd out bits are already zero. - // Also check that the 2nd shift is valid (less than the type width) or we'll - // crash trying to produce the bit mask for the 'and'. - if (SecondShiftAmt > FirstShiftAmt && SecondShiftAmt < TypeWidth) { - unsigned MaskShift = IsSecondShiftLeft ? TypeWidth - SecondShiftAmt - : SecondShiftAmt - FirstShiftAmt; - APInt Mask = APInt::getLowBitsSet(TypeWidth, FirstShiftAmt) << MaskShift; - if (IC.MaskedValueIsZero(SecondShift->getOperand(0), Mask, 0, CxtI)) + // Also, check that the inner shift is valid (less than the type width) or + // we'll crash trying to produce the bit mask for the 'and'. + unsigned TypeWidth = InnerShift->getType()->getScalarSizeInBits(); + if (InnerShAmt > OuterShAmt && InnerShAmt < TypeWidth) { + unsigned MaskShift = + IsInnerShl ? TypeWidth - InnerShAmt : InnerShAmt - OuterShAmt; + APInt Mask = APInt::getLowBitsSet(TypeWidth, OuterShAmt) << MaskShift; + if (IC.MaskedValueIsZero(InnerShift->getOperand(0), Mask, 0, CxtI)) return true; } return false; } -/// See if we can compute the specified value, but shifted -/// logically to the left or right by some number of bits. This should return -/// true if the expression can be computed for the same cost as the current -/// expression tree. This is used to eliminate extraneous shifting from things -/// like: +/// See if we can compute the specified value, but shifted logically to the left +/// or right by some number of bits. This should return true if the expression +/// can be computed for the same cost as the current expression tree. This is +/// used to eliminate extraneous shifting from things like: /// %C = shl i128 %A, 64 /// %D = shl i128 %B, 96 /// %E = or i128 %C, %D /// %F = lshr i128 %E, 64 -/// where the client will ask if E can be computed shifted right by 64-bits. If -/// this succeeds, the GetShiftedValue function will be called to produce the -/// value. -static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, +/// where the client will ask if E can be computed shifted right by 64-bits. If +/// this succeeds, getShiftedValue() will be called to produce the value. +static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombiner &IC, Instruction *CxtI) { // We can always evaluate constants shifted. if (isa<Constant>(V)) @@ -165,8 +162,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, case Instruction::Or: case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. - return CanEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && - CanEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); + return canEvaluateShifted(I->getOperand(0), NumBits, IsLeftShift, IC, I) && + canEvaluateShifted(I->getOperand(1), NumBits, IsLeftShift, IC, I); case Instruction::Shl: case Instruction::LShr: @@ -176,8 +173,8 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, SelectInst *SI = cast<SelectInst>(I); Value *TrueVal = SI->getTrueValue(); Value *FalseVal = SI->getFalseValue(); - return CanEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && - CanEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); + return canEvaluateShifted(TrueVal, NumBits, IsLeftShift, IC, SI) && + canEvaluateShifted(FalseVal, NumBits, IsLeftShift, IC, SI); } case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -185,16 +182,79 @@ static bool CanEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (Value *IncValue : PN->incoming_values()) - if (!CanEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) + if (!canEvaluateShifted(IncValue, NumBits, IsLeftShift, IC, PN)) return false; return true; } } } -/// When CanEvaluateShifted returned true for an expression, -/// this value inserts the new computation that produces the shifted value. -static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, +/// Fold OuterShift (InnerShift X, C1), C2. +/// See canEvaluateShiftedShift() for the constraints on these instructions. +static Value *foldShiftedShift(BinaryOperator *InnerShift, unsigned OuterShAmt, + bool IsOuterShl, + InstCombiner::BuilderTy &Builder) { + bool IsInnerShl = InnerShift->getOpcode() == Instruction::Shl; + Type *ShType = InnerShift->getType(); + unsigned TypeWidth = ShType->getScalarSizeInBits(); + + // We only accept shifts-by-a-constant in canEvaluateShifted(). + const APInt *C1; + match(InnerShift->getOperand(1), m_APInt(C1)); + unsigned InnerShAmt = C1->getZExtValue(); + + // Change the shift amount and clear the appropriate IR flags. + auto NewInnerShift = [&](unsigned ShAmt) { + InnerShift->setOperand(1, ConstantInt::get(ShType, ShAmt)); + if (IsInnerShl) { + InnerShift->setHasNoUnsignedWrap(false); + InnerShift->setHasNoSignedWrap(false); + } else { + InnerShift->setIsExact(false); + } + return InnerShift; + }; + + // Two logical shifts in the same direction: + // shl (shl X, C1), C2 --> shl X, C1 + C2 + // lshr (lshr X, C1), C2 --> lshr X, C1 + C2 + if (IsInnerShl == IsOuterShl) { + // If this is an oversized composite shift, then unsigned shifts get 0. + if (InnerShAmt + OuterShAmt >= TypeWidth) + return Constant::getNullValue(ShType); + + return NewInnerShift(InnerShAmt + OuterShAmt); + } + + // Equal shift amounts in opposite directions become bitwise 'and': + // lshr (shl X, C), C --> and X, C' + // shl (lshr X, C), C --> and X, C' + if (InnerShAmt == OuterShAmt) { + APInt Mask = IsInnerShl + ? APInt::getLowBitsSet(TypeWidth, TypeWidth - OuterShAmt) + : APInt::getHighBitsSet(TypeWidth, TypeWidth - OuterShAmt); + Value *And = Builder.CreateAnd(InnerShift->getOperand(0), + ConstantInt::get(ShType, Mask)); + if (auto *AndI = dyn_cast<Instruction>(And)) { + AndI->moveBefore(InnerShift); + AndI->takeName(InnerShift); + } + return And; + } + + assert(InnerShAmt > OuterShAmt && + "Unexpected opposite direction logical shift pair"); + + // In general, we would need an 'and' for this transform, but + // canEvaluateShiftedShift() guarantees that the masked-off bits are not used. + // lshr (shl X, C1), C2 --> shl X, C1 - C2 + // shl (lshr X, C1), C2 --> lshr X, C1 - C2 + return NewInnerShift(InnerShAmt - OuterShAmt); +} + +/// When canEvaluateShifted() returns true for an expression, this function +/// inserts the new computation that produces the shifted value. +static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, InstCombiner &IC, const DataLayout &DL) { // We can always evaluate constants shifted. if (Constant *C = dyn_cast<Constant>(V)) { @@ -220,100 +280,21 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Xor: // Bitwise operators can all arbitrarily be arbitrarily evaluated shifted. I->setOperand( - 0, GetShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); + 0, getShiftedValue(I->getOperand(0), NumBits, isLeftShift, IC, DL)); I->setOperand( - 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); return I; - case Instruction::Shl: { - BinaryOperator *BO = cast<BinaryOperator>(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - - // We can always fold shl(c1)+shl(c2) -> shl(c1+c2). - if (isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(I->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return I; - } - - // We turn shl(c)+lshr(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(BO->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast<Instruction>(V)) { - VI->moveBefore(BO); - VI->takeName(BO); - } - return V; - } - - // We turn shl(c1)+shr(c2) -> shl(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setHasNoUnsignedWrap(false); - BO->setHasNoSignedWrap(false); - return BO; - } - // FIXME: This is almost identical to the SHL case. Refactor both cases into - // a helper function. - case Instruction::LShr: { - BinaryOperator *BO = cast<BinaryOperator>(I); - unsigned TypeWidth = BO->getType()->getScalarSizeInBits(); - // We only accept shifts-by-a-constant in CanEvaluateShifted. - ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); - - // We can always fold lshr(c1)+lshr(c2) -> lshr(c1+c2). - if (!isLeftShift) { - // If this is oversized composite shift, then unsigned shifts get 0. - unsigned NewShAmt = NumBits+CI->getZExtValue(); - if (NewShAmt >= TypeWidth) - return Constant::getNullValue(BO->getType()); - - BO->setOperand(1, ConstantInt::get(BO->getType(), NewShAmt)); - BO->setIsExact(false); - return I; - } - - // We turn lshr(c)+shl(c) -> and(c2) if the input doesn't already have - // zeros. - if (CI->getValue() == NumBits) { - APInt Mask(APInt::getHighBitsSet(TypeWidth, TypeWidth - NumBits)); - V = IC.Builder->CreateAnd(I->getOperand(0), - ConstantInt::get(BO->getContext(), Mask)); - if (Instruction *VI = dyn_cast<Instruction>(V)) { - VI->moveBefore(I); - VI->takeName(I); - } - return V; - } - - // We turn lshr(c1)+shl(c2) -> lshr(c3)+and(c4), but only when we know that - // the and won't be needed. - assert(CI->getZExtValue() > NumBits); - BO->setOperand(1, ConstantInt::get(BO->getType(), - CI->getZExtValue() - NumBits)); - BO->setIsExact(false); - return BO; - } + case Instruction::Shl: + case Instruction::LShr: + return foldShiftedShift(cast<BinaryOperator>(I), NumBits, isLeftShift, + *(IC.Builder)); case Instruction::Select: I->setOperand( - 1, GetShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); + 1, getShiftedValue(I->getOperand(1), NumBits, isLeftShift, IC, DL)); I->setOperand( - 2, GetShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); + 2, getShiftedValue(I->getOperand(2), NumBits, isLeftShift, IC, DL)); return I; case Instruction::PHI: { // We can change a phi if we can change all operands. Note that we never @@ -321,215 +302,39 @@ static Value *GetShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // instructions with a single use. PHINode *PN = cast<PHINode>(I); for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - PN->setIncomingValue(i, GetShiftedValue(PN->getIncomingValue(i), NumBits, + PN->setIncomingValue(i, getShiftedValue(PN->getIncomingValue(i), NumBits, isLeftShift, IC, DL)); return PN; } } } -/// Try to fold (X << C1) << C2, where the shifts are some combination of -/// shl/ashr/lshr. -static Instruction * -foldShiftByConstOfShiftByConst(BinaryOperator &I, ConstantInt *COp1, - InstCombiner::BuilderTy *Builder) { - Value *Op0 = I.getOperand(0); - uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); - - // Find out if this is a shift of a shift by a constant. - BinaryOperator *ShiftOp = dyn_cast<BinaryOperator>(Op0); - if (ShiftOp && !ShiftOp->isShift()) - ShiftOp = nullptr; - - if (ShiftOp && isa<ConstantInt>(ShiftOp->getOperand(1))) { - - // This is a constant shift of a constant shift. Be careful about hiding - // shl instructions behind bit masks. They are used to represent multiplies - // by a constant, and it is important that simple arithmetic expressions - // are still recognizable by scalar evolution. - // - // The transforms applied to shl are very similar to the transforms applied - // to mul by constant. We can be more aggressive about optimizing right - // shifts. - // - // Combinations of right and left shifts will still be optimized in - // DAGCombine where scalar evolution no longer applies. - - ConstantInt *ShiftAmt1C = cast<ConstantInt>(ShiftOp->getOperand(1)); - uint32_t ShiftAmt1 = ShiftAmt1C->getLimitedValue(TypeBits); - uint32_t ShiftAmt2 = COp1->getLimitedValue(TypeBits); - assert(ShiftAmt2 != 0 && "Should have been simplified earlier"); - if (ShiftAmt1 == 0) - return nullptr; // Will be simplified in the future. - Value *X = ShiftOp->getOperand(0); - - IntegerType *Ty = cast<IntegerType>(I.getType()); - - // Check for (X << c1) << c2 and (X >> c1) >> c2 - if (I.getOpcode() == ShiftOp->getOpcode()) { - uint32_t AmtSum = ShiftAmt1 + ShiftAmt2; // Fold into one big shift. - // If this is an oversized composite shift, then unsigned shifts become - // zero (handled in InstSimplify) and ashr saturates. - if (AmtSum >= TypeBits) { - if (I.getOpcode() != Instruction::AShr) - return nullptr; - AmtSum = TypeBits - 1; // Saturate to 31 for i32 ashr. - } - - return BinaryOperator::Create(I.getOpcode(), X, - ConstantInt::get(Ty, AmtSum)); - } - - if (ShiftAmt1 == ShiftAmt2) { - // If we have ((X << C) >>u C), turn this into X & (-1 >>u C). - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt1)); - return BinaryOperator::CreateAnd( - X, ConstantInt::get(I.getContext(), Mask)); - } - } else if (ShiftAmt1 < ShiftAmt2) { - uint32_t ShiftDiff = ShiftAmt2 - ShiftAmt1; - - // (X >>?,exact C1) << C2 --> X << (C2-C1) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { - assert(ShiftOp->getOpcode() == Instruction::LShr || - ShiftOp->getOpcode() == Instruction::AShr); - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); - return NewShl; - } - - // (X << C1) >>u C2 --> X >>u (C2-C1) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - // (X <<nuw C1) >>u C2 --> X >>u (C2-C1) - if (ShiftOp->hasNoUnsignedWrap()) { - BinaryOperator *NewLShr = - BinaryOperator::Create(Instruction::LShr, X, ShiftDiffCst); - NewLShr->setIsExact(I.isExact()); - return NewLShr; - } - Value *Shift = Builder->CreateLShr(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd( - Shift, ConstantInt::get(I.getContext(), Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X >>s (C2-C1) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewAShr = - BinaryOperator::Create(Instruction::AShr, X, ShiftDiffCst); - NewAShr->setIsExact(I.isExact()); - return NewAShr; - } - } - } else { - assert(ShiftAmt2 < ShiftAmt1); - uint32_t ShiftDiff = ShiftAmt1 - ShiftAmt2; - - // (X >>?exact C1) << C2 --> X >>?exact (C1-C2) - // The inexact version is deferred to DAGCombine so we don't hide shl - // behind a bit mask. - if (I.getOpcode() == Instruction::Shl && - ShiftOp->getOpcode() != Instruction::Shl && ShiftOp->isExact()) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShr = - BinaryOperator::Create(ShiftOp->getOpcode(), X, ShiftDiffCst); - NewShr->setIsExact(true); - return NewShr; - } - - // (X << C1) >>u C2 --> X << (C1-C2) & (-1 >> C2) - if (I.getOpcode() == Instruction::LShr && - ShiftOp->getOpcode() == Instruction::Shl) { - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - if (ShiftOp->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1-C2) - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoUnsignedWrap(true); - return NewShl; - } - Value *Shift = Builder->CreateShl(X, ShiftDiffCst); - - APInt Mask(APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt2)); - return BinaryOperator::CreateAnd( - Shift, ConstantInt::get(I.getContext(), Mask)); - } - - // We can't handle (X << C1) >>s C2, it shifts arbitrary bits in. However, - // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. - if (I.getOpcode() == Instruction::AShr && - ShiftOp->getOpcode() == Instruction::Shl) { - if (ShiftOp->hasNoSignedWrap()) { - // (X <<nsw C1) >>s C2 --> X <<nsw (C1-C2) - ConstantInt *ShiftDiffCst = ConstantInt::get(Ty, ShiftDiff); - BinaryOperator *NewShl = - BinaryOperator::Create(Instruction::Shl, X, ShiftDiffCst); - NewShl->setHasNoSignedWrap(true); - return NewShl; - } - } - } - } - - return nullptr; -} - Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; - ConstantInt *COp1 = nullptr; - if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(Op1)) - COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); - else if (ConstantVector *CV = dyn_cast<ConstantVector>(Op1)) - COp1 = dyn_cast_or_null<ConstantInt>(CV->getSplatValue()); - else - COp1 = dyn_cast<ConstantInt>(Op1); - - if (!COp1) + const APInt *Op1C; + if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. if (I.getOpcode() != Instruction::AShr && - CanEvaluateShifted(Op0, COp1->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { DEBUG(dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I <<"\n"); return replaceInstUsesWith( - I, GetShiftedValue(Op0, COp1->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole // purpose is to compute bits we don't care about. - uint32_t TypeBits = Op0->getType()->getScalarSizeInBits(); + unsigned TypeBits = Op0->getType()->getScalarSizeInBits(); - assert(!COp1->uge(TypeBits) && + assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); - // ((X*C1) << C2) == (X * (C1 << C2)) - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op0)) - if (BO->getOpcode() == Instruction::Mul && isLeftShift) - if (Constant *BOOp = dyn_cast<Constant>(BO->getOperand(1))) - return BinaryOperator::CreateMul(BO->getOperand(0), - ConstantExpr::getShl(BOOp, Op1)); - if (Instruction *FoldedShift = foldOpWithConstantIntoOperand(I)) return FoldedShift; @@ -544,7 +349,8 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, if (TrOp && I.isLogicalShift() && TrOp->isShift() && isa<ConstantInt>(TrOp->getOperand(1))) { // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(COp1, TrOp->getType()); + Constant *ShAmt = + ConstantExpr::getZExt(cast<Constant>(Op1), TrOp->getType()); // (shift2 (shift1 & 0x00FF), c2) Value *NSh = Builder->CreateBinOp(I.getOpcode(), TrOp, ShAmt,I.getName()); @@ -561,10 +367,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // shift. We know that it is a logical shift by a constant, so adjust the // mask as appropriate. if (I.getOpcode() == Instruction::Shl) - MaskV <<= COp1->getZExtValue(); + MaskV <<= Op1C->getZExtValue(); else { assert(I.getOpcode() == Instruction::LShr && "Unknown logical shift"); - MaskV = MaskV.lshr(COp1->getZExtValue()); + MaskV = MaskV.lshr(Op1C->getZExtValue()); } // shift1 & 0x00FF @@ -598,7 +404,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), YS, V1, Op0BO->getOperand(1)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -634,7 +440,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // (X + (Y << C)) Value *X = Builder->CreateBinOp(Op0BO->getOpcode(), V1, YS, Op0BO->getOperand(0)->getName()); - uint32_t Op1Val = COp1->getLimitedValue(TypeBits); + unsigned Op1Val = Op1C->getLimitedValue(TypeBits); APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); @@ -705,9 +511,6 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } } - if (Instruction *Folded = foldShiftByConstOfShiftByConst(I, COp1, Builder)) - return Folded; - return nullptr; } @@ -715,59 +518,97 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = - SimplifyShlInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), - I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), + I.hasNoUnsignedWrap(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *V = commonShiftTransforms(I)) return V; - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(I.getOperand(1))) { - unsigned ShAmt = Op1C->getZExtValue(); - - // Turn: - // %zext = zext i32 %V to i64 - // %res = shl i64 %V, 8 - // - // Into: - // %shl = shl i32 %V, 8 - // %res = zext i32 %shl to i64 - // - // This is only valid if %V would have zeros shifted out. - if (auto *ZI = dyn_cast<ZExtInst>(I.getOperand(0))) { - unsigned SrcBitWidth = ZI->getSrcTy()->getScalarSizeInBits(); - if (ShAmt < SrcBitWidth && - MaskedValueIsZero(ZI->getOperand(0), - APInt::getHighBitsSet(SrcBitWidth, ShAmt), 0, &I)) { - auto *Shl = Builder->CreateShl(ZI->getOperand(0), ShAmt); - return new ZExtInst(Shl, I.getType()); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + + // shl (zext X), ShAmt --> zext (shl X, ShAmt) + // This is only valid if X would have zeros shifted out. + Value *X; + if (match(Op0, m_ZExt(m_Value(X)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + if (ShAmt < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) + return new ZExtInst(Builder->CreateShl(X, ShAmt), Ty); + } + + // (X >>u C) << C --> X & (-1 << C) + if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) { + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + + // Be careful about hiding shl instructions behind bit masks. They are used + // to represent multiplies by a constant, and it is important that simple + // arithmetic expressions are still recognizable by scalar evolution. + // The inexact versions are deferred to DAGCombine, so we don't hide shl + // behind a bit mask. + const APInt *ShOp1; + if (match(Op0, m_CombineOr(m_Exact(m_LShr(m_Value(X), m_APInt(ShOp1))), + m_Exact(m_AShr(m_Value(X), m_APInt(ShOp1)))))) { + unsigned ShrAmt = ShOp1->getZExtValue(); + if (ShrAmt < ShAmt) { + // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); + return NewShl; } + if (ShrAmt > ShAmt) { + // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + auto *NewShr = BinaryOperator::Create( + cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); + NewShr->setIsExact(true); + return NewShr; + } + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized shifts are simplified to zero in InstSimplify. + if (AmtSum < BitWidth) + // (X << C1) << C2 --> X << (C1 + C2) + return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); } // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(I.getOperand(0), - APInt::getHighBitsSet(Op1C->getBitWidth(), ShAmt), 0, - &I)) { + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { I.setHasNoUnsignedWrap(); return &I; } - // If the shifted out value is all signbits, this is a NSW shift. - if (!I.hasNoSignedWrap() && - ComputeNumSignBits(I.getOperand(0), 0, &I) > ShAmt) { + // If the shifted-out value is all signbits, then this is a NSW shift. + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { I.setHasNoSignedWrap(); return &I; } } - // (C1 << A) << C2 -> (C1 << C2) << A - Constant *C1, *C2; - Value *A; - if (match(I.getOperand(0), m_OneUse(m_Shl(m_Constant(C1), m_Value(A)))) && - match(I.getOperand(1), m_Constant(C2))) - return BinaryOperator::CreateShl(ConstantExpr::getShl(C1, C2), A); + Constant *C1; + if (match(Op1, m_Constant(C1))) { + Constant *C2; + Value *X; + // (C2 << X) << C1 --> (C2 << C1) << X + if (match(Op0, m_OneUse(m_Shl(m_Constant(C2), m_Value(X))))) + return BinaryOperator::CreateShl(ConstantExpr::getShl(C2, C1), X); + + // (X * C2) << C1 --> X * (C2 << C1) + if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) + return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + } return nullptr; } @@ -776,43 +617,83 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyLShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, &TLI, &DT, &AC)) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - unsigned ShAmt = Op1C->getZExtValue(); - - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { - unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + auto *II = dyn_cast<IntrinsicInst>(Op0); + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && + (II->getIntrinsicID() == Intrinsic::ctlz || + II->getIntrinsicID() == Intrinsic::cttz || + II->getIntrinsicID() == Intrinsic::ctpop)) { // ctlz.i32(x)>>5 --> zext(x == 0) // cttz.i32(x)>>5 --> zext(x == 0) // ctpop.i32(x)>>5 --> zext(x == -1) - if ((II->getIntrinsicID() == Intrinsic::ctlz || - II->getIntrinsicID() == Intrinsic::cttz || - II->getIntrinsicID() == Intrinsic::ctpop) && - isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt) { - bool isCtPop = II->getIntrinsicID() == Intrinsic::ctpop; - Constant *RHS = ConstantInt::getSigned(Op0->getType(), isCtPop ? -1:0); - Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS); - return new ZExtInst(Cmp, II->getType()); + bool IsPop = II->getIntrinsicID() == Intrinsic::ctpop; + Constant *RHS = ConstantInt::getSigned(Ty, IsPop ? -1 : 0); + Value *Cmp = Builder->CreateICmpEQ(II->getArgOperand(0), RHS); + return new ZExtInst(Cmp, Ty); + } + + Value *X; + const APInt *ShOp1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1)))) { + unsigned ShlAmt = ShOp1->getZExtValue(); + if (ShlAmt < ShAmt) { + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) + auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); + NewLShr->setIsExact(I.isExact()); + return NewLShr; + } + // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) + Value *NewLShr = Builder->CreateLShr(X, ShiftDiff, "", I.isExact()); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); } + if (ShlAmt > ShAmt) { + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { + // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) + auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); + NewShl->setHasNoUnsignedWrap(true); + return NewShl; + } + // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) + Value *NewShl = Builder->CreateShl(X, ShiftDiff); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); + } + assert(ShlAmt == ShAmt); + // (X << C) >>u C --> X & (-1 >>u C) + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized shifts are simplified to zero in InstSimplify. + if (AmtSum < BitWidth) + // (X >>u C1) >>u C2 --> X >>u (C1 + C2) + return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); } // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)){ + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { I.setIsExact(); return &I; } } - return nullptr; } @@ -820,48 +701,66 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); - if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), - DL, &TLI, &DT, &AC)) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), DL, &TLI, &DT, &AC)) return replaceInstUsesWith(I, V); if (Instruction *R = commonShiftTransforms(I)) return R; - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { - unsigned ShAmt = Op1C->getZExtValue(); + Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); + const APInt *ShAmtAPInt; + if (match(Op1, m_APInt(ShAmtAPInt))) { + unsigned ShAmt = ShAmtAPInt->getZExtValue(); - // If the input is a SHL by the same constant (ashr (shl X, C), C), then we - // have a sign-extend idiom. + // If the shift amount equals the difference in width of the destination + // and source scalar types: + // ashr (shl (zext X), C), C --> sext X Value *X; - if (match(Op0, m_Shl(m_Value(X), m_Specific(Op1)))) { - // If the input is an extension from the shifted amount value, e.g. - // %x = zext i8 %A to i32 - // %y = shl i32 %x, 24 - // %z = ashr %y, 24 - // then turn this into "z = sext i8 A to i32". - if (ZExtInst *ZI = dyn_cast<ZExtInst>(X)) { - uint32_t SrcBits = ZI->getOperand(0)->getType()->getScalarSizeInBits(); - uint32_t DestBits = ZI->getType()->getScalarSizeInBits(); - if (Op1C->getZExtValue() == DestBits-SrcBits) - return new SExtInst(ZI->getOperand(0), ZI->getType()); + if (match(Op0, m_Shl(m_ZExt(m_Value(X)), m_Specific(Op1))) && + ShAmt == BitWidth - X->getType()->getScalarSizeInBits()) + return new SExtInst(X, Ty); + + // We can't handle (X << C1) >>s C2. It shifts arbitrary bits in. However, + // we can handle (X <<nsw C1) >>s C2 since it only shifts in sign bits. + const APInt *ShOp1; + if (match(Op0, m_NSWShl(m_Value(X), m_APInt(ShOp1)))) { + unsigned ShlAmt = ShOp1->getZExtValue(); + if (ShlAmt < ShAmt) { + // (X <<nsw C1) >>s C2 --> X >>s (C2 - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + auto *NewAShr = BinaryOperator::CreateAShr(X, ShiftDiff); + NewAShr->setIsExact(I.isExact()); + return NewAShr; } + if (ShlAmt > ShAmt) { + // (X <<nsw C1) >>s C2 --> X <<nsw (C1 - C2) + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + auto *NewShl = BinaryOperator::Create(Instruction::Shl, X, ShiftDiff); + NewShl->setHasNoSignedWrap(true); + return NewShl; + } + } + + if (match(Op0, m_AShr(m_Value(X), m_APInt(ShOp1)))) { + unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Oversized arithmetic shifts replicate the sign bit. + AmtSum = std::min(AmtSum, BitWidth - 1); + // (X >>s C1) >>s C2 --> X >>s (C1 + C2) + return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); } // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(Op1C->getBitWidth(), ShAmt), - 0, &I)) { + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { I.setIsExact(); return &I; } } // See if we can turn a signed shr into an unsigned shr. - if (MaskedValueIsZero(Op0, - APInt::getSignBit(I.getType()->getScalarSizeInBits()), - 0, &I)) + if (MaskedValueIsZero(Op0, APInt::getSignBit(BitWidth), 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 8b930bd95dfe..4e6f02058d83 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -30,18 +30,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, assert(I && "No instruction?"); assert(OpNo < I->getNumOperands() && "Operand index too large"); - // If the operand is not a constant integer, nothing to do. - ConstantInt *OpC = dyn_cast<ConstantInt>(I->getOperand(OpNo)); - if (!OpC) return false; + // The operand must be a constant integer or splat integer. + Value *Op = I->getOperand(OpNo); + const APInt *C; + if (!match(Op, m_APInt(C))) + return false; // If there are no bits set that aren't demanded, nothing to do. - Demanded = Demanded.zextOrTrunc(OpC->getValue().getBitWidth()); - if ((~Demanded & OpC->getValue()) == 0) + Demanded = Demanded.zextOrTrunc(C->getBitWidth()); + if ((~Demanded & *C) == 0) return false; // This instruction is producing bits that are not demanded. Shrink the RHS. - Demanded &= OpC->getValue(); - I->setOperand(OpNo, ConstantInt::get(OpC->getType(), Demanded)); + Demanded &= *C; + I->setOperand(OpNo, ConstantInt::get(Op->getType(), Demanded)); return true; } @@ -66,12 +68,13 @@ bool InstCombiner::SimplifyDemandedInstructionBits(Instruction &Inst) { /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. -bool InstCombiner::SimplifyDemandedBits(Use &U, const APInt &DemandedMask, +bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, + const APInt &DemandedMask, APInt &KnownZero, APInt &KnownOne, unsigned Depth) { - auto *UserI = dyn_cast<Instruction>(U.getUser()); + Use &U = I->getOperandUse(OpNo); Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, KnownZero, - KnownOne, Depth, UserI); + KnownOne, Depth, I); if (!NewVal) return false; U = NewVal; return true; @@ -114,9 +117,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownOne.getBitWidth() == BitWidth && "Value *V, DemandedMask, KnownZero and KnownOne " "must have same BitWidth"); - if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { - // We know all of the bits for a constant! - KnownOne = CI->getValue() & DemandedMask; + const APInt *C; + if (match(V, m_APInt(C))) { + // We know all of the bits for a scalar constant or a splat vector constant! + KnownOne = *C & DemandedMask; KnownZero = ~KnownOne & DemandedMask; return nullptr; } @@ -138,9 +142,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 6) // Limit search depth. return nullptr; - APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); - APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); - Instruction *I = dyn_cast<Instruction>(V); if (!I) { computeKnownBits(V, KnownZero, KnownOne, Depth, CxtI); @@ -151,107 +152,43 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // we can't do any simplifications of the operands, because DemandedMask // only reflects the bits demanded by *one* of the users. if (Depth != 0 && !I->hasOneUse()) { - // Despite the fact that we can't simplify this instruction in all User's - // context, we can at least compute the knownzero/knownone bits, and we can - // do simplifications that apply to *just* the one user if we know that - // this instruction has a simpler value in that context. - if (I->getOpcode() == Instruction::And) { - // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known 1 on one side, return the other. - // These bits cannot contribute to the result of the 'and' in this - // context. - if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == - (DemandedMask & ~LHSKnownZero)) - return I->getOperand(0); - if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == - (DemandedMask & ~RHSKnownZero)) - return I->getOperand(1); - - // If all of the demanded bits in the inputs are known zeros, return zero. - if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) - return Constant::getNullValue(VTy); - - } else if (I->getOpcode() == Instruction::Or) { - // We can simplify (X|Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - // If either the LHS or the RHS are One, the result is One. - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known zero on one side, return the - // other. These bits cannot contribute to the result of the 'or' in this - // context. - if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == - (DemandedMask & ~LHSKnownOne)) - return I->getOperand(0); - if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == - (DemandedMask & ~RHSKnownOne)) - return I->getOperand(1); - - // If all of the potentially set bits on one side are known to be set on - // the other side, just use the 'other' side. - if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == - (DemandedMask & (~RHSKnownZero))) - return I->getOperand(0); - if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == - (DemandedMask & (~LHSKnownZero))) - return I->getOperand(1); - } else if (I->getOpcode() == Instruction::Xor) { - // We can simplify (X^Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, - CxtI); - computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, - CxtI); - - // If all of the demanded bits are known zero on one side, return the - // other. - if ((DemandedMask & RHSKnownZero) == DemandedMask) - return I->getOperand(0); - if ((DemandedMask & LHSKnownZero) == DemandedMask) - return I->getOperand(1); - } - - // Compute the KnownZero/KnownOne bits to simplify things downstream. - computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); - return nullptr; + return SimplifyMultipleUseDemandedBits(I, DemandedMask, KnownZero, KnownOne, + Depth, CxtI); } + APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + // If this is the root being simplified, allow it to have multiple uses, // just set the DemandedMask to all bits so that we can try to simplify the // operands. This allows visitTruncInst (for example) to simplify the // operand of a trunc without duplicating all the logic below. if (Depth == 0 && !V->hasOneUse()) - DemandedMask = APInt::getAllOnesValue(BitWidth); + DemandedMask.setAllBits(); switch (I->getOpcode()) { default: computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); break; - case Instruction::And: + case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownZero, - LHSKnownZero, LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, + Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownZero, LHSKnownZero, + LHSKnownOne, Depth + 1)) return I; assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // Output known-0 are known to be clear if zero in either the LHS | RHS. + APInt IKnownZero = RHSKnownZero | LHSKnownZero; + // Output known-1 bits are only known if set in both the LHS & RHS. + APInt IKnownOne = RHSKnownOne & LHSKnownOne; + // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & ((RHSKnownZero | LHSKnownZero)| - (RHSKnownOne & LHSKnownOne))) == DemandedMask) - return Constant::getIntegerValue(VTy, RHSKnownOne & LHSKnownOne); + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(VTy, IKnownOne); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and'. @@ -262,34 +199,33 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, (DemandedMask & ~RHSKnownZero)) return I->getOperand(1); - // If all of the demanded bits in the inputs are known zeros, return zero. - if ((DemandedMask & (RHSKnownZero|LHSKnownZero)) == DemandedMask) - return Constant::getNullValue(VTy); - // If the RHS is a constant, see if we can simplify it. if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnownZero)) return I; - // Output known-1 bits are only known if set in both the LHS & RHS. - KnownOne = RHSKnownOne & LHSKnownOne; - // Output known-0 are known to be clear if zero in either the LHS | RHS. - KnownZero = RHSKnownZero | LHSKnownZero; + KnownZero = std::move(IKnownZero); + KnownOne = std::move(IKnownOne); break; - case Instruction::Or: + } + case Instruction::Or: { // If either the LHS or the RHS are One, the result is One. - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask & ~RHSKnownOne, - LHSKnownZero, LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, + Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnownOne, LHSKnownZero, + LHSKnownOne, Depth + 1)) return I; assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); + // Output known-0 bits are only known if clear in both the LHS & RHS. + APInt IKnownZero = RHSKnownZero & LHSKnownZero; + // Output known-1 are known to be set if set in either the LHS | RHS. + APInt IKnownOne = RHSKnownOne | LHSKnownOne; + // If the client is only demanding bits that we know, return the known // constant. - if ((DemandedMask & ((RHSKnownZero & LHSKnownZero)| - (RHSKnownOne | LHSKnownOne))) == DemandedMask) - return Constant::getIntegerValue(VTy, RHSKnownOne | LHSKnownOne); + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(VTy, IKnownOne); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. @@ -313,16 +249,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (ShrinkDemandedConstant(I, 1, DemandedMask)) return I; - // Output known-0 bits are only known if clear in both the LHS & RHS. - KnownZero = RHSKnownZero & LHSKnownZero; - // Output known-1 are known to be set if set in either the LHS | RHS. - KnownOne = RHSKnownOne | LHSKnownOne; + KnownZero = std::move(IKnownZero); + KnownOne = std::move(IKnownOne); break; + } case Instruction::Xor: { - if (SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, LHSKnownZero, - LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnownZero, RHSKnownOne, + Depth + 1) || + SimplifyDemandedBits(I, 0, DemandedMask, LHSKnownZero, LHSKnownOne, + Depth + 1)) return I; assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); @@ -400,9 +335,9 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } // Output known-0 bits are known if clear or set in both the LHS & RHS. - KnownZero= (RHSKnownZero & LHSKnownZero) | (RHSKnownOne & LHSKnownOne); + KnownZero = std::move(IKnownZero); // Output known-1 are known to be set if set in only one of the LHS, RHS. - KnownOne = (RHSKnownZero & LHSKnownOne) | (RHSKnownOne & LHSKnownZero); + KnownOne = std::move(IKnownOne); break; } case Instruction::Select: @@ -412,10 +347,10 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (matchSelectPattern(I, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; - if (SimplifyDemandedBits(I->getOperandUse(2), DemandedMask, RHSKnownZero, - RHSKnownOne, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(1), DemandedMask, LHSKnownZero, - LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 2, DemandedMask, RHSKnownZero, RHSKnownOne, + Depth + 1) || + SimplifyDemandedBits(I, 1, DemandedMask, LHSKnownZero, LHSKnownOne, + Depth + 1)) return I; assert(!(RHSKnownZero & RHSKnownOne) && "Bits known to be one AND zero?"); assert(!(LHSKnownZero & LHSKnownOne) && "Bits known to be one AND zero?"); @@ -434,8 +369,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask = DemandedMask.zext(truncBf); KnownZero = KnownZero.zext(truncBf); KnownOne = KnownOne.zext(truncBf); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, + Depth + 1)) return I; DemandedMask = DemandedMask.trunc(BitWidth); KnownZero = KnownZero.trunc(BitWidth); @@ -460,8 +395,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Don't touch a vector-to-scalar bitcast. return nullptr; - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, + Depth + 1)) return I; assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); break; @@ -472,15 +407,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, DemandedMask = DemandedMask.trunc(SrcBitWidth); KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMask, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMask, KnownZero, KnownOne, + Depth + 1)) return I; DemandedMask = DemandedMask.zext(BitWidth); KnownZero = KnownZero.zext(BitWidth); KnownOne = KnownOne.zext(BitWidth); assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); // The top bits are known to be zero. - KnownZero |= APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth); + KnownZero.setBitsFrom(SrcBitWidth); break; } case Instruction::SExt: { @@ -490,7 +425,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt InputDemandedBits = DemandedMask & APInt::getLowBitsSet(BitWidth, SrcBitWidth); - APInt NewBits(APInt::getHighBitsSet(BitWidth, BitWidth - SrcBitWidth)); + APInt NewBits(APInt::getBitsSetFrom(BitWidth, SrcBitWidth)); // If any of the sign extended bits are demanded, we know that the sign // bit is demanded. if ((NewBits & DemandedMask) != 0) @@ -499,8 +434,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, InputDemandedBits = InputDemandedBits.trunc(SrcBitWidth); KnownZero = KnownZero.trunc(SrcBitWidth); KnownOne = KnownOne.trunc(SrcBitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), InputDemandedBits, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, InputDemandedBits, KnownZero, KnownOne, + Depth + 1)) return I; InputDemandedBits = InputDemandedBits.zext(BitWidth); KnownZero = KnownZero.zext(BitWidth); @@ -530,11 +465,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Right fill the mask of bits for this ADD/SUB to demand the most // significant bit and all those below it. APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedFromOps, - LHSKnownZero, LHSKnownOne, Depth + 1) || + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnownZero, LHSKnownOne, + Depth + 1) || ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I->getOperandUse(1), DemandedFromOps, - LHSKnownZero, LHSKnownOne, Depth + 1)) { + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnownZero, RHSKnownOne, + Depth + 1)) { // Disable the nsw and nuw flags here: We can no longer guarantee that // we won't wrap after simplification. Removing the nsw/nuw flags is // legal here because the top bit is not demanded. @@ -543,6 +479,15 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, BinOP.setHasNoUnsignedWrap(false); return I; } + + // If we are known to be adding/subtracting zeros to every bit below + // the highest demanded bit, we just return the other side. + if ((DemandedFromOps & RHSKnownZero) == DemandedFromOps) + return I->getOperand(0); + // We can't do this with the LHS for subtraction. + if (I->getOpcode() == Instruction::Add && + (DemandedFromOps & LHSKnownZero) == DemandedFromOps) + return I->getOperand(1); } // Otherwise just hand the add/sub off to computeKnownBits to fill in @@ -569,19 +514,19 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the shift is NUW/NSW, then it does demand the high bits. ShlOperator *IOp = cast<ShlOperator>(I); if (IOp->hasNoSignedWrap()) - DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt+1); + DemandedMaskIn.setHighBits(ShiftAmt+1); else if (IOp->hasNoUnsignedWrap()) - DemandedMaskIn |= APInt::getHighBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setHighBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, + Depth + 1)) return I; assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); KnownZero <<= ShiftAmt; KnownOne <<= ShiftAmt; // low bits known zero. if (ShiftAmt) - KnownZero |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + KnownZero.setLowBits(ShiftAmt); } break; case Instruction::LShr: @@ -595,19 +540,16 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If the shift is exact, then it does demand the low bits (and knows that // they are zero). if (cast<LShrOperator>(I)->isExact()) - DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setLowBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, + Depth + 1)) return I; assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); - KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); - KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); - if (ShiftAmt) { - // Compute the new bits that are at the top now. - APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); - KnownZero |= HighBits; // high bits known zero. - } + KnownZero = KnownZero.lshr(ShiftAmt); + KnownOne = KnownOne.lshr(ShiftAmt); + if (ShiftAmt) + KnownZero.setHighBits(ShiftAmt); // high bits known zero. } break; case Instruction::AShr: @@ -635,26 +577,26 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If any of the "high bits" are demanded, we should set the sign bit as // demanded. if (DemandedMask.countLeadingZeros() <= ShiftAmt) - DemandedMaskIn.setBit(BitWidth-1); + DemandedMaskIn.setSignBit(); // If the shift is exact, then it does demand the low bits (and knows that // they are zero). if (cast<AShrOperator>(I)->isExact()) - DemandedMaskIn |= APInt::getLowBitsSet(BitWidth, ShiftAmt); + DemandedMaskIn.setLowBits(ShiftAmt); - if (SimplifyDemandedBits(I->getOperandUse(0), DemandedMaskIn, KnownZero, - KnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, KnownZero, KnownOne, + Depth + 1)) return I; assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now. APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); - KnownZero = APIntOps::lshr(KnownZero, ShiftAmt); - KnownOne = APIntOps::lshr(KnownOne, ShiftAmt); + KnownZero = KnownZero.lshr(ShiftAmt); + KnownOne = KnownOne.lshr(ShiftAmt); // Handle the sign bits. APInt SignBit(APInt::getSignBit(BitWidth)); // Adjust to where it is now in the mask. - SignBit = APIntOps::lshr(SignBit, ShiftAmt); + SignBit = SignBit.lshr(ShiftAmt); // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. @@ -683,8 +625,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt LowBits = RA - 1; APInt Mask2 = LowBits | APInt::getSignBit(BitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), Mask2, LHSKnownZero, - LHSKnownOne, Depth + 1)) + if (SimplifyDemandedBits(I, 0, Mask2, LHSKnownZero, LHSKnownOne, + Depth + 1)) return I; // The low bits of LHS are unchanged by the srem. @@ -693,12 +635,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If LHS is non-negative or has all low bits zero, then the upper bits // are all zero. - if (LHSKnownZero[BitWidth-1] || ((LHSKnownZero & LowBits) == LowBits)) + if (LHSKnownZero.isNegative() || ((LHSKnownZero & LowBits) == LowBits)) KnownZero |= ~LowBits; // If LHS is negative and not all low bits are zero, then the upper bits // are all one. - if (LHSKnownOne[BitWidth-1] && ((LHSKnownOne & LowBits) != 0)) + if (LHSKnownOne.isNegative() && ((LHSKnownOne & LowBits) != 0)) KnownOne |= ~LowBits; assert(!(KnownZero & KnownOne) && "Bits known to be one AND zero?"); @@ -713,21 +655,17 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, CxtI); // If it's known zero, our sign bit is also zero. if (LHSKnownZero.isNegative()) - KnownZero.setBit(KnownZero.getBitWidth() - 1); + KnownZero.setSignBit(); } break; case Instruction::URem: { APInt KnownZero2(BitWidth, 0), KnownOne2(BitWidth, 0); APInt AllOnes = APInt::getAllOnesValue(BitWidth); - if (SimplifyDemandedBits(I->getOperandUse(0), AllOnes, KnownZero2, - KnownOne2, Depth + 1) || - SimplifyDemandedBits(I->getOperandUse(1), AllOnes, KnownZero2, - KnownOne2, Depth + 1)) + if (SimplifyDemandedBits(I, 0, AllOnes, KnownZero2, KnownOne2, Depth + 1) || + SimplifyDemandedBits(I, 1, AllOnes, KnownZero2, KnownOne2, Depth + 1)) return I; unsigned Leaders = KnownZero2.countLeadingOnes(); - Leaders = std::max(Leaders, - KnownZero2.countLeadingOnes()); KnownZero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; break; } @@ -792,11 +730,11 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return ConstantInt::getNullValue(VTy); // We know that the upper bits are set to zero. - KnownZero = APInt::getHighBitsSet(BitWidth, BitWidth - ArgWidth); + KnownZero.setBitsFrom(ArgWidth); return nullptr; } case Intrinsic::x86_sse42_crc32_64_64: - KnownZero = APInt::getHighBitsSet(64, 32); + KnownZero.setBitsFrom(32); return nullptr; } } @@ -811,6 +749,150 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return nullptr; } +/// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne +/// bits. It also tries to handle simplifications that can be done based on +/// DemandedMask, but without modifying the Instruction. +Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, + const APInt &DemandedMask, + APInt &KnownZero, + APInt &KnownOne, + unsigned Depth, + Instruction *CxtI) { + unsigned BitWidth = DemandedMask.getBitWidth(); + Type *ITy = I->getType(); + + APInt LHSKnownZero(BitWidth, 0), LHSKnownOne(BitWidth, 0); + APInt RHSKnownZero(BitWidth, 0), RHSKnownOne(BitWidth, 0); + + // Despite the fact that we can't simplify this instruction in all User's + // context, we can at least compute the knownzero/knownone bits, and we can + // do simplifications that apply to *just* the one user if we know that + // this instruction has a simpler value in that context. + switch (I->getOpcode()) { + case Instruction::And: { + // If either the LHS or the RHS are Zero, the result is zero. + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, + CxtI); + + // Output known-0 are known to be clear if zero in either the LHS | RHS. + APInt IKnownZero = RHSKnownZero | LHSKnownZero; + // Output known-1 bits are only known if set in both the LHS & RHS. + APInt IKnownOne = RHSKnownOne & LHSKnownOne; + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known 1 on one side, return the other. + // These bits cannot contribute to the result of the 'and' in this + // context. + if ((DemandedMask & ~LHSKnownZero & RHSKnownOne) == + (DemandedMask & ~LHSKnownZero)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownZero & LHSKnownOne) == + (DemandedMask & ~RHSKnownZero)) + return I->getOperand(1); + + KnownZero = std::move(IKnownZero); + KnownOne = std::move(IKnownOne); + break; + } + case Instruction::Or: { + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + // If either the LHS or the RHS are One, the result is One. + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, + CxtI); + + // Output known-0 bits are only known if clear in both the LHS & RHS. + APInt IKnownZero = RHSKnownZero & LHSKnownZero; + // Output known-1 are known to be set if set in either the LHS | RHS. + APInt IKnownOne = RHSKnownOne | LHSKnownOne; + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known zero on one side, return the + // other. These bits cannot contribute to the result of the 'or' in this + // context. + if ((DemandedMask & ~LHSKnownOne & RHSKnownZero) == + (DemandedMask & ~LHSKnownOne)) + return I->getOperand(0); + if ((DemandedMask & ~RHSKnownOne & LHSKnownZero) == + (DemandedMask & ~RHSKnownOne)) + return I->getOperand(1); + + // If all of the potentially set bits on one side are known to be set on + // the other side, just use the 'other' side. + if ((DemandedMask & (~RHSKnownZero) & LHSKnownOne) == + (DemandedMask & (~RHSKnownZero))) + return I->getOperand(0); + if ((DemandedMask & (~LHSKnownZero) & RHSKnownOne) == + (DemandedMask & (~LHSKnownZero))) + return I->getOperand(1); + + KnownZero = std::move(IKnownZero); + KnownOne = std::move(IKnownOne); + break; + } + case Instruction::Xor: { + // We can simplify (X^Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + + computeKnownBits(I->getOperand(1), RHSKnownZero, RHSKnownOne, Depth + 1, + CxtI); + computeKnownBits(I->getOperand(0), LHSKnownZero, LHSKnownOne, Depth + 1, + CxtI); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + APInt IKnownZero = (RHSKnownZero & LHSKnownZero) | + (RHSKnownOne & LHSKnownOne); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + APInt IKnownOne = (RHSKnownZero & LHSKnownOne) | + (RHSKnownOne & LHSKnownZero); + + // If the client is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (IKnownZero|IKnownOne)) == DemandedMask) + return Constant::getIntegerValue(ITy, IKnownOne); + + // If all of the demanded bits are known zero on one side, return the + // other. + if ((DemandedMask & RHSKnownZero) == DemandedMask) + return I->getOperand(0); + if ((DemandedMask & LHSKnownZero) == DemandedMask) + return I->getOperand(1); + + // Output known-0 bits are known if clear or set in both the LHS & RHS. + KnownZero = std::move(IKnownZero); + // Output known-1 are known to be set if set in only one of the LHS, RHS. + KnownOne = std::move(IKnownOne); + break; + } + default: + // Compute the KnownZero/KnownOne bits to simplify things downstream. + computeKnownBits(I, KnownZero, KnownOne, Depth, CxtI); + + // If this user is only demanding bits that we know, return the known + // constant. + if ((DemandedMask & (KnownZero|KnownOne)) == DemandedMask) + return Constant::getIntegerValue(ITy, KnownOne); + + break; + } + + return nullptr; +} + + /// Helper routine of SimplifyDemandedUseBits. It tries to simplify /// "E1 = (X lsr C1) << C2", where the C1 and C2 are constant, into /// "E2 = X << (C2 - C1)" or "E2 = X >> (C1 - C2)", depending on the sign @@ -849,7 +931,7 @@ Value *InstCombiner::SimplifyShrShlDemandedBits(Instruction *Shr, unsigned ShrAmt = ShrOp1.getZExtValue(); KnownOne.clearAllBits(); - KnownZero = APInt::getBitsSet(KnownZero.getBitWidth(), 0, ShlAmt-1); + KnownZero.setLowBits(ShlAmt - 1); KnownZero &= DemandedMask; APInt BitMask1(APInt::getAllOnesValue(BitWidth)); @@ -1472,14 +1554,136 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_sse2_packuswb_128: + case Intrinsic::x86_sse41_packusdw: + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx2_packusdw: + case Intrinsic::x86_avx2_packuswb: + case Intrinsic::x86_avx512_packssdw_512: + case Intrinsic::x86_avx512_packsswb_512: + case Intrinsic::x86_avx512_packusdw_512: + case Intrinsic::x86_avx512_packuswb_512: { + auto *Ty0 = II->getArgOperand(0)->getType(); + unsigned InnerVWidth = Ty0->getVectorNumElements(); + assert(VWidth == (InnerVWidth * 2) && "Unexpected input size"); + + unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128; + unsigned VWidthPerLane = VWidth / NumLanes; + unsigned InnerVWidthPerLane = InnerVWidth / NumLanes; + + // Per lane, pack the elements of the first input and then the second. + // e.g. + // v8i16 PACK(v4i32 X, v4i32 Y) - (X[0..3],Y[0..3]) + // v32i8 PACK(v16i16 X, v16i16 Y) - (X[0..7],Y[0..7]),(X[8..15],Y[8..15]) + for (int OpNum = 0; OpNum != 2; ++OpNum) { + APInt OpDemandedElts(InnerVWidth, 0); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + unsigned LaneIdx = Lane * VWidthPerLane; + for (unsigned Elt = 0; Elt != InnerVWidthPerLane; ++Elt) { + unsigned Idx = LaneIdx + Elt + InnerVWidthPerLane * OpNum; + if (DemandedElts[Idx]) + OpDemandedElts.setBit((Lane * InnerVWidthPerLane) + Elt); + } + } + + // Demand elements from the operand. + auto *Op = II->getArgOperand(OpNum); + APInt OpUndefElts(InnerVWidth, 0); + TmpV = SimplifyDemandedVectorElts(Op, OpDemandedElts, OpUndefElts, + Depth + 1); + if (TmpV) { + II->setArgOperand(OpNum, TmpV); + MadeChange = true; + } + + // Pack the operand's UNDEF elements, one lane at a time. + OpUndefElts = OpUndefElts.zext(VWidth); + for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { + APInt LaneElts = OpUndefElts.lshr(InnerVWidthPerLane * Lane); + LaneElts = LaneElts.getLoBits(InnerVWidthPerLane); + LaneElts = LaneElts.shl(InnerVWidthPerLane * (2 * Lane + OpNum)); + UndefElts |= LaneElts; + } + } + break; + } + + // PSHUFB + case Intrinsic::x86_ssse3_pshuf_b_128: + case Intrinsic::x86_avx2_pshuf_b: + case Intrinsic::x86_avx512_pshuf_b_512: + // PERMILVAR + case Intrinsic::x86_avx_vpermilvar_ps: + case Intrinsic::x86_avx_vpermilvar_ps_256: + case Intrinsic::x86_avx512_vpermilvar_ps_512: + case Intrinsic::x86_avx_vpermilvar_pd: + case Intrinsic::x86_avx_vpermilvar_pd_256: + case Intrinsic::x86_avx512_vpermilvar_pd_512: + // PERMV + case Intrinsic::x86_avx2_permd: + case Intrinsic::x86_avx2_permps: { + Value *Op1 = II->getArgOperand(1); + TmpV = SimplifyDemandedVectorElts(Op1, DemandedElts, UndefElts, + Depth + 1); + if (TmpV) { II->setArgOperand(1, TmpV); MadeChange = true; } + break; + } + // SSE4A instructions leave the upper 64-bits of the 128-bit result // in an undefined state. case Intrinsic::x86_sse4a_extrq: case Intrinsic::x86_sse4a_extrqi: case Intrinsic::x86_sse4a_insertq: case Intrinsic::x86_sse4a_insertqi: - UndefElts |= APInt::getHighBitsSet(VWidth, VWidth / 2); + UndefElts.setHighBits(VWidth / 2); break; + case Intrinsic::amdgcn_buffer_load: + case Intrinsic::amdgcn_buffer_load_format: { + if (VWidth == 1 || !DemandedElts.isMask()) + return nullptr; + + // TODO: Handle 3 vectors when supported in code gen. + unsigned NewNumElts = PowerOf2Ceil(DemandedElts.countTrailingOnes()); + if (NewNumElts == VWidth) + return nullptr; + + Module *M = II->getParent()->getParent()->getParent(); + Type *EltTy = V->getType()->getVectorElementType(); + + Type *NewTy = (NewNumElts == 1) ? EltTy : + VectorType::get(EltTy, NewNumElts); + + Function *NewIntrin = Intrinsic::getDeclaration(M, II->getIntrinsicID(), + NewTy); + + SmallVector<Value *, 5> Args; + for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) + Args.push_back(II->getArgOperand(I)); + + IRBuilderBase::InsertPointGuard Guard(*Builder); + Builder->SetInsertPoint(II); + + CallInst *NewCall = Builder->CreateCall(NewIntrin, Args); + NewCall->takeName(II); + NewCall->copyMetadata(*II); + if (NewNumElts == 1) { + return Builder->CreateInsertElement(UndefValue::get(V->getType()), + NewCall, static_cast<uint64_t>(0)); + } + + SmallVector<uint32_t, 8> EltMask; + for (unsigned I = 0; I < VWidth; ++I) + EltMask.push_back(I); + + Value *Shuffle = Builder->CreateShuffleVector( + NewCall, UndefValue::get(NewTy), EltMask); + + MadeChange = true; + return Shuffle; + } } break; } diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b2477f6c8633..e89b400a4afc 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -645,6 +645,36 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { return new ShuffleVectorInst(InsertFirst, UndefValue::get(VT), ZeroMask); } +/// If we have an insertelement instruction feeding into another insertelement +/// and the 2nd is inserting a constant into the vector, canonicalize that +/// constant insertion before the insertion of a variable: +/// +/// insertelement (insertelement X, Y, IdxC1), ScalarC, IdxC2 --> +/// insertelement (insertelement X, ScalarC, IdxC2), Y, IdxC1 +/// +/// This has the potential of eliminating the 2nd insertelement instruction +/// via constant folding of the scalar constant into a vector constant. +static Instruction *hoistInsEltConst(InsertElementInst &InsElt2, + InstCombiner::BuilderTy &Builder) { + auto *InsElt1 = dyn_cast<InsertElementInst>(InsElt2.getOperand(0)); + if (!InsElt1 || !InsElt1->hasOneUse()) + return nullptr; + + Value *X, *Y; + Constant *ScalarC; + ConstantInt *IdxC1, *IdxC2; + if (match(InsElt1->getOperand(0), m_Value(X)) && + match(InsElt1->getOperand(1), m_Value(Y)) && !isa<Constant>(Y) && + match(InsElt1->getOperand(2), m_ConstantInt(IdxC1)) && + match(InsElt2.getOperand(1), m_Constant(ScalarC)) && + match(InsElt2.getOperand(2), m_ConstantInt(IdxC2)) && IdxC1 != IdxC2) { + Value *NewInsElt1 = Builder.CreateInsertElement(X, ScalarC, IdxC2); + return InsertElementInst::Create(NewInsElt1, Y, IdxC1); + } + + return nullptr; +} + /// insertelt (shufflevector X, CVec, Mask|insertelt X, C1, CIndex1), C, CIndex /// --> shufflevector X, CVec', Mask' static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { @@ -806,6 +836,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) return Shuf; + if (Instruction *NewInsElt = hoistInsEltConst(IE, *Builder)) + return NewInsElt; + // Turn a sequence of inserts that broadcasts a scalar into a single // insert + shufflevector. if (Instruction *Broadcast = foldInsSequenceIntoBroadcast(IE)) @@ -1107,12 +1140,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SmallVector<int, 16> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); - bool MadeChange = false; - - // Undefined shuffle mask -> undefined value. - if (isa<UndefValue>(SVI.getOperand(2))) - return replaceInstUsesWith(SVI, UndefValue::get(SVI.getType())); + if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getMask(), + SVI.getType(), DL, &TLI, &DT, &AC)) + return replaceInstUsesWith(SVI, V); + bool MadeChange = false; unsigned VWidth = SVI.getType()->getVectorNumElements(); APInt UndefElts(VWidth, 0); @@ -1209,7 +1241,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (isShuffleExtractingFromLHS(SVI, Mask)) { Value *V = LHS; unsigned MaskElems = Mask.size(); - unsigned BegIdx = Mask.front(); VectorType *SrcTy = cast<VectorType>(V->getType()); unsigned VecBitWidth = SrcTy->getBitWidth(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); @@ -1223,6 +1254,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // Only visit bitcasts that weren't previously handled. BCs.push_back(BC); for (BitCastInst *BC : BCs) { + unsigned BegIdx = Mask.front(); Type *TgtTy = BC->getDestTy(); unsigned TgtElemBitWidth = DL.getTypeSizeInBits(TgtTy); if (!TgtElemBitWidth) diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 27fc34d23175..88ef17bbc8fa 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -82,18 +82,24 @@ static cl::opt<bool> EnableExpensiveCombines("expensive-combines", cl::desc("Enable expensive instruction combines")); +static cl::opt<unsigned> +MaxArraySize("instcombine-maxarray-size", cl::init(1024), + cl::desc("Maximum array size considered when doing a combine")); + Value *InstCombiner::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(Builder, DL, GEP); } /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(unsigned FromWidth, +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. A width of '1' is always treated as a legal type +/// because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. +bool InstCombiner::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { - bool FromLegal = DL.isLegalInteger(FromWidth); - bool ToLegal = DL.isLegalInteger(ToWidth); + bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); + bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); // If this is a legal integer from type, and the result would be an illegal // type, don't do the transformation. @@ -109,14 +115,16 @@ bool InstCombiner::ShouldChangeType(unsigned FromWidth, } /// Return true if it is desirable to convert a computation from 'From' to 'To'. -/// We don't want to convert from a legal to an illegal type for example or from -/// a smaller to a larger illegal type. -bool InstCombiner::ShouldChangeType(Type *From, Type *To) const { +/// We don't want to convert from a legal to an illegal type or from a smaller +/// to a larger illegal type. i1 is always treated as a legal type because it is +/// a fundamental type in IR, and there are many specialized optimizations for +/// i1 types. +bool InstCombiner::shouldChangeType(Type *From, Type *To) const { assert(From->isIntegerTy() && To->isIntegerTy()); unsigned FromWidth = From->getPrimitiveSizeInBits(); unsigned ToWidth = To->getPrimitiveSizeInBits(); - return ShouldChangeType(FromWidth, ToWidth); + return shouldChangeType(FromWidth, ToWidth); } // Return true, if No Signed Wrap should be maintained for I. @@ -447,16 +455,11 @@ static bool RightDistributesOverLeft(Instruction::BinaryOps LOp, /// This function returns identity value for given opcode, which can be used to /// factor patterns like (X * 2) + X ==> (X * 2) + (X * 1) ==> X * (2 + 1). -static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { +static Value *getIdentityValue(Instruction::BinaryOps Opcode, Value *V) { if (isa<Constant>(V)) return nullptr; - if (OpCode == Instruction::Mul) - return ConstantInt::get(V->getType(), 1); - - // TODO: We can handle other cases e.g. Instruction::And, Instruction::Or etc. - - return nullptr; + return ConstantExpr::getBinOpIdentity(Opcode, V->getType()); } /// This function factors binary ops which can be combined using distributive @@ -468,8 +471,7 @@ static Value *getIdentityValue(Instruction::BinaryOps OpCode, Value *V) { static Instruction::BinaryOps getBinOpsForFactorization(Instruction::BinaryOps TopLevelOpcode, BinaryOperator *Op, Value *&LHS, Value *&RHS) { - if (!Op) - return Instruction::BinaryOpsEnd; + assert(Op && "Expected a binary operator"); LHS = Op->getOperand(0); RHS = Op->getOperand(1); @@ -499,11 +501,7 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder, const DataLayout &DL, BinaryOperator &I, Instruction::BinaryOps InnerOpcode, Value *A, Value *B, Value *C, Value *D) { - - // If any of A, B, C, D are null, we can not factor I, return early. - // Checking A and C should be enough. - if (!A || !C || !B || !D) - return nullptr; + assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; Value *SimplifiedInst = nullptr; @@ -564,13 +562,11 @@ static Value *tryFactorization(InstCombiner::BuilderTy *Builder, if (isa<OverflowingBinaryOperator>(&I)) HasNSW = I.hasNoSignedWrap(); - if (BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS)) - if (isa<OverflowingBinaryOperator>(Op0)) - HasNSW &= Op0->hasNoSignedWrap(); + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) + HasNSW &= LOBO->hasNoSignedWrap(); - if (BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS)) - if (isa<OverflowingBinaryOperator>(Op1)) - HasNSW &= Op1->hasNoSignedWrap(); + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) + HasNSW &= ROBO->hasNoSignedWrap(); // We can propagate 'nsw' if we know that // %Y = mul nsw i16 %X, C @@ -599,31 +595,39 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); - // Factorization. - Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr; - auto TopLevelOpcode = I.getOpcode(); - auto LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); - auto RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - - // The instruction has the form "(A op' B) op (C op' D)". Try to factorize - // a common term. - if (LHSOpcode == RHSOpcode) { - if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D)) - return V; - } - - // The instruction has the form "(A op' B) op (C)". Try to factorize common - // term. - if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS, - getIdentityValue(LHSOpcode, RHS))) - return V; + { + // Factorization. + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = tryFactorization(Builder, DL, I, LHSOpcode, A, B, RHS, + Ident)) + return V; - // The instruction has the form "(B) op (C op' D)". Try to factorize common - // term. - if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, - getIdentityValue(RHSOpcode, LHS), C, D)) - return V; + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = tryFactorization(Builder, DL, I, RHSOpcode, LHS, Ident, + C, D)) + return V; + } // Expansion. if (Op0 && RightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { @@ -720,6 +724,21 @@ Value *InstCombiner::dyn_castNegVal(Value *V) const { if (C->getType()->getElementType()->isIntegerTy()) return ConstantExpr::getNeg(C); + if (ConstantVector *CV = dyn_cast<ConstantVector>(V)) { + for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + Constant *Elt = CV->getAggregateElement(i); + if (!Elt) + return nullptr; + + if (isa<UndefValue>(Elt)) + continue; + + if (!isa<ConstantInt>(Elt)) + return nullptr; + } + return ConstantExpr::getNeg(CV); + } + return nullptr; } @@ -820,8 +839,29 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } -Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { - PHINode *PN = cast<PHINode>(I.getOperand(0)); +static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, + InstCombiner *IC) { + bool ConstIsRHS = isa<Constant>(I->getOperand(1)); + Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); + + if (auto *InC = dyn_cast<Constant>(InV)) { + if (ConstIsRHS) + return ConstantExpr::get(I->getOpcode(), InC, C); + return ConstantExpr::get(I->getOpcode(), C, InC); + } + + Value *Op0 = InV, *Op1 = C; + if (!ConstIsRHS) + std::swap(Op0, Op1); + + Value *RI = IC->Builder->CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp"); + auto *FPInst = dyn_cast<Instruction>(RI); + if (FPInst && isa<FPMathOperator>(FPInst)) + FPInst->copyFastMathFlags(I); + return RI; +} + +Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) return nullptr; @@ -902,7 +942,11 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { // Beware of ConstantExpr: it may eventually evaluate to getNullValue, // even if currently isNullValue gives false. Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); - if (InC && !isa<ConstantExpr>(InC)) + // For vector constants, we cannot use isNullValue to fold into + // FalseVInPred versus TrueVInPred. When we have individual nonzero + // elements in the vector, we will incorrectly fold InC to + // `TrueVInPred`. + if (InC && !isa<ConstantExpr>(InC) && isa<ConstantInt>(InC)) InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; else InV = Builder->CreateSelect(PN->getIncomingValue(i), @@ -923,15 +967,9 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { C, "phitmp"); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } - } else if (I.getNumOperands() == 2) { - Constant *C = cast<Constant>(I.getOperand(1)); + } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = nullptr; - if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::get(I.getOpcode(), InC, C); - else - InV = Builder->CreateBinOp(cast<BinaryOperator>(I).getOpcode(), - PN->getIncomingValue(i), C, "phitmp"); + Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), this); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } } else { @@ -957,14 +995,14 @@ Instruction *InstCombiner::FoldOpIntoPhi(Instruction &I) { return replaceInstUsesWith(I, NewPN); } -Instruction *InstCombiner::foldOpWithConstantIntoOperand(Instruction &I) { +Instruction *InstCombiner::foldOpWithConstantIntoOperand(BinaryOperator &I) { assert(isa<Constant>(I.getOperand(1)) && "Unexpected operand type"); if (auto *Sel = dyn_cast<SelectInst>(I.getOperand(0))) { if (Instruction *NewSel = FoldOpIntoSelect(I, Sel)) return NewSel; - } else if (isa<PHINode>(I.getOperand(0))) { - if (Instruction *NewPhi = FoldOpIntoPhi(I)) + } else if (auto *PN = dyn_cast<PHINode>(I.getOperand(0))) { + if (Instruction *NewPhi = foldOpIntoPhi(I, PN)) return NewPhi; } return nullptr; @@ -1315,22 +1353,19 @@ Value *InstCombiner::SimplifyVectorOp(BinaryOperator &Inst) { assert(cast<VectorType>(LHS->getType())->getNumElements() == VWidth); assert(cast<VectorType>(RHS->getType())->getNumElements() == VWidth); - // If both arguments of binary operation are shuffles, which use the same - // mask and shuffle within a single vector, it is worthwhile to move the - // shuffle after binary operation: + // If both arguments of the binary operation are shuffles that use the same + // mask and shuffle within a single vector, move the shuffle after the binop: // Op(shuffle(v1, m), shuffle(v2, m)) -> shuffle(Op(v1, v2), m) - if (isa<ShuffleVectorInst>(LHS) && isa<ShuffleVectorInst>(RHS)) { - ShuffleVectorInst *LShuf = cast<ShuffleVectorInst>(LHS); - ShuffleVectorInst *RShuf = cast<ShuffleVectorInst>(RHS); - if (isa<UndefValue>(LShuf->getOperand(1)) && - isa<UndefValue>(RShuf->getOperand(1)) && - LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType() && - LShuf->getMask() == RShuf->getMask()) { - Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), - RShuf->getOperand(0), Builder); - return Builder->CreateShuffleVector(NewBO, - UndefValue::get(NewBO->getType()), LShuf->getMask()); - } + auto *LShuf = dyn_cast<ShuffleVectorInst>(LHS); + auto *RShuf = dyn_cast<ShuffleVectorInst>(RHS); + if (LShuf && RShuf && LShuf->getMask() == RShuf->getMask() && + isa<UndefValue>(LShuf->getOperand(1)) && + isa<UndefValue>(RShuf->getOperand(1)) && + LShuf->getOperand(0)->getType() == RShuf->getOperand(0)->getType()) { + Value *NewBO = CreateBinOpAsGiven(Inst, LShuf->getOperand(0), + RShuf->getOperand(0), Builder); + return Builder->CreateShuffleVector( + NewBO, UndefValue::get(NewBO->getType()), LShuf->getMask()); } // If one argument is a shuffle within one vector, the other is a constant, @@ -1559,27 +1594,21 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Replace: gep (gep %P, long B), long A, ... // With: T = long A+B; gep %P, T, ... // - Value *Sum; Value *SO1 = Src->getOperand(Src->getNumOperands()-1); Value *GO1 = GEP.getOperand(1); - if (SO1 == Constant::getNullValue(SO1->getType())) { - Sum = GO1; - } else if (GO1 == Constant::getNullValue(GO1->getType())) { - Sum = SO1; - } else { - // If they aren't the same type, then the input hasn't been processed - // by the loop above yet (which canonicalizes sequential index types to - // intptr_t). Just avoid transforming this until the input has been - // normalized. - if (SO1->getType() != GO1->getType()) - return nullptr; - // Only do the combine when GO1 and SO1 are both constants. Only in - // this case, we are sure the cost after the merge is never more than - // that before the merge. - if (!isa<Constant>(GO1) || !isa<Constant>(SO1)) - return nullptr; - Sum = Builder->CreateAdd(SO1, GO1, PtrOp->getName()+".sum"); - } + + // If they aren't the same type, then the input hasn't been processed + // by the loop above yet (which canonicalizes sequential index types to + // intptr_t). Just avoid transforming this until the input has been + // normalized. + if (SO1->getType() != GO1->getType()) + return nullptr; + + Value* Sum = SimplifyAddInst(GO1, SO1, false, false, DL, &TLI, &DT, &AC); + // Only do the combine when we are sure the cost after the + // merge is never more than that before the merge. + if (Sum == nullptr) + return nullptr; // Update the GEP in place if possible. if (Src->getNumOperands() == 2) { @@ -1654,14 +1683,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } - // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). - Value *StrippedPtr = PtrOp->stripPointerCasts(); - PointerType *StrippedPtrTy = dyn_cast<PointerType>(StrippedPtr->getType()); - // We do not handle pointer-vector geps here. - if (!StrippedPtrTy) + if (GEP.getType()->isVectorTy()) return nullptr; + // Handle gep(bitcast x) and gep(gep x, 0, 0, 0). + Value *StrippedPtr = PtrOp->stripPointerCasts(); + PointerType *StrippedPtrTy = cast<PointerType>(StrippedPtr->getType()); + if (StrippedPtr != PtrOp) { bool HasZeroPointerIndex = false; if (ConstantInt *C = dyn_cast<ConstantInt>(GEP.getOperand(1))) @@ -2239,11 +2268,11 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { ConstantInt *AddRHS; if (match(Cond, m_Add(m_Value(Op0), m_ConstantInt(AddRHS)))) { // Change 'switch (X+4) case 1:' into 'switch (X) case -3'. - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - Constant *NewCase = ConstantExpr::getSub(CaseIter.getCaseValue(), AddRHS); + for (auto Case : SI.cases()) { + Constant *NewCase = ConstantExpr::getSub(Case.getCaseValue(), AddRHS); assert(isa<ConstantInt>(NewCase) && "Result of expression should be constant"); - CaseIter.setValue(cast<ConstantInt>(NewCase)); + Case.setValue(cast<ConstantInt>(NewCase)); } SI.setCondition(Op0); return &SI; @@ -2275,9 +2304,9 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { Value *NewCond = Builder->CreateTrunc(Cond, Ty, "trunc"); SI.setCondition(NewCond); - for (SwitchInst::CaseIt CaseIter : SI.cases()) { - APInt TruncatedCase = CaseIter.getCaseValue()->getValue().trunc(NewWidth); - CaseIter.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); + for (auto Case : SI.cases()) { + APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); + Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } return &SI; } @@ -2934,8 +2963,8 @@ bool InstCombiner::run() { Result->takeName(I); // Push the new instruction and any users onto the worklist. - Worklist.Add(Result); Worklist.AddUsersToWorkList(*Result); + Worklist.Add(Result); // Insert the new instruction into the basic block... BasicBlock *InstParent = I->getParent(); @@ -2958,8 +2987,8 @@ bool InstCombiner::run() { if (isInstructionTriviallyDead(I, &TLI)) { eraseInstFromFunction(*I); } else { - Worklist.Add(I); Worklist.AddUsersToWorkList(*I); + Worklist.Add(I); } } MadeIRChange = true; @@ -3022,12 +3051,11 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, } // See if we can constant fold its operands. - for (User::op_iterator i = Inst->op_begin(), e = Inst->op_end(); i != e; - ++i) { - if (!isa<ConstantVector>(i) && !isa<ConstantExpr>(i)) + for (Use &U : Inst->operands()) { + if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U)) continue; - auto *C = cast<Constant>(i); + auto *C = cast<Constant>(U); Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) FoldRes = ConstantFoldConstant(C, DL, TLI); @@ -3035,7 +3063,10 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, FoldRes = C; if (FoldRes != C) { - *i = FoldRes; + DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + << "\n Old = " << *C + << "\n New = " << *FoldRes << '\n'); + U = FoldRes; MadeIRChange = true; } } @@ -3055,17 +3086,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { if (ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { - // See if this is an explicit destination. - for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); - i != e; ++i) - if (i.getCaseValue() == Cond) { - BasicBlock *ReachableBB = i.getCaseSuccessor(); - Worklist.push_back(ReachableBB); - continue; - } - - // Otherwise it is the default destination. - Worklist.push_back(SI->getDefaultDest()); + Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } } @@ -3152,6 +3173,7 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, InstCombiner IC(Worklist, &Builder, F.optForMinSize(), ExpensiveCombines, AA, AC, TLI, DT, DL, LI); + IC.MaxArraySizeForCombine = MaxArraySize; Changed |= IC.run(); if (!Changed) @@ -3176,9 +3198,10 @@ PreservedAnalyses InstCombinePass::run(Function &F, return PreservedAnalyses::all(); // Mark all the analyses that instcombine updates as preserved. - // FIXME: This should also 'preserve the CFG'. PreservedAnalyses PA; - PA.preserve<DominatorTreeAnalysis>(); + PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<GlobalsAA>(); return PA; } |