diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 332 |
1 files changed, 161 insertions, 171 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index ec976a971e3ce..a7f5e0a7774d2 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -270,7 +270,7 @@ void FAddendCoef::operator=(const FAddendCoef &That) { } void FAddendCoef::operator+=(const FAddendCoef &That) { - enum APFloat::roundingMode RndMode = APFloat::rmNearestTiesToEven; + RoundingMode RndMode = RoundingMode::NearestTiesToEven; if (isInt() == That.isInt()) { if (isInt()) IntVal += That.IntVal; @@ -663,8 +663,7 @@ Value *FAddCombine::createFSub(Value *Opnd0, Value *Opnd1) { } Value *FAddCombine::createFNeg(Value *V) { - Value *Zero = cast<Value>(ConstantFP::getZeroValueForNegation(V->getType())); - Value *NewV = createFSub(Zero, V); + Value *NewV = Builder.CreateFNeg(V); if (Instruction *I = dyn_cast<Instruction>(NewV)) createInstPostProc(I, true); // fneg's don't receive instruction numbers. return NewV; @@ -724,8 +723,6 @@ unsigned FAddCombine::calcInstrNumber(const AddendVect &Opnds) { if (!CE.isMinusOne() && !CE.isOne()) InstrNeeded++; } - if (NegOpndNum == OpndNum) - InstrNeeded++; return InstrNeeded; } @@ -1044,8 +1041,7 @@ Value *InstCombiner::SimplifyAddWithRemainder(BinaryOperator &I) { // Match RemOpV = X / C0 if (MatchDiv(RemOpV, DivOpV, DivOpC, IsSigned) && X == DivOpV && C0 == DivOpC && !MulWillOverflow(C0, C1, IsSigned)) { - Value *NewDivisor = - ConstantInt::get(X->getType()->getContext(), C0 * C1); + Value *NewDivisor = ConstantInt::get(X->getType(), C0 * C1); return IsSigned ? Builder.CreateSRem(X, NewDivisor, "srem") : Builder.CreateURem(X, NewDivisor, "urem"); } @@ -1307,9 +1303,28 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { match(&I, m_BinOp(m_c_Add(m_Not(m_Value(B)), m_Value(A)), m_One()))) return BinaryOperator::CreateSub(A, B); + // (A + RHS) + RHS --> A + (RHS << 1) + if (match(LHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(RHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(RHS, 1, "reass.add")); + + // LHS + (A + LHS) --> A + (LHS << 1) + if (match(RHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(LHS))))) + return BinaryOperator::CreateAdd(A, Builder.CreateShl(LHS, 1, "reass.add")); + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); + // ((X s/ C1) << C2) + X => X s% -C1 where -C1 is 1 << C2 + const APInt *C1, *C2; + if (match(LHS, m_Shl(m_SDiv(m_Specific(RHS), m_APInt(C1)), m_APInt(C2)))) { + APInt one(C2->getBitWidth(), 1); + APInt minusC1 = -(*C1); + if (minusC1 == (one << *C2)) { + Constant *NewRHS = ConstantInt::get(RHS->getType(), minusC1); + return BinaryOperator::CreateSRem(RHS, NewRHS); + } + } + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); @@ -1380,8 +1395,9 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // (add (and A, B) (or A, B)) --> (add A, B) if (match(&I, m_c_BinOp(m_Or(m_Value(A), m_Value(B)), m_c_And(m_Deferred(A), m_Deferred(B))))) { - I.setOperand(0, A); - I.setOperand(1, B); + // Replacing operands in-place to preserve nuw/nsw flags. + replaceOperand(I, 0, A); + replaceOperand(I, 1, B); return &I; } @@ -1685,12 +1701,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (Instruction *X = foldVectorBinop(I)) return X; - // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) - return replaceInstUsesWith(I, V); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // If this is a 'B = x-(-A)', change to B = x+A. - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + // We deal with this without involving Negator to preserve NSW flag. if (Value *V = dyn_castNegVal(Op1)) { BinaryOperator *Res = BinaryOperator::CreateAdd(Op0, V); @@ -1707,6 +1721,45 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return Res; } + auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * { + if (Instruction *Ext = narrowMathIfNoOverflow(I)) + return Ext; + + bool Changed = false; + if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoSignedWrap(true); + } + if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { + Changed = true; + I.setHasNoUnsignedWrap(true); + } + + return Changed ? &I : nullptr; + }; + + // First, let's try to interpret `sub a, b` as `add a, (sub 0, b)`, + // and let's try to sink `(sub 0, b)` into `b` itself. But only if this isn't + // a pure negation used by a select that looks like abs/nabs. + bool IsNegation = match(Op0, m_ZeroInt()); + if (!IsNegation || none_of(I.users(), [&I, Op1](const User *U) { + const Instruction *UI = dyn_cast<Instruction>(U); + if (!UI) + return false; + return match(UI, + m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || + match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); + })) { + if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this)) + return BinaryOperator::CreateAdd(NegOp1, Op0); + } + if (IsNegation) + return TryToNarrowDeduceFlags(); // Should have been handled in Negator! + + // (A*B)-(A*C) -> A*(B-C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + if (I.getType()->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(Op0, Op1); @@ -1723,33 +1776,40 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes())))) return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X); - // Y - (X + 1) --> ~X + Y - if (match(Op1, m_OneUse(m_Add(m_Value(X), m_One())))) - return BinaryOperator::CreateAdd(Builder.CreateNot(X), Op0); + // Reassociate sub/add sequences to create more add instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_Add(m_OneUse(m_Sub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateAdd(X, Z); + Value *YW = Builder.CreateAdd(Y, Op1); + return BinaryOperator::CreateSub(XZ, YW); + } - // Y - ~X --> (X + 1) + Y - if (match(Op1, m_OneUse(m_Not(m_Value(X))))) { - return BinaryOperator::CreateAdd( - Builder.CreateAdd(Op0, ConstantInt::get(I.getType(), 1)), X); + auto m_AddRdx = [](Value *&Vec) { + return m_OneUse( + m_Intrinsic<Intrinsic::experimental_vector_reduce_add>(m_Value(Vec))); + }; + Value *V0, *V1; + if (match(Op0, m_AddRdx(V0)) && match(Op1, m_AddRdx(V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(V0) - add_rdx(V1) --> add_rdx(V0 - V1) + Value *Sub = Builder.CreateSub(V0, V1); + Value *Rdx = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_reduce_add, {Sub->getType()}, {Sub}); + return replaceInstUsesWith(I, Rdx); } if (Constant *C = dyn_cast<Constant>(Op0)) { - bool IsNegate = match(C, m_ZeroInt()); Value *X; - if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // 0 - (zext bool) --> sext bool + if (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (zext bool) --> bool ? C - 1 : C - if (IsNegate) - return CastInst::CreateSExtOrBitCast(X, I.getType()); return SelectInst::Create(X, SubOne(C), C); - } - if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - // 0 - (sext bool) --> zext bool + if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) // C - (sext bool) --> bool ? C + 1 : C - if (IsNegate) - return CastInst::CreateZExtOrBitCast(X, I.getType()); return SelectInst::Create(X, AddOne(C), C); - } // C - ~X == X + (1+C) if (match(Op1, m_Not(m_Value(X)))) @@ -1768,7 +1828,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Constant *C2; // C-(C2-X) --> X+(C-C2) - if (match(Op1, m_Sub(m_Constant(C2), m_Value(X)))) + if (match(Op1, m_Sub(m_Constant(C2), m_Value(X))) && !isa<ConstantExpr>(C2)) return BinaryOperator::CreateAdd(X, ConstantExpr::getSub(C, C2)); // C-(X+C2) --> (C-C2)-X @@ -1777,62 +1837,12 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { } const APInt *Op0C; - if (match(Op0, m_APInt(Op0C))) { - - if (Op0C->isNullValue()) { - Value *Op1Wide; - match(Op1, m_TruncOrSelf(m_Value(Op1Wide))); - bool HadTrunc = Op1Wide != Op1; - bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse(); - unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits(); - - Value *X; - const APInt *ShAmt; - // -(X >>u 31) -> (X >>s 31) - if (NoTruncOrTruncIsOneUse && - match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) && - *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); - Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp); - NewShift->copyIRFlags(Op1Wide); - if (!HadTrunc) - return NewShift; - Builder.Insert(NewShift); - return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); - } - // -(X >>s 31) -> (X >>u 31) - if (NoTruncOrTruncIsOneUse && - match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) && - *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); - Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp); - NewShift->copyIRFlags(Op1Wide); - if (!HadTrunc) - return NewShift; - Builder.Insert(NewShift); - return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); - } - - if (!HadTrunc && Op1->hasOneUse()) { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; - if (SPF == SPF_ABS || SPF == SPF_NABS) { - // This is a negate of an ABS/NABS pattern. Just swap the operands - // of the select. - cast<SelectInst>(Op1)->swapValues(); - // Don't swap prof metadata, we didn't change the branch behavior. - return replaceInstUsesWith(I, Op1); - } - } - } - + if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. - if (Op0C->isMask()) { - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) - return BinaryOperator::CreateXor(Op1, Op0); - } + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) + return BinaryOperator::CreateXor(Op1, Op0); } { @@ -1956,71 +1966,11 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return NewSel; } - if (Op1->hasOneUse()) { - Value *X = nullptr, *Y = nullptr, *Z = nullptr; - Constant *C = nullptr; - - // (X - (Y - Z)) --> (X + (Z - Y)). - if (match(Op1, m_Sub(m_Value(Y), m_Value(Z)))) - return BinaryOperator::CreateAdd(Op0, - Builder.CreateSub(Z, Y, Op1->getName())); - - // (X - (X & Y)) --> (X & ~Y) - if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) - return BinaryOperator::CreateAnd(Op0, - Builder.CreateNot(Y, Y->getName() + ".not")); - - // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. - if (match(Op0, m_Zero())) { - Constant *Op11C; - if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) && - !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() && - Op11C->isNotOneValue()) { - Instruction *BO = - BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C)); - BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); - return BO; - } - } - - // 0 - (X << Y) -> (-X << Y) when X is freely negatable. - if (match(Op1, m_Shl(m_Value(X), m_Value(Y))) && match(Op0, m_Zero())) - if (Value *XNeg = dyn_castNegVal(X)) - return BinaryOperator::CreateShl(XNeg, Y); - - // Subtracting -1/0 is the same as adding 1/0: - // sub [nsw] Op0, sext(bool Y) -> add [nsw] Op0, zext(bool Y) - // 'nuw' is dropped in favor of the canonical form. - if (match(Op1, m_SExt(m_Value(Y))) && - Y->getType()->getScalarSizeInBits() == 1) { - Value *Zext = Builder.CreateZExt(Y, I.getType()); - BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Zext); - Add->setHasNoSignedWrap(I.hasNoSignedWrap()); - return Add; - } - // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y) - // 'nuw' is dropped in favor of the canonical form. - if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) { - Value *Sext = Builder.CreateSExt(Y, I.getType()); - BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext); - Add->setHasNoSignedWrap(I.hasNoSignedWrap()); - return Add; - } - - // X - A*-B -> X + A*B - // X - -A*B -> X + A*B - Value *A, *B; - if (match(Op1, m_c_Mul(m_Value(A), m_Neg(m_Value(B))))) - return BinaryOperator::CreateAdd(Op0, Builder.CreateMul(A, B)); - - // X - A*C -> X + A*-C - // No need to handle commuted multiply because multiply handling will - // ensure constant will be move to the right hand side. - if (match(Op1, m_Mul(m_Value(A), m_Constant(C))) && !isa<ConstantExpr>(C)) { - Value *NewMul = Builder.CreateMul(A, ConstantExpr::getNeg(C)); - return BinaryOperator::CreateAdd(Op0, NewMul); - } - } + // (X - (X & Y)) --> (X & ~Y) + if (match(Op1, m_c_And(m_Specific(Op0), m_Value(Y))) && + (Op1->hasOneUse() || isa<Constant>(Y))) + return BinaryOperator::CreateAnd( + Op0, Builder.CreateNot(Y, Y->getName() + ".not")); { // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A @@ -2096,20 +2046,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; - if (Instruction *Ext = narrowMathIfNoOverflow(I)) - return Ext; - - bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedSub(Op0, Op1, I)) { - Changed = true; - I.setHasNoSignedWrap(true); - } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedSub(Op0, Op1, I)) { - Changed = true; - I.setHasNoUnsignedWrap(true); - } - - return Changed ? &I : nullptr; + return TryToNarrowDeduceFlags(); } /// This eliminates floating-point negation in either 'fneg(X)' or @@ -2132,6 +2069,12 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { if (match(&I, m_FNeg(m_OneUse(m_FDiv(m_Constant(C), m_Value(X)))))) return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: + // -(X + C) --> -X + -C --> -C - X + if (I.hasNoSignedZeros() && + match(&I, m_FNeg(m_OneUse(m_FAdd(m_Value(X), m_Constant(C)))))) + return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + return nullptr; } @@ -2184,10 +2127,15 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return X; // Subtraction from -0.0 is the canonical form of fneg. - // fsub nsz 0, X ==> fsub nsz -0.0, X - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - if (I.hasNoSignedZeros() && match(Op0, m_PosZeroFP())) - return BinaryOperator::CreateFNegFMF(Op1, &I); + // fsub -0.0, X ==> fneg X + // fsub nsz 0.0, X ==> fneg nsz X + // + // FIXME This matcher does not respect FTZ or DAZ yet: + // fsub -0.0, Denorm ==> +-0 + // fneg Denorm ==> -Denorm + Value *Op; + if (match(&I, m_FNeg(m_Value(Op)))) + return UnaryOperator::CreateFNegFMF(Op, &I); if (Instruction *X = foldFNegIntoConstant(I)) return X; @@ -2198,6 +2146,7 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { Value *X, *Y; Constant *C; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // If Op0 is not -0.0 or we can ignore -0.0: Z - (X - Y) --> Z + (Y - X) // Canonicalize to fadd to make analysis easier. // This can also help codegen because fadd is commutative. @@ -2211,6 +2160,13 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } } + // (-X) - Op1 --> -(X + Op1) + if (I.hasNoSignedZeros() && !isa<ConstantExpr>(Op0) && + match(Op0, m_OneUse(m_FNeg(m_Value(X))))) { + Value *FAdd = Builder.CreateFAddFMF(X, Op1, &I); + return UnaryOperator::CreateFNegFMF(FAdd, &I); + } + if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) if (Instruction *NV = FoldOpIntoSelect(I, SI)) @@ -2258,12 +2214,12 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (I.hasAllowReassoc() && I.hasNoSignedZeros()) { // (Y - X) - Y --> -X if (match(Op0, m_FSub(m_Specific(Op1), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + return UnaryOperator::CreateFNegFMF(X, &I); // Y - (X + Y) --> -X // Y - (Y + X) --> -X if (match(Op1, m_c_FAdd(m_Specific(Op0), m_Value(X)))) - return BinaryOperator::CreateFNegFMF(X, &I); + return UnaryOperator::CreateFNegFMF(X, &I); // (X * C) - X --> X * (C - 1.0) if (match(Op0, m_FMul(m_Specific(Op1), m_Constant(C)))) { @@ -2276,6 +2232,34 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(Op0, OneSubC, &I); } + // Reassociate fsub/fadd sequences to create more fadd instructions and + // reduce dependency chains: + // ((X - Y) + Z) - Op1 --> (X + Z) - (Y + Op1) + Value *Z; + if (match(Op0, m_OneUse(m_c_FAdd(m_OneUse(m_FSub(m_Value(X), m_Value(Y))), + m_Value(Z))))) { + Value *XZ = Builder.CreateFAddFMF(X, Z, &I); + Value *YW = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(XZ, YW, &I); + } + + auto m_FaddRdx = [](Value *&Sum, Value *&Vec) { + return m_OneUse( + m_Intrinsic<Intrinsic::experimental_vector_reduce_v2_fadd>( + m_Value(Sum), m_Value(Vec))); + }; + Value *A0, *A1, *V0, *V1; + if (match(Op0, m_FaddRdx(A0, V0)) && match(Op1, m_FaddRdx(A1, V1)) && + V0->getType() == V1->getType()) { + // Difference of sums is sum of differences: + // add_rdx(A0, V0) - add_rdx(A1, V1) --> add_rdx(A0, V0 - V1) - A1 + Value *Sub = Builder.CreateFSubFMF(V0, V1, &I); + Value *Rdx = Builder.CreateIntrinsic( + Intrinsic::experimental_vector_reduce_v2_fadd, + {A0->getType(), Sub->getType()}, {A0, Sub}, &I); + return BinaryOperator::CreateFSubFMF(Rdx, A1, &I); + } + if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; @@ -2285,6 +2269,12 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { // complex pattern matching and remove this from InstCombine. if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); + + // (X - Y) - Op1 --> X - (Y + Op1) + if (match(Op0, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) { + Value *FAdd = Builder.CreateFAddFMF(Y, Op1, &I); + return BinaryOperator::CreateFSubFMF(X, FAdd, &I); + } } return nullptr; |