diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp | 328 |
1 files changed, 255 insertions, 73 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 4a459ec6c550..b68efc993723 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -576,8 +576,7 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { } } - assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) && - "out-of-bound access"); + assert((NextTmpIdx <= std::size(TmpResult) + 1) && "out-of-bound access"); Value *Result; if (!SimpVect.empty()) @@ -849,6 +848,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Type *Ty = Add.getType(); Constant *Op1C; if (!match(Op1, m_ImmConstant(Op1C))) return nullptr; @@ -883,7 +883,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { if (match(Op0, m_Not(m_Value(X)))) return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + // (iN X s>> (N - 1)) + 1 --> zext (X > -1) const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (match(Op0, m_OneUse(m_AShr(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1)))) && + match(Op1, m_One())) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (!match(Op1, m_APInt(C))) return nullptr; @@ -911,7 +918,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X - Type *Ty = Add.getType(); if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); @@ -969,15 +975,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } - // If all bits affected by the add are included in a high-bit-mask, do the - // add before the mask op: - // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00 - if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) && - C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) { - Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C)); - return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2)); - } - return nullptr; } @@ -1132,6 +1129,35 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +/// Try to reduce signed division by power-of-2 to an arithmetic shift right. +static Instruction *foldAddToAshr(BinaryOperator &Add) { + // Division must be by power-of-2, but not the minimum signed value. + Value *X; + const APInt *DivC; + if (!match(Add.getOperand(0), m_SDiv(m_Value(X), m_Power2(DivC))) || + DivC->isNegative()) + return nullptr; + + // Rounding is done by adding -1 if the dividend (X) is negative and has any + // low bits set. The canonical pattern for that is an "ugt" compare with SMIN: + // sext (icmp ugt (X & (DivC - 1)), SMIN) + const APInt *MaskC; + ICmpInst::Predicate Pred; + if (!match(Add.getOperand(1), + m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)), + m_SignMask()))) || + Pred != ICmpInst::ICMP_UGT) + return nullptr; + + APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits()); + if (*MaskC != (SMin | (*DivC - 1))) + return nullptr; + + // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC) + return BinaryOperator::CreateAShr( + X, ConstantInt::get(Add.getType(), DivC->exactLogBase2())); +} + Instruction *InstCombinerImpl:: canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( BinaryOperator &I) { @@ -1234,7 +1260,7 @@ Instruction *InstCombinerImpl:: } /// This is a specialization of a more general transform from -/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// foldUsingDistributiveLaws. If that code can be made to work optimally /// for multi-use cases or propagating nsw/nuw, then we would not need this. static Instruction *factorizeMathWithShlOps(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1270,6 +1296,45 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I, return NewShl; } +/// Reduce a sequence of masked half-width multiplies to a single multiply. +/// ((XLow * YHigh) + (YLow * XHigh)) << HalfBits) + (XLow * YLow) --> X * Y +static Instruction *foldBoxMultiply(BinaryOperator &I) { + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + // Skip the odd bitwidth types. + if ((BitWidth & 0x1)) + return nullptr; + + unsigned HalfBits = BitWidth >> 1; + APInt HalfMask = APInt::getMaxValue(HalfBits); + + // ResLo = (CrossSum << HalfBits) + (YLo * XLo) + Value *XLo, *YLo; + Value *CrossSum; + if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), + m_Mul(m_Value(YLo), m_Value(XLo))))) + return nullptr; + + // XLo = X & HalfMask + // YLo = Y & HalfMask + // TODO: Refactor with SimplifyDemandedBits or KnownBits known leading zeros + // to enhance robustness + Value *X, *Y; + if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) || + !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask)))) + return nullptr; + + // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits)) + // X' can be either X or XLo in the pattern (and the same for Y') + if (match(CrossSum, + m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(X), m_Specific(XLo))), + m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(Y), m_Specific(YLo)))))) + return BinaryOperator::CreateMul(X, Y); + + return nullptr; +} + Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1286,9 +1351,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return Phi; // (A*B)+(A*C) -> A*(B+C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldBoxMultiply(I)) + return R; + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) return R; @@ -1376,35 +1444,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, NewMask); } + // ZExt (B - A) + ZExt(A) --> ZExt(B) + if ((match(RHS, m_ZExt(m_Value(A))) && + match(LHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))) || + (match(LHS, m_ZExt(m_Value(A))) && + match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A)))))) + return new ZExtInst(B, LHS->getType()); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - // add (select X 0 (sub n A)) A --> select X A n - { - SelectInst *SI = dyn_cast<SelectInst>(LHS); - Value *A = RHS; - if (!SI) { - SI = dyn_cast<SelectInst>(RHS); - A = LHS; - } - if (SI && SI->hasOneUse()) { - Value *TV = SI->getTrueValue(); - Value *FV = SI->getFalseValue(); - Value *N; - - // Can we fold the add into the argument of the select? - // We check both true and false select arguments for a matching subtract. - if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the true select value. - return SelectInst::Create(SI->getCondition(), N, A); - - if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the false select value. - return SelectInst::Create(SI->getCondition(), A, N); - } - } - if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1424,6 +1474,68 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return &I; } + // (add A (or A, -A)) --> (and (add A, -1) A) + // (add A (or -A, A)) --> (and (add A, -1) A) + // (add (or A, -A) A) --> (and (add A, -1) A) + // (add (or -A, A) A) --> (and (add A, -1) A) + if (match(&I, m_c_BinOp(m_Value(A), m_OneUse(m_c_Or(m_Neg(m_Deferred(A)), + m_Deferred(A)))))) { + Value *Add = + Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()), "", + I.hasNoUnsignedWrap(), I.hasNoSignedWrap()); + return BinaryOperator::CreateAnd(Add, A); + } + + // Canonicalize ((A & -A) - 1) --> ((A - 1) & ~A) + // Forms all commutable operations, and simplifies ctpop -> cttz folds. + if (match(&I, + m_Add(m_OneUse(m_c_And(m_Value(A), m_OneUse(m_Neg(m_Deferred(A))))), + m_AllOnes()))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(RHS->getType()); + Value *Dec = Builder.CreateAdd(A, AllOnes); + Value *Not = Builder.CreateXor(A, AllOnes); + return BinaryOperator::CreateAnd(Dec, Not); + } + + // Disguised reassociation/factorization: + // ~(A * C1) + A + // ((A * -C1) - 1) + A + // ((A * -C1) + A) - 1 + // (A * (1 - C1)) - 1 + if (match(&I, + m_c_Add(m_OneUse(m_Not(m_OneUse(m_Mul(m_Value(A), m_APInt(C1))))), + m_Deferred(A)))) { + Type *Ty = I.getType(); + Constant *NewMulC = ConstantInt::get(Ty, 1 - *C1); + Value *NewMul = Builder.CreateMul(A, NewMulC); + return BinaryOperator::CreateAdd(NewMul, ConstantInt::getAllOnesValue(Ty)); + } + + // (A * -2**C) + B --> B - (A << C) + const APInt *NegPow2C; + if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))), + m_Value(B)))) { + Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros()); + Value *Shl = Builder.CreateShl(A, ShiftAmtC); + return BinaryOperator::CreateSub(B, Shl); + } + + // Canonicalize signum variant that ends in add: + // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0)) + ICmpInst::Predicate Pred; + uint64_t BitWidth = Ty->getScalarSizeInBits(); + if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) && + match(RHS, m_OneUse(m_ZExt( + m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) && + Pred == CmpInst::ICMP_SGT) { + Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull"); + Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext"); + return BinaryOperator::CreateOr(LHS, Zext); + } + + if (Instruction *Ashr = foldAddToAshr(I)) + return Ashr; + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. @@ -1665,6 +1777,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(X, NewMulC, &I); } + // (-X - Y) + (X + Z) --> Z - Y + if (match(&I, m_c_FAdd(m_FSub(m_FNeg(m_Value(X)), m_Value(Y)), + m_c_FAdd(m_Deferred(X), m_Value(Z))))) + return BinaryOperator::CreateFSubFMF(Z, Y, &I); + if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1879,7 +1996,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return TryToNarrowDeduceFlags(); // Should have been handled in Negator! // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (I.getType()->isIntOrIntVectorTy(1)) @@ -1967,12 +2084,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } const APInt *Op0C; - if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { - // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known - // zero. - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnes()) - return BinaryOperator::CreateXor(Op1, Op0); + if (match(Op0, m_APInt(Op0C))) { + if (Op0C->isMask()) { + // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known + // zero. + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateXor(Op1, Op0); + } + + // C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when: + // (C3 - ((C2 & C3) - 1)) is pow2 + // ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1) + // C2 is negative pow2 || sub nuw + const APInt *C2, *C3; + BinaryOperator *InnerSub; + if (match(Op1, m_OneUse(m_And(m_BinOp(InnerSub), m_APInt(C2)))) && + match(InnerSub, m_Sub(m_APInt(C3), m_Value(X))) && + (InnerSub->hasNoUnsignedWrap() || C2->isNegatedPowerOf2())) { + APInt C2AndC3 = *C2 & *C3; + APInt C2AndC3Minus1 = C2AndC3 - 1; + APInt C2AddC3 = *C2 + *C3; + if ((*C3 - C2AndC3Minus1).isPowerOf2() && + C2AndC3Minus1.isSubsetOf(C2AddC3)) { + Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), *C2)); + return BinaryOperator::CreateAdd( + And, ConstantInt::get(I.getType(), *Op0C - C2AndC3)); + } + } } { @@ -2165,8 +2304,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { Value *A; const APInt *ShAmt; Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && - Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + Op1->hasNUses(2) && *ShAmt == BitWidth - 1 && match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { // B = ashr i32 A, 31 ; smear the sign bit // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) @@ -2185,7 +2325,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const APInt *AddC, *AndC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { - unsigned BitWidth = Ty->getScalarSizeInBits(); unsigned Cttz = AddC->countTrailingZeros(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); if ((HighMask & *AndC).isZero()) @@ -2227,18 +2366,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } // C - ctpop(X) => ctpop(~X) if C is bitwidth - if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && + if (match(Op0, m_SpecificInt(BitWidth)) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateNot(X)})); + // Reduce multiplies for difference-of-squares by factoring: + // (X * X) - (Y * Y) --> (X + Y) * (X - Y) + if (match(Op0, m_OneUse(m_Mul(m_Value(X), m_Deferred(X)))) && + match(Op1, m_OneUse(m_Mul(m_Value(Y), m_Deferred(Y))))) { + auto *OBO0 = cast<OverflowingBinaryOperator>(Op0); + auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); + bool PropagateNSW = I.hasNoSignedWrap() && OBO0->hasNoSignedWrap() && + OBO1->hasNoSignedWrap() && BitWidth > 2; + bool PropagateNUW = I.hasNoUnsignedWrap() && OBO0->hasNoUnsignedWrap() && + OBO1->hasNoUnsignedWrap() && BitWidth > 1; + Value *Add = Builder.CreateAdd(X, Y, "add", PropagateNUW, PropagateNSW); + Value *Sub = Builder.CreateSub(X, Y, "sub", PropagateNUW, PropagateNSW); + Value *Mul = Builder.CreateMul(Add, Sub, "", PropagateNUW, PropagateNSW); + return replaceInstUsesWith(I, Mul); + } + return TryToNarrowDeduceFlags(); } /// This eliminates floating-point negation in either 'fneg(X)' or /// 'fsub(-0.0, X)' form by combining into a constant operand. -static Instruction *foldFNegIntoConstant(Instruction &I) { +static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) { // This is limited with one-use because fneg is assumed better for // reassociation and cheaper in codegen than fmul/fdiv. // TODO: Should the m_OneUse restriction be removed? @@ -2252,28 +2407,31 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { // Fold negation into constant operand. // -(X * C) --> X * (-C) if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFMulFMF(X, NegC, &I); // -(X / C) --> X / (-C) if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(X, NegC, &I); // -(C / X) --> (-C) / X - if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { - Instruction *FDiv = - BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); - - // Intersect 'nsz' and 'ninf' because those special value exceptions may not - // apply to the fdiv. Everything else propagates from the fneg. - // TODO: We could propagate nsz/ninf from fdiv alone? - FastMathFlags FMF = I.getFastMathFlags(); - FastMathFlags OpFMF = FNegOp->getFastMathFlags(); - FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); - FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); - return FDiv; - } + if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) { + Instruction *FDiv = BinaryOperator::CreateFDivFMF(NegC, X, &I); + + // Intersect 'nsz' and 'ninf' because those special value exceptions may + // not apply to the fdiv. Everything else propagates from the fneg. + // TODO: We could propagate nsz/ninf from fdiv alone? + FastMathFlags FMF = I.getFastMathFlags(); + FastMathFlags OpFMF = FNegOp->getFastMathFlags(); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); + return FDiv; + } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: // -(X + C) --> -X + -C --> -C - X if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFSubFMF(NegC, X, &I); return nullptr; } @@ -2301,7 +2459,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; Value *X, *Y; @@ -2314,18 +2472,26 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) return R; + Value *OneUse; + if (!match(Op, m_OneUse(m_Value(OneUse)))) + return nullptr; + // Try to eliminate fneg if at least 1 arm of the select is negated. Value *Cond; - if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) { + if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) { // Unlike most transforms, this one is not safe to propagate nsz unless - // it is present on the original select. (We are conservatively intersecting - // the nsz flags from the select and root fneg instruction.) + // it is present on the original select. We union the flags from the select + // and fneg and then remove nsz if needed. auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); - if (auto *OldSel = dyn_cast<SelectInst>(Op)) + if (auto *OldSel = dyn_cast<SelectInst>(Op)) { + FastMathFlags FMF = I.getFastMathFlags(); + FMF |= OldSel->getFastMathFlags(); + S->setFastMathFlags(FMF); if (!OldSel->hasNoSignedZeros() && !CommonOperand && !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) S->setHasNoSignedZeros(false); + } }; // -(Cond ? -P : Y) --> Cond ? P : -Y Value *P; @@ -2344,6 +2510,21 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { } } + // fneg (copysign x, y) -> copysign x, (fneg y) + if (match(OneUse, m_CopySign(m_Value(X), m_Value(Y)))) { + // The source copysign has an additional value input, so we can't propagate + // flags the copysign doesn't also have. + FastMathFlags FMF = I.getFastMathFlags(); + FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags(); + + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); + + Value *NegY = Builder.CreateFNeg(Y); + Value *NewCopySign = Builder.CreateCopySign(X, NegY); + return replaceInstUsesWith(I, NewCopySign); + } + return nullptr; } @@ -2370,7 +2551,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (match(&I, m_FNeg(m_Value(Op)))) return UnaryOperator::CreateFNegFMF(Op, &I); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) @@ -2409,7 +2590,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. if (match(Op1, m_ImmConstant(C))) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFAddFMF(Op0, NegC, &I); // X - (-Y) --> X + Y if (match(Op1, m_FNeg(m_Value(Y)))) |