aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-12-09 13:28:42 +0000
commitb1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch)
tree7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
parent7fa27ce4a07f19b07799a767fc29416f3b625afb (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp274
1 files changed, 221 insertions, 53 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 91ca44e0f11e..719a2678fc18 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -830,15 +830,15 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
// (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C)
Constant *NarrowC;
if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) {
- Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty);
- Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
+ Value *WideC = Builder.CreateSExt(NarrowC, Ty);
+ Value *NewC = Builder.CreateAdd(WideC, Op1C);
Value *WideX = Builder.CreateSExt(X, Ty);
return BinaryOperator::CreateAdd(WideX, NewC);
}
// (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C)
if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) {
- Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty);
- Constant *NewC = ConstantExpr::getAdd(WideC, Op1C);
+ Value *WideC = Builder.CreateZExt(NarrowC, Ty);
+ Value *NewC = Builder.CreateAdd(WideC, Op1C);
Value *WideX = Builder.CreateZExt(X, Ty);
return BinaryOperator::CreateAdd(WideX, NewC);
}
@@ -903,8 +903,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
// (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add`
Constant *Op01C;
- if (match(Op0, m_Or(m_Value(X), m_ImmConstant(Op01C))) &&
- haveNoCommonBitsSet(X, Op01C, DL, &AC, &Add, &DT))
+ if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C))))
return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C));
// (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C)
@@ -995,6 +994,69 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
return nullptr;
}
+// match variations of a^2 + 2*a*b + b^2
+//
+// to reuse the code between the FP and Int versions, the instruction OpCodes
+// and constant types have been turned into template parameters.
+//
+// Mul2Rhs: The constant to perform the multiplicative equivalent of X*2 with;
+// should be `m_SpecificFP(2.0)` for FP and `m_SpecificInt(1)` for Int
+// (we're matching `X<<1` instead of `X*2` for Int)
+template <bool FP, typename Mul2Rhs>
+static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A,
+ Value *&B) {
+ constexpr unsigned MulOp = FP ? Instruction::FMul : Instruction::Mul;
+ constexpr unsigned AddOp = FP ? Instruction::FAdd : Instruction::Add;
+ constexpr unsigned Mul2Op = FP ? Instruction::FMul : Instruction::Shl;
+
+ // (a * a) + (((a * 2) + b) * b)
+ if (match(&I, m_c_BinOp(
+ AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))),
+ m_OneUse(m_BinOp(
+ MulOp,
+ m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs),
+ m_Value(B)),
+ m_Deferred(B))))))
+ return true;
+
+ // ((a * b) * 2) or ((a * 2) * b)
+ // +
+ // (a * a + b * b) or (b * b + a * a)
+ return match(
+ &I,
+ m_c_BinOp(AddOp,
+ m_CombineOr(
+ m_OneUse(m_BinOp(
+ Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)),
+ m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs),
+ m_Value(B)))),
+ m_OneUse(m_c_BinOp(
+ AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)),
+ m_BinOp(MulOp, m_Deferred(B), m_Deferred(B))))));
+}
+
+// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2
+Instruction *InstCombinerImpl::foldSquareSumInt(BinaryOperator &I) {
+ Value *A, *B;
+ if (matchesSquareSum</*FP*/ false>(I, m_SpecificInt(1), A, B)) {
+ Value *AB = Builder.CreateAdd(A, B);
+ return BinaryOperator::CreateMul(AB, AB);
+ }
+ return nullptr;
+}
+
+// Fold floating point variations of a^2 + 2*a*b + b^2 -> (a + b)^2
+// Requires `nsz` and `reassoc`.
+Instruction *InstCombinerImpl::foldSquareSumFP(BinaryOperator &I) {
+ assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch");
+ Value *A, *B;
+ if (matchesSquareSum</*FP*/ true>(I, m_SpecificFP(2.0), A, B)) {
+ Value *AB = Builder.CreateFAddFMF(A, B, &I);
+ return BinaryOperator::CreateFMulFMF(AB, AB, &I);
+ }
+ return nullptr;
+}
+
// Matches multiplication expression Op * C where C is a constant. Returns the
// constant value in C and the other operand in Op. Returns true if such a
// match is found.
@@ -1146,6 +1208,21 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+// Transform:
+// (add A, (shl (neg B), Y))
+// -> (sub A, (shl B, Y))
+static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder,
+ const BinaryOperator &I) {
+ Value *A, *B, *Cnt;
+ if (match(&I,
+ m_c_Add(m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(B))), m_Value(Cnt))),
+ m_Value(A)))) {
+ Value *NewShl = Builder.CreateShl(B, Cnt);
+ return BinaryOperator::CreateSub(A, NewShl);
+ }
+ return nullptr;
+}
+
/// Try to reduce signed division by power-of-2 to an arithmetic shift right.
static Instruction *foldAddToAshr(BinaryOperator &Add) {
// Division must be by power-of-2, but not the minimum signed value.
@@ -1156,18 +1233,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) {
return nullptr;
// Rounding is done by adding -1 if the dividend (X) is negative and has any
- // low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
- // sext (icmp ugt (X & (DivC - 1)), SMIN)
- const APInt *MaskC;
+ // low bits set. It recognizes two canonical patterns:
+ // 1. For an 'ugt' cmp with the signed minimum value (SMIN), the
+ // pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN).
+ // 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1).
+ // Note that, by the time we end up here, if possible, ugt has been
+ // canonicalized into eq.
+ const APInt *MaskC, *MaskCCmp;
ICmpInst::Predicate Pred;
if (!match(Add.getOperand(1),
m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
- m_SignMask()))) ||
- Pred != ICmpInst::ICMP_UGT)
+ m_APInt(MaskCCmp)))))
+ return nullptr;
+
+ if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) &&
+ (Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC))
return nullptr;
APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
- if (*MaskC != (SMin | (*DivC - 1)))
+ bool IsMaskValid = Pred == ICmpInst::ICMP_UGT
+ ? (*MaskC == (SMin | (*DivC - 1)))
+ : (*DivC == 2 && *MaskC == SMin + 1);
+ if (!IsMaskValid)
return nullptr;
// (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
@@ -1327,8 +1414,10 @@ static Instruction *foldBoxMultiply(BinaryOperator &I) {
// ResLo = (CrossSum << HalfBits) + (YLo * XLo)
Value *XLo, *YLo;
Value *CrossSum;
+ // Require one-use on the multiply to avoid increasing the number of
+ // multiplications.
if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)),
- m_Mul(m_Value(YLo), m_Value(XLo)))))
+ m_OneUse(m_Mul(m_Value(YLo), m_Value(XLo))))))
return nullptr;
// XLo = X & HalfMask
@@ -1386,6 +1475,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *R = foldBinOpShiftWithShift(I))
return R;
+ if (Instruction *R = combineAddSubWithShlAddSub(Builder, I))
+ return R;
+
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
@@ -1406,7 +1498,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B));
// -A + B --> B - A
- return BinaryOperator::CreateSub(RHS, A);
+ auto *Sub = BinaryOperator::CreateSub(RHS, A);
+ auto *OB0 = cast<OverflowingBinaryOperator>(LHS);
+ Sub->setHasNoSignedWrap(I.hasNoSignedWrap() && OB0->hasNoSignedWrap());
+
+ return Sub;
}
// A + -B --> A - B
@@ -1485,8 +1581,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
// A+B --> A|B iff A and B have no bits set in common.
- if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
- return BinaryOperator::CreateOr(LHS, RHS);
+ WithCache<const Value *> LHSCache(LHS), RHSCache(RHS);
+ if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I)))
+ return BinaryOperator::CreateDisjointOr(LHS, RHS);
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1576,15 +1673,33 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
m_c_UMin(m_Deferred(A), m_Deferred(B))))))
return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
+ // (~X) + (~Y) --> -2 - (X + Y)
+ {
+ // To ensure we can save instructions we need to ensure that we consume both
+ // LHS/RHS (i.e they have a `not`).
+ bool ConsumesLHS, ConsumesRHS;
+ if (isFreeToInvert(LHS, LHS->hasOneUse(), ConsumesLHS) && ConsumesLHS &&
+ isFreeToInvert(RHS, RHS->hasOneUse(), ConsumesRHS) && ConsumesRHS) {
+ Value *NotLHS = getFreelyInverted(LHS, LHS->hasOneUse(), &Builder);
+ Value *NotRHS = getFreelyInverted(RHS, RHS->hasOneUse(), &Builder);
+ assert(NotLHS != nullptr && NotRHS != nullptr &&
+ "isFreeToInvert desynced with getFreelyInverted");
+ Value *LHSPlusRHS = Builder.CreateAdd(NotLHS, NotRHS);
+ return BinaryOperator::CreateSub(ConstantInt::get(RHS->getType(), -2),
+ LHSPlusRHS);
+ }
+ }
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
bool Changed = false;
- if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) {
+ if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoSignedWrap(true);
}
- if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) {
+ if (!I.hasNoUnsignedWrap() &&
+ willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) {
Changed = true;
I.setHasNoUnsignedWrap(true);
}
@@ -1610,11 +1725,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// ctpop(A) + ctpop(B) => ctpop(A | B) if A and B have no bits set in common.
if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(A)))) &&
match(RHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(B)))) &&
- haveNoCommonBitsSet(A, B, DL, &AC, &I, &DT))
+ haveNoCommonBitsSet(A, B, SQ.getWithInstruction(&I)))
return replaceInstUsesWith(
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateOr(A, B)}));
+ if (Instruction *Res = foldSquareSumInt(I))
+ return Res;
+
if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
return Res;
@@ -1755,10 +1873,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
// instcombined.
if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS))
if (IsValidPromotion(FPType, LHSIntVal->getType())) {
- Constant *CI =
- ConstantExpr::getFPToSI(CFP, LHSIntVal->getType());
+ Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP,
+ LHSIntVal->getType(), DL);
if (LHSConv->hasOneUse() &&
- ConstantExpr::getSIToFP(CI, I.getType()) == CFP &&
+ ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) ==
+ CFP &&
willNotOverflowSignedAdd(LHSIntVal, CI, I)) {
// Insert the new integer add.
Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv");
@@ -1794,6 +1913,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
if (Instruction *F = factorizeFAddFSub(I, Builder))
return F;
+ if (Instruction *F = foldSquareSumFP(I))
+ return F;
+
// Try to fold fadd into start value of reduction intrinsic.
if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>(
m_AnyZeroFP(), m_Value(X))),
@@ -2017,14 +2139,16 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// C-(X+C2) --> (C-C2)-X
if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) {
- // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW
- // => (C-C2)-X can have NSW
+ // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW/NUW
+ // => (C-C2)-X can have NSW/NUW
bool WillNotSOV = willNotOverflowSignedSub(C, C2, I);
BinaryOperator *Res =
BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() &&
WillNotSOV);
+ Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap() &&
+ OBO1->hasNoUnsignedWrap());
return Res;
}
}
@@ -2058,7 +2182,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) ||
match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1)));
})) {
- if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this))
+ if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation &&
+ I.hasNoSignedWrap(),
+ Op1, *this))
return BinaryOperator::CreateAdd(NegOp1, Op0);
}
if (IsNegation)
@@ -2093,19 +2219,50 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
// ((X - Y) - Op1) --> X - (Y + Op1)
if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) {
- Value *Add = Builder.CreateAdd(Y, Op1);
- return BinaryOperator::CreateSub(X, Add);
+ OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0);
+ bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap();
+ bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap();
+ Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW,
+ /* HasNSW */ HasNSW);
+ BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add);
+ Sub->setHasNoUnsignedWrap(HasNUW);
+ Sub->setHasNoSignedWrap(HasNSW);
+ return Sub;
+ }
+
+ {
+ // (X + Z) - (Y + Z) --> (X - Y)
+ // This is done in other passes, but we want to be able to consume this
+ // pattern in InstCombine so we can generate it without creating infinite
+ // loops.
+ if (match(Op0, m_Add(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Add(m_Value(Y), m_Specific(Z))))
+ return BinaryOperator::CreateSub(X, Y);
+
+ // (X + C0) - (Y + C1) --> (X - Y) + (C0 - C1)
+ Constant *CX, *CY;
+ if (match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(CX)))) &&
+ match(Op1, m_OneUse(m_Add(m_Value(Y), m_ImmConstant(CY))))) {
+ Value *OpsSub = Builder.CreateSub(X, Y);
+ Constant *ConstsSub = ConstantExpr::getSub(CX, CY);
+ return BinaryOperator::CreateAdd(OpsSub, ConstsSub);
+ }
}
// (~X) - (~Y) --> Y - X
- // This is placed after the other reassociations and explicitly excludes a
- // sub-of-sub pattern to avoid infinite looping.
- if (isFreeToInvert(Op0, Op0->hasOneUse()) &&
- isFreeToInvert(Op1, Op1->hasOneUse()) &&
- !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) {
- Value *NotOp0 = Builder.CreateNot(Op0);
- Value *NotOp1 = Builder.CreateNot(Op1);
- return BinaryOperator::CreateSub(NotOp1, NotOp0);
+ {
+ // Need to ensure we can consume at least one of the `not` instructions,
+ // otherwise this can inf loop.
+ bool ConsumesOp0, ConsumesOp1;
+ if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) &&
+ isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) &&
+ (ConsumesOp0 || ConsumesOp1)) {
+ Value *NotOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder);
+ Value *NotOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder);
+ assert(NotOp0 != nullptr && NotOp1 != nullptr &&
+ "isFreeToInvert desynced with getFreelyInverted");
+ return BinaryOperator::CreateSub(NotOp1, NotOp0);
+ }
}
auto m_AddRdx = [](Value *&Vec) {
@@ -2520,18 +2677,33 @@ static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) {
return nullptr;
}
-static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
- InstCombiner::BuilderTy &Builder) {
- Value *FNeg;
- if (!match(&I, m_FNeg(m_Value(FNeg))))
- return nullptr;
-
+Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp,
+ Instruction &FMFSource) {
Value *X, *Y;
- if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))
- return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+ if (match(FNegOp, m_FMul(m_Value(X), m_Value(Y)))) {
+ return cast<Instruction>(Builder.CreateFMulFMF(
+ Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource));
+ }
+
+ if (match(FNegOp, m_FDiv(m_Value(X), m_Value(Y)))) {
+ return cast<Instruction>(Builder.CreateFDivFMF(
+ Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource));
+ }
+
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(FNegOp)) {
+ // Make sure to preserve flags and metadata on the call.
+ if (II->getIntrinsicID() == Intrinsic::ldexp) {
+ FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags();
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(FMF);
- if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))))
- return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+ CallInst *New = Builder.CreateCall(
+ II->getCalledFunction(),
+ {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)});
+ New->copyMetadata(*II);
+ return New;
+ }
+ }
return nullptr;
}
@@ -2553,13 +2725,13 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y)))))
return BinaryOperator::CreateFSubFMF(Y, X, &I);
- if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
- return R;
-
Value *OneUse;
if (!match(Op, m_OneUse(m_Value(OneUse))))
return nullptr;
+ if (Instruction *R = hoistFNegAboveFMulFDiv(OneUse, I))
+ return replaceInstUsesWith(I, R);
+
// Try to eliminate fneg if at least 1 arm of the select is negated.
Value *Cond;
if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) {
@@ -2569,8 +2741,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) {
S->copyFastMathFlags(&I);
if (auto *OldSel = dyn_cast<SelectInst>(Op)) {
- FastMathFlags FMF = I.getFastMathFlags();
- FMF |= OldSel->getFastMathFlags();
+ FastMathFlags FMF = I.getFastMathFlags() | OldSel->getFastMathFlags();
S->setFastMathFlags(FMF);
if (!OldSel->hasNoSignedZeros() && !CommonOperand &&
!isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition()))
@@ -2638,9 +2809,6 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
- if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
- return R;
-
Value *X, *Y;
Constant *C;