aboutsummaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineAddSub.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineAddSub.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineAddSub.cpp268
1 files changed, 244 insertions, 24 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index ba15b023f2a3..8bc34825f8a7 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1097,6 +1097,107 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+Instruction *
+InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
+ BinaryOperator &I) {
+ assert((I.getOpcode() == Instruction::Add ||
+ I.getOpcode() == Instruction::Or ||
+ I.getOpcode() == Instruction::Sub) &&
+ "Expecting add/or/sub instruction");
+
+ // We have a subtraction/addition between a (potentially truncated) *logical*
+ // right-shift of X and a "select".
+ Value *X, *Select;
+ Instruction *LowBitsToSkip, *Extract;
+ if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd(
+ m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)),
+ m_Instruction(Extract))),
+ m_Value(Select))))
+ return nullptr;
+
+ // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS.
+ if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select)
+ return nullptr;
+
+ Type *XTy = X->getType();
+ bool HadTrunc = I.getType() != XTy;
+
+ // If there was a truncation of extracted value, then we'll need to produce
+ // one extra instruction, so we need to ensure one instruction will go away.
+ if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())))
+ return nullptr;
+
+ // Extraction should extract high NBits bits, with shift amount calculated as:
+ // low bits to skip = shift bitwidth - high bits to extract
+ // The shift amount itself may be extended, and we need to look past zero-ext
+ // when matching NBits, that will matter for matching later.
+ Constant *C;
+ Value *NBits;
+ if (!match(
+ LowBitsToSkip,
+ m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) ||
+ !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ,
+ APInt(C->getType()->getScalarSizeInBits(),
+ X->getType()->getScalarSizeInBits()))))
+ return nullptr;
+
+ // Sign-extending value can be zero-extended if we `sub`tract it,
+ // or sign-extended otherwise.
+ auto SkipExtInMagic = [&I](Value *&V) {
+ if (I.getOpcode() == Instruction::Sub)
+ match(V, m_ZExtOrSelf(m_Value(V)));
+ else
+ match(V, m_SExtOrSelf(m_Value(V)));
+ };
+
+ // Now, finally validate the sign-extending magic.
+ // `select` itself may be appropriately extended, look past that.
+ SkipExtInMagic(Select);
+
+ ICmpInst::Predicate Pred;
+ const APInt *Thr;
+ Value *SignExtendingValue, *Zero;
+ bool ShouldSignext;
+ // It must be a select between two values we will later establish to be a
+ // sign-extending value and a zero constant. The condition guarding the
+ // sign-extension must be based on a sign bit of the same X we had in `lshr`.
+ if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)),
+ m_Value(SignExtendingValue), m_Value(Zero))) ||
+ !isSignBitCheck(Pred, *Thr, ShouldSignext))
+ return nullptr;
+
+ // icmp-select pair is commutative.
+ if (!ShouldSignext)
+ std::swap(SignExtendingValue, Zero);
+
+ // If we should not perform sign-extension then we must add/or/subtract zero.
+ if (!match(Zero, m_Zero()))
+ return nullptr;
+ // Otherwise, it should be some constant, left-shifted by the same NBits we
+ // had in `lshr`. Said left-shift can also be appropriately extended.
+ // Again, we must look past zero-ext when looking for NBits.
+ SkipExtInMagic(SignExtendingValue);
+ Constant *SignExtendingValueBaseConstant;
+ if (!match(SignExtendingValue,
+ m_Shl(m_Constant(SignExtendingValueBaseConstant),
+ m_ZExtOrSelf(m_Specific(NBits)))))
+ return nullptr;
+ // If we `sub`, then the constant should be one, else it should be all-ones.
+ if (I.getOpcode() == Instruction::Sub
+ ? !match(SignExtendingValueBaseConstant, m_One())
+ : !match(SignExtendingValueBaseConstant, m_AllOnes()))
+ return nullptr;
+
+ auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip,
+ Extract->getName() + ".sext");
+ NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness.
+ if (!HadTrunc)
+ return NewAShr;
+
+ Builder.Insert(NewAShr);
+ return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType());
+}
+
Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -1302,12 +1403,32 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) {
if (Instruction *V = canonicalizeLowbitMask(I, Builder))
return V;
+ if (Instruction *V =
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
+ return V;
+
if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I))
return SatAdd;
return Changed ? &I : nullptr;
}
+/// Eliminate an op from a linear interpolation (lerp) pattern.
+static Instruction *factorizeLerp(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X, *Y, *Z;
+ if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y),
+ m_OneUse(m_FSub(m_FPOne(),
+ m_Value(Z))))),
+ m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z))))))
+ return nullptr;
+
+ // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants]
+ Value *XY = Builder.CreateFSubFMF(X, Y, &I);
+ Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I);
+ return BinaryOperator::CreateFAddFMF(Y, MulZ, &I);
+}
+
/// Factor a common operand out of fadd/fsub of fmul/fdiv.
static Instruction *factorizeFAddFSub(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
@@ -1315,6 +1436,10 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I,
I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub");
assert(I.hasAllowReassoc() && I.hasNoSignedZeros() &&
"FP factorization requires FMF");
+
+ if (Instruction *Lerp = factorizeLerp(I, Builder))
+ return Lerp;
+
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Value *X, *Y, *Z;
bool IsFMul;
@@ -1362,17 +1487,32 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) {
if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I))
return FoldedFAdd;
- Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
- Value *X;
// (-X) + Y --> Y - X
- if (match(LHS, m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFSubFMF(RHS, X, &I);
- // Y + (-X) --> Y - X
- if (match(RHS, m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFSubFMF(LHS, X, &I);
+ Value *X, *Y;
+ if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y))))
+ return BinaryOperator::CreateFSubFMF(Y, X, &I);
+
+ // Similar to above, but look through fmul/fdiv for the negated term.
+ // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants]
+ Value *Z;
+ if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))),
+ m_Value(Z)))) {
+ Value *XY = Builder.CreateFMulFMF(X, Y, &I);
+ return BinaryOperator::CreateFSubFMF(Z, XY, &I);
+ }
+ // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants]
+ // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants]
+ if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))),
+ m_Value(Z))) ||
+ match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))),
+ m_Value(Z)))) {
+ Value *XY = Builder.CreateFDivFMF(X, Y, &I);
+ return BinaryOperator::CreateFSubFMF(Z, XY, &I);
+ }
// Check for (fadd double (sitofp x), y), see if we can merge this into an
// integer add followed by a promotion.
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) {
Value *LHSIntVal = LHSConv->getOperand(0);
Type *FPType = LHSConv->getType();
@@ -1631,37 +1771,50 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
const APInt *Op0C;
if (match(Op0, m_APInt(Op0C))) {
- unsigned BitWidth = I.getType()->getScalarSizeInBits();
- // -(X >>u 31) -> (X >>s 31)
- // -(X >>s 31) -> (X >>u 31)
if (Op0C->isNullValue()) {
+ Value *Op1Wide;
+ match(Op1, m_TruncOrSelf(m_Value(Op1Wide)));
+ bool HadTrunc = Op1Wide != Op1;
+ bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse();
+ unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits();
+
Value *X;
const APInt *ShAmt;
- if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
+ // -(X >>u 31) -> (X >>s 31)
+ if (NoTruncOrTruncIsOneUse &&
+ match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) &&
*ShAmt == BitWidth - 1) {
- Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
- return BinaryOperator::CreateAShr(X, ShAmtOp);
+ Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
+ Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp);
+ NewShift->copyIRFlags(Op1Wide);
+ if (!HadTrunc)
+ return NewShift;
+ Builder.Insert(NewShift);
+ return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
}
- if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
+ // -(X >>s 31) -> (X >>u 31)
+ if (NoTruncOrTruncIsOneUse &&
+ match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) &&
*ShAmt == BitWidth - 1) {
- Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1);
- return BinaryOperator::CreateLShr(X, ShAmtOp);
+ Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1);
+ Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp);
+ NewShift->copyIRFlags(Op1Wide);
+ if (!HadTrunc)
+ return NewShift;
+ Builder.Insert(NewShift);
+ return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType());
}
- if (Op1->hasOneUse()) {
+ if (!HadTrunc && Op1->hasOneUse()) {
Value *LHS, *RHS;
SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor;
if (SPF == SPF_ABS || SPF == SPF_NABS) {
// This is a negate of an ABS/NABS pattern. Just swap the operands
// of the select.
- SelectInst *SI = cast<SelectInst>(Op1);
- Value *TrueVal = SI->getTrueValue();
- Value *FalseVal = SI->getFalseValue();
- SI->setTrueValue(FalseVal);
- SI->setFalseValue(TrueVal);
+ cast<SelectInst>(Op1)->swapValues();
// Don't swap prof metadata, we didn't change the branch behavior.
- return replaceInstUsesWith(I, SI);
+ return replaceInstUsesWith(I, Op1);
}
}
}
@@ -1686,6 +1839,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return BinaryOperator::CreateNeg(Y);
}
+ // (sub (or A, B) (and A, B)) --> (xor A, B)
+ {
+ Value *A, *B;
+ if (match(Op1, m_And(m_Value(A), m_Value(B))) &&
+ match(Op0, m_c_Or(m_Specific(A), m_Specific(B))))
+ return BinaryOperator::CreateXor(A, B);
+ }
+
+ // (sub (and A, B) (or A, B)) --> neg (xor A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_And(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse()))
+ return BinaryOperator::CreateNeg(Builder.CreateXor(A, B));
+ }
+
// (sub (or A, B), (xor A, B)) --> (and A, B)
{
Value *A, *B;
@@ -1694,6 +1864,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return BinaryOperator::CreateAnd(A, B);
}
+ // (sub (xor A, B) (or A, B)) --> neg (and A, B)
+ {
+ Value *A, *B;
+ if (match(Op0, m_Xor(m_Value(A), m_Value(B))) &&
+ match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse()))
+ return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B));
+ }
+
{
Value *Y;
// ((X | Y) - X) --> (~X & Y)
@@ -1778,7 +1957,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
std::swap(LHS, RHS);
// LHS is now O above and expected to have at least 2 uses (the min/max)
// NotA is epected to have 2 uses from the min/max and 1 from the sub.
- if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
+ if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) &&
!NotA->hasNUsesOrMore(4)) {
// Note: We don't generate the inverse max/min, just create the not of
// it and let other folds do the rest.
@@ -1826,6 +2005,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) {
return SelectInst::Create(Cmp, Neg, A);
}
+ if (Instruction *V =
+ canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I))
+ return V;
+
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1865,6 +2048,22 @@ static Instruction *foldFNegIntoConstant(Instruction &I) {
return nullptr;
}
+static Instruction *hoistFNegAboveFMulFDiv(Instruction &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *FNeg;
+ if (!match(&I, m_FNeg(m_Value(FNeg))))
+ return nullptr;
+
+ Value *X, *Y;
+ if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y)))))
+ return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+
+ if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y)))))
+ return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I);
+
+ return nullptr;
+}
+
Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
Value *Op = I.getOperand(0);
@@ -1882,6 +2081,9 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) {
match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y)))))
return BinaryOperator::CreateFSubFMF(Y, X, &I);
+ if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
+ return R;
+
return nullptr;
}
@@ -1903,6 +2105,9 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
if (Instruction *X = foldFNegIntoConstant(I))
return X;
+ if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
+ return R;
+
Value *X, *Y;
Constant *C;
@@ -1944,6 +2149,21 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) {
if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y))))))
return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I);
+ // Similar to above, but look through fmul/fdiv of the negated value:
+ // Op0 - (-X * Y) --> Op0 + (X * Y)
+ // Op0 - (Y * -X) --> Op0 + (X * Y)
+ if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) {
+ Value *FMul = Builder.CreateFMulFMF(X, Y, &I);
+ return BinaryOperator::CreateFAddFMF(Op0, FMul, &I);
+ }
+ // Op0 - (-X / Y) --> Op0 + (X / Y)
+ // Op0 - (X / -Y) --> Op0 + (X / Y)
+ if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) ||
+ match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) {
+ Value *FDiv = Builder.CreateFDivFMF(X, Y, &I);
+ return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I);
+ }
+
// Handle special cases for FSub with selects feeding the operation
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
return replaceInstUsesWith(I, V);