diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2020-07-26 19:36:28 +0000 |
commit | cfca06d7963fa0909f90483b42a6d7d194d01e08 (patch) | |
tree | 209fb2a2d68f8f277793fc8df46c753d31bc853b /llvm/lib/Transforms/InstCombine | |
parent | 706b4fc47bbc608932d3b491ae19a3b9cde9497b (diff) |
Notes
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
16 files changed, 3612 insertions, 2036 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index ec976a971e3ce..a7f5e0a7774d2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -270,7 +270,7 @@ void FAddendCoef::operator=(const FAddendCoef &That) { } void FAddendCoef::operator+=(const FAddendCoef &That) { - enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven; + RoundingMode RndMode = RoundingMode::NearestTiesToEven; if (isInt() == That.isInt()) { if (isInt()) IntVal += That.IntVal; @@ -663,8 +663,7 @@ Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { } Value *FAddCombine::createFNeg(Value *V) { - Value *Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType())); - Value *NewV = createFSub(Zero, V); + Value *NewV = Builder.CreateFNeg(V); if (Instruction *I = dyn_cast<Instruction>(NewV)) createInstPostProc(I, true); // fneg's don't receive instruction numbers. return NewV; @@ -724,8 +723,6 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { if (!CE.isMinusOne() && !CE.isOne()) InstrNeeded++; } - if (NegOpndNum == OpndNum) - InstrNeeded++; return InstrNeeded; } @@ -1044,8 +1041,7 @@ Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) { // Match RemOpV = X / C0 if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) { - Value *NewDivisor = - ConstantInt::get(X->getType()->getContext(), C0 * C1); + Value *NewDivisor = ConstantInt::get(X->getType(), C0 * C1); return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem") : Builder.CreateURem(X, NewDivisor, "urem"); } @@ -1307,9 +1303,28 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { match(&I, m_BinOp(m_c_Add(m_Not(m_Value(B)), m_Value(A)), m_One()))) return BinaryOperator::CreateSub(A, B); + // (A + RHS) + RHS --> A + (RHS << 1) + if (match(LHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(RHS, 1, "reass.add")); + + // LHS + (A + LHS) --> A + (LHS << 1) + if (match(RHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(LHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(LHS, 1, "reass.add")); + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 + const APInt *C1, *C2; + if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + APInt one(C2->getBitWidth(), 1); + APInt minusC1 = -(*C1); + if (minusC1 == (one << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); + return BinaryOperator::CreateSRem(RHS, NewRHS); + } + } + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1380,8 +1395,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (and A, B) (or A, B)) --> (add A, B) if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), m_c_And(m_Deferred(A), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); + // Replacing operands in-place to preserve nuw/nsw flags. + replaceOperand(I, 0, A); + replaceOperand(I, 1, B); return &I; } @@ -1685,12 +1701,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Instruction *X = foldVectorBinop(I)) return X; - // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) - return replaceInstUsesWith(I, V); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // If this is a 'B = x-(-A)', change to B = x+A. - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + // We deal with this without involving Negator to preserve NSW flag. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); @@ -1707,6 +1721,45 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } + auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + + bool Changed = false; + if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; + }; + + // First, let's try to interpret `sub a, b` as `add a, (sub 0, b)`, + // and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't + // a pure negation used by a select that looks like abs/nabs. + bool IsNegation = match(Op0, m_ZeroInt()); + if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) { + const Instruction *UI = dyn_cast<Instruction>(U); + if (!UI) + return false; + return match(UI, + m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || + match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); + })) { + if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this)) + return BinaryOperator::CreateAdd(NegOp1, Op0); + } + if (IsNegation) + return TryToNarrowDeduceFlags(); // Should have been handled in Negator! + + // (A*B)-(A*C) -> A*(B-C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(Op0, Op1); @@ -1723,33 +1776,40 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes())))) return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X); - // Y - (X + 1) --> ~X + Y - if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One())))) - return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0); + // Reassociate sub/add sequences to create more add instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_Add(m_OneUse(m_Sub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateAdd(X, Z); + Value *YW = Builder.CreateAdd(Y, Op1); + return BinaryOperator::CreateSub(XZ, YW); + } - // Y - ~X --> (X + 1) + Y - if (match(Op1, m_OneUse(m_Not(m_Value(X))))) { - return BinaryOperator::CreateAdd( - Builder.CreateAdd(Op0, ConstantInt::get(I.getType(), 1)), X); + auto m_AddRdx = [](Value *&Vec) { + return m_OneUse( + m_Intrinsic<Intrinsic::experimental_vector_reduce_add>(m_Value(Vec))); + }; + Value *V0, *V1; + if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1) + Value *Sub = Builder.CreateSub(V0, V1); + Value *Rdx = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_reduce_add, {Sub->getType()}, {Sub}); + return replaceInstUsesWith(I, Rdx); } if (Constant *C = dyn_cast<Constant>(Op0)) { - bool IsNegate = match(C, m_ZeroInt()); Value *X; - if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // 0 - (zext bool) --> sext bool + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (zext bool) --> bool ? C - 1 : C - if (IsNegate) - return CastInst::CreateSExtOrBitCast(X, I.getType()); return SelectInst::Create(X, SubOne(C), C); - } - if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // 0 - (sext bool) --> zext bool + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (sext bool) --> bool ? C + 1 : C - if (IsNegate) - return CastInst::CreateZExtOrBitCast(X, I.getType()); return SelectInst::Create(X, AddOne(C), C); - } // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) @@ -1768,7 +1828,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; // C-(C2-X) --> X+(C-C2) - if (match(Op1, m_Sub(m_Constant(C2), m_Value(X)))) + if (match(Op1, m_Sub(m_Constant(C2), m_Value(X))) && !isa<ConstantExpr>(C2)) return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); // C-(X+C2) --> (C-C2)-X @@ -1777,62 +1837,12 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } const APInt *Op0C; - if (match(Op0, m_APInt(Op0C))) { - - if (Op0C->isNullValue()) { - Value *Op1Wide; - match(Op1, m_TruncOrSelf(m_Value(Op1Wide))); - bool HadTrunc = Op1Wide != Op1; - bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse(); - unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits(); - - Value *X; - const APInt *ShAmt; - // -(X >>u 31) -> (X >>s 31) - if (NoTruncOrTruncIsOneUse && - match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) && - *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); - Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp); - NewShift->copyIRFlags(Op1Wide); - if (!HadTrunc) - return NewShift; - Builder.Insert(NewShift); - return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); - } - // -(X >>s 31) -> (X >>u 31) - if (NoTruncOrTruncIsOneUse && - match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) && - *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); - Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp); - NewShift->copyIRFlags(Op1Wide); - if (!HadTrunc) - return NewShift; - Builder.Insert(NewShift); - return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); - } - - if (!HadTrunc && Op1->hasOneUse()) { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; - if (SPF == SPF_ABS || SPF == SPF_NABS) { - // This is a negate of an ABS/NABS pattern. Just swap the operands - // of the select. - cast<SelectInst>(Op1)->swapValues(); - // Don't swap prof metadata, we didn't change the branch behavior. - return replaceInstUsesWith(I, Op1); - } - } - } - + if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - if (Op0C->isMask()) { - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) - return BinaryOperator::CreateXor(Op1, Op0); - } + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) + return BinaryOperator::CreateXor(Op1, Op0); } { @@ -1956,71 +1966,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return NewSel; } - if (Op1->hasOneUse()) { - Value *X = nullptr, *Y = nullptr, *Z = nullptr; - Constant *C = nullptr; - - // (X - (Y - Z)) --> (X + (Z - Y)). - if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) - return BinaryOperator::CreateAdd(Op0, - Builder.CreateSub(Z, Y, Op1->getName())); - - // (X - (X & Y)) --> (X & ~Y) - if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) - return BinaryOperator::CreateAnd(Op0, - Builder.CreateNot(Y, Y->getName() + ".not")); - - // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. - if (match(Op0, m_Zero())) { - Constant *Op11C; - if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) && - !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() && - Op11C->isNotOneValue()) { - Instruction *BO = - BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C)); - BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); - return BO; - } - } - - // 0 - (X << Y) -> (-X << Y) when X is freely negatable. - if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero())) - if (Value *XNeg = dyn_castNegVal(X)) - return BinaryOperator::CreateShl(XNeg, Y); - - // Subtracting -1/0 is the same as adding 1/0: - // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y) - // 'nuw' is dropped in favor of the canonical form. - if (match(Op1, m_SExt(m_Value(Y))) && - Y->getType()->getScalarSizeInBits() == 1) { - Value *Zext = Builder.CreateZExt(Y, I.getType()); - BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext); - Add->setHasNoSignedWrap(I.hasNoSignedWrap()); - return Add; - } - // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y) - // 'nuw' is dropped in favor of the canonical form. - if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) { - Value *Sext = Builder.CreateSExt(Y, I.getType()); - BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext); - Add->setHasNoSignedWrap(I.hasNoSignedWrap()); - return Add; - } - - // X - A*-B -> X + A*B - // X - -A*B -> X + A*B - Value *A, *B; - if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B))))) - return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B)); - - // X - A*C -> X + A*-C - // No need to handle commuted multiply because multiply handling will - // ensure constant will be move to the right hand side. - if (match(Op1, m_Mul(m_Value(A), m_Constant(C))) && !isa<ConstantExpr>(C)) { - Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(C)); - return BinaryOperator::CreateAdd(Op0, NewMul); - } - } + // (X - (X & Y)) --> (X & ~Y) + if (match(Op1, m_c_And(m_Specific(Op0), m_Value(Y))) && + (Op1->hasOneUse() || isa<Constant>(Y))) + return BinaryOperator::CreateAnd( + Op0, Builder.CreateNot(Y, Y->getName() + ".not")); { // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A @@ -2096,20 +2046,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; - if (Instruction *Ext = narrowMathIfNoOverflow(I)) - return Ext; - - bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { - Changed = true; - I.setHasNoSignedWrap(true); - } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { - Changed = true; - I.setHasNoUnsignedWrap(true); - } - - return Changed ? &I : nullptr; + return TryToNarrowDeduceFlags(); } /// This eliminates floating-point negation in either 'fneg(X)' or @@ -2132,6 +2069,12 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: + // -(X + C) --> -X + -C --> -C - X + if (I.hasNoSignedZeros() && + match(&I, m_FNeg(m_OneUse(m_FAdd(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + return nullptr; } @@ -2184,10 +2127,15 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return X; // Subtraction from -0.0 is the canonical form of fneg. - // fsub nsz 0, X ==> fsub nsz -0.0, X - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + // fsub -0.0, X ==> fneg X + // fsub nsz 0.0, X ==> fneg nsz X + // + // FIXME This matcher does not respect FTZ or DAZ yet: + // fsub -0.0, Denorm ==> +-0 + // fneg Denorm ==> -Denorm + Value *Op; + if (match(&I, m_FNeg(m_Value(Op)))) + return UnaryOperator::CreateFNegFMF(Op, &I); if (Instruction *X = foldFNegIntoConstant(I)) return X; @@ -2198,6 +2146,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { Value *X, *Y; Constant *C; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. @@ -2211,6 +2160,13 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } } + // (-X) - Op1 --> -(X + Op1) + if (I.hasNoSignedZeros() && !isa<ConstantExpr>(Op0) && + match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { + Value *FAdd = Builder.CreateFAddFMF(X, Op1, &I); + return UnaryOperator::CreateFNegFMF(FAdd, &I); + } + if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) if (Instruction *NV = FoldOpIntoSelect(I, SI)) @@ -2258,12 +2214,12 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + return UnaryOperator::CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + return UnaryOperator::CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { @@ -2276,6 +2232,34 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); } + // Reassociate fsub/fadd sequences to create more fadd instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_FAdd(m_OneUse(m_FSub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateFAddFMF(X, Z, &I); + Value *YW = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(XZ, YW, &I); + } + + auto m_FaddRdx = [](Value *&Sum, Value *&Vec) { + return m_OneUse( + m_Intrinsic<Intrinsic::experimental_vector_reduce_v2_fadd>( + m_Value(Sum), m_Value(Vec))); + }; + Value *A0, *A1, *V0, *V1; + if (match(Op0, m_FaddRdx(A0, V0)) && match(Op1, m_FaddRdx(A1, V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(A0, V0) - add_rdx(A1, V1) --> add_rdx(A0, V0 - V1) - A1 + Value *Sub = Builder.CreateFSubFMF(V0, V1, &I); + Value *Rdx = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_reduce_v2_fadd, + {A0->getType(), Sub->getType()}, {A0, Sub}, &I); + return BinaryOperator::CreateFSubFMF(Rdx, A1, &I); + } + if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; @@ -2285,6 +2269,12 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // complex pattern matching and remove this from InstCombine. if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); + + // (X - Y) - Op1 --> X - (Y + Op1) + if (match(Op0, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *FAdd = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(X, FAdd, &I); + } } return nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index cc0a9127f8b18..d3c718a919c0a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -143,8 +143,7 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, // the XOR is to toggle the bit. If it is clear, then the ADD has // no effect. if ((AddRHS & AndRHSV).isNullValue()) { // Bit is not set, noop - TheAnd.setOperand(0, X); - return &TheAnd; + return replaceOperand(TheAnd, 0, X); } else { // Pull the XOR out of the AND. Value *NewAnd = Builder.CreateAnd(X, AndRHS); @@ -858,8 +857,10 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - bool JoinedByAnd, - Instruction &CxtI) { + BinaryOperator &Logic) { + bool JoinedByAnd = Logic.getOpcode() == Instruction::And; + assert((JoinedByAnd || Logic.getOpcode() == Instruction::Or) && + "Wrong opcode"); ICmpInst::Predicate Pred = LHS->getPredicate(); if (Pred != RHS->getPredicate()) return nullptr; @@ -883,8 +884,8 @@ Value *InstCombiner::foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, std::swap(A, B); if (A == C && - isKnownToBeAPowerOfTwo(B, false, 0, &CxtI) && - isKnownToBeAPowerOfTwo(D, false, 0, &CxtI)) { + isKnownToBeAPowerOfTwo(B, false, 0, &Logic) && + isKnownToBeAPowerOfTwo(D, false, 0, &Logic)) { Value *Mask = Builder.CreateOr(B, D); Value *Masked = Builder.CreateAnd(A, Mask); auto NewPred = JoinedByAnd ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE; @@ -1072,9 +1073,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { - if (UnsignedICmp->getOperand(0) != ZeroCmpOp) - UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); - auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { if (!IsKnownNonZero(NonZero)) std::swap(NonZero, Other); @@ -1111,8 +1109,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || !ICmpInst::isUnsigned(UnsignedPred)) return nullptr; - if (UnsignedICmp->getOperand(0) != Base) - UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset // (no overflow and not null) @@ -1141,14 +1137,59 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, return nullptr; } +/// Reduce logic-of-compares with equality to a constant by substituting a +/// common operand with the constant. Callers are expected to call this with +/// Cmp0/Cmp1 switched to handle logic op commutativity. +static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, + BinaryOperator &Logic, + InstCombiner::BuilderTy &Builder, + const SimplifyQuery &Q) { + bool IsAnd = Logic.getOpcode() == Instruction::And; + assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op"); + + // Match an equality compare with a non-poison constant as Cmp0. + ICmpInst::Predicate Pred0; + Value *X; + Constant *C; + if (!match(Cmp0, m_ICmp(Pred0, m_Value(X), m_Constant(C))) || + !isGuaranteedNotToBeUndefOrPoison(C)) + return nullptr; + if ((IsAnd && Pred0 != ICmpInst::ICMP_EQ) || + (!IsAnd && Pred0 != ICmpInst::ICMP_NE)) + return nullptr; + + // The other compare must include a common operand (X). Canonicalize the + // common operand as operand 1 (Pred1 is swapped if the common operand was + // operand 0). + Value *Y; + ICmpInst::Predicate Pred1; + if (!match(Cmp1, m_c_ICmp(Pred1, m_Value(Y), m_Deferred(X)))) + return nullptr; + + // Replace variable with constant value equivalence to remove a variable use: + // (X == C) && (Y Pred1 X) --> (X == C) && (Y Pred1 C) + // (X != C) || (Y Pred1 X) --> (X != C) || (Y Pred1 C) + // Can think of the 'or' substitution with the 'and' bool equivalent: + // A || B --> A || (!A && B) + Value *SubstituteCmp = SimplifyICmpInst(Pred1, Y, C, Q); + if (!SubstituteCmp) { + // If we need to create a new instruction, require that the old compare can + // be removed. + if (!Cmp1->hasOneUse()) + return nullptr; + SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); + } + return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction &CxtI) { - const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + BinaryOperator &And) { + const SimplifyQuery Q = SQ.getWithInstruction(&And); // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, And)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -1171,6 +1212,11 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, And, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q)) + return V; + // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/false)) return V; @@ -1182,7 +1228,7 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) return V; - if (Value *V = foldSignedTruncationCheck(LHS, RHS, CxtI, Builder)) + if (Value *V = foldSignedTruncationCheck(LHS, RHS, And, Builder)) return V; if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) @@ -1658,7 +1704,7 @@ static bool canNarrowShiftAmt(Constant *C, unsigned BitWidth) { if (C->getType()->isVectorTy()) { // Check each element of a constant vector. - unsigned NumElts = C->getType()->getVectorNumElements(); + unsigned NumElts = cast<VectorType>(C->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -1802,7 +1848,17 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::Create(BinOp, NewLHS, Y); } } - + const APInt *ShiftC; + if (match(Op0, m_OneUse(m_SExt(m_AShr(m_Value(X), m_APInt(ShiftC)))))) { + unsigned Width = I.getType()->getScalarSizeInBits(); + if (*C == APInt::getLowBitsSet(Width, Width - ShiftC->getZExtValue())) { + // We are clearing high bits that were potentially set by sext+ashr: + // and (sext (ashr X, ShiftC)), C --> lshr (sext X), ShiftC + Value *Sext = Builder.CreateSExt(X, I.getType()); + Constant *ShAmtC = ConstantInt::get(I.getType(), ShiftC->zext(Width)); + return BinaryOperator::CreateLShr(Sext, ShAmtC); + } + } } if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { @@ -2020,7 +2076,7 @@ Instruction *InstCombiner::matchBSwap(BinaryOperator &Or) { LastInst->removeFromParent(); for (auto *Inst : Insts) - Worklist.Add(Inst); + Worklist.push(Inst); return LastInst; } @@ -2086,9 +2142,62 @@ static Instruction *matchRotate(Instruction &Or) { return IntrinsicInst::Create(F, { ShVal, ShVal, ShAmt }); } +/// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. +static Instruction *matchOrConcat(Instruction &Or, + InstCombiner::BuilderTy &Builder) { + assert(Or.getOpcode() == Instruction::Or && "bswap requires an 'or'"); + Value *Op0 = Or.getOperand(0), *Op1 = Or.getOperand(1); + Type *Ty = Or.getType(); + + unsigned Width = Ty->getScalarSizeInBits(); + if ((Width & 1) != 0) + return nullptr; + unsigned HalfWidth = Width / 2; + + // Canonicalize zext (lower half) to LHS. + if (!isa<ZExtInst>(Op0)) + std::swap(Op0, Op1); + + // Find lower/upper half. + Value *LowerSrc, *ShlVal, *UpperSrc; + const APInt *C; + if (!match(Op0, m_OneUse(m_ZExt(m_Value(LowerSrc)))) || + !match(Op1, m_OneUse(m_Shl(m_Value(ShlVal), m_APInt(C)))) || + !match(ShlVal, m_OneUse(m_ZExt(m_Value(UpperSrc))))) + return nullptr; + if (*C != HalfWidth || LowerSrc->getType() != UpperSrc->getType() || + LowerSrc->getType()->getScalarSizeInBits() != HalfWidth) + return nullptr; + + auto ConcatIntrinsicCalls = [&](Intrinsic::ID id, Value *Lo, Value *Hi) { + Value *NewLower = Builder.CreateZExt(Lo, Ty); + Value *NewUpper = Builder.CreateZExt(Hi, Ty); + NewUpper = Builder.CreateShl(NewUpper, HalfWidth); + Value *BinOp = Builder.CreateOr(NewLower, NewUpper); + Function *F = Intrinsic::getDeclaration(Or.getModule(), id, Ty); + return Builder.CreateCall(F, BinOp); + }; + + // BSWAP: Push the concat down, swapping the lower/upper sources. + // concat(bswap(x),bswap(y)) -> bswap(concat(x,y)) + Value *LowerBSwap, *UpperBSwap; + if (match(LowerSrc, m_BSwap(m_Value(LowerBSwap))) && + match(UpperSrc, m_BSwap(m_Value(UpperBSwap)))) + return ConcatIntrinsicCalls(Intrinsic::bswap, UpperBSwap, LowerBSwap); + + // BITREVERSE: Push the concat down, swapping the lower/upper sources. + // concat(bitreverse(x),bitreverse(y)) -> bitreverse(concat(x,y)) + Value *LowerBRev, *UpperBRev; + if (match(LowerSrc, m_BitReverse(m_Value(LowerBRev))) && + match(UpperSrc, m_BitReverse(m_Value(UpperBRev)))) + return ConcatIntrinsicCalls(Intrinsic::bitreverse, UpperBRev, LowerBRev); + + return nullptr; +} + /// If all elements of two constant vectors are 0/-1 and inverses, return true. static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { - unsigned NumElts = C1->getType()->getVectorNumElements(); + unsigned NumElts = cast<VectorType>(C1->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *EltC1 = C1->getAggregateElement(i); Constant *EltC2 = C2->getAggregateElement(i); @@ -2185,12 +2294,12 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, - Instruction &CxtI) { - const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + BinaryOperator &Or) { + const SimplifyQuery Q = SQ.getWithInstruction(&Or); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) + if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, Or)) return V; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); @@ -2299,6 +2408,11 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Builder.CreateAdd(B, ConstantInt::getSigned(B->getType(), -1)), A); } + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q)) + return V; + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) + return V; + // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n if (Value *V = simplifyRangeCheck(LHS, RHS, /*Inverted=*/true)) return V; @@ -2481,6 +2595,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Instruction *Rotate = matchRotate(I)) return Rotate; + if (Instruction *Concat = matchOrConcat(I, Builder)) + return replaceInstUsesWith(I, Concat); + Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && @@ -2729,6 +2846,32 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; + CmpInst::Predicate Pred; + Value *Mul, *Ov, *MulIsNotZero, *UMulWithOv; + // Check if the OR weakens the overflow condition for umul.with.overflow by + // treating any non-zero result as overflow. In that case, we overflow if both + // umul.with.overflow operands are != 0, as in that case the result can only + // be 0, iff the multiplication overflows. + if (match(&I, + m_c_Or(m_CombineAnd(m_ExtractValue<1>(m_Value(UMulWithOv)), + m_Value(Ov)), + m_CombineAnd(m_ICmp(Pred, + m_CombineAnd(m_ExtractValue<0>( + m_Deferred(UMulWithOv)), + m_Value(Mul)), + m_ZeroInt()), + m_Value(MulIsNotZero)))) && + (Ov->hasOneUse() || (MulIsNotZero->hasOneUse() && Mul->hasOneUse())) && + Pred == CmpInst::ICMP_NE) { + Value *A, *B; + if (match(UMulWithOv, m_Intrinsic<Intrinsic::umul_with_overflow>( + m_Value(A), m_Value(B)))) { + Value *NotNullA = Builder.CreateIsNotNull(A); + Value *NotNullB = Builder.CreateIsNotNull(B); + return BinaryOperator::CreateAnd(NotNullA, NotNullB); + } + } + return nullptr; } @@ -2748,33 +2891,24 @@ static Instruction *foldXorToXor(BinaryOperator &I, // (A | B) ^ (A & B) -> A ^ B // (A | B) ^ (B & A) -> A ^ B if (match(&I, m_c_Xor(m_And(m_Value(A), m_Value(B)), - m_c_Or(m_Deferred(A), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_Or(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // (A | ~B) ^ (~A | B) -> A ^ B // (~B | A) ^ (~A | B) -> A ^ B // (~A | B) ^ (A | ~B) -> A ^ B // (B | ~A) ^ (A | ~B) -> A ^ B if (match(&I, m_Xor(m_c_Or(m_Value(A), m_Not(m_Value(B))), - m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_Or(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // (A & ~B) ^ (~A & B) -> A ^ B // (~B & A) ^ (~A & B) -> A ^ B // (~A & B) ^ (A & ~B) -> A ^ B // (B & ~A) ^ (A & ~B) -> A ^ B if (match(&I, m_Xor(m_c_And(m_Value(A), m_Not(m_Value(B))), - m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); - return &I; - } + m_c_And(m_Not(m_Deferred(A)), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); // For the remaining cases we need to get rid of one of the operands. if (!Op0->hasOneUse() && !Op1->hasOneUse()) @@ -2878,6 +3012,7 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator())); Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); // Replace all uses of Y (excluding the one in NotY!) with NotY. + Worklist.pushUsersToWorkList(*Y); Y->replaceUsesWithIf(NotY, [NotY](Use &U) { return U.getUser() != NotY; }); } @@ -2924,6 +3059,9 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, Constant *C; if (D->hasOneUse() && match(M, m_Constant(C))) { + // Propagating undef is unsafe. Clamp undef elements to -1. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); // Unfold. Value *LHS = Builder.CreateAnd(X, C); Value *NotC = Builder.CreateNot(C); @@ -3058,13 +3196,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // ~(C >>s Y) --> ~C >>u Y (when inverting the replicated sign bits) Constant *C; if (match(NotVal, m_AShr(m_Constant(C), m_Value(Y))) && - match(C, m_Negative())) + match(C, m_Negative())) { + // We matched a negative constant, so propagating undef is unsafe. + // Clamp undef elements to -1. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getAllOnesValue(EltTy)); return BinaryOperator::CreateLShr(ConstantExpr::getNot(C), Y); + } // ~(C >>u Y) --> ~C >>s Y (when inverting the replicated sign bits) if (match(NotVal, m_LShr(m_Constant(C), m_Value(Y))) && - match(C, m_NonNegative())) + match(C, m_NonNegative())) { + // We matched a non-negative constant, so propagating undef is unsafe. + // Clamp undef elements to 0. + Type *EltTy = C->getType()->getScalarType(); + C = Constant::replaceUndefsWith(C, ConstantInt::getNullValue(EltTy)); return BinaryOperator::CreateAShr(ConstantExpr::getNot(C), Y); + } // ~(X + C) --> -(C + 1) - X if (match(Op0, m_Add(m_Value(X), m_Constant(C)))) @@ -3114,10 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && MaskedValueIsZero(X, *C, 0, &I)) { Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC); - Worklist.Add(cast<Instruction>(Op0)); - I.setOperand(0, X); - I.setOperand(1, NewC); - return &I; + return BinaryOperator::CreateXor(X, NewC); } } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 825f4b468b0a7..ba1cf982229d7 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -124,7 +124,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { auto *SI = new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), &RMWI); SI->setAtomic(Ordering, RMWI.getSyncScopeID()); - SI->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); + SI->setAlignment(DL.getABITypeAlign(RMWI.getType())); return eraseInstFromFunction(RMWI); } @@ -138,13 +138,11 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { if (RMWI.getType()->isIntegerTy() && RMWI.getOperation() != AtomicRMWInst::Or) { RMWI.setOperation(AtomicRMWInst::Or); - RMWI.setOperand(1, ConstantInt::get(RMWI.getType(), 0)); - return &RMWI; + return replaceOperand(RMWI, 1, ConstantInt::get(RMWI.getType(), 0)); } else if (RMWI.getType()->isFloatingPointTy() && RMWI.getOperation() != AtomicRMWInst::FAdd) { RMWI.setOperation(AtomicRMWInst::FAdd); - RMWI.setOperand(1, ConstantFP::getNegativeZero(RMWI.getType())); - return &RMWI; + return replaceOperand(RMWI, 1, ConstantFP::getNegativeZero(RMWI.getType())); } // Check if the required ordering is compatible with an atomic load. @@ -152,8 +150,8 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { Ordering != AtomicOrdering::Monotonic) return nullptr; - LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand()); - Load->setAtomic(Ordering, RMWI.getSyncScopeID()); - Load->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); + LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "", + false, DL.getABITypeAlign(RMWI.getType()), + Ordering, RMWI.getSyncScopeID()); return Load; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index f463c5fa1138a..c734c9a68fb2d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -15,12 +15,15 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" @@ -40,12 +43,13 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" -#include "llvm/IR/IntrinsicsX86.h" -#include "llvm/IR/IntrinsicsARM.h" #include "llvm/IR/IntrinsicsAArch64.h" -#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IntrinsicsHexagon.h" +#include "llvm/IR/IntrinsicsNVPTX.h" #include "llvm/IR/IntrinsicsPowerPC.h" +#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" @@ -114,16 +118,16 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { } Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { - unsigned DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); - unsigned CopyDstAlign = MI->getDestAlignment(); - if (CopyDstAlign < DstAlign){ + Align DstAlign = getKnownAlignment(MI->getRawDest(), DL, MI, &AC, &DT); + MaybeAlign CopyDstAlign = MI->getDestAlign(); + if (!CopyDstAlign || *CopyDstAlign < DstAlign) { MI->setDestAlignment(DstAlign); return MI; } - unsigned SrcAlign = getKnownAlignment(MI->getRawSource(), DL, MI, &AC, &DT); - unsigned CopySrcAlign = MI->getSourceAlignment(); - if (CopySrcAlign < SrcAlign) { + Align SrcAlign = getKnownAlignment(MI->getRawSource(), DL, MI, &AC, &DT); + MaybeAlign CopySrcAlign = MI->getSourceAlign(); + if (!CopySrcAlign || *CopySrcAlign < SrcAlign) { MI->setSourceAlignment(SrcAlign); return MI; } @@ -157,7 +161,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { // into libcall in CodeGen. This is not evident performance gain so disable // it now. if (isa<AtomicMemTransferInst>(MI)) - if (CopyDstAlign < Size || CopySrcAlign < Size) + if (*CopyDstAlign < Size || *CopySrcAlign < Size) return nullptr; // Use an integer load+store unless we can find something better. @@ -191,8 +195,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. - L->setAlignment( - MaybeAlign(CopySrcAlign)); // FIXME: Check if we can use Align instead. + L->setAlignment(*CopySrcAlign); if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); MDNode *LoopMemParallelMD = @@ -205,8 +208,7 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { StoreInst *S = Builder.CreateStore(L, Dest); // Alignment from the mem intrinsic will be better, so use it. - S->setAlignment( - MaybeAlign(CopyDstAlign)); // FIXME: Check if we can use Align instead. + S->setAlignment(*CopyDstAlign); if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) @@ -231,9 +233,10 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { } Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { - const unsigned KnownAlignment = + const Align KnownAlignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); - if (MI->getDestAlignment() < KnownAlignment) { + MaybeAlign MemSetAlign = MI->getDestAlign(); + if (!MemSetAlign || *MemSetAlign < KnownAlignment) { MI->setDestAlignment(KnownAlignment); return MI; } @@ -293,106 +296,154 @@ static Value *simplifyX86immShift(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { bool LogicalShift = false; bool ShiftLeft = false; + bool IsImm = false; switch (II.getIntrinsicID()) { default: llvm_unreachable("Unexpected intrinsic!"); - case Intrinsic::x86_sse2_psra_d: - case Intrinsic::x86_sse2_psra_w: case Intrinsic::x86_sse2_psrai_d: case Intrinsic::x86_sse2_psrai_w: - case Intrinsic::x86_avx2_psra_d: - case Intrinsic::x86_avx2_psra_w: case Intrinsic::x86_avx2_psrai_d: case Intrinsic::x86_avx2_psrai_w: - case Intrinsic::x86_avx512_psra_q_128: case Intrinsic::x86_avx512_psrai_q_128: - case Intrinsic::x86_avx512_psra_q_256: case Intrinsic::x86_avx512_psrai_q_256: - case Intrinsic::x86_avx512_psra_d_512: - case Intrinsic::x86_avx512_psra_q_512: - case Intrinsic::x86_avx512_psra_w_512: case Intrinsic::x86_avx512_psrai_d_512: case Intrinsic::x86_avx512_psrai_q_512: case Intrinsic::x86_avx512_psrai_w_512: - LogicalShift = false; ShiftLeft = false; + IsImm = true; + LLVM_FALLTHROUGH; + case Intrinsic::x86_sse2_psra_d: + case Intrinsic::x86_sse2_psra_w: + case Intrinsic::x86_avx2_psra_d: + case Intrinsic::x86_avx2_psra_w: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psra_w_512: + LogicalShift = false; + ShiftLeft = false; break; - case Intrinsic::x86_sse2_psrl_d: - case Intrinsic::x86_sse2_psrl_q: - case Intrinsic::x86_sse2_psrl_w: case Intrinsic::x86_sse2_psrli_d: case Intrinsic::x86_sse2_psrli_q: case Intrinsic::x86_sse2_psrli_w: - case Intrinsic::x86_avx2_psrl_d: - case Intrinsic::x86_avx2_psrl_q: - case Intrinsic::x86_avx2_psrl_w: case Intrinsic::x86_avx2_psrli_d: case Intrinsic::x86_avx2_psrli_q: case Intrinsic::x86_avx2_psrli_w: - case Intrinsic::x86_avx512_psrl_d_512: - case Intrinsic::x86_avx512_psrl_q_512: - case Intrinsic::x86_avx512_psrl_w_512: case Intrinsic::x86_avx512_psrli_d_512: case Intrinsic::x86_avx512_psrli_q_512: case Intrinsic::x86_avx512_psrli_w_512: - LogicalShift = true; ShiftLeft = false; + IsImm = true; + LLVM_FALLTHROUGH; + case Intrinsic::x86_sse2_psrl_d: + case Intrinsic::x86_sse2_psrl_q: + case Intrinsic::x86_sse2_psrl_w: + case Intrinsic::x86_avx2_psrl_d: + case Intrinsic::x86_avx2_psrl_q: + case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psrl_w_512: + LogicalShift = true; + ShiftLeft = false; break; - case Intrinsic::x86_sse2_psll_d: - case Intrinsic::x86_sse2_psll_q: - case Intrinsic::x86_sse2_psll_w: case Intrinsic::x86_sse2_pslli_d: case Intrinsic::x86_sse2_pslli_q: case Intrinsic::x86_sse2_pslli_w: - case Intrinsic::x86_avx2_psll_d: - case Intrinsic::x86_avx2_psll_q: - case Intrinsic::x86_avx2_psll_w: case Intrinsic::x86_avx2_pslli_d: case Intrinsic::x86_avx2_pslli_q: case Intrinsic::x86_avx2_pslli_w: - case Intrinsic::x86_avx512_psll_d_512: - case Intrinsic::x86_avx512_psll_q_512: - case Intrinsic::x86_avx512_psll_w_512: case Intrinsic::x86_avx512_pslli_d_512: case Intrinsic::x86_avx512_pslli_q_512: case Intrinsic::x86_avx512_pslli_w_512: - LogicalShift = true; ShiftLeft = true; + IsImm = true; + LLVM_FALLTHROUGH; + case Intrinsic::x86_sse2_psll_d: + case Intrinsic::x86_sse2_psll_q: + case Intrinsic::x86_sse2_psll_w: + case Intrinsic::x86_avx2_psll_d: + case Intrinsic::x86_avx2_psll_q: + case Intrinsic::x86_avx2_psll_w: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_psll_w_512: + LogicalShift = true; + ShiftLeft = true; break; } assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); - // Simplify if count is constant. - auto Arg1 = II.getArgOperand(1); - auto CAZ = dyn_cast<ConstantAggregateZero>(Arg1); - auto CDV = dyn_cast<ConstantDataVector>(Arg1); - auto CInt = dyn_cast<ConstantInt>(Arg1); - if (!CAZ && !CDV && !CInt) - return nullptr; - - APInt Count(64, 0); - if (CDV) { - // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector - // operand to compute the shift amount. - auto VT = cast<VectorType>(CDV->getType()); - unsigned BitWidth = VT->getElementType()->getPrimitiveSizeInBits(); - assert((64 % BitWidth) == 0 && "Unexpected packed shift size"); - unsigned NumSubElts = 64 / BitWidth; - - // Concatenate the sub-elements to create the 64-bit value. - for (unsigned i = 0; i != NumSubElts; ++i) { - unsigned SubEltIdx = (NumSubElts - 1) - i; - auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); - Count <<= BitWidth; - Count |= SubElt->getValue().zextOrTrunc(64); - } - } - else if (CInt) - Count = CInt->getValue(); - auto Vec = II.getArgOperand(0); + auto Amt = II.getArgOperand(1); auto VT = cast<VectorType>(Vec->getType()); auto SVT = VT->getElementType(); + auto AmtVT = Amt->getType(); unsigned VWidth = VT->getNumElements(); unsigned BitWidth = SVT->getPrimitiveSizeInBits(); + // If the shift amount is guaranteed to be in-range we can replace it with a + // generic shift. If its guaranteed to be out of range, logical shifts combine to + // zero and arithmetic shifts are clamped to (BitWidth - 1). + if (IsImm) { + assert(AmtVT ->isIntegerTy(32) && + "Unexpected shift-by-immediate type"); + KnownBits KnownAmtBits = + llvm::computeKnownBits(Amt, II.getModule()->getDataLayout()); + if (KnownAmtBits.getMaxValue().ult(BitWidth)) { + Amt = Builder.CreateZExtOrTrunc(Amt, SVT); + Amt = Builder.CreateVectorSplat(VWidth, Amt); + return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) + : Builder.CreateLShr(Vec, Amt)) + : Builder.CreateAShr(Vec, Amt)); + } + if (KnownAmtBits.getMinValue().uge(BitWidth)) { + if (LogicalShift) + return ConstantAggregateZero::get(VT); + Amt = ConstantInt::get(SVT, BitWidth - 1); + return Builder.CreateAShr(Vec, Builder.CreateVectorSplat(VWidth, Amt)); + } + } else { + // Ensure the first element has an in-range value and the rest of the + // elements in the bottom 64 bits are zero. + assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 && + cast<VectorType>(AmtVT)->getElementType() == SVT && + "Unexpected shift-by-scalar type"); + unsigned NumAmtElts = cast<VectorType>(AmtVT)->getNumElements(); + APInt DemandedLower = APInt::getOneBitSet(NumAmtElts, 0); + APInt DemandedUpper = APInt::getBitsSet(NumAmtElts, 1, NumAmtElts / 2); + KnownBits KnownLowerBits = llvm::computeKnownBits( + Amt, DemandedLower, II.getModule()->getDataLayout()); + KnownBits KnownUpperBits = llvm::computeKnownBits( + Amt, DemandedUpper, II.getModule()->getDataLayout()); + if (KnownLowerBits.getMaxValue().ult(BitWidth) && + (DemandedUpper.isNullValue() || KnownUpperBits.isZero())) { + SmallVector<int, 16> ZeroSplat(VWidth, 0); + Amt = Builder.CreateShuffleVector(Amt, Amt, ZeroSplat); + return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) + : Builder.CreateLShr(Vec, Amt)) + : Builder.CreateAShr(Vec, Amt)); + } + } + + // Simplify if count is constant vector. + auto CDV = dyn_cast<ConstantDataVector>(Amt); + if (!CDV) + return nullptr; + + // SSE2/AVX2 uses all the first 64-bits of the 128-bit vector + // operand to compute the shift amount. + assert(AmtVT->isVectorTy() && AmtVT->getPrimitiveSizeInBits() == 128 && + cast<VectorType>(AmtVT)->getElementType() == SVT && + "Unexpected shift-by-scalar type"); + + // Concatenate the sub-elements to create the 64-bit value. + APInt Count(64, 0); + for (unsigned i = 0, NumSubElts = 64 / BitWidth; i != NumSubElts; ++i) { + unsigned SubEltIdx = (NumSubElts - 1) - i; + auto SubElt = cast<ConstantInt>(CDV->getElementAsConstant(SubEltIdx)); + Count <<= BitWidth; + Count |= SubElt->getValue().zextOrTrunc(64); + } + // If shift-by-zero then just return the original value. if (Count.isNullValue()) return Vec; @@ -469,17 +520,29 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, } assert((LogicalShift || !ShiftLeft) && "Only logical shifts can shift left"); - // Simplify if all shift amounts are constant/undef. - auto *CShift = dyn_cast<Constant>(II.getArgOperand(1)); - if (!CShift) - return nullptr; - auto Vec = II.getArgOperand(0); + auto Amt = II.getArgOperand(1); auto VT = cast<VectorType>(II.getType()); - auto SVT = VT->getVectorElementType(); + auto SVT = VT->getElementType(); int NumElts = VT->getNumElements(); int BitWidth = SVT->getIntegerBitWidth(); + // If the shift amount is guaranteed to be in-range we can replace it with a + // generic shift. + APInt UpperBits = + APInt::getHighBitsSet(BitWidth, BitWidth - Log2_32(BitWidth)); + if (llvm::MaskedValueIsZero(Amt, UpperBits, + II.getModule()->getDataLayout())) { + return (LogicalShift ? (ShiftLeft ? Builder.CreateShl(Vec, Amt) + : Builder.CreateLShr(Vec, Amt)) + : Builder.CreateAShr(Vec, Amt)); + } + + // Simplify if all shift amounts are constant/undef. + auto *CShift = dyn_cast<Constant>(Amt); + if (!CShift) + return nullptr; + // Collect each element's shift amount. // We also collect special cases: UNDEF = -1, OUT-OF-RANGE = BitWidth. bool AnyOutOfRange = false; @@ -557,10 +620,10 @@ static Value *simplifyX86pack(IntrinsicInst &II, if (isa<UndefValue>(Arg0) && isa<UndefValue>(Arg1)) return UndefValue::get(ResTy); - Type *ArgTy = Arg0->getType(); + auto *ArgTy = cast<VectorType>(Arg0->getType()); unsigned NumLanes = ResTy->getPrimitiveSizeInBits() / 128; - unsigned NumSrcElts = ArgTy->getVectorNumElements(); - assert(ResTy->getVectorNumElements() == (2 * NumSrcElts) && + unsigned NumSrcElts = ArgTy->getNumElements(); + assert(cast<VectorType>(ResTy)->getNumElements() == (2 * NumSrcElts) && "Unexpected packing types"); unsigned NumSrcEltsPerLane = NumSrcElts / NumLanes; @@ -600,7 +663,7 @@ static Value *simplifyX86pack(IntrinsicInst &II, Arg1 = Builder.CreateSelect(Builder.CreateICmpSGT(Arg1, MaxC), MaxC, Arg1); // Shuffle clamped args together at the lane level. - SmallVector<unsigned, 32> PackMask; + SmallVector<int, 32> PackMask; for (unsigned Lane = 0; Lane != NumLanes; ++Lane) { for (unsigned Elt = 0; Elt != NumSrcEltsPerLane; ++Elt) PackMask.push_back(Elt + (Lane * NumSrcEltsPerLane)); @@ -617,14 +680,14 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder) { Value *Arg = II.getArgOperand(0); Type *ResTy = II.getType(); - Type *ArgTy = Arg->getType(); // movmsk(undef) -> zero as we must ensure the upper bits are zero. if (isa<UndefValue>(Arg)) return Constant::getNullValue(ResTy); + auto *ArgTy = dyn_cast<VectorType>(Arg->getType()); // We can't easily peek through x86_mmx types. - if (!ArgTy->isVectorTy()) + if (!ArgTy) return nullptr; // Expand MOVMSK to compare/bitcast/zext: @@ -632,8 +695,8 @@ static Value *simplifyX86movmsk(const IntrinsicInst &II, // %cmp = icmp slt <16 x i8> %x, zeroinitializer // %int = bitcast <16 x i1> %cmp to i16 // %res = zext i16 %int to i32 - unsigned NumElts = ArgTy->getVectorNumElements(); - Type *IntegerVecTy = VectorType::getInteger(cast<VectorType>(ArgTy)); + unsigned NumElts = ArgTy->getNumElements(); + Type *IntegerVecTy = VectorType::getInteger(ArgTy); Type *IntegerTy = Builder.getIntNTy(NumElts); Value *Res = Builder.CreateBitCast(Arg, IntegerVecTy); @@ -697,7 +760,7 @@ static Value *simplifyX86insertps(const IntrinsicInst &II, return ZeroVector; // Initialize by passing all of the first source bits through. - uint32_t ShuffleMask[4] = { 0, 1, 2, 3 }; + int ShuffleMask[4] = {0, 1, 2, 3}; // We may replace the second operand with the zero vector. Value *V1 = II.getArgOperand(1); @@ -777,22 +840,19 @@ static Value *simplifyX86extrq(IntrinsicInst &II, Value *Op0, Index /= 8; Type *IntTy8 = Type::getInt8Ty(II.getContext()); - Type *IntTy32 = Type::getInt32Ty(II.getContext()); - VectorType *ShufTy = VectorType::get(IntTy8, 16); + auto *ShufTy = FixedVectorType::get(IntTy8, 16); - SmallVector<Constant *, 16> ShuffleMask; + SmallVector<int, 16> ShuffleMask; for (int i = 0; i != (int)Length; ++i) - ShuffleMask.push_back( - Constant::getIntegerValue(IntTy32, APInt(32, i + Index))); + ShuffleMask.push_back(i + Index); for (int i = Length; i != 8; ++i) - ShuffleMask.push_back( - Constant::getIntegerValue(IntTy32, APInt(32, i + 16))); + ShuffleMask.push_back(i + 16); for (int i = 8; i != 16; ++i) - ShuffleMask.push_back(UndefValue::get(IntTy32)); + ShuffleMask.push_back(-1); Value *SV = Builder.CreateShuffleVector( Builder.CreateBitCast(Op0, ShufTy), - ConstantAggregateZero::get(ShufTy), ConstantVector::get(ShuffleMask)); + ConstantAggregateZero::get(ShufTy), ShuffleMask); return Builder.CreateBitCast(SV, II.getType()); } @@ -857,23 +917,21 @@ static Value *simplifyX86insertq(IntrinsicInst &II, Value *Op0, Value *Op1, Index /= 8; Type *IntTy8 = Type::getInt8Ty(II.getContext()); - Type *IntTy32 = Type::getInt32Ty(II.getContext()); - VectorType *ShufTy = VectorType::get(IntTy8, 16); + auto *ShufTy = FixedVectorType::get(IntTy8, 16); - SmallVector<Constant *, 16> ShuffleMask; + SmallVector<int, 16> ShuffleMask; for (int i = 0; i != (int)Index; ++i) - ShuffleMask.push_back(Constant::getIntegerValue(IntTy32, APInt(32, i))); + ShuffleMask.push_back(i); for (int i = 0; i != (int)Length; ++i) - ShuffleMask.push_back( - Constant::getIntegerValue(IntTy32, APInt(32, i + 16))); + ShuffleMask.push_back(i + 16); for (int i = Index + Length; i != 8; ++i) - ShuffleMask.push_back(Constant::getIntegerValue(IntTy32, APInt(32, i))); + ShuffleMask.push_back(i); for (int i = 8; i != 16; ++i) - ShuffleMask.push_back(UndefValue::get(IntTy32)); + ShuffleMask.push_back(-1); Value *SV = Builder.CreateShuffleVector(Builder.CreateBitCast(Op0, ShufTy), Builder.CreateBitCast(Op1, ShufTy), - ConstantVector::get(ShuffleMask)); + ShuffleMask); return Builder.CreateBitCast(SV, II.getType()); } @@ -925,13 +983,12 @@ static Value *simplifyX86pshufb(const IntrinsicInst &II, return nullptr; auto *VecTy = cast<VectorType>(II.getType()); - auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned NumElts = VecTy->getNumElements(); assert((NumElts == 16 || NumElts == 32 || NumElts == 64) && "Unexpected number of elements in shuffle mask!"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[64] = {nullptr}; + int Indexes[64]; // Each byte in the shuffle control mask forms an index to permute the // corresponding byte in the destination operand. @@ -941,7 +998,7 @@ static Value *simplifyX86pshufb(const IntrinsicInst &II, return nullptr; if (isa<UndefValue>(COp)) { - Indexes[I] = UndefValue::get(MaskEltTy); + Indexes[I] = -1; continue; } @@ -955,13 +1012,12 @@ static Value *simplifyX86pshufb(const IntrinsicInst &II, // The value of each index for the high 128-bit lane is the least // significant 4 bits of the respective shuffle control byte. Index = ((Index < 0) ? NumElts : Index & 0x0F) + (I & 0xF0); - Indexes[I] = ConstantInt::get(MaskEltTy, Index); + Indexes[I] = Index; } - auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); auto V1 = II.getArgOperand(0); auto V2 = Constant::getNullValue(VecTy); - return Builder.CreateShuffleVector(V1, V2, ShuffleMask); + return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts)); } /// Attempt to convert vpermilvar* to shufflevector if the mask is constant. @@ -972,14 +1028,13 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, return nullptr; auto *VecTy = cast<VectorType>(II.getType()); - auto *MaskEltTy = Type::getInt32Ty(II.getContext()); - unsigned NumElts = VecTy->getVectorNumElements(); + unsigned NumElts = VecTy->getNumElements(); bool IsPD = VecTy->getScalarType()->isDoubleTy(); unsigned NumLaneElts = IsPD ? 2 : 4; assert(NumElts == 16 || NumElts == 8 || NumElts == 4 || NumElts == 2); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[16] = {nullptr}; + int Indexes[16]; // The intrinsics only read one or two bits, clear the rest. for (unsigned I = 0; I < NumElts; ++I) { @@ -988,7 +1043,7 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, return nullptr; if (isa<UndefValue>(COp)) { - Indexes[I] = UndefValue::get(MaskEltTy); + Indexes[I] = -1; continue; } @@ -1005,13 +1060,12 @@ static Value *simplifyX86vpermilvar(const IntrinsicInst &II, // shuffle, we have to make that explicit. Index += APInt(32, (I / NumLaneElts) * NumLaneElts); - Indexes[I] = ConstantInt::get(MaskEltTy, Index); + Indexes[I] = Index.getZExtValue(); } - auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, NumElts)); auto V1 = II.getArgOperand(0); auto V2 = UndefValue::get(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, ShuffleMask); + return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, NumElts)); } /// Attempt to convert vpermd/vpermps to shufflevector if the mask is constant. @@ -1022,13 +1076,12 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, return nullptr; auto *VecTy = cast<VectorType>(II.getType()); - auto *MaskEltTy = Type::getInt32Ty(II.getContext()); unsigned Size = VecTy->getNumElements(); assert((Size == 4 || Size == 8 || Size == 16 || Size == 32 || Size == 64) && "Unexpected shuffle mask size"); // Construct a shuffle mask from constant integers or UNDEFs. - Constant *Indexes[64] = {nullptr}; + int Indexes[64]; for (unsigned I = 0; I < Size; ++I) { Constant *COp = V->getAggregateElement(I); @@ -1036,26 +1089,26 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, return nullptr; if (isa<UndefValue>(COp)) { - Indexes[I] = UndefValue::get(MaskEltTy); + Indexes[I] = -1; continue; } uint32_t Index = cast<ConstantInt>(COp)->getZExtValue(); Index &= Size - 1; - Indexes[I] = ConstantInt::get(MaskEltTy, Index); + Indexes[I] = Index; } - auto ShuffleMask = ConstantVector::get(makeArrayRef(Indexes, Size)); auto V1 = II.getArgOperand(0); auto V2 = UndefValue::get(VecTy); - return Builder.CreateShuffleVector(V1, V2, ShuffleMask); + return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes, Size)); } // TODO, Obvious Missing Transforms: // * Narrow width by halfs excluding zero/undef lanes Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { Value *LoadPtr = II.getArgOperand(0); - unsigned Alignment = cast<ConstantInt>(II.getArgOperand(1))->getZExtValue(); + const Align Alignment = + cast<ConstantInt>(II.getArgOperand(1))->getAlignValue(); // If the mask is all ones or undefs, this is a plain vector load of the 1st // argument. @@ -1065,9 +1118,9 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query - if (isDereferenceableAndAlignedPointer( - LoadPtr, II.getType(), MaybeAlign(Alignment), - II.getModule()->getDataLayout(), &II, nullptr)) { + if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment, + II.getModule()->getDataLayout(), &II, + nullptr)) { Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); @@ -1091,8 +1144,7 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { Value *StorePtr = II.getArgOperand(1); - MaybeAlign Alignment( - cast<ConstantInt>(II.getArgOperand(2))->getZExtValue()); + Align Alignment = cast<ConstantInt>(II.getArgOperand(2))->getAlignValue(); return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } @@ -1100,10 +1152,8 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) { - II.setOperand(0, V); - return &II; - } + DemandedElts, UndefElts)) + return replaceOperand(II, 0, V); return nullptr; } @@ -1138,15 +1188,11 @@ Instruction *InstCombiner::simplifyMaskedScatter(IntrinsicInst &II) { APInt DemandedElts = possiblyDemandedEltsInMask(ConstMask); APInt UndefElts(DemandedElts.getBitWidth(), 0); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(0), - DemandedElts, UndefElts)) { - II.setOperand(0, V); - return &II; - } + DemandedElts, UndefElts)) + return replaceOperand(II, 0, V); if (Value *V = SimplifyDemandedVectorElts(II.getOperand(1), - DemandedElts, UndefElts)) { - II.setOperand(1, V); - return &II; - } + DemandedElts, UndefElts)) + return replaceOperand(II, 1, V); return nullptr; } @@ -1202,19 +1248,15 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { if (IsTZ) { // cttz(-x) -> cttz(x) - if (match(Op0, m_Neg(m_Value(X)))) { - II.setOperand(0, X); - return &II; - } + if (match(Op0, m_Neg(m_Value(X)))) + return IC.replaceOperand(II, 0, X); // cttz(abs(x)) -> cttz(x) // cttz(nabs(x)) -> cttz(x) Value *Y; SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; - if (SPF == SPF_ABS || SPF == SPF_NABS) { - II.setOperand(0, X); - return &II; - } + if (SPF == SPF_ABS || SPF == SPF_NABS) + return IC.replaceOperand(II, 0, X); } KnownBits Known = IC.computeKnownBits(Op0, 0, &II); @@ -1240,10 +1282,8 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { if (!Known.One.isNullValue() || isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, &IC.getDominatorTree())) { - if (!match(II.getArgOperand(1), m_One())) { - II.setOperand(1, IC.Builder.getTrue()); - return &II; - } + if (!match(II.getArgOperand(1), m_One())) + return IC.replaceOperand(II, 1, IC.Builder.getTrue()); } // Add range metadata since known bits can't completely reflect what we know. @@ -1264,21 +1304,39 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombiner &IC) { static Instruction *foldCtpop(IntrinsicInst &II, InstCombiner &IC) { assert(II.getIntrinsicID() == Intrinsic::ctpop && "Expected ctpop intrinsic"); + Type *Ty = II.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); Value *Op0 = II.getArgOperand(0); Value *X; + // ctpop(bitreverse(x)) -> ctpop(x) // ctpop(bswap(x)) -> ctpop(x) - if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X)))) { - II.setOperand(0, X); - return &II; + if (match(Op0, m_BitReverse(m_Value(X))) || match(Op0, m_BSwap(m_Value(X)))) + return IC.replaceOperand(II, 0, X); + + // ctpop(x | -x) -> bitwidth - cttz(x, false) + if (Op0->hasOneUse() && + match(Op0, m_c_Or(m_Value(X), m_Neg(m_Deferred(X))))) { + Function *F = + Intrinsic::getDeclaration(II.getModule(), Intrinsic::cttz, Ty); + auto *Cttz = IC.Builder.CreateCall(F, {X, IC.Builder.getFalse()}); + auto *Bw = ConstantInt::get(Ty, APInt(BitWidth, BitWidth)); + return IC.replaceInstUsesWith(II, IC.Builder.CreateSub(Bw, Cttz)); + } + + // ctpop(~x & (x - 1)) -> cttz(x, false) + if (match(Op0, + m_c_And(m_Not(m_Value(X)), m_Add(m_Deferred(X), m_AllOnes())))) { + Function *F = + Intrinsic::getDeclaration(II.getModule(), Intrinsic::cttz, Ty); + return CallInst::Create(F, {X, IC.Builder.getFalse()}); } // FIXME: Try to simplify vectors of integers. - auto *IT = dyn_cast<IntegerType>(Op0->getType()); + auto *IT = dyn_cast<IntegerType>(Ty); if (!IT) return nullptr; - unsigned BitWidth = IT->getBitWidth(); KnownBits Known(BitWidth); IC.computeKnownBits(Op0, Known, 0, &II); @@ -1330,7 +1388,7 @@ static Instruction *simplifyX86MaskedLoad(IntrinsicInst &II, InstCombiner &IC) { // The pass-through vector for an x86 masked load is a zero vector. CallInst *NewMaskedLoad = - IC.Builder.CreateMaskedLoad(PtrCast, 1, BoolMask, ZeroVec); + IC.Builder.CreateMaskedLoad(PtrCast, Align(1), BoolMask, ZeroVec); return IC.replaceInstUsesWith(II, NewMaskedLoad); } @@ -1371,7 +1429,7 @@ static bool simplifyX86MaskedStore(IntrinsicInst &II, InstCombiner &IC) { // on each element's most significant bit (the sign bit). Constant *BoolMask = getNegativeIsTrueBoolVec(ConstMask); - IC.Builder.CreateMaskedStore(Vec, PtrCast, 1, BoolMask); + IC.Builder.CreateMaskedStore(Vec, PtrCast, Align(1), BoolMask); // 'Replace uses' doesn't work for stores. Erase the original masked store. IC.eraseInstFromFunction(II); @@ -1417,7 +1475,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, if (!VecTy->getElementType()->isIntegerTy(8) || NumElts != 8) return nullptr; - uint32_t Indexes[8]; + int Indexes[8]; for (unsigned I = 0; I < NumElts; ++I) { Constant *COp = C->getAggregateElement(I); @@ -1428,15 +1486,13 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, Indexes[I] = cast<ConstantInt>(COp)->getLimitedValue(); // Make sure the mask indices are in range. - if (Indexes[I] >= NumElts) + if ((unsigned)Indexes[I] >= NumElts) return nullptr; } - auto *ShuffleMask = ConstantDataVector::get(II.getContext(), - makeArrayRef(Indexes)); auto *V1 = II.getArgOperand(0); auto *V2 = Constant::getNullValue(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, ShuffleMask); + return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes)); } /// Convert a vector load intrinsic into a simple llvm load instruction. @@ -1458,7 +1514,7 @@ static Value *simplifyNeonVld1(const IntrinsicInst &II, auto *BCastInst = Builder.CreateBitCast(II.getArgOperand(0), PointerType::get(II.getType(), 0)); - return Builder.CreateAlignedLoad(II.getType(), BCastInst, Alignment); + return Builder.CreateAlignedLoad(II.getType(), BCastInst, Align(Alignment)); } // Returns true iff the 2 intrinsics have the same operands, limiting the @@ -1478,24 +1534,30 @@ static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, // start/end intrinsics in between). As this handles only the most trivial // cases, tracking the nesting level is not needed: // -// call @llvm.foo.start(i1 0) ; &I // call @llvm.foo.start(i1 0) -// call @llvm.foo.end(i1 0) ; This one will not be skipped: it will be removed +// call @llvm.foo.start(i1 0) ; This one won't be skipped: it will be removed // call @llvm.foo.end(i1 0) -static bool removeTriviallyEmptyRange(IntrinsicInst &I, unsigned StartID, - unsigned EndID, InstCombiner &IC) { - assert(I.getIntrinsicID() == StartID && - "Start intrinsic does not have expected ID"); - BasicBlock::iterator BI(I), BE(I.getParent()->end()); - for (++BI; BI != BE; ++BI) { - if (auto *E = dyn_cast<IntrinsicInst>(BI)) { - if (isa<DbgInfoIntrinsic>(E) || E->getIntrinsicID() == StartID) +// call @llvm.foo.end(i1 0) ; &I +static bool removeTriviallyEmptyRange( + IntrinsicInst &EndI, InstCombiner &IC, + std::function<bool(const IntrinsicInst &)> IsStart) { + // We start from the end intrinsic and scan backwards, so that InstCombine + // has already processed (and potentially removed) all the instructions + // before the end intrinsic. + BasicBlock::reverse_iterator BI(EndI), BE(EndI.getParent()->rend()); + for (; BI != BE; ++BI) { + if (auto *I = dyn_cast<IntrinsicInst>(&*BI)) { + if (isa<DbgInfoIntrinsic>(I) || + I->getIntrinsicID() == EndI.getIntrinsicID()) + continue; + if (IsStart(*I)) { + if (haveSameOperands(EndI, *I, EndI.getNumArgOperands())) { + IC.eraseInstFromFunction(*I); + IC.eraseInstFromFunction(EndI); + return true; + } + // Skip start intrinsics that don't pair with this end intrinsic. continue; - if (E->getIntrinsicID() == EndID && - haveSameOperands(I, *E, E->getNumArgOperands())) { - IC.eraseInstFromFunction(*E); - IC.eraseInstFromFunction(I); - return true; } } break; @@ -1709,9 +1771,11 @@ static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // intrinsic, we don't have to look up any module metadata, as // FtzRequirementTy will be FTZ_Any.) if (Action.FtzRequirement != FTZ_Any) { - bool FtzEnabled = - II->getFunction()->getFnAttribute("nvptx-f32ftz").getValueAsString() == - "true"; + StringRef Attr = II->getFunction() + ->getFnAttribute("denormal-fp-math-f32") + .getValueAsString(); + DenormalMode Mode = parseDenormalFPAttribute(Attr); + bool FtzEnabled = Mode.Output != DenormalMode::IEEE; if (FtzEnabled != (Action.FtzRequirement == FTZ_MustBeOn)) return nullptr; @@ -1751,13 +1815,11 @@ static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { llvm_unreachable("All SpecialCase enumerators should be handled in switch."); } -Instruction *InstCombiner::visitVAStartInst(VAStartInst &I) { - removeTriviallyEmptyRange(I, Intrinsic::vastart, Intrinsic::vaend, *this); - return nullptr; -} - -Instruction *InstCombiner::visitVACopyInst(VACopyInst &I) { - removeTriviallyEmptyRange(I, Intrinsic::vacopy, Intrinsic::vaend, *this); +Instruction *InstCombiner::visitVAEndInst(VAEndInst &I) { + removeTriviallyEmptyRange(I, *this, [](const IntrinsicInst &I) { + return I.getIntrinsicID() == Intrinsic::vastart || + I.getIntrinsicID() == Intrinsic::vacopy; + }); return nullptr; } @@ -1786,8 +1848,11 @@ Instruction *InstCombiner::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. Instruction *InstCombiner::visitCallInst(CallInst &CI) { - if (Value *V = SimplifyCall(&CI, SQ.getWithInstruction(&CI))) - return replaceInstUsesWith(CI, V); + // Don't try to simplify calls without uses. It will not do anything useful, + // but will result in the following folds being skipped. + if (!CI.use_empty()) + if (Value *V = SimplifyCall(&CI, SQ.getWithInstruction(&CI))) + return replaceInstUsesWith(CI, V); if (isFreeCall(&CI, &TLI)) return visitFree(CI); @@ -1802,6 +1867,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(&CI); if (!II) return visitCallBase(CI); + // For atomic unordered mem intrinsics if len is not a positive or + // not a multiple of element size then behavior is undefined. + if (auto *AMI = dyn_cast<AtomicMemIntrinsic>(II)) + if (ConstantInt *NumBytes = dyn_cast<ConstantInt>(AMI->getLength())) + if (NumBytes->getSExtValue() < 0 || + (NumBytes->getZExtValue() % AMI->getElementSizeInBytes() != 0)) { + CreateNonTerminatorUnreachable(AMI); + assert(AMI->getType()->isVoidTy() && + "non void atomic unordered mem intrinsic"); + return eraseInstFromFunction(*AMI); + } + // Intrinsics cannot occur in an invoke or a callbr, so handle them here // instead of in visitCallBase. if (auto *MI = dyn_cast<AnyMemIntrinsic>(II)) { @@ -1863,9 +1940,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } - // For vector result intrinsics, use the generic demanded vector support. - if (II->getType()->isVectorTy()) { - auto VWidth = II->getType()->getVectorNumElements(); + // For fixed width vector result intrinsics, use the generic demanded vector + // support. + if (auto *IIFVTy = dyn_cast<FixedVectorType>(II->getType())) { + auto VWidth = IIFVTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { @@ -1958,10 +2036,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Canonicalize a shift amount constant operand to modulo the bit-width. Constant *WidthC = ConstantInt::get(Ty, BitWidth); Constant *ModuloC = ConstantExpr::getURem(ShAmtC, WidthC); - if (ModuloC != ShAmtC) { - II->setArgOperand(2, ModuloC); - return II; - } + if (ModuloC != ShAmtC) + return replaceOperand(*II, 2, ModuloC); + assert(ConstantExpr::getICmp(ICmpInst::ICMP_UGT, WidthC, ShAmtC) == ConstantInt::getTrue(CmpInst::makeCmpResultType(Ty)) && "Shift amount expected to be modulo bitwidth"); @@ -2189,7 +2266,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { llvm_unreachable("unexpected intrinsic ID"); } Value *NewCall = Builder.CreateBinaryIntrinsic(NewIID, X, Y, II); - Instruction *FNeg = BinaryOperator::CreateFNeg(NewCall); + Instruction *FNeg = UnaryOperator::CreateFNeg(NewCall); FNeg->copyIRFlags(II); return FNeg; } @@ -2220,12 +2297,31 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { llvm_unreachable("unexpected intrinsic ID"); } Instruction *NewCall = Builder.CreateBinaryIntrinsic( - IID, X, ConstantFP::get(Arg0->getType(), Res)); - NewCall->copyIRFlags(II); + IID, X, ConstantFP::get(Arg0->getType(), Res), II); + // TODO: Conservatively intersecting FMF. If Res == C2, the transform + // was a simplification (so Arg0 and its original flags could + // propagate?) + NewCall->andIRFlags(M); return replaceInstUsesWith(*II, NewCall); } } + Value *ExtSrc0; + Value *ExtSrc1; + + // minnum (fpext x), (fpext y) -> minnum x, y + // maxnum (fpext x), (fpext y) -> maxnum x, y + if (match(II->getArgOperand(0), m_OneUse(m_FPExt(m_Value(ExtSrc0)))) && + match(II->getArgOperand(1), m_OneUse(m_FPExt(m_Value(ExtSrc1)))) && + ExtSrc0->getType() == ExtSrc1->getType()) { + Function *F = Intrinsic::getDeclaration( + II->getModule(), II->getIntrinsicID(), {ExtSrc0->getType()}); + CallInst *NewCall = Builder.CreateCall(F, { ExtSrc0, ExtSrc1 }); + NewCall->copyFastMathFlags(II); + NewCall->takeName(II); + return new FPExtInst(NewCall, II->getType()); + } + break; } case Intrinsic::fmuladd: { @@ -2260,16 +2356,16 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src1 = II->getArgOperand(1); Value *X, *Y; if (match(Src0, m_FNeg(m_Value(X))) && match(Src1, m_FNeg(m_Value(Y)))) { - II->setArgOperand(0, X); - II->setArgOperand(1, Y); + replaceOperand(*II, 0, X); + replaceOperand(*II, 1, Y); return II; } // fma fabs(x), fabs(x), z -> fma x, x, z if (match(Src0, m_FAbs(m_Value(X))) && match(Src1, m_FAbs(m_Specific(X)))) { - II->setArgOperand(0, X); - II->setArgOperand(1, X); + replaceOperand(*II, 0, X); + replaceOperand(*II, 1, X); return II; } @@ -2283,6 +2379,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return FAdd; } + // fma x, y, 0 -> fmul x, y + // This is always valid for -0.0, but requires nsz for +0.0 as + // -0.0 + 0.0 = 0.0, which would not be the same as the fmul on its own. + if (match(II->getArgOperand(2), m_NegZeroFP()) || + (match(II->getArgOperand(2), m_PosZeroFP()) && + II->getFastMathFlags().noSignedZeros())) + return BinaryOperator::CreateFMulFMF(Src0, Src1, II); + break; } case Intrinsic::copysign: { @@ -2307,10 +2411,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // copysign X, (copysign ?, SignArg) --> copysign X, SignArg Value *SignArg; if (match(II->getArgOperand(1), - m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg)))) { - II->setArgOperand(1, SignArg); - return II; - } + m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg)))) + return replaceOperand(*II, 1, SignArg); break; } @@ -2329,6 +2431,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ceil: case Intrinsic::floor: case Intrinsic::round: + case Intrinsic::roundeven: case Intrinsic::nearbyint: case Intrinsic::rint: case Intrinsic::trunc: { @@ -2347,8 +2450,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (match(Src, m_FNeg(m_Value(X))) || match(Src, m_FAbs(m_Value(X)))) { // cos(-x) -> cos(x) // cos(fabs(x)) -> cos(x) - II->setArgOperand(0, X); - return II; + return replaceOperand(*II, 0, X); } break; } @@ -2357,7 +2459,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (match(II->getArgOperand(0), m_OneUse(m_FNeg(m_Value(X))))) { // sin(-x) --> -sin(x) Value *NewSin = Builder.CreateUnaryIntrinsic(Intrinsic::sin, X, II); - Instruction *FNeg = BinaryOperator::CreateFNeg(NewSin); + Instruction *FNeg = UnaryOperator::CreateFNeg(NewSin); FNeg->copyFastMathFlags(II); return FNeg; } @@ -2366,11 +2468,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::ppc_altivec_lvx: case Intrinsic::ppc_altivec_lvxl: // Turn PPC lvx -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC, &DT) >= 16) { Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr); + return new LoadInst(II->getType(), Ptr, "", false, Align(16)); } break; case Intrinsic::ppc_vsx_lxvw4x: @@ -2378,17 +2480,17 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX loads into normal loads. Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, Twine(""), false, Align::None()); + return new LoadInst(II->getType(), Ptr, Twine(""), false, Align(1)); } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: // Turn stvx -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC, &DT) >= 16) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr); + return new StoreInst(II->getArgOperand(0), Ptr, false, Align(16)); } break; case Intrinsic::ppc_vsx_stxvw4x: @@ -2396,14 +2498,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX stores into normal stores. Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, Align::None()); + return new StoreInst(II->getArgOperand(0), Ptr, false, Align(1)); } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 16, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(16), DL, II, &AC, &DT) >= 16) { - Type *VTy = VectorType::get(Builder.getFloatTy(), - II->getType()->getVectorNumElements()); + Type *VTy = + VectorType::get(Builder.getFloatTy(), + cast<VectorType>(II->getType())->getElementCount()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(VTy)); Value *Load = Builder.CreateLoad(VTy, Ptr); @@ -2412,33 +2515,34 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::ppc_qpx_qvlfd: // Turn PPC QPX qvlfd -> load if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(0), 32, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(0), Align(32), DL, II, &AC, &DT) >= 32) { Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr); + return new LoadInst(II->getType(), Ptr, "", false, Align(32)); } break; case Intrinsic::ppc_qpx_qvstfs: // Turn PPC QPX qvstfs -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 16, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(16), DL, II, &AC, &DT) >= 16) { - Type *VTy = VectorType::get(Builder.getFloatTy(), - II->getArgOperand(0)->getType()->getVectorNumElements()); + Type *VTy = VectorType::get( + Builder.getFloatTy(), + cast<VectorType>(II->getArgOperand(0)->getType())->getElementCount()); Value *TOp = Builder.CreateFPTrunc(II->getArgOperand(0), VTy); Type *OpPtrTy = PointerType::getUnqual(VTy); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(TOp, Ptr); + return new StoreInst(TOp, Ptr, false, Align(16)); } break; case Intrinsic::ppc_qpx_qvstfd: // Turn PPC QPX qvstfd -> store if the pointer is known aligned. - if (getOrEnforceKnownAlignment(II->getArgOperand(1), 32, DL, II, &AC, + if (getOrEnforceKnownAlignment(II->getArgOperand(1), Align(32), DL, II, &AC, &DT) >= 32) { Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr); + return new StoreInst(II->getArgOperand(0), Ptr, false, Align(32)); } break; @@ -2546,50 +2650,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::x86_vcvtph2ps_128: - case Intrinsic::x86_vcvtph2ps_256: { - auto Arg = II->getArgOperand(0); - auto ArgType = cast<VectorType>(Arg->getType()); - auto RetType = cast<VectorType>(II->getType()); - unsigned ArgWidth = ArgType->getNumElements(); - unsigned RetWidth = RetType->getNumElements(); - assert(RetWidth <= ArgWidth && "Unexpected input/return vector widths"); - assert(ArgType->isIntOrIntVectorTy() && - ArgType->getScalarSizeInBits() == 16 && - "CVTPH2PS input type should be 16-bit integer vector"); - assert(RetType->getScalarType()->isFloatTy() && - "CVTPH2PS output type should be 32-bit float vector"); - - // Constant folding: Convert to generic half to single conversion. - if (isa<ConstantAggregateZero>(Arg)) - return replaceInstUsesWith(*II, ConstantAggregateZero::get(RetType)); - - if (isa<ConstantDataVector>(Arg)) { - auto VectorHalfAsShorts = Arg; - if (RetWidth < ArgWidth) { - SmallVector<uint32_t, 8> SubVecMask; - for (unsigned i = 0; i != RetWidth; ++i) - SubVecMask.push_back((int)i); - VectorHalfAsShorts = Builder.CreateShuffleVector( - Arg, UndefValue::get(ArgType), SubVecMask); - } - - auto VectorHalfType = - VectorType::get(Type::getHalfTy(II->getContext()), RetWidth); - auto VectorHalfs = - Builder.CreateBitCast(VectorHalfAsShorts, VectorHalfType); - auto VectorFloats = Builder.CreateFPExt(VectorHalfs, RetType); - return replaceInstUsesWith(*II, VectorFloats); - } - - // We only use the lowest lanes of the argument. - if (Value *V = SimplifyDemandedVectorEltsLow(Arg, ArgWidth, RetWidth)) { - II->setArgOperand(0, V); - return II; - } - break; - } - case Intrinsic::x86_sse_cvtss2si: case Intrinsic::x86_sse_cvtss2si64: case Intrinsic::x86_sse_cvttss2si: @@ -2617,11 +2677,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // These intrinsics only demand the 0th element of their input vectors. If // we can simplify the input based on that, do so now. Value *Arg = II->getArgOperand(0); - unsigned VWidth = Arg->getType()->getVectorNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg, VWidth, 1)) { - II->setArgOperand(0, V); - return II; - } + unsigned VWidth = cast<VectorType>(Arg->getType())->getNumElements(); + if (Value *V = SimplifyDemandedVectorEltsLow(Arg, VWidth, 1)) + return replaceOperand(*II, 0, V); break; } @@ -2669,13 +2727,13 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { bool MadeChange = false; Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg0->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements(); if (Value *V = SimplifyDemandedVectorEltsLow(Arg0, VWidth, 1)) { - II->setArgOperand(0, V); + replaceOperand(*II, 0, V); MadeChange = true; } if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, 1)) { - II->setArgOperand(1, V); + replaceOperand(*II, 1, V); MadeChange = true; } if (MadeChange) @@ -2707,8 +2765,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { cast<Instruction>(Arg0)->getFastMathFlags().noInfs())) { if (Arg0IsZero) std::swap(A, B); - II->setArgOperand(0, A); - II->setArgOperand(1, B); + replaceOperand(*II, 0, A); + replaceOperand(*II, 1, B); return II; } break; @@ -2800,8 +2858,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // We don't need a select if we know the mask bit is a 1. if (!C || !C->getValue()[0]) { // Cast the mask to an i1 vector and then extract the lowest element. - auto *MaskTy = VectorType::get(Builder.getInt1Ty(), - cast<IntegerType>(Mask->getType())->getBitWidth()); + auto *MaskTy = FixedVectorType::get( + Builder.getInt1Ty(), + cast<IntegerType>(Mask->getType())->getBitWidth()); Mask = Builder.CreateBitCast(Mask, MaskTy); Mask = Builder.CreateExtractElement(Mask, (uint64_t)0); // Extract the lowest element from the passthru operand. @@ -2887,12 +2946,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Arg1 = II->getArgOperand(1); assert(Arg1->getType()->getPrimitiveSizeInBits() == 128 && "Unexpected packed shift size"); - unsigned VWidth = Arg1->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(Arg1->getType())->getNumElements(); - if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, VWidth / 2)) { - II->setArgOperand(1, V); - return II; - } + if (Value *V = SimplifyDemandedVectorEltsLow(Arg1, VWidth, VWidth / 2)) + return replaceOperand(*II, 1, V); break; } @@ -2956,14 +3013,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { bool MadeChange = false; Value *Arg0 = II->getArgOperand(0); Value *Arg1 = II->getArgOperand(1); - unsigned VWidth = Arg0->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(Arg0->getType())->getNumElements(); APInt UndefElts1(VWidth, 0); APInt DemandedElts1 = APInt::getSplat(VWidth, APInt(2, (Imm & 0x01) ? 2 : 1)); if (Value *V = SimplifyDemandedVectorElts(Arg0, DemandedElts1, UndefElts1)) { - II->setArgOperand(0, V); + replaceOperand(*II, 0, V); MadeChange = true; } @@ -2972,7 +3029,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { APInt(2, (Imm & 0x10) ? 2 : 1)); if (Value *V = SimplifyDemandedVectorElts(Arg1, DemandedElts2, UndefElts2)) { - II->setArgOperand(1, V); + replaceOperand(*II, 1, V); MadeChange = true; } @@ -2996,8 +3053,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse4a_extrq: { Value *Op0 = II->getArgOperand(0); Value *Op1 = II->getArgOperand(1); - unsigned VWidth0 = Op0->getType()->getVectorNumElements(); - unsigned VWidth1 = Op1->getType()->getVectorNumElements(); + unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements(); + unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements(); assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 && VWidth1 == 16 && "Unexpected operand sizes"); @@ -3019,11 +3076,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // operands and the lowest 16-bits of the second. bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { - II->setArgOperand(0, V); + replaceOperand(*II, 0, V); MadeChange = true; } if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 2)) { - II->setArgOperand(1, V); + replaceOperand(*II, 1, V); MadeChange = true; } if (MadeChange) @@ -3035,7 +3092,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // EXTRQI: Extract Length bits starting from Index. Zero pad the remaining // bits of the lower 64-bits. The upper 64-bits are undefined. Value *Op0 = II->getArgOperand(0); - unsigned VWidth = Op0->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements(); assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 && "Unexpected operand size"); @@ -3049,20 +3106,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // EXTRQI only uses the lowest 64-bits of the first 128-bit vector // operand. - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) { - II->setArgOperand(0, V); - return II; - } + if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) + return replaceOperand(*II, 0, V); break; } case Intrinsic::x86_sse4a_insertq: { Value *Op0 = II->getArgOperand(0); Value *Op1 = II->getArgOperand(1); - unsigned VWidth = Op0->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(Op0->getType())->getNumElements(); assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth == 2 && - Op1->getType()->getVectorNumElements() == 2 && + cast<VectorType>(Op1->getType())->getNumElements() == 2 && "Unexpected operand size"); // See if we're dealing with constant values. @@ -3082,10 +3137,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // INSERTQ only uses the lowest 64-bits of the first 128-bit vector // operand. - if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) { - II->setArgOperand(0, V); - return II; - } + if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth, 1)) + return replaceOperand(*II, 0, V); break; } @@ -3095,8 +3148,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // undefined. Value *Op0 = II->getArgOperand(0); Value *Op1 = II->getArgOperand(1); - unsigned VWidth0 = Op0->getType()->getVectorNumElements(); - unsigned VWidth1 = Op1->getType()->getVectorNumElements(); + unsigned VWidth0 = cast<VectorType>(Op0->getType())->getNumElements(); + unsigned VWidth1 = cast<VectorType>(Op1->getType())->getNumElements(); assert(Op0->getType()->getPrimitiveSizeInBits() == 128 && Op1->getType()->getPrimitiveSizeInBits() == 128 && VWidth0 == 2 && VWidth1 == 2 && "Unexpected operand sizes"); @@ -3117,11 +3170,11 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // operands. bool MadeChange = false; if (Value *V = SimplifyDemandedVectorEltsLow(Op0, VWidth0, 1)) { - II->setArgOperand(0, V); + replaceOperand(*II, 0, V); MadeChange = true; } if (Value *V = SimplifyDemandedVectorEltsLow(Op1, VWidth1, 1)) { - II->setArgOperand(1, V); + replaceOperand(*II, 1, V); MadeChange = true; } if (MadeChange) @@ -3163,8 +3216,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getType()->getPrimitiveSizeInBits() && "Not expecting mask and operands with different sizes"); - unsigned NumMaskElts = Mask->getType()->getVectorNumElements(); - unsigned NumOperandElts = II->getType()->getVectorNumElements(); + unsigned NumMaskElts = + cast<VectorType>(Mask->getType())->getNumElements(); + unsigned NumOperandElts = + cast<VectorType>(II->getType())->getNumElements(); if (NumMaskElts == NumOperandElts) return SelectInst::Create(BoolVec, Op1, Op0); @@ -3255,7 +3310,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // the permutation mask with respect to 31 and reverse the order of // V1 and V2. if (Constant *Mask = dyn_cast<Constant>(II->getArgOperand(2))) { - assert(Mask->getType()->getVectorNumElements() == 16 && + assert(cast<VectorType>(Mask->getType())->getNumElements() == 16 && "Bad type for intrinsic!"); // Check that all of the elements are integer constants or undefs. @@ -3307,9 +3362,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; case Intrinsic::arm_neon_vld1: { - unsigned MemAlign = getKnownAlignment(II->getArgOperand(0), - DL, II, &AC, &DT); - if (Value *V = simplifyNeonVld1(*II, MemAlign, Builder)) + Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); + if (Value *V = simplifyNeonVld1(*II, MemAlign.value(), Builder)) return replaceInstUsesWith(*II, V); break; } @@ -3327,16 +3381,14 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::arm_neon_vst2lane: case Intrinsic::arm_neon_vst3lane: case Intrinsic::arm_neon_vst4lane: { - unsigned MemAlign = - getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); + Align MemAlign = getKnownAlignment(II->getArgOperand(0), DL, II, &AC, &DT); unsigned AlignArg = II->getNumArgOperands() - 1; - ConstantInt *IntrAlign = dyn_cast<ConstantInt>(II->getArgOperand(AlignArg)); - if (IntrAlign && IntrAlign->getZExtValue() < MemAlign) { - II->setArgOperand(AlignArg, - ConstantInt::get(Type::getInt32Ty(II->getContext()), - MemAlign, false)); - return II; - } + Value *AlignArgOp = II->getArgOperand(AlignArg); + MaybeAlign Align = cast<ConstantInt>(AlignArgOp)->getMaybeAlignValue(); + if (Align && *Align < MemAlign) + return replaceOperand(*II, AlignArg, + ConstantInt::get(Type::getInt32Ty(II->getContext()), + MemAlign.value(), false)); break; } @@ -3395,8 +3447,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Data, *Key; if (match(KeyArg, m_ZeroInt()) && match(DataArg, m_Xor(m_Value(Data), m_Value(Key)))) { - II->setArgOperand(0, Data); - II->setArgOperand(1, Key); + replaceOperand(*II, 0, Data); + replaceOperand(*II, 1, Key); return II; } break; @@ -3415,7 +3467,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (auto *CI = dyn_cast<ConstantInt>(XorMask)) { if (CI->getValue().trunc(16).isAllOnesValue()) { auto TrueVector = Builder.CreateVectorSplat( - II->getType()->getVectorNumElements(), Builder.getTrue()); + cast<VectorType>(II->getType())->getNumElements(), + Builder.getTrue()); return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector); } } @@ -3459,18 +3512,25 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src = II->getArgOperand(0); // TODO: Move to ConstantFolding/InstSimplify? - if (isa<UndefValue>(Src)) - return replaceInstUsesWith(CI, Src); + if (isa<UndefValue>(Src)) { + Type *Ty = II->getType(); + auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); + return replaceInstUsesWith(CI, QNaN); + } + + if (II->isStrictFP()) + break; if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { const APFloat &ArgVal = C->getValueAPF(); APFloat Val(ArgVal.getSemantics(), 1); - APFloat::opStatus Status = Val.divide(ArgVal, - APFloat::rmNearestTiesToEven); - // Only do this if it was exact and therefore not dependent on the - // rounding mode. - if (Status == APFloat::opOK) - return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); + Val.divide(ArgVal, APFloat::rmNearestTiesToEven); + + // This is more precise than the instruction may give. + // + // TODO: The instruction always flushes denormal results (except for f16), + // should this also? + return replaceInstUsesWith(CI, ConstantFP::get(II->getContext(), Val)); } break; @@ -3479,8 +3539,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *Src = II->getArgOperand(0); // TODO: Move to ConstantFolding/InstSimplify? - if (isa<UndefValue>(Src)) - return replaceInstUsesWith(CI, Src); + if (isa<UndefValue>(Src)) { + Type *Ty = II->getType(); + auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); + return replaceInstUsesWith(CI, QNaN); + } + break; } case Intrinsic::amdgcn_frexp_mant: @@ -3563,11 +3627,9 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other - if (((Mask & S_NAN) || (Mask & Q_NAN)) && isKnownNeverNaN(Src0, &TLI)) { - II->setArgOperand(1, ConstantInt::get(Src1->getType(), - Mask & ~(S_NAN | Q_NAN))); - return II; - } + if (((Mask & S_NAN) || (Mask & Q_NAN)) && isKnownNeverNaN(Src0, &TLI)) + return replaceOperand(*II, 1, ConstantInt::get(Src1->getType(), + Mask & ~(S_NAN | Q_NAN))); const ConstantFP *CVal = dyn_cast<ConstantFP>(Src0); if (!CVal) { @@ -3657,23 +3719,19 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if ((Width & (IntSize - 1)) == 0) return replaceInstUsesWith(*II, ConstantInt::getNullValue(Ty)); - if (Width >= IntSize) { - // Hardware ignores high bits, so remove those. - II->setArgOperand(2, ConstantInt::get(CWidth->getType(), - Width & (IntSize - 1))); - return II; - } + // Hardware ignores high bits, so remove those. + if (Width >= IntSize) + return replaceOperand(*II, 2, ConstantInt::get(CWidth->getType(), + Width & (IntSize - 1))); } unsigned Offset; ConstantInt *COffset = dyn_cast<ConstantInt>(II->getArgOperand(1)); if (COffset) { Offset = COffset->getZExtValue(); - if (Offset >= IntSize) { - II->setArgOperand(1, ConstantInt::get(COffset->getType(), - Offset & (IntSize - 1))); - return II; - } + if (Offset >= IntSize) + return replaceOperand(*II, 1, ConstantInt::get(COffset->getType(), + Offset & (IntSize - 1))); } bool Signed = IID == Intrinsic::amdgcn_sbfe; @@ -3716,7 +3774,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { (IsCompr && ((EnBits & (0x3 << (2 * I))) == 0))) { Value *Src = II->getArgOperand(I + 2); if (!isa<UndefValue>(Src)) { - II->setArgOperand(I + 2, UndefValue::get(Src->getType())); + replaceOperand(*II, I + 2, UndefValue::get(Src->getType())); Changed = true; } } @@ -3855,8 +3913,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { ((match(Src1, m_One()) && match(Src0, m_ZExt(m_Value(ExtSrc)))) || (match(Src1, m_AllOnes()) && match(Src0, m_SExt(m_Value(ExtSrc))))) && ExtSrc->getType()->isIntegerTy(1)) { - II->setArgOperand(1, ConstantInt::getNullValue(Src1->getType())); - II->setArgOperand(2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); + replaceOperand(*II, 1, ConstantInt::getNullValue(Src1->getType())); + replaceOperand(*II, 2, ConstantInt::get(CC->getType(), CmpInst::ICMP_NE)); return II; } @@ -3928,6 +3986,35 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_ballot: { + if (auto *Src = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + if (Src->isZero()) { + // amdgcn.ballot(i1 0) is zero. + return replaceInstUsesWith(*II, Constant::getNullValue(II->getType())); + } + + if (Src->isOne()) { + // amdgcn.ballot(i1 1) is exec. + const char *RegName = "exec"; + if (II->getType()->isIntegerTy(32)) + RegName = "exec_lo"; + else if (!II->getType()->isIntegerTy(64)) + break; + + Function *NewF = Intrinsic::getDeclaration( + II->getModule(), Intrinsic::read_register, II->getType()); + Metadata *MDArgs[] = {MDString::get(II->getContext(), RegName)}; + MDNode *MD = MDNode::get(II->getContext(), MDArgs); + Value *Args[] = {MetadataAsValue::get(II->getContext(), MD)}; + CallInst *NewCall = Builder.CreateCall(NewF, Args); + NewCall->addAttribute(AttributeList::FunctionIndex, + Attribute::Convergent); + NewCall->takeName(II); + return replaceInstUsesWith(*II, NewCall); + } + } + break; + } case Intrinsic::amdgcn_wqm_vote: { // wqm_vote is identity when the argument is constant. if (!isa<Constant>(II->getArgOperand(0))) @@ -3956,8 +4043,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; // If bound_ctrl = 1, row mask = bank mask = 0xf we can omit old value. - II->setOperand(0, UndefValue::get(Old->getType())); - return II; + return replaceOperand(*II, 0, UndefValue::get(Old->getType())); + } + case Intrinsic::amdgcn_permlane16: + case Intrinsic::amdgcn_permlanex16: { + // Discard vdst_in if it's not going to be read. + Value *VDstIn = II->getArgOperand(0); + if (isa<UndefValue>(VDstIn)) + break; + + ConstantInt *FetchInvalid = cast<ConstantInt>(II->getArgOperand(4)); + ConstantInt *BoundCtrl = cast<ConstantInt>(II->getArgOperand(5)); + if (!FetchInvalid->getZExtValue() && !BoundCtrl->getZExtValue()) + break; + + return replaceOperand(*II, 0, UndefValue::get(VDstIn->getType())); } case Intrinsic::amdgcn_readfirstlane: case Intrinsic::amdgcn_readlane: { @@ -3990,6 +4090,71 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_ldexp: { + // FIXME: This doesn't introduce new instructions and belongs in + // InstructionSimplify. + Type *Ty = II->getType(); + Value *Op0 = II->getArgOperand(0); + Value *Op1 = II->getArgOperand(1); + + // Folding undef to qnan is safe regardless of the FP mode. + if (isa<UndefValue>(Op0)) { + auto *QNaN = ConstantFP::get(Ty, APFloat::getQNaN(Ty->getFltSemantics())); + return replaceInstUsesWith(*II, QNaN); + } + + const APFloat *C = nullptr; + match(Op0, m_APFloat(C)); + + // FIXME: Should flush denorms depending on FP mode, but that's ignored + // everywhere else. + // + // These cases should be safe, even with strictfp. + // ldexp(0.0, x) -> 0.0 + // ldexp(-0.0, x) -> -0.0 + // ldexp(inf, x) -> inf + // ldexp(-inf, x) -> -inf + if (C && (C->isZero() || C->isInfinity())) + return replaceInstUsesWith(*II, Op0); + + // With strictfp, be more careful about possibly needing to flush denormals + // or not, and snan behavior depends on ieee_mode. + if (II->isStrictFP()) + break; + + if (C && C->isNaN()) { + // FIXME: We just need to make the nan quiet here, but that's unavailable + // on APFloat, only IEEEfloat + auto *Quieted = ConstantFP::get( + Ty, scalbn(*C, 0, APFloat::rmNearestTiesToEven)); + return replaceInstUsesWith(*II, Quieted); + } + + // ldexp(x, 0) -> x + // ldexp(x, undef) -> x + if (isa<UndefValue>(Op1) || match(Op1, m_ZeroInt())) + return replaceInstUsesWith(*II, Op0); + + break; + } + case Intrinsic::hexagon_V6_vandvrt: + case Intrinsic::hexagon_V6_vandvrt_128B: { + // Simplify Q -> V -> Q conversion. + if (auto Op0 = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { + Intrinsic::ID ID0 = Op0->getIntrinsicID(); + if (ID0 != Intrinsic::hexagon_V6_vandqrt && + ID0 != Intrinsic::hexagon_V6_vandqrt_128B) + break; + Value *Bytes = Op0->getArgOperand(1), *Mask = II->getArgOperand(1); + uint64_t Bytes1 = computeKnownBits(Bytes, 0, Op0).One.getZExtValue(); + uint64_t Mask1 = computeKnownBits(Mask, 0, II).One.getZExtValue(); + // Check if every byte has common bits in Bytes and Mask. + uint64_t C = Bytes1 & Mask1; + if ((C & 0xFF) && (C & 0xFF00) && (C & 0xFF0000) && (C & 0xFF000000)) + return replaceInstUsesWith(*II, Op0->getArgOperand(0)); + } + break; + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -4040,7 +4205,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return eraseInstFromFunction(CI); break; } - case Intrinsic::lifetime_start: + case Intrinsic::lifetime_end: // Asan needs to poison memory to detect invalid access which is possible // even for empty lifetime range. if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || @@ -4048,34 +4213,41 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) break; - if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, - Intrinsic::lifetime_end, *this)) + if (removeTriviallyEmptyRange(*II, *this, [](const IntrinsicInst &I) { + return I.getIntrinsicID() == Intrinsic::lifetime_start; + })) return nullptr; break; case Intrinsic::assume: { Value *IIOperand = II->getArgOperand(0); + SmallVector<OperandBundleDef, 4> OpBundles; + II->getOperandBundlesAsDefs(OpBundles); + bool HasOpBundles = !OpBundles.empty(); // Remove an assume if it is followed by an identical assume. // TODO: Do we need this? Unless there are conflicting assumptions, the // computeKnownBits(IIOperand) below here eliminates redundant assumes. Instruction *Next = II->getNextNonDebugInstruction(); - if (match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand)))) + if (HasOpBundles && + match(Next, m_Intrinsic<Intrinsic::assume>(m_Specific(IIOperand))) && + !cast<IntrinsicInst>(Next)->hasOperandBundles()) return eraseInstFromFunction(CI); // Canonicalize assume(a && b) -> assume(a); assume(b); // Note: New assumption intrinsics created here are registered by // the InstCombineIRInserter object. FunctionType *AssumeIntrinsicTy = II->getFunctionType(); - Value *AssumeIntrinsic = II->getCalledValue(); + Value *AssumeIntrinsic = II->getCalledOperand(); Value *A, *B; if (match(IIOperand, m_And(m_Value(A), m_Value(B)))) { - Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, II->getName()); + Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, A, OpBundles, + II->getName()); Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, B, II->getName()); return eraseInstFromFunction(*II); } // assume(!(a || b)) -> assume(!a); assume(!b); if (match(IIOperand, m_Not(m_Or(m_Value(A), m_Value(B))))) { Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, - Builder.CreateNot(A), II->getName()); + Builder.CreateNot(A), OpBundles, II->getName()); Builder.CreateCall(AssumeIntrinsicTy, AssumeIntrinsic, Builder.CreateNot(B), II->getName()); return eraseInstFromFunction(*II); @@ -4091,7 +4263,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { isValidAssumeForContext(II, LHS, &DT)) { MDNode *MD = MDNode::get(II->getContext(), None); LHS->setMetadata(LLVMContext::MD_nonnull, MD); - return eraseInstFromFunction(*II); + if (!HasOpBundles) + return eraseInstFromFunction(*II); // TODO: apply nonnull return attributes to calls and invokes // TODO: apply range metadata for range check patterns? @@ -4101,7 +4274,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // then this one is redundant, and should be removed. KnownBits Known(1); computeKnownBits(IIOperand, Known, 0, II); - if (Known.isAllOnes()) + if (Known.isAllOnes() && isAssumeWithEmptyBundle(*II)) return eraseInstFromFunction(*II); // Update the cache of affected values for this assumption (we might be @@ -4117,10 +4290,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (GCR.getBasePtr() == GCR.getDerivedPtr() && GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { auto *OpIntTy = GCR.getOperand(2)->getType(); - II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); - return II; + return replaceOperand(*II, 2, + ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); } - + // Translate facts known about a pointer before relocating into // facts about the relocate value, while being careful to // preserve relocation semantics. @@ -4187,7 +4360,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { MoveI = MoveI->getNextNonDebugInstruction(); Temp->moveBefore(II); } - II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); + replaceOperand(*II, 0, Builder.CreateAnd(CurrCond, NextCond)); } eraseInstFromFunction(*NextInst); return II; @@ -4232,13 +4405,14 @@ static bool isSafeToEliminateVarargsCast(const CallBase &Call, // TODO: This is probably something which should be expanded to all // intrinsics since the entire point of intrinsics is that // they are understandable by the optimizer. - if (isStatepoint(&Call) || isGCRelocate(&Call) || isGCResult(&Call)) + if (isa<GCStatepointInst>(Call) || isa<GCRelocateInst>(Call) || + isa<GCResultInst>(Call)) return false; // The size of ByVal or InAlloca arguments is derived from the type, so we // can't change to a type with a different size. If the size were // passed explicitly we could avoid this check. - if (!Call.isByValOrInAllocaArgument(ix)) + if (!Call.isPassPointeeByValueArgument(ix)) return true; Type* SrcTy = @@ -4264,7 +4438,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { }; LibCallSimplifier Simplifier(DL, &TLI, ORE, BFI, PSI, InstCombineRAUW, InstCombineErase); - if (Value *With = Simplifier.optimizeCall(CI)) { + if (Value *With = Simplifier.optimizeCall(CI, Builder)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); } @@ -4353,7 +4527,8 @@ static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0)); ConstantInt *Op1C = (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1)); - // Bail out if the allocation size is zero. + // Bail out if the allocation size is zero (or an invalid alignment of zero + // with aligned_alloc). if ((Op0C && Op0C->isNullValue()) || (Op1C && Op1C->isNullValue())) return; @@ -4366,6 +4541,18 @@ static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { Call.addAttribute(AttributeList::ReturnIndex, Attribute::getWithDereferenceableOrNullBytes( Call.getContext(), Op0C->getZExtValue())); + } else if (isAlignedAllocLikeFn(&Call, TLI) && Op1C) { + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); + // Add alignment attribute if alignment is a power of two constant. + if (Op0C && Op0C->getValue().ult(llvm::Value::MaximumAlignment)) { + uint64_t AlignmentVal = Op0C->getZExtValue(); + if (llvm::isPowerOf2_64(AlignmentVal)) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithAlignment(Call.getContext(), + Align(AlignmentVal))); + } } else if (isReallocLikeFn(&Call, TLI) && Op1C) { Call.addAttribute(AttributeList::ReturnIndex, Attribute::getWithDereferenceableOrNullBytes( @@ -4430,7 +4617,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { // If the callee is a pointer to a function, attempt to move any casts to the // arguments of the call/callbr/invoke. - Value *Callee = Call.getCalledValue(); + Value *Callee = Call.getCalledOperand(); if (!isa<Function>(Callee) && transformConstExprCastCall(Call)) return nullptr; @@ -4500,7 +4687,7 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { I != E; ++I, ++ix) { CastInst *CI = dyn_cast<CastInst>(*I); if (CI && isSafeToEliminateVarargsCast(Call, DL, CI, ix)) { - *I = CI->getOperand(0); + replaceUse(*I, CI->getOperand(0)); // Update the byval type to match the argument type. if (Call.isByValArgument(ix)) { @@ -4531,6 +4718,15 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { if (I) return eraseInstFromFunction(*I); } + if (!Call.use_empty() && !Call.isMustTailCall()) + if (Value *ReturnedArg = Call.getReturnedArgOperand()) { + Type *CallTy = Call.getType(); + Type *RetArgTy = ReturnedArg->getType(); + if (RetArgTy->canLosslesslyBitCastTo(CallTy)) + return replaceInstUsesWith( + Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); + } + if (isAllocLikeFn(&Call, &TLI)) return visitAllocSite(Call); @@ -4540,7 +4736,8 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { /// If the callee is a constexpr cast of a function, attempt to move the cast to /// the arguments of the call/callbr/invoke. bool InstCombiner::transformConstExprCastCall(CallBase &Call) { - auto *Callee = dyn_cast<Function>(Call.getCalledValue()->stripPointerCasts()); + auto *Callee = + dyn_cast<Function>(Call.getCalledOperand()->stripPointerCasts()); if (!Callee) return false; @@ -4618,6 +4815,7 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { // // Similarly, avoid folding away bitcasts of byval calls. if (Callee->getAttributes().hasAttrSomewhere(Attribute::InAlloca) || + Callee->getAttributes().hasAttrSomewhere(Attribute::Preallocated) || Callee->getAttributes().hasAttrSomewhere(Attribute::ByVal)) return false; @@ -4658,7 +4856,7 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { // If the callee is just a declaration, don't change the varargsness of the // call. We don't want to introduce a varargs call where one doesn't // already exist. - PointerType *APTy = cast<PointerType>(Call.getCalledValue()->getType()); + PointerType *APTy = cast<PointerType>(Call.getCalledOperand()->getType()); if (FT->isVarArg()!=cast<FunctionType>(APTy->getElementType())->isVarArg()) return false; @@ -4774,11 +4972,8 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { NewCall->setCallingConv(Call.getCallingConv()); NewCall->setAttributes(NewCallerPAL); - // Preserve the weight metadata for the new call instruction. The metadata - // is used by SamplePGO to check callsite's hotness. - uint64_t W; - if (Caller->extractProfTotalWeight(W)) - NewCall->setProfWeight(W); + // Preserve prof metadata if any. + NewCall->copyMetadata(*Caller, {LLVMContext::MD_prof}); // Insert a cast of the return type as necessary. Instruction *NC = NewCall; @@ -4800,7 +4995,7 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { // Otherwise, it's a call, just insert cast right after the call. InsertNewInstBefore(NC, *Caller); } - Worklist.AddUsersToWorkList(*Caller); + Worklist.pushUsersToWorkList(*Caller); } else { NV = UndefValue::get(Caller->getType()); } @@ -4826,7 +5021,7 @@ bool InstCombiner::transformConstExprCastCall(CallBase &Call) { Instruction * InstCombiner::transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp) { - Value *Callee = Call.getCalledValue(); + Value *Callee = Call.getCalledOperand(); Type *CalleeTy = Callee->getType(); FunctionType *FTy = Call.getFunctionType(); AttributeList Attrs = Call.getAttributes(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 71b7f279e5fa5..3639edb5df4d1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -85,16 +85,16 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, AllocaInst &AI) { PointerType *PTy = cast<PointerType>(CI.getType()); - BuilderTy AllocaBuilder(Builder); - AllocaBuilder.SetInsertPoint(&AI); + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&AI); // Get the type really allocated and the type casted to. Type *AllocElTy = AI.getAllocatedType(); Type *CastElTy = PTy->getElementType(); if (!AllocElTy->isSized() || !CastElTy->isSized()) return nullptr; - unsigned AllocElTyAlign = DL.getABITypeAlignment(AllocElTy); - unsigned CastElTyAlign = DL.getABITypeAlignment(CastElTy); + Align AllocElTyAlign = DL.getABITypeAlign(AllocElTy); + Align CastElTyAlign = DL.getABITypeAlign(CastElTy); if (CastElTyAlign < AllocElTyAlign) return nullptr; // If the allocation has multiple uses, only promote it if we are strictly @@ -131,17 +131,17 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, } else { Amt = ConstantInt::get(AI.getArraySize()->getType(), Scale); // Insert before the alloca, not before the cast. - Amt = AllocaBuilder.CreateMul(Amt, NumElements); + Amt = Builder.CreateMul(Amt, NumElements); } if (uint64_t Offset = (AllocElTySize*ArrayOffset)/CastElTySize) { Value *Off = ConstantInt::get(AI.getArraySize()->getType(), Offset, true); - Amt = AllocaBuilder.CreateAdd(Amt, Off); + Amt = Builder.CreateAdd(Amt, Off); } - AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt); - New->setAlignment(MaybeAlign(AI.getAlignment())); + AllocaInst *New = Builder.CreateAlloca(CastElTy, Amt); + New->setAlignment(AI.getAlign()); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); @@ -151,8 +151,9 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, if (!AI.hasOneUse()) { // New is the allocation instruction, pointer typed. AI is the original // allocation instruction, also pointer typed. Thus, cast to use is BitCast. - Value *NewCast = AllocaBuilder.CreateBitCast(New, AI.getType(), "tmpcast"); + Value *NewCast = Builder.CreateBitCast(New, AI.getType(), "tmpcast"); replaceInstUsesWith(AI, NewCast); + eraseInstFromFunction(AI); } return replaceInstUsesWith(CI, New); } @@ -164,9 +165,7 @@ Value *InstCombiner::EvaluateInDifferentType(Value *V, Type *Ty, if (Constant *C = dyn_cast<Constant>(V)) { C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); // If we got a constantexpr back, try to simplify it with DL info. - if (Constant *FoldedC = ConstantFoldConstant(C, DL, &TLI)) - C = FoldedC; - return C; + return ConstantFoldConstant(C, DL, &TLI); } // Otherwise, it must be an instruction. @@ -276,16 +275,20 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { } if (auto *Sel = dyn_cast<SelectInst>(Src)) { - // We are casting a select. Try to fold the cast into the select, but only - // if the select does not have a compare instruction with matching operand - // types. Creating a select with operands that are different sizes than its + // We are casting a select. Try to fold the cast into the select if the + // select does not have a compare instruction with matching operand types + // or the select is likely better done in a narrow type. + // Creating a select with operands that are different sizes than its // condition may inhibit other folds and lead to worse codegen. auto *Cmp = dyn_cast<CmpInst>(Sel->getCondition()); - if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType()) + if (!Cmp || Cmp->getOperand(0)->getType() != Sel->getType() || + (CI.getOpcode() == Instruction::Trunc && + shouldChangeType(CI.getSrcTy(), CI.getType()))) { if (Instruction *NV = FoldOpIntoSelect(CI, Sel)) { replaceAllDbgUsesWith(*Sel, *NV, CI, DT); return NV; } + } } // If we are casting a PHI, then fold the cast into the PHI. @@ -293,7 +296,7 @@ Instruction *InstCombiner::commonCastTransforms(CastInst &CI) { // Don't do this if it would create a PHI node with an illegal type from a // legal type. if (!Src->getType()->isIntegerTy() || !CI.getType()->isIntegerTy() || - shouldChangeType(CI.getType(), Src->getType())) + shouldChangeType(CI.getSrcTy(), CI.getType())) if (Instruction *NV = foldOpIntoPhi(CI, PN)) return NV; } @@ -374,29 +377,31 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, break; } case Instruction::Shl: { - // If we are truncating the result of this SHL, and if it's a shift of a - // constant amount, we can always perform a SHL in a smaller type. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + // If we are truncating the result of this SHL, and if it's a shift of an + // inrange amount, we can always perform a SHL in a smaller type. + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + if (AmtKnownBits.getMaxValue().ult(BitWidth)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::LShr: { // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth && - IC.MaskedValueIsZero(I->getOperand(0), - APInt::getBitsSetFrom(OrigBitWidth, BitWidth), 0, CxtI)) { - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + // TODO: It is enough to check that the bits we would be shifting in are + // zero - use AmtKnownBits.getMaxValue(). + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + APInt ShiftedBits = APInt::getBitsSetFrom(OrigBitWidth, BitWidth); + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + IC.MaskedValueIsZero(I->getOperand(0), ShiftedBits, 0, CxtI)) { + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); } break; } @@ -406,15 +411,15 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, // original type and the sign bit of the truncate type are similar. // TODO: It is enough to check that the bits we would be shifting in are // similar to sign bit of the truncate type. - const APInt *Amt; - if (match(I->getOperand(1), m_APInt(Amt))) { - uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); - uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (Amt->getLimitedValue(BitWidth) < BitWidth && - OrigBitWidth - BitWidth < - IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) - return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); - } + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + KnownBits AmtKnownBits = + llvm::computeKnownBits(I->getOperand(1), IC.getDataLayout()); + unsigned ShiftedBits = OrigBitWidth - BitWidth; + if (AmtKnownBits.getMaxValue().ult(BitWidth) && + ShiftedBits < IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI) && + canEvaluateTruncated(I->getOperand(1), Ty, IC, CxtI); break; } case Instruction::Trunc: @@ -480,7 +485,7 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { // bitcast it to a vector type that we can extract from. unsigned NumVecElts = VecWidth / DestWidth; if (VecType->getElementType() != DestType) { - VecType = VectorType::get(DestType, NumVecElts); + VecType = FixedVectorType::get(DestType, NumVecElts); VecInput = IC.Builder.CreateBitCast(VecInput, VecType, "bc"); } @@ -639,12 +644,12 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && isa<UndefValue>(Shuf->getOperand(1)) && - Shuf->getMask()->getSplatValue() && + is_splat(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask Constant *NarrowUndef = UndefValue::get(Trunc.getType()); Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); - return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getMask()); + return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask()); } return nullptr; @@ -682,29 +687,51 @@ static Instruction *shrinkInsertElt(CastInst &Trunc, return nullptr; } -Instruction *InstCombiner::visitTrunc(TruncInst &CI) { - if (Instruction *Result = commonCastTransforms(CI)) +Instruction *InstCombiner::visitTrunc(TruncInst &Trunc) { + if (Instruction *Result = commonCastTransforms(Trunc)) return Result; - Value *Src = CI.getOperand(0); - Type *DestTy = CI.getType(), *SrcTy = Src->getType(); + Value *Src = Trunc.getOperand(0); + Type *DestTy = Trunc.getType(), *SrcTy = Src->getType(); + unsigned DestWidth = DestTy->getScalarSizeInBits(); + unsigned SrcWidth = SrcTy->getScalarSizeInBits(); + ConstantInt *Cst; // Attempt to truncate the entire input expression tree to the destination // type. Only do this if the dest type is a simple type, don't convert the // expression tree to something weird like i93 unless the source is also // strange. if ((DestTy->isVectorTy() || shouldChangeType(SrcTy, DestTy)) && - canEvaluateTruncated(Src, DestTy, *this, &CI)) { + canEvaluateTruncated(Src, DestTy, *this, &Trunc)) { // If this cast is a truncate, evaluting in a different type always // eliminates the cast, so it is always a win. LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid cast: " - << CI << '\n'); + << Trunc << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); - return replaceInstUsesWith(CI, Res); + return replaceInstUsesWith(Trunc, Res); + } + + // For integer types, check if we can shorten the entire input expression to + // DestWidth * 2, which won't allow removing the truncate, but reducing the + // width may enable further optimizations, e.g. allowing for larger + // vectorization factors. + if (auto *DestITy = dyn_cast<IntegerType>(DestTy)) { + if (DestWidth * 2 < SrcWidth) { + auto *NewDestTy = DestITy->getExtendedType(); + if (shouldChangeType(SrcTy, NewDestTy) && + canEvaluateTruncated(Src, NewDestTy, *this, &Trunc)) { + LLVM_DEBUG( + dbgs() << "ICE: EvaluateInDifferentType converting expression type" + " to reduce the width of operand of" + << Trunc << '\n'); + Value *Res = EvaluateInDifferentType(Src, NewDestTy, false); + return new TruncInst(Res, DestTy); + } + } } // Test if the trunc is the user of a select which is part of a @@ -712,17 +739,17 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // Even simplifying demanded bits can break the canonical form of a // min/max. Value *LHS, *RHS; - if (SelectInst *SI = dyn_cast<SelectInst>(CI.getOperand(0))) - if (matchSelectPattern(SI, LHS, RHS).Flavor != SPF_UNKNOWN) + if (SelectInst *Sel = dyn_cast<SelectInst>(Src)) + if (matchSelectPattern(Sel, LHS, RHS).Flavor != SPF_UNKNOWN) return nullptr; // See if we can simplify any instructions used by the input whose sole // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(CI)) - return &CI; + if (SimplifyDemandedInstructionBits(Trunc)) + return &Trunc; - if (DestTy->getScalarSizeInBits() == 1) { - Value *Zero = Constant::getNullValue(Src->getType()); + if (DestWidth == 1) { + Value *Zero = Constant::getNullValue(SrcTy); if (DestTy->isIntegerTy()) { // Canonicalize trunc x to i1 -> icmp ne (and x, 1), 0 (scalar only). // TODO: We canonicalize to more instructions here because we are probably @@ -736,18 +763,21 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // For vectors, we do not canonicalize all truncs to icmp, so optimize // patterns that would be covered within visitICmpInst. Value *X; - const APInt *C; - if (match(Src, m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + Constant *C; + if (match(Src, m_OneUse(m_LShr(m_Value(X), m_Constant(C))))) { // trunc (lshr X, C) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C); - Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } - if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_APInt(C)), + if (match(Src, m_OneUse(m_c_Or(m_LShr(m_Value(X), m_Constant(C)), m_Deferred(X))))) { // trunc (or (lshr X, C), X) to i1 --> icmp ne (and X, C'), 0 - APInt MaskC = APInt(SrcTy->getScalarSizeInBits(), 1).shl(*C) | 1; - Value *And = Builder.CreateAnd(X, ConstantInt::get(SrcTy, MaskC)); + Constant *One = ConstantInt::get(SrcTy, APInt(SrcWidth, 1)); + Constant *MaskC = ConstantExpr::getShl(One, C); + MaskC = ConstantExpr::getOr(MaskC, One); + Value *And = Builder.CreateAnd(X, MaskC); return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); } } @@ -756,7 +786,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // more efficiently. Support vector types. Cleanup code by using m_OneUse. // Transform trunc(lshr (zext A), Cst) to eliminate one type conversion. - Value *A = nullptr; ConstantInt *Cst = nullptr; + Value *A = nullptr; if (Src->hasOneUse() && match(Src, m_LShr(m_ZExt(m_Value(A)), m_ConstantInt(Cst)))) { // We have three types to worry about here, the type of A, the source of @@ -768,7 +798,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { // If the shift amount is larger than the size of A, then the result is // known to be zero because all the input bits got shifted out. if (Cst->getZExtValue() >= ASize) - return replaceInstUsesWith(CI, Constant::getNullValue(DestTy)); + return replaceInstUsesWith(Trunc, Constant::getNullValue(DestTy)); // Since we're doing an lshr and a zero extend, and know that the shift // amount is smaller than ASize, it is always safe to do the shift in A's @@ -778,45 +808,37 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return CastInst::CreateIntegerCast(Shift, DestTy, false); } - // FIXME: We should canonicalize to zext/trunc and remove this transform. - // Transform trunc(lshr (sext A), Cst) to ashr A, Cst to eliminate type - // conversion. - // It works because bits coming from sign extension have the same value as - // the sign bit of the original value; performing ashr instead of lshr - // generates bits of the same value as the sign bit. - if (Src->hasOneUse() && - match(Src, m_LShr(m_SExt(m_Value(A)), m_ConstantInt(Cst)))) { - Value *SExt = cast<Instruction>(Src)->getOperand(0); - const unsigned SExtSize = SExt->getType()->getPrimitiveSizeInBits(); - const unsigned ASize = A->getType()->getPrimitiveSizeInBits(); - const unsigned CISize = CI.getType()->getPrimitiveSizeInBits(); - const unsigned MaxAmt = SExtSize - std::max(CISize, ASize); - unsigned ShiftAmt = Cst->getZExtValue(); - - // This optimization can be only performed when zero bits generated by - // the original lshr aren't pulled into the value after truncation, so we - // can only shift by values no larger than the number of extension bits. - // FIXME: Instead of bailing when the shift is too large, use and to clear - // the extra bits. - if (ShiftAmt <= MaxAmt) { - if (CISize == ASize) - return BinaryOperator::CreateAShr(A, ConstantInt::get(CI.getType(), - std::min(ShiftAmt, ASize - 1))); - if (SExt->hasOneUse()) { - Value *Shift = Builder.CreateAShr(A, std::min(ShiftAmt, ASize - 1)); - Shift->takeName(Src); - return CastInst::CreateIntegerCast(Shift, CI.getType(), true); + const APInt *C; + if (match(Src, m_LShr(m_SExt(m_Value(A)), m_APInt(C)))) { + unsigned AWidth = A->getType()->getScalarSizeInBits(); + unsigned MaxShiftAmt = SrcWidth - std::max(DestWidth, AWidth); + + // If the shift is small enough, all zero bits created by the shift are + // removed by the trunc. + if (C->getZExtValue() <= MaxShiftAmt) { + // trunc (lshr (sext A), C) --> ashr A, C + if (A->getType() == DestTy) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), DestWidth - 1); + return BinaryOperator::CreateAShr(A, ConstantInt::get(DestTy, ShAmt)); + } + // The types are mismatched, so create a cast after shifting: + // trunc (lshr (sext A), C) --> sext/trunc (ashr A, C) + if (Src->hasOneUse()) { + unsigned ShAmt = std::min((unsigned)C->getZExtValue(), AWidth - 1); + Value *Shift = Builder.CreateAShr(A, ShAmt); + return CastInst::CreateIntegerCast(Shift, DestTy, true); } } + // TODO: Mask high bits with 'and'. } - if (Instruction *I = narrowBinOp(CI)) + if (Instruction *I = narrowBinOp(Trunc)) return I; - if (Instruction *I = shrinkSplatShuffle(CI, Builder)) + if (Instruction *I = shrinkSplatShuffle(Trunc, Builder)) return I; - if (Instruction *I = shrinkInsertElt(CI, Builder)) + if (Instruction *I = shrinkInsertElt(Trunc, Builder)) return I; if (Src->hasOneUse() && isa<IntegerType>(SrcTy) && @@ -827,20 +849,48 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { !match(A, m_Shr(m_Value(), m_Constant()))) { // Skip shifts of shift by constants. It undoes a combine in // FoldShiftByConstant and is the extend in reg pattern. - const unsigned DestSize = DestTy->getScalarSizeInBits(); - if (Cst->getValue().ult(DestSize)) { + if (Cst->getValue().ult(DestWidth)) { Value *NewTrunc = Builder.CreateTrunc(A, DestTy, A->getName() + ".tr"); return BinaryOperator::Create( Instruction::Shl, NewTrunc, - ConstantInt::get(DestTy, Cst->getValue().trunc(DestSize))); + ConstantInt::get(DestTy, Cst->getValue().trunc(DestWidth))); } } } - if (Instruction *I = foldVecTruncToExtElt(CI, *this)) + if (Instruction *I = foldVecTruncToExtElt(Trunc, *this)) return I; + // Whenever an element is extracted from a vector, and then truncated, + // canonicalize by converting it to a bitcast followed by an + // extractelement. + // + // Example (little endian): + // trunc (extractelement <4 x i64> %X, 0) to i32 + // ---> + // extractelement <8 x i32> (bitcast <4 x i64> %X to <8 x i32>), i32 0 + Value *VecOp; + if (match(Src, m_OneUse(m_ExtractElt(m_Value(VecOp), m_ConstantInt(Cst))))) { + auto *VecOpTy = cast<VectorType>(VecOp->getType()); + unsigned VecNumElts = VecOpTy->getNumElements(); + + // A badly fit destination size would result in an invalid cast. + if (SrcWidth % DestWidth == 0) { + uint64_t TruncRatio = SrcWidth / DestWidth; + uint64_t BitCastNumElts = VecNumElts * TruncRatio; + uint64_t VecOpIdx = Cst->getZExtValue(); + uint64_t NewIdx = DL.isBigEndian() ? (VecOpIdx + 1) * TruncRatio - 1 + : VecOpIdx * TruncRatio; + assert(BitCastNumElts <= std::numeric_limits<uint32_t>::max() && + "overflow 32-bits"); + + auto *BitCastTo = FixedVectorType::get(DestTy, BitCastNumElts); + Value *BitCast = Builder.CreateBitCast(VecOp, BitCastTo); + return ExtractElementInst::Create(BitCast, Builder.getInt32(NewIdx)); + } + } + return nullptr; } @@ -1431,16 +1481,17 @@ Instruction *InstCombiner::visitSExt(SExtInst &CI) { // %d = ashr i32 %a, 30 Value *A = nullptr; // TODO: Eventually this could be subsumed by EvaluateInDifferentType. - ConstantInt *BA = nullptr, *CA = nullptr; - if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_ConstantInt(BA)), - m_ConstantInt(CA))) && + Constant *BA = nullptr, *CA = nullptr; + if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), + m_Constant(CA))) && BA == CA && A->getType() == CI.getType()) { unsigned MidSize = Src->getType()->getScalarSizeInBits(); unsigned SrcDstSize = CI.getType()->getScalarSizeInBits(); - unsigned ShAmt = CA->getZExtValue()+SrcDstSize-MidSize; - Constant *ShAmtV = ConstantInt::get(CI.getType(), ShAmt); - A = Builder.CreateShl(A, ShAmtV, CI.getName()); - return BinaryOperator::CreateAShr(A, ShAmtV); + Constant *SizeDiff = ConstantInt::get(CA->getType(), SrcDstSize - MidSize); + Constant *ShAmt = ConstantExpr::getAdd(CA, SizeDiff); + Constant *ShAmtExt = ConstantExpr::getSExt(ShAmt, CI.getType()); + A = Builder.CreateShl(A, ShAmtExt, CI.getName()); + return BinaryOperator::CreateAShr(A, ShAmtExt); } return nullptr; @@ -1478,12 +1529,13 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // TODO: Make these support undef elements. static Type *shrinkFPConstantVector(Value *V) { auto *CV = dyn_cast<Constant>(V); - if (!CV || !CV->getType()->isVectorTy()) + auto *CVVTy = dyn_cast<VectorType>(V->getType()); + if (!CV || !CVVTy) return nullptr; Type *MinType = nullptr; - unsigned NumElts = CV->getType()->getVectorNumElements(); + unsigned NumElts = CVVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) @@ -1500,7 +1552,7 @@ static Type *shrinkFPConstantVector(Value *V) { } // Make a vector type from the minimal type. - return VectorType::get(MinType, NumElts); + return FixedVectorType::get(MinType, NumElts); } /// Find the minimum FP type we can safely truncate to. @@ -1522,6 +1574,48 @@ static Type *getMinimumFPType(Value *V) { return V->getType(); } +/// Return true if the cast from integer to FP can be proven to be exact for all +/// possible inputs (the conversion does not lose any precision). +static bool isKnownExactCastIntToFP(CastInst &I) { + CastInst::CastOps Opcode = I.getOpcode(); + assert((Opcode == CastInst::SIToFP || Opcode == CastInst::UIToFP) && + "Unexpected cast"); + Value *Src = I.getOperand(0); + Type *SrcTy = Src->getType(); + Type *FPTy = I.getType(); + bool IsSigned = Opcode == Instruction::SIToFP; + int SrcSize = (int)SrcTy->getScalarSizeInBits() - IsSigned; + + // Easy case - if the source integer type has less bits than the FP mantissa, + // then the cast must be exact. + int DestNumSigBits = FPTy->getFPMantissaWidth(); + if (SrcSize <= DestNumSigBits) + return true; + + // Cast from FP to integer and back to FP is independent of the intermediate + // integer width because of poison on overflow. + Value *F; + if (match(Src, m_FPToSI(m_Value(F))) || match(Src, m_FPToUI(m_Value(F)))) { + // If this is uitofp (fptosi F), the source needs an extra bit to avoid + // potential rounding of negative FP input values. + int SrcNumSigBits = F->getType()->getFPMantissaWidth(); + if (!IsSigned && match(Src, m_FPToSI(m_Value()))) + SrcNumSigBits++; + + // [su]itofp (fpto[su]i F) --> exact if the source type has less or equal + // significant bits than the destination (and make sure neither type is + // weird -- ppc_fp128). + if (SrcNumSigBits > 0 && DestNumSigBits > 0 && + SrcNumSigBits <= DestNumSigBits) + return true; + } + + // TODO: + // Try harder to find if the source integer type has less significant bits. + // For example, compute number of sign bits or compute low bit mask. + return false; +} + Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (Instruction *I = commonCastTransforms(FPT)) return I; @@ -1632,10 +1726,6 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (match(Op, m_FNeg(m_Value(X)))) { Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); - // FIXME: Once we're sure that unary FNeg optimizations are on par with - // binary FNeg, this should always return a unary operator. - if (isa<BinaryOperator>(Op)) - return BinaryOperator::CreateFNegFMF(InnerTrunc, Op); return UnaryOperator::CreateFNegFMF(InnerTrunc, Op); } @@ -1667,6 +1757,7 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { case Intrinsic::nearbyint: case Intrinsic::rint: case Intrinsic::round: + case Intrinsic::roundeven: case Intrinsic::trunc: { Value *Src = II->getArgOperand(0); if (!Src->hasOneUse()) @@ -1699,74 +1790,83 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { if (Instruction *I = shrinkInsertElt(FPT, Builder)) return I; + Value *Src = FPT.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + return nullptr; } -Instruction *InstCombiner::visitFPExt(CastInst &CI) { - return commonCastTransforms(CI); +Instruction *InstCombiner::visitFPExt(CastInst &FPExt) { + // If the source operand is a cast from integer to FP and known exact, then + // cast the integer operand directly to the destination type. + Type *Ty = FPExt.getType(); + Value *Src = FPExt.getOperand(0); + if (isa<SIToFPInst>(Src) || isa<UIToFPInst>(Src)) { + auto *FPCast = cast<CastInst>(Src); + if (isKnownExactCastIntToFP(*FPCast)) + return CastInst::Create(FPCast->getOpcode(), FPCast->getOperand(0), Ty); + } + + return commonCastTransforms(FPExt); } -// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) -// This is safe if the intermediate type has enough bits in its mantissa to -// accurately represent all values of X. For example, this won't work with -// i64 -> float -> i64. -Instruction *InstCombiner::FoldItoFPtoI(Instruction &FI) { +/// fpto{s/u}i({u/s}itofp(X)) --> X or zext(X) or sext(X) or trunc(X) +/// This is safe if the intermediate type has enough bits in its mantissa to +/// accurately represent all values of X. For example, this won't work with +/// i64 -> float -> i64. +Instruction *InstCombiner::foldItoFPtoI(CastInst &FI) { if (!isa<UIToFPInst>(FI.getOperand(0)) && !isa<SIToFPInst>(FI.getOperand(0))) return nullptr; - Instruction *OpI = cast<Instruction>(FI.getOperand(0)); - Value *SrcI = OpI->getOperand(0); - Type *FITy = FI.getType(); - Type *OpITy = OpI->getType(); - Type *SrcTy = SrcI->getType(); - bool IsInputSigned = isa<SIToFPInst>(OpI); + auto *OpI = cast<CastInst>(FI.getOperand(0)); + Value *X = OpI->getOperand(0); + Type *XType = X->getType(); + Type *DestType = FI.getType(); bool IsOutputSigned = isa<FPToSIInst>(FI); - // We can safely assume the conversion won't overflow the output range, - // because (for example) (uint8_t)18293.f is undefined behavior. - // Since we can assume the conversion won't overflow, our decision as to // whether the input will fit in the float should depend on the minimum // of the input range and output range. // This means this is also safe for a signed input and unsigned output, since // a negative input would lead to undefined behavior. - int InputSize = (int)SrcTy->getScalarSizeInBits() - IsInputSigned; - int OutputSize = (int)FITy->getScalarSizeInBits() - IsOutputSigned; - int ActualSize = std::min(InputSize, OutputSize); - - if (ActualSize <= OpITy->getFPMantissaWidth()) { - if (FITy->getScalarSizeInBits() > SrcTy->getScalarSizeInBits()) { - if (IsInputSigned && IsOutputSigned) - return new SExtInst(SrcI, FITy); - return new ZExtInst(SrcI, FITy); - } - if (FITy->getScalarSizeInBits() < SrcTy->getScalarSizeInBits()) - return new TruncInst(SrcI, FITy); - if (SrcTy == FITy) - return replaceInstUsesWith(FI, SrcI); - return new BitCastInst(SrcI, FITy); + if (!isKnownExactCastIntToFP(*OpI)) { + // The first cast may not round exactly based on the source integer width + // and FP width, but the overflow UB rules can still allow this to fold. + // If the destination type is narrow, that means the intermediate FP value + // must be large enough to hold the source value exactly. + // For example, (uint8_t)((float)(uint32_t 16777217) is undefined behavior. + int OutputSize = (int)DestType->getScalarSizeInBits() - IsOutputSigned; + if (OutputSize > OpI->getType()->getFPMantissaWidth()) + return nullptr; } - return nullptr; + + if (DestType->getScalarSizeInBits() > XType->getScalarSizeInBits()) { + bool IsInputSigned = isa<SIToFPInst>(OpI); + if (IsInputSigned && IsOutputSigned) + return new SExtInst(X, DestType); + return new ZExtInst(X, DestType); + } + if (DestType->getScalarSizeInBits() < XType->getScalarSizeInBits()) + return new TruncInst(X, DestType); + + assert(XType == DestType && "Unexpected types for int to FP to int casts"); + return replaceInstUsesWith(FI, X); } Instruction *InstCombiner::visitFPToUI(FPToUIInst &FI) { - Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); - if (!OpI) - return commonCastTransforms(FI); - - if (Instruction *I = FoldItoFPtoI(FI)) + if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); } Instruction *InstCombiner::visitFPToSI(FPToSIInst &FI) { - Instruction *OpI = dyn_cast<Instruction>(FI.getOperand(0)); - if (!OpI) - return commonCastTransforms(FI); - - if (Instruction *I = FoldItoFPtoI(FI)) + if (Instruction *I = foldItoFPtoI(FI)) return I; return commonCastTransforms(FI); @@ -1788,8 +1888,9 @@ Instruction *InstCombiner::visitIntToPtr(IntToPtrInst &CI) { if (CI.getOperand(0)->getType()->getScalarSizeInBits() != DL.getPointerSizeInBits(AS)) { Type *Ty = DL.getIntPtrType(CI.getContext(), AS); - if (CI.getType()->isVectorTy()) // Handle vectors of pointers. - Ty = VectorType::get(Ty, CI.getType()->getVectorNumElements()); + // Handle vectors of pointers. + if (auto *CIVTy = dyn_cast<VectorType>(CI.getType())) + Ty = VectorType::get(Ty, CIVTy->getElementCount()); Value *P = Builder.CreateZExtOrTrunc(CI.getOperand(0), Ty); return new IntToPtrInst(P, CI.getType()); @@ -1817,9 +1918,7 @@ Instruction *InstCombiner::commonPointerCastTransforms(CastInst &CI) { // Changing the cast operand is usually not a good idea but it is safe // here because the pointer operand is being replaced with another // pointer operand so the opcode doesn't need to change. - Worklist.Add(GEP); - CI.setOperand(0, GEP->getOperand(0)); - return &CI; + return replaceOperand(CI, 0, GEP->getOperand(0)); } } @@ -1838,8 +1937,11 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { return commonPointerCastTransforms(CI); Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); - if (Ty->isVectorTy()) // Handle vectors of pointers. - PtrTy = VectorType::get(PtrTy, Ty->getVectorNumElements()); + if (auto *VTy = dyn_cast<VectorType>(Ty)) { + // Handle vectors of pointers. + // FIXME: what should happen for scalable vectors? + PtrTy = FixedVectorType::get(PtrTy, VTy->getNumElements()); + } Value *P = Builder.CreatePtrToInt(CI.getOperand(0), PtrTy); return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); @@ -1878,7 +1980,8 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, DestTy->getElementType()->getPrimitiveSizeInBits()) return nullptr; - SrcTy = VectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); + SrcTy = + FixedVectorType::get(DestTy->getElementType(), SrcTy->getNumElements()); InVal = IC.Builder.CreateBitCast(InVal, SrcTy); } @@ -1891,8 +1994,8 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, // Now that the element types match, get the shuffle mask and RHS of the // shuffle to use, which depends on whether we're increasing or decreasing the // size of the input. - SmallVector<uint32_t, 16> ShuffleMaskStorage; - ArrayRef<uint32_t> ShuffleMask; + SmallVector<int, 16> ShuffleMaskStorage; + ArrayRef<int> ShuffleMask; Value *V2; // Produce an identify shuffle mask for the src vector. @@ -1931,9 +2034,7 @@ static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, ShuffleMask = ShuffleMaskStorage; } - return new ShuffleVectorInst(InVal, V2, - ConstantDataVector::get(V2->getContext(), - ShuffleMask)); + return new ShuffleVectorInst(InVal, V2, ShuffleMask); } static bool isMultipleOfTypeSize(unsigned Value, Type *Ty) { @@ -2106,7 +2207,7 @@ static Instruction *canonicalizeBitCastExtElt(BitCastInst &BitCast, return nullptr; unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); - auto *NewVecType = VectorType::get(DestType, NumElts); + auto *NewVecType = FixedVectorType::get(DestType, NumElts); auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), NewVecType, "bc"); return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); @@ -2151,7 +2252,7 @@ static Instruction *foldBitCastBitwiseLogic(BitCastInst &BitCast, if (match(BO->getOperand(1), m_Constant(C))) { // bitcast (logic X, C) --> logic (bitcast X, C') Value *CastedOp0 = Builder.CreateBitCast(BO->getOperand(0), DestTy); - Value *CastedC = ConstantExpr::getBitCast(C, DestTy); + Value *CastedC = Builder.CreateBitCast(C, DestTy); return BinaryOperator::Create(BO->getOpcode(), CastedOp0, CastedC); } @@ -2169,10 +2270,10 @@ static Instruction *foldBitCastSelect(BitCastInst &BitCast, // A vector select must maintain the same number of elements in its operands. Type *CondTy = Cond->getType(); Type *DestTy = BitCast.getType(); - if (CondTy->isVectorTy()) { + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { if (!DestTy->isVectorTy()) return nullptr; - if (DestTy->getVectorNumElements() != CondTy->getVectorNumElements()) + if (cast<VectorType>(DestTy)->getNumElements() != CondVTy->getNumElements()) return nullptr; } @@ -2359,7 +2460,7 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { auto *NewBC = cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); SI->setOperand(0, NewBC); - Worklist.Add(SI); + Worklist.push(SI); assert(hasStoreUsersOnly(*NewBC)); } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { @@ -2395,8 +2496,9 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { if (DestTy == Src->getType()) return replaceInstUsesWith(CI, Src); - if (PointerType *DstPTy = dyn_cast<PointerType>(DestTy)) { + if (isa<PointerType>(SrcTy) && isa<PointerType>(DestTy)) { PointerType *SrcPTy = cast<PointerType>(SrcTy); + PointerType *DstPTy = cast<PointerType>(DestTy); Type *DstElTy = DstPTy->getElementType(); Type *SrcElTy = SrcPTy->getElementType(); @@ -2425,10 +2527,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // to a getelementptr X, 0, 0, 0... turn it into the appropriate gep. // This can enhance SROA and other transforms that want type-safe pointers. unsigned NumZeros = 0; - while (SrcElTy != DstElTy && - isa<CompositeType>(SrcElTy) && !SrcElTy->isPointerTy() && - SrcElTy->getNumContainedTypes() /* not "{}" */) { - SrcElTy = cast<CompositeType>(SrcElTy)->getTypeAtIndex(0U); + while (SrcElTy && SrcElTy != DstElTy) { + SrcElTy = GetElementPtrInst::getTypeAtIndex(SrcElTy, (uint64_t)0); ++NumZeros; } @@ -2455,12 +2555,12 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (VectorType *DestVTy = dyn_cast<VectorType>(DestTy)) { - if (DestVTy->getNumElements() == 1 && !SrcTy->isVectorTy()) { + if (FixedVectorType *DestVTy = dyn_cast<FixedVectorType>(DestTy)) { + // Beware: messing with this target-specific oddity may cause trouble. + if (DestVTy->getNumElements() == 1 && SrcTy->isX86_MMXTy()) { Value *Elem = Builder.CreateBitCast(Src, DestVTy->getElementType()); return InsertElementInst::Create(UndefValue::get(DestTy), Elem, Constant::getNullValue(Type::getInt32Ty(CI.getContext()))); - // FIXME: Canonicalize bitcast(insertelement) -> insertelement(bitcast) } if (isa<IntegerType>(SrcTy)) { @@ -2484,7 +2584,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (VectorType *SrcVTy = dyn_cast<VectorType>(SrcTy)) { + if (FixedVectorType *SrcVTy = dyn_cast<FixedVectorType>(SrcTy)) { if (SrcVTy->getNumElements() == 1) { // If our destination is not a vector, then make this a straight // scalar-scalar cast. @@ -2508,10 +2608,11 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // a bitcast to a vector with the same # elts. Value *ShufOp0 = Shuf->getOperand(0); Value *ShufOp1 = Shuf->getOperand(1); - unsigned NumShufElts = Shuf->getType()->getVectorNumElements(); - unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements(); + unsigned NumShufElts = Shuf->getType()->getNumElements(); + unsigned NumSrcVecElts = + cast<VectorType>(ShufOp0->getType())->getNumElements(); if (Shuf->hasOneUse() && DestTy->isVectorTy() && - DestTy->getVectorNumElements() == NumShufElts && + cast<VectorType>(DestTy)->getNumElements() == NumShufElts && NumShufElts == NumSrcVecElts) { BitCastInst *Tmp; // If either of the operands is a cast from CI.getType(), then @@ -2525,7 +2626,7 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy); // Return a new shuffle vector. Use the same element ID's, as we // know the vector types match #elts. - return new ShuffleVectorInst(LHS, RHS, Shuf->getOperand(2)); + return new ShuffleVectorInst(LHS, RHS, Shuf->getShuffleMask()); } } @@ -2578,7 +2679,8 @@ Instruction *InstCombiner::visitAddrSpaceCast(AddrSpaceCastInst &CI) { Type *MidTy = PointerType::get(DestElemTy, SrcTy->getAddressSpace()); if (VectorType *VT = dyn_cast<VectorType>(CI.getType())) { // Handle vectors of pointers. - MidTy = VectorType::get(MidTy, VT->getNumElements()); + // FIXME: what should happen for scalable vectors? + MidTy = FixedVectorType::get(MidTy, VT->getNumElements()); } Value *NewBitCast = Builder.CreateBitCast(Src, MidTy); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index f38dc436722dc..f1233b62445d0 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -897,7 +897,7 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // For vectors, we apply the same reasoning on a per-lane basis. auto *Base = GEPLHS->getPointerOperand(); if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { - int NumElts = GEPLHS->getType()->getVectorNumElements(); + int NumElts = cast<VectorType>(GEPLHS->getType())->getNumElements(); Base = Builder.CreateVectorSplat(NumElts, Base); } return new ICmpInst(Cond, Base, @@ -1330,6 +1330,7 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // The inner add was the result of the narrow add, zero extended to the // wider type. Replace it with the result computed by the intrinsic. IC.replaceInstUsesWith(*OrigAdd, ZExt); + IC.eraseInstFromFunction(*OrigAdd); // The original icmp gets replaced with the overflow value. return ExtractValueInst::Create(Call, 1, "sadd.overflow"); @@ -1451,6 +1452,27 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { if (Instruction *Res = processUGT_ADDCST_ADD(Cmp, A, B, CI2, CI, *this)) return Res; + // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...). + Constant *C = dyn_cast<Constant>(Op1); + if (!C) + return nullptr; + + if (auto *Phi = dyn_cast<PHINode>(Op0)) + if (all_of(Phi->operands(), [](Value *V) { return isa<Constant>(V); })) { + Type *Ty = Cmp.getType(); + Builder.SetInsertPoint(Phi); + PHINode *NewPhi = + Builder.CreatePHI(Ty, Phi->getNumOperands()); + for (BasicBlock *Predecessor : predecessors(Phi->getParent())) { + auto *Input = + cast<Constant>(Phi->getIncomingValueForBlock(Predecessor)); + auto *BoolInput = ConstantExpr::getCompare(Pred, Input, C); + NewPhi->addIncoming(BoolInput, Predecessor); + } + NewPhi->takeName(&Cmp); + return replaceInstUsesWith(Cmp, NewPhi); + } + return nullptr; } @@ -1575,11 +1597,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. - if (!XorC->isNegative()) { - Cmp.setOperand(0, X); - Worklist.Add(Xor); - return &Cmp; - } + if (!XorC->isNegative()) + return replaceOperand(Cmp, 0, X); // Emit the opposite comparison. if (TrueIfSigned) @@ -1645,51 +1664,53 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, bool IsShl = ShiftOpcode == Instruction::Shl; const APInt *C3; if (match(Shift->getOperand(1), m_APInt(C3))) { - bool CanFold = false; + APInt NewAndCst, NewCmpCst; + bool AnyCmpCstBitsShiftedOut; if (ShiftOpcode == Instruction::Shl) { // For a left shift, we can fold if the comparison is not signed. We can // also fold a signed comparison if the mask value and comparison value // are not negative. These constraints may not be obvious, but we can // prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative())) - CanFold = true; - } else { - bool IsAshr = ShiftOpcode == Instruction::AShr; + if (Cmp.isSigned() && (C2.isNegative() || C1.isNegative())) + return nullptr; + + NewCmpCst = C1.lshr(*C3); + NewAndCst = C2.lshr(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.shl(*C3) != C1; + } else if (ShiftOpcode == Instruction::LShr) { // For a logical right shift, we can fold if the comparison is not signed. // We can also fold a signed comparison if the shifted mask value and the // shifted comparison value are not negative. These constraints may not be // obvious, but we can prove that they are correct using an SMT solver. - // For an arithmetic shift right we can do the same, if we ensure - // the And doesn't use any bits being shifted in. Normally these would - // be turned into lshr by SimplifyDemandedBits, but not if there is an - // additional user. - if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) { - if (!Cmp.isSigned() || - (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative())) - CanFold = true; - } + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.lshr(*C3) != C1; + if (Cmp.isSigned() && (NewAndCst.isNegative() || NewCmpCst.isNegative())) + return nullptr; + } else { + // For an arithmetic shift, check that both constants don't use (in a + // signed sense) the top bits being shifted out. + assert(ShiftOpcode == Instruction::AShr && "Unknown shift opcode"); + NewCmpCst = C1.shl(*C3); + NewAndCst = C2.shl(*C3); + AnyCmpCstBitsShiftedOut = NewCmpCst.ashr(*C3) != C1; + if (NewAndCst.ashr(*C3) != C2) + return nullptr; } - if (CanFold) { - APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3); - APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); - // Check to see if we are shifting out any of the bits being compared. - if (SameAsC1 != C1) { - // If we shifted bits out, the fold is not going to work out. As a - // special case, check to see if this means that the result is always - // true or false now. - if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) - return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); - if (Cmp.getPredicate() == ICmpInst::ICMP_NE) - return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); - } else { - Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); - APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3); - And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); - And->setOperand(0, Shift->getOperand(0)); - Worklist.Add(Shift); // Shift is dead. - return &Cmp; - } + if (AnyCmpCstBitsShiftedOut) { + // If we shifted bits out, the fold is not going to work out. As a + // special case, check to see if this means that the result is always + // true or false now. + if (Cmp.getPredicate() == ICmpInst::ICMP_EQ) + return replaceInstUsesWith(Cmp, ConstantInt::getFalse(Cmp.getType())); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); + } else { + Value *NewAnd = Builder.CreateAnd( + Shift->getOperand(0), ConstantInt::get(And->getType(), NewAndCst)); + return new ICmpInst(Cmp.getPredicate(), + NewAnd, ConstantInt::get(And->getType(), NewCmpCst)); } } @@ -1705,8 +1726,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Compute X & (C2 << Y). Value *NewAnd = Builder.CreateAnd(Shift->getOperand(0), NewShift); - Cmp.setOperand(0, NewAnd); - return &Cmp; + return replaceOperand(Cmp, 0, NewAnd); } return nullptr; @@ -1812,8 +1832,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } if (NewOr) { Value *NewAnd = Builder.CreateAnd(A, NewOr, And->getName()); - Cmp.setOperand(0, NewAnd); - return &Cmp; + return replaceOperand(Cmp, 0, NewAnd); } } } @@ -1863,8 +1882,8 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); - if (And->getType()->isVectorTy()) - NTy = VectorType::get(NTy, And->getType()->getVectorNumElements()); + if (auto *AndVTy = dyn_cast<VectorType>(And->getType())) + NTy = FixedVectorType::get(NTy, AndVTy->getNumElements()); Value *Trunc = Builder.CreateTrunc(X, NTy); auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT; @@ -1888,20 +1907,24 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, } Value *OrOp0 = Or->getOperand(0), *OrOp1 = Or->getOperand(1); - if (Cmp.isEquality() && Cmp.getOperand(1) == OrOp1) { - // X | C == C --> X <=u C - // X | C != C --> X >u C - // iff C+1 is a power of 2 (C is a bitmask of the low bits) - if ((C + 1).isPowerOf2()) { + const APInt *MaskC; + if (match(OrOp1, m_APInt(MaskC)) && Cmp.isEquality()) { + if (*MaskC == C && (C + 1).isPowerOf2()) { + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; return new ICmpInst(Pred, OrOp0, OrOp1); } - // More general: are all bits outside of a mask constant set or not set? - // X | C == C --> (X & ~C) == 0 - // X | C != C --> (X & ~C) != 0 + + // More general: canonicalize 'equality with set bits mask' to + // 'equality with clear bits mask'. + // (X | MaskC) == C --> (X & ~MaskC) == C ^ MaskC + // (X | MaskC) != C --> (X & ~MaskC) != C ^ MaskC if (Or->hasOneUse()) { - Value *A = Builder.CreateAnd(OrOp0, ~C); - return new ICmpInst(Pred, A, ConstantInt::getNullValue(OrOp0->getType())); + Value *And = Builder.CreateAnd(OrOp0, ~(*MaskC)); + Constant *NewC = ConstantInt::get(Or->getType(), C ^ (*MaskC)); + return new ICmpInst(Pred, And, NewC); } } @@ -2149,8 +2172,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (ShType->isVectorTy()) - TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); + if (auto *ShVTy = dyn_cast<VectorType>(ShType)) + TruncTy = FixedVectorType::get(TruncTy, ShVTy->getNumElements()); Constant *NewC = ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); @@ -2763,6 +2786,37 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, if (match(BCSrcOp, m_UIToFP(m_Value(X)))) if (Cmp.isEquality() && match(Op1, m_Zero())) return new ICmpInst(Pred, X, ConstantInt::getNullValue(X->getType())); + + // If this is a sign-bit test of a bitcast of a casted FP value, eliminate + // the FP extend/truncate because that cast does not change the sign-bit. + // This is true for all standard IEEE-754 types and the X86 80-bit type. + // The sign-bit is always the most significant bit in those types. + const APInt *C; + bool TrueIfSigned; + if (match(Op1, m_APInt(C)) && Bitcast->hasOneUse() && + isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (match(BCSrcOp, m_FPExt(m_Value(X))) || + match(BCSrcOp, m_FPTrunc(m_Value(X)))) { + // (bitcast (fpext/fptrunc X)) to iX) < 0 --> (bitcast X to iY) < 0 + // (bitcast (fpext/fptrunc X)) to iX) > -1 --> (bitcast X to iY) > -1 + Type *XType = X->getType(); + + // We can't currently handle Power style floating point operations here. + if (!(XType->isPPC_FP128Ty() || BCSrcOp->getType()->isPPC_FP128Ty())) { + + Type *NewType = Builder.getIntNTy(XType->getScalarSizeInBits()); + if (auto *XVTy = dyn_cast<VectorType>(XType)) + NewType = FixedVectorType::get(NewType, XVTy->getNumElements()); + Value *NewBitcast = Builder.CreateBitCast(X, NewType); + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SLT, NewBitcast, + ConstantInt::getNullValue(NewType)); + else + return new ICmpInst(ICmpInst::ICMP_SGT, NewBitcast, + ConstantInt::getAllOnesValue(NewType)); + } + } + } } // Test to see if the operands of the icmp are casted versions of other @@ -2792,11 +2846,10 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, return nullptr; Value *Vec; - Constant *Mask; - if (match(BCSrcOp, - m_ShuffleVector(m_Value(Vec), m_Undef(), m_Constant(Mask)))) { + ArrayRef<int> Mask; + if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant - if (auto *Elem = dyn_cast_or_null<ConstantInt>(Mask->getSplatValue())) { + if (is_splat(Mask)) { auto *VecTy = cast<VectorType>(BCSrcOp->getType()); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { @@ -2805,6 +2858,7 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, // then: // => %E = extractelement <N x iK> %vec, i32 Elem // icmp <pred> iK %SplatVal, <pattern> + Value *Elem = Builder.getInt32(Mask[0]); Value *Extract = Builder.CreateExtractElement(Vec, Elem); Value *NewC = ConstantInt::get(EltTy, C->trunc(EltTy->getBitWidth())); return new ICmpInst(Pred, Extract, NewC); @@ -2928,12 +2982,9 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; case Instruction::Add: { // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - const APInt *BOC; - if (match(BOp1, m_APInt(BOC))) { - if (BO->hasOneUse()) { - Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); - return new ICmpInst(Pred, BOp0, SubC); - } + if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + if (BO->hasOneUse()) + return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); } else if (C.isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. @@ -2963,11 +3014,11 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; case Instruction::Sub: if (BO->hasOneUse()) { - const APInt *BOC; - if (match(BOp0, m_APInt(BOC))) { + // Only check for constant LHS here, as constant RHS will be canonicalized + // to add and use the fold above. + if (Constant *BOC = dyn_cast<Constant>(BOp0)) { // Replace ((sub BOC, B) != C) with (B != BOC-C). - Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); - return new ICmpInst(Pred, BOp1, SubC); + return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS)); } else if (C.isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); @@ -3028,20 +3079,16 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, unsigned BitWidth = C.getBitWidth(); switch (II->getIntrinsicID()) { case Intrinsic::bswap: - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap())); - return &Cmp; + // bswap(A) == C -> A == bswap(C) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + ConstantInt::get(Ty, C.byteSwap())); case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (C == BitWidth) { - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); - return &Cmp; - } + if (C == BitWidth) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + ConstantInt::getNullValue(Ty)); // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set // and Mask1 has bits 0..C+1 set. Similar for ctl, but for high bits. @@ -3054,10 +3101,9 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, APInt Mask2 = IsTrailing ? APInt::getOneBitSet(BitWidth, Num) : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); - Cmp.setOperand(0, Builder.CreateAnd(II->getArgOperand(0), Mask1)); - Cmp.setOperand(1, ConstantInt::get(Ty, Mask2)); - Worklist.Add(II); - return &Cmp; + return new ICmpInst(Cmp.getPredicate(), + Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); } break; } @@ -3066,14 +3112,10 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != bool IsZero = C.isNullValue(); - if (IsZero || C == BitWidth) { - Worklist.Add(II); - Cmp.setOperand(0, II->getArgOperand(0)); - auto *NewOp = - IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty); - Cmp.setOperand(1, NewOp); - return &Cmp; - } + if (IsZero || C == BitWidth) + return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty)); + break; } @@ -3081,9 +3123,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, // uadd.sat(a, b) == 0 -> (a | b) == 0 if (C.isNullValue()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); - return replaceInstUsesWith(Cmp, Builder.CreateICmp( - Cmp.getPredicate(), Or, Constant::getNullValue(Ty))); - + return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty)); } break; } @@ -3093,8 +3133,7 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, if (C.isNullValue()) { ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; - return ICmpInst::Create(Instruction::ICmp, NewPred, - II->getArgOperand(0), II->getArgOperand(1)); + return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } break; } @@ -3300,30 +3339,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, // x & (-1 >> y) != x -> x u> (-1 >> y) DstPred = ICmpInst::Predicate::ICMP_UGT; break; - case ICmpInst::Predicate::ICMP_UGT: + case ICmpInst::Predicate::ICMP_ULT: + // x & (-1 >> y) u< x -> x u> (-1 >> y) // x u> x & (-1 >> y) -> x u> (-1 >> y) - assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); DstPred = ICmpInst::Predicate::ICMP_UGT; break; case ICmpInst::Predicate::ICMP_UGE: // x & (-1 >> y) u>= x -> x u<= (-1 >> y) - assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); - DstPred = ICmpInst::Predicate::ICMP_ULE; - break; - case ICmpInst::Predicate::ICMP_ULT: - // x & (-1 >> y) u< x -> x u> (-1 >> y) - assert(X == I.getOperand(1) && "instsimplify took care of commut. variant"); - DstPred = ICmpInst::Predicate::ICMP_UGT; - break; - case ICmpInst::Predicate::ICMP_ULE: // x u<= x & (-1 >> y) -> x u<= (-1 >> y) - assert(X == I.getOperand(0) && "instsimplify took care of commut. variant"); DstPred = ICmpInst::Predicate::ICMP_ULE; break; - case ICmpInst::Predicate::ICMP_SGT: + case ICmpInst::Predicate::ICMP_SLT: + // x & (-1 >> y) s< x -> x s> (-1 >> y) // x s> x & (-1 >> y) -> x s> (-1 >> y) - if (X != I.getOperand(0)) // X must be on LHS of comparison! - return nullptr; // Ignore the other case. if (!match(M, m_Constant())) // Can not do this fold with non-constant. return nullptr; if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. @@ -3332,33 +3360,19 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, break; case ICmpInst::Predicate::ICMP_SGE: // x & (-1 >> y) s>= x -> x s<= (-1 >> y) - if (X != I.getOperand(1)) // X must be on RHS of comparison! - return nullptr; // Ignore the other case. + // x s<= x & (-1 >> y) -> x s<= (-1 >> y) if (!match(M, m_Constant())) // Can not do this fold with non-constant. return nullptr; if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. return nullptr; DstPred = ICmpInst::Predicate::ICMP_SLE; break; - case ICmpInst::Predicate::ICMP_SLT: - // x & (-1 >> y) s< x -> x s> (-1 >> y) - if (X != I.getOperand(1)) // X must be on RHS of comparison! - return nullptr; // Ignore the other case. - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; - DstPred = ICmpInst::Predicate::ICMP_SGT; - break; + case ICmpInst::Predicate::ICMP_SGT: case ICmpInst::Predicate::ICMP_SLE: - // x s<= x & (-1 >> y) -> x s<= (-1 >> y) - if (X != I.getOperand(0)) // X must be on LHS of comparison! - return nullptr; // Ignore the other case. - if (!match(M, m_Constant())) // Can not do this fold with non-constant. - return nullptr; - if (!match(M, m_NonNegative())) // Must not have any -1 vector elements. - return nullptr; - DstPred = ICmpInst::Predicate::ICMP_SLE; + return nullptr; + case ICmpInst::Predicate::ICMP_UGT: + case ICmpInst::Predicate::ICMP_ULE: + llvm_unreachable("Instsimplify took care of commut. variant"); break; default: llvm_unreachable("All possible folds are handled."); @@ -3370,8 +3384,9 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, Type *OpTy = M->getType(); auto *VecC = dyn_cast<Constant>(M); if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) { + auto *OpVTy = cast<VectorType>(OpTy); Constant *SafeReplacementConstant = nullptr; - for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) { + for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa<UndefValue>(VecC->getAggregateElement(i))) { SafeReplacementConstant = VecC->getAggregateElement(i); break; @@ -3494,7 +3509,8 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, Instruction *NarrowestShift = XShift; Type *WidestTy = WidestShift->getType(); - assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + Type *NarrowestTy = NarrowestShift->getType(); + assert(NarrowestTy == I.getOperand(0)->getType() && "We did not look past any shifts while matching XShift though."); bool HadTrunc = WidestTy != I.getOperand(0)->getType(); @@ -3533,6 +3549,23 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, if (XShAmt->getType() != YShAmt->getType()) return nullptr; + // As input, we have the following pattern: + // icmp eq/ne (and ((x shift Q), (y oppositeshift K))), 0 + // We want to rewrite that as: + // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (WidestTy->getScalarSizeInBits() - 1) + + (NarrowestTy->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // Can we fold (XShAmt+YShAmt) ? auto *NewShAmt = dyn_cast_or_null<Constant>( SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, @@ -3627,9 +3660,6 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), m_Value(Y)))) { Mul = nullptr; - // Canonicalize as-if y was on RHS. - if (I.getOperand(1) != Y) - Pred = I.getSwappedPredicate(); // Are we checking that overflow does not happen, or does happen? switch (Pred) { @@ -3674,6 +3704,11 @@ Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (NeedNegation) // This technically increases instruction count. Res = Builder.CreateNot(Res, "umul.not.ov"); + // If we replaced the mul, erase it. Do this after all uses of Builder, + // as the mul is used as insertion point. + if (MulHadOtherUses) + eraseInstFromFunction(*Mul); + return Res; } @@ -4202,9 +4237,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { if (X) { // Build (X^Y) & Z Op1 = Builder.CreateXor(X, Y); Op1 = Builder.CreateAnd(Op1, Z); - I.setOperand(0, Op1); - I.setOperand(1, Constant::getNullValue(Op1->getType())); - return &I; + return new ICmpInst(Pred, Op1, Constant::getNullValue(Op1->getType())); } } @@ -4613,17 +4646,6 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, case ICmpInst::ICMP_NE: // Recognize pattern: // mulval = mul(zext A, zext B) - // cmp eq/neq mulval, zext trunc mulval - if (ZExtInst *Zext = dyn_cast<ZExtInst>(OtherVal)) - if (Zext->hasOneUse()) { - Value *ZextArg = Zext->getOperand(0); - if (TruncInst *Trunc = dyn_cast<TruncInst>(ZextArg)) - if (Trunc->getType()->getPrimitiveSizeInBits() == MulWidth) - break; //Recognized - } - - // Recognize pattern: - // mulval = mul(zext A, zext B) // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. ConstantInt *CI; Value *ValToMask; @@ -4701,7 +4723,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, Function *F = Intrinsic::getDeclaration( I.getModule(), Intrinsic::umul_with_overflow, MulType); CallInst *Call = Builder.CreateCall(F, {MulA, MulB}, "umul"); - IC.Worklist.Add(MulInstr); + IC.Worklist.push(MulInstr); // If there are uses of mul result other than the comparison, we know that // they are truncation or binary AND. Change them to use result of @@ -4723,18 +4745,16 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, ConstantInt *CI = cast<ConstantInt>(BO->getOperand(1)); APInt ShortMask = CI->getValue().trunc(MulWidth); Value *ShortAnd = Builder.CreateAnd(Mul, ShortMask); - Instruction *Zext = - cast<Instruction>(Builder.CreateZExt(ShortAnd, BO->getType())); - IC.Worklist.Add(Zext); + Value *Zext = Builder.CreateZExt(ShortAnd, BO->getType()); IC.replaceInstUsesWith(*BO, Zext); } else { llvm_unreachable("Unexpected Binary operation"); } - IC.Worklist.Add(cast<Instruction>(U)); + IC.Worklist.push(cast<Instruction>(U)); } } if (isa<Instruction>(OtherVal)) - IC.Worklist.Add(cast<Instruction>(OtherVal)); + IC.Worklist.push(cast<Instruction>(OtherVal)); // The original icmp gets replaced with the overflow value, maybe inverted // depending on predicate. @@ -5189,8 +5209,8 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) return llvm::None; - } else if (Type->isVectorTy()) { - unsigned NumElts = Type->getVectorNumElements(); + } else if (auto *VTy = dyn_cast<VectorType>(Type)) { + unsigned NumElts = VTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) @@ -5252,6 +5272,47 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); } +/// If we have a comparison with a non-canonical predicate, if we can update +/// all the users, invert the predicate and adjust all the users. +static CmpInst *canonicalizeICmpPredicate(CmpInst &I) { + // Is the predicate already canonical? + CmpInst::Predicate Pred = I.getPredicate(); + if (isCanonicalPredicate(Pred)) + return nullptr; + + // Can all users be adjusted to predicate inversion? + if (!canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return nullptr; + + // Ok, we can canonicalize comparison! + // Let's first invert the comparison's predicate. + I.setPredicate(CmpInst::getInversePredicate(Pred)); + I.setName(I.getName() + ".not"); + + // And now let's adjust every user. + for (User *U : I.users()) { + switch (cast<Instruction>(U)->getOpcode()) { + case Instruction::Select: { + auto *SI = cast<SelectInst>(U); + SI->swapValues(); + SI->swapProfMetadata(); + break; + } + case Instruction::Br: + cast<BranchInst>(U)->swapSuccessors(); // swaps prof metadata too + break; + case Instruction::Xor: + U->replaceAllUsesWith(&I); + break; + default: + llvm_unreachable("Got unexpected user - out of sync with " + "canFreelyInvertAllUsersOf() ?"); + } + } + + return &I; +} + /// Integer compare with boolean values can always be turned into bitwise ops. static Instruction *canonicalizeICmpBool(ICmpInst &I, InstCombiner::BuilderTy &Builder) { @@ -5338,10 +5399,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, Value *X, *Y; if (match(&Cmp, m_c_ICmp(Pred, m_OneUse(m_Shl(m_One(), m_Value(Y))), m_Value(X)))) { - // We want X to be the icmp's second operand, so swap predicate if it isn't. - if (Cmp.getOperand(0) == X) - Pred = Cmp.getSwappedPredicate(); - switch (Pred) { case ICmpInst::ICMP_ULE: NewPred = ICmpInst::ICMP_NE; @@ -5361,10 +5418,6 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, // The variant with 'add' is not canonical, (the variant with 'not' is) // we only get it because it has extra uses, and can't be canonicalized, - // We want X to be the icmp's second operand, so swap predicate if it isn't. - if (Cmp.getOperand(0) == X) - Pred = Cmp.getSwappedPredicate(); - switch (Pred) { case ICmpInst::ICMP_ULT: NewPred = ICmpInst::ICMP_NE; @@ -5385,21 +5438,45 @@ static Instruction *foldICmpWithHighBitMask(ICmpInst &Cmp, static Instruction *foldVectorCmp(CmpInst &Cmp, InstCombiner::BuilderTy &Builder) { - // If both arguments of the cmp are shuffles that use the same mask and - // shuffle within a single vector, move the shuffle after the cmp. + const CmpInst::Predicate Pred = Cmp.getPredicate(); Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); Value *V1, *V2; - Constant *M; - if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(M))) && - match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(M))) && - V1->getType() == V2->getType() && - (LHS->hasOneUse() || RHS->hasOneUse())) { - // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M - CmpInst::Predicate P = Cmp.getPredicate(); - Value *NewCmp = isa<ICmpInst>(Cmp) ? Builder.CreateICmp(P, V1, V2) - : Builder.CreateFCmp(P, V1, V2); + ArrayRef<int> M; + if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) + return nullptr; + + // If both arguments of the cmp are shuffles that use the same mask and + // shuffle within a single vector, move the shuffle after the cmp: + // cmp (shuffle V1, M), (shuffle V2, M) --> shuffle (cmp V1, V2), M + Type *V1Ty = V1->getType(); + if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && + V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); } + + // Try to canonicalize compare with splatted operand and splat constant. + // TODO: We could generalize this for more than splats. See/use the code in + // InstCombiner::foldVectorBinop(). + Constant *C; + if (!LHS->hasOneUse() || !match(RHS, m_Constant(C))) + return nullptr; + + // Length-changing splats are ok, so adjust the constants as needed: + // cmp (shuffle V1, M), C --> shuffle (cmp V1, C'), M + Constant *ScalarC = C->getSplatValue(/* AllowUndefs */ true); + int MaskSplatIndex; + if (ScalarC && match(M, m_SplatOrUndefMask(MaskSplatIndex))) { + // We allow undefs in matching, but this transform removes those for safety. + // Demanded elements analysis should be able to recover some/all of that. + C = ConstantVector::getSplat(cast<VectorType>(V1Ty)->getElementCount(), + ScalarC); + SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); + Value *NewCmp = Builder.CreateCmp(Pred, V1, C); + return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), + NewM); + } + return nullptr; } @@ -5474,8 +5551,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = canonicalizeICmpBool(I, Builder)) return Res; - if (ICmpInst *NewICmp = canonicalizeCmpWithConstant(I)) - return NewICmp; + if (Instruction *Res = canonicalizeCmpWithConstant(I)) + return Res; + + if (Instruction *Res = canonicalizeICmpPredicate(I)) + return Res; if (Instruction *Res = foldICmpWithConstant(I)) return Res; @@ -5565,6 +5645,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpBitCast(I, Builder)) return Res; + // TODO: Hoist this above the min/max bailout. if (Instruction *R = foldICmpWithCastOp(I)) return R; @@ -5600,9 +5681,13 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { isa<IntegerType>(A->getType())) { Value *Result; Constant *Overflow; - if (OptimizeOverflowCheck(Instruction::Add, /*Signed*/false, A, B, - *AddI, Result, Overflow)) { + // m_UAddWithOverflow can match patterns that do not include an explicit + // "add" instruction, so check the opcode of the matched op. + if (AddI->getOpcode() == Instruction::Add && + OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI, + Result, Overflow)) { replaceInstUsesWith(*AddI, Result); + eraseInstFromFunction(*AddI); return replaceInstUsesWith(I, Overflow); } } @@ -5689,7 +5774,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, // TODO: Can never be -0.0 and other non-representable values APFloat RHSRoundInt(RHS); RHSRoundInt.roundToIntegral(APFloat::rmNearestTiesToEven); - if (RHS.compare(RHSRoundInt) != APFloat::cmpEqual) { + if (RHS != RHSRoundInt) { if (P == FCmpInst::FCMP_OEQ || P == FCmpInst::FCMP_UEQ) return replaceInstUsesWith(I, Builder.getFalse()); @@ -5777,7 +5862,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat SMax(RHS.getSemantics()); SMax.convertFromAPInt(APInt::getSignedMaxValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMax.compare(RHS) == APFloat::cmpLessThan) { // smax < 13123.0 + if (SMax < RHS) { // smax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5789,7 +5874,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat UMax(RHS.getSemantics()); UMax.convertFromAPInt(APInt::getMaxValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (UMax.compare(RHS) == APFloat::cmpLessThan) { // umax < 13123.0 + if (UMax < RHS) { // umax < 13123.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5802,7 +5887,7 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, APFloat SMin(RHS.getSemantics()); SMin.convertFromAPInt(APInt::getSignedMinValue(IntWidth), true, APFloat::rmNearestTiesToEven); - if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // smin > 12312.0 + if (SMin > RHS) { // smin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5810,10 +5895,10 @@ Instruction *InstCombiner::foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, } } else { // See if the RHS value is < UnsignedMin. - APFloat SMin(RHS.getSemantics()); - SMin.convertFromAPInt(APInt::getMinValue(IntWidth), true, + APFloat UMin(RHS.getSemantics()); + UMin.convertFromAPInt(APInt::getMinValue(IntWidth), false, APFloat::rmNearestTiesToEven); - if (SMin.compare(RHS) == APFloat::cmpGreaterThan) { // umin > 12312.0 + if (UMin > RHS) { // umin > 12312.0 if (Pred == ICmpInst::ICMP_NE || Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) return replaceInstUsesWith(I, Builder.getTrue()); @@ -5949,16 +6034,15 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, } /// Optimize fabs(X) compared with zero. -static Instruction *foldFabsWithFcmpZero(FCmpInst &I) { +static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombiner &IC) { Value *X; if (!match(I.getOperand(0), m_Intrinsic<Intrinsic::fabs>(m_Value(X))) || !match(I.getOperand(1), m_PosZeroFP())) return nullptr; - auto replacePredAndOp0 = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { I->setPredicate(P); - I->setOperand(0, X); - return I; + return IC.replaceOperand(*I, 0, X); }; switch (I.getPredicate()) { @@ -6058,14 +6142,11 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, // then canonicalize the operand to 0.0. if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { - if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) { - I.setOperand(0, ConstantFP::getNullValue(OpType)); - return &I; - } - if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) { - I.setOperand(1, ConstantFP::getNullValue(OpType)); - return &I; - } + if (!match(Op0, m_PosZeroFP()) && isKnownNeverNaN(Op0, &TLI)) + return replaceOperand(I, 0, ConstantFP::getNullValue(OpType)); + + if (!match(Op1, m_PosZeroFP()) && isKnownNeverNaN(Op1, &TLI)) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); } // fcmp pred (fneg X), (fneg Y) -> fcmp swap(pred) X, Y @@ -6090,10 +6171,8 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // The sign of 0.0 is ignored by fcmp, so canonicalize to +0.0: // fcmp Pred X, -0.0 --> fcmp Pred X, 0.0 - if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) { - I.setOperand(1, ConstantFP::getNullValue(OpType)); - return &I; - } + if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) + return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); // Handle fcmp with instruction LHS and constant RHS. Instruction *LHSI; @@ -6128,7 +6207,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } - if (Instruction *R = foldFabsWithFcmpZero(I)) + if (Instruction *R = foldFabsWithFcmpZero(I, *this)) return R; if (match(Op0, m_FNeg(m_Value(X)))) { @@ -6159,8 +6238,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { APFloat Fabs = TruncC; Fabs.clearSign(); if (!Lossy && - ((Fabs.compare(APFloat::getSmallestNormalized(FPSem)) != - APFloat::cmpLessThan) || Fabs.isZero())) { + (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) { Constant *NewC = ConstantFP::get(X->getType(), TruncC); return new FCmpInst(Pred, X, NewC, "", &I); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 1a746cb87abb4..f918dc7198ca9 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -16,7 +16,8 @@ #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #include "llvm/ADT/ArrayRef.h" -#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Statistic.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" @@ -50,6 +51,7 @@ using namespace llvm::PatternMatch; namespace llvm { +class AAResults; class APInt; class AssumptionCache; class BlockFrequencyInfo; @@ -213,18 +215,23 @@ static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) { } /// Given i1 V, can every user of V be freely adapted if V is changed to !V ? +/// InstCombine's canonicalizeICmpPredicate() must be kept in sync with this fn. /// /// See also: isFreeToInvert() static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) { // Look at every user of V. - for (User *U : V->users()) { - if (U == IgnoredUser) + for (Use &U : V->uses()) { + if (U.getUser() == IgnoredUser) continue; // Don't consider this user. - auto *I = cast<Instruction>(U); + auto *I = cast<Instruction>(U.getUser()); switch (I->getOpcode()) { case Instruction::Select: + if (U.getOperandNo() != 0) // Only if the value is used as select cond. + return false; + break; case Instruction::Br: + assert(U.getOperandNo() == 0 && "Must be branching on that value."); break; // Free to invert by swapping true/false values/destinations. case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it. if (!match(I, m_Not(m_Value()))) @@ -244,9 +251,10 @@ static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) { /// If no identity constant exists, replace undef with some other safe constant. static inline Constant *getSafeVectorConstantForBinop( BinaryOperator::BinaryOps Opcode, Constant *In, bool IsRHSConstant) { - assert(In->getType()->isVectorTy() && "Not expecting scalars here"); + auto *InVTy = dyn_cast<VectorType>(In->getType()); + assert(InVTy && "Not expecting scalars here"); - Type *EltTy = In->getType()->getVectorElementType(); + Type *EltTy = InVTy->getElementType(); auto *SafeC = ConstantExpr::getBinOpIdentity(Opcode, EltTy, IsRHSConstant); if (!SafeC) { // TODO: Should this be available as a constant utility function? It is @@ -284,7 +292,7 @@ static inline Constant *getSafeVectorConstantForBinop( } } assert(SafeC && "Must have safe constant for binop"); - unsigned NumElts = In->getType()->getVectorNumElements(); + unsigned NumElts = InVTy->getNumElements(); SmallVector<Constant *, 16> Out(NumElts); for (unsigned i = 0; i != NumElts; ++i) { Constant *C = In->getAggregateElement(i); @@ -313,10 +321,7 @@ private: // Mode in which we are running the combiner. const bool MinimizeSize; - /// Enable combines that trigger rarely but are costly in compiletime. - const bool ExpensiveCombines; - - AliasAnalysis *AA; + AAResults *AA; // Required analyses. AssumptionCache &AC; @@ -336,12 +341,12 @@ private: public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, - bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, + bool MinimizeSize, AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, const DataLayout &DL, LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), - ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), + AA(AA), AC(AC), TLI(TLI), DT(DT), DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), BFI(BFI), PSI(PSI), LI(LI) {} /// Run the combiner over the entire worklist until it is empty. @@ -420,7 +425,7 @@ public: Instruction *visitIntToPtr(IntToPtrInst &CI); Instruction *visitBitCast(BitCastInst &CI); Instruction *visitAddrSpaceCast(AddrSpaceCastInst &CI); - Instruction *FoldItoFPtoI(Instruction &FI); + Instruction *foldItoFPtoI(CastInst &FI); Instruction *visitSelectInst(SelectInst &SI); Instruction *visitCallInst(CallInst &CI); Instruction *visitInvokeInst(InvokeInst &II); @@ -435,6 +440,7 @@ public: Instruction *visitLoadInst(LoadInst &LI); Instruction *visitStoreInst(StoreInst &SI); Instruction *visitAtomicRMWInst(AtomicRMWInst &SI); + Instruction *visitUnconditionalBranchInst(BranchInst &BI); Instruction *visitBranchInst(BranchInst &BI); Instruction *visitFenceInst(FenceInst &FI); Instruction *visitSwitchInst(SwitchInst &SI); @@ -445,8 +451,7 @@ public: Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); Instruction *visitExtractValueInst(ExtractValueInst &EV); Instruction *visitLandingPadInst(LandingPadInst &LI); - Instruction *visitVAStartInst(VAStartInst &I); - Instruction *visitVACopyInst(VACopyInst &I); + Instruction *visitVAEndInst(VAEndInst &I); Instruction *visitFreeze(FreezeInst &I); /// Specify what to return for unhandled instructions. @@ -515,7 +520,7 @@ private: Instruction *simplifyMaskedStore(IntrinsicInst &II); Instruction *simplifyMaskedGather(IntrinsicInst &II); Instruction *simplifyMaskedScatter(IntrinsicInst &II); - + /// Transform (zext icmp) to bitwise / integer operations in order to /// eliminate it. /// @@ -621,9 +626,9 @@ private: Instruction::CastOps isEliminableCastPair(const CastInst *CI1, const CastInst *CI2); - Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &I); + Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And); + Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); + Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should @@ -631,11 +636,12 @@ private: Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd); Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, - bool JoinedByAnd, Instruction &CxtI); + BinaryOperator &Logic); Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); Value *getSelectCondition(Value *A, Value *B); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); + Instruction *foldFPSignBitOps(BinaryOperator &I); public: /// Inserts an instruction \p New before instruction \p Old @@ -647,7 +653,7 @@ public: "New instruction already inserted into a basic block!"); BasicBlock *BB = Old.getParent(); BB->getInstList().insert(Old.getIterator(), New); // Insert inst - Worklist.Add(New); + Worklist.push(New); return New; } @@ -668,7 +674,7 @@ public: // no changes were made to the program. if (I.use_empty()) return nullptr; - Worklist.AddUsersToWorkList(I); // Add all modified instrs to worklist. + Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist. // If we are replacing the instruction with itself, this must be in a // segment of unreachable code, so just clobber the instruction. @@ -682,6 +688,19 @@ public: return &I; } + /// Replace operand of instruction and add old operand to the worklist. + Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) { + Worklist.addValue(I.getOperand(OpNum)); + I.setOperand(OpNum, V); + return &I; + } + + /// Replace use and add the previously used value to the worklist. + void replaceUse(Use &U, Value *NewValue) { + Worklist.addValue(U); + U = NewValue; + } + /// Creates a result tuple for an overflow intrinsic \p II with a given /// \p Result and a constant \p Overflow value. Instruction *CreateOverflowTuple(IntrinsicInst *II, Value *Result, @@ -710,16 +729,15 @@ public: Instruction *eraseInstFromFunction(Instruction &I) { LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); - salvageDebugInfoOrMarkUndef(I); + salvageDebugInfo(I); // Make sure that we reprocess all operands now that we reduced their // use counts. - if (I.getNumOperands() < 8) { - for (Use &Operand : I.operands()) - if (auto *Inst = dyn_cast<Instruction>(Operand)) - Worklist.Add(Inst); - } - Worklist.Remove(&I); + for (Use &Operand : I.operands()) + if (auto *Inst = dyn_cast<Instruction>(Operand)) + Worklist.add(Inst); + + Worklist.remove(&I); I.eraseFromParent(); MadeIRChange = true; return nullptr; // Don't do anything with FI @@ -869,6 +887,7 @@ private: /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); + Instruction *foldVectorSelect(SelectInst &Sel); /// Given a binary operator, cast instruction, or select which has a PHI node /// as operand #0, see if we can fold the instruction into the PHI (which is @@ -1004,6 +1023,64 @@ private: Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); }; +namespace { + +// As a default, let's assume that we want to be aggressive, +// and attempt to traverse with no limits in attempt to sink negation. +static constexpr unsigned NegatorDefaultMaxDepth = ~0U; + +// Let's guesstimate that most often we will end up visiting/producing +// fairly small number of new instructions. +static constexpr unsigned NegatorMaxNodesSSO = 16; + +} // namespace + +class Negator final { + /// Top-to-bottom, def-to-use negated instruction tree we produced. + SmallVector<Instruction *, NegatorMaxNodesSSO> NewInstructions; + + using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; + BuilderTy Builder; + + const DataLayout &DL; + AssumptionCache &AC; + const DominatorTree &DT; + + const bool IsTrulyNegation; + + SmallDenseMap<Value *, Value *> NegationsCache; + + Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC, + const DominatorTree &DT, bool IsTrulyNegation); + +#if LLVM_ENABLE_STATS + unsigned NumValuesVisitedInThisNegator = 0; + ~Negator(); +#endif + + using Result = std::pair<ArrayRef<Instruction *> /*NewInstructions*/, + Value * /*NegatedRoot*/>; + + LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth); + + LLVM_NODISCARD Value *negate(Value *V, unsigned Depth); + + /// Recurse depth-first and attempt to sink the negation. + /// FIXME: use worklist? + LLVM_NODISCARD Optional<Result> run(Value *Root); + + Negator(const Negator &) = delete; + Negator(Negator &&) = delete; + Negator &operator=(const Negator &) = delete; + Negator &operator=(Negator &&) = delete; + +public: + /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, + /// otherwise returns negated value. + LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root, + InstCombiner &IC); +}; + } // end namespace llvm #undef DEBUG_TYPE diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index ebf9d24eecc41..dad2f23120bdb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -14,8 +14,8 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/Loads.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfoMetadata.h" @@ -24,6 +24,7 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -32,22 +33,6 @@ using namespace PatternMatch; STATISTIC(NumDeadStore, "Number of dead stores eliminated"); STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); -/// pointsToConstantGlobal - Return true if V (possibly indirectly) points to -/// some part of a constant global variable. This intentionally only accepts -/// constant expressions because we can't rewrite arbitrary instructions. -static bool pointsToConstantGlobal(Value *V) { - if (GlobalVariable *GV = dyn_cast<GlobalVariable>(V)) - return GV->isConstant(); - - if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { - if (CE->getOpcode() == Instruction::BitCast || - CE->getOpcode() == Instruction::AddrSpaceCast || - CE->getOpcode() == Instruction::GetElementPtr) - return pointsToConstantGlobal(CE->getOperand(0)); - } - return false; -} - /// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) /// pointer to an alloca. Ignore any reads of the pointer, return false if we /// see any stores or other unknown uses. If we see pointer arithmetic, keep @@ -56,7 +41,8 @@ static bool pointsToConstantGlobal(Value *V) { /// the alloca, and if the source pointer is a pointer to a constant global, we /// can optimize this. static bool -isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, +isOnlyCopiedFromConstantMemory(AAResults *AA, + Value *V, MemTransferInst *&TheCopy, SmallVectorImpl<Instruction *> &ToDelete) { // We track lifetime intrinsics as we encounter them. If we decide to go // ahead and replace the value with the global, this lets the caller quickly @@ -145,7 +131,7 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, if (U.getOperandNo() != 0) return false; // If the source of the memcpy/move is not a constant global, reject it. - if (!pointsToConstantGlobal(MI->getSource())) + if (!AA->pointsToConstantMemory(MI->getSource())) return false; // Otherwise, the transform is safe. Remember the copy instruction. @@ -159,10 +145,11 @@ isOnlyCopiedFromConstantGlobal(Value *V, MemTransferInst *&TheCopy, /// modified by a copy from a constant global. If we can prove this, we can /// replace any uses of the alloca with uses of the global directly. static MemTransferInst * -isOnlyCopiedFromConstantGlobal(AllocaInst *AI, +isOnlyCopiedFromConstantMemory(AAResults *AA, + AllocaInst *AI, SmallVectorImpl<Instruction *> &ToDelete) { MemTransferInst *TheCopy = nullptr; - if (isOnlyCopiedFromConstantGlobal(AI, TheCopy, ToDelete)) + if (isOnlyCopiedFromConstantMemory(AA, AI, TheCopy, ToDelete)) return TheCopy; return nullptr; } @@ -187,9 +174,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { return nullptr; // Canonicalize it. - Value *V = IC.Builder.getInt32(1); - AI.setOperand(0, V); - return &AI; + return IC.replaceOperand(AI, 0, IC.Builder.getInt32(1)); } // Convert: alloca Ty, C - where C is a constant != 1 into: alloca [C x Ty], 1 @@ -197,7 +182,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { if (C->getValue().getActiveBits() <= 64) { Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); - New->setAlignment(MaybeAlign(AI.getAlignment())); + New->setAlignment(AI.getAlign()); // Scan to the end of the allocation instructions, to skip over a block of // allocas if possible...also skip interleaved debug info @@ -230,8 +215,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { Type *IntPtrTy = IC.getDataLayout().getIntPtrType(AI.getType()); if (AI.getArraySize()->getType() != IntPtrTy) { Value *V = IC.Builder.CreateIntCast(AI.getArraySize(), IntPtrTy, false); - AI.setOperand(0, V); - return &AI; + return IC.replaceOperand(AI, 0, V); } return nullptr; @@ -298,7 +282,8 @@ void PointerReplacer::replace(Instruction *I) { if (auto *LT = dyn_cast<LoadInst>(I)) { auto *V = getReplacement(LT->getPointerOperand()); assert(V && "Operand not replaced"); - auto *NewI = new LoadInst(I->getType(), V); + auto *NewI = new LoadInst(I->getType(), V, "", false, + IC.getDataLayout().getABITypeAlign(I->getType())); NewI->takeName(LT); IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); @@ -343,22 +328,16 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { return I; if (AI.getAllocatedType()->isSized()) { - // If the alignment is 0 (unspecified), assign it the preferred alignment. - if (AI.getAlignment() == 0) - AI.setAlignment( - MaybeAlign(DL.getPrefTypeAlignment(AI.getAllocatedType()))); - // Move all alloca's of zero byte objects to the entry block and merge them // together. Note that we only do this for alloca's, because malloc should // allocate and return a unique pointer, even for a zero byte allocation. - if (DL.getTypeAllocSize(AI.getAllocatedType()) == 0) { + if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinSize() == 0) { // For a zero sized alloca there is no point in doing an array allocation. // This is helpful if the array size is a complicated expression not used // elsewhere. - if (AI.isArrayAllocation()) { - AI.setOperand(0, ConstantInt::get(AI.getArraySize()->getType(), 1)); - return &AI; - } + if (AI.isArrayAllocation()) + return replaceOperand(AI, 0, + ConstantInt::get(AI.getArraySize()->getType(), 1)); // Get the first instruction in the entry block. BasicBlock &EntryBlock = AI.getParent()->getParent()->getEntryBlock(); @@ -369,21 +348,16 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // dominance as the array size was forced to a constant earlier already. AllocaInst *EntryAI = dyn_cast<AllocaInst>(FirstInst); if (!EntryAI || !EntryAI->getAllocatedType()->isSized() || - DL.getTypeAllocSize(EntryAI->getAllocatedType()) != 0) { + DL.getTypeAllocSize(EntryAI->getAllocatedType()) + .getKnownMinSize() != 0) { AI.moveBefore(FirstInst); return &AI; } - // If the alignment of the entry block alloca is 0 (unspecified), - // assign it the preferred alignment. - if (EntryAI->getAlignment() == 0) - EntryAI->setAlignment( - MaybeAlign(DL.getPrefTypeAlignment(EntryAI->getAllocatedType()))); // Replace this zero-sized alloca with the one at the start of the entry // block after ensuring that the address will be aligned enough for both // types. - const MaybeAlign MaxAlign( - std::max(EntryAI->getAlignment(), AI.getAlignment())); + const Align MaxAlign = std::max(EntryAI->getAlign(), AI.getAlign()); EntryAI->setAlignment(MaxAlign); if (AI.getType() != EntryAI->getType()) return new BitCastInst(EntryAI, AI.getType()); @@ -392,41 +366,40 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { } } - if (AI.getAlignment()) { - // Check to see if this allocation is only modified by a memcpy/memmove from - // a constant global whose alignment is equal to or exceeds that of the - // allocation. If this is the case, we can change all users to use - // the constant global instead. This is commonly produced by the CFE by - // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' - // is only subsequently read. - SmallVector<Instruction *, 4> ToDelete; - if (MemTransferInst *Copy = isOnlyCopiedFromConstantGlobal(&AI, ToDelete)) { - unsigned SourceAlign = getOrEnforceKnownAlignment( - Copy->getSource(), AI.getAlignment(), DL, &AI, &AC, &DT); - if (AI.getAlignment() <= SourceAlign && - isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { - LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); - LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); - for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) - eraseInstFromFunction(*ToDelete[i]); - Constant *TheSrc = cast<Constant>(Copy->getSource()); - auto *SrcTy = TheSrc->getType(); - auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), - SrcTy->getPointerAddressSpace()); - Constant *Cast = - ConstantExpr::getPointerBitCastOrAddrSpaceCast(TheSrc, DestTy); - if (AI.getType()->getPointerAddressSpace() == - SrcTy->getPointerAddressSpace()) { - Instruction *NewI = replaceInstUsesWith(AI, Cast); - eraseInstFromFunction(*Copy); - ++NumGlobalCopies; - return NewI; - } else { - PointerReplacer PtrReplacer(*this); - PtrReplacer.replacePointer(AI, Cast); - ++NumGlobalCopies; - } + // Check to see if this allocation is only modified by a memcpy/memmove from + // a constant whose alignment is equal to or exceeds that of the allocation. + // If this is the case, we can change all users to use the constant global + // instead. This is commonly produced by the CFE by constructs like "void + // foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' is only subsequently + // read. + SmallVector<Instruction *, 4> ToDelete; + if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) { + Align AllocaAlign = AI.getAlign(); + Align SourceAlign = getOrEnforceKnownAlignment( + Copy->getSource(), AllocaAlign, DL, &AI, &AC, &DT); + if (AllocaAlign <= SourceAlign && + isDereferenceableForAllocaSize(Copy->getSource(), &AI, DL)) { + LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); + LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); + for (unsigned i = 0, e = ToDelete.size(); i != e; ++i) + eraseInstFromFunction(*ToDelete[i]); + Value *TheSrc = Copy->getSource(); + auto *SrcTy = TheSrc->getType(); + auto *DestTy = PointerType::get(AI.getType()->getPointerElementType(), + SrcTy->getPointerAddressSpace()); + Value *Cast = + Builder.CreatePointerBitCastOrAddrSpaceCast(TheSrc, DestTy); + if (AI.getType()->getPointerAddressSpace() == + SrcTy->getPointerAddressSpace()) { + Instruction *NewI = replaceInstUsesWith(AI, Cast); + eraseInstFromFunction(*Copy); + ++NumGlobalCopies; + return NewI; } + + PointerReplacer PtrReplacer(*this); + PtrReplacer.replacePointer(AI, Cast); + ++NumGlobalCopies; } } @@ -462,15 +435,8 @@ LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy, NewPtr->getType()->getPointerAddressSpace() == AS)) NewPtr = Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); - unsigned Align = LI.getAlignment(); - if (!Align) - // If old load did not have an explicit alignment specified, - // manually preserve the implied (ABI) alignment of the load. - // Else we may inadvertently incorrectly over-promise alignment. - Align = getDataLayout().getABITypeAlignment(LI.getType()); - LoadInst *NewLoad = Builder.CreateAlignedLoad( - NewTy, NewPtr, Align, LI.isVolatile(), LI.getName() + Suffix); + NewTy, NewPtr, LI.getAlign(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); copyMetadataForLoad(*NewLoad, LI); return NewLoad; @@ -490,7 +456,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value StoreInst *NewStore = IC.Builder.CreateAlignedStore( V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlignment(), SI.isVolatile()); + SI.getAlign(), SI.isVolatile()); NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -594,11 +560,9 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // Do not perform canonicalization if minmax pattern is found (to avoid // infinite loop). Type *Dummy; - if (!Ty->isIntegerTy() && Ty->isSized() && - !(Ty->isVectorTy() && Ty->getVectorIsScalable()) && + if (!Ty->isIntegerTy() && Ty->isSized() && !isa<ScalableVectorType>(Ty) && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && - DL.typeSizeEqualsStoreSize(Ty) && - !DL.isNonIntegralPointerType(Ty) && + DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && !isMinMaxWithLoads( peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true), Dummy)) { @@ -674,10 +638,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { if (SL->hasPadding()) return nullptr; - auto Align = LI.getAlignment(); - if (!Align) - Align = DL.getABITypeAlignment(ST); - + const auto Align = LI.getAlign(); auto *Addr = LI.getPointerOperand(); auto *IdxType = Type::getInt32Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); @@ -690,9 +651,9 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { }; auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), Name + ".elt"); - auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); - auto *L = IC.Builder.CreateAlignedLoad(ST->getElementType(i), Ptr, - EltAlign, Name + ".unpack"); + auto *L = IC.Builder.CreateAlignedLoad( + ST->getElementType(i), Ptr, + commonAlignment(Align, SL->getElementOffset(i)), Name + ".unpack"); // Propagate AA metadata. It'll still be valid on the narrowed load. AAMDNodes AAMD; LI.getAAMetadata(AAMD); @@ -725,9 +686,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { const DataLayout &DL = IC.getDataLayout(); auto EltSize = DL.getTypeAllocSize(ET); - auto Align = LI.getAlignment(); - if (!Align) - Align = DL.getABITypeAlignment(T); + const auto Align = LI.getAlign(); auto *Addr = LI.getPointerOperand(); auto *IdxType = Type::getInt64Ty(T->getContext()); @@ -742,8 +701,9 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { }; auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), Name + ".elt"); - auto *L = IC.Builder.CreateAlignedLoad( - AT->getElementType(), Ptr, MinAlign(Align, Offset), Name + ".unpack"); + auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, + commonAlignment(Align, Offset), + Name + ".unpack"); AAMDNodes AAMD; LI.getAAMetadata(AAMD); L->setAAMetadata(AAMD); @@ -964,20 +924,14 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { return Res; // Attempt to improve the alignment. - unsigned KnownAlign = getOrEnforceKnownAlignment( - Op, DL.getPrefTypeAlignment(LI.getType()), DL, &LI, &AC, &DT); - unsigned LoadAlign = LI.getAlignment(); - unsigned EffectiveLoadAlign = - LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType()); - - if (KnownAlign > EffectiveLoadAlign) - LI.setAlignment(MaybeAlign(KnownAlign)); - else if (LoadAlign == 0) - LI.setAlignment(MaybeAlign(EffectiveLoadAlign)); + Align KnownAlign = getOrEnforceKnownAlignment( + Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); + if (KnownAlign > LI.getAlign()) + LI.setAlignment(KnownAlign); // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { - Worklist.Add(NewGEPI); + Worklist.push(NewGEPI); return &LI; } @@ -1030,7 +984,7 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). - const MaybeAlign Alignment(LI.getAlignment()); + Align Alignment = LI.getAlign(); if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), Alignment, DL, SI) && isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), @@ -1052,18 +1006,14 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // load (select (cond, null, P)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(1)) && !NullPointerIsDefined(SI->getFunction(), - LI.getPointerAddressSpace())) { - LI.setOperand(0, SI->getOperand(2)); - return &LI; - } + LI.getPointerAddressSpace())) + return replaceOperand(LI, 0, SI->getOperand(2)); // load (select (cond, P, null)) -> load P if (isa<ConstantPointerNull>(SI->getOperand(2)) && !NullPointerIsDefined(SI->getFunction(), - LI.getPointerAddressSpace())) { - LI.setOperand(0, SI->getOperand(1)); - return &LI; - } + LI.getPointerAddressSpace())) + return replaceOperand(LI, 0, SI->getOperand(1)); } } return nullptr; @@ -1204,9 +1154,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { if (SL->hasPadding()) return false; - auto Align = SI.getAlignment(); - if (!Align) - Align = DL.getABITypeAlignment(ST); + const auto Align = SI.getAlign(); SmallString<16> EltName = V->getName(); EltName += ".elt"; @@ -1224,7 +1172,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); - auto EltAlign = MinAlign(Align, SL->getElementOffset(i)); + auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); AAMDNodes AAMD; SI.getAAMetadata(AAMD); @@ -1252,9 +1200,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { const DataLayout &DL = IC.getDataLayout(); auto EltSize = DL.getTypeAllocSize(AT->getElementType()); - auto Align = SI.getAlignment(); - if (!Align) - Align = DL.getABITypeAlignment(T); + const auto Align = SI.getAlign(); SmallString<16> EltName = V->getName(); EltName += ".elt"; @@ -1274,7 +1220,7 @@ static bool unpackStoreToAggregate(InstCombiner &IC, StoreInst &SI) { auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); - auto EltAlign = MinAlign(Align, Offset); + auto EltAlign = commonAlignment(Align, Offset); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); AAMDNodes AAMD; SI.getAAMetadata(AAMD); @@ -1336,6 +1282,11 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) return false; + // Make sure the type would actually change. + // This condition can be hit with chains of bitcasts. + if (LI->getType() == CmpLoadTy) + return false; + // Make sure we're not changing the size of the load/store. const auto &DL = IC.getDataLayout(); if (DL.getTypeStoreSizeInBits(LI->getType()) != @@ -1372,16 +1323,10 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { return eraseInstFromFunction(SI); // Attempt to improve the alignment. - const Align KnownAlign = Align(getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT)); - const MaybeAlign StoreAlign = MaybeAlign(SI.getAlignment()); - const Align EffectiveStoreAlign = - StoreAlign ? *StoreAlign : Align(DL.getABITypeAlignment(Val->getType())); - - if (KnownAlign > EffectiveStoreAlign) + const Align KnownAlign = getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); + if (KnownAlign > SI.getAlign()) SI.setAlignment(KnownAlign); - else if (!StoreAlign) - SI.setAlignment(EffectiveStoreAlign); // Try to canonicalize the stored type. if (unpackStoreToAggregate(*this, SI)) @@ -1392,7 +1337,7 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { - Worklist.Add(NewGEPI); + Worklist.push(NewGEPI); return &SI; } @@ -1439,9 +1384,12 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (PrevSI->isUnordered() && equivalentAddressValues(PrevSI->getOperand(1), SI.getOperand(1))) { ++NumDeadStore; - ++BBI; + // Manually add back the original store to the worklist now, so it will + // be processed after the operands of the removed store, as this may + // expose additional DSE opportunities. + Worklist.push(&SI); eraseInstFromFunction(*PrevSI); - continue; + return nullptr; } break; } @@ -1468,11 +1416,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { // store X, null -> turns into 'unreachable' in SimplifyCFG // store X, GEP(null, Y) -> turns into 'unreachable' in SimplifyCFG if (canSimplifyNullStoreOrGEP(SI)) { - if (!isa<UndefValue>(Val)) { - SI.setOperand(0, UndefValue::get(Val->getType())); - if (Instruction *U = dyn_cast<Instruction>(Val)) - Worklist.Add(U); // Dropped a use. - } + if (!isa<UndefValue>(Val)) + return replaceOperand(SI, 0, UndefValue::get(Val->getType())); return nullptr; // Do not modify these! } @@ -1480,19 +1425,6 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (isa<UndefValue>(Val)) return eraseInstFromFunction(SI); - // If this store is the second-to-last instruction in the basic block - // (excluding debug info and bitcasts of pointers) and if the block ends with - // an unconditional branch, try to move the store to the successor block. - BBI = SI.getIterator(); - do { - ++BBI; - } while (isa<DbgInfoIntrinsic>(BBI) || - (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())); - - if (BranchInst *BI = dyn_cast<BranchInst>(BBI)) - if (BI->isUnconditional()) - mergeStoreIntoSuccessor(SI); - return nullptr; } @@ -1502,8 +1434,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { /// *P = v1; if () { *P = v2; } /// into a phi node with a store in the successor. bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { - assert(SI.isUnordered() && - "This code has not been audited for volatile or ordered store case."); + if (!SI.isUnordered()) + return false; // This code has not been audited for volatile/ordered case. // Check if the successor block has exactly 2 incoming edges. BasicBlock *StoreBB = SI.getParent(); @@ -1595,9 +1527,9 @@ bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { // Advance to a place where it is safe to insert the new store and insert it. BBI = DestBB->getFirstInsertionPt(); - StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), - MaybeAlign(SI.getAlignment()), - SI.getOrdering(), SI.getSyncScopeID()); + StoreInst *NewSI = + new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), SI.getAlign(), + SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(MergedLoc); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 2774e46151faf..c6233a68847dd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -72,7 +72,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { - I->setOperand(0, V2); + IC.replaceOperand(*I, 0, V2); MadeChange = true; } @@ -96,19 +96,22 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, /// A helper routine of InstCombiner::visitMul(). /// -/// If C is a scalar/vector of known powers of 2, then this function returns -/// a new scalar/vector obtained from logBase2 of C. +/// If C is a scalar/fixed width vector of known powers of 2, then this +/// function returns a new scalar/fixed width vector obtained from logBase2 +/// of C. /// Return a null pointer otherwise. static Constant *getLogBase2(Type *Ty, Constant *C) { const APInt *IVal; if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) return ConstantInt::get(Ty, IVal->logBase2()); - if (!Ty->isVectorTy()) + // FIXME: We can extract pow of 2 of splat constant for scalable vectors. + if (!isa<FixedVectorType>(Ty)) return nullptr; SmallVector<Constant *, 4> Elts; - for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { + for (unsigned I = 0, E = cast<FixedVectorType>(Ty)->getNumElements(); I != E; + ++I) { Constant *Elt = C->getAggregateElement(I); if (!Elt) return nullptr; @@ -274,6 +277,15 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // abs(X) * abs(X) -> X * X + // nabs(X) * nabs(X) -> X * X + if (Op0 == Op1) { + Value *X, *Y; + SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) + return BinaryOperator::CreateMul(X, X); + } + // -X * C --> X * -C Value *X, *Y; Constant *Op1C; @@ -354,6 +366,27 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // (zext bool X) * (zext bool Y) --> zext (and X, Y) + // (sext bool X) * (sext bool Y) --> zext (and X, Y) + // Note: -1 * -1 == 1 * 1 == 1 (if the extends match, the result is the same) + if (((match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_SExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::ZExt, And, I.getType()); + } + // (sext bool X) * (zext bool Y) --> sext (and X, Y) + // (zext bool X) * (sext bool Y) --> sext (and X, Y) + // Note: -1 * 1 == 1 * -1 == -1 + if (((match(Op0, m_SExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::SExt, And, I.getType()); + } + // (bool X) * Y --> X ? Y : 0 // Y * (bool X) --> X ? Y : 0 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) @@ -390,6 +423,40 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } +Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) { + BinaryOperator::BinaryOps Opcode = I.getOpcode(); + assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) && + "Expected fmul or fdiv"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X, *Y; + + // -X * -Y --> X * Y + // -X / -Y --> X / Y + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, Y, &I); + + // fabs(X) * fabs(X) -> X * X + // fabs(X) / fabs(X) -> X / X + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I); + + // fabs(X) * fabs(Y) --> fabs(X * Y) + // fabs(X) / fabs(Y) --> fabs(X / Y) + if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *XY = Builder.CreateBinOp(Opcode, X, Y); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY); + Fabs->takeName(&I); + return replaceInstUsesWith(I, Fabs); + } + + return nullptr; +} + Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -408,25 +475,20 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) return replaceInstUsesWith(I, FoldedMul); + if (Instruction *R = foldFPSignBitOps(I)) + return R; + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) - return BinaryOperator::CreateFNegFMF(Op0, &I); - - // -X * -Y --> X * Y - Value *X, *Y; - if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFMulFMF(X, Y, &I); + return UnaryOperator::CreateFNegFMF(Op0, &I); // -X * C --> X * -C + Value *X, *Y; Constant *C; if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - // fabs(X) * fabs(X) -> X * X - if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) - return BinaryOperator::CreateFMulFMF(X, X, &I); - // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); @@ -563,8 +625,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Y = Op0; } if (Log2) { - Log2->setArgOperand(0, X); - Log2->copyFastMathFlags(&I); + Value *Log2 = Builder.CreateUnaryIntrinsic(Intrinsic::log2, X, &I); Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I); return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I); } @@ -592,7 +653,7 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { return false; // Change the div/rem to use 'Y' instead of the select. - I.setOperand(1, SI->getOperand(NonNullOperand)); + replaceOperand(I, 1, SI->getOperand(NonNullOperand)); // Okay, we know we replace the operand of the div/rem with 'Y' with no // problem. However, the select, or the condition of the select may have @@ -620,12 +681,12 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end(); I != E; ++I) { if (*I == SI) { - *I = SI->getOperand(NonNullOperand); - Worklist.Add(&*BBI); + replaceUse(*I, SI->getOperand(NonNullOperand)); + Worklist.push(&*BBI); } else if (*I == SelectCond) { - *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) - : ConstantInt::getFalse(CondTy); - Worklist.Add(&*BBI); + replaceUse(*I, NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) + : ConstantInt::getFalse(CondTy)); + Worklist.push(&*BBI); } } @@ -683,10 +744,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Type *Ty = I.getType(); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { - I.setOperand(1, V); - return &I; - } + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); // Handle cases involving: [su]div X, (select Cond, Y, Z) // This does not apply for fdiv. @@ -800,8 +859,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap(); bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(); if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) { - I.setOperand(0, ConstantInt::get(Ty, 1)); - I.setOperand(1, Y); + replaceOperand(I, 0, ConstantInt::get(Ty, 1)); + replaceOperand(I, 1, Y); return &I; } } @@ -1214,6 +1273,9 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Instruction *R = foldFDivConstantDividend(I)) return R; + if (Instruction *R = foldFPSignBitOps(I)) + return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1274,21 +1336,14 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } - // -X / -Y -> X / Y - Value *X, *Y; - if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) { - I.setOperand(0, X); - I.setOperand(1, Y); - return &I; - } - // X / (X * Y) --> 1.0 / Y // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed. // We can ignore the possibility that X is infinity because INF/INF is NaN. + Value *X, *Y; if (I.hasNoNaNs() && I.hasAllowReassoc() && match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) { - I.setOperand(0, ConstantFP::get(I.getType(), 1.0)); - I.setOperand(1, Y); + replaceOperand(I, 0, ConstantFP::get(I.getType(), 1.0)); + replaceOperand(I, 1, Y); return &I; } @@ -1314,10 +1369,8 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { - I.setOperand(1, V); - return &I; - } + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); // Handle cases involving: rem X, (select Cond, Y, Z) if (simplifyDivRemOfSelectWithZeroOp(I)) @@ -1417,11 +1470,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { { const APInt *Y; // X % -Y -> X % Y - if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) { - Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); - return &I; - } + if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) + return replaceOperand(I, 1, ConstantInt::get(I.getType(), -*Y)); } // -X srem Y --> -(X srem Y) @@ -1441,7 +1491,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // If it's a constant vector, flip any negative values positive. if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) { Constant *C = cast<Constant>(Op1); - unsigned VWidth = C->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(C->getType())->getNumElements(); bool hasNegative = false; bool hasMissing = false; @@ -1468,11 +1518,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { } Constant *NewRHSV = ConstantVector::get(Elts); - if (NewRHSV != C) { // Don't loop on -MININT - Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, NewRHSV); - return &I; - } + if (NewRHSV != C) // Don't loop on -MININT + return replaceOperand(I, 1, NewRHSV); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp new file mode 100644 index 0000000000000..3fe615ac54391 --- /dev/null +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -0,0 +1,474 @@ +//===- InstCombineNegator.cpp -----------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements sinking of negation into expression trees, +// as long as that can be done without increasing instruction count. +// +//===----------------------------------------------------------------------===// + +#include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetFolder.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include <functional> +#include <tuple> +#include <type_traits> +#include <utility> + +namespace llvm { +class AssumptionCache; +class DataLayout; +class DominatorTree; +class LLVMContext; +} // namespace llvm + +using namespace llvm; + +#define DEBUG_TYPE "instcombine" + +STATISTIC(NegatorTotalNegationsAttempted, + "Negator: Number of negations attempted to be sinked"); +STATISTIC(NegatorNumTreesNegated, + "Negator: Number of negations successfully sinked"); +STATISTIC(NegatorMaxDepthVisited, "Negator: Maximal traversal depth ever " + "reached while attempting to sink negation"); +STATISTIC(NegatorTimesDepthLimitReached, + "Negator: How many times did the traversal depth limit was reached " + "during sinking"); +STATISTIC( + NegatorNumValuesVisited, + "Negator: Total number of values visited during attempts to sink negation"); +STATISTIC(NegatorNumNegationsFoundInCache, + "Negator: How many negations did we retrieve/reuse from cache"); +STATISTIC(NegatorMaxTotalValuesVisited, + "Negator: Maximal number of values ever visited while attempting to " + "sink negation"); +STATISTIC(NegatorNumInstructionsCreatedTotal, + "Negator: Number of new negated instructions created, total"); +STATISTIC(NegatorMaxInstructionsCreated, + "Negator: Maximal number of new instructions created during negation " + "attempt"); +STATISTIC(NegatorNumInstructionsNegatedSuccess, + "Negator: Number of new negated instructions created in successful " + "negation sinking attempts"); + +DEBUG_COUNTER(NegatorCounter, "instcombine-negator", + "Controls Negator transformations in InstCombine pass"); + +static cl::opt<bool> + NegatorEnabled("instcombine-negator-enabled", cl::init(true), + cl::desc("Should we attempt to sink negations?")); + +static cl::opt<unsigned> + NegatorMaxDepth("instcombine-negator-max-depth", + cl::init(NegatorDefaultMaxDepth), + cl::desc("What is the maximal lookup depth when trying to " + "check for viability of negation sinking.")); + +Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_, + const DominatorTree &DT_, bool IsTrulyNegation_) + : Builder(C, TargetFolder(DL_), + IRBuilderCallbackInserter([&](Instruction *I) { + ++NegatorNumInstructionsCreatedTotal; + NewInstructions.push_back(I); + })), + DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {} + +#if LLVM_ENABLE_STATS +Negator::~Negator() { + NegatorMaxTotalValuesVisited.updateMax(NumValuesVisitedInThisNegator); +} +#endif + +// FIXME: can this be reworked into a worklist-based algorithm while preserving +// the depth-first, early bailout traversal? +LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { + // -(undef) -> undef. + if (match(V, m_Undef())) + return V; + + // In i1, negation can simply be ignored. + if (V->getType()->isIntOrIntVectorTy(1)) + return V; + + Value *X; + + // -(-(X)) -> X. + if (match(V, m_Neg(m_Value(X)))) + return X; + + // Integral constants can be freely negated. + if (match(V, m_AnyIntegralConstant())) + return ConstantExpr::getNeg(cast<Constant>(V), /*HasNUW=*/false, + /*HasNSW=*/false); + + // If we have a non-instruction, then give up. + if (!isa<Instruction>(V)) + return nullptr; + + // If we have started with a true negation (i.e. `sub 0, %y`), then if we've + // got instruction that does not require recursive reasoning, we can still + // negate it even if it has other uses, without increasing instruction count. + if (!V->hasOneUse() && !IsTrulyNegation) + return nullptr; + + auto *I = cast<Instruction>(V); + unsigned BitWidth = I->getType()->getScalarSizeInBits(); + + // We must preserve the insertion point and debug info that is set in the + // builder at the time this function is called. + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + // And since we are trying to negate instruction I, that tells us about the + // insertion point and the debug info that we need to keep. + Builder.SetInsertPoint(I); + + // In some cases we can give the answer without further recursion. + switch (I->getOpcode()) { + case Instruction::Add: + // `inc` is always negatible. + if (match(I->getOperand(1), m_One())) + return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg"); + break; + case Instruction::Xor: + // `not` is always negatible. + if (match(I, m_Not(m_Value(X)))) + return Builder.CreateAdd(X, ConstantInt::get(X->getType(), 1), + I->getName() + ".neg"); + break; + case Instruction::AShr: + case Instruction::LShr: { + // Right-shift sign bit smear is negatible. + const APInt *Op1Val; + if (match(I->getOperand(1), m_APInt(Op1Val)) && *Op1Val == BitWidth - 1) { + Value *BO = I->getOpcode() == Instruction::AShr + ? Builder.CreateLShr(I->getOperand(0), I->getOperand(1)) + : Builder.CreateAShr(I->getOperand(0), I->getOperand(1)); + if (auto *NewInstr = dyn_cast<Instruction>(BO)) { + NewInstr->copyIRFlags(I); + NewInstr->setName(I->getName() + ".neg"); + } + return BO; + } + break; + } + case Instruction::SExt: + case Instruction::ZExt: + // `*ext` of i1 is always negatible + if (I->getOperand(0)->getType()->isIntOrIntVectorTy(1)) + return I->getOpcode() == Instruction::SExt + ? Builder.CreateZExt(I->getOperand(0), I->getType(), + I->getName() + ".neg") + : Builder.CreateSExt(I->getOperand(0), I->getType(), + I->getName() + ".neg"); + break; + default: + break; // Other instructions require recursive reasoning. + } + + // Some other cases, while still don't require recursion, + // are restricted to the one-use case. + if (!V->hasOneUse()) + return nullptr; + + switch (I->getOpcode()) { + case Instruction::Sub: + // `sub` is always negatible. + // But if the old `sub` sticks around, even thought we don't increase + // instruction count, this is a likely regression since we increased + // live-range of *both* of the operands, which might lead to more spilling. + return Builder.CreateSub(I->getOperand(1), I->getOperand(0), + I->getName() + ".neg"); + case Instruction::SDiv: + // `sdiv` is negatible if divisor is not undef/INT_MIN/1. + // While this is normally not behind a use-check, + // let's consider division to be special since it's costly. + if (auto *Op1C = dyn_cast<Constant>(I->getOperand(1))) { + if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() && + Op1C->isNotOneValue()) { + Value *BO = + Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), + I->getName() + ".neg"); + if (auto *NewInstr = dyn_cast<Instruction>(BO)) + NewInstr->setIsExact(I->isExact()); + return BO; + } + } + break; + } + + // Rest of the logic is recursive, so if it's time to give up then it's time. + if (Depth > NegatorMaxDepth) { + LLVM_DEBUG(dbgs() << "Negator: reached maximal allowed traversal depth in " + << *V << ". Giving up.\n"); + ++NegatorTimesDepthLimitReached; + return nullptr; + } + + switch (I->getOpcode()) { + case Instruction::PHI: { + // `phi` is negatible if all the incoming values are negatible. + auto *PHI = cast<PHINode>(I); + SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands()); + for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) { + if (!(std::get<1>(I) = + negate(std::get<0>(I), Depth + 1))) // Early return. + return nullptr; + } + // All incoming values are indeed negatible. Create negated PHI node. + PHINode *NegatedPHI = Builder.CreatePHI( + PHI->getType(), PHI->getNumOperands(), PHI->getName() + ".neg"); + for (auto I : zip(NegatedIncomingValues, PHI->blocks())) + NegatedPHI->addIncoming(std::get<0>(I), std::get<1>(I)); + return NegatedPHI; + } + case Instruction::Select: { + { + // `abs`/`nabs` is always negatible. + Value *LHS, *RHS; + SelectPatternFlavor SPF = + matchSelectPattern(I, LHS, RHS, /*CastOp=*/nullptr, Depth).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) { + auto *NewSelect = cast<SelectInst>(I->clone()); + // Just swap the operands of the select. + NewSelect->swapValues(); + // Don't swap prof metadata, we didn't change the branch behavior. + NewSelect->setName(I->getName() + ".neg"); + Builder.Insert(NewSelect); + return NewSelect; + } + } + // `select` is negatible if both hands of `select` are negatible. + Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + if (!NegOp1) // Early return. + return nullptr; + Value *NegOp2 = negate(I->getOperand(2), Depth + 1); + if (!NegOp2) + return nullptr; + // Do preserve the metadata! + return Builder.CreateSelect(I->getOperand(0), NegOp1, NegOp2, + I->getName() + ".neg", /*MDFrom=*/I); + } + case Instruction::ShuffleVector: { + // `shufflevector` is negatible if both operands are negatible. + auto *Shuf = cast<ShuffleVectorInst>(I); + Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + if (!NegOp0) // Early return. + return nullptr; + Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + if (!NegOp1) + return nullptr; + return Builder.CreateShuffleVector(NegOp0, NegOp1, Shuf->getShuffleMask(), + I->getName() + ".neg"); + } + case Instruction::ExtractElement: { + // `extractelement` is negatible if source operand is negatible. + auto *EEI = cast<ExtractElementInst>(I); + Value *NegVector = negate(EEI->getVectorOperand(), Depth + 1); + if (!NegVector) // Early return. + return nullptr; + return Builder.CreateExtractElement(NegVector, EEI->getIndexOperand(), + I->getName() + ".neg"); + } + case Instruction::InsertElement: { + // `insertelement` is negatible if both the source vector and + // element-to-be-inserted are negatible. + auto *IEI = cast<InsertElementInst>(I); + Value *NegVector = negate(IEI->getOperand(0), Depth + 1); + if (!NegVector) // Early return. + return nullptr; + Value *NegNewElt = negate(IEI->getOperand(1), Depth + 1); + if (!NegNewElt) // Early return. + return nullptr; + return Builder.CreateInsertElement(NegVector, NegNewElt, IEI->getOperand(2), + I->getName() + ".neg"); + } + case Instruction::Trunc: { + // `trunc` is negatible if its operand is negatible. + Value *NegOp = negate(I->getOperand(0), Depth + 1); + if (!NegOp) // Early return. + return nullptr; + return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg"); + } + case Instruction::Shl: { + // `shl` is negatible if the first operand is negatible. + Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + if (!NegOp0) // Early return. + return nullptr; + return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + } + case Instruction::Or: + if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, + &DT)) + return nullptr; // Don't know how to handle `or` in general. + // `or`/`add` are interchangeable when operands have no common bits set. + // `inc` is always negatible. + if (match(I->getOperand(1), m_One())) + return Builder.CreateNot(I->getOperand(0), I->getName() + ".neg"); + // Else, just defer to Instruction::Add handling. + LLVM_FALLTHROUGH; + case Instruction::Add: { + // `add` is negatible if both of its operands are negatible. + Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + if (!NegOp0) // Early return. + return nullptr; + Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + if (!NegOp1) + return nullptr; + return Builder.CreateAdd(NegOp0, NegOp1, I->getName() + ".neg"); + } + case Instruction::Xor: + // `xor` is negatible if one of its operands is invertible. + // FIXME: InstCombineInverter? But how to connect Inverter and Negator? + if (auto *C = dyn_cast<Constant>(I->getOperand(1))) { + Value *Xor = Builder.CreateXor(I->getOperand(0), ConstantExpr::getNot(C)); + return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), + I->getName() + ".neg"); + } + return nullptr; + case Instruction::Mul: { + // `mul` is negatible if one of its operands is negatible. + Value *NegatedOp, *OtherOp; + // First try the second operand, in case it's a constant it will be best to + // just invert it instead of sinking the `neg` deeper. + if (Value *NegOp1 = negate(I->getOperand(1), Depth + 1)) { + NegatedOp = NegOp1; + OtherOp = I->getOperand(0); + } else if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) { + NegatedOp = NegOp0; + OtherOp = I->getOperand(1); + } else + // Can't negate either of them. + return nullptr; + return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg"); + } + default: + return nullptr; // Don't know, likely not negatible for free. + } + + llvm_unreachable("Can't get here. We always return from switch."); +} + +LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { + NegatorMaxDepthVisited.updateMax(Depth); + ++NegatorNumValuesVisited; + +#if LLVM_ENABLE_STATS + ++NumValuesVisitedInThisNegator; +#endif + +#ifndef NDEBUG + // We can't ever have a Value with such an address. + Value *Placeholder = reinterpret_cast<Value *>(static_cast<uintptr_t>(-1)); +#endif + + // Did we already try to negate this value? + auto NegationsCacheIterator = NegationsCache.find(V); + if (NegationsCacheIterator != NegationsCache.end()) { + ++NegatorNumNegationsFoundInCache; + Value *NegatedV = NegationsCacheIterator->second; + assert(NegatedV != Placeholder && "Encountered a cycle during negation."); + return NegatedV; + } + +#ifndef NDEBUG + // We did not find a cached result for negation of V. While there, + // let's temporairly cache a placeholder value, with the idea that if later + // during negation we fetch it from cache, we'll know we're in a cycle. + NegationsCache[V] = Placeholder; +#endif + + // No luck. Try negating it for real. + Value *NegatedV = visitImpl(V, Depth); + // And cache the (real) result for the future. + NegationsCache[V] = NegatedV; + + return NegatedV; +} + +LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) { + Value *Negated = negate(Root, /*Depth=*/0); + if (!Negated) { + // We must cleanup newly-inserted instructions, to avoid any potential + // endless combine looping. + llvm::for_each(llvm::reverse(NewInstructions), + [&](Instruction *I) { I->eraseFromParent(); }); + return llvm::None; + } + return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); +} + +LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, + InstCombiner &IC) { + ++NegatorTotalNegationsAttempted; + LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root + << "\n"); + + if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter)) + return nullptr; + + Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), + IC.getDominatorTree(), LHSIsZero); + Optional<Result> Res = N.run(Root); + if (!Res) { // Negation failed. + LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root + << "\n"); + return nullptr; + } + + LLVM_DEBUG(dbgs() << "Negator: successfully sunk negation into " << *Root + << "\n NEW: " << *Res->second << "\n"); + ++NegatorNumTreesNegated; + + // We must temporarily unset the 'current' insertion point and DebugLoc of the + // InstCombine's IRBuilder so that it won't interfere with the ones we have + // already specified when producing negated instructions. + InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder); + IC.Builder.ClearInsertionPoint(); + IC.Builder.SetCurrentDebugLocation(DebugLoc()); + + // And finally, we must add newly-created instructions into the InstCombine's + // worklist (in a proper order!) so it can attempt to combine them. + LLVM_DEBUG(dbgs() << "Negator: Propagating " << Res->first.size() + << " instrs to InstCombine\n"); + NegatorMaxInstructionsCreated.updateMax(Res->first.size()); + NegatorNumInstructionsNegatedSuccess += Res->first.size(); + + // They are in def-use order, so nothing fancy, just insert them in order. + llvm::for_each(Res->first, + [&](Instruction *I) { IC.Builder.Insert(I, I->getName()); }); + + // And return the new root. + return Res->second; +} diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 74e015a4f1d44..2b2f2e1b9470f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -218,13 +218,21 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { return nullptr; // If any of the operand that requires casting is a terminator - // instruction, do not do it. + // instruction, do not do it. Similarly, do not do the transform if the value + // is PHI in a block with no insertion point, for example, a catchswitch + // block, since we will not be able to insert a cast after the PHI. if (any_of(AvailablePtrVals, [&](Value *V) { if (V->getType() == IntToPtr->getType()) return false; - auto *Inst = dyn_cast<Instruction>(V); - return Inst && Inst->isTerminator(); + if (!Inst) + return false; + if (Inst->isTerminator()) + return true; + auto *BB = Inst->getParent(); + if (isa<PHINode>(Inst) && BB->getFirstInsertionPt() == BB->end()) + return true; + return false; })) return nullptr; @@ -264,8 +272,10 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { if (auto *IncomingI = dyn_cast<Instruction>(IncomingVal)) { BasicBlock::iterator InsertPos(IncomingI); InsertPos++; + BasicBlock *BB = IncomingI->getParent(); if (isa<PHINode>(IncomingI)) - InsertPos = IncomingI->getParent()->getFirstInsertionPt(); + InsertPos = BB->getFirstInsertionPt(); + assert(InsertPos != BB->end() && "should have checked above"); InsertNewInstBefore(CI, *InsertPos); } else { auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); @@ -544,7 +554,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // visitLoadInst will propagate an alignment onto the load when TD is around, // and if TD isn't around, we can't handle the mixed case. bool isVolatile = FirstLI->isVolatile(); - MaybeAlign LoadAlignment(FirstLI->getAlignment()); + Align LoadAlignment = FirstLI->getAlign(); unsigned LoadAddrSpace = FirstLI->getPointerAddressSpace(); // We can't sink the load if the loaded value could be modified between the @@ -574,12 +584,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { !isSafeAndProfitableToSinkLoad(LI)) return nullptr; - // If some of the loads have an alignment specified but not all of them, - // we can't do the transformation. - if ((LoadAlignment.hasValue()) != (LI->getAlignment() != 0)) - return nullptr; - - LoadAlignment = std::min(LoadAlignment, MaybeAlign(LI->getAlignment())); + LoadAlignment = std::min(LoadAlignment, Align(LI->getAlign())); // If the PHI is of volatile loads and the load block has multiple // successors, sinking it would remove a load of the volatile value from @@ -1184,15 +1189,22 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && match(CmpInst->getOperand(1), m_Zero())) { ConstantInt *NonZeroConst = nullptr; + bool MadeChange = false; for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { Instruction *CtxI = PN.getIncomingBlock(i)->getTerminator(); Value *VA = PN.getIncomingValue(i); if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { if (!NonZeroConst) NonZeroConst = GetAnyNonZeroConstInt(PN); - PN.setIncomingValue(i, NonZeroConst); + + if (NonZeroConst != VA) { + replaceOperand(PN, i, NonZeroConst); + MadeChange = true; + } } } + if (MadeChange) + return &PN; } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 05a624fde86b6..17124f717af79 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -56,7 +56,8 @@ static Value *createMinMax(InstCombiner::BuilderTy &Builder, /// Replace a select operand based on an equality comparison with the identity /// constant of a binop. static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, - const TargetLibraryInfo &TLI) { + const TargetLibraryInfo &TLI, + InstCombiner &IC) { // The select condition must be an equality compare with a constant operand. Value *X; Constant *C; @@ -107,8 +108,7 @@ static Instruction *foldSelectBinOpIdentity(SelectInst &Sel, // S = { select (cmp eq X, C), BO, ? } or { select (cmp ne X, C), ?, BO } // => // S = { select (cmp eq X, C), Y, ? } or { select (cmp ne X, C), ?, Y } - Sel.setOperand(IsEq ? 1 : 2, Y); - return &Sel; + return IC.replaceOperand(Sel, IsEq ? 1 : 2, Y); } /// This folds: @@ -301,10 +301,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, // The select condition may be a vector. We may only change the operand // type if the vector width remains the same (and matches the condition). - if (CondTy->isVectorTy()) { + if (auto *CondVTy = dyn_cast<VectorType>(CondTy)) { if (!FIOpndTy->isVectorTy()) return nullptr; - if (CondTy->getVectorNumElements() != FIOpndTy->getVectorNumElements()) + if (CondVTy->getNumElements() != + cast<VectorType>(FIOpndTy)->getNumElements()) return nullptr; // TODO: If the backend knew how to deal with casts better, we could @@ -338,11 +339,7 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && (TI->hasOneUse() || FI->hasOneUse())) { Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - // TODO: Remove the hack for the binop form when the unary op is optimized - // properly with all IR passes. - if (TI->getOpcode() != Instruction::FNeg) - return BinaryOperator::CreateFNegFMF(NewSel, cast<BinaryOperator>(TI)); - return UnaryOperator::CreateFNeg(NewSel); + return UnaryOperator::CreateFNegFMF(NewSel, TI); } // Only handle binary operators (including two-operand getelementptr) with @@ -674,6 +671,38 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, return Builder.CreateOr(V, Y); } +/// Canonicalize a set or clear of a masked set of constant bits to +/// select-of-constants form. +static Instruction *foldSetClearBits(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *T = Sel.getTrueValue(); + Value *F = Sel.getFalseValue(); + Type *Ty = Sel.getType(); + Value *X; + const APInt *NotC, *C; + + // Cond ? (X & ~C) : (X | C) --> (X & ~C) | (Cond ? 0 : C) + if (match(T, m_And(m_Value(X), m_APInt(NotC))) && + match(F, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Constant *OrC = ConstantInt::get(Ty, *C); + Value *NewSel = Builder.CreateSelect(Cond, Zero, OrC, "masksel", &Sel); + return BinaryOperator::CreateOr(T, NewSel); + } + + // Cond ? (X | C) : (X & ~C) --> (X & ~C) | (Cond ? C : 0) + if (match(F, m_And(m_Value(X), m_APInt(NotC))) && + match(T, m_OneUse(m_Or(m_Specific(X), m_APInt(C)))) && *NotC == ~(*C)) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Constant *OrC = ConstantInt::get(Ty, *C); + Value *NewSel = Builder.CreateSelect(Cond, OrC, Zero, "masksel", &Sel); + return BinaryOperator::CreateOr(F, NewSel); + } + + return nullptr; +} + /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. @@ -857,16 +886,16 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) return nullptr; - Value *Count = FalseVal; + Value *SelectArg = FalseVal; Value *ValueOnZero = TrueVal; if (Pred == ICmpInst::ICMP_NE) - std::swap(Count, ValueOnZero); + std::swap(SelectArg, ValueOnZero); // Skip zero extend/truncate. - Value *V = nullptr; - if (match(Count, m_ZExt(m_Value(V))) || - match(Count, m_Trunc(m_Value(V)))) - Count = V; + Value *Count = nullptr; + if (!match(SelectArg, m_ZExt(m_Value(Count))) && + !match(SelectArg, m_Trunc(m_Value(Count)))) + Count = SelectArg; // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the // input to the cttz/ctlz is used as LHS for the compare instruction. @@ -880,17 +909,17 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // sizeof in bits of 'Count'. unsigned SizeOfInBits = Count->getType()->getScalarSizeInBits(); if (match(ValueOnZero, m_SpecificInt(SizeOfInBits))) { - // Explicitly clear the 'undef_on_zero' flag. - IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); - NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); - Builder.Insert(NewI); - return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); + // Explicitly clear the 'undef_on_zero' flag. It's always valid to go from + // true to false on this flag, so we can replace it for all users. + II->setArgOperand(1, ConstantInt::getFalse(II->getContext())); + return SelectArg; } - // If the ValueOnZero is not the bitwidth, we can at least make use of the - // fact that the cttz/ctlz result will not be used if the input is zero, so - // it's okay to relax it to undef for that case. - if (II->hasOneUse() && !match(II->getArgOperand(1), m_One())) + // The ValueOnZero is not the bitwidth. But if the cttz/ctlz (and optional + // zext/trunc) have one use (ending at the select), the cttz/ctlz result will + // not be used if the input is zero. Relax to 'undef_on_zero' for that case. + if (II->hasOneUse() && SelectArg->hasOneUse() && + !match(II->getArgOperand(1), m_One())) II->setArgOperand(1, ConstantInt::getTrue(II->getContext())); return nullptr; @@ -997,7 +1026,7 @@ static bool adjustMinMax(SelectInst &Sel, ICmpInst &Cmp) { /// constant operand of the select. static Instruction * canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { + InstCombiner &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; @@ -1013,8 +1042,14 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, Cmp.getPredicate() == CanonicalPred) return nullptr; + // Bail out on unsimplified X-0 operand (due to some worklist management bug), + // as this may cause an infinite combine loop. Let the sub be folded first. + if (match(LHS, m_Sub(m_Value(), m_Zero())) || + match(RHS, m_Sub(m_Value(), m_Zero()))) + return nullptr; + // Create the canonical compare and plug it into the select. - Sel.setCondition(Builder.CreateICmp(CanonicalPred, LHS, RHS)); + IC.replaceOperand(Sel, 0, IC.Builder.CreateICmp(CanonicalPred, LHS, RHS)); // If the select operands did not change, we're done. if (Sel.getTrueValue() == LHS && Sel.getFalseValue() == RHS) @@ -1035,7 +1070,7 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, /// Canonicalize all these variants to 1 pattern. /// This makes CSE more likely. static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { + InstCombiner &IC) { if (!Cmp.hasOneUse() || !isa<Constant>(Cmp.getOperand(1))) return nullptr; @@ -1067,10 +1102,11 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, if (CmpCanonicalized && RHSCanonicalized) return nullptr; - // If RHS is used by other instructions except compare and select, don't - // canonicalize it to not increase the instruction count. - if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) - return nullptr; + // If RHS is not canonical but is used by other instructions, don't + // canonicalize it and potentially increase the instruction count. + if (!RHSCanonicalized) + if (!(RHS->hasOneUse() || (RHS->hasNUses(2) && CmpUsesNegatedOp))) + return nullptr; // Create the canonical compare: icmp slt LHS 0. if (!CmpCanonicalized) { @@ -1083,12 +1119,14 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, // Create the canonical RHS: RHS = sub (0, LHS). if (!RHSCanonicalized) { assert(RHS->hasOneUse() && "RHS use number is not right"); - RHS = Builder.CreateNeg(LHS); + RHS = IC.Builder.CreateNeg(LHS); if (TVal == LHS) { - Sel.setFalseValue(RHS); + // Replace false value. + IC.replaceOperand(Sel, 2, RHS); FVal = RHS; } else { - Sel.setTrueValue(RHS); + // Replace true value. + IC.replaceOperand(Sel, 1, RHS); TVal = RHS; } } @@ -1322,7 +1360,7 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // and swap the hands of select. static Instruction * tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { + InstCombiner &IC) { ICmpInst::Predicate Pred; Value *X; Constant *C0; @@ -1374,13 +1412,13 @@ tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, return nullptr; // It matched! Lets insert the new comparison just before select. - InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(&Sel); + InstCombiner::BuilderTy::InsertPointGuard Guard(IC.Builder); + IC.Builder.SetInsertPoint(&Sel); Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped. - Value *NewCmp = Builder.CreateICmp(Pred, X, FlippedStrictness->second, - Cmp.getName() + ".inv"); - Sel.setCondition(NewCmp); + Value *NewCmp = IC.Builder.CreateICmp(Pred, X, FlippedStrictness->second, + Cmp.getName() + ".inv"); + IC.replaceOperand(Sel, 0, NewCmp); Sel.swapValues(); Sel.swapProfMetadata(); @@ -1393,17 +1431,17 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) return replaceInstUsesWith(SI, V); - if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, *this)) return NewSel; - if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) + if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, *this)) return NewAbs; if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder)) return NewAbs; if (Instruction *NewSel = - tryToReuseConstantFromSelectInComparison(SI, *ICI, Builder)) + tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) return NewSel; bool Changed = adjustMinMax(SI, *ICI); @@ -1892,7 +1930,7 @@ Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { Type *SelType = Sel.getType(); Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); - if (ExtC == C) { + if (ExtC == C && ExtInst->hasOneUse()) { Value *TruncCVal = cast<Value>(TruncC); if (ExtInst == Sel.getFalseValue()) std::swap(X, TruncCVal); @@ -1931,10 +1969,9 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { if (!CondVal->getType()->isVectorTy() || !match(CondVal, m_Constant(CondC))) return nullptr; - unsigned NumElts = CondVal->getType()->getVectorNumElements(); - SmallVector<Constant *, 16> Mask; + unsigned NumElts = cast<VectorType>(CondVal->getType())->getNumElements(); + SmallVector<int, 16> Mask; Mask.reserve(NumElts); - Type *Int32Ty = Type::getInt32Ty(CondVal->getContext()); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = CondC->getAggregateElement(i); if (!Elt) @@ -1942,10 +1979,10 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { if (Elt->isOneValue()) { // If the select condition element is true, choose from the 1st vector. - Mask.push_back(ConstantInt::get(Int32Ty, i)); + Mask.push_back(i); } else if (Elt->isNullValue()) { // If the select condition element is false, choose from the 2nd vector. - Mask.push_back(ConstantInt::get(Int32Ty, i + NumElts)); + Mask.push_back(i + NumElts); } else if (isa<UndefValue>(Elt)) { // Undef in a select condition (choose one of the operands) does not mean // the same thing as undef in a shuffle mask (any value is acceptable), so @@ -1957,8 +1994,7 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { } } - return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), - ConstantVector::get(Mask)); + return new ShuffleVectorInst(SI.getTrueValue(), SI.getFalseValue(), Mask); } /// If we have a select of vectors with a scalar condition, try to convert that @@ -1966,23 +2002,21 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { /// other operations in IR and having all operands of a select be vector types /// is likely better for vector codegen. static Instruction *canonicalizeScalarSelectOfVecs( - SelectInst &Sel, InstCombiner::BuilderTy &Builder) { - Type *Ty = Sel.getType(); - if (!Ty->isVectorTy()) + SelectInst &Sel, InstCombiner &IC) { + auto *Ty = dyn_cast<VectorType>(Sel.getType()); + if (!Ty) return nullptr; // We can replace a single-use extract with constant index. Value *Cond = Sel.getCondition(); - if (!match(Cond, m_OneUse(m_ExtractElement(m_Value(), m_ConstantInt())))) + if (!match(Cond, m_OneUse(m_ExtractElt(m_Value(), m_ConstantInt())))) return nullptr; // select (extelt V, Index), T, F --> select (splat V, Index), T, F // Splatting the extracted condition reduces code (we could directly create a // splat shuffle of the source vector to eliminate the intermediate step). - unsigned NumElts = Ty->getVectorNumElements(); - Value *SplatCond = Builder.CreateVectorSplat(NumElts, Cond); - Sel.setCondition(SplatCond); - return &Sel; + unsigned NumElts = Ty->getNumElements(); + return IC.replaceOperand(Sel, 0, IC.Builder.CreateVectorSplat(NumElts, Cond)); } /// Reuse bitcasted operands between a compare and select: @@ -2055,7 +2089,7 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, /// %1 = extractvalue { i64, i1 } %0, 0 /// ret i64 %1 /// -static Instruction *foldSelectCmpXchg(SelectInst &SI) { +static Value *foldSelectCmpXchg(SelectInst &SI) { // A helper that determines if V is an extractvalue instruction whose // aggregate operand is a cmpxchg instruction and whose single index is equal // to I. If such conditions are true, the helper returns the cmpxchg @@ -2087,19 +2121,15 @@ static Instruction *foldSelectCmpXchg(SelectInst &SI) { // value of the same cmpxchg used by the condition, and the false value is the // cmpxchg instruction's compare operand. if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) - if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { - SI.setTrueValue(SI.getFalseValue()); - return &SI; - } + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) + return SI.getFalseValue(); // Check the false value case: The false value of the select is the returned // value of the same cmpxchg used by the condition, and the true value is the // cmpxchg instruction's compare operand. if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) - if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { - SI.setTrueValue(SI.getFalseValue()); - return &SI; - } + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) + return SI.getFalseValue(); return nullptr; } @@ -2317,6 +2347,174 @@ static Instruction *foldSelectRotate(SelectInst &Sel) { return IntrinsicInst::Create(F, { TVal, TVal, ShAmt }); } +static Instruction *foldSelectToCopysign(SelectInst &Sel, + InstCombiner::BuilderTy &Builder) { + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Type *SelType = Sel.getType(); + + // Match select ?, TC, FC where the constants are equal but negated. + // TODO: Generalize to handle a negated variable operand? + const APFloat *TC, *FC; + if (!match(TVal, m_APFloat(TC)) || !match(FVal, m_APFloat(FC)) || + !abs(*TC).bitwiseIsEqual(abs(*FC))) + return nullptr; + + assert(TC != FC && "Expected equal select arms to simplify"); + + Value *X; + const APInt *C; + bool IsTrueIfSignSet; + ICmpInst::Predicate Pred; + if (!match(Cond, m_OneUse(m_ICmp(Pred, m_BitCast(m_Value(X)), m_APInt(C)))) || + !isSignBitCheck(Pred, *C, IsTrueIfSignSet) || X->getType() != SelType) + return nullptr; + + // If needed, negate the value that will be the sign argument of the copysign: + // (bitcast X) < 0 ? -TC : TC --> copysign(TC, X) + // (bitcast X) < 0 ? TC : -TC --> copysign(TC, -X) + // (bitcast X) >= 0 ? -TC : TC --> copysign(TC, -X) + // (bitcast X) >= 0 ? TC : -TC --> copysign(TC, X) + if (IsTrueIfSignSet ^ TC->isNegative()) + X = Builder.CreateFNegFMF(X, &Sel); + + // Canonicalize the magnitude argument as the positive constant since we do + // not care about its sign. + Value *MagArg = TC->isNegative() ? FVal : TVal; + Function *F = Intrinsic::getDeclaration(Sel.getModule(), Intrinsic::copysign, + Sel.getType()); + Instruction *CopySign = IntrinsicInst::Create(F, { MagArg, X }); + CopySign->setFastMathFlags(Sel.getFastMathFlags()); + return CopySign; +} + +Instruction *InstCombiner::foldVectorSelect(SelectInst &Sel) { + auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); + if (!VecTy) + return nullptr; + + unsigned NumElts = VecTy->getNumElements(); + APInt UndefElts(NumElts, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(NumElts)); + if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) { + if (V != &Sel) + return replaceInstUsesWith(Sel, V); + return &Sel; + } + + // A select of a "select shuffle" with a common operand can be rearranged + // to select followed by "select shuffle". Because of poison, this only works + // in the case of a shuffle with no undefined mask elements. + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Value *X, *Y; + ArrayRef<int> Mask; + if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && + !is_contained(Mask, UndefMaskElem) && + cast<ShuffleVectorInst>(TVal)->isSelect()) { + if (X == FVal) { + // select Cond, (shuf_sel X, Y), X --> shuf_sel X, (select Cond, Y, X) + Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); + return new ShuffleVectorInst(X, NewSel, Mask); + } + if (Y == FVal) { + // select Cond, (shuf_sel X, Y), Y --> shuf_sel (select Cond, X, Y), Y + Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); + return new ShuffleVectorInst(NewSel, Y, Mask); + } + } + if (match(FVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && + !is_contained(Mask, UndefMaskElem) && + cast<ShuffleVectorInst>(FVal)->isSelect()) { + if (X == TVal) { + // select Cond, X, (shuf_sel X, Y) --> shuf_sel X, (select Cond, X, Y) + Value *NewSel = Builder.CreateSelect(Cond, X, Y, "sel", &Sel); + return new ShuffleVectorInst(X, NewSel, Mask); + } + if (Y == TVal) { + // select Cond, Y, (shuf_sel X, Y) --> shuf_sel (select Cond, Y, X), Y + Value *NewSel = Builder.CreateSelect(Cond, Y, X, "sel", &Sel); + return new ShuffleVectorInst(NewSel, Y, Mask); + } + } + + return nullptr; +} + +static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, + const DominatorTree &DT, + InstCombiner::BuilderTy &Builder) { + // Find the block's immediate dominator that ends with a conditional branch + // that matches select's condition (maybe inverted). + auto *IDomNode = DT[BB]->getIDom(); + if (!IDomNode) + return nullptr; + BasicBlock *IDom = IDomNode->getBlock(); + + Value *Cond = Sel.getCondition(); + Value *IfTrue, *IfFalse; + BasicBlock *TrueSucc, *FalseSucc; + if (match(IDom->getTerminator(), + m_Br(m_Specific(Cond), m_BasicBlock(TrueSucc), + m_BasicBlock(FalseSucc)))) { + IfTrue = Sel.getTrueValue(); + IfFalse = Sel.getFalseValue(); + } else if (match(IDom->getTerminator(), + m_Br(m_Not(m_Specific(Cond)), m_BasicBlock(TrueSucc), + m_BasicBlock(FalseSucc)))) { + IfTrue = Sel.getFalseValue(); + IfFalse = Sel.getTrueValue(); + } else + return nullptr; + + // We want to replace select %cond, %a, %b with a phi that takes value %a + // for all incoming edges that are dominated by condition `%cond == true`, + // and value %b for edges dominated by condition `%cond == false`. If %a + // or %b are also phis from the same basic block, we can go further and take + // their incoming values from the corresponding blocks. + BasicBlockEdge TrueEdge(IDom, TrueSucc); + BasicBlockEdge FalseEdge(IDom, FalseSucc); + DenseMap<BasicBlock *, Value *> Inputs; + for (auto *Pred : predecessors(BB)) { + // Check implication. + BasicBlockEdge Incoming(Pred, BB); + if (DT.dominates(TrueEdge, Incoming)) + Inputs[Pred] = IfTrue->DoPHITranslation(BB, Pred); + else if (DT.dominates(FalseEdge, Incoming)) + Inputs[Pred] = IfFalse->DoPHITranslation(BB, Pred); + else + return nullptr; + // Check availability. + if (auto *Insn = dyn_cast<Instruction>(Inputs[Pred])) + if (!DT.dominates(Insn, Pred->getTerminator())) + return nullptr; + } + + Builder.SetInsertPoint(&*BB->begin()); + auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size()); + for (auto *Pred : predecessors(BB)) + PN->addIncoming(Inputs[Pred], Pred); + PN->takeName(&Sel); + return PN; +} + +static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, + InstCombiner::BuilderTy &Builder) { + // Try to replace this select with Phi in one of these blocks. + SmallSetVector<BasicBlock *, 4> CandidateBlocks; + CandidateBlocks.insert(Sel.getParent()); + for (Value *V : Sel.operands()) + if (auto *I = dyn_cast<Instruction>(V)) + CandidateBlocks.insert(I->getParent()); + + for (BasicBlock *BB : CandidateBlocks) + if (auto *PN = foldSelectToPhiImpl(Sel, BB, DT, Builder)) + return PN; + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -2346,25 +2544,10 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = canonicalizeSelectToShuffle(SI)) return I; - if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, Builder)) + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) return I; - // Canonicalize a one-use integer compare with a non-canonical predicate by - // inverting the predicate and swapping the select operands. This matches a - // compare canonicalization for conditional branches. - // TODO: Should we do the same for FP compares? CmpInst::Predicate Pred; - if (match(CondVal, m_OneUse(m_ICmp(Pred, m_Value(), m_Value()))) && - !isCanonicalPredicate(Pred)) { - // Swap true/false values and condition. - CmpInst *Cond = cast<CmpInst>(CondVal); - Cond->setPredicate(CmpInst::getInversePredicate(Pred)); - SI.setOperand(1, FalseVal); - SI.setOperand(2, TrueVal); - SI.swapProfMetadata(); - Worklist.Add(Cond); - return &SI; - } if (SelType->isIntOrIntVectorTy(1) && TrueVal->getType() == CondVal->getType()) { @@ -2514,6 +2697,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return Add; if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) return Add; + if (Instruction *Or = foldSetClearBits(SI, Builder)) + return Or; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast<Instruction>(TrueVal); @@ -2650,16 +2835,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (TrueSI->getCondition() == CondVal) { if (SI.getTrueValue() == TrueSI->getTrueValue()) return nullptr; - SI.setOperand(1, TrueSI->getTrueValue()); - return &SI; + return replaceOperand(SI, 1, TrueSI->getTrueValue()); } // select(C0, select(C1, a, b), b) -> select(C0&C1, a, b) // We choose this as normal form to enable folding on the And and shortening // paths for the values (this helps GetUnderlyingObjects() for example). if (TrueSI->getFalseValue() == FalseVal && TrueSI->hasOneUse()) { Value *And = Builder.CreateAnd(CondVal, TrueSI->getCondition()); - SI.setOperand(0, And); - SI.setOperand(1, TrueSI->getTrueValue()); + replaceOperand(SI, 0, And); + replaceOperand(SI, 1, TrueSI->getTrueValue()); return &SI; } } @@ -2670,14 +2854,13 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (FalseSI->getCondition() == CondVal) { if (SI.getFalseValue() == FalseSI->getFalseValue()) return nullptr; - SI.setOperand(2, FalseSI->getFalseValue()); - return &SI; + return replaceOperand(SI, 2, FalseSI->getFalseValue()); } // select(C0, a, select(C1, a, b)) -> select(C0|C1, a, b) if (FalseSI->getTrueValue() == TrueVal && FalseSI->hasOneUse()) { Value *Or = Builder.CreateOr(CondVal, FalseSI->getCondition()); - SI.setOperand(0, Or); - SI.setOperand(2, FalseSI->getFalseValue()); + replaceOperand(SI, 0, Or); + replaceOperand(SI, 2, FalseSI->getFalseValue()); return &SI; } } @@ -2704,15 +2887,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { canMergeSelectThroughBinop(TrueBO)) { if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { if (TrueBOSI->getCondition() == CondVal) { - TrueBO->setOperand(0, TrueBOSI->getTrueValue()); - Worklist.Add(TrueBO); + replaceOperand(*TrueBO, 0, TrueBOSI->getTrueValue()); + Worklist.push(TrueBO); return &SI; } } if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { if (TrueBOSI->getCondition() == CondVal) { - TrueBO->setOperand(1, TrueBOSI->getTrueValue()); - Worklist.Add(TrueBO); + replaceOperand(*TrueBO, 1, TrueBOSI->getTrueValue()); + Worklist.push(TrueBO); return &SI; } } @@ -2724,15 +2907,15 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { canMergeSelectThroughBinop(FalseBO)) { if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { if (FalseBOSI->getCondition() == CondVal) { - FalseBO->setOperand(0, FalseBOSI->getFalseValue()); - Worklist.Add(FalseBO); + replaceOperand(*FalseBO, 0, FalseBOSI->getFalseValue()); + Worklist.push(FalseBO); return &SI; } } if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { if (FalseBOSI->getCondition() == CondVal) { - FalseBO->setOperand(1, FalseBOSI->getFalseValue()); - Worklist.Add(FalseBO); + replaceOperand(*FalseBO, 1, FalseBOSI->getFalseValue()); + Worklist.push(FalseBO); return &SI; } } @@ -2740,23 +2923,14 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *NotCond; if (match(CondVal, m_Not(m_Value(NotCond)))) { - SI.setOperand(0, NotCond); - SI.setOperand(1, FalseVal); - SI.setOperand(2, TrueVal); + replaceOperand(SI, 0, NotCond); + SI.swapValues(); SI.swapProfMetadata(); return &SI; } - if (VectorType *VecTy = dyn_cast<VectorType>(SelType)) { - unsigned VWidth = VecTy->getNumElements(); - APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); - if (Value *V = SimplifyDemandedVectorElts(&SI, AllOnesEltMask, UndefElts)) { - if (V != &SI) - return replaceInstUsesWith(SI, V); - return &SI; - } - } + if (Instruction *I = foldVectorSelect(SI)) + return I; // If we can compute the condition, there's no need for a select. // Like the above fold, we are attempting to reduce compile-time cost by @@ -2776,14 +2950,20 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return BitCastSel; // Simplify selects that test the returned flag of cmpxchg instructions. - if (Instruction *Select = foldSelectCmpXchg(SI)) - return Select; + if (Value *V = foldSelectCmpXchg(SI)) + return replaceInstUsesWith(SI, V); - if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI)) + if (Instruction *Select = foldSelectBinOpIdentity(SI, TLI, *this)) return Select; if (Instruction *Rot = foldSelectRotate(SI)) return Rot; + if (Instruction *Copysign = foldSelectToCopysign(SI, Builder)) + return Copysign; + + if (Instruction *PN = foldSelectToPhi(SI, DT, Builder)) + return replaceInstUsesWith(SI, PN); + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index fbff5dd4a8cd5..0a842b4e10475 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -23,8 +23,11 @@ using namespace PatternMatch; // Given pattern: // (x shiftopcode Q) shiftopcode K // we should rewrite it as -// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) -// This is valid for any shift, but they must be identical. +// x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) and +// +// This is valid for any shift, but they must be identical, and we must be +// careful in case we have (zext(Q)+zext(K)) and look past extensions, +// (Q+K) must not overflow or else (Q+K) u< bitwidth(x) is bogus. // // AnalyzeForSignBitExtraction indicates that we will only analyze whether this // pattern has any 2 right-shifts that sum to 1 less than original bit width. @@ -58,6 +61,23 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( if (ShAmt0->getType() != ShAmt1->getType()) return nullptr; + // As input, we have the following pattern: + // Sh0 (Sh1 X, Q), K + // We want to rewrite that as: + // Sh x, (Q+K) iff (Q+K) u< bitwidth(x) + // While we know that originally (Q+K) would not overflow + // (because 2 * (N-1) u<= iN -1), we have looked past extensions of + // shift amounts. so it may now overflow in smaller bitwidth. + // To ensure that does not happen, we need to ensure that the total maximal + // shift amount is still representable in that smaller bit width. + unsigned MaximalPossibleTotalShiftAmount = + (Sh0->getType()->getScalarSizeInBits() - 1) + + (Sh1->getType()->getScalarSizeInBits() - 1); + APInt MaximalRepresentableShiftAmount = + APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) + return nullptr; + // We are only looking for signbit extraction if we have two right shifts. bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && match(Sh1, m_Shr(m_Value(), m_Value())); @@ -388,8 +408,7 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { // demand the sign bit (and many others) here?? Value *Rem = Builder.CreateAnd(A, ConstantInt::get(I.getType(), *B - 1), Op1->getName()); - I.setOperand(1, Rem); - return &I; + return replaceOperand(I, 1, Rem); } if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) @@ -593,19 +612,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, // We can always evaluate constants shifted. if (Constant *C = dyn_cast<Constant>(V)) { if (isLeftShift) - V = IC.Builder.CreateShl(C, NumBits); + return IC.Builder.CreateShl(C, NumBits); else - V = IC.Builder.CreateLShr(C, NumBits); - // If we got a constantexpr back, try to simplify it with TD info. - if (auto *C = dyn_cast<Constant>(V)) - if (auto *FoldedC = - ConstantFoldConstant(C, DL, &IC.getTargetLibraryInfo())) - V = FoldedC; - return V; + return IC.Builder.CreateLShr(C, NumBits); } Instruction *I = cast<Instruction>(V); - IC.Worklist.Add(I); + IC.Worklist.push(I); switch (I->getOpcode()) { default: llvm_unreachable("Inconsistency with CanEvaluateShifted"); @@ -761,7 +774,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } @@ -796,7 +809,7 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); Constant *Mask = ConstantInt::get(I.getContext(), Bits); if (VectorType *VT = dyn_cast<VectorType>(X->getType())) - Mask = ConstantVector::getSplat(VT->getNumElements(), Mask); + Mask = ConstantVector::getSplat(VT->getElementCount(), Mask); return BinaryOperator::CreateAnd(X, Mask); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 47ce83974c8d8..7cfe4c8b5892b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -87,7 +87,10 @@ bool InstCombiner::SimplifyDemandedBits(Instruction *I, unsigned OpNo, Value *NewVal = SimplifyDemandedUseBits(U.get(), DemandedMask, Known, Depth, I); if (!NewVal) return false; - U = NewVal; + if (Instruction* OpInst = dyn_cast<Instruction>(U)) + salvageDebugInfo(*OpInst); + + replaceUse(U, NewVal); return true; } @@ -173,15 +176,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - // Output known-0 are known to be clear if zero in either the LHS | RHS. - APInt IKnownZero = RHSKnown.Zero | LHSKnown.Zero; - // Output known-1 bits are only known if set in both the LHS & RHS. - APInt IKnownOne = RHSKnown.One & LHSKnown.One; + Known = LHSKnown & RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(VTy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and'. @@ -194,8 +194,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (ShrinkDemandedConstant(I, 1, DemandedMask & ~LHSKnown.Zero)) return I; - Known.Zero = std::move(IKnownZero); - Known.One = std::move(IKnownOne); break; } case Instruction::Or: { @@ -207,15 +205,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - // Output known-0 bits are only known if clear in both the LHS & RHS. - APInt IKnownZero = RHSKnown.Zero & LHSKnown.Zero; - // Output known-1 are known. to be set if s.et in either the LHS | RHS. - APInt IKnownOne = RHSKnown.One | LHSKnown.One; + Known = LHSKnown | RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(VTy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'or'. @@ -228,8 +223,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (ShrinkDemandedConstant(I, 1, DemandedMask)) return I; - Known.Zero = std::move(IKnownZero); - Known.One = std::move(IKnownOne); break; } case Instruction::Xor: { @@ -239,17 +232,12 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); - // Output known-0 bits are known if clear or set in both the LHS & RHS. - APInt IKnownZero = (RHSKnown.Zero & LHSKnown.Zero) | - (RHSKnown.One & LHSKnown.One); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - APInt IKnownOne = (RHSKnown.Zero & LHSKnown.One) | - (RHSKnown.One & LHSKnown.Zero); + Known = LHSKnown ^ RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(VTy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(VTy, Known.One); // If all of the demanded bits are known zero on one side, return the other. // These bits cannot contribute to the result of the 'xor'. @@ -309,10 +297,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return InsertNewInstWith(NewXor, *I); } - // Output known-0 bits are known if clear or set in both the LHS & RHS. - Known.Zero = std::move(IKnownZero); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - Known.One = std::move(IKnownOne); break; } case Instruction::Select: { @@ -396,8 +380,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) return I; assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); - Known = InputKnown.zextOrTrunc(BitWidth, - true /* ExtendedBitsAreKnownZero */); + Known = InputKnown.zextOrTrunc(BitWidth); assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } @@ -453,6 +436,43 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, break; } case Instruction::Add: + if ((DemandedMask & 1) == 0) { + // If we do not need the low bit, try to convert bool math to logic: + // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN + Value *X, *Y; + if (match(I, m_c_Add(m_OneUse(m_ZExt(m_Value(X))), + m_OneUse(m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { + // Truth table for inputs and output signbits: + // X:0 | X:1 + // ---------- + // Y:0 | 0 | 0 | + // Y:1 | -1 | 0 | + // ---------- + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *AndNot = Builder.CreateAnd(Builder.CreateNot(X), Y); + return Builder.CreateSExt(AndNot, VTy); + } + + // add iN (sext i1 X), (sext i1 Y) --> sext (X | Y) to iN + // TODO: Relax the one-use checks because we are removing an instruction? + if (match(I, m_Add(m_OneUse(m_SExt(m_Value(X))), + m_OneUse(m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType()) { + // Truth table for inputs and output signbits: + // X:0 | X:1 + // ----------- + // Y:0 | -1 | -1 | + // Y:1 | -1 | 0 | + // ----------- + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Or = Builder.CreateOr(X, Y); + return Builder.CreateSExt(Or, VTy); + } + } + LLVM_FALLTHROUGH; case Instruction::Sub: { /// If the high-bits of an ADD/SUB are not demanded, then we do not care /// about the high bits of the operands. @@ -515,11 +535,27 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); + + bool SignBitZero = Known.Zero.isSignBitSet(); + bool SignBitOne = Known.One.isSignBitSet(); Known.Zero <<= ShiftAmt; Known.One <<= ShiftAmt; // low bits known zero. if (ShiftAmt) Known.Zero.setLowBits(ShiftAmt); + + // If this shift has "nsw" keyword, then the result is either a poison + // value or has the same sign bit as the first operand. + if (IOp->hasNoSignedWrap()) { + if (SignBitZero) + Known.Zero.setSignBit(); + else if (SignBitOne) + Known.One.setSignBit(); + if (Known.hasConflict()) + return UndefValue::get(I->getType()); + } + } else { + computeKnownBits(I, Known, Depth, CxtI); } break; } @@ -543,6 +579,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known.One.lshrInPlace(ShiftAmt); if (ShiftAmt) Known.Zero.setHighBits(ShiftAmt); // high bits known zero. + } else { + computeKnownBits(I, Known, Depth, CxtI); } break; } @@ -603,6 +641,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. Known.One |= HighBits; } + } else { + computeKnownBits(I, Known, Depth, CxtI); } break; } @@ -624,6 +664,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Propagate zero bits from the input. Known.Zero.setHighBits(std::min( BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); + } else { + computeKnownBits(I, Known, Depth, CxtI); } break; } @@ -682,7 +724,8 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known.Zero = APInt::getHighBitsSet(BitWidth, Leaders) & DemandedMask; break; } - case Instruction::Call: + case Instruction::Call: { + bool KnownBitsComputed = false; if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { switch (II->getIntrinsicID()) { default: break; @@ -714,8 +757,6 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, NewVal->takeName(I); return InsertNewInstWith(NewVal, *I); } - - // TODO: Could compute known zero/one bits based on the input. break; } case Intrinsic::fshr: @@ -740,6 +781,7 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, RHSKnown.Zero.lshr(BitWidth - ShiftAmt); Known.One = LHSKnown.One.shl(ShiftAmt) | RHSKnown.One.lshr(BitWidth - ShiftAmt); + KnownBitsComputed = true; break; } case Intrinsic::x86_mmx_pmovmskb: @@ -768,16 +810,21 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // We know that the upper bits are set to zero. Known.Zero.setBitsFrom(ArgWidth); - return nullptr; + KnownBitsComputed = true; + break; } case Intrinsic::x86_sse42_crc32_64_64: Known.Zero.setBitsFrom(32); - return nullptr; + KnownBitsComputed = true; + break; } } - computeKnownBits(V, Known, Depth, CxtI); + + if (!KnownBitsComputed) + computeKnownBits(V, Known, Depth, CxtI); break; } + } // If the client is only demanding bits that we know, return the known // constant. @@ -811,15 +858,12 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - // Output known-0 are known to be clear if zero in either the LHS | RHS. - APInt IKnownZero = RHSKnown.Zero | LHSKnown.Zero; - // Output known-1 bits are only known if set in both the LHS & RHS. - APInt IKnownOne = RHSKnown.One & LHSKnown.One; + Known = LHSKnown & RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(ITy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); // If all of the demanded bits are known 1 on one side, return the other. // These bits cannot contribute to the result of the 'and' in this @@ -829,8 +873,6 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) return I->getOperand(1); - Known.Zero = std::move(IKnownZero); - Known.One = std::move(IKnownOne); break; } case Instruction::Or: { @@ -842,15 +884,12 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - // Output known-0 bits are only known if clear in both the LHS & RHS. - APInt IKnownZero = RHSKnown.Zero & LHSKnown.Zero; - // Output known-1 are known to be set if set in either the LHS | RHS. - APInt IKnownOne = RHSKnown.One | LHSKnown.One; + Known = LHSKnown | RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(ITy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); // If all of the demanded bits are known zero on one side, return the // other. These bits cannot contribute to the result of the 'or' in this @@ -860,8 +899,6 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) return I->getOperand(1); - Known.Zero = std::move(IKnownZero); - Known.One = std::move(IKnownOne); break; } case Instruction::Xor: { @@ -872,17 +909,12 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - // Output known-0 bits are known if clear or set in both the LHS & RHS. - APInt IKnownZero = (RHSKnown.Zero & LHSKnown.Zero) | - (RHSKnown.One & LHSKnown.One); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - APInt IKnownOne = (RHSKnown.Zero & LHSKnown.One) | - (RHSKnown.One & LHSKnown.Zero); + Known = LHSKnown ^ RHSKnown; // If the client is only demanding bits that we know, return the known // constant. - if (DemandedMask.isSubsetOf(IKnownZero|IKnownOne)) - return Constant::getIntegerValue(ITy, IKnownOne); + if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) + return Constant::getIntegerValue(ITy, Known.One); // If all of the demanded bits are known zero on one side, return the // other. @@ -891,10 +923,6 @@ Value *InstCombiner::SimplifyMultipleUseDemandedBits(Instruction *I, if (DemandedMask.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); - // Output known-0 bits are known if clear or set in both the LHS & RHS. - Known.Zero = std::move(IKnownZero); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - Known.One = std::move(IKnownOne); break; } default: @@ -1008,17 +1036,69 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, DemandedElts.getActiveBits() == 3) return nullptr; - unsigned VWidth = II->getType()->getVectorNumElements(); + auto *IIVTy = cast<VectorType>(II->getType()); + unsigned VWidth = IIVTy->getNumElements(); if (VWidth == 1) return nullptr; - ConstantInt *NewDMask = nullptr; + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(II); + + // Assume the arguments are unchanged and later override them, if needed. + SmallVector<Value *, 16> Args(II->arg_begin(), II->arg_end()); if (DMaskIdx < 0) { - // Pretend that a prefix of elements is demanded to simplify the code - // below. - DemandedElts = (1 << DemandedElts.getActiveBits()) - 1; + // Buffer case. + + const unsigned ActiveBits = DemandedElts.getActiveBits(); + const unsigned UnusedComponentsAtFront = DemandedElts.countTrailingZeros(); + + // Start assuming the prefix of elements is demanded, but possibly clear + // some other bits if there are trailing zeros (unused components at front) + // and update offset. + DemandedElts = (1 << ActiveBits) - 1; + + if (UnusedComponentsAtFront > 0) { + static const unsigned InvalidOffsetIdx = 0xf; + + unsigned OffsetIdx; + switch (II->getIntrinsicID()) { + case Intrinsic::amdgcn_raw_buffer_load: + OffsetIdx = 1; + break; + case Intrinsic::amdgcn_s_buffer_load: + // If resulting type is vec3, there is no point in trimming the + // load with updated offset, as the vec3 would most likely be widened to + // vec4 anyway during lowering. + if (ActiveBits == 4 && UnusedComponentsAtFront == 1) + OffsetIdx = InvalidOffsetIdx; + else + OffsetIdx = 1; + break; + case Intrinsic::amdgcn_struct_buffer_load: + OffsetIdx = 2; + break; + default: + // TODO: handle tbuffer* intrinsics. + OffsetIdx = InvalidOffsetIdx; + break; + } + + if (OffsetIdx != InvalidOffsetIdx) { + // Clear demanded bits and update the offset. + DemandedElts &= ~((1 << UnusedComponentsAtFront) - 1); + auto *Offset = II->getArgOperand(OffsetIdx); + unsigned SingleComponentSizeInBits = + getDataLayout().getTypeSizeInBits(II->getType()->getScalarType()); + unsigned OffsetAdd = + UnusedComponentsAtFront * SingleComponentSizeInBits / 8; + auto *OffsetAddVal = ConstantInt::get(Offset->getType(), OffsetAdd); + Args[OffsetIdx] = Builder.CreateAdd(Offset, OffsetAddVal); + } + } } else { + // Image case. + ConstantInt *DMask = cast<ConstantInt>(II->getArgOperand(DMaskIdx)); unsigned DMaskVal = DMask->getZExtValue() & 0xf; @@ -1037,7 +1117,7 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, } if (DMaskVal != NewDMaskVal) - NewDMask = ConstantInt::get(DMask->getType(), NewDMaskVal); + Args[DMaskIdx] = ConstantInt::get(DMask->getType(), NewDMaskVal); } unsigned NewNumElts = DemandedElts.countPopulation(); @@ -1045,39 +1125,25 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, return UndefValue::get(II->getType()); if (NewNumElts >= VWidth && DemandedElts.isMask()) { - if (NewDMask) - II->setArgOperand(DMaskIdx, NewDMask); + if (DMaskIdx >= 0) + II->setArgOperand(DMaskIdx, Args[DMaskIdx]); return nullptr; } - // Determine the overload types of the original intrinsic. - auto IID = II->getIntrinsicID(); - SmallVector<Intrinsic::IITDescriptor, 16> Table; - getIntrinsicInfoTableEntries(IID, Table); - ArrayRef<Intrinsic::IITDescriptor> TableRef = Table; - // Validate function argument and return types, extracting overloaded types // along the way. - FunctionType *FTy = II->getCalledFunction()->getFunctionType(); SmallVector<Type *, 6> OverloadTys; - Intrinsic::matchIntrinsicSignature(FTy, TableRef, OverloadTys); + if (!Intrinsic::getIntrinsicSignature(II->getCalledFunction(), OverloadTys)) + return nullptr; Module *M = II->getParent()->getParent()->getParent(); - Type *EltTy = II->getType()->getVectorElementType(); - Type *NewTy = (NewNumElts == 1) ? EltTy : VectorType::get(EltTy, NewNumElts); + Type *EltTy = IIVTy->getElementType(); + Type *NewTy = + (NewNumElts == 1) ? EltTy : FixedVectorType::get(EltTy, NewNumElts); OverloadTys[0] = NewTy; - Function *NewIntrin = Intrinsic::getDeclaration(M, IID, OverloadTys); - - SmallVector<Value *, 16> Args; - for (unsigned I = 0, E = II->getNumArgOperands(); I != E; ++I) - Args.push_back(II->getArgOperand(I)); - - if (NewDMask) - Args[DMaskIdx] = NewDMask; - - IRBuilderBase::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(II); + Function *NewIntrin = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), OverloadTys); CallInst *NewCall = Builder.CreateCall(NewIntrin, Args); NewCall->takeName(II); @@ -1088,7 +1154,7 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, DemandedElts.countTrailingZeros()); } - SmallVector<uint32_t, 8> EltMask; + SmallVector<int, 8> EltMask; unsigned NewLoadIdx = 0; for (unsigned OrigLoadIdx = 0; OrigLoadIdx < VWidth; ++OrigLoadIdx) { if (!!DemandedElts[OrigLoadIdx]) @@ -1120,7 +1186,12 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth, bool AllowMultipleUsers) { - unsigned VWidth = V->getType()->getVectorNumElements(); + // Cannot analyze scalable type. The number of vector elements is not a + // compile-time constant. + if (isa<ScalableVectorType>(V->getType())) + return nullptr; + + unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1199,10 +1270,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, auto *II = dyn_cast<IntrinsicInst>(Inst); Value *Op = II ? II->getArgOperand(OpNum) : Inst->getOperand(OpNum); if (Value *V = SimplifyDemandedVectorElts(Op, Demanded, Undef, Depth + 1)) { - if (II) - II->setArgOperand(OpNum, V); - else - Inst->setOperand(OpNum, V); + replaceOperand(*Inst, OpNum, V); MadeChange = true; } }; @@ -1268,7 +1336,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // If this is inserting an element that isn't demanded, remove this // insertelement. if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { - Worklist.Add(I); + Worklist.push(I); return I->getOperand(0); } @@ -1282,7 +1350,25 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, Shuffle->getOperand(1)->getType() && "Expected shuffle operands to have same type"); unsigned OpWidth = - Shuffle->getOperand(0)->getType()->getVectorNumElements(); + cast<VectorType>(Shuffle->getOperand(0)->getType())->getNumElements(); + // Handle trivial case of a splat. Only check the first element of LHS + // operand. + if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && + DemandedElts.isAllOnesValue()) { + if (!isa<UndefValue>(I->getOperand(1))) { + I->setOperand(1, UndefValue::get(I->getOperand(1)->getType())); + MadeChange = true; + } + APInt LeftDemanded(OpWidth, 1); + APInt LHSUndefElts(OpWidth, 0); + simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); + if (LHSUndefElts[0]) + UndefElts = EltMask; + else + UndefElts.clearAllBits(); + break; + } + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) { @@ -1396,15 +1482,14 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } if (NewUndefElts) { // Add additional discovered undefs. - SmallVector<Constant*, 16> Elts; + SmallVector<int, 16> Elts; for (unsigned i = 0; i < VWidth; ++i) { if (UndefElts[i]) - Elts.push_back(UndefValue::get(Type::getInt32Ty(I->getContext()))); + Elts.push_back(UndefMaskElem); else - Elts.push_back(ConstantInt::get(Type::getInt32Ty(I->getContext()), - Shuffle->getMaskValue(i))); + Elts.push_back(Shuffle->getMaskValue(i)); } - I->setOperand(2, ConstantVector::get(Elts)); + Shuffle->setShuffleMask(Elts); MadeChange = true; } break; @@ -1549,7 +1634,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // use Arg0 if DemandedElts[0] is clear like we do for other intrinsics. // Instead we should return a zero vector. if (!DemandedElts[0]) { - Worklist.Add(II); + Worklist.push(II); return ConstantAggregateZero::get(II->getType()); } @@ -1568,7 +1653,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { - Worklist.Add(II); + Worklist.push(II); return II->getArgOperand(0); } // TODO: If only low elt lower SQRT to FSQRT (with rounding/exceptions @@ -1588,7 +1673,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { - Worklist.Add(II); + Worklist.push(II); return II->getArgOperand(0); } @@ -1615,7 +1700,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { - Worklist.Add(II); + Worklist.push(II); return II->getArgOperand(0); } @@ -1649,7 +1734,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // If lowest element of a scalar op isn't used then use Arg0. if (!DemandedElts[0]) { - Worklist.Add(II); + Worklist.push(II); return II->getArgOperand(0); } @@ -1678,7 +1763,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::x86_avx512_packusdw_512: case Intrinsic::x86_avx512_packuswb_512: { auto *Ty0 = II->getArgOperand(0)->getType(); - unsigned InnerVWidth = Ty0->getVectorNumElements(); + unsigned InnerVWidth = cast<VectorType>(Ty0)->getNumElements(); assert(VWidth == (InnerVWidth * 2) && "Unexpected input size"); unsigned NumLanes = Ty0->getPrimitiveSizeInBits() / 128; @@ -1747,6 +1832,7 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::amdgcn_raw_buffer_load: case Intrinsic::amdgcn_raw_buffer_load_format: case Intrinsic::amdgcn_raw_tbuffer_load: + case Intrinsic::amdgcn_s_buffer_load: case Intrinsic::amdgcn_struct_buffer_load: case Intrinsic::amdgcn_struct_buffer_load_format: case Intrinsic::amdgcn_struct_tbuffer_load: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index f604c9dc32cae..ff70347569abc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/VectorUtils.h" @@ -57,12 +58,15 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { // An insertelement to the same constant index as our extract will simplify // to the scalar inserted element. An insertelement to a different constant // index is irrelevant to our extract. - if (match(V, m_InsertElement(m_Value(), m_Value(), m_ConstantInt()))) + if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt()))) return IsConstantExtractIndex; if (match(V, m_OneUse(m_Load(m_Value())))) return true; + if (match(V, m_OneUse(m_UnOp()))) + return true; + Value *V0, *V1; if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) if (cheapToScalarize(V0, IsConstantExtractIndex) || @@ -172,9 +176,9 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, // If this extractelement is using a bitcast from a vector of the same number // of elements, see if we can find the source element from the source vector: // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] - Type *SrcTy = X->getType(); + auto *SrcTy = cast<VectorType>(X->getType()); Type *DestTy = Ext.getType(); - unsigned NumSrcElts = SrcTy->getVectorNumElements(); + unsigned NumSrcElts = SrcTy->getNumElements(); unsigned NumElts = Ext.getVectorOperandType()->getNumElements(); if (NumSrcElts == NumElts) if (Value *Elt = findScalarElement(X, ExtIndexC)) @@ -185,8 +189,8 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, if (NumSrcElts < NumElts) { Value *Scalar; uint64_t InsIndexC; - if (!match(X, m_InsertElement(m_Value(), m_Value(Scalar), - m_ConstantInt(InsIndexC)))) + if (!match(X, m_InsertElt(m_Value(), m_Value(Scalar), + m_ConstantInt(InsIndexC)))) return nullptr; // The extract must be from the subset of vector elements that we inserted @@ -255,7 +259,7 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, /// Find elements of V demanded by UserInstr. static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { - unsigned VWidth = V->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); // Conservatively assume that all elements are needed. APInt UsedElts(APInt::getAllOnesValue(VWidth)); @@ -272,7 +276,8 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { } case Instruction::ShuffleVector: { ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); - unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements(); + unsigned MaskNumElts = + cast<VectorType>(UserInstr->getType())->getNumElements(); UsedElts = APInt(VWidth, 0); for (unsigned i = 0; i < MaskNumElts; i++) { @@ -298,7 +303,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { /// no user demands an element of V, then the corresponding bit /// remains unset in the returned value. static APInt findDemandedEltsByAllUsers(Value *V) { - unsigned VWidth = V->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(V->getType())->getNumElements(); APInt UnionUsedElts(VWidth, 0); for (const Use &U : V->uses()) { @@ -327,14 +332,18 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast<ConstantInt>(Index); if (IndexC) { - unsigned NumElts = EI.getVectorOperandType()->getNumElements(); + ElementCount EC = EI.getVectorOperandType()->getElementCount(); + unsigned NumElts = EC.Min; // InstSimplify should handle cases where the index is invalid. - if (!IndexC->getValue().ule(NumElts)) + // For fixed-length vector, it's invalid to extract out-of-range element. + if (!EC.Scalable && IndexC->getValue().uge(NumElts)) return nullptr; // This instruction only demands the single element from the input vector. - if (NumElts != 1) { + // Skip for scalable type, the number of elements is unknown at + // compile-time. + if (!EC.Scalable && NumElts != 1) { // If the input vector has a single use, simplify it based on this use // property. if (SrcVec->hasOneUse()) { @@ -342,10 +351,8 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { APInt DemandedElts(NumElts, 0); DemandedElts.setBit(IndexC->getZExtValue()); if (Value *V = - SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { - EI.setOperand(0, V); - return &EI; - } + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) + return replaceOperand(EI, 0, V); } else { // If the input vector has multiple uses, simplify it based on a union // of all elements used. @@ -373,6 +380,16 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { return ScalarPHI; } + // TODO come up with a n-ary matcher that subsumes both unary and + // binary matchers. + UnaryOperator *UO; + if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) { + // extelt (unop X), Index --> unop (extelt X, Index) + Value *X = UO->getOperand(0); + Value *E = Builder.CreateExtractElement(X, Index); + return UnaryOperator::CreateWithCopiedFlags(UO->getOpcode(), E, UO); + } + BinaryOperator *BO; if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) { // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) @@ -399,19 +416,18 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { return replaceInstUsesWith(EI, IE->getOperand(1)); // If the inserted and extracted elements are constants, they must not // be the same value, extract from the pre-inserted value instead. - if (isa<Constant>(IE->getOperand(2)) && IndexC) { - Worklist.AddValue(SrcVec); - EI.setOperand(0, IE->getOperand(0)); - return &EI; - } + if (isa<Constant>(IE->getOperand(2)) && IndexC) + return replaceOperand(EI, 0, IE->getOperand(0)); } else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) { // If this is extracting an element from a shufflevector, figure out where // it came from and extract from the appropriate input element instead. - if (auto *Elt = dyn_cast<ConstantInt>(Index)) { - int SrcIdx = SVI->getMaskValue(Elt->getZExtValue()); + // Restrict the following transformation to fixed-length vector. + if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) { + int SrcIdx = + SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue()); Value *Src; - unsigned LHSWidth = - SVI->getOperand(0)->getType()->getVectorNumElements(); + unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType()) + ->getNumElements(); if (SrcIdx < 0) return replaceInstUsesWith(EI, UndefValue::get(EI.getType())); @@ -422,9 +438,8 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Src = SVI->getOperand(1); } Type *Int32Ty = Type::getInt32Ty(EI.getContext()); - return ExtractElementInst::Create(Src, - ConstantInt::get(Int32Ty, - SrcIdx, false)); + return ExtractElementInst::Create( + Src, ConstantInt::get(Int32Ty, SrcIdx, false)); } } else if (auto *CI = dyn_cast<CastInst>(I)) { // Canonicalize extractelement(cast) -> cast(extractelement). @@ -432,7 +447,6 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { // nothing. if (CI->hasOneUse() && (CI->getOpcode() != Instruction::BitCast)) { Value *EE = Builder.CreateExtractElement(CI->getOperand(0), Index); - Worklist.AddValue(EE); return CastInst::Create(CI->getOpcode(), EE, EI.getType()); } } @@ -443,26 +457,25 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { /// If V is a shuffle of values that ONLY returns elements from either LHS or /// RHS, return the shuffle mask and true. Otherwise, return false. static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, - SmallVectorImpl<Constant*> &Mask) { + SmallVectorImpl<int> &Mask) { assert(LHS->getType() == RHS->getType() && "Invalid CollectSingleShuffleElements"); - unsigned NumElts = V->getType()->getVectorNumElements(); + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); if (isa<UndefValue>(V)) { - Mask.assign(NumElts, UndefValue::get(Type::getInt32Ty(V->getContext()))); + Mask.assign(NumElts, -1); return true; } if (V == LHS) { for (unsigned i = 0; i != NumElts; ++i) - Mask.push_back(ConstantInt::get(Type::getInt32Ty(V->getContext()), i)); + Mask.push_back(i); return true; } if (V == RHS) { for (unsigned i = 0; i != NumElts; ++i) - Mask.push_back(ConstantInt::get(Type::getInt32Ty(V->getContext()), - i+NumElts)); + Mask.push_back(i + NumElts); return true; } @@ -481,14 +494,15 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, // transitively ok. if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { // If so, update the mask to reflect the inserted undef. - Mask[InsertedIdx] = UndefValue::get(Type::getInt32Ty(V->getContext())); + Mask[InsertedIdx] = -1; return true; } } else if (ExtractElementInst *EI = dyn_cast<ExtractElementInst>(ScalarOp)){ if (isa<ConstantInt>(EI->getOperand(1))) { unsigned ExtractedIdx = cast<ConstantInt>(EI->getOperand(1))->getZExtValue(); - unsigned NumLHSElts = LHS->getType()->getVectorNumElements(); + unsigned NumLHSElts = + cast<VectorType>(LHS->getType())->getNumElements(); // This must be extracting from either LHS or RHS. if (EI->getOperand(0) == LHS || EI->getOperand(0) == RHS) { @@ -497,14 +511,10 @@ static bool collectSingleShuffleElements(Value *V, Value *LHS, Value *RHS, if (collectSingleShuffleElements(VecOp, LHS, RHS, Mask)) { // If so, update the mask to reflect the inserted value. if (EI->getOperand(0) == LHS) { - Mask[InsertedIdx % NumElts] = - ConstantInt::get(Type::getInt32Ty(V->getContext()), - ExtractedIdx); + Mask[InsertedIdx % NumElts] = ExtractedIdx; } else { assert(EI->getOperand(0) == RHS); - Mask[InsertedIdx % NumElts] = - ConstantInt::get(Type::getInt32Ty(V->getContext()), - ExtractedIdx + NumLHSElts); + Mask[InsertedIdx % NumElts] = ExtractedIdx + NumLHSElts; } return true; } @@ -524,8 +534,8 @@ static void replaceExtractElements(InsertElementInst *InsElt, InstCombiner &IC) { VectorType *InsVecType = InsElt->getType(); VectorType *ExtVecType = ExtElt->getVectorOperandType(); - unsigned NumInsElts = InsVecType->getVectorNumElements(); - unsigned NumExtElts = ExtVecType->getVectorNumElements(); + unsigned NumInsElts = InsVecType->getNumElements(); + unsigned NumExtElts = ExtVecType->getNumElements(); // The inserted-to vector must be wider than the extracted-from vector. if (InsVecType->getElementType() != ExtVecType->getElementType() || @@ -536,12 +546,11 @@ static void replaceExtractElements(InsertElementInst *InsElt, // values. The mask selects all of the values of the original vector followed // by as many undefined values as needed to create a vector of the same length // as the inserted-to vector. - SmallVector<Constant *, 16> ExtendMask; - IntegerType *IntType = Type::getInt32Ty(InsElt->getContext()); + SmallVector<int, 16> ExtendMask; for (unsigned i = 0; i < NumExtElts; ++i) - ExtendMask.push_back(ConstantInt::get(IntType, i)); + ExtendMask.push_back(i); for (unsigned i = NumExtElts; i < NumInsElts; ++i) - ExtendMask.push_back(UndefValue::get(IntType)); + ExtendMask.push_back(-1); Value *ExtVecOp = ExtElt->getVectorOperand(); auto *ExtVecOpInst = dyn_cast<Instruction>(ExtVecOp); @@ -569,8 +578,8 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) return; - auto *WideVec = new ShuffleVectorInst(ExtVecOp, UndefValue::get(ExtVecType), - ConstantVector::get(ExtendMask)); + auto *WideVec = + new ShuffleVectorInst(ExtVecOp, UndefValue::get(ExtVecType), ExtendMask); // Insert the new shuffle after the vector operand of the extract is defined // (as long as it's not a PHI) or at the start of the basic block of the @@ -603,21 +612,20 @@ static void replaceExtractElements(InsertElementInst *InsElt, /// often been chosen carefully to be efficiently implementable on the target. using ShuffleOps = std::pair<Value *, Value *>; -static ShuffleOps collectShuffleElements(Value *V, - SmallVectorImpl<Constant *> &Mask, +static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<int> &Mask, Value *PermittedRHS, InstCombiner &IC) { assert(V->getType()->isVectorTy() && "Invalid shuffle!"); - unsigned NumElts = V->getType()->getVectorNumElements(); + unsigned NumElts = cast<FixedVectorType>(V->getType())->getNumElements(); if (isa<UndefValue>(V)) { - Mask.assign(NumElts, UndefValue::get(Type::getInt32Ty(V->getContext()))); + Mask.assign(NumElts, -1); return std::make_pair( PermittedRHS ? UndefValue::get(PermittedRHS->getType()) : V, nullptr); } if (isa<ConstantAggregateZero>(V)) { - Mask.assign(NumElts, ConstantInt::get(Type::getInt32Ty(V->getContext()),0)); + Mask.assign(NumElts, 0); return std::make_pair(V, nullptr); } @@ -648,14 +656,13 @@ static ShuffleOps collectShuffleElements(Value *V, // We tried our best, but we can't find anything compatible with RHS // further up the chain. Return a trivial shuffle. for (unsigned i = 0; i < NumElts; ++i) - Mask[i] = ConstantInt::get(Type::getInt32Ty(V->getContext()), i); + Mask[i] = i; return std::make_pair(V, nullptr); } - unsigned NumLHSElts = RHS->getType()->getVectorNumElements(); - Mask[InsertedIdx % NumElts] = - ConstantInt::get(Type::getInt32Ty(V->getContext()), - NumLHSElts+ExtractedIdx); + unsigned NumLHSElts = + cast<VectorType>(RHS->getType())->getNumElements(); + Mask[InsertedIdx % NumElts] = NumLHSElts + ExtractedIdx; return std::make_pair(LR.first, RHS); } @@ -663,11 +670,9 @@ static ShuffleOps collectShuffleElements(Value *V, // We've gone as far as we can: anything on the other side of the // extractelement will already have been converted into a shuffle. unsigned NumLHSElts = - EI->getOperand(0)->getType()->getVectorNumElements(); + cast<VectorType>(EI->getOperand(0)->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) - Mask.push_back(ConstantInt::get( - Type::getInt32Ty(V->getContext()), - i == InsertedIdx ? ExtractedIdx : NumLHSElts + i)); + Mask.push_back(i == InsertedIdx ? ExtractedIdx : NumLHSElts + i); return std::make_pair(EI->getOperand(0), PermittedRHS); } @@ -683,7 +688,7 @@ static ShuffleOps collectShuffleElements(Value *V, // Otherwise, we can't do anything fancy. Return an identity vector. for (unsigned i = 0; i != NumElts; ++i) - Mask.push_back(ConstantInt::get(Type::getInt32Ty(V->getContext()), i)); + Mask.push_back(i); return std::make_pair(V, nullptr); } @@ -723,8 +728,14 @@ Instruction *InstCombiner::visitInsertValueInst(InsertValueInst &I) { } static bool isShuffleEquivalentToSelect(ShuffleVectorInst &Shuf) { - int MaskSize = Shuf.getMask()->getType()->getVectorNumElements(); - int VecSize = Shuf.getOperand(0)->getType()->getVectorNumElements(); + // Can not analyze scalable type, the number of elements is not a compile-time + // constant. + if (isa<ScalableVectorType>(Shuf.getOperand(0)->getType())) + return false; + + int MaskSize = Shuf.getShuffleMask().size(); + int VecSize = + cast<FixedVectorType>(Shuf.getOperand(0)->getType())->getNumElements(); // A vector select does not change the size of the operands. if (MaskSize != VecSize) @@ -750,8 +761,12 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { if (InsElt.hasOneUse() && isa<InsertElementInst>(InsElt.user_back())) return nullptr; - auto *VecTy = cast<VectorType>(InsElt.getType()); - unsigned NumElements = VecTy->getNumElements(); + VectorType *VecTy = InsElt.getType(); + // Can not handle scalable type, the number of elements is not a compile-time + // constant. + if (isa<ScalableVectorType>(VecTy)) + return nullptr; + unsigned NumElements = cast<FixedVectorType>(VecTy)->getNumElements(); // Do not try to do this for a one-element vector, since that's a nop, // and will cause an inf-loop. @@ -760,7 +775,7 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { Value *SplatVal = InsElt.getOperand(1); InsertElementInst *CurrIE = &InsElt; - SmallVector<bool, 16> ElementPresent(NumElements, false); + SmallBitVector ElementPresent(NumElements, false); InsertElementInst *FirstIE = nullptr; // Walk the chain backwards, keeping track of which indices we inserted into, @@ -792,7 +807,7 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { // TODO: If the base vector is not undef, it might be better to create a splat // and then a select-shuffle (blend) with the base vector. if (!isa<UndefValue>(FirstIE->getOperand(0))) - if (any_of(ElementPresent, [](bool Present) { return !Present; })) + if (!ElementPresent.all()) return nullptr; // Create the insert + shuffle. @@ -803,12 +818,12 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { FirstIE = InsertElementInst::Create(UndefVec, SplatVal, Zero, "", &InsElt); // Splat from element 0, but replace absent elements with undef in the mask. - SmallVector<Constant *, 16> Mask(NumElements, Zero); + SmallVector<int, 16> Mask(NumElements, 0); for (unsigned i = 0; i != NumElements; ++i) if (!ElementPresent[i]) - Mask[i] = UndefValue::get(Int32Ty); + Mask[i] = -1; - return new ShuffleVectorInst(FirstIE, UndefVec, ConstantVector::get(Mask)); + return new ShuffleVectorInst(FirstIE, UndefVec, Mask); } /// Try to fold an insert element into an existing splat shuffle by changing @@ -819,6 +834,11 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { if (!Shuf || !Shuf->isZeroEltSplat()) return nullptr; + // Bail out early if shuffle is scalable type. The number of elements in + // shuffle mask is unknown at compile-time. + if (isa<ScalableVectorType>(Shuf->getType())) + return nullptr; + // Check for a constant insertion index. uint64_t IdxC; if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) @@ -827,21 +847,18 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { // Check if the splat shuffle's input is the same as this insert's scalar op. Value *X = InsElt.getOperand(1); Value *Op0 = Shuf->getOperand(0); - if (!match(Op0, m_InsertElement(m_Undef(), m_Specific(X), m_ZeroInt()))) + if (!match(Op0, m_InsertElt(m_Undef(), m_Specific(X), m_ZeroInt()))) return nullptr; // Replace the shuffle mask element at the index of this insert with a zero. // For example: // inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1 // --> shuf (inselt undef, X, 0), undef, <0,0,0,undef> - unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); - SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); - Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); - Constant *Zero = ConstantInt::getNullValue(I32Ty); + unsigned NumMaskElts = Shuf->getType()->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts); for (unsigned i = 0; i != NumMaskElts; ++i) - NewMaskVec[i] = i == IdxC ? Zero : Shuf->getMask()->getAggregateElement(i); + NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i); - Constant *NewMask = ConstantVector::get(NewMaskVec); return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); } @@ -854,6 +871,11 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) return nullptr; + // Bail out early if shuffle is scalable type. The number of elements in + // shuffle mask is unknown at compile-time. + if (isa<ScalableVectorType>(Shuf->getType())) + return nullptr; + // Check for a constant insertion index. uint64_t IdxC; if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) @@ -863,34 +885,31 @@ static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { // input vector. Value *Scalar = InsElt.getOperand(1); Value *X = Shuf->getOperand(0); - if (!match(Scalar, m_ExtractElement(m_Specific(X), m_SpecificInt(IdxC)))) + if (!match(Scalar, m_ExtractElt(m_Specific(X), m_SpecificInt(IdxC)))) return nullptr; // Replace the shuffle mask element at the index of this extract+insert with // that same index value. // For example: // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' - unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); - SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); - Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); - Constant *NewMaskEltC = ConstantInt::get(I32Ty, IdxC); - Constant *OldMask = Shuf->getMask(); + unsigned NumMaskElts = Shuf->getType()->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts); + ArrayRef<int> OldMask = Shuf->getShuffleMask(); for (unsigned i = 0; i != NumMaskElts; ++i) { if (i != IdxC) { // All mask elements besides the inserted element remain the same. - NewMaskVec[i] = OldMask->getAggregateElement(i); - } else if (OldMask->getAggregateElement(i) == NewMaskEltC) { + NewMask[i] = OldMask[i]; + } else if (OldMask[i] == (int)IdxC) { // If the mask element was already set, there's nothing to do // (demanded elements analysis may unset it later). return nullptr; } else { - assert(isa<UndefValue>(OldMask->getAggregateElement(i)) && + assert(OldMask[i] == UndefMaskElem && "Unexpected shuffle mask element for identity shuffle"); - NewMaskVec[i] = NewMaskEltC; + NewMask[i] = IdxC; } } - Constant *NewMask = ConstantVector::get(NewMaskVec); return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask); } @@ -958,31 +977,34 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { // mask vector with the insertelt index plus the length of the vector // (because the constant vector operand of a shuffle is always the 2nd // operand). - Constant *Mask = Shuf->getMask(); - unsigned NumElts = Mask->getType()->getVectorNumElements(); + ArrayRef<int> Mask = Shuf->getShuffleMask(); + unsigned NumElts = Mask.size(); SmallVector<Constant *, 16> NewShufElts(NumElts); - SmallVector<Constant *, 16> NewMaskElts(NumElts); + SmallVector<int, 16> NewMaskElts(NumElts); for (unsigned I = 0; I != NumElts; ++I) { if (I == InsEltIndex) { NewShufElts[I] = InsEltScalar; - Type *Int32Ty = Type::getInt32Ty(Shuf->getContext()); - NewMaskElts[I] = ConstantInt::get(Int32Ty, InsEltIndex + NumElts); + NewMaskElts[I] = InsEltIndex + NumElts; } else { // Copy over the existing values. NewShufElts[I] = ShufConstVec->getAggregateElement(I); - NewMaskElts[I] = Mask->getAggregateElement(I); + NewMaskElts[I] = Mask[I]; } } // Create new operands for a shuffle that includes the constant of the // original insertelt. The old shuffle will be dead now. return new ShuffleVectorInst(Shuf->getOperand(0), - ConstantVector::get(NewShufElts), - ConstantVector::get(NewMaskElts)); + ConstantVector::get(NewShufElts), NewMaskElts); } else if (auto *IEI = dyn_cast<InsertElementInst>(Inst)) { // Transform sequences of insertelements ops with constant data/indexes into // a single shuffle op. - unsigned NumElts = InsElt.getType()->getNumElements(); + // Can not handle scalable type, the number of elements needed to create + // shuffle mask is not a compile-time constant. + if (isa<ScalableVectorType>(InsElt.getType())) + return nullptr; + unsigned NumElts = + cast<FixedVectorType>(InsElt.getType())->getNumElements(); uint64_t InsertIdx[2]; Constant *Val[2]; @@ -992,33 +1014,29 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { !match(IEI->getOperand(1), m_Constant(Val[1]))) return nullptr; SmallVector<Constant *, 16> Values(NumElts); - SmallVector<Constant *, 16> Mask(NumElts); + SmallVector<int, 16> Mask(NumElts); auto ValI = std::begin(Val); // Generate new constant vector and mask. // We have 2 values/masks from the insertelements instructions. Insert them // into new value/mask vectors. for (uint64_t I : InsertIdx) { if (!Values[I]) { - assert(!Mask[I]); Values[I] = *ValI; - Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), - NumElts + I); + Mask[I] = NumElts + I; } ++ValI; } // Remaining values are filled with 'undef' values. for (unsigned I = 0; I < NumElts; ++I) { if (!Values[I]) { - assert(!Mask[I]); Values[I] = UndefValue::get(InsElt.getType()->getElementType()); - Mask[I] = ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), I); + Mask[I] = I; } } // Create new operands for a shuffle that includes the constant of the // original insertelt. return new ShuffleVectorInst(IEI->getOperand(0), - ConstantVector::get(Values), - ConstantVector::get(Mask)); + ConstantVector::get(Values), Mask); } return nullptr; } @@ -1032,28 +1050,51 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) return replaceInstUsesWith(IE, V); + // If the scalar is bitcast and inserted into undef, do the insert in the + // source type followed by bitcast. + // TODO: Generalize for insert into any constant, not just undef? + Value *ScalarSrc; + if (match(VecOp, m_Undef()) && + match(ScalarOp, m_OneUse(m_BitCast(m_Value(ScalarSrc)))) && + (ScalarSrc->getType()->isIntegerTy() || + ScalarSrc->getType()->isFloatingPointTy())) { + // inselt undef, (bitcast ScalarSrc), IdxOp --> + // bitcast (inselt undef, ScalarSrc, IdxOp) + Type *ScalarTy = ScalarSrc->getType(); + Type *VecTy = VectorType::get(ScalarTy, IE.getType()->getElementCount()); + UndefValue *NewUndef = UndefValue::get(VecTy); + Value *NewInsElt = Builder.CreateInsertElement(NewUndef, ScalarSrc, IdxOp); + return new BitCastInst(NewInsElt, IE.getType()); + } + // If the vector and scalar are both bitcast from the same element type, do // the insert in that source type followed by bitcast. - Value *VecSrc, *ScalarSrc; + Value *VecSrc; if (match(VecOp, m_BitCast(m_Value(VecSrc))) && match(ScalarOp, m_BitCast(m_Value(ScalarSrc))) && (VecOp->hasOneUse() || ScalarOp->hasOneUse()) && VecSrc->getType()->isVectorTy() && !ScalarSrc->getType()->isVectorTy() && - VecSrc->getType()->getVectorElementType() == ScalarSrc->getType()) { + cast<VectorType>(VecSrc->getType())->getElementType() == + ScalarSrc->getType()) { // inselt (bitcast VecSrc), (bitcast ScalarSrc), IdxOp --> // bitcast (inselt VecSrc, ScalarSrc, IdxOp) Value *NewInsElt = Builder.CreateInsertElement(VecSrc, ScalarSrc, IdxOp); return new BitCastInst(NewInsElt, IE.getType()); } - // If the inserted element was extracted from some other vector and both - // indexes are valid constants, try to turn this into a shuffle. + // If the inserted element was extracted from some other fixed-length vector + // and both indexes are valid constants, try to turn this into a shuffle. + // Can not handle scalable vector type, the number of elements needed to + // create shuffle mask is not a compile-time constant. uint64_t InsertedIdx, ExtractedIdx; Value *ExtVecOp; - if (match(IdxOp, m_ConstantInt(InsertedIdx)) && - match(ScalarOp, m_ExtractElement(m_Value(ExtVecOp), - m_ConstantInt(ExtractedIdx))) && - ExtractedIdx < ExtVecOp->getType()->getVectorNumElements()) { + if (isa<FixedVectorType>(IE.getType()) && + match(IdxOp, m_ConstantInt(InsertedIdx)) && + match(ScalarOp, + m_ExtractElt(m_Value(ExtVecOp), m_ConstantInt(ExtractedIdx))) && + isa<FixedVectorType>(ExtVecOp->getType()) && + ExtractedIdx < + cast<FixedVectorType>(ExtVecOp->getType())->getNumElements()) { // TODO: Looking at the user(s) to determine if this insert is a // fold-to-shuffle opportunity does not match the usual instcombine // constraints. We should decide if the transform is worthy based only @@ -1079,7 +1120,7 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { // Try to form a shuffle from a chain of extract-insert ops. if (isShuffleRootCandidate(IE)) { - SmallVector<Constant*, 16> Mask; + SmallVector<int, 16> Mask; ShuffleOps LR = collectShuffleElements(&IE, Mask, nullptr, *this); // The proposed shuffle may be trivial, in which case we shouldn't @@ -1088,19 +1129,20 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { // We now have a shuffle of LHS, RHS, Mask. if (LR.second == nullptr) LR.second = UndefValue::get(LR.first->getType()); - return new ShuffleVectorInst(LR.first, LR.second, - ConstantVector::get(Mask)); + return new ShuffleVectorInst(LR.first, LR.second, Mask); } } } - unsigned VWidth = VecOp->getType()->getVectorNumElements(); - APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); - if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { - if (V != &IE) - return replaceInstUsesWith(IE, V); - return &IE; + if (auto VecTy = dyn_cast<FixedVectorType>(VecOp->getType())) { + unsigned VWidth = VecTy->getNumElements(); + APInt UndefElts(VWidth, 0); + APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { + if (V != &IE) + return replaceInstUsesWith(IE, V); + return &IE; + } } if (Instruction *Shuf = foldConstantInsEltIntoShuffle(IE)) @@ -1179,7 +1221,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // Bail out if we would create longer vector ops. We could allow creating // longer vector ops, but that may result in more expensive codegen. Type *ITy = I->getType(); - if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements()) + if (ITy->isVectorTy() && + Mask.size() > cast<VectorType>(ITy)->getNumElements()) return false; for (Value *Operand : I->operands()) { if (!canEvaluateShuffled(Operand, Mask, Depth - 1)) @@ -1267,9 +1310,9 @@ static Value *buildNew(Instruction *I, ArrayRef<Value*> NewOps) { case Instruction::FPExt: { // It's possible that the mask has a different number of elements from // the original cast. We recompute the destination type to match the mask. - Type *DestTy = - VectorType::get(I->getType()->getScalarType(), - NewOps[0]->getType()->getVectorNumElements()); + Type *DestTy = VectorType::get( + I->getType()->getScalarType(), + cast<VectorType>(NewOps[0]->getType())->getElementCount()); assert(NewOps.size() == 1 && "cast with #ops != 1"); return CastInst::Create(cast<CastInst>(I)->getOpcode(), NewOps[0], DestTy, "", I); @@ -1293,22 +1336,14 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { Type *EltTy = V->getType()->getScalarType(); Type *I32Ty = IntegerType::getInt32Ty(V->getContext()); if (isa<UndefValue>(V)) - return UndefValue::get(VectorType::get(EltTy, Mask.size())); + return UndefValue::get(FixedVectorType::get(EltTy, Mask.size())); if (isa<ConstantAggregateZero>(V)) - return ConstantAggregateZero::get(VectorType::get(EltTy, Mask.size())); + return ConstantAggregateZero::get(FixedVectorType::get(EltTy, Mask.size())); - if (Constant *C = dyn_cast<Constant>(V)) { - SmallVector<Constant *, 16> MaskValues; - for (int i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == -1) - MaskValues.push_back(UndefValue::get(I32Ty)); - else - MaskValues.push_back(ConstantInt::get(I32Ty, Mask[i])); - } + if (Constant *C = dyn_cast<Constant>(V)) return ConstantExpr::getShuffleVector(C, UndefValue::get(C->getType()), - ConstantVector::get(MaskValues)); - } + Mask); Instruction *I = cast<Instruction>(V); switch (I->getOpcode()) { @@ -1344,7 +1379,8 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { case Instruction::Select: case Instruction::GetElementPtr: { SmallVector<Value*, 8> NewOps; - bool NeedsRebuild = (Mask.size() != I->getType()->getVectorNumElements()); + bool NeedsRebuild = + (Mask.size() != cast<VectorType>(I->getType())->getNumElements()); for (int i = 0, e = I->getNumOperands(); i != e; ++i) { Value *V; // Recursively call evaluateInDifferentElementOrder on vector arguments @@ -1397,8 +1433,9 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // Shuffles to: |EE|FF|GG|HH| // +--+--+--+--+ static bool isShuffleExtractingFromLHS(ShuffleVectorInst &SVI, - SmallVector<int, 16> &Mask) { - unsigned LHSElems = SVI.getOperand(0)->getType()->getVectorNumElements(); + ArrayRef<int> Mask) { + unsigned LHSElems = + cast<VectorType>(SVI.getOperand(0)->getType())->getNumElements(); unsigned MaskElems = Mask.size(); unsigned BegIdx = Mask.front(); unsigned EndIdx = Mask.back(); @@ -1480,12 +1517,12 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { // Example: shuf (mul X, {-1,-2,-3,-4}), X, {0,5,6,3} --> mul X, {-1,1,1,-4} // Example: shuf X, (add X, {-1,-2,-3,-4}), {0,1,6,7} --> add X, {0,0,-3,-4} // The existing binop constant vector remains in the same operand position. - Constant *Mask = Shuf.getMask(); + ArrayRef<int> Mask = Shuf.getShuffleMask(); Constant *NewC = Op0IsBinop ? ConstantExpr::getShuffleVector(C, IdC, Mask) : ConstantExpr::getShuffleVector(IdC, C, Mask); bool MightCreatePoisonOrUB = - Mask->containsUndefElement() && + is_contained(Mask, UndefMaskElem) && (Instruction::isIntDivRem(BOpcode) || Instruction::isShift(BOpcode)); if (MightCreatePoisonOrUB) NewC = getSafeVectorConstantForBinop(BOpcode, NewC, true); @@ -1499,7 +1536,7 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { // An undef shuffle mask element may propagate as an undef constant element in // the new binop. That would produce poison where the original code might not. // If we already made a safe constant, then there's no danger. - if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) NewBO->dropPoisonGeneratingFlags(); return NewBO; } @@ -1511,14 +1548,14 @@ static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, InstCombiner::BuilderTy &Builder) { Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); - Constant *Mask = Shuf.getMask(); + ArrayRef<int> Mask = Shuf.getShuffleMask(); Value *X; uint64_t IndexC; // Match a shuffle that is a splat to a non-zero element. - if (!match(Op0, m_OneUse(m_InsertElement(m_Undef(), m_Value(X), - m_ConstantInt(IndexC)))) || - !match(Op1, m_Undef()) || match(Mask, m_ZeroInt()) || IndexC == 0) + if (!match(Op0, m_OneUse(m_InsertElt(m_Undef(), m_Value(X), + m_ConstantInt(IndexC)))) || + !match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0) return nullptr; // Insert into element 0 of an undef vector. @@ -1530,13 +1567,13 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // For example: // shuf (inselt undef, X, 2), undef, <2,2,undef> // --> shuf (inselt undef, X, 0), undef, <0,0,undef> - unsigned NumMaskElts = Shuf.getType()->getVectorNumElements(); - SmallVector<Constant *, 16> NewMask(NumMaskElts, Zero); + unsigned NumMaskElts = Shuf.getType()->getNumElements(); + SmallVector<int, 16> NewMask(NumMaskElts, 0); for (unsigned i = 0; i != NumMaskElts; ++i) - if (isa<UndefValue>(Mask->getAggregateElement(i))) - NewMask[i] = Mask->getAggregateElement(i); + if (Mask[i] == UndefMaskElem) + NewMask[i] = Mask[i]; - return new ShuffleVectorInst(NewIns, UndefVec, ConstantVector::get(NewMask)); + return new ShuffleVectorInst(NewIns, UndefVec, NewMask); } /// Try to fold shuffles that are the equivalent of a vector select. @@ -1548,7 +1585,7 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, // Canonicalize to choose from operand 0 first unless operand 1 is undefined. // Commuting undef to operand 0 conflicts with another canonicalization. - unsigned NumElts = Shuf.getType()->getVectorNumElements(); + unsigned NumElts = Shuf.getType()->getNumElements(); if (!isa<UndefValue>(Shuf.getOperand(1)) && Shuf.getMaskValue(0) >= (int)NumElts) { // TODO: Can we assert that both operands of a shuffle-select are not undef @@ -1605,14 +1642,14 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, BinaryOperator::BinaryOps BOpc = Opc0; // Select the constant elements needed for the single binop. - Constant *Mask = Shuf.getMask(); + ArrayRef<int> Mask = Shuf.getShuffleMask(); Constant *NewC = ConstantExpr::getShuffleVector(C0, C1, Mask); // We are moving a binop after a shuffle. When a shuffle has an undefined // mask element, the result is undefined, but it is not poison or undefined // behavior. That is not necessarily true for div/rem/shift. bool MightCreatePoisonOrUB = - Mask->containsUndefElement() && + is_contained(Mask, UndefMaskElem) && (Instruction::isIntDivRem(BOpc) || Instruction::isShift(BOpc)); if (MightCreatePoisonOrUB) NewC = getSafeVectorConstantForBinop(BOpc, NewC, ConstantsAreOp1); @@ -1661,11 +1698,53 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, NewBO->andIRFlags(B1); if (DropNSW) NewBO->setHasNoSignedWrap(false); - if (Mask->containsUndefElement() && !MightCreatePoisonOrUB) + if (is_contained(Mask, UndefMaskElem) && !MightCreatePoisonOrUB) NewBO->dropPoisonGeneratingFlags(); return NewBO; } +/// Convert a narrowing shuffle of a bitcasted vector into a vector truncate. +/// Example (little endian): +/// shuf (bitcast <4 x i16> X to <8 x i8>), <0, 2, 4, 6> --> trunc X to <4 x i8> +static Instruction *foldTruncShuffle(ShuffleVectorInst &Shuf, + bool IsBigEndian) { + // This must be a bitcasted shuffle of 1 vector integer operand. + Type *DestType = Shuf.getType(); + Value *X; + if (!match(Shuf.getOperand(0), m_BitCast(m_Value(X))) || + !match(Shuf.getOperand(1), m_Undef()) || !DestType->isIntOrIntVectorTy()) + return nullptr; + + // The source type must have the same number of elements as the shuffle, + // and the source element type must be larger than the shuffle element type. + Type *SrcType = X->getType(); + if (!SrcType->isVectorTy() || !SrcType->isIntOrIntVectorTy() || + cast<VectorType>(SrcType)->getNumElements() != + cast<VectorType>(DestType)->getNumElements() || + SrcType->getScalarSizeInBits() % DestType->getScalarSizeInBits() != 0) + return nullptr; + + assert(Shuf.changesLength() && !Shuf.increasesLength() && + "Expected a shuffle that decreases length"); + + // Last, check that the mask chooses the correct low bits for each narrow + // element in the result. + uint64_t TruncRatio = + SrcType->getScalarSizeInBits() / DestType->getScalarSizeInBits(); + ArrayRef<int> Mask = Shuf.getShuffleMask(); + for (unsigned i = 0, e = Mask.size(); i != e; ++i) { + if (Mask[i] == UndefMaskElem) + continue; + uint64_t LSBIndex = IsBigEndian ? (i + 1) * TruncRatio - 1 : i * TruncRatio; + assert(LSBIndex <= std::numeric_limits<int32_t>::max() && + "Overflowed 32-bits"); + if (Mask[i] != (int)LSBIndex) + return nullptr; + } + + return new TruncInst(X, DestType); +} + /// Match a shuffle-select-shuffle pattern where the shuffles are widening and /// narrowing (concatenating with undef and extracting back to the original /// length). This allows replacing the wide select with a narrow select. @@ -1685,19 +1764,19 @@ static Instruction *narrowVectorSelect(ShuffleVectorInst &Shuf, // We need a narrow condition value. It must be extended with undef elements // and have the same number of elements as this shuffle. - unsigned NarrowNumElts = Shuf.getType()->getVectorNumElements(); + unsigned NarrowNumElts = Shuf.getType()->getNumElements(); Value *NarrowCond; - if (!match(Cond, m_OneUse(m_ShuffleVector(m_Value(NarrowCond), m_Undef(), - m_Constant()))) || - NarrowCond->getType()->getVectorNumElements() != NarrowNumElts || + if (!match(Cond, m_OneUse(m_Shuffle(m_Value(NarrowCond), m_Undef()))) || + cast<VectorType>(NarrowCond->getType())->getNumElements() != + NarrowNumElts || !cast<ShuffleVectorInst>(Cond)->isIdentityWithPadding()) return nullptr; // shuf (sel (shuf NarrowCond, undef, WideMask), X, Y), undef, NarrowMask) --> // sel NarrowCond, (shuf X, undef, NarrowMask), (shuf Y, undef, NarrowMask) Value *Undef = UndefValue::get(X->getType()); - Value *NarrowX = Builder.CreateShuffleVector(X, Undef, Shuf.getMask()); - Value *NarrowY = Builder.CreateShuffleVector(Y, Undef, Shuf.getMask()); + Value *NarrowX = Builder.CreateShuffleVector(X, Undef, Shuf.getShuffleMask()); + Value *NarrowY = Builder.CreateShuffleVector(Y, Undef, Shuf.getShuffleMask()); return SelectInst::Create(NarrowCond, NarrowX, NarrowY); } @@ -1708,8 +1787,8 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { return nullptr; Value *X, *Y; - Constant *Mask; - if (!match(Op0, m_ShuffleVector(m_Value(X), m_Value(Y), m_Constant(Mask)))) + ArrayRef<int> Mask; + if (!match(Op0, m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) return nullptr; // Be conservative with shuffle transforms. If we can't kill the 1st shuffle, @@ -1728,30 +1807,32 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { // new shuffle mask. Otherwise, copy the original mask element. Example: // shuf (shuf X, Y, <C0, C1, C2, undef, C4>), undef, <0, undef, 2, 3> --> // shuf X, Y, <C0, undef, C2, undef> - unsigned NumElts = Shuf.getType()->getVectorNumElements(); - SmallVector<Constant *, 16> NewMask(NumElts); - assert(NumElts < Mask->getType()->getVectorNumElements() && + unsigned NumElts = Shuf.getType()->getNumElements(); + SmallVector<int, 16> NewMask(NumElts); + assert(NumElts < Mask.size() && "Identity with extract must have less elements than its inputs"); for (unsigned i = 0; i != NumElts; ++i) { - Constant *ExtractMaskElt = Shuf.getMask()->getAggregateElement(i); - Constant *MaskElt = Mask->getAggregateElement(i); - NewMask[i] = isa<UndefValue>(ExtractMaskElt) ? ExtractMaskElt : MaskElt; + int ExtractMaskElt = Shuf.getMaskValue(i); + int MaskElt = Mask[i]; + NewMask[i] = ExtractMaskElt == UndefMaskElem ? ExtractMaskElt : MaskElt; } - return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); + return new ShuffleVectorInst(X, Y, NewMask); } /// Try to replace a shuffle with an insertelement or try to replace a shuffle /// operand with the operand of an insertelement. -static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { +static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, + InstCombiner &IC) { Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); - SmallVector<int, 16> Mask = Shuf.getShuffleMask(); + SmallVector<int, 16> Mask; + Shuf.getShuffleMask(Mask); // The shuffle must not change vector sizes. // TODO: This restriction could be removed if the insert has only one use // (because the transform would require a new length-changing shuffle). int NumElts = Mask.size(); - if (NumElts != (int)(V0->getType()->getVectorNumElements())) + if (NumElts != (int)(cast<VectorType>(V0->getType())->getNumElements())) return nullptr; // This is a specialization of a fold in SimplifyDemandedVectorElts. We may @@ -1761,29 +1842,25 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { // operand with the source vector of the insertelement. Value *X; uint64_t IdxC; - if (match(V0, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + if (match(V0, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask - if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { - Shuf.setOperand(0, X); - return &Shuf; - } + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) + return IC.replaceOperand(Shuf, 0, X); } - if (match(V1, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { // Offset the index constant by the vector width because we are checking for // accesses to the 2nd vector input of the shuffle. IdxC += NumElts; // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask - if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { - Shuf.setOperand(1, X); - return &Shuf; - } + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) + return IC.replaceOperand(Shuf, 1, X); } // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { // We need an insertelement with a constant index. - if (!match(V0, m_InsertElement(m_Value(), m_Value(Scalar), - m_ConstantInt(IndexC)))) + if (!match(V0, m_InsertElt(m_Value(), m_Value(Scalar), + m_ConstantInt(IndexC)))) return false; // Test the shuffle mask to see if it splices the inserted scalar into the @@ -1850,9 +1927,9 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { Value *X = Shuffle0->getOperand(0); Value *Y = Shuffle1->getOperand(0); if (X->getType() != Y->getType() || - !isPowerOf2_32(Shuf.getType()->getVectorNumElements()) || - !isPowerOf2_32(Shuffle0->getType()->getVectorNumElements()) || - !isPowerOf2_32(X->getType()->getVectorNumElements()) || + !isPowerOf2_32(Shuf.getType()->getNumElements()) || + !isPowerOf2_32(Shuffle0->getType()->getNumElements()) || + !isPowerOf2_32(cast<VectorType>(X->getType())->getNumElements()) || isa<UndefValue>(X) || isa<UndefValue>(Y)) return nullptr; assert(isa<UndefValue>(Shuffle0->getOperand(1)) && @@ -1863,13 +1940,12 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // operands directly by adjusting the shuffle mask to account for the narrower // types: // shuf (widen X), (widen Y), Mask --> shuf X, Y, Mask' - int NarrowElts = X->getType()->getVectorNumElements(); - int WideElts = Shuffle0->getType()->getVectorNumElements(); + int NarrowElts = cast<VectorType>(X->getType())->getNumElements(); + int WideElts = Shuffle0->getType()->getNumElements(); assert(WideElts > NarrowElts && "Unexpected types for identity with padding"); - Type *I32Ty = IntegerType::getInt32Ty(Shuf.getContext()); - SmallVector<int, 16> Mask = Shuf.getShuffleMask(); - SmallVector<Constant *, 16> NewMask(Mask.size(), UndefValue::get(I32Ty)); + ArrayRef<int> Mask = Shuf.getShuffleMask(); + SmallVector<int, 16> NewMask(Mask.size(), -1); for (int i = 0, e = Mask.size(); i != e; ++i) { if (Mask[i] == -1) continue; @@ -1889,42 +1965,71 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { // element is offset down to adjust for the narrow vector widths. if (Mask[i] < WideElts) { assert(Mask[i] < NarrowElts && "Unexpected shuffle mask"); - NewMask[i] = ConstantInt::get(I32Ty, Mask[i]); + NewMask[i] = Mask[i]; } else { assert(Mask[i] < (WideElts + NarrowElts) && "Unexpected shuffle mask"); - NewMask[i] = ConstantInt::get(I32Ty, Mask[i] - (WideElts - NarrowElts)); + NewMask[i] = Mask[i] - (WideElts - NarrowElts); } } - return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); + return new ShuffleVectorInst(X, Y, NewMask); } Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); - if (auto *V = SimplifyShuffleVectorInst( - LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) + SimplifyQuery ShufQuery = SQ.getWithInstruction(&SVI); + if (auto *V = SimplifyShuffleVectorInst(LHS, RHS, SVI.getShuffleMask(), + SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); // shuffle x, x, mask --> shuffle x, undef, mask' - unsigned VWidth = SVI.getType()->getVectorNumElements(); - unsigned LHSWidth = LHS->getType()->getVectorNumElements(); - SmallVector<int, 16> Mask = SVI.getShuffleMask(); + unsigned VWidth = SVI.getType()->getNumElements(); + unsigned LHSWidth = cast<VectorType>(LHS->getType())->getNumElements(); + ArrayRef<int> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); + + // Peek through a bitcasted shuffle operand by scaling the mask. If the + // simulated shuffle can simplify, then this shuffle is unnecessary: + // shuf (bitcast X), undef, Mask --> bitcast X' + // TODO: This could be extended to allow length-changing shuffles. + // The transform might also be obsoleted if we allowed canonicalization + // of bitcasted shuffles. + Value *X; + if (match(LHS, m_BitCast(m_Value(X))) && match(RHS, m_Undef()) && + X->getType()->isVectorTy() && VWidth == LHSWidth) { + // Try to create a scaled mask constant. + auto *XType = cast<VectorType>(X->getType()); + unsigned XNumElts = XType->getNumElements(); + SmallVector<int, 16> ScaledMask; + if (XNumElts >= VWidth) { + assert(XNumElts % VWidth == 0 && "Unexpected vector bitcast"); + narrowShuffleMaskElts(XNumElts / VWidth, Mask, ScaledMask); + } else { + assert(VWidth % XNumElts == 0 && "Unexpected vector bitcast"); + if (!widenShuffleMaskElts(VWidth / XNumElts, Mask, ScaledMask)) + ScaledMask.clear(); + } + if (!ScaledMask.empty()) { + // If the shuffled source vector simplifies, cast that value to this + // shuffle's type. + if (auto *V = SimplifyShuffleVectorInst(X, UndefValue::get(XType), + ScaledMask, XType, ShufQuery)) + return BitCastInst::Create(Instruction::BitCast, V, SVI.getType()); + } + } + if (LHS == RHS) { assert(!isa<UndefValue>(RHS) && "Shuffle with 2 undef ops not simplified?"); // Remap any references to RHS to use LHS. - SmallVector<Constant*, 16> Elts; + SmallVector<int, 16> Elts; for (unsigned i = 0; i != VWidth; ++i) { // Propagate undef elements or force mask to LHS. if (Mask[i] < 0) - Elts.push_back(UndefValue::get(Int32Ty)); + Elts.push_back(UndefMaskElem); else - Elts.push_back(ConstantInt::get(Int32Ty, Mask[i] % LHSWidth)); + Elts.push_back(Mask[i] % LHSWidth); } - SVI.setOperand(0, SVI.getOperand(1)); - SVI.setOperand(1, UndefValue::get(RHS->getType())); - SVI.setOperand(2, ConstantVector::get(Elts)); - return &SVI; + return new ShuffleVectorInst(LHS, UndefValue::get(RHS->getType()), Elts); } // shuffle undef, x, mask --> shuffle x, undef, mask' @@ -1939,6 +2044,9 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = foldSelectShuffle(SVI, Builder, DL)) return I; + if (Instruction *I = foldTruncShuffle(SVI, DL.isBigEndian())) + return I; + if (Instruction *I = narrowVectorSelect(SVI, Builder)) return I; @@ -1955,7 +2063,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { // These transforms have the potential to lose undef knowledge, so they are // intentionally placed after SimplifyDemandedVectorElts(). - if (Instruction *I = foldShuffleWithInsert(SVI)) + if (Instruction *I = foldShuffleWithInsert(SVI, *this)) return I; if (Instruction *I = foldIdentityPaddedShuffles(SVI)) return I; @@ -1999,7 +2107,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *V = LHS; unsigned MaskElems = Mask.size(); VectorType *SrcTy = cast<VectorType>(V->getType()); - unsigned VecBitWidth = SrcTy->getBitWidth(); + unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); assert(SrcElemBitWidth && "vector elements must have a bitwidth"); unsigned SrcNumElems = SrcTy->getNumElements(); @@ -2023,16 +2131,15 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { continue; if (!VectorType::isValidElementType(TgtTy)) continue; - VectorType *CastSrcTy = VectorType::get(TgtTy, TgtNumElems); + auto *CastSrcTy = FixedVectorType::get(TgtTy, TgtNumElems); if (!BegIsAligned) { // Shuffle the input so [0,NumElements) contains the output, and // [NumElems,SrcNumElems) is undef. - SmallVector<Constant *, 16> ShuffleMask(SrcNumElems, - UndefValue::get(Int32Ty)); + SmallVector<int, 16> ShuffleMask(SrcNumElems, -1); for (unsigned I = 0, E = MaskElems, Idx = BegIdx; I != E; ++Idx, ++I) - ShuffleMask[I] = ConstantInt::get(Int32Ty, Idx); + ShuffleMask[I] = Idx; V = Builder.CreateShuffleVector(V, UndefValue::get(V->getType()), - ConstantVector::get(ShuffleMask), + ShuffleMask, SVI.getName() + ".extract"); BegIdx = 0; } @@ -2117,11 +2224,11 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHSShuffle) { LHSOp0 = LHSShuffle->getOperand(0); LHSOp1 = LHSShuffle->getOperand(1); - LHSOp0Width = LHSOp0->getType()->getVectorNumElements(); + LHSOp0Width = cast<VectorType>(LHSOp0->getType())->getNumElements(); } if (RHSShuffle) { RHSOp0 = RHSShuffle->getOperand(0); - RHSOp0Width = RHSOp0->getType()->getVectorNumElements(); + RHSOp0Width = cast<VectorType>(RHSOp0->getType())->getNumElements(); } Value* newLHS = LHS; Value* newRHS = RHS; @@ -2149,8 +2256,8 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (newLHS == LHS && newRHS == RHS) return MadeChange ? &SVI : nullptr; - SmallVector<int, 16> LHSMask; - SmallVector<int, 16> RHSMask; + ArrayRef<int> LHSMask; + ArrayRef<int> RHSMask; if (newLHS != LHS) LHSMask = LHSShuffle->getShuffleMask(); if (RHSShuffle && newRHS != RHS) diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 801c09a317a7f..b3254c10a0b2b 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -60,6 +60,7 @@ #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constant.h" @@ -129,10 +130,6 @@ static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); -static cl::opt<bool> -EnableExpensiveCombines("expensive-combines", - cl::desc("Enable expensive instruction combines")); - static cl::opt<unsigned> LimitMaxIterations( "instcombine-max-iterations", cl::desc("Limit the maximum number of instruction combining iterations"), @@ -267,7 +264,7 @@ static void ClearSubclassDataAfterReassociation(BinaryOperator &I) { /// cast to eliminate one of the associative operations: /// (op (cast (op X, C2)), C1) --> (cast (op X, op (C1, C2))) /// (op (cast (op X, C2)), C1) --> (op (cast X), op (C1, C2)) -static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { +static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, InstCombiner &IC) { auto *Cast = dyn_cast<CastInst>(BinOp1->getOperand(0)); if (!Cast || !Cast->hasOneUse()) return false; @@ -300,8 +297,8 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1) { Type *DestTy = C1->getType(); Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); Constant *FoldedC = ConstantExpr::get(AssocOpcode, C1, CastC2); - Cast->setOperand(0, BinOp2->getOperand(0)); - BinOp1->setOperand(1, FoldedC); + IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); + IC.replaceOperand(*BinOp1, 1, FoldedC); return true; } @@ -350,8 +347,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Does "B op C" simplify? if (Value *V = SimplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "A op V". - I.setOperand(0, A); - I.setOperand(1, V); + replaceOperand(I, 0, A); + replaceOperand(I, 1, V); bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0); bool IsNSW = maintainNoSignedWrap(I, B, C) && hasNoSignedWrap(*Op0); @@ -383,8 +380,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Does "A op B" simplify? if (Value *V = SimplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op C". - I.setOperand(0, V); - I.setOperand(1, C); + replaceOperand(I, 0, V); + replaceOperand(I, 1, C); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -396,7 +393,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { } if (I.isAssociative() && I.isCommutative()) { - if (simplifyAssocCastAssoc(&I)) { + if (simplifyAssocCastAssoc(&I, *this)) { Changed = true; ++NumReassoc; continue; @@ -411,8 +408,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Does "C op A" simplify? if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "V op B". - I.setOperand(0, V); - I.setOperand(1, B); + replaceOperand(I, 0, V); + replaceOperand(I, 1, B); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -431,8 +428,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // Does "C op A" simplify? if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) { // It simplifies to V. Form "B op V". - I.setOperand(0, B); - I.setOperand(1, V); + replaceOperand(I, 0, B); + replaceOperand(I, 1, V); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -465,8 +462,8 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { } InsertNewInstWith(NewBO, I); NewBO->takeName(Op1); - I.setOperand(0, NewBO); - I.setOperand(1, ConstantExpr::get(Opcode, C1, C2)); + replaceOperand(I, 0, NewBO); + replaceOperand(I, 1, ConstantExpr::get(Opcode, C1, C2)); // Conservatively clear the optional flags, since they may not be // preserved by the reassociation. ClearSubclassDataAfterReassociation(I); @@ -925,8 +922,31 @@ Instruction *InstCombiner::FoldOpIntoSelect(Instruction &Op, SelectInst *SI) { if (auto *CI = dyn_cast<CmpInst>(SI->getCondition())) { if (CI->hasOneUse()) { Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + + // FIXME: This is a hack to avoid infinite looping with min/max patterns. + // We have to ensure that vector constants that only differ with + // undef elements are treated as equivalent. + auto areLooselyEqual = [](Value *A, Value *B) { + if (A == B) + return true; + + // Test for vector constants. + Constant *ConstA, *ConstB; + if (!match(A, m_Constant(ConstA)) || !match(B, m_Constant(ConstB))) + return false; + + // TODO: Deal with FP constants? + if (!A->getType()->isIntOrIntVectorTy() || A->getType() != B->getType()) + return false; + + // Compare for equality including undefs as equal. + auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); + const APInt *C; + return match(Cmp, m_APIntAllowUndef(C)) && C->isOneValue(); + }; + + if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || + (areLooselyEqual(FV, Op0) && areLooselyEqual(TV, Op1))) return nullptr; } } @@ -951,7 +971,7 @@ static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, if (!ConstIsRHS) std::swap(Op0, Op1); - Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phitmp"); + Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phi.bo"); auto *FPInst = dyn_cast<Instruction>(RI); if (FPInst && isa<FPMathOperator>(FPInst)) FPInst->copyFastMathFlags(I); @@ -1056,7 +1076,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { // the select would be generated exactly once in the NonConstBB. Builder.SetInsertPoint(ThisBB->getTerminator()); InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred, - FalseVInPred, "phitmp"); + FalseVInPred, "phi.sel"); } NewPN->addIncoming(InV, ThisBB); } @@ -1064,14 +1084,11 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { Constant *C = cast<Constant>(I.getOperand(1)); for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InV = nullptr; - if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) + if (auto *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); - else if (isa<ICmpInst>(CI)) - InV = Builder.CreateICmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phitmp"); else - InV = Builder.CreateFCmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phitmp"); + InV = Builder.CreateCmp(CI->getPredicate(), PN->getIncomingValue(i), + C, "phi.cmp"); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { @@ -1089,7 +1106,7 @@ Instruction *InstCombiner::foldOpIntoPhi(Instruction &I, PHINode *PN) { InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); else InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i), - I.getType(), "phitmp"); + I.getType(), "phi.cast"); NewPN->addIncoming(InV, PN->getIncomingBlock(i)); } } @@ -1391,8 +1408,8 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { assert(Parent.first->hasOneUse() && "Drilled down when more than one use!"); assert(Op != Parent.first->getOperand(Parent.second) && "Descaling was a no-op?"); - Parent.first->setOperand(Parent.second, Op); - Worklist.Add(Parent.first); + replaceOperand(*Parent.first, Parent.second, Op); + Worklist.push(Parent.first); // Now work back up the expression correcting nsw flags. The logic is based // on the following observation: if X * Y is known not to overflow as a signed @@ -1410,7 +1427,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { NoSignedWrap &= OpNoSignedWrap; if (NoSignedWrap != OpNoSignedWrap) { BO->setHasNoSignedWrap(NoSignedWrap); - Worklist.Add(Ancestor); + Worklist.push(Ancestor); } } else if (Ancestor->getOpcode() == Instruction::Trunc) { // The fact that the descaled input to the trunc has smaller absolute @@ -1432,21 +1449,24 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { - if (!Inst.getType()->isVectorTy()) return nullptr; + // FIXME: some of this is likely fine for scalable vectors + if (!isa<FixedVectorType>(Inst.getType())) + return nullptr; BinaryOperator::BinaryOps Opcode = Inst.getOpcode(); - unsigned NumElts = cast<VectorType>(Inst.getType())->getNumElements(); Value *LHS = Inst.getOperand(0), *RHS = Inst.getOperand(1); - assert(cast<VectorType>(LHS->getType())->getNumElements() == NumElts); - assert(cast<VectorType>(RHS->getType())->getNumElements() == NumElts); + assert(cast<VectorType>(LHS->getType())->getElementCount() == + cast<VectorType>(Inst.getType())->getElementCount()); + assert(cast<VectorType>(RHS->getType())->getElementCount() == + cast<VectorType>(Inst.getType())->getElementCount()); // If both operands of the binop are vector concatenations, then perform the // narrow binop on each pair of the source operands followed by concatenation // of the results. Value *L0, *L1, *R0, *R1; - Constant *Mask; - if (match(LHS, m_ShuffleVector(m_Value(L0), m_Value(L1), m_Constant(Mask))) && - match(RHS, m_ShuffleVector(m_Value(R0), m_Value(R1), m_Specific(Mask))) && + ArrayRef<int> Mask; + if (match(LHS, m_Shuffle(m_Value(L0), m_Value(L1), m_Mask(Mask))) && + match(RHS, m_Shuffle(m_Value(R0), m_Value(R1), m_SpecificMask(Mask))) && LHS->hasOneUse() && RHS->hasOneUse() && cast<ShuffleVectorInst>(LHS)->isConcat() && cast<ShuffleVectorInst>(RHS)->isConcat()) { @@ -1470,7 +1490,7 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { if (!isSafeToSpeculativelyExecute(&Inst)) return nullptr; - auto createBinOpShuffle = [&](Value *X, Value *Y, Constant *M) { + auto createBinOpShuffle = [&](Value *X, Value *Y, ArrayRef<int> M) { Value *XY = Builder.CreateBinOp(Opcode, X, Y); if (auto *BO = dyn_cast<BinaryOperator>(XY)) BO->copyIRFlags(&Inst); @@ -1480,8 +1500,8 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // If both arguments of the binary operation are shuffles that use the same // mask and shuffle within a single vector, move the shuffle after the binop. Value *V1, *V2; - if (match(LHS, m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))) && - match(RHS, m_ShuffleVector(m_Value(V2), m_Undef(), m_Specific(Mask))) && + if (match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))) && + match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(Mask))) && V1->getType() == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse() || LHS == RHS)) { // Op(shuffle(V1, Mask), shuffle(V2, Mask)) -> shuffle(Op(V1, V2), Mask) @@ -1491,17 +1511,19 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // If both arguments of a commutative binop are select-shuffles that use the // same mask with commuted operands, the shuffles are unnecessary. if (Inst.isCommutative() && - match(LHS, m_ShuffleVector(m_Value(V1), m_Value(V2), m_Constant(Mask))) && - match(RHS, m_ShuffleVector(m_Specific(V2), m_Specific(V1), - m_Specific(Mask)))) { + match(LHS, m_Shuffle(m_Value(V1), m_Value(V2), m_Mask(Mask))) && + match(RHS, + m_Shuffle(m_Specific(V2), m_Specific(V1), m_SpecificMask(Mask)))) { auto *LShuf = cast<ShuffleVectorInst>(LHS); auto *RShuf = cast<ShuffleVectorInst>(RHS); // TODO: Allow shuffles that contain undefs in the mask? // That is legal, but it reduces undef knowledge. // TODO: Allow arbitrary shuffles by shuffling after binop? // That might be legal, but we have to deal with poison. - if (LShuf->isSelect() && !LShuf->getMask()->containsUndefElement() && - RShuf->isSelect() && !RShuf->getMask()->containsUndefElement()) { + if (LShuf->isSelect() && + !is_contained(LShuf->getShuffleMask(), UndefMaskElem) && + RShuf->isSelect() && + !is_contained(RShuf->getShuffleMask(), UndefMaskElem)) { // Example: // LHS = shuffle V1, V2, <0, 5, 6, 3> // RHS = shuffle V2, V1, <0, 5, 6, 3> @@ -1517,11 +1539,12 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // intends to move shuffles closer to other shuffles and binops closer to // other binops, so they can be folded. It may also enable demanded elements // transforms. + unsigned NumElts = cast<FixedVectorType>(Inst.getType())->getNumElements(); Constant *C; - if (match(&Inst, m_c_BinOp( - m_OneUse(m_ShuffleVector(m_Value(V1), m_Undef(), m_Constant(Mask))), - m_Constant(C))) && - V1->getType()->getVectorNumElements() <= NumElts) { + if (match(&Inst, + m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))), + m_Constant(C))) && + cast<FixedVectorType>(V1->getType())->getNumElements() <= NumElts) { assert(Inst.getType()->getScalarType() == V1->getType()->getScalarType() && "Shuffle should not change scalar type"); @@ -1531,9 +1554,9 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // reorder is not possible. A 1-to-1 mapping is not required. Example: // ShMask = <1,1,2,2> and C = <5,5,6,6> --> NewC = <undef,5,6,undef> bool ConstOp1 = isa<Constant>(RHS); - SmallVector<int, 16> ShMask; - ShuffleVectorInst::getShuffleMask(Mask, ShMask); - unsigned SrcVecNumElts = V1->getType()->getVectorNumElements(); + ArrayRef<int> ShMask = Mask; + unsigned SrcVecNumElts = + cast<FixedVectorType>(V1->getType())->getNumElements(); UndefValue *UndefScalar = UndefValue::get(C->getType()->getScalarType()); SmallVector<Constant *, 16> NewVecC(SrcVecNumElts, UndefScalar); bool MayChange = true; @@ -1590,6 +1613,57 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { } } + // Try to reassociate to sink a splat shuffle after a binary operation. + if (Inst.isAssociative() && Inst.isCommutative()) { + // Canonicalize shuffle operand as LHS. + if (isa<ShuffleVectorInst>(RHS)) + std::swap(LHS, RHS); + + Value *X; + ArrayRef<int> MaskC; + int SplatIndex; + BinaryOperator *BO; + if (!match(LHS, + m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(MaskC)))) || + !match(MaskC, m_SplatOrUndefMask(SplatIndex)) || + X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || + BO->getOpcode() != Opcode) + return nullptr; + + // FIXME: This may not be safe if the analysis allows undef elements. By + // moving 'Y' before the splat shuffle, we are implicitly assuming + // that it is not undef/poison at the splat index. + Value *Y, *OtherOp; + if (isSplatValue(BO->getOperand(0), SplatIndex)) { + Y = BO->getOperand(0); + OtherOp = BO->getOperand(1); + } else if (isSplatValue(BO->getOperand(1), SplatIndex)) { + Y = BO->getOperand(1); + OtherOp = BO->getOperand(0); + } else { + return nullptr; + } + + // X and Y are splatted values, so perform the binary operation on those + // values followed by a splat followed by the 2nd binary operation: + // bo (splat X), (bo Y, OtherOp) --> bo (splat (bo X, Y)), OtherOp + Value *NewBO = Builder.CreateBinOp(Opcode, X, Y); + UndefValue *Undef = UndefValue::get(Inst.getType()); + SmallVector<int, 8> NewMask(MaskC.size(), SplatIndex); + Value *NewSplat = Builder.CreateShuffleVector(NewBO, Undef, NewMask); + Instruction *R = BinaryOperator::Create(Opcode, NewSplat, OtherOp); + + // Intersect FMF on both new binops. Other (poison-generating) flags are + // dropped to be safe. + if (isa<FPMathOperator>(R)) { + R->copyFastMathFlags(&Inst); + R->andIRFlags(BO); + } + if (auto *NewInstBO = dyn_cast<BinaryOperator>(NewBO)) + NewInstBO->copyIRFlags(R); + return R; + } + return nullptr; } @@ -1658,16 +1732,46 @@ static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) { (GEP2.isInBounds() || GEP2.hasAllZeroIndices()); } +/// Thread a GEP operation with constant indices through the constant true/false +/// arms of a select. +static Instruction *foldSelectGEP(GetElementPtrInst &GEP, + InstCombiner::BuilderTy &Builder) { + if (!GEP.hasAllConstantIndices()) + return nullptr; + + Instruction *Sel; + Value *Cond; + Constant *TrueC, *FalseC; + if (!match(GEP.getPointerOperand(), m_Instruction(Sel)) || + !match(Sel, + m_Select(m_Value(Cond), m_Constant(TrueC), m_Constant(FalseC)))) + return nullptr; + + // gep (select Cond, TrueC, FalseC), IndexC --> select Cond, TrueC', FalseC' + // Propagate 'inbounds' and metadata from existing instructions. + // Note: using IRBuilder to create the constants for efficiency. + SmallVector<Value *, 4> IndexC(GEP.idx_begin(), GEP.idx_end()); + bool IsInBounds = GEP.isInBounds(); + Value *NewTrueC = IsInBounds ? Builder.CreateInBoundsGEP(TrueC, IndexC) + : Builder.CreateGEP(TrueC, IndexC); + Value *NewFalseC = IsInBounds ? Builder.CreateInBoundsGEP(FalseC, IndexC) + : Builder.CreateGEP(FalseC, IndexC); + return SelectInst::Create(Cond, NewTrueC, NewFalseC, "", nullptr, Sel); +} + Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); + bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); // For vector geps, use the generic demanded vector support. - if (GEP.getType()->isVectorTy()) { - auto VWidth = GEP.getType()->getVectorNumElements(); + // Skip if GEP return type is scalable. The number of elements is unknown at + // compile-time. + if (auto *GEPFVTy = dyn_cast<FixedVectorType>(GEPType)) { + auto VWidth = GEPFVTy->getNumElements(); APInt UndefElts(VWidth, 0); APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&GEP, AllOnesEltMask, @@ -1679,7 +1783,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // TODO: 1) Scalarize splat operands, 2) scalarize entire instruction if // possible (decide on canonical form for pointer broadcast), 3) exploit - // undef elements to decrease demanded bits + // undef elements to decrease demanded bits } Value *PtrOp = GEP.getOperand(0); @@ -1703,13 +1807,14 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { Type *IndexTy = (*I)->getType(); Type *NewIndexType = IndexTy->isVectorTy() - ? VectorType::get(NewScalarIndexTy, IndexTy->getVectorNumElements()) + ? VectorType::get(NewScalarIndexTy, + cast<VectorType>(IndexTy)->getElementCount()) : NewScalarIndexTy; // If the element type has zero size then any index over it is equivalent // to an index of zero, so replace it with zero if it is not zero already. Type *EltTy = GTI.getIndexedType(); - if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) + if (EltTy->isSized() && DL.getTypeAllocSize(EltTy).isZero()) if (!isa<Constant>(*I) || !match(I->get(), m_Zero())) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; @@ -1789,10 +1894,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (J > 0) { if (J == 1) { CurTy = Op1->getSourceElementType(); - } else if (auto *CT = dyn_cast<CompositeType>(CurTy)) { - CurTy = CT->getTypeAtIndex(Op1->getOperand(J)); } else { - CurTy = nullptr; + CurTy = + GetElementPtrInst::getTypeAtIndex(CurTy, Op1->getOperand(J)); } } } @@ -1808,8 +1912,6 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (DI == -1) { // All the GEPs feeding the PHI are identical. Clone one down into our // BB so that it can be merged with the current GEP. - GEP.getParent()->getInstList().insert( - GEP.getParent()->getFirstInsertionPt(), NewGEP); } else { // All the GEPs feeding the PHI differ at a single offset. Clone a GEP // into the current block so it can be merged, and create a new PHI to @@ -1827,12 +1929,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { PN->getIncomingBlock(I)); NewGEP->setOperand(DI, NewPN); - GEP.getParent()->getInstList().insert( - GEP.getParent()->getFirstInsertionPt(), NewGEP); - NewGEP->setOperand(DI, NewPN); } - GEP.setOperand(0, NewGEP); + GEP.getParent()->getInstList().insert( + GEP.getParent()->getFirstInsertionPt(), NewGEP); + replaceOperand(GEP, 0, NewGEP); PtrOp = NewGEP; } @@ -1932,8 +2033,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Update the GEP in place if possible. if (Src->getNumOperands() == 2) { GEP.setIsInBounds(isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))); - GEP.setOperand(0, Src->getOperand(0)); - GEP.setOperand(1, Sum); + replaceOperand(GEP, 0, Src->getOperand(0)); + replaceOperand(GEP, 1, Sum); return &GEP; } Indices.append(Src->op_begin()+1, Src->op_end()-1); @@ -1957,11 +2058,13 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { GEP.getName()); } - if (GEP.getNumIndices() == 1) { + // Skip if GEP source element type is scalable. The type alloc size is unknown + // at compile-time. + if (GEP.getNumIndices() == 1 && !IsGEPSrcEleScalable) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) { - uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType); + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); bool Matched = false; uint64_t C; @@ -2051,9 +2154,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // array. Because the array type is never stepped over (there // is a leading zero) we can fold the cast into this GEP. if (StrippedPtrTy->getAddressSpace() == GEP.getAddressSpace()) { - GEP.setOperand(0, StrippedPtr); GEP.setSourceElementType(XATy); - return &GEP; + return replaceOperand(GEP, 0, StrippedPtr); } // Cannot replace the base pointer directly because StrippedPtr's // address space is different. Instead, create a new GEP followed by @@ -2075,10 +2177,12 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } } - } else if (GEP.getNumOperands() == 2) { - // Transform things like: - // %t = getelementptr i32* bitcast ([2 x i32]* %str to i32*), i32 %V - // into: %t1 = getelementptr [2 x i32]* %str, i32 0, i32 %V; bitcast + } else if (GEP.getNumOperands() == 2 && !IsGEPSrcEleScalable) { + // Skip if GEP source element type is scalable. The type alloc size is + // unknown at compile-time. + // Transform things like: %t = getelementptr i32* + // bitcast ([2 x i32]* %str to i32*), i32 %V into: %t1 = getelementptr [2 + // x i32]* %str, i32 0, i32 %V; bitcast if (StrippedPtrEltTy->isArrayTy() && DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) == DL.getTypeAllocSize(GEPEltType)) { @@ -2102,8 +2206,8 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); - uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy).getFixedSize(); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2142,9 +2246,10 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { StrippedPtrEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); uint64_t ArrayEltSize = - DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()); + DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) + .getFixedSize(); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2203,8 +2308,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, const DataLayout &DL) { - return ArrTy->getArrayElementType() == VecTy->getVectorElementType() && - ArrTy->getArrayNumElements() == VecTy->getVectorNumElements() && + auto *VecVTy = cast<VectorType>(VecTy); + return ArrTy->getArrayElementType() == VecVTy->getElementType() && + ArrTy->getArrayNumElements() == VecVTy->getNumElements() && DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); }; if (GEP.getNumOperands() == 3 && @@ -2291,7 +2397,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (auto *AI = dyn_cast<AllocaInst>(UnderlyingPtrOp)) { if (GEP.accumulateConstantOffset(DL, BasePtrOffset) && BasePtrOffset.isNonNegative()) { - APInt AllocSize(IdxWidth, DL.getTypeAllocSize(AI->getAllocatedType())); + APInt AllocSize( + IdxWidth, + DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinSize()); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( GEP.getSourceElementType(), PtrOp, makeArrayRef(Ops).slice(1), @@ -2301,6 +2409,9 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } + if (Instruction *R = foldSelectGEP(GEP, Builder)) + return R; + return nullptr; } @@ -2369,6 +2480,7 @@ static bool isAllocSiteRemovable(Instruction *AI, return false; LLVM_FALLTHROUGH; } + case Intrinsic::assume: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: @@ -2517,7 +2629,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, // If there are more than 2 instructions, check that they are noops // i.e., they won't hurt the performance of the generated code. if (FreeInstrBB->size() != 2) { - for (const Instruction &Inst : *FreeInstrBB) { + for (const Instruction &Inst : FreeInstrBB->instructionsWithoutDebug()) { if (&Inst == &FI || &Inst == FreeInstrBBTerminator) continue; auto *Cast = dyn_cast<CastInst>(&Inst); @@ -2579,60 +2691,108 @@ Instruction *InstCombiner::visitFree(CallInst &FI) { // if (foo) free(foo); // into // free(foo); - if (MinimizeSize) - if (Instruction *I = tryToMoveFreeBeforeNullTest(FI, DL)) - return I; + // + // Note that we can only do this for 'free' and not for any flavor of + // 'operator delete'; there is no 'operator delete' symbol for which we are + // permitted to invent a call, even if we're passing in a null pointer. + if (MinimizeSize) { + LibFunc Func; + if (TLI.getLibFunc(FI, Func) && TLI.has(Func) && Func == LibFunc_free) + if (Instruction *I = tryToMoveFreeBeforeNullTest(FI, DL)) + return I; + } return nullptr; } +static bool isMustTailCall(Value *V) { + if (auto *CI = dyn_cast<CallInst>(V)) + return CI->isMustTailCall(); + return false; +} + Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { if (RI.getNumOperands() == 0) // ret void return nullptr; Value *ResultOp = RI.getOperand(0); Type *VTy = ResultOp->getType(); - if (!VTy->isIntegerTy()) + if (!VTy->isIntegerTy() || isa<Constant>(ResultOp)) + return nullptr; + + // Don't replace result of musttail calls. + if (isMustTailCall(ResultOp)) return nullptr; // There might be assume intrinsics dominating this return that completely // determine the value. If so, constant fold it. KnownBits Known = computeKnownBits(ResultOp, 0, &RI); if (Known.isConstant()) - RI.setOperand(0, Constant::getIntegerValue(VTy, Known.getConstant())); + return replaceOperand(RI, 0, + Constant::getIntegerValue(VTy, Known.getConstant())); + + return nullptr; +} + +Instruction *InstCombiner::visitUnconditionalBranchInst(BranchInst &BI) { + assert(BI.isUnconditional() && "Only for unconditional branches."); + + // If this store is the second-to-last instruction in the basic block + // (excluding debug info and bitcasts of pointers) and if the block ends with + // an unconditional branch, try to move the store to the successor block. + + auto GetLastSinkableStore = [](BasicBlock::iterator BBI) { + auto IsNoopInstrForStoreMerging = [](BasicBlock::iterator BBI) { + return isa<DbgInfoIntrinsic>(BBI) || + (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy()); + }; + + BasicBlock::iterator FirstInstr = BBI->getParent()->begin(); + do { + if (BBI != FirstInstr) + --BBI; + } while (BBI != FirstInstr && IsNoopInstrForStoreMerging(BBI)); + + return dyn_cast<StoreInst>(BBI); + }; + + if (StoreInst *SI = GetLastSinkableStore(BasicBlock::iterator(BI))) + if (mergeStoreIntoSuccessor(*SI)) + return &BI; return nullptr; } Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { + if (BI.isUnconditional()) + return visitUnconditionalBranchInst(BI); + // Change br (not X), label True, label False to: br X, label False, True Value *X = nullptr; if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && !isa<Constant>(X)) { // Swap Destinations and condition... - BI.setCondition(X); BI.swapSuccessors(); - return &BI; + return replaceOperand(BI, 0, X); } // If the condition is irrelevant, remove the use so that other // transforms on the condition become more effective. - if (BI.isConditional() && !isa<ConstantInt>(BI.getCondition()) && - BI.getSuccessor(0) == BI.getSuccessor(1)) { - BI.setCondition(ConstantInt::getFalse(BI.getCondition()->getType())); - return &BI; - } + if (!isa<ConstantInt>(BI.getCondition()) && + BI.getSuccessor(0) == BI.getSuccessor(1)) + return replaceOperand( + BI, 0, ConstantInt::getFalse(BI.getCondition()->getType())); - // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. + // Canonicalize, for example, fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), + if (match(&BI, m_Br(m_OneUse(m_FCmp(Pred, m_Value(), m_Value())), m_BasicBlock(), m_BasicBlock())) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. CmpInst *Cond = cast<CmpInst>(BI.getCondition()); Cond->setPredicate(CmpInst::getInversePredicate(Pred)); BI.swapSuccessors(); - Worklist.Add(Cond); + Worklist.push(Cond); return &BI; } @@ -2651,8 +2811,7 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { "Result of expression should be constant"); Case.setValue(cast<ConstantInt>(NewCase)); } - SI.setCondition(Op0); - return &SI; + return replaceOperand(SI, 0, Op0); } KnownBits Known = computeKnownBits(Cond, 0, &SI); @@ -2679,13 +2838,12 @@ Instruction *InstCombiner::visitSwitchInst(SwitchInst &SI) { IntegerType *Ty = IntegerType::get(SI.getContext(), NewWidth); Builder.SetInsertPoint(&SI); Value *NewCond = Builder.CreateTrunc(Cond, Ty, "trunc"); - SI.setCondition(NewCond); for (auto Case : SI.cases()) { APInt TruncatedCase = Case.getCaseValue()->getValue().trunc(NewWidth); Case.setValue(ConstantInt::get(SI.getContext(), TruncatedCase)); } - return &SI; + return replaceOperand(SI, 0, NewCond); } return nullptr; @@ -3175,7 +3333,7 @@ Instruction *InstCombiner::visitFreeze(FreezeInst &I) { /// instruction past all of the instructions between it and the end of its /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { - assert(I->hasOneUse() && "Invariants didn't hold!"); + assert(I->getSingleUndroppableUse() && "Invariants didn't hold!"); BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -3202,12 +3360,26 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { // We can only sink load instructions if there is nothing between the load and // the end of block that could change the value. if (I->mayReadFromMemory()) { + // We don't want to do any sophisticated alias analysis, so we only check + // the instructions after I in I's parent block if we try to sink to its + // successor block. + if (DestBlock->getUniquePredecessor() != I->getParent()) + return false; for (BasicBlock::iterator Scan = I->getIterator(), E = I->getParent()->end(); Scan != E; ++Scan) if (Scan->mayWriteToMemory()) return false; } + + I->dropDroppableUses([DestBlock](const Use *U) { + if (auto *I = dyn_cast<Instruction>(U->getUser())) + return I->getParent() != DestBlock; + return true; + }); + /// FIXME: We could remove droppable uses that are not dominated by + /// the new position. + BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); I->moveBefore(&*InsertPos); ++NumSunkInst; @@ -3219,60 +3391,70 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { // here, but that computation has been sunk. SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; findDbgUsers(DbgUsers, I); - for (auto *DII : reverse(DbgUsers)) { - if (DII->getParent() == SrcBlock) { - if (isa<DbgDeclareInst>(DII)) { - // A dbg.declare instruction should not be cloned, since there can only be - // one per variable fragment. It should be left in the original place since - // sunk instruction is not an alloca(otherwise we could not be here). - // But we need to update arguments of dbg.declare instruction, so that it - // would not point into sunk instruction. - if (!isa<CastInst>(I)) - continue; // dbg.declare points at something it shouldn't - - DII->setOperand( - 0, MetadataAsValue::get(I->getContext(), - ValueAsMetadata::get(I->getOperand(0)))); - continue; - } - // dbg.value is in the same basic block as the sunk inst, see if we can - // salvage it. Clone a new copy of the instruction: on success we need - // both salvaged and unsalvaged copies. - SmallVector<DbgVariableIntrinsic *, 1> TmpUser{ - cast<DbgVariableIntrinsic>(DII->clone())}; - - if (!salvageDebugInfoForDbgValues(*I, TmpUser)) { - // We are unable to salvage: sink the cloned dbg.value, and mark the - // original as undef, terminating any earlier variable location. - LLVM_DEBUG(dbgs() << "SINK: " << *DII << '\n'); - TmpUser[0]->insertBefore(&*InsertPos); - Value *Undef = UndefValue::get(I->getType()); - DII->setOperand(0, MetadataAsValue::get(DII->getContext(), - ValueAsMetadata::get(Undef))); - } else { - // We successfully salvaged: place the salvaged dbg.value in the - // original location, and move the unmodified dbg.value to sink with - // the sunk inst. - TmpUser[0]->insertBefore(DII); - DII->moveBefore(&*InsertPos); - } + // Update the arguments of a dbg.declare instruction, so that it + // does not point into a sunk instruction. + auto updateDbgDeclare = [&I](DbgVariableIntrinsic *DII) { + if (!isa<DbgDeclareInst>(DII)) + return false; + + if (isa<CastInst>(I)) + DII->setOperand( + 0, MetadataAsValue::get(I->getContext(), + ValueAsMetadata::get(I->getOperand(0)))); + return true; + }; + + SmallVector<DbgVariableIntrinsic *, 2> DIIClones; + for (auto User : DbgUsers) { + // A dbg.declare instruction should not be cloned, since there can only be + // one per variable fragment. It should be left in the original place + // because the sunk instruction is not an alloca (otherwise we could not be + // here). + if (User->getParent() != SrcBlock || updateDbgDeclare(User)) + continue; + + DIIClones.emplace_back(cast<DbgVariableIntrinsic>(User->clone())); + LLVM_DEBUG(dbgs() << "CLONE: " << *DIIClones.back() << '\n'); + } + + // Perform salvaging without the clones, then sink the clones. + if (!DIIClones.empty()) { + salvageDebugInfoForDbgValues(*I, DbgUsers); + for (auto &DIIClone : DIIClones) { + DIIClone->insertBefore(&*InsertPos); + LLVM_DEBUG(dbgs() << "SINK: " << *DIIClone << '\n'); } } + return true; } bool InstCombiner::run() { while (!Worklist.isEmpty()) { - Instruction *I = Worklist.RemoveOne(); + // Walk deferred instructions in reverse order, and push them to the + // worklist, which means they'll end up popped from the worklist in-order. + while (Instruction *I = Worklist.popDeferred()) { + // Check to see if we can DCE the instruction. We do this already here to + // reduce the number of uses and thus allow other folds to trigger. + // Note that eraseInstFromFunction() may push additional instructions on + // the deferred worklist, so this will DCE whole instruction chains. + if (isInstructionTriviallyDead(I, &TLI)) { + eraseInstFromFunction(*I); + ++NumDeadInst; + continue; + } + + Worklist.push(I); + } + + Instruction *I = Worklist.removeOne(); if (I == nullptr) continue; // skip null values. // Check to see if we can DCE the instruction. if (isInstructionTriviallyDead(I, &TLI)) { - LLVM_DEBUG(dbgs() << "IC: DCE: " << *I << '\n'); eraseInstFromFunction(*I); ++NumDeadInst; - MadeIRChange = true; continue; } @@ -3296,65 +3478,51 @@ bool InstCombiner::run() { } } - // In general, it is possible for computeKnownBits to determine all bits in - // a value even when the operands are not all constants. - Type *Ty = I->getType(); - if (ExpensiveCombines && !I->use_empty() && Ty->isIntOrIntVectorTy()) { - KnownBits Known = computeKnownBits(I, /*Depth*/0, I); - if (Known.isConstant()) { - Constant *C = ConstantInt::get(Ty, Known.getConstant()); - LLVM_DEBUG(dbgs() << "IC: ConstFold (all bits known) to: " << *C - << " from: " << *I << '\n'); - - // Add operands to the worklist. - replaceInstUsesWith(*I, C); - ++NumConstProp; - if (isInstructionTriviallyDead(I, &TLI)) - eraseInstFromFunction(*I); - MadeIRChange = true; - continue; - } - } - - // See if we can trivially sink this instruction to a successor basic block. - if (EnableCodeSinking && I->hasOneUse()) { - BasicBlock *BB = I->getParent(); - Instruction *UserInst = cast<Instruction>(*I->user_begin()); - BasicBlock *UserParent; - - // Get the block the use occurs in. - if (PHINode *PN = dyn_cast<PHINode>(UserInst)) - UserParent = PN->getIncomingBlock(*I->use_begin()); - else - UserParent = UserInst->getParent(); - - if (UserParent != BB) { - bool UserIsSuccessor = false; - // See if the user is one of our successors. - for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) - if (*SI == UserParent) { - UserIsSuccessor = true; - break; + // See if we can trivially sink this instruction to its user if we can + // prove that the successor is not executed more frequently than our block. + if (EnableCodeSinking) + if (Use *SingleUse = I->getSingleUndroppableUse()) { + BasicBlock *BB = I->getParent(); + Instruction *UserInst = cast<Instruction>(SingleUse->getUser()); + BasicBlock *UserParent; + + // Get the block the use occurs in. + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) + UserParent = PN->getIncomingBlock(*SingleUse); + else + UserParent = UserInst->getParent(); + + if (UserParent != BB) { + // See if the user is one of our successors that has only one + // predecessor, so that we don't have to split the critical edge. + bool ShouldSink = UserParent->getUniquePredecessor() == BB; + // Another option where we can sink is a block that ends with a + // terminator that does not pass control to other block (such as + // return or unreachable). In this case: + // - I dominates the User (by SSA form); + // - the User will be executed at most once. + // So sinking I down to User is always profitable or neutral. + if (!ShouldSink) { + auto *Term = UserParent->getTerminator(); + ShouldSink = isa<ReturnInst>(Term) || isa<UnreachableInst>(Term); } - - // If the user is one of our immediate successors, and if that successor - // only has us as a predecessors (we'd have to split the critical edge - // otherwise), we can keep going. - if (UserIsSuccessor && UserParent->getUniquePredecessor()) { - // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent)) { - LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); - MadeIRChange = true; - // We'll add uses of the sunk instruction below, but since sinking - // can expose opportunities for it's *operands* add them to the - // worklist - for (Use &U : I->operands()) - if (Instruction *OpI = dyn_cast<Instruction>(U.get())) - Worklist.Add(OpI); + if (ShouldSink) { + assert(DT.dominates(BB, UserParent) && + "Dominance relation broken?"); + // Okay, the CFG is simple enough, try to sink this instruction. + if (TryToSinkInstruction(I, UserParent)) { + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since sinking + // can expose opportunities for it's *operands* add them to the + // worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(U.get())) + Worklist.push(OpI); + } } } } - } // Now that we have an instruction, try combining it to simplify it. Builder.SetInsertPoint(I); @@ -3393,8 +3561,8 @@ bool InstCombiner::run() { InstParent->getInstList().insert(InsertPos, Result); // Push the new instruction and any users onto the worklist. - Worklist.AddUsersToWorkList(*Result); - Worklist.Add(Result); + Worklist.pushUsersToWorkList(*Result); + Worklist.push(Result); eraseInstFromFunction(*I); } else { @@ -3406,39 +3574,39 @@ bool InstCombiner::run() { if (isInstructionTriviallyDead(I, &TLI)) { eraseInstFromFunction(*I); } else { - Worklist.AddUsersToWorkList(*I); - Worklist.Add(I); + Worklist.pushUsersToWorkList(*I); + Worklist.push(I); } } MadeIRChange = true; } } - Worklist.Zap(); + Worklist.zap(); return MadeIRChange; } -/// Walk the function in depth-first order, adding all reachable code to the -/// worklist. +/// Populate the IC worklist from a function, by walking it in depth-first +/// order and adding all reachable code to the worklist. /// /// This has a couple of tricks to make the code faster and more powerful. In /// particular, we constant fold and DCE instructions as we go, to avoid adding /// them to the worklist (this significantly speeds up instcombine on code where /// many instructions are dead or constant). Additionally, if we find a branch /// whose condition is a known constant, we only visit the reachable successors. -static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, - SmallPtrSetImpl<BasicBlock *> &Visited, - InstCombineWorklist &ICWorklist, - const TargetLibraryInfo *TLI) { +static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, + const TargetLibraryInfo *TLI, + InstCombineWorklist &ICWorklist) { bool MadeIRChange = false; + SmallPtrSet<BasicBlock *, 32> Visited; SmallVector<BasicBlock*, 256> Worklist; - Worklist.push_back(BB); + Worklist.push_back(&F.front()); SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; DenseMap<Constant *, Constant *> FoldedConstants; do { - BB = Worklist.pop_back_val(); + BasicBlock *BB = Worklist.pop_back_val(); // We have now visited this block! If we've already been here, ignore it. if (!Visited.insert(BB).second) @@ -3447,16 +3615,6 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { Instruction *Inst = &*BBI++; - // DCE instruction if trivially dead. - if (isInstructionTriviallyDead(Inst, TLI)) { - ++NumDeadInst; - LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); - salvageDebugInfoOrMarkUndef(*Inst); - Inst->eraseFromParent(); - MadeIRChange = true; - continue; - } - // ConstantProp instruction if trivially constant. if (!Inst->use_empty() && (Inst->getNumOperands() == 0 || isa<Constant>(Inst->getOperand(0)))) @@ -3480,8 +3638,6 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) FoldRes = ConstantFoldConstant(C, DL, TLI); - if (!FoldRes) - FoldRes = C; if (FoldRes != C) { LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst @@ -3519,36 +3675,9 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, Worklist.push_back(SuccBB); } while (!Worklist.empty()); - // Once we've found all of the instructions to add to instcombine's worklist, - // add them in reverse order. This way instcombine will visit from the top - // of the function down. This jives well with the way that it adds all uses - // of instructions to the worklist after doing a transformation, thus avoiding - // some N^2 behavior in pathological cases. - ICWorklist.AddInitialGroup(InstrsForInstCombineWorklist); - - return MadeIRChange; -} - -/// Populate the IC worklist from a function, and prune any dead basic -/// blocks discovered in the process. -/// -/// This also does basic constant propagation and other forward fixing to make -/// the combiner itself run much faster. -static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, - TargetLibraryInfo *TLI, - InstCombineWorklist &ICWorklist) { - bool MadeIRChange = false; - - // Do a depth-first traversal of the function, populate the worklist with - // the reachable instructions. Ignore blocks that are not reachable. Keep - // track of which blocks we visit. - SmallPtrSet<BasicBlock *, 32> Visited; - MadeIRChange |= - AddReachableCodeToWorklist(&F.front(), DL, Visited, ICWorklist, TLI); - - // Do a quick scan over the function. If we find any blocks that are - // unreachable, remove any instructions inside of them. This prevents - // the instcombine code from having to deal with some bad special cases. + // Remove instructions inside unreachable blocks. This prevents the + // instcombine code from having to deal with some bad special cases, and + // reduces use counts of instructions. for (BasicBlock &BB : F) { if (Visited.count(&BB)) continue; @@ -3558,6 +3687,27 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, NumDeadInst += NumDeadInstInBB; } + // Once we've found all of the instructions to add to instcombine's worklist, + // add them in reverse order. This way instcombine will visit from the top + // of the function down. This jives well with the way that it adds all uses + // of instructions to the worklist after doing a transformation, thus avoiding + // some N^2 behavior in pathological cases. + ICWorklist.reserve(InstrsForInstCombineWorklist.size()); + for (Instruction *Inst : reverse(InstrsForInstCombineWorklist)) { + // DCE instruction if trivially dead. As we iterate in reverse program + // order here, we will clean up whole chains of dead instructions. + if (isInstructionTriviallyDead(Inst, TLI)) { + ++NumDeadInst; + LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + salvageDebugInfo(*Inst); + Inst->eraseFromParent(); + MadeIRChange = true; + continue; + } + + ICWorklist.push(Inst); + } + return MadeIRChange; } @@ -3565,10 +3715,8 @@ static bool combineInstructionsOverFunction( Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, bool ExpensiveCombines, unsigned MaxIterations, - LoopInfo *LI) { + ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { auto &DL = F.getParent()->getDataLayout(); - ExpensiveCombines |= EnableExpensiveCombines; MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); /// Builder - This is an IRBuilder that automatically inserts new @@ -3576,7 +3724,7 @@ static bool combineInstructionsOverFunction( IRBuilder<TargetFolder, IRBuilderCallbackInserter> Builder( F.getContext(), TargetFolder(DL), IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { - Worklist.Add(I); + Worklist.add(I); if (match(I, m_Intrinsic<Intrinsic::assume>())) AC.registerAssumption(cast<CallInst>(I)); })); @@ -3610,7 +3758,7 @@ static bool combineInstructionsOverFunction( MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, Builder, F.hasMinSize(), ExpensiveCombines, AA, + InstCombiner IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, DT, ORE, BFI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; @@ -3623,11 +3771,10 @@ static bool combineInstructionsOverFunction( return MadeIRChange; } -InstCombinePass::InstCombinePass(bool ExpensiveCombines) - : ExpensiveCombines(ExpensiveCombines), MaxIterations(LimitMaxIterations) {} +InstCombinePass::InstCombinePass() : MaxIterations(LimitMaxIterations) {} -InstCombinePass::InstCombinePass(bool ExpensiveCombines, unsigned MaxIterations) - : ExpensiveCombines(ExpensiveCombines), MaxIterations(MaxIterations) {} +InstCombinePass::InstCombinePass(unsigned MaxIterations) + : MaxIterations(MaxIterations) {} PreservedAnalyses InstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { @@ -3639,16 +3786,14 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto *LI = AM.getCachedResult<LoopAnalysis>(F); auto *AA = &AM.getResult<AAManager>(F); - const ModuleAnalysisManager &MAM = - AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + auto &MAMProxy = AM.getResult<ModuleAnalysisManagerFunctionProxy>(F); ProfileSummaryInfo *PSI = - MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + MAMProxy.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, - PSI, ExpensiveCombines, MaxIterations, - LI)) + PSI, MaxIterations, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3698,22 +3843,18 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, - PSI, ExpensiveCombines, MaxIterations, - LI); + PSI, MaxIterations, LI); } char InstructionCombiningPass::ID = 0; -InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines) - : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines), - MaxIterations(InstCombineDefaultMaxIterations) { +InstructionCombiningPass::InstructionCombiningPass() + : FunctionPass(ID), MaxIterations(InstCombineDefaultMaxIterations) { initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); } -InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines, - unsigned MaxIterations) - : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines), - MaxIterations(MaxIterations) { +InstructionCombiningPass::InstructionCombiningPass(unsigned MaxIterations) + : FunctionPass(ID), MaxIterations(MaxIterations) { initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); } @@ -3739,13 +3880,12 @@ void LLVMInitializeInstCombine(LLVMPassRegistryRef R) { initializeInstructionCombiningPassPass(*unwrap(R)); } -FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { - return new InstructionCombiningPass(ExpensiveCombines); +FunctionPass *llvm::createInstructionCombiningPass() { + return new InstructionCombiningPass(); } -FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines, - unsigned MaxIterations) { - return new InstructionCombiningPass(ExpensiveCombines, MaxIterations); +FunctionPass *llvm::createInstructionCombiningPass(unsigned MaxIterations) { + return new InstructionCombiningPass(MaxIterations); } void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { |