aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp102
1 files changed, 93 insertions, 9 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index b68efc993723..91ca44e0f11e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -797,7 +797,7 @@ static Value *checkForNegativeOperand(BinaryOperator &I,
// LHS = XOR(Y, C1), Y = AND(Z, C2), C1 == (C2 + 1) => LHS == NEG(OR(Z, ~C2))
// ADD(LHS, RHS) == SUB(RHS, OR(Z, ~C2))
if (match(LHS, m_Xor(m_Value(Y), m_APInt(C1))))
- if (C1->countTrailingZeros() == 0)
+ if (C1->countr_zero() == 0)
if (match(Y, m_And(m_Value(Z), m_APInt(C2))) && *C1 == (*C2 + 1)) {
Value *NewOr = Builder.CreateOr(Z, ~(*C2));
return Builder.CreateSub(RHS, NewOr, "sub");
@@ -880,8 +880,15 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
return SelectInst::Create(X, InstCombiner::SubOne(Op1C), Op1);
// ~X + C --> (C-1) - X
- if (match(Op0, m_Not(m_Value(X))))
- return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X);
+ if (match(Op0, m_Not(m_Value(X)))) {
+ // ~X + C has NSW and (C-1) won't oveflow => (C-1)-X can have NSW
+ auto *COne = ConstantInt::get(Op1C->getType(), 1);
+ bool WillNotSOV = willNotOverflowSignedSub(Op1C, COne, Add);
+ BinaryOperator *Res =
+ BinaryOperator::CreateSub(ConstantExpr::getSub(Op1C, COne), X);
+ Res->setHasNoSignedWrap(Add.hasNoSignedWrap() && WillNotSOV);
+ return Res;
+ }
// (iN X s>> (N - 1)) + 1 --> zext (X > -1)
const APInt *C;
@@ -975,6 +982,16 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
}
}
+ // Fold (add (zext (add X, -1)), 1) -> (zext X) if X is non-zero.
+ // TODO: There's a general form for any constant on the outer add.
+ if (C->isOne()) {
+ if (match(Op0, m_ZExt(m_Add(m_Value(X), m_AllOnes())))) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Add);
+ if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return new ZExtInst(X, Ty);
+ }
+ }
+
return nullptr;
}
@@ -1366,6 +1383,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *X = foldNoWrapAdd(I, Builder))
return X;
+ if (Instruction *R = foldBinOpShiftWithShift(I))
+ return R;
+
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1))
@@ -1421,6 +1441,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
Value *Sub = Builder.CreateSub(A, B);
return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2));
}
+
+ // Canonicalize a constant sub operand as an add operand for better folding:
+ // (C1 - A) + B --> (B - A) + C1
+ if (match(&I, m_c_Add(m_OneUse(m_Sub(m_ImmConstant(C1), m_Value(A))),
+ m_Value(B)))) {
+ Value *Sub = Builder.CreateSub(B, A, "reass.sub");
+ return BinaryOperator::CreateAdd(Sub, C1);
+ }
}
// X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1)
@@ -1439,7 +1467,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
// (A & 2^C1) + A => A & (2^C1 - 1) iff bit C1 in A is a sign bit
if (match(&I, m_c_Add(m_And(m_Value(A), m_APInt(C1)), m_Deferred(A))) &&
- C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countLeadingZeros())) {
+ C1->isPowerOf2() && (ComputeNumSignBits(A) > C1->countl_zero())) {
Constant *NewMask = ConstantInt::get(RHS->getType(), *C1 - 1);
return BinaryOperator::CreateAnd(A, NewMask);
}
@@ -1451,6 +1479,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))))
return new ZExtInst(B, LHS->getType());
+ // zext(A) + sext(A) --> 0 if A is i1
+ if (match(&I, m_c_BinOp(m_ZExt(m_Value(A)), m_SExt(m_Deferred(A)))) &&
+ A->getType()->isIntOrIntVectorTy(1))
+ return replaceInstUsesWith(I, Constant::getNullValue(I.getType()));
+
// A+B --> A|B iff A and B have no bits set in common.
if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
return BinaryOperator::CreateOr(LHS, RHS);
@@ -1515,7 +1548,7 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
const APInt *NegPow2C;
if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))),
m_Value(B)))) {
- Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros());
+ Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countr_zero());
Value *Shl = Builder.CreateShl(A, ShiftAmtC);
return BinaryOperator::CreateSub(B, Shl);
}
@@ -1536,6 +1569,13 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Instruction *Ashr = foldAddToAshr(I))
return Ashr;
+ // min(A, B) + max(A, B) => A + B.
+ if (match(&I, m_CombineOr(m_c_Add(m_SMax(m_Value(A), m_Value(B)),
+ m_c_SMin(m_Deferred(A), m_Deferred(B))),
+ m_c_Add(m_UMax(m_Value(A), m_Value(B)),
+ m_c_UMin(m_Deferred(A), m_Deferred(B))))))
+ return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I);
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
@@ -1575,6 +1615,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateOr(A, B)}));
+ if (Instruction *Res = foldBinOpOfDisplacedShifts(I))
+ return Res;
+
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
return Changed ? &I : nullptr;
}
@@ -1786,6 +1832,20 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return replaceInstUsesWith(I, V);
}
+ // minumum(X, Y) + maximum(X, Y) => X + Y.
+ if (match(&I,
+ m_c_FAdd(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)),
+ m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X),
+ m_Deferred(Y))))) {
+ BinaryOperator *Result = BinaryOperator::CreateFAddFMF(X, Y, &I);
+ // We cannot preserve ninf if nnan flag is not set.
+ // If X is NaN and Y is Inf then in original program we had NaN + NaN,
+ // while in optimized version NaN + Inf and this is a poison with ninf flag.
+ if (!Result->hasNoNaNs())
+ Result->setHasNoInfs(false);
+ return Result;
+ }
+
return nullptr;
}
@@ -1956,8 +2016,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
Constant *C2;
// C-(X+C2) --> (C-C2)-X
- if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2))))
- return BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
+ if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) {
+ // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW
+ // => (C-C2)-X can have NSW
+ bool WillNotSOV = willNotOverflowSignedSub(C, C2, I);
+ BinaryOperator *Res =
+ BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X);
+ auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
+ Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() &&
+ WillNotSOV);
+ return Res;
+ }
}
auto TryToNarrowDeduceFlags = [this, &I, &Op0, &Op1]() -> Instruction * {
@@ -2325,7 +2394,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
const APInt *AddC, *AndC;
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) &&
match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) {
- unsigned Cttz = AddC->countTrailingZeros();
+ unsigned Cttz = AddC->countr_zero();
APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz));
if ((HighMask & *AndC).isZero())
return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC)));
@@ -2388,6 +2457,21 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
return replaceInstUsesWith(I, Mul);
}
+ // max(X,Y) nsw/nuw - min(X,Y) --> abs(X nsw - Y)
+ if (match(Op0, m_OneUse(m_c_SMax(m_Value(X), m_Value(Y)))) &&
+ match(Op1, m_OneUse(m_c_SMin(m_Specific(X), m_Specific(Y))))) {
+ if (I.hasNoUnsignedWrap() || I.hasNoSignedWrap()) {
+ Value *Sub =
+ Builder.CreateSub(X, Y, "sub", /*HasNUW=*/false, /*HasNSW=*/true);
+ Value *Call =
+ Builder.CreateBinaryIntrinsic(Intrinsic::abs, Sub, Builder.getTrue());
+ return replaceInstUsesWith(I, Call);
+ }
+ }
+
+ if (Instruction *Res = foldBinOpOfSelectAndCastOfSelectCondition(I))
+ return Res;
+
return TryToNarrowDeduceFlags();
}
@@ -2567,7 +2651,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
// Note that if this fsub was really an fneg, the fadd with -0.0 will get
// killed later. We still limit that particular transform with 'hasOneUse'
// because an fneg is assumed better/cheaper than a generic fsub.
- if (I.hasNoSignedZeros() || CannotBeNegativeZero(Op0, SQ.TLI)) {
+ if (I.hasNoSignedZeros() || cannotBeNegativeZero(Op0, SQ.DL, SQ.TLI)) {
if (match(Op1, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) {
Value *NewSub = Builder.CreateFSubFMF(Y, X, &I);
return BinaryOperator::CreateFAddFMF(Op0, NewSub, &I);