summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineMulDivRem.cpp165
1 files changed, 102 insertions, 63 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 160792b0a0000..788097f33f121 100644
--- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -45,28 +45,28 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC,
// (PowerOfTwo >>u B) --> isExact since shifting out the result would make it
// inexact. Similarly for <<.
- if (BinaryOperator *I = dyn_cast<BinaryOperator>(V))
- if (I->isLogicalShift() &&
- isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0,
- IC.getAssumptionCache(), &CxtI,
- IC.getDominatorTree())) {
- // We know that this is an exact/nuw shift and that the input is a
- // non-zero context as well.
- if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
- I->setOperand(0, V2);
- MadeChange = true;
- }
+ BinaryOperator *I = dyn_cast<BinaryOperator>(V);
+ if (I && I->isLogicalShift() &&
+ isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0,
+ IC.getAssumptionCache(), &CxtI,
+ IC.getDominatorTree())) {
+ // We know that this is an exact/nuw shift and that the input is a
+ // non-zero context as well.
+ if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) {
+ I->setOperand(0, V2);
+ MadeChange = true;
+ }
- if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
- I->setIsExact();
- MadeChange = true;
- }
+ if (I->getOpcode() == Instruction::LShr && !I->isExact()) {
+ I->setIsExact();
+ MadeChange = true;
+ }
- if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
- I->setHasNoUnsignedWrap();
- MadeChange = true;
- }
+ if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) {
+ I->setHasNoUnsignedWrap();
+ MadeChange = true;
}
+ }
// TODO: Lots more we could do here:
// If V is a phi node, we can call this on each of its operands.
@@ -177,13 +177,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyUsingDistributiveLaws(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// X * -1 == 0 - X
if (match(Op1, m_AllOnes())) {
@@ -323,7 +323,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO))
if (SDiv->isExact()) {
if (Op1BO == Op1C)
- return ReplaceInstUsesWith(I, Op0BO);
+ return replaceInstUsesWith(I, Op0BO);
return BinaryOperator::CreateNeg(Op0BO);
}
@@ -374,10 +374,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) {
APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true);
Value *BoolCast = nullptr, *OtherOp = nullptr;
- if (MaskedValueIsZero(Op0, Negative2, 0, &I))
- BoolCast = Op0, OtherOp = Op1;
- else if (MaskedValueIsZero(Op1, Negative2, 0, &I))
- BoolCast = Op1, OtherOp = Op0;
+ if (MaskedValueIsZero(Op0, Negative2, 0, &I)) {
+ BoolCast = Op0;
+ OtherOp = Op1;
+ } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) {
+ BoolCast = Op1;
+ OtherOp = Op0;
+ }
if (BoolCast) {
Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()),
@@ -536,14 +539,14 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (isa<Constant>(Op0))
std::swap(Op0, Op1);
if (Value *V =
SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
bool AllowReassociate = I.hasUnsafeAlgebra();
@@ -574,7 +577,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
// Try to simplify "MDC * Constant"
if (isFMulOrFDivWithConstant(Op0))
if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// (MDC +/- C1) * C => (MDC * C) +/- (C1 * C)
Instruction *FAddSub = dyn_cast<Instruction>(Op0);
@@ -612,11 +615,22 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
}
}
- // sqrt(X) * sqrt(X) -> X
- if (AllowReassociate && (Op0 == Op1))
- if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0))
- if (II->getIntrinsicID() == Intrinsic::sqrt)
- return ReplaceInstUsesWith(I, II->getOperand(0));
+ if (Op0 == Op1) {
+ if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) {
+ // sqrt(X) * sqrt(X) -> X
+ if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt)
+ return replaceInstUsesWith(I, II->getOperand(0));
+
+ // fabs(X) * fabs(X) -> X * X
+ if (II->getIntrinsicID() == Intrinsic::fabs) {
+ Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0),
+ II->getOperand(0),
+ I.getName());
+ FMulVal->copyFastMathFlags(&I);
+ return FMulVal;
+ }
+ }
+ }
// Under unsafe algebra do:
// X * log2(0.5*Y) = X*log2(Y) - X
@@ -641,7 +655,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
Value *FMulVal = Builder->CreateFMul(OpX, Log2);
Value *FSub = Builder->CreateFSub(FMulVal, OpX);
FSub->takeName(&I);
- return ReplaceInstUsesWith(I, FSub);
+ return replaceInstUsesWith(I, FSub);
}
}
@@ -661,7 +675,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
if (N1) {
Value *FMul = Builder->CreateFMul(N0, N1);
FMul->takeName(&I);
- return ReplaceInstUsesWith(I, FMul);
+ return replaceInstUsesWith(I, FMul);
}
if (Opnd0->hasOneUse()) {
@@ -669,7 +683,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
Value *T = Builder->CreateFMul(N0, Opnd1);
Value *Neg = Builder->CreateFNeg(T);
Neg->takeName(&I);
- return ReplaceInstUsesWith(I, Neg);
+ return replaceInstUsesWith(I, Neg);
}
}
@@ -698,7 +712,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) {
Value *R = Builder->CreateFMul(T, Y);
R->takeName(&I);
- return ReplaceInstUsesWith(I, R);
+ return replaceInstUsesWith(I, R);
}
}
}
@@ -1043,10 +1057,10 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// Handle the integer div common cases
if (Instruction *Common = commonIDivTransforms(I))
@@ -1116,27 +1130,43 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// Handle the integer div common cases
if (Instruction *Common = commonIDivTransforms(I))
return Common;
- // sdiv X, -1 == -X
- if (match(Op1, m_AllOnes()))
- return BinaryOperator::CreateNeg(Op0);
+ const APInt *Op1C;
+ if (match(Op1, m_APInt(Op1C))) {
+ // sdiv X, -1 == -X
+ if (Op1C->isAllOnesValue())
+ return BinaryOperator::CreateNeg(Op0);
- if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) {
- // sdiv X, C --> ashr exact X, log2(C)
- if (I.isExact() && RHS->getValue().isNonNegative() &&
- RHS->getValue().isPowerOf2()) {
- Value *ShAmt = llvm::ConstantInt::get(RHS->getType(),
- RHS->getValue().exactLogBase2());
+ // sdiv exact X, C --> ashr exact X, log2(C)
+ if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) {
+ Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2());
return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName());
}
+
+ // If the dividend is sign-extended and the constant divisor is small enough
+ // to fit in the source type, shrink the division to the narrower type:
+ // (sext X) sdiv C --> sext (X sdiv C)
+ Value *Op0Src;
+ if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) &&
+ Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) {
+
+ // In the general case, we need to make sure that the dividend is not the
+ // minimum signed value because dividing that by -1 is UB. But here, we
+ // know that the -1 divisor case is already handled above.
+
+ Constant *NarrowDivisor =
+ ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType());
+ Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor);
+ return new SExtInst(NarrowOp, Op0->getType());
+ }
}
if (Constant *RHS = dyn_cast<Constant>(Op1)) {
@@ -1214,11 +1244,11 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(),
DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (isa<Constant>(Op0))
if (SelectInst *SI = dyn_cast<SelectInst>(Op1))
@@ -1363,8 +1393,17 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) {
if (Instruction *R = FoldOpIntoSelect(I, SI))
return R;
} else if (isa<PHINode>(Op0I)) {
- if (Instruction *NV = FoldOpIntoPhi(I))
- return NV;
+ using namespace llvm::PatternMatch;
+ const APInt *Op1Int;
+ if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() &&
+ (I.getOpcode() == Instruction::URem ||
+ !Op1Int->isMinSignedValue())) {
+ // FoldOpIntoPhi will speculate instructions to the end of the PHI's
+ // predecessor blocks, so do this only if we know the srem or urem
+ // will not fault.
+ if (Instruction *NV = FoldOpIntoPhi(I))
+ return NV;
+ }
}
// See if we can fold away this rem instruction.
@@ -1380,10 +1419,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Instruction *common = commonIRemTransforms(I))
return common;
@@ -1405,7 +1444,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) {
if (match(Op0, m_One())) {
Value *Cmp = Builder->CreateICmpNE(Op1, Op0);
Value *Ext = Builder->CreateZExt(Cmp, I.getType());
- return ReplaceInstUsesWith(I, Ext);
+ return replaceInstUsesWith(I, Ext);
}
return nullptr;
@@ -1415,10 +1454,10 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// Handle the integer rem common cases
if (Instruction *Common = commonIRemTransforms(I))
@@ -1490,11 +1529,11 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V = SimplifyVectorOp(I))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(),
DL, TLI, DT, AC))
- return ReplaceInstUsesWith(I, V);
+ return replaceInstUsesWith(I, V);
// Handle cases involving: rem X, (select Cond, Y, Z)
if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I))