aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp155
1 files changed, 101 insertions, 54 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 2774e46151fa..c6233a68847d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -72,7 +72,7 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
// We know that this is an exact/nuw shift and that the input is a
// non-zero context as well.
if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
- I->setOperand(0, V2);
+ IC.replaceOperand(*I, 0, V2);
MadeChange = true;
}
@@ -96,19 +96,22 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
/// A helper routine of InstCombiner::visitMul().
///
-/// If C is a scalar/vector of known powers of 2, then this function returns
-/// a new scalar/vector obtained from logBase2 of C.
+/// If C is a scalar/fixed width vector of known powers of 2, then this
+/// function returns a new scalar/fixed width vector obtained from logBase2
+/// of C.
/// Return a null pointer otherwise.
static Constant *getLogBase2(Type *Ty, Constant *C) {
const APInt *IVal;
if (match(C, m_APInt(IVal)) && IVal->isPowerOf2())
return ConstantInt::get(Ty, IVal->logBase2());
- if (!Ty->isVectorTy())
+ // FIXME: We can extract pow of 2 of splat constant for scalable vectors.
+ if (!isa<FixedVectorType>(Ty))
return nullptr;
SmallVector<Constant *, 4> Elts;
- for (unsigned I = 0, E = Ty->getVectorNumElements(); I != E; ++I) {
+ for (unsigned I = 0, E = cast<FixedVectorType>(Ty)->getNumElements(); I != E;
+ ++I) {
Constant *Elt = C->getAggregateElement(I);
if (!Elt)
return nullptr;
@@ -274,6 +277,15 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
+ // abs(X) * abs(X) -> X * X
+ // nabs(X) * nabs(X) -> X * X
+ if (Op0 == Op1) {
+ Value *X, *Y;
+ SelectPatternFlavor SPF = matchSelectPattern(Op0, X, Y).Flavor;
+ if (SPF == SPF_ABS || SPF == SPF_NABS)
+ return BinaryOperator::CreateMul(X, X);
+ }
+
// -X * C --> X * -C
Value *X, *Y;
Constant *Op1C;
@@ -354,6 +366,27 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
}
}
+ // (zext bool X) * (zext bool Y) --> zext (and X, Y)
+ // (sext bool X) * (sext bool Y) --> zext (and X, Y)
+ // Note: -1 * -1 == 1 * 1 == 1 (if the extends match, the result is the same)
+ if (((match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) ||
+ (match(Op0, m_SExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) &&
+ X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ Value *And = Builder.CreateAnd(X, Y, "mulbool");
+ return CastInst::Create(Instruction::ZExt, And, I.getType());
+ }
+ // (sext bool X) * (zext bool Y) --> sext (and X, Y)
+ // (zext bool X) * (sext bool Y) --> sext (and X, Y)
+ // Note: -1 * 1 == 1 * -1 == -1
+ if (((match(Op0, m_SExt(m_Value(X))) && match(Op1, m_ZExt(m_Value(Y)))) ||
+ (match(Op0, m_ZExt(m_Value(X))) && match(Op1, m_SExt(m_Value(Y))))) &&
+ X->getType()->isIntOrIntVectorTy(1) && X->getType() == Y->getType() &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ Value *And = Builder.CreateAnd(X, Y, "mulbool");
+ return CastInst::Create(Instruction::SExt, And, I.getType());
+ }
+
// (bool X) * Y --> X ? Y : 0
// Y * (bool X) --> X ? Y : 0
if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))
@@ -390,6 +423,40 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
return Changed ? &I : nullptr;
}
+Instruction *InstCombiner::foldFPSignBitOps(BinaryOperator &I) {
+ BinaryOperator::BinaryOps Opcode = I.getOpcode();
+ assert((Opcode == Instruction::FMul || Opcode == Instruction::FDiv) &&
+ "Expected fmul or fdiv");
+
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Value *X, *Y;
+
+ // -X * -Y --> X * Y
+ // -X / -Y --> X / Y
+ if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
+ return BinaryOperator::CreateWithCopiedFlags(Opcode, X, Y, &I);
+
+ // fabs(X) * fabs(X) -> X * X
+ // fabs(X) / fabs(X) -> X / X
+ if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
+ return BinaryOperator::CreateWithCopiedFlags(Opcode, X, X, &I);
+
+ // fabs(X) * fabs(Y) --> fabs(X * Y)
+ // fabs(X) / fabs(Y) --> fabs(X / Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))) &&
+ match(Op1, m_Intrinsic<Intrinsic::fabs>(m_Value(Y))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(I.getFastMathFlags());
+ Value *XY = Builder.CreateBinOp(Opcode, X, Y);
+ Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, XY);
+ Fabs->takeName(&I);
+ return replaceInstUsesWith(I, Fabs);
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (Value *V = SimplifyFMulInst(I.getOperand(0), I.getOperand(1),
I.getFastMathFlags(),
@@ -408,25 +475,20 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (Value *FoldedMul = foldMulSelectToNegate(I, Builder))
return replaceInstUsesWith(I, FoldedMul);
+ if (Instruction *R = foldFPSignBitOps(I))
+ return R;
+
// X * -1.0 --> -X
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (match(Op1, m_SpecificFP(-1.0)))
- return BinaryOperator::CreateFNegFMF(Op0, &I);
-
- // -X * -Y --> X * Y
- Value *X, *Y;
- if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y))))
- return BinaryOperator::CreateFMulFMF(X, Y, &I);
+ return UnaryOperator::CreateFNegFMF(Op0, &I);
// -X * C --> X * -C
+ Value *X, *Y;
Constant *C;
if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C)))
return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
- // fabs(X) * fabs(X) -> X * X
- if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X))))
- return BinaryOperator::CreateFMulFMF(X, X, &I);
-
// (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E)
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);
@@ -563,8 +625,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
Y = Op0;
}
if (Log2) {
- Log2->setArgOperand(0, X);
- Log2->copyFastMathFlags(&I);
+ Value *Log2 = Builder.CreateUnaryIntrinsic(Intrinsic::log2, X, &I);
Value *LogXTimesY = Builder.CreateFMulFMF(Log2, Y, &I);
return BinaryOperator::CreateFSubFMF(LogXTimesY, Y, &I);
}
@@ -592,7 +653,7 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) {
return false;
// Change the div/rem to use 'Y' instead of the select.
- I.setOperand(1, SI->getOperand(NonNullOperand));
+ replaceOperand(I, 1, SI->getOperand(NonNullOperand));
// Okay, we know we replace the operand of the div/rem with 'Y' with no
// problem. However, the select, or the condition of the select may have
@@ -620,12 +681,12 @@ bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) {
for (Instruction::op_iterator I = BBI->op_begin(), E = BBI->op_end();
I != E; ++I) {
if (*I == SI) {
- *I = SI->getOperand(NonNullOperand);
- Worklist.Add(&*BBI);
+ replaceUse(*I, SI->getOperand(NonNullOperand));
+ Worklist.push(&*BBI);
} else if (*I == SelectCond) {
- *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy)
- : ConstantInt::getFalse(CondTy);
- Worklist.Add(&*BBI);
+ replaceUse(*I, NonNullOperand == 1 ? ConstantInt::getTrue(CondTy)
+ : ConstantInt::getFalse(CondTy));
+ Worklist.push(&*BBI);
}
}
@@ -683,10 +744,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
Type *Ty = I.getType();
// The RHS is known non-zero.
- if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
- I.setOperand(1, V);
- return &I;
- }
+ if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I))
+ return replaceOperand(I, 1, V);
// Handle cases involving: [su]div X, (select Cond, Y, Z)
// This does not apply for fdiv.
@@ -800,8 +859,8 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) {
bool HasNSW = cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap();
bool HasNUW = cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap();
if ((IsSigned && HasNSW) || (!IsSigned && HasNUW)) {
- I.setOperand(0, ConstantInt::get(Ty, 1));
- I.setOperand(1, Y);
+ replaceOperand(I, 0, ConstantInt::get(Ty, 1));
+ replaceOperand(I, 1, Y);
return &I;
}
}
@@ -1214,6 +1273,9 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
if (Instruction *R = foldFDivConstantDividend(I))
return R;
+ if (Instruction *R = foldFPSignBitOps(I))
+ return R;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (isa<Constant>(Op0))
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
@@ -1274,21 +1336,14 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
}
}
- // -X / -Y -> X / Y
- Value *X, *Y;
- if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_FNeg(m_Value(Y)))) {
- I.setOperand(0, X);
- I.setOperand(1, Y);
- return &I;
- }
-
// X / (X * Y) --> 1.0 / Y
// Reassociate to (X / X -> 1.0) is legal when NaNs are not allowed.
// We can ignore the possibility that X is infinity because INF/INF is NaN.
+ Value *X, *Y;
if (I.hasNoNaNs() && I.hasAllowReassoc() &&
match(Op1, m_c_FMul(m_Specific(Op0), m_Value(Y)))) {
- I.setOperand(0, ConstantFP::get(I.getType(), 1.0));
- I.setOperand(1, Y);
+ replaceOperand(I, 0, ConstantFP::get(I.getType(), 1.0));
+ replaceOperand(I, 1, Y);
return &I;
}
@@ -1314,10 +1369,8 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
// The RHS is known non-zero.
- if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I)) {
- I.setOperand(1, V);
- return &I;
- }
+ if (Value *V = simplifyValueKnownNonZero(I.getOperand(1), *this, I))
+ return replaceOperand(I, 1, V);
// Handle cases involving: rem X, (select Cond, Y, Z)
if (simplifyDivRemOfSelectWithZeroOp(I))
@@ -1417,11 +1470,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
{
const APInt *Y;
// X % -Y -> X % Y
- if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue()) {
- Worklist.AddValue(I.getOperand(1));
- I.setOperand(1, ConstantInt::get(I.getType(), -*Y));
- return &I;
- }
+ if (match(Op1, m_Negative(Y)) && !Y->isMinSignedValue())
+ return replaceOperand(I, 1, ConstantInt::get(I.getType(), -*Y));
}
// -X srem Y --> -(X srem Y)
@@ -1441,7 +1491,7 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
// If it's a constant vector, flip any negative values positive.
if (isa<ConstantVector>(Op1) || isa<ConstantDataVector>(Op1)) {
Constant *C = cast<Constant>(Op1);
- unsigned VWidth = C->getType()->getVectorNumElements();
+ unsigned VWidth = cast<VectorType>(C->getType())->getNumElements();
bool hasNegative = false;
bool hasMissing = false;
@@ -1468,11 +1518,8 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
}
Constant *NewRHSV = ConstantVector::get(Elts);
- if (NewRHSV != C) { // Don't loop on -MININT
- Worklist.AddValue(I.getOperand(1));
- I.setOperand(1, NewRHSV);
- return &I;
- }
+ if (NewRHSV != C) // Don't loop on -MININT
+ return replaceOperand(I, 1, NewRHSV);
}
}