diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-07-26 19:03:47 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-07-26 19:04:23 +0000 |
| commit | 7fa27ce4a07f19b07799a767fc29416f3b625afb (patch) | |
| tree | 27825c83636c4de341eb09a74f49f5d38a15d165 /llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | |
| parent | e3b557809604d036af6e00c60f012c2025b59a5e (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 244 |
1 files changed, 223 insertions, 21 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 97f129e200de..50458e2773e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, return nullptr; } +static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, + bool AssumeNonZero, bool DoFold); + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (Value *V = @@ -270,7 +273,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (match(Op0, m_ZExtOrSExt(m_Value(X))) && match(Op1, m_APIntAllowUndef(NegPow2C))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - unsigned ShiftAmt = NegPow2C->countTrailingZeros(); + unsigned ShiftAmt = NegPow2C->countr_zero(); if (ShiftAmt >= BitWidth - SrcWidth) { Value *N = Builder.CreateNeg(X, X->getName() + ".neg"); Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z"); @@ -471,6 +474,40 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; + if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I)) + return Res; + + // min(X, Y) * max(X, Y) => X * Y. + if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)), + m_c_SMin(m_Deferred(X), m_Deferred(Y))), + m_c_Mul(m_UMax(m_Value(X), m_Value(Y)), + m_c_UMin(m_Deferred(X), m_Deferred(Y)))))) + return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I); + + // (mul Op0 Op1): + // if Log2(Op0) folds away -> + // (shl Op1, Log2(Op0)) + // if Log2(Op1) folds away -> + // (shl Op0, Log2(Op1)) + if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false, + /*DoFold*/ true); + BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res); + // We can only propegate nuw flag. + Shl->setHasNoUnsignedWrap(HasNUW); + return Shl; + } + bool Changed = false; if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; @@ -765,6 +802,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { I.hasNoSignedZeros() && match(Start, m_Zero())) return replaceInstUsesWith(I, Start); + // minimun(X, Y) * maximum(X, Y) => X * Y. + if (match(&I, + m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), + m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), + m_Deferred(Y))))) { + BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I); + // We cannot preserve ninf if nnan flag is not set. + // If X is NaN and Y is Inf then in original program we had NaN * NaN, + // while in optimized version NaN * Inf and this is a poison with ninf flag. + if (!Result->hasNoNaNs()) + Result->setHasNoInfs(false); + return Result; + } + return nullptr; } @@ -976,9 +1027,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { ConstantInt::get(Ty, Product)); } + APInt Quotient(C2->getBitWidth(), /*val=*/0ULL, IsSigned); if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) || (!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); // (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1. if (isMultiple(*C2, *C1, Quotient, IsSigned)) { @@ -1003,7 +1054,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { C1->ult(C1->getBitWidth() - 1)) || (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && C1->ult(C1->getBitWidth()))) { - APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); @@ -1026,6 +1076,23 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } + // Distribute div over add to eliminate a matching div/mul pair: + // ((X * C2) + C1) / C2 --> X + C1/C2 + // We need a multiple of the divisor for a signed add constant, but + // unsigned is fine with any constant pair. + if (IsSigned && + match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1))) && + isMultiple(*C1, *C2, Quotient, IsSigned)) { + return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient)); + } + if (!IsSigned && + match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)), + m_APInt(C1)))) { + return BinaryOperator::CreateNUWAdd(X, + ConstantInt::get(Ty, C1->udiv(*C2))); + } + if (!C2->isZero()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) return FoldedDiv; @@ -1121,7 +1188,7 @@ static const unsigned MaxDepth = 6; // actual instructions, otherwise return a non-null dummy value. Return nullptr // on failure. static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, - bool DoFold) { + bool AssumeNonZero, bool DoFold) { auto IfFold = [DoFold](function_ref<Value *()> Fn) { if (!DoFold) return reinterpret_cast<Value *>(-1); @@ -1147,14 +1214,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: Require one use? Value *X, *Y; if (match(Op, m_ZExt(m_Value(X)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); }); // log2(X << Y) -> log2(X) + Y // FIXME: Require one use unless X is 1? - if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) - if (Value *LogX = takeLog2(Builder, X, Depth, DoFold)) - return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) { + auto *BO = cast<OverflowingBinaryOperator>(Op); + // nuw will be set if the `shl` is trivially non-zero. + if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap()) + if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold)) + return IfFold([&]() { return Builder.CreateAdd(LogX, Y); }); + } // log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y) // FIXME: missed optimization: if one of the hands of select is/contains @@ -1162,8 +1233,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // FIXME: can both hands contain undef? // FIXME: Require one use? if (SelectInst *SI = dyn_cast<SelectInst>(Op)) - if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold)) + if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, + AssumeNonZero, DoFold)) + if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, + AssumeNonZero, DoFold)) return IfFold([&]() { return Builder.CreateSelect(SI->getOperand(0), LogX, LogY); }); @@ -1171,13 +1244,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, // log2(umin(X, Y)) -> umin(log2(X), log2(Y)) // log2(umax(X, Y)) -> umax(log2(X), log2(Y)) auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op); - if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) - if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold)) - if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold)) + if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) { + // Use AssumeNonZero as false here. Otherwise we can hit case where + // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow). + if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) + if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, + /*AssumeNonZero*/ false, DoFold)) return IfFold([&]() { - return Builder.CreateBinaryIntrinsic( - MinMax->getIntrinsicID(), LogX, LogY); + return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX, + LogY); }); + } return nullptr; } @@ -1297,8 +1375,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. - if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { - Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); + if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true, + /*DoFold*/ false)) { + Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, + /*AssumeNonZero*/ true, /*DoFold*/ true); return replaceInstUsesWith( I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact())); } @@ -1359,7 +1439,8 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // (sext X) sdiv C --> sext (X sdiv C) Value *Op0Src; if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) && - Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) { + Op0Src->getType()->getScalarSizeInBits() >= + Op1C->getSignificantBits()) { // In the general case, we need to make sure that the dividend is not the // minimum signed value because dividing that by -1 is UB. But here, we @@ -1402,7 +1483,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { KnownBits KnownDividend = computeKnownBits(Op0, 0, &I); if (!I.isExact() && (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) && - KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) { + KnownDividend.countMinTrailingZeros() >= Op1C->countr_zero()) { I.setIsExact(); return &I; } @@ -1681,6 +1762,111 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return nullptr; } +// Variety of transform for: +// (urem/srem (mul X, Y), (mul X, Z)) +// (urem/srem (shl X, Y), (shl X, Z)) +// (urem/srem (shl Y, X), (shl Z, X)) +// NB: The shift cases are really just extensions of the mul case. We treat +// shift as Val * (1 << Amt). +static Instruction *simplifyIRemMulShl(BinaryOperator &I, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr; + APInt Y, Z; + bool ShiftByX = false; + + // If V is not nullptr, it will be matched using m_Specific. + auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp))))) + C = *Tmp; + else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) || + (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp))))) + C = APInt(Tmp->getBitWidth(), 1) << *Tmp; + if (Tmp != nullptr) + return true; + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool { + const APInt *Tmp = nullptr; + if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) || + (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) { + C = *Tmp; + return true; + } + + // Reset `V` so we don't start with specific value on next match attempt. + V = nullptr; + return false; + }; + + if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) { + // pass + } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) { + ShiftByX = true; + } else { + return nullptr; + } + + bool IsSRem = I.getOpcode() == Instruction::SRem; + + OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0); + // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >= + // Z or Z >= Y. + bool BO0HasNSW = BO0->hasNoSignedWrap(); + bool BO0HasNUW = BO0->hasNoUnsignedWrap(); + bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW; + + APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z); + // (rem (mul nuw/nsw X, Y), (mul X, Z)) + // if (rem Y, Z) == 0 + // -> 0 + if (RemYZ.isZero() && BO0NoWrap) + return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType())); + + // Helper function to emit either (RemSimplificationC << X) or + // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as + // (shl V, X) or (mul V, X) respectively. + auto CreateMulOrShift = + [&](const APInt &RemSimplificationC) -> BinaryOperator * { + Value *RemSimplification = + ConstantInt::get(I.getType(), RemSimplificationC); + return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X) + : BinaryOperator::CreateMul(X, RemSimplification); + }; + + OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1); + bool BO1HasNSW = BO1->hasNoSignedWrap(); + bool BO1HasNUW = BO1->hasNoUnsignedWrap(); + bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW; + // (rem (mul X, Y), (mul nuw/nsw X, Z)) + // if (rem Y, Z) == Y + // -> (mul nuw/nsw X, Y) + if (RemYZ == Y && BO1NoWrap) { + BinaryOperator *BO = CreateMulOrShift(Y); + // Copy any overflow flags from Op0. + BO->setHasNoSignedWrap(IsSRem || BO0HasNSW); + BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW); + return BO; + } + + // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z)) + // if Y >= Z + // -> (mul {nuw} nsw X, (rem Y, Z)) + if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) { + BinaryOperator *BO = CreateMulOrShift(RemYZ); + BO->setHasNoSignedWrap(); + BO->setHasNoUnsignedWrap(BO0HasNUW); + return BO; + } + + return nullptr; +} + /// This function implements the transforms common to both integer remainder /// instructions (urem and srem). It is called by the visitors to those integer /// remainder instructions. @@ -1733,6 +1919,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) { } } + if (Instruction *R = simplifyIRemMulShl(I, *this)) + return R; + return nullptr; } @@ -1782,8 +1971,21 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { // urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0 Value *X; if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty)); - return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0); + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = + Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty)); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } + + // For "(X + 1) % Op1" and if (X u< Op1) => (X + 1) == Op1 ? 0 : X + 1 . + if (match(Op0, m_Add(m_Value(X), m_One()))) { + Value *Val = + simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I)); + if (Val && match(Val, m_One())) { + Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen"); + Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1); + return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0); + } } return nullptr; |
