aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-07-26 19:03:47 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-07-26 19:04:23 +0000
commit7fa27ce4a07f19b07799a767fc29416f3b625afb (patch)
tree27825c83636c4de341eb09a74f49f5d38a15d165 /llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
parente3b557809604d036af6e00c60f012c2025b59a5e (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp244
1 files changed, 223 insertions, 21 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 97f129e200de..50458e2773e6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -185,6 +185,9 @@ static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
return nullptr;
}
+static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
+ bool AssumeNonZero, bool DoFold);
+
Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
if (Value *V =
@@ -270,7 +273,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (match(Op0, m_ZExtOrSExt(m_Value(X))) &&
match(Op1, m_APIntAllowUndef(NegPow2C))) {
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
- unsigned ShiftAmt = NegPow2C->countTrailingZeros();
+ unsigned ShiftAmt = NegPow2C->countr_zero();
if (ShiftAmt >= BitWidth - SrcWidth) {
Value *N = Builder.CreateNeg(X, X->getName() + ".neg");
Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z");
@@ -471,6 +474,40 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
+ // min(X, Y) * max(X, Y) => X * Y.
+ if (match(&I, m_CombineOr(m_c_Mul(m_SMax(m_Value(X), m_Value(Y)),
+ m_c_SMin(m_Deferred(X), m_Deferred(Y))),
+ m_c_Mul(m_UMax(m_Value(X), m_Value(Y)),
+ m_c_UMin(m_Deferred(X), m_Deferred(Y))))))
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Mul, X, Y, &I);
+
+ // (mul Op0 Op1):
+ // if Log2(Op0) folds away ->
+ // (shl Op1, Log2(Op0))
+ // if Log2(Op1) folds away ->
+ // (shl Op0, Log2(Op1))
+ if (takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op0, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ true);
+ BinaryOperator *Shl = BinaryOperator::CreateShl(Op1, Res);
+ // We can only propegate nuw flag.
+ Shl->setHasNoUnsignedWrap(HasNUW);
+ return Shl;
+ }
+ if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ false,
+ /*DoFold*/ true);
+ BinaryOperator *Shl = BinaryOperator::CreateShl(Op0, Res);
+ // We can only propegate nuw flag.
+ Shl->setHasNoUnsignedWrap(HasNUW);
+ return Shl;
+ }
+
bool Changed = false;
if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) {
Changed = true;
@@ -765,6 +802,20 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
I.hasNoSignedZeros() && match(Start, m_Zero()))
return replaceInstUsesWith(I, Start);
+ // minimun(X, Y) * maximum(X, Y) => X * Y.
+ if (match(&I,
+ m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFMulFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN * NaN,
+ // while in optimized version NaN * Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
@@ -976,9 +1027,9 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
ConstantInt::get(Ty, Product));
}
+ APInt Quotient(C2->getBitWidth(), /*val=*/0ULL, IsSigned);
if ((IsSigned && match(Op0, m_NSWMul(m_Value(X), m_APInt(C1)))) ||
(!IsSigned && match(Op0, m_NUWMul(m_Value(X), m_APInt(C1))))) {
- APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned);
// (X * C1) / C2 -> X / (C2 / C1) if C2 is a multiple of C1.
if (isMultiple(*C2, *C1, Quotient, IsSigned)) {
@@ -1003,7 +1054,6 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
C1->ult(C1->getBitWidth() - 1)) ||
(!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) &&
C1->ult(C1->getBitWidth()))) {
- APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned);
APInt C1Shifted = APInt::getOneBitSet(
C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue()));
@@ -1026,6 +1076,23 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
}
}
+ // Distribute div over add to eliminate a matching div/mul pair:
+ // ((X * C2) + C1) / C2 --> X + C1/C2
+ // We need a multiple of the divisor for a signed add constant, but
+ // unsigned is fine with any constant pair.
+ if (IsSigned &&
+ match(Op0, m_NSWAdd(m_NSWMul(m_Value(X), m_SpecificInt(*C2)),
+ m_APInt(C1))) &&
+ isMultiple(*C1, *C2, Quotient, IsSigned)) {
+ return BinaryOperator::CreateNSWAdd(X, ConstantInt::get(Ty, Quotient));
+ }
+ if (!IsSigned &&
+ match(Op0, m_NUWAdd(m_NUWMul(m_Value(X), m_SpecificInt(*C2)),
+ m_APInt(C1)))) {
+ return BinaryOperator::CreateNUWAdd(X,
+ ConstantInt::get(Ty, C1->udiv(*C2)));
+ }
+
if (!C2->isZero()) // avoid X udiv 0
if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I))
return FoldedDiv;
@@ -1121,7 +1188,7 @@ static const unsigned MaxDepth = 6;
// actual instructions, otherwise return a non-null dummy value. Return nullptr
// on failure.
static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
- bool DoFold) {
+ bool AssumeNonZero, bool DoFold) {
auto IfFold = [DoFold](function_ref<Value *()> Fn) {
if (!DoFold)
return reinterpret_cast<Value *>(-1);
@@ -1147,14 +1214,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: Require one use?
Value *X, *Y;
if (match(Op, m_ZExt(m_Value(X))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
return IfFold([&]() { return Builder.CreateZExt(LogX, Op->getType()); });
// log2(X << Y) -> log2(X) + Y
// FIXME: Require one use unless X is 1?
- if (match(Op, m_Shl(m_Value(X), m_Value(Y))))
- if (Value *LogX = takeLog2(Builder, X, Depth, DoFold))
- return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ if (match(Op, m_Shl(m_Value(X), m_Value(Y)))) {
+ auto *BO = cast<OverflowingBinaryOperator>(Op);
+ // nuw will be set if the `shl` is trivially non-zero.
+ if (AssumeNonZero || BO->hasNoUnsignedWrap() || BO->hasNoSignedWrap())
+ if (Value *LogX = takeLog2(Builder, X, Depth, AssumeNonZero, DoFold))
+ return IfFold([&]() { return Builder.CreateAdd(LogX, Y); });
+ }
// log2(Cond ? X : Y) -> Cond ? log2(X) : log2(Y)
// FIXME: missed optimization: if one of the hands of select is/contains
@@ -1162,8 +1233,10 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// FIXME: can both hands contain undef?
// FIXME: Require one use?
if (SelectInst *SI = dyn_cast<SelectInst>(Op))
- if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth, DoFold))
+ if (Value *LogX = takeLog2(Builder, SI->getOperand(1), Depth,
+ AssumeNonZero, DoFold))
+ if (Value *LogY = takeLog2(Builder, SI->getOperand(2), Depth,
+ AssumeNonZero, DoFold))
return IfFold([&]() {
return Builder.CreateSelect(SI->getOperand(0), LogX, LogY);
});
@@ -1171,13 +1244,18 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth,
// log2(umin(X, Y)) -> umin(log2(X), log2(Y))
// log2(umax(X, Y)) -> umax(log2(X), log2(Y))
auto *MinMax = dyn_cast<MinMaxIntrinsic>(Op);
- if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned())
- if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth, DoFold))
- if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth, DoFold))
+ if (MinMax && MinMax->hasOneUse() && !MinMax->isSigned()) {
+ // Use AssumeNonZero as false here. Otherwise we can hit case where
+ // log2(umax(X, Y)) != umax(log2(X), log2(Y)) (because overflow).
+ if (Value *LogX = takeLog2(Builder, MinMax->getLHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
+ if (Value *LogY = takeLog2(Builder, MinMax->getRHS(), Depth,
+ /*AssumeNonZero*/ false, DoFold))
return IfFold([&]() {
- return Builder.CreateBinaryIntrinsic(
- MinMax->getIntrinsicID(), LogX, LogY);
+ return Builder.CreateBinaryIntrinsic(MinMax->getIntrinsicID(), LogX,
+ LogY);
});
+ }
return nullptr;
}
@@ -1297,8 +1375,10 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
- if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
- Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
+ if (takeLog2(Builder, Op1, /*Depth*/ 0, /*AssumeNonZero*/ true,
+ /*DoFold*/ false)) {
+ Value *Res = takeLog2(Builder, Op1, /*Depth*/ 0,
+ /*AssumeNonZero*/ true, /*DoFold*/ true);
return replaceInstUsesWith(
I, Builder.CreateLShr(Op0, Res, I.getName(), I.isExact()));
}
@@ -1359,7 +1439,8 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
// (sext X) sdiv C --> sext (X sdiv C)
Value *Op0Src;
if (match(Op0, m_OneUse(m_SExt(m_Value(Op0Src)))) &&
- Op0Src->getType()->getScalarSizeInBits() >= Op1C->getMinSignedBits()) {
+ Op0Src->getType()->getScalarSizeInBits() >=
+ Op1C->getSignificantBits()) {
// In the general case, we need to make sure that the dividend is not the
// minimum signed value because dividing that by -1 is UB. But here, we
@@ -1402,7 +1483,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
KnownBits KnownDividend = computeKnownBits(Op0, 0, &I);
if (!I.isExact() &&
(match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) &&
- KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) {
+ KnownDividend.countMinTrailingZeros() >= Op1C->countr_zero()) {
I.setIsExact();
return &I;
}
@@ -1681,6 +1762,111 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
return nullptr;
}
+// Variety of transform for:
+// (urem/srem (mul X, Y), (mul X, Z))
+// (urem/srem (shl X, Y), (shl X, Z))
+// (urem/srem (shl Y, X), (shl Z, X))
+// NB: The shift cases are really just extensions of the mul case. We treat
+// shift as Val * (1 << Amt).
+static Instruction *simplifyIRemMulShl(BinaryOperator &I,
+ InstCombinerImpl &IC) {
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *X = nullptr;
+ APInt Y, Z;
+ bool ShiftByX = false;
+
+ // If V is not nullptr, it will be matched using m_Specific.
+ auto MatchShiftOrMulXC = [](Value *Op, Value *&V, APInt &C) -> bool {
+ const APInt *Tmp = nullptr;
+ if ((!V && match(Op, m_Mul(m_Value(V), m_APInt(Tmp)))) ||
+ (V && match(Op, m_Mul(m_Specific(V), m_APInt(Tmp)))))
+ C = *Tmp;
+ else if ((!V && match(Op, m_Shl(m_Value(V), m_APInt(Tmp)))) ||
+ (V && match(Op, m_Shl(m_Specific(V), m_APInt(Tmp)))))
+ C = APInt(Tmp->getBitWidth(), 1) << *Tmp;
+ if (Tmp != nullptr)
+ return true;
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
+ };
+
+ auto MatchShiftCX = [](Value *Op, APInt &C, Value *&V) -> bool {
+ const APInt *Tmp = nullptr;
+ if ((!V && match(Op, m_Shl(m_APInt(Tmp), m_Value(V)))) ||
+ (V && match(Op, m_Shl(m_APInt(Tmp), m_Specific(V))))) {
+ C = *Tmp;
+ return true;
+ }
+
+ // Reset `V` so we don't start with specific value on next match attempt.
+ V = nullptr;
+ return false;
+ };
+
+ if (MatchShiftOrMulXC(Op0, X, Y) && MatchShiftOrMulXC(Op1, X, Z)) {
+ // pass
+ } else if (MatchShiftCX(Op0, Y, X) && MatchShiftCX(Op1, Z, X)) {
+ ShiftByX = true;
+ } else {
+ return nullptr;
+ }
+
+ bool IsSRem = I.getOpcode() == Instruction::SRem;
+
+ OverflowingBinaryOperator *BO0 = cast<OverflowingBinaryOperator>(Op0);
+ // TODO: We may be able to deduce more about nsw/nuw of BO0/BO1 based on Y >=
+ // Z or Z >= Y.
+ bool BO0HasNSW = BO0->hasNoSignedWrap();
+ bool BO0HasNUW = BO0->hasNoUnsignedWrap();
+ bool BO0NoWrap = IsSRem ? BO0HasNSW : BO0HasNUW;
+
+ APInt RemYZ = IsSRem ? Y.srem(Z) : Y.urem(Z);
+ // (rem (mul nuw/nsw X, Y), (mul X, Z))
+ // if (rem Y, Z) == 0
+ // -> 0
+ if (RemYZ.isZero() && BO0NoWrap)
+ return IC.replaceInstUsesWith(I, ConstantInt::getNullValue(I.getType()));
+
+ // Helper function to emit either (RemSimplificationC << X) or
+ // (RemSimplificationC * X) depending on whether we matched Op0/Op1 as
+ // (shl V, X) or (mul V, X) respectively.
+ auto CreateMulOrShift =
+ [&](const APInt &RemSimplificationC) -> BinaryOperator * {
+ Value *RemSimplification =
+ ConstantInt::get(I.getType(), RemSimplificationC);
+ return ShiftByX ? BinaryOperator::CreateShl(RemSimplification, X)
+ : BinaryOperator::CreateMul(X, RemSimplification);
+ };
+
+ OverflowingBinaryOperator *BO1 = cast<OverflowingBinaryOperator>(Op1);
+ bool BO1HasNSW = BO1->hasNoSignedWrap();
+ bool BO1HasNUW = BO1->hasNoUnsignedWrap();
+ bool BO1NoWrap = IsSRem ? BO1HasNSW : BO1HasNUW;
+ // (rem (mul X, Y), (mul nuw/nsw X, Z))
+ // if (rem Y, Z) == Y
+ // -> (mul nuw/nsw X, Y)
+ if (RemYZ == Y && BO1NoWrap) {
+ BinaryOperator *BO = CreateMulOrShift(Y);
+ // Copy any overflow flags from Op0.
+ BO->setHasNoSignedWrap(IsSRem || BO0HasNSW);
+ BO->setHasNoUnsignedWrap(!IsSRem || BO0HasNUW);
+ return BO;
+ }
+
+ // (rem (mul nuw/nsw X, Y), (mul {nsw} X, Z))
+ // if Y >= Z
+ // -> (mul {nuw} nsw X, (rem Y, Z))
+ if (Y.uge(Z) && (IsSRem ? (BO0HasNSW && BO1HasNSW) : BO0HasNUW)) {
+ BinaryOperator *BO = CreateMulOrShift(RemYZ);
+ BO->setHasNoSignedWrap();
+ BO->setHasNoUnsignedWrap(BO0HasNUW);
+ return BO;
+ }
+
+ return nullptr;
+}
+
/// This function implements the transforms common to both integer remainder
/// instructions (urem and srem). It is called by the visitors to those integer
/// remainder instructions.
@@ -1733,6 +1919,9 @@ Instruction *InstCombinerImpl::commonIRemTransforms(BinaryOperator &I) {
}
}
+ if (Instruction *R = simplifyIRemMulShl(I, *this))
+ return R;
+
return nullptr;
}
@@ -1782,8 +1971,21 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) {
// urem Op0, (sext i1 X) --> (Op0 == -1) ? 0 : Op0
Value *X;
if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) {
- Value *Cmp = Builder.CreateICmpEQ(Op0, ConstantInt::getAllOnesValue(Ty));
- return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Op0);
+ Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen");
+ Value *Cmp =
+ Builder.CreateICmpEQ(FrozenOp0, ConstantInt::getAllOnesValue(Ty));
+ return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0);
+ }
+
+ // For "(X + 1) % Op1" and if (X u< Op1) => (X + 1) == Op1 ? 0 : X + 1 .
+ if (match(Op0, m_Add(m_Value(X), m_One()))) {
+ Value *Val =
+ simplifyICmpInst(ICmpInst::ICMP_ULT, X, Op1, SQ.getWithInstruction(&I));
+ if (Val && match(Val, m_One())) {
+ Value *FrozenOp0 = Builder.CreateFreeze(Op0, Op0->getName() + ".frozen");
+ Value *Cmp = Builder.CreateICmpEQ(FrozenOp0, Op1);
+ return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), FrozenOp0);
+ }
}
return nullptr;