diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 165 |
1 files changed, 102 insertions, 63 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 160792b0a0000..788097f33f121 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -45,28 +45,28 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, // (PowerOfTwo >>u B) --> isExact since shifting out the result would make it // inexact. Similarly for <<. - if (BinaryOperator *I = dyn_cast<BinaryOperator>(V)) - if (I->isLogicalShift() && - isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, - IC.getAssumptionCache(), &CxtI, - IC.getDominatorTree())) { - // We know that this is an exact/nuw shift and that the input is a - // non-zero context as well. - if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { - I->setOperand(0, V2); - MadeChange = true; - } + BinaryOperator *I = dyn_cast<BinaryOperator>(V); + if (I && I->isLogicalShift() && + isKnownToBeAPowerOfTwo(I->getOperand(0), IC.getDataLayout(), false, 0, + IC.getAssumptionCache(), &CxtI, + IC.getDominatorTree())) { + // We know that this is an exact/nuw shift and that the input is a + // non-zero context as well. + if (Value *V2 = simplifyValueKnownNonZero(I->getOperand(0), IC, CxtI)) { + I->setOperand(0, V2); + MadeChange = true; + } - if (I->getOpcode() == Instruction::LShr && !I->isExact()) { - I->setIsExact(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::LShr && !I->isExact()) { + I->setIsExact(); + MadeChange = true; + } - if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { - I->setHasNoUnsignedWrap(); - MadeChange = true; - } + if (I->getOpcode() == Instruction::Shl && !I->hasNoUnsignedWrap()) { + I->setHasNoUnsignedWrap(); + MadeChange = true; } + } // TODO: Lots more we could do here: // If V is a phi node, we can call this on each of its operands. @@ -177,13 +177,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyMulInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUsingDistributiveLaws(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // X * -1 == 0 - X if (match(Op1, m_AllOnes())) { @@ -323,7 +323,7 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (PossiblyExactOperator *SDiv = dyn_cast<PossiblyExactOperator>(BO)) if (SDiv->isExact()) { if (Op1BO == Op1C) - return ReplaceInstUsesWith(I, Op0BO); + return replaceInstUsesWith(I, Op0BO); return BinaryOperator::CreateNeg(Op0BO); } @@ -374,10 +374,13 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { APInt Negative2(I.getType()->getPrimitiveSizeInBits(), (uint64_t)-2, true); Value *BoolCast = nullptr, *OtherOp = nullptr; - if (MaskedValueIsZero(Op0, Negative2, 0, &I)) - BoolCast = Op0, OtherOp = Op1; - else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) - BoolCast = Op1, OtherOp = Op0; + if (MaskedValueIsZero(Op0, Negative2, 0, &I)) { + BoolCast = Op0; + OtherOp = Op1; + } else if (MaskedValueIsZero(Op1, Negative2, 0, &I)) { + BoolCast = Op1; + OtherOp = Op0; + } if (BoolCast) { Value *V = Builder->CreateSub(Constant::getNullValue(I.getType()), @@ -536,14 +539,14 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) std::swap(Op0, Op1); if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); bool AllowReassociate = I.hasUnsafeAlgebra(); @@ -574,7 +577,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // Try to simplify "MDC * Constant" if (isFMulOrFDivWithConstant(Op0)) if (Value *V = foldFMulConst(cast<Instruction>(Op0), C, &I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // (MDC +/- C1) * C => (MDC * C) +/- (C1 * C) Instruction *FAddSub = dyn_cast<Instruction>(Op0); @@ -612,11 +615,22 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } - // sqrt(X) * sqrt(X) -> X - if (AllowReassociate && (Op0 == Op1)) - if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) - if (II->getIntrinsicID() == Intrinsic::sqrt) - return ReplaceInstUsesWith(I, II->getOperand(0)); + if (Op0 == Op1) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op0)) { + // sqrt(X) * sqrt(X) -> X + if (AllowReassociate && II->getIntrinsicID() == Intrinsic::sqrt) + return replaceInstUsesWith(I, II->getOperand(0)); + + // fabs(X) * fabs(X) -> X * X + if (II->getIntrinsicID() == Intrinsic::fabs) { + Instruction *FMulVal = BinaryOperator::CreateFMul(II->getOperand(0), + II->getOperand(0), + I.getName()); + FMulVal->copyFastMathFlags(&I); + return FMulVal; + } + } + } // Under unsafe algebra do: // X * log2(0.5*Y) = X*log2(Y) - X @@ -641,7 +655,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *FMulVal = Builder->CreateFMul(OpX, Log2); Value *FSub = Builder->CreateFSub(FMulVal, OpX); FSub->takeName(&I); - return ReplaceInstUsesWith(I, FSub); + return replaceInstUsesWith(I, FSub); } } @@ -661,7 +675,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (N1) { Value *FMul = Builder->CreateFMul(N0, N1); FMul->takeName(&I); - return ReplaceInstUsesWith(I, FMul); + return replaceInstUsesWith(I, FMul); } if (Opnd0->hasOneUse()) { @@ -669,7 +683,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *T = Builder->CreateFMul(N0, Opnd1); Value *Neg = Builder->CreateFNeg(T); Neg->takeName(&I); - return ReplaceInstUsesWith(I, Neg); + return replaceInstUsesWith(I, Neg); } } @@ -698,7 +712,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { Value *R = Builder->CreateFMul(T, Y); R->takeName(&I); - return ReplaceInstUsesWith(I, R); + return replaceInstUsesWith(I, R); } } } @@ -1043,10 +1057,10 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyUDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) @@ -1116,27 +1130,43 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySDivInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer div common cases if (Instruction *Common = commonIDivTransforms(I)) return Common; - // sdiv X, -1 == -X - if (match(Op1, m_AllOnes())) - return BinaryOperator::CreateNeg(Op0); + const APInt *Op1C; + if (match(Op1, m_APInt(Op1C))) { + // sdiv X, -1 == -X + if (Op1C->isAllOnesValue()) + return BinaryOperator::CreateNeg(Op0); - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { - // sdiv X, C --> ashr exact X, log2(C) - if (I.isExact() && RHS->getValue().isNonNegative() && - RHS->getValue().isPowerOf2()) { - Value *ShAmt = llvm::ConstantInt::get(RHS->getType(), - RHS->getValue().exactLogBase2()); + // sdiv exact X, C --> ashr exact X, log2(C) + if (I.isExact() && Op1C->isNonNegative() && Op1C->isPowerOf2()) { + Value *ShAmt = ConstantInt::get(Op1->getType(), Op1C->exactLogBase2()); return BinaryOperator::CreateExactAShr(Op0, ShAmt, I.getName()); } + + // If the dividend is sign-extended and the constant divisor is small enough + // to fit in the source type, shrink the division to the narrower type: + // (sext X) sdiv C --> sext (X sdiv C) + Value *Op0Src; + if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && + Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + + // In the general case, we need to make sure that the dividend is not the + // minimum signed value because dividing that by -1 is UB. But here, we + // know that the -1 divisor case is already handled above. + + Constant *NarrowDivisor = + ConstantExpr::getTrunc(cast<Constant>(Op1), Op0Src->getType()); + Value *NarrowOp = Builder->CreateSDiv(Op0Src, NarrowDivisor); + return new SExtInst(NarrowOp, Op0->getType()); + } } if (Constant *RHS = dyn_cast<Constant>(Op1)) { @@ -1214,11 +1244,11 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (isa<Constant>(Op0)) if (SelectInst *SI = dyn_cast<SelectInst>(Op1)) @@ -1363,8 +1393,17 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; } else if (isa<PHINode>(Op0I)) { - if (Instruction *NV = FoldOpIntoPhi(I)) - return NV; + using namespace llvm::PatternMatch; + const APInt *Op1Int; + if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && + (I.getOpcode() == Instruction::URem || + !Op1Int->isMinSignedValue())) { + // FoldOpIntoPhi will speculate instructions to the end of the PHI's + // predecessor blocks, so do this only if we know the srem or urem + // will not fault. + if (Instruction *NV = FoldOpIntoPhi(I)) + return NV; + } } // See if we can fold away this rem instruction. @@ -1380,10 +1419,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyURemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Instruction *common = commonIRemTransforms(I)) return common; @@ -1405,7 +1444,7 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (match(Op0, m_One())) { Value *Cmp = Builder->CreateICmpNE(Op1, Op0); Value *Ext = Builder->CreateZExt(Cmp, I.getType()); - return ReplaceInstUsesWith(I, Ext); + return replaceInstUsesWith(I, Ext); } return nullptr; @@ -1415,10 +1454,10 @@ Instruction *InstCombiner::visitSRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifySRemInst(Op0, Op1, DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle the integer rem common cases if (Instruction *Common = commonIRemTransforms(I)) @@ -1490,11 +1529,11 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = SimplifyVectorOp(I)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), DL, TLI, DT, AC)) - return ReplaceInstUsesWith(I, V); + return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) |