aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp328
1 files changed, 255 insertions, 73 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 4a459ec6c550..b68efc993723 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -576,8 +576,7 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
}
}
- assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) &&
- "out-of-bound access");
+ assert((NextTmpIdx <= std::size(TmpResult) + 1) && "out-of-bound access");
Value *Result;
if (!SimpVect.empty())
@@ -849,6 +848,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
+ Type *Ty = Add.getType();
Constant *Op1C;
if (!match(Op1, m_ImmConstant(Op1C)))
return nullptr;
@@ -883,7 +883,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
if (match(Op0, m_Not(m_Value(X))))
return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X);
+ // (iN X s>> (N - 1)) + 1 --> zext (X > -1)
const APInt *C;
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ if (match(Op0, m_OneUse(m_AShr(m_Value(X),
+ m_SpecificIntAllowUndef(BitWidth - 1)))) &&
+ match(Op1, m_One()))
+ return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
+
if (!match(Op1, m_APInt(C)))
return nullptr;
@@ -911,7 +918,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
// Is this add the last step in a convoluted sext?
// add(zext(xor i16 X, -32768), -32768) --> sext X
- Type *Ty = Add.getType();
if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
return CastInst::Create(Instruction::SExt, X, Ty);
@@ -969,15 +975,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
}
}
- // If all bits affected by the add are included in a high-bit-mask, do the
- // add before the mask op:
- // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00
- if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) &&
- C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) {
- Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C));
- return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2));
- }
-
return nullptr;
}
@@ -1132,6 +1129,35 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+/// Try to reduce signed division by power-of-2 to an arithmetic shift right.
+static Instruction *foldAddToAshr(BinaryOperator &Add) {
+ // Division must be by power-of-2, but not the minimum signed value.
+ Value *X;
+ const APInt *DivC;
+ if (!match(Add.getOperand(0), m_SDiv(m_Value(X), m_Power2(DivC))) ||
+ DivC->isNegative())
+ return nullptr;
+
+ // Rounding is done by adding -1 if the dividend (X) is negative and has any
+ // low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
+ // sext (icmp ugt (X & (DivC - 1)), SMIN)
+ const APInt *MaskC;
+ ICmpInst::Predicate Pred;
+ if (!match(Add.getOperand(1),
+ m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
+ m_SignMask()))) ||
+ Pred != ICmpInst::ICMP_UGT)
+ return nullptr;
+
+ APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
+ if (*MaskC != (SMin | (*DivC - 1)))
+ return nullptr;
+
+ // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
+ return BinaryOperator::CreateAShr(
+ X, ConstantInt::get(Add.getType(), DivC->exactLogBase2()));
+}
+
Instruction *InstCombinerImpl::
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
BinaryOperator &I) {
@@ -1234,7 +1260,7 @@ Instruction *InstCombinerImpl::
}
/// This is a specialization of a more general transform from
-/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally
+/// foldUsingDistributiveLaws. If that code can be made to work optimally
/// for multi-use cases or propagating nsw/nuw, then we would not need this.
static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
@@ -1270,6 +1296,45 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
return NewShl;
}
+/// Reduce a sequence of masked half-width multiplies to a single multiply.
+/// ((XLow * YHigh) + (YLow * XHigh)) << HalfBits) + (XLow * YLow) --> X * Y
+static Instruction *foldBoxMultiply(BinaryOperator &I) {
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ // Skip the odd bitwidth types.
+ if ((BitWidth & 0x1))
+ return nullptr;
+
+ unsigned HalfBits = BitWidth >> 1;
+ APInt HalfMask = APInt::getMaxValue(HalfBits);
+
+ // ResLo = (CrossSum << HalfBits) + (YLo * XLo)
+ Value *XLo, *YLo;
+ Value *CrossSum;
+ if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)),
+ m_Mul(m_Value(YLo), m_Value(XLo)))))
+ return nullptr;
+
+ // XLo = X & HalfMask
+ // YLo = Y & HalfMask
+ // TODO: Refactor with SimplifyDemandedBits or KnownBits known leading zeros
+ // to enhance robustness
+ Value *X, *Y;
+ if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) ||
+ !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask))))
+ return nullptr;
+
+ // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits))
+ // X' can be either X or XLo in the pattern (and the same for Y')
+ if (match(CrossSum,
+ m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)),
+ m_CombineOr(m_Specific(X), m_Specific(XLo))),
+ m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)),
+ m_CombineOr(m_Specific(Y), m_Specific(YLo))))))
+ return BinaryOperator::CreateMul(X, Y);
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -1286,9 +1351,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return Phi;
// (A*B)+(A*C) -> A*(B+C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
+ if (Instruction *R = foldBoxMultiply(I))
+ return R;
+
if (Instruction *R = factorizeMathWithShlOps(I, Builder))
return R;
@@ -1376,35 +1444,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(A, NewMask);
}
+ // ZExt (B - A) + ZExt(A) --> ZExt(B)
+ if ((match(RHS, m_ZExt(m_Value(A))) &&
+ match(LHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))) ||
+ (match(LHS, m_ZExt(m_Value(A))) &&
+ match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))))
+ return new ZExtInst(B, LHS->getType());
+
// A+B --> A|B iff A and B have no bits set in common.
if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
return BinaryOperator::CreateOr(LHS, RHS);
- // add (select X 0 (sub n A)) A --> select X A n
- {
- SelectInst *SI = dyn_cast<SelectInst>(LHS);
- Value *A = RHS;
- if (!SI) {
- SI = dyn_cast<SelectInst>(RHS);
- A = LHS;
- }
- if (SI && SI->hasOneUse()) {
- Value *TV = SI->getTrueValue();
- Value *FV = SI->getFalseValue();
- Value *N;
-
- // Can we fold the add into the argument of the select?
- // We check both true and false select arguments for a matching subtract.
- if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
- // Fold the add into the true select value.
- return SelectInst::Create(SI->getCondition(), N, A);
-
- if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
- // Fold the add into the false select value.
- return SelectInst::Create(SI->getCondition(), A, N);
- }
- }
-
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1424,6 +1474,68 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return &I;
}
+ // (add A (or A, -A)) --> (and (add A, -1) A)
+ // (add A (or -A, A)) --> (and (add A, -1) A)
+ // (add (or A, -A) A) --> (and (add A, -1) A)
+ // (add (or -A, A) A) --> (and (add A, -1) A)
+ if (match(&I, m_c_BinOp(m_Value(A), m_OneUse(m_c_Or(m_Neg(m_Deferred(A)),
+ m_Deferred(A)))))) {
+ Value *Add =
+ Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()), "",
+ I.hasNoUnsignedWrap(), I.hasNoSignedWrap());
+ return BinaryOperator::CreateAnd(Add, A);
+ }
+
+ // Canonicalize ((A & -A) - 1) --> ((A - 1) & ~A)
+ // Forms all commutable operations, and simplifies ctpop -> cttz folds.
+ if (match(&I,
+ m_Add(m_OneUse(m_c_And(m_Value(A), m_OneUse(m_Neg(m_Deferred(A))))),
+ m_AllOnes()))) {
+ Constant *AllOnes = ConstantInt::getAllOnesValue(RHS->getType());
+ Value *Dec = Builder.CreateAdd(A, AllOnes);
+ Value *Not = Builder.CreateXor(A, AllOnes);
+ return BinaryOperator::CreateAnd(Dec, Not);
+ }
+
+ // Disguised reassociation/factorization:
+ // ~(A * C1) + A
+ // ((A * -C1) - 1) + A
+ // ((A * -C1) + A) - 1
+ // (A * (1 - C1)) - 1
+ if (match(&I,
+ m_c_Add(m_OneUse(m_Not(m_OneUse(m_Mul(m_Value(A), m_APInt(C1))))),
+ m_Deferred(A)))) {
+ Type *Ty = I.getType();
+ Constant *NewMulC = ConstantInt::get(Ty, 1 - *C1);
+ Value *NewMul = Builder.CreateMul(A, NewMulC);
+ return BinaryOperator::CreateAdd(NewMul, ConstantInt::getAllOnesValue(Ty));
+ }
+
+ // (A * -2**C) + B --> B - (A << C)
+ const APInt *NegPow2C;
+ if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))),
+ m_Value(B)))) {
+ Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros());
+ Value *Shl = Builder.CreateShl(A, ShiftAmtC);
+ return BinaryOperator::CreateSub(B, Shl);
+ }
+
+ // Canonicalize signum variant that ends in add:
+ // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0))
+ ICmpInst::Predicate Pred;
+ uint64_t BitWidth = Ty->getScalarSizeInBits();
+ if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) &&
+ match(RHS, m_OneUse(m_ZExt(
+ m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) &&
+ Pred == CmpInst::ICMP_SGT) {
+ Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull");
+ Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext");
+ return BinaryOperator::CreateOr(LHS, Zext);
+ }
+
+ if (Instruction *Ashr = foldAddToAshr(I))
+ return Ashr;
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
@@ -1665,6 +1777,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return BinaryOperator::CreateFMulFMF(X, NewMulC, &I);
}
+ // (-X - Y) + (X + Z) --> Z - Y
+ if (match(&I, m_c_FAdd(m_FSub(m_FNeg(m_Value(X)), m_Value(Y)),
+ m_c_FAdd(m_Deferred(X), m_Value(Z)))))
+ return BinaryOperator::CreateFSubFMF(Z, Y, &I);
+
if (Value *V = FAddCombine(Builder).simplify(&I))
return replaceInstUsesWith(I, V);
}
@@ -1879,7 +1996,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
return TryToNarrowDeduceFlags(); // Should have been handled in Negator!
// (A*B)-(A*C) -> A*(B-C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
if (I.getType()->isIntOrIntVectorTy(1))
@@ -1967,12 +2084,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
const APInt *Op0C;
- if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) {
- // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
- // zero.
- KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
- if ((*Op0C | RHSKnown.Zero).isAllOnes())
- return BinaryOperator::CreateXor(Op1, Op0);
+ if (match(Op0, m_APInt(Op0C))) {
+ if (Op0C->isMask()) {
+ // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
+ // zero.
+ KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
+ if ((*Op0C | RHSKnown.Zero).isAllOnes())
+ return BinaryOperator::CreateXor(Op1, Op0);
+ }
+
+ // C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when:
+ // (C3 - ((C2 & C3) - 1)) is pow2
+ // ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1)
+ // C2 is negative pow2 || sub nuw
+ const APInt *C2, *C3;
+ BinaryOperator *InnerSub;
+ if (match(Op1, m_OneUse(m_And(m_BinOp(InnerSub), m_APInt(C2)))) &&
+ match(InnerSub, m_Sub(m_APInt(C3), m_Value(X))) &&
+ (InnerSub->hasNoUnsignedWrap() || C2->isNegatedPowerOf2())) {
+ APInt C2AndC3 = *C2 & *C3;
+ APInt C2AndC3Minus1 = C2AndC3 - 1;
+ APInt C2AddC3 = *C2 + *C3;
+ if ((*C3 - C2AndC3Minus1).isPowerOf2() &&
+ C2AndC3Minus1.isSubsetOf(C2AddC3)) {
+ Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), *C2));
+ return BinaryOperator::CreateAdd(
+ And, ConstantInt::get(I.getType(), *Op0C - C2AndC3));
+ }
+ }
}
{
@@ -2165,8 +2304,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
Value *A;
const APInt *ShAmt;
Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
- Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
+ Op1->hasNUses(2) && *ShAmt == BitWidth - 1 &&
match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) {
// B = ashr i32 A, 31 ; smear the sign bit
// sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1)
@@ -2185,7 +2325,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
const APInt *AddC, *AndC;
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) &&
match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) {
- unsigned BitWidth = Ty->getScalarSizeInBits();
unsigned Cttz = AddC->countTrailingZeros();
APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz));
if ((HighMask & *AndC).isZero())
@@ -2227,18 +2366,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
// C - ctpop(X) => ctpop(~X) if C is bitwidth
- if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) &&
+ if (match(Op0, m_SpecificInt(BitWidth)) &&
match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X)))))
return replaceInstUsesWith(
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateNot(X)}));
+ // Reduce multiplies for difference-of-squares by factoring:
+ // (X * X) - (Y * Y) --> (X + Y) * (X - Y)
+ if (match(Op0, m_OneUse(m_Mul(m_Value(X), m_Deferred(X)))) &&
+ match(Op1, m_OneUse(m_Mul(m_Value(Y), m_Deferred(Y))))) {
+ auto *OBO0 = cast<OverflowingBinaryOperator>(Op0);
+ auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
+ bool PropagateNSW = I.hasNoSignedWrap() && OBO0->hasNoSignedWrap() &&
+ OBO1->hasNoSignedWrap() && BitWidth > 2;
+ bool PropagateNUW = I.hasNoUnsignedWrap() && OBO0->hasNoUnsignedWrap() &&
+ OBO1->hasNoUnsignedWrap() && BitWidth > 1;
+ Value *Add = Builder.CreateAdd(X, Y, "add", PropagateNUW, PropagateNSW);
+ Value *Sub = Builder.CreateSub(X, Y, "sub", PropagateNUW, PropagateNSW);
+ Value *Mul = Builder.CreateMul(Add, Sub, "", PropagateNUW, PropagateNSW);
+ return replaceInstUsesWith(I, Mul);
+ }
+
return TryToNarrowDeduceFlags();
}
/// This eliminates floating-point negation in either 'fneg(X)' or
/// 'fsub(-0.0, X)' form by combining into a constant operand.
-static Instruction *foldFNegIntoConstant(Instruction &I) {
+static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) {
// This is limited with one-use because fneg is assumed better for
// reassociation and cheaper in codegen than fmul/fdiv.
// TODO: Should the m_OneUse restriction be removed?
@@ -2252,28 +2407,31 @@ static Instruction *foldFNegIntoConstant(Instruction &I) {
// Fold negation into constant operand.
// -(X * C) --> X * (-C)
if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFMulFMF(X, NegC, &I);
// -(X / C) --> X / (-C)
if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFDivFMF(X, NegC, &I);
// -(C / X) --> (-C) / X
- if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) {
- Instruction *FDiv =
- BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I);
-
- // Intersect 'nsz' and 'ninf' because those special value exceptions may not
- // apply to the fdiv. Everything else propagates from the fneg.
- // TODO: We could propagate nsz/ninf from fdiv alone?
- FastMathFlags FMF = I.getFastMathFlags();
- FastMathFlags OpFMF = FNegOp->getFastMathFlags();
- FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros());
- FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs());
- return FDiv;
- }
+ if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X))))
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) {
+ Instruction *FDiv = BinaryOperator::CreateFDivFMF(NegC, X, &I);
+
+ // Intersect 'nsz' and 'ninf' because those special value exceptions may
+ // not apply to the fdiv. Everything else propagates from the fneg.
+ // TODO: We could propagate nsz/ninf from fdiv alone?
+ FastMathFlags FMF = I.getFastMathFlags();
+ FastMathFlags OpFMF = FNegOp->getFastMathFlags();
+ FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros());
+ FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs());
+ return FDiv;
+ }
// With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]:
// -(X + C) --> -X + -C --> -C - X
if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFSubFMF(NegC, X, &I);
return nullptr;
}
@@ -2301,7 +2459,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
getSimplifyQuery().getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- if (Instruction *X = foldFNegIntoConstant(I))
+ if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
Value *X, *Y;
@@ -2314,18 +2472,26 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
return R;
+ Value *OneUse;
+ if (!match(Op, m_OneUse(m_Value(OneUse))))
+ return nullptr;
+
// Try to eliminate fneg if at least 1 arm of the select is negated.
Value *Cond;
- if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) {
+ if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) {
// Unlike most transforms, this one is not safe to propagate nsz unless
- // it is present on the original select. (We are conservatively intersecting
- // the nsz flags from the select and root fneg instruction.)
+ // it is present on the original select. We union the flags from the select
+ // and fneg and then remove nsz if needed.
auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) {
S->copyFastMathFlags(&I);
- if (auto *OldSel = dyn_cast<SelectInst>(Op))
+ if (auto *OldSel = dyn_cast<SelectInst>(Op)) {
+ FastMathFlags FMF = I.getFastMathFlags();
+ FMF |= OldSel->getFastMathFlags();
+ S->setFastMathFlags(FMF);
if (!OldSel->hasNoSignedZeros() && !CommonOperand &&
!isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition()))
S->setHasNoSignedZeros(false);
+ }
};
// -(Cond ? -P : Y) --> Cond ? P : -Y
Value *P;
@@ -2344,6 +2510,21 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
}
}
+ // fneg (copysign x, y) -> copysign x, (fneg y)
+ if (match(OneUse, m_CopySign(m_Value(X), m_Value(Y)))) {
+ // The source copysign has an additional value input, so we can't propagate
+ // flags the copysign doesn't also have.
+ FastMathFlags FMF = I.getFastMathFlags();
+ FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags();
+
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(FMF);
+
+ Value *NegY = Builder.CreateFNeg(Y);
+ Value *NewCopySign = Builder.CreateCopySign(X, NegY);
+ return replaceInstUsesWith(I, NewCopySign);
+ }
+
return nullptr;
}
@@ -2370,7 +2551,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (match(&I, m_FNeg(m_Value(Op))))
return UnaryOperator::CreateFNegFMF(Op, &I);
- if (Instruction *X = foldFNegIntoConstant(I))
+ if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
@@ -2409,7 +2590,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
// But don't transform constant expressions because there's an inverse fold
// for X + (-Y) --> X - Y.
if (match(Op1, m_ImmConstant(C)))
- return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFAddFMF(Op0, NegC, &I);
// X - (-Y) --> X + Y
if (match(Op1, m_FNeg(m_Value(Y))))