diff options
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 155 |
1 files changed, 101 insertions, 54 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 2774e46151fa..c6233a68847d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -72,7 +72,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // We know that this is an exact/nuw shift and that the input is a // non-zero context as well. if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { - I->setOperand(0, V2); + IC.replaceOperand(*I, 0, V2); MadeChange = true; } @@ -96,19 +96,22 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, /// A helper routine of InstCombiner::visitMul(). /// -/// If C is a scalar/vector of known powers of 2, then this function returns -/// a new scalar/vector obtained from logBase2 of C. +/// If C is a scalar/fixed width vector of known powers of 2, then this +/// function returns a new scalar/fixed width vector obtained from logBase2 +/// of C. /// Return a null pointer otherwise. static Constant *getLogBase2(Type *Ty, Constant *C) { const APInt *IVal; if (match(C, m_APInt(IVal)) && IVal->isPowerOf2()) return ConstantInt::get(Ty, IVal->logBase2()); - if (!Ty->isVectorTy()) + // FIXME: We can extract pow of 2 of splat constant for scalable vectors. + if (!isa<FixedVectorType>(Ty)) return nullptr; SmallVector<Constant *, 4> Elts; - for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) { + for (unsigned I = 0, E = cast<FixedVectorType>(Ty)->getNumElements(); I != E; + ++I) { Constant *Elt = C->getAggregateElement(I); if (!Elt) return nullptr; @@ -274,6 +277,15 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // abs(X) * abs(X) -> X * X + // nabs(X) * nabs(X) -> X * X + if (Op0 == Op1) { + Value *X, *Y; + SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor; + if (SPF == SPF_ABS || SPF == SPF_NABS) + return BinaryOperator::CreateMul(X, X); + } + // -X * C --> X * -C Value *X, *Y; Constant *Op1C; @@ -354,6 +366,27 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { } } + // (zext bool X) * (zext bool Y) --> zext (and X, Y) + // (sext bool X) * (sext bool Y) --> zext (and X, Y) + // Note: -1 * -1 == 1 * 1 == 1 (if the extends match, the result is the same) + if (((match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_SExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::ZExt, And, I.getType()); + } + // (sext bool X) * (zext bool Y) --> sext (and X, Y) + // (zext bool X) * (sext bool Y) --> sext (and X, Y) + // Note: -1 * 1 == 1 * -1 == -1 + if (((match(Op0, m_SExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) || + (match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) && + X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() && + (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *And = Builder.CreateAnd(X, Y, "mulbool"); + return CastInst::Create(Instruction::SExt, And, I.getType()); + } + // (bool X) * Y --> X ? Y : 0 // Y * (bool X) --> X ? Y : 0 if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) @@ -390,6 +423,40 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { return Changed ? &I : nullptr; } +Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) { + BinaryOperator::BinaryOps Opcode = I.getOpcode(); + assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) && + "Expected fmul or fdiv"); + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *X, *Y; + + // -X * -Y --> X * Y + // -X / -Y --> X / Y + if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, Y, &I); + + // fabs(X) * fabs(X) -> X * X + // fabs(X) / fabs(X) -> X / X + if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) + return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I); + + // fabs(X) * fabs(Y) --> fabs(X * Y) + // fabs(X) / fabs(Y) --> fabs(X / Y) + if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + Value *XY = Builder.CreateBinOp(Opcode, X, Y); + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY); + Fabs->takeName(&I); + return replaceInstUsesWith(I, Fabs); + } + + return nullptr; +} + Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -408,25 +475,20 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) return replaceInstUsesWith(I, FoldedMul); + if (Instruction *R = foldFPSignBitOps(I)) + return R; + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) - return BinaryOperator::CreateFNegFMF(Op0, &I); - - // -X * -Y --> X * Y - Value *X, *Y; - if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) - return BinaryOperator::CreateFMulFMF(X, Y, &I); + return UnaryOperator::CreateFNegFMF(Op0, &I); // -X * C --> X * -C + Value *X, *Y; Constant *C; if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - // fabs(X) * fabs(X) -> X * X - if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) - return BinaryOperator::CreateFMulFMF(X, X, &I); - // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); @@ -563,8 +625,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Y = Op0; } if (Log2) { - Log2->setArgOperand(0, X); - Log2->copyFastMathFlags(&I); + Value *Log2 = Builder.CreateUnaryIntrinsic(Intrinsic::log2, X, &I); Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I); return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I); } @@ -592,7 +653,7 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { return false; // Change the div/rem to use 'Y' instead of the select. - I.setOperand(1, SI->getOperand(NonNullOperand)); + replaceOperand(I, 1, SI->getOperand(NonNullOperand)); // Okay, we know we replace the operand of the div/rem with 'Y' with no // problem. However, the select, or the condition of the select may have @@ -620,12 +681,12 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end(); I != E; ++I) { if (*I == SI) { - *I = SI->getOperand(NonNullOperand); - Worklist.Add(&*BBI); + replaceUse(*I, SI->getOperand(NonNullOperand)); + Worklist.push(&*BBI); } else if (*I == SelectCond) { - *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) - : ConstantInt::getFalse(CondTy); - Worklist.Add(&*BBI); + replaceUse(*I, NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) + : ConstantInt::getFalse(CondTy)); + Worklist.push(&*BBI); } } @@ -683,10 +744,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { Type *Ty = I.getType(); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { - I.setOperand(1, V); - return &I; - } + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); // Handle cases involving: [su]div X, (select Cond, Y, Z) // This does not apply for fdiv. @@ -800,8 +859,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap(); bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap(); if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) { - I.setOperand(0, ConstantInt::get(Ty, 1)); - I.setOperand(1, Y); + replaceOperand(I, 0, ConstantInt::get(Ty, 1)); + replaceOperand(I, 1, Y); return &I; } } @@ -1214,6 +1273,9 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Instruction *R = foldFDivConstantDividend(I)) return R; + if (Instruction *R = foldFPSignBitOps(I)) + return R; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1274,21 +1336,14 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } - // -X / -Y -> X / Y - Value *X, *Y; - if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) { - I.setOperand(0, X); - I.setOperand(1, Y); - return &I; - } - // X / (X * Y) --> 1.0 / Y // Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed. // We can ignore the possibility that X is infinity because INF/INF is NaN. + Value *X, *Y; if (I.hasNoNaNs() && I.hasAllowReassoc() && match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) { - I.setOperand(0, ConstantFP::get(I.getType(), 1.0)); - I.setOperand(1, Y); + replaceOperand(I, 0, ConstantFP::get(I.getType(), 1.0)); + replaceOperand(I, 1, Y); return &I; } @@ -1314,10 +1369,8 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // The RHS is known non-zero. - if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) { - I.setOperand(1, V); - return &I; - } + if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) + return replaceOperand(I, 1, V); // Handle cases involving: rem X, (select Cond, Y, Z) if (simplifyDivRemOfSelectWithZeroOp(I)) @@ -1417,11 +1470,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { { const APInt *Y; // X % -Y -> X % Y - if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) { - Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, ConstantInt::get(I.getType(), -*Y)); - return &I; - } + if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) + return replaceOperand(I, 1, ConstantInt::get(I.getType(), -*Y)); } // -X srem Y --> -(X srem Y) @@ -1441,7 +1491,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { // If it's a constant vector, flip any negative values positive. if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) { Constant *C = cast<Constant>(Op1); - unsigned VWidth = C->getType()->getVectorNumElements(); + unsigned VWidth = cast<VectorType>(C->getType())->getNumElements(); bool hasNegative = false; bool hasMissing = false; @@ -1468,11 +1518,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { } Constant *NewRHSV = ConstantVector::get(Elts); - if (NewRHSV != C) { // Don't loop on -MININT - Worklist.AddValue(I.getOperand(1)); - I.setOperand(1, NewRHSV); - return &I; - } + if (NewRHSV != C) // Don't loop on -MININT + return replaceOperand(I, 1, NewRHSV); } } |