diff options
Diffstat (limited to 'contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
| -rw-r--r-- | contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 389 |
1 files changed, 348 insertions, 41 deletions
diff --git a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index ba15b023f2a3..ec976a971e3c 100644 --- a/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/contrib/llvm-project/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -890,6 +890,10 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) return SelectInst::Create(X, AddOne(Op1C), Op1); + // sext(bool) + C -> bool ? C - 1 : C + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, SubOne(Op1C), Op1); // ~X + C --> (C-1) - X if (match(Op0, m_Not(m_Value(X)))) @@ -1097,6 +1101,107 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I) { + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Or || + I.getOpcode() == Instruction::Sub) && + "Expecting add/or/sub instruction"); + + // We have a subtraction/addition between a (potentially truncated) *logical* + // right-shift of X and a "select". + Value *X, *Select; + Instruction *LowBitsToSkip, *Extract; + if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd( + m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)), + m_Instruction(Extract))), + m_Value(Select)))) + return nullptr; + + // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS. + if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select) + return nullptr; + + Type *XTy = X->getType(); + bool HadTrunc = I.getType() != XTy; + + // If there was a truncation of extracted value, then we'll need to produce + // one extra instruction, so we need to ensure one instruction will go away. + if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Extraction should extract high NBits bits, with shift amount calculated as: + // low bits to skip = shift bitwidth - high bits to extract + // The shift amount itself may be extended, and we need to look past zero-ext + // when matching NBits, that will matter for matching later. + Constant *C; + Value *NBits; + if (!match( + LowBitsToSkip, + m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) || + !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + X->getType()->getScalarSizeInBits())))) + return nullptr; + + // Sign-extending value can be zero-extended if we `sub`tract it, + // or sign-extended otherwise. + auto SkipExtInMagic = [&I](Value *&V) { + if (I.getOpcode() == Instruction::Sub) + match(V, m_ZExtOrSelf(m_Value(V))); + else + match(V, m_SExtOrSelf(m_Value(V))); + }; + + // Now, finally validate the sign-extending magic. + // `select` itself may be appropriately extended, look past that. + SkipExtInMagic(Select); + + ICmpInst::Predicate Pred; + const APInt *Thr; + Value *SignExtendingValue, *Zero; + bool ShouldSignext; + // It must be a select between two values we will later establish to be a + // sign-extending value and a zero constant. The condition guarding the + // sign-extension must be based on a sign bit of the same X we had in `lshr`. + if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)), + m_Value(SignExtendingValue), m_Value(Zero))) || + !isSignBitCheck(Pred, *Thr, ShouldSignext)) + return nullptr; + + // icmp-select pair is commutative. + if (!ShouldSignext) + std::swap(SignExtendingValue, Zero); + + // If we should not perform sign-extension then we must add/or/subtract zero. + if (!match(Zero, m_Zero())) + return nullptr; + // Otherwise, it should be some constant, left-shifted by the same NBits we + // had in `lshr`. Said left-shift can also be appropriately extended. + // Again, we must look past zero-ext when looking for NBits. + SkipExtInMagic(SignExtendingValue); + Constant *SignExtendingValueBaseConstant; + if (!match(SignExtendingValue, + m_Shl(m_Constant(SignExtendingValueBaseConstant), + m_ZExtOrSelf(m_Specific(NBits))))) + return nullptr; + // If we `sub`, then the constant should be one, else it should be all-ones. + if (I.getOpcode() == Instruction::Sub + ? !match(SignExtendingValueBaseConstant, m_One()) + : !match(SignExtendingValueBaseConstant, m_AllOnes())) + return nullptr; + + auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip, + Extract->getName() + ".sext"); + NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType()); +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1187,12 +1292,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(RHS, A); } - // Canonicalize sext to zext for better value tracking potential. - // add A, sext(B) --> sub A, zext(B) - if (match(&I, m_c_Add(m_Value(A), m_OneUse(m_SExt(m_Value(B))))) && - B->getType()->isIntOrIntVectorTy(1)) - return BinaryOperator::CreateSub(A, Builder.CreateZExt(B, Ty)); - // A + -B --> A - B if (match(RHS, m_Neg(m_Value(B)))) return BinaryOperator::CreateSub(LHS, B); @@ -1302,12 +1401,32 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *V = canonicalizeLowbitMask(I, Builder)) return V; + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) return SatAdd; return Changed ? &I : nullptr; } +/// Eliminate an op from a linear interpolation (lerp) pattern. +static Instruction *factorizeLerp(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y, *Z; + if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y), + m_OneUse(m_FSub(m_FPOne(), + m_Value(Z))))), + m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z)))))) + return nullptr; + + // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants] + Value *XY = Builder.CreateFSubFMF(X, Y, &I); + Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I); + return BinaryOperator::CreateFAddFMF(Y, MulZ, &I); +} + /// Factor a common operand out of fadd/fsub of fmul/fdiv. static Instruction *factorizeFAddFSub(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1315,6 +1434,10 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I, I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub"); assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "FP factorization requires FMF"); + + if (Instruction *Lerp = factorizeLerp(I, Builder)) + return Lerp; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y, *Z; bool IsFMul; @@ -1362,17 +1485,32 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) return FoldedFAdd; - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - Value *X; // (-X) + Y --> Y - X - if (match(LHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(RHS, X, &I); - // Y + (-X) --> Y - X - if (match(RHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(LHS, X, &I); + Value *X, *Y; + if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + // Similar to above, but look through fmul/fdiv for the negated term. + // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants] + Value *Z; + if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z)))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } + // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants] + // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants] + if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z))) || + match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))), + m_Value(Z)))) { + Value *XY = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { Value *LHSIntVal = LHSConv->getOperand(0); Type *FPType = LHSConv->getType(); @@ -1447,7 +1585,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, - Type *Ty) { + Type *Ty, bool IsNUW) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; @@ -1515,6 +1653,15 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, // Emit the offset of the GEP and an intptr_t. Value *Result = EmitGEPOffset(GEP1); + // If this is a single inbounds GEP and the original sub was nuw, + // then the final multiplication is also nuw. We match an extra add zero + // here, because that's what EmitGEPOffset() generates. + Instruction *I; + if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && + match(Result, m_Add(m_Instruction(I), m_Zero())) && + I->getOpcode() == Instruction::Mul) + I->setHasNoUnsignedWrap(); + // If we had a constant expression GEP on the other side offsetting the // pointer, subtract it from the offset we have. if (GEP2) { @@ -1631,37 +1778,50 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { const APInt *Op0C; if (match(Op0, m_APInt(Op0C))) { - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - // -(X >>u 31) -> (X >>s 31) - // -(X >>s 31) -> (X >>u 31) if (Op0C->isNullValue()) { + Value *Op1Wide; + match(Op1, m_TruncOrSelf(m_Value(Op1Wide))); + bool HadTrunc = Op1Wide != Op1; + bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse(); + unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits(); + Value *X; const APInt *ShAmt; - if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>u 31) -> (X >>s 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateAShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>s 31) -> (X >>u 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateLShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (Op1->hasOneUse()) { + if (!HadTrunc && Op1->hasOneUse()) { Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) { // This is a negate of an ABS/NABS pattern. Just swap the operands // of the select. - SelectInst *SI = cast<SelectInst>(Op1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - SI->setTrueValue(FalseVal); - SI->setFalseValue(TrueVal); + cast<SelectInst>(Op1)->swapValues(); // Don't swap prof metadata, we didn't change the branch behavior. - return replaceInstUsesWith(I, SI); + return replaceInstUsesWith(I, Op1); } } } @@ -1686,6 +1846,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (and A, B)) --> (xor A, B) + { + Value *A, *B; + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + } + + // (sub (and A, B) (or A, B)) --> neg (xor A, B) + { + Value *A, *B; + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateXor(A, B)); + } + // (sub (or A, B), (xor A, B)) --> (and A, B) { Value *A, *B; @@ -1694,6 +1871,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, B); } + // (sub (xor A, B) (or A, B)) --> neg (and A, B) + { + Value *A, *B; + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B)); + } + { Value *Y; // ((X | Y) - X) --> (~X & Y) @@ -1702,6 +1888,74 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Y, Builder.CreateNot(Op1, Op1->getName() + ".not")); } + { + // (sub (and Op1, (neg X)), Op1) --> neg (and Op1, (add X, -1)) + Value *X; + if (match(Op0, m_OneUse(m_c_And(m_Specific(Op1), + m_OneUse(m_Neg(m_Value(X))))))) { + return BinaryOperator::CreateNeg(Builder.CreateAnd( + Op1, Builder.CreateAdd(X, Constant::getAllOnesValue(I.getType())))); + } + } + + { + // (sub (and Op1, C), Op1) --> neg (and Op1, ~C) + Constant *C; + if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_Constant(C))))) { + return BinaryOperator::CreateNeg( + Builder.CreateAnd(Op1, Builder.CreateNot(C))); + } + } + + { + // If we have a subtraction between some value and a select between + // said value and something else, sink subtraction into select hands, i.e.: + // sub (select %Cond, %TrueVal, %FalseVal), %Op1 + // -> + // select %Cond, (sub %TrueVal, %Op1), (sub %FalseVal, %Op1) + // or + // sub %Op0, (select %Cond, %TrueVal, %FalseVal) + // -> + // select %Cond, (sub %Op0, %TrueVal), (sub %Op0, %FalseVal) + // This will result in select between new subtraction and 0. + auto SinkSubIntoSelect = + [Ty = I.getType()](Value *Select, Value *OtherHandOfSub, + auto SubBuilder) -> Instruction * { + Value *Cond, *TrueVal, *FalseVal; + if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal), + m_Value(FalseVal))))) + return nullptr; + if (OtherHandOfSub != TrueVal && OtherHandOfSub != FalseVal) + return nullptr; + // While it is really tempting to just create two subtractions and let + // InstCombine fold one of those to 0, it isn't possible to do so + // because of worklist visitation order. So ugly it is. + bool OtherHandOfSubIsTrueVal = OtherHandOfSub == TrueVal; + Value *NewSub = SubBuilder(OtherHandOfSubIsTrueVal ? FalseVal : TrueVal); + Constant *Zero = Constant::getNullValue(Ty); + SelectInst *NewSel = + SelectInst::Create(Cond, OtherHandOfSubIsTrueVal ? Zero : NewSub, + OtherHandOfSubIsTrueVal ? NewSub : Zero); + // Preserve prof metadata if any. + NewSel->copyMetadata(cast<Instruction>(*Select)); + return NewSel; + }; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op0, /*OtherHandOfSub=*/Op1, + [Builder = &Builder, Op1](Value *OtherHandOfSelect) { + return Builder->CreateSub(OtherHandOfSelect, + /*OtherHandOfSub=*/Op1); + })) + return NewSel; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op1, /*OtherHandOfSub=*/Op0, + [Builder = &Builder, Op0](Value *OtherHandOfSelect) { + return Builder->CreateSub(/*OtherHandOfSub=*/Op0, + OtherHandOfSelect); + })) + return NewSel; + } + if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; @@ -1717,14 +1971,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Builder.CreateNot(Y, Y->getName() + ".not")); // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. - // TODO: This could be extended to match arbitrary vector constants. - const APInt *DivC; - if (match(Op0, m_Zero()) && match(Op1, m_SDiv(m_Value(X), m_APInt(DivC))) && - !DivC->isMinSignedValue() && *DivC != 1) { - Constant *NegDivC = ConstantInt::get(I.getType(), -(*DivC)); - Instruction *BO = BinaryOperator::CreateSDiv(X, NegDivC); - BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); - return BO; + if (match(Op0, m_Zero())) { + Constant *Op11C; + if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) && + !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() && + Op11C->isNotOneValue()) { + Instruction *BO = + BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C)); + BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); + return BO; + } } // 0 - (X << Y) -> (-X << Y) when X is freely negatable. @@ -1742,6 +1998,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Add->setHasNoSignedWrap(I.hasNoSignedWrap()); return Add; } + // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y) + // 'nuw' is dropped in favor of the canonical form. + if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) { + Value *Sext = Builder.CreateSExt(Y, I.getType()); + BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext); + Add->setHasNoSignedWrap(I.hasNoSignedWrap()); + return Add; + } // X - A*-B -> X + A*B // X - -A*B -> X + A*B @@ -1778,7 +2042,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { std::swap(LHS, RHS); // LHS is now O above and expected to have at least 2 uses (the min/max) // NotA is epected to have 2 uses from the min/max and 1 from the sub. - if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && !NotA->hasNUsesOrMore(4)) { // Note: We don't generate the inverse max/min, just create the not of // it and let other folds do the rest. @@ -1796,13 +2060,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *LHSOp, *RHSOp; if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + /* IsNUW */ false)) return replaceInstUsesWith(I, Res); // Canonicalize a shifty way to code absolute value to the common pattern. @@ -1826,6 +2092,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return SelectInst::Create(Cmp, Neg, A); } + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1865,6 +2135,22 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { return nullptr; } +static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, + InstCombiner::BuilderTy &Builder) { + Value *FNeg; + if (!match(&I, m_FNeg(m_Value(FNeg)))) + return nullptr; + + Value *X, *Y; + if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + return nullptr; +} + Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { Value *Op = I.getOperand(0); @@ -1882,6 +2168,9 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) return BinaryOperator::CreateFSubFMF(Y, X, &I); + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + return nullptr; } @@ -1903,6 +2192,9 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I)) return X; + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + Value *X, *Y; Constant *C; @@ -1944,6 +2236,21 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + // Similar to above, but look through fmul/fdiv of the negated value: + // Op0 - (-X * Y) --> Op0 + (X * Y) + // Op0 - (Y * -X) --> Op0 + (X * Y) + if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) { + Value *FMul = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FMul, &I); + } + // Op0 - (-X / Y) --> Op0 + (X / Y) + // Op0 - (X / -Y) --> Op0 + (X / Y) + if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) || + match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) { + Value *FDiv = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I); + } + // Handle special cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); |
