aboutsummaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:04 +0000
committerDimitry Andric <dim@FreeBSD.org>2023-02-11 12:38:11 +0000
commite3b557809604d036af6e00c60f012c2025b59a5e (patch)
tree8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Transforms/InstCombine
parent08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp326
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp695
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp17
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp303
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp295
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp561
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineInternal.h167
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp267
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp413
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp33
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp49
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp930
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp157
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp178
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp246
-rw-r--r--llvm/lib/Transforms/InstCombine/InstructionCombining.cpp723
16 files changed, 3646 insertions, 1714 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index 4a459ec6c550..b68efc993723 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -576,8 +576,7 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
}
}
- assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) &&
- "out-of-bound access");
+ assert((NextTmpIdx <= std::size(TmpResult) + 1) && "out-of-bound access");
Value *Result;
if (!SimpVect.empty())
@@ -849,6 +848,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add,
Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1);
+ Type *Ty = Add.getType();
Constant *Op1C;
if (!match(Op1, m_ImmConstant(Op1C)))
return nullptr;
@@ -883,7 +883,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
if (match(Op0, m_Not(m_Value(X))))
return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X);
+ // (iN X s>> (N - 1)) + 1 --> zext (X > -1)
const APInt *C;
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+ if (match(Op0, m_OneUse(m_AShr(m_Value(X),
+ m_SpecificIntAllowUndef(BitWidth - 1)))) &&
+ match(Op1, m_One()))
+ return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
+
if (!match(Op1, m_APInt(C)))
return nullptr;
@@ -911,7 +918,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
// Is this add the last step in a convoluted sext?
// add(zext(xor i16 X, -32768), -32768) --> sext X
- Type *Ty = Add.getType();
if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) &&
C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C)
return CastInst::Create(Instruction::SExt, X, Ty);
@@ -969,15 +975,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) {
}
}
- // If all bits affected by the add are included in a high-bit-mask, do the
- // add before the mask op:
- // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00
- if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) &&
- C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) {
- Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C));
- return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2));
- }
-
return nullptr;
}
@@ -1132,6 +1129,35 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) {
return nullptr;
}
+/// Try to reduce signed division by power-of-2 to an arithmetic shift right.
+static Instruction *foldAddToAshr(BinaryOperator &Add) {
+ // Division must be by power-of-2, but not the minimum signed value.
+ Value *X;
+ const APInt *DivC;
+ if (!match(Add.getOperand(0), m_SDiv(m_Value(X), m_Power2(DivC))) ||
+ DivC->isNegative())
+ return nullptr;
+
+ // Rounding is done by adding -1 if the dividend (X) is negative and has any
+ // low bits set. The canonical pattern for that is an "ugt" compare with SMIN:
+ // sext (icmp ugt (X & (DivC - 1)), SMIN)
+ const APInt *MaskC;
+ ICmpInst::Predicate Pred;
+ if (!match(Add.getOperand(1),
+ m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)),
+ m_SignMask()))) ||
+ Pred != ICmpInst::ICMP_UGT)
+ return nullptr;
+
+ APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits());
+ if (*MaskC != (SMin | (*DivC - 1)))
+ return nullptr;
+
+ // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC)
+ return BinaryOperator::CreateAShr(
+ X, ConstantInt::get(Add.getType(), DivC->exactLogBase2()));
+}
+
Instruction *InstCombinerImpl::
canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(
BinaryOperator &I) {
@@ -1234,7 +1260,7 @@ Instruction *InstCombinerImpl::
}
/// This is a specialization of a more general transform from
-/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally
+/// foldUsingDistributiveLaws. If that code can be made to work optimally
/// for multi-use cases or propagating nsw/nuw, then we would not need this.
static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
InstCombiner::BuilderTy &Builder) {
@@ -1270,6 +1296,45 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I,
return NewShl;
}
+/// Reduce a sequence of masked half-width multiplies to a single multiply.
+/// ((XLow * YHigh) + (YLow * XHigh)) << HalfBits) + (XLow * YLow) --> X * Y
+static Instruction *foldBoxMultiply(BinaryOperator &I) {
+ unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ // Skip the odd bitwidth types.
+ if ((BitWidth & 0x1))
+ return nullptr;
+
+ unsigned HalfBits = BitWidth >> 1;
+ APInt HalfMask = APInt::getMaxValue(HalfBits);
+
+ // ResLo = (CrossSum << HalfBits) + (YLo * XLo)
+ Value *XLo, *YLo;
+ Value *CrossSum;
+ if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)),
+ m_Mul(m_Value(YLo), m_Value(XLo)))))
+ return nullptr;
+
+ // XLo = X & HalfMask
+ // YLo = Y & HalfMask
+ // TODO: Refactor with SimplifyDemandedBits or KnownBits known leading zeros
+ // to enhance robustness
+ Value *X, *Y;
+ if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) ||
+ !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask))))
+ return nullptr;
+
+ // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits))
+ // X' can be either X or XLo in the pattern (and the same for Y')
+ if (match(CrossSum,
+ m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)),
+ m_CombineOr(m_Specific(X), m_Specific(XLo))),
+ m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)),
+ m_CombineOr(m_Specific(Y), m_Specific(YLo))))))
+ return BinaryOperator::CreateMul(X, Y);
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1),
I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
@@ -1286,9 +1351,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return Phi;
// (A*B)+(A*C) -> A*(B+C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
+ if (Instruction *R = foldBoxMultiply(I))
+ return R;
+
if (Instruction *R = factorizeMathWithShlOps(I, Builder))
return R;
@@ -1376,35 +1444,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return BinaryOperator::CreateAnd(A, NewMask);
}
+ // ZExt (B - A) + ZExt(A) --> ZExt(B)
+ if ((match(RHS, m_ZExt(m_Value(A))) &&
+ match(LHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))) ||
+ (match(LHS, m_ZExt(m_Value(A))) &&
+ match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))))
+ return new ZExtInst(B, LHS->getType());
+
// A+B --> A|B iff A and B have no bits set in common.
if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT))
return BinaryOperator::CreateOr(LHS, RHS);
- // add (select X 0 (sub n A)) A --> select X A n
- {
- SelectInst *SI = dyn_cast<SelectInst>(LHS);
- Value *A = RHS;
- if (!SI) {
- SI = dyn_cast<SelectInst>(RHS);
- A = LHS;
- }
- if (SI && SI->hasOneUse()) {
- Value *TV = SI->getTrueValue();
- Value *FV = SI->getFalseValue();
- Value *N;
-
- // Can we fold the add into the argument of the select?
- // We check both true and false select arguments for a matching subtract.
- if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A))))
- // Fold the add into the true select value.
- return SelectInst::Create(SI->getCondition(), N, A);
-
- if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A))))
- // Fold the add into the false select value.
- return SelectInst::Create(SI->getCondition(), A, N);
- }
- }
-
if (Instruction *Ext = narrowMathIfNoOverflow(I))
return Ext;
@@ -1424,6 +1474,68 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) {
return &I;
}
+ // (add A (or A, -A)) --> (and (add A, -1) A)
+ // (add A (or -A, A)) --> (and (add A, -1) A)
+ // (add (or A, -A) A) --> (and (add A, -1) A)
+ // (add (or -A, A) A) --> (and (add A, -1) A)
+ if (match(&I, m_c_BinOp(m_Value(A), m_OneUse(m_c_Or(m_Neg(m_Deferred(A)),
+ m_Deferred(A)))))) {
+ Value *Add =
+ Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()), "",
+ I.hasNoUnsignedWrap(), I.hasNoSignedWrap());
+ return BinaryOperator::CreateAnd(Add, A);
+ }
+
+ // Canonicalize ((A & -A) - 1) --> ((A - 1) & ~A)
+ // Forms all commutable operations, and simplifies ctpop -> cttz folds.
+ if (match(&I,
+ m_Add(m_OneUse(m_c_And(m_Value(A), m_OneUse(m_Neg(m_Deferred(A))))),
+ m_AllOnes()))) {
+ Constant *AllOnes = ConstantInt::getAllOnesValue(RHS->getType());
+ Value *Dec = Builder.CreateAdd(A, AllOnes);
+ Value *Not = Builder.CreateXor(A, AllOnes);
+ return BinaryOperator::CreateAnd(Dec, Not);
+ }
+
+ // Disguised reassociation/factorization:
+ // ~(A * C1) + A
+ // ((A * -C1) - 1) + A
+ // ((A * -C1) + A) - 1
+ // (A * (1 - C1)) - 1
+ if (match(&I,
+ m_c_Add(m_OneUse(m_Not(m_OneUse(m_Mul(m_Value(A), m_APInt(C1))))),
+ m_Deferred(A)))) {
+ Type *Ty = I.getType();
+ Constant *NewMulC = ConstantInt::get(Ty, 1 - *C1);
+ Value *NewMul = Builder.CreateMul(A, NewMulC);
+ return BinaryOperator::CreateAdd(NewMul, ConstantInt::getAllOnesValue(Ty));
+ }
+
+ // (A * -2**C) + B --> B - (A << C)
+ const APInt *NegPow2C;
+ if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))),
+ m_Value(B)))) {
+ Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros());
+ Value *Shl = Builder.CreateShl(A, ShiftAmtC);
+ return BinaryOperator::CreateSub(B, Shl);
+ }
+
+ // Canonicalize signum variant that ends in add:
+ // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0))
+ ICmpInst::Predicate Pred;
+ uint64_t BitWidth = Ty->getScalarSizeInBits();
+ if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) &&
+ match(RHS, m_OneUse(m_ZExt(
+ m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) &&
+ Pred == CmpInst::ICMP_SGT) {
+ Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull");
+ Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext");
+ return BinaryOperator::CreateOr(LHS, Zext);
+ }
+
+ if (Instruction *Ashr = foldAddToAshr(I))
+ return Ashr;
+
// TODO(jingyue): Consider willNotOverflowSignedAdd and
// willNotOverflowUnsignedAdd to reduce the number of invocations of
// computeKnownBits.
@@ -1665,6 +1777,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) {
return BinaryOperator::CreateFMulFMF(X, NewMulC, &I);
}
+ // (-X - Y) + (X + Z) --> Z - Y
+ if (match(&I, m_c_FAdd(m_FSub(m_FNeg(m_Value(X)), m_Value(Y)),
+ m_c_FAdd(m_Deferred(X), m_Value(Z)))))
+ return BinaryOperator::CreateFSubFMF(Z, Y, &I);
+
if (Value *V = FAddCombine(Builder).simplify(&I))
return replaceInstUsesWith(I, V);
}
@@ -1879,7 +1996,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
return TryToNarrowDeduceFlags(); // Should have been handled in Negator!
// (A*B)-(A*C) -> A*(B-C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
if (I.getType()->isIntOrIntVectorTy(1))
@@ -1967,12 +2084,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
const APInt *Op0C;
- if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) {
- // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
- // zero.
- KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
- if ((*Op0C | RHSKnown.Zero).isAllOnes())
- return BinaryOperator::CreateXor(Op1, Op0);
+ if (match(Op0, m_APInt(Op0C))) {
+ if (Op0C->isMask()) {
+ // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known
+ // zero.
+ KnownBits RHSKnown = computeKnownBits(Op1, 0, &I);
+ if ((*Op0C | RHSKnown.Zero).isAllOnes())
+ return BinaryOperator::CreateXor(Op1, Op0);
+ }
+
+ // C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when:
+ // (C3 - ((C2 & C3) - 1)) is pow2
+ // ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1)
+ // C2 is negative pow2 || sub nuw
+ const APInt *C2, *C3;
+ BinaryOperator *InnerSub;
+ if (match(Op1, m_OneUse(m_And(m_BinOp(InnerSub), m_APInt(C2)))) &&
+ match(InnerSub, m_Sub(m_APInt(C3), m_Value(X))) &&
+ (InnerSub->hasNoUnsignedWrap() || C2->isNegatedPowerOf2())) {
+ APInt C2AndC3 = *C2 & *C3;
+ APInt C2AndC3Minus1 = C2AndC3 - 1;
+ APInt C2AddC3 = *C2 + *C3;
+ if ((*C3 - C2AndC3Minus1).isPowerOf2() &&
+ C2AndC3Minus1.isSubsetOf(C2AddC3)) {
+ Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), *C2));
+ return BinaryOperator::CreateAdd(
+ And, ConstantInt::get(I.getType(), *Op0C - C2AndC3));
+ }
+ }
}
{
@@ -2165,8 +2304,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
Value *A;
const APInt *ShAmt;
Type *Ty = I.getType();
+ unsigned BitWidth = Ty->getScalarSizeInBits();
if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) &&
- Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 &&
+ Op1->hasNUses(2) && *ShAmt == BitWidth - 1 &&
match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) {
// B = ashr i32 A, 31 ; smear the sign bit
// sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1)
@@ -2185,7 +2325,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
const APInt *AddC, *AndC;
if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) &&
match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) {
- unsigned BitWidth = Ty->getScalarSizeInBits();
unsigned Cttz = AddC->countTrailingZeros();
APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz));
if ((HighMask & *AndC).isZero())
@@ -2227,18 +2366,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) {
}
// C - ctpop(X) => ctpop(~X) if C is bitwidth
- if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) &&
+ if (match(Op0, m_SpecificInt(BitWidth)) &&
match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X)))))
return replaceInstUsesWith(
I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()},
{Builder.CreateNot(X)}));
+ // Reduce multiplies for difference-of-squares by factoring:
+ // (X * X) - (Y * Y) --> (X + Y) * (X - Y)
+ if (match(Op0, m_OneUse(m_Mul(m_Value(X), m_Deferred(X)))) &&
+ match(Op1, m_OneUse(m_Mul(m_Value(Y), m_Deferred(Y))))) {
+ auto *OBO0 = cast<OverflowingBinaryOperator>(Op0);
+ auto *OBO1 = cast<OverflowingBinaryOperator>(Op1);
+ bool PropagateNSW = I.hasNoSignedWrap() && OBO0->hasNoSignedWrap() &&
+ OBO1->hasNoSignedWrap() && BitWidth > 2;
+ bool PropagateNUW = I.hasNoUnsignedWrap() && OBO0->hasNoUnsignedWrap() &&
+ OBO1->hasNoUnsignedWrap() && BitWidth > 1;
+ Value *Add = Builder.CreateAdd(X, Y, "add", PropagateNUW, PropagateNSW);
+ Value *Sub = Builder.CreateSub(X, Y, "sub", PropagateNUW, PropagateNSW);
+ Value *Mul = Builder.CreateMul(Add, Sub, "", PropagateNUW, PropagateNSW);
+ return replaceInstUsesWith(I, Mul);
+ }
+
return TryToNarrowDeduceFlags();
}
/// This eliminates floating-point negation in either 'fneg(X)' or
/// 'fsub(-0.0, X)' form by combining into a constant operand.
-static Instruction *foldFNegIntoConstant(Instruction &I) {
+static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) {
// This is limited with one-use because fneg is assumed better for
// reassociation and cheaper in codegen than fmul/fdiv.
// TODO: Should the m_OneUse restriction be removed?
@@ -2252,28 +2407,31 @@ static Instruction *foldFNegIntoConstant(Instruction &I) {
// Fold negation into constant operand.
// -(X * C) --> X * (-C)
if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFMulFMF(X, NegC, &I);
// -(X / C) --> X / (-C)
if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFDivFMF(X, NegC, &I);
// -(C / X) --> (-C) / X
- if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) {
- Instruction *FDiv =
- BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I);
+ if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X))))
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) {
+ Instruction *FDiv = BinaryOperator::CreateFDivFMF(NegC, X, &I);
- // Intersect 'nsz' and 'ninf' because those special value exceptions may not
- // apply to the fdiv. Everything else propagates from the fneg.
- // TODO: We could propagate nsz/ninf from fdiv alone?
- FastMathFlags FMF = I.getFastMathFlags();
- FastMathFlags OpFMF = FNegOp->getFastMathFlags();
- FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros());
- FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs());
- return FDiv;
- }
+ // Intersect 'nsz' and 'ninf' because those special value exceptions may
+ // not apply to the fdiv. Everything else propagates from the fneg.
+ // TODO: We could propagate nsz/ninf from fdiv alone?
+ FastMathFlags FMF = I.getFastMathFlags();
+ FastMathFlags OpFMF = FNegOp->getFastMathFlags();
+ FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros());
+ FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs());
+ return FDiv;
+ }
// With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]:
// -(X + C) --> -X + -C --> -C - X
if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C))))
- return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFSubFMF(NegC, X, &I);
return nullptr;
}
@@ -2301,7 +2459,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
getSimplifyQuery().getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
- if (Instruction *X = foldFNegIntoConstant(I))
+ if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
Value *X, *Y;
@@ -2314,18 +2472,26 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
return R;
+ Value *OneUse;
+ if (!match(Op, m_OneUse(m_Value(OneUse))))
+ return nullptr;
+
// Try to eliminate fneg if at least 1 arm of the select is negated.
Value *Cond;
- if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) {
+ if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) {
// Unlike most transforms, this one is not safe to propagate nsz unless
- // it is present on the original select. (We are conservatively intersecting
- // the nsz flags from the select and root fneg instruction.)
+ // it is present on the original select. We union the flags from the select
+ // and fneg and then remove nsz if needed.
auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) {
S->copyFastMathFlags(&I);
- if (auto *OldSel = dyn_cast<SelectInst>(Op))
+ if (auto *OldSel = dyn_cast<SelectInst>(Op)) {
+ FastMathFlags FMF = I.getFastMathFlags();
+ FMF |= OldSel->getFastMathFlags();
+ S->setFastMathFlags(FMF);
if (!OldSel->hasNoSignedZeros() && !CommonOperand &&
!isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition()))
S->setHasNoSignedZeros(false);
+ }
};
// -(Cond ? -P : Y) --> Cond ? P : -Y
Value *P;
@@ -2344,6 +2510,21 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) {
}
}
+ // fneg (copysign x, y) -> copysign x, (fneg y)
+ if (match(OneUse, m_CopySign(m_Value(X), m_Value(Y)))) {
+ // The source copysign has an additional value input, so we can't propagate
+ // flags the copysign doesn't also have.
+ FastMathFlags FMF = I.getFastMathFlags();
+ FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags();
+
+ IRBuilder<>::FastMathFlagGuard FMFGuard(Builder);
+ Builder.setFastMathFlags(FMF);
+
+ Value *NegY = Builder.CreateFNeg(Y);
+ Value *NewCopySign = Builder.CreateCopySign(X, NegY);
+ return replaceInstUsesWith(I, NewCopySign);
+ }
+
return nullptr;
}
@@ -2370,7 +2551,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
if (match(&I, m_FNeg(m_Value(Op))))
return UnaryOperator::CreateFNegFMF(Op, &I);
- if (Instruction *X = foldFNegIntoConstant(I))
+ if (Instruction *X = foldFNegIntoConstant(I, DL))
return X;
if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder))
@@ -2409,7 +2590,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) {
// But don't transform constant expressions because there's an inverse fold
// for X + (-Y) --> X - Y.
if (match(Op1, m_ImmConstant(C)))
- return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFAddFMF(Op0, NegC, &I);
// X - (-Y) --> X + Y
if (match(Op1, m_FNeg(m_Value(Y))))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index 8253c575bc37..97a001b2ed32 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -233,17 +233,13 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre
/// the right hand side as a pair.
/// LHS and RHS are the left hand side and the right hand side ICmps and PredL
/// and PredR are their predicates, respectively.
-static
-Optional<std::pair<unsigned, unsigned>>
-getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
- Value *&D, Value *&E, ICmpInst *LHS,
- ICmpInst *RHS,
- ICmpInst::Predicate &PredL,
- ICmpInst::Predicate &PredR) {
+static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair(
+ Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS,
+ ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) {
// Don't allow pointers. Splat vectors are fine.
if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() ||
!RHS->getOperand(0)->getType()->isIntOrIntVectorTy())
- return None;
+ return std::nullopt;
// Here comes the tricky part:
// LHS might be of the form L11 & L12 == X, X == L21 & L22,
@@ -274,7 +270,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
// Bail if LHS was a icmp that can't be decomposed into an equality.
if (!ICmpInst::isEquality(PredL))
- return None;
+ return std::nullopt;
Value *R1 = RHS->getOperand(0);
Value *R2 = RHS->getOperand(1);
@@ -288,7 +284,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
A = R12;
D = R11;
} else {
- return None;
+ return std::nullopt;
}
E = R2;
R1 = nullptr;
@@ -316,7 +312,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
// Bail if RHS was a icmp that can't be decomposed into an equality.
if (!ICmpInst::isEquality(PredR))
- return None;
+ return std::nullopt;
// Look for ANDs on the right side of the RHS icmp.
if (!Ok) {
@@ -336,7 +332,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
E = R1;
Ok = true;
} else {
- return None;
+ return std::nullopt;
}
assert(Ok && "Failed to find AND on the right side of the RHS icmp.");
@@ -358,7 +354,8 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C,
unsigned LeftType = getMaskedICmpType(A, B, C, PredL);
unsigned RightType = getMaskedICmpType(A, D, E, PredR);
- return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType));
+ return std::optional<std::pair<unsigned, unsigned>>(
+ std::make_pair(LeftType, RightType));
}
/// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single
@@ -526,7 +523,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
InstCombiner::BuilderTy &Builder) {
Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr;
ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
- Optional<std::pair<unsigned, unsigned>> MaskPair =
+ std::optional<std::pair<unsigned, unsigned>> MaskPair =
getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR);
if (!MaskPair)
return nullptr;
@@ -1016,10 +1013,10 @@ struct IntPart {
};
/// Match an extraction of bits from an integer.
-static Optional<IntPart> matchIntPart(Value *V) {
+static std::optional<IntPart> matchIntPart(Value *V) {
Value *X;
if (!match(V, m_OneUse(m_Trunc(m_Value(X)))))
- return None;
+ return std::nullopt;
unsigned NumOriginalBits = X->getType()->getScalarSizeInBits();
unsigned NumExtractedBits = V->getType()->getScalarSizeInBits();
@@ -1056,10 +1053,10 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred)
return nullptr;
- Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
- Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
- Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
- Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
+ std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0));
+ std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1));
+ std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0));
+ std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1));
if (!L0 || !R0 || !L1 || !R1)
return nullptr;
@@ -1094,7 +1091,7 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1,
/// common operand with the constant. Callers are expected to call this with
/// Cmp0/Cmp1 switched to handle logic op commutativity.
static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
- bool IsAnd,
+ bool IsAnd, bool IsLogical,
InstCombiner::BuilderTy &Builder,
const SimplifyQuery &Q) {
// Match an equality compare with a non-poison constant as Cmp0.
@@ -1130,6 +1127,9 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1,
return nullptr;
SubstituteCmp = Builder.CreateICmp(Pred1, Y, C);
}
+ if (IsLogical)
+ return IsAnd ? Builder.CreateLogicalAnd(Cmp0, SubstituteCmp)
+ : Builder.CreateLogicalOr(Cmp0, SubstituteCmp);
return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0,
SubstituteCmp);
}
@@ -1174,7 +1174,7 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
Type *Ty = V1->getType();
Value *NewV = V1;
- Optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
+ std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2);
if (!CR) {
if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() ||
CR2.isWrappedSet())
@@ -1205,6 +1205,47 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1,
return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC));
}
+/// Ignore all operations which only change the sign of a value, returning the
+/// underlying magnitude value.
+static Value *stripSignOnlyFPOps(Value *Val) {
+ match(Val, m_FNeg(m_Value(Val)));
+ match(Val, m_FAbs(m_Value(Val)));
+ match(Val, m_CopySign(m_Value(Val), m_Value()));
+ return Val;
+}
+
+/// Matches canonical form of isnan, fcmp ord x, 0
+static bool matchIsNotNaN(FCmpInst::Predicate P, Value *LHS, Value *RHS) {
+ return P == FCmpInst::FCMP_ORD && match(RHS, m_AnyZeroFP());
+}
+
+/// Matches fcmp u__ x, +/-inf
+static bool matchUnorderedInfCompare(FCmpInst::Predicate P, Value *LHS,
+ Value *RHS) {
+ return FCmpInst::isUnordered(P) && match(RHS, m_Inf());
+}
+
+/// and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf
+///
+/// Clang emits this pattern for doing an isfinite check in __builtin_isnormal.
+static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS,
+ FCmpInst *RHS) {
+ Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1);
+ Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1);
+ FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate();
+
+ if (!matchIsNotNaN(PredL, LHS0, LHS1) ||
+ !matchUnorderedInfCompare(PredR, RHS0, RHS1))
+ return nullptr;
+
+ IRBuilder<>::FastMathFlagGuard FMFG(Builder);
+ FastMathFlags FMF = LHS->getFastMathFlags();
+ FMF &= RHS->getFastMathFlags();
+ Builder.setFastMathFlags(FMF);
+
+ return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1);
+}
+
Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
bool IsAnd, bool IsLogicalSelect) {
Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1);
@@ -1263,9 +1304,79 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS,
return Builder.CreateFCmp(PredL, LHS0, RHS0);
}
+ if (IsAnd && stripSignOnlyFPOps(LHS0) == stripSignOnlyFPOps(RHS0)) {
+ // and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf
+ // and (fcmp ord x, 0), (fcmp u* fabs(x), inf) -> fcmp o* x, inf
+ if (Value *Left = matchIsFiniteTest(Builder, LHS, RHS))
+ return Left;
+ if (Value *Right = matchIsFiniteTest(Builder, RHS, LHS))
+ return Right;
+ }
+
return nullptr;
}
+/// or (is_fpclass x, mask0), (is_fpclass x, mask1)
+/// -> is_fpclass x, (mask0 | mask1)
+/// and (is_fpclass x, mask0), (is_fpclass x, mask1)
+/// -> is_fpclass x, (mask0 & mask1)
+/// xor (is_fpclass x, mask0), (is_fpclass x, mask1)
+/// -> is_fpclass x, (mask0 ^ mask1)
+Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO,
+ Value *Op0, Value *Op1) {
+ Value *ClassVal;
+ uint64_t ClassMask0, ClassMask1;
+
+ if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>(
+ m_Value(ClassVal), m_ConstantInt(ClassMask0)))) &&
+ match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>(
+ m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) {
+ unsigned NewClassMask;
+ switch (BO.getOpcode()) {
+ case Instruction::And:
+ NewClassMask = ClassMask0 & ClassMask1;
+ break;
+ case Instruction::Or:
+ NewClassMask = ClassMask0 | ClassMask1;
+ break;
+ case Instruction::Xor:
+ NewClassMask = ClassMask0 ^ ClassMask1;
+ break;
+ default:
+ llvm_unreachable("not a binary logic operator");
+ }
+
+ // TODO: Also check for special fcmps
+ auto *II = cast<IntrinsicInst>(Op0);
+ II->setArgOperand(
+ 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask));
+ return replaceInstUsesWith(BO, II);
+ }
+
+ return nullptr;
+}
+
+/// Look for the pattern that conditionally negates a value via math operations:
+/// cond.splat = sext i1 cond
+/// sub = add cond.splat, x
+/// xor = xor sub, cond.splat
+/// and rewrite it to do the same, but via logical operations:
+/// value.neg = sub 0, value
+/// cond = select i1 neg, value.neg, value
+Instruction *InstCombinerImpl::canonicalizeConditionalNegationViaMathToSelect(
+ BinaryOperator &I) {
+ assert(I.getOpcode() == BinaryOperator::Xor && "Only for xor!");
+ Value *Cond, *X;
+ // As per complexity ordering, `xor` is not commutative here.
+ if (!match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())) ||
+ !match(I.getOperand(1), m_SExt(m_Value(Cond))) ||
+ !Cond->getType()->isIntOrIntVectorTy(1) ||
+ !match(I.getOperand(0), m_c_Add(m_SExt(m_Deferred(Cond)), m_Value(X))))
+ return nullptr;
+ return SelectInst::Create(Cond, Builder.CreateNeg(X, X->getName() + ".neg"),
+ X);
+}
+
/// This a limited reassociation for a special case (see above) where we are
/// checking if two values are either both NAN (unordered) or not-NAN (ordered).
/// This could be handled more generally in '-reassociation', but it seems like
@@ -1430,11 +1541,33 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) {
if (!Cast1)
return nullptr;
- // Both operands of the logic operation are casts. The casts must be of the
- // same type for reduction.
- auto CastOpcode = Cast0->getOpcode();
- if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy())
+ // Both operands of the logic operation are casts. The casts must be the
+ // same kind for reduction.
+ Instruction::CastOps CastOpcode = Cast0->getOpcode();
+ if (CastOpcode != Cast1->getOpcode())
+ return nullptr;
+
+ // If the source types do not match, but the casts are matching extends, we
+ // can still narrow the logic op.
+ if (SrcTy != Cast1->getSrcTy()) {
+ Value *X, *Y;
+ if (match(Cast0, m_OneUse(m_ZExtOrSExt(m_Value(X)))) &&
+ match(Cast1, m_OneUse(m_ZExtOrSExt(m_Value(Y))))) {
+ // Cast the narrower source to the wider source type.
+ unsigned XNumBits = X->getType()->getScalarSizeInBits();
+ unsigned YNumBits = Y->getType()->getScalarSizeInBits();
+ if (XNumBits < YNumBits)
+ X = Builder.CreateCast(CastOpcode, X, Y->getType());
+ else
+ Y = Builder.CreateCast(CastOpcode, Y, X->getType());
+ // Do the logic op in the intermediate width, then widen more.
+ Value *NarrowLogic = Builder.CreateBinOp(LogicOpc, X, Y);
+ return CastInst::Create(CastOpcode, NarrowLogic, DestTy);
+ }
+
+ // Give up for other cast opcodes.
return nullptr;
+ }
Value *Cast0Src = Cast0->getOperand(0);
Value *Cast1Src = Cast1->getOperand(0);
@@ -1722,6 +1855,77 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I,
return nullptr;
}
+/// Try to reassociate a pair of binops so that values with one use only are
+/// part of the same instruction. This may enable folds that are limited with
+/// multi-use restrictions and makes it more likely to match other patterns that
+/// are looking for a common operand.
+static Instruction *reassociateForUses(BinaryOperator &BO,
+ InstCombinerImpl::BuilderTy &Builder) {
+ Instruction::BinaryOps Opcode = BO.getOpcode();
+ Value *X, *Y, *Z;
+ if (match(&BO,
+ m_c_BinOp(Opcode, m_OneUse(m_BinOp(Opcode, m_Value(X), m_Value(Y))),
+ m_OneUse(m_Value(Z))))) {
+ if (!isa<Constant>(X) && !isa<Constant>(Y) && !isa<Constant>(Z)) {
+ // (X op Y) op Z --> (Y op Z) op X
+ if (!X->hasOneUse()) {
+ Value *YZ = Builder.CreateBinOp(Opcode, Y, Z);
+ return BinaryOperator::Create(Opcode, YZ, X);
+ }
+ // (X op Y) op Z --> (X op Z) op Y
+ if (!Y->hasOneUse()) {
+ Value *XZ = Builder.CreateBinOp(Opcode, X, Z);
+ return BinaryOperator::Create(Opcode, XZ, Y);
+ }
+ }
+ }
+
+ return nullptr;
+}
+
+// Match
+// (X + C2) | C
+// (X + C2) ^ C
+// (X + C2) & C
+// and convert to do the bitwise logic first:
+// (X | C) + C2
+// (X ^ C) + C2
+// (X & C) + C2
+// iff bits affected by logic op are lower than last bit affected by math op
+static Instruction *canonicalizeLogicFirst(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Type *Ty = I.getType();
+ Instruction::BinaryOps OpC = I.getOpcode();
+ Value *Op0 = I.getOperand(0);
+ Value *Op1 = I.getOperand(1);
+ Value *X;
+ const APInt *C, *C2;
+
+ if (!(match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C2)))) &&
+ match(Op1, m_APInt(C))))
+ return nullptr;
+
+ unsigned Width = Ty->getScalarSizeInBits();
+ unsigned LastOneMath = Width - C2->countTrailingZeros();
+
+ switch (OpC) {
+ case Instruction::And:
+ if (C->countLeadingOnes() < LastOneMath)
+ return nullptr;
+ break;
+ case Instruction::Xor:
+ case Instruction::Or:
+ if (C->countLeadingZeros() < LastOneMath)
+ return nullptr;
+ break;
+ default:
+ llvm_unreachable("Unexpected BinaryOp!");
+ }
+
+ Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C));
+ return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2));
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
@@ -1754,7 +1958,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
return X;
// (A|B)&(A|C) -> A|(B&C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
if (Value *V = SimplifyBSwap(I, Builder))
@@ -2156,24 +2360,36 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
A->getType()->isIntOrIntVectorTy(1))
return SelectInst::Create(A, Op0, Constant::getNullValue(Ty));
- // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0
- unsigned FullShift = Ty->getScalarSizeInBits() - 1;
- if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))),
- m_Value(Y)))) {
+ // Similarly, a 'not' of the bool translates to a swap of the select arms:
+ // ~sext(A) & Op1 --> A ? 0 : Op1
+ // Op0 & ~sext(A) --> A ? 0 : Op0
+ if (match(Op0, m_Not(m_SExt(m_Value(A)))) &&
+ A->getType()->isIntOrIntVectorTy(1))
+ return SelectInst::Create(A, Constant::getNullValue(Ty), Op1);
+ if (match(Op1, m_Not(m_SExt(m_Value(A)))) &&
+ A->getType()->isIntOrIntVectorTy(1))
+ return SelectInst::Create(A, Constant::getNullValue(Ty), Op0);
+
+ // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext
+ if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf(
+ m_AShr(m_Value(X), m_APIntAllowUndef(C)))),
+ m_Value(Y))) &&
+ *C == X->getType()->getScalarSizeInBits() - 1) {
Value *IsNeg = Builder.CreateIsNeg(X, "isneg");
return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty));
}
// If there's a 'not' of the shifted value, swap the select operands:
- // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y
- if (match(&I, m_c_And(m_OneUse(m_Not(
- m_AShr(m_Value(X), m_SpecificInt(FullShift)))),
- m_Value(Y)))) {
+ // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y -- with optional sext
+ if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf(
+ m_Not(m_AShr(m_Value(X), m_APIntAllowUndef(C))))),
+ m_Value(Y))) &&
+ *C == X->getType()->getScalarSizeInBits() - 1) {
Value *IsNeg = Builder.CreateIsNeg(X, "isneg");
return SelectInst::Create(IsNeg, ConstantInt::getNullValue(Ty), Y);
}
// (~x) & y --> ~(x | (~y)) iff that gets rid of inversions
- if (sinkNotIntoOtherHandOfAndOrOr(I))
+ if (sinkNotIntoOtherHandOfLogicalOp(I))
return &I;
// An and recurrence w/loop invariant step is equivelent to (and start, step)
@@ -2182,6 +2398,15 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) {
if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN))
return replaceInstUsesWith(I, Builder.CreateAnd(Start, Step));
+ if (Instruction *R = reassociateForUses(I, Builder))
+ return R;
+
+ if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder))
+ return Canonicalized;
+
+ if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1))
+ return Folded;
+
return nullptr;
}
@@ -2375,7 +2600,9 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) {
/// We have an expression of the form (A & C) | (B & D). If A is a scalar or
/// vector composed of all-zeros or all-ones values and is the bitwise 'not' of
/// B, it can be used as the condition operand of a select instruction.
-Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
+/// We will detect (A & C) | ~(B | D) when the flag ABIsTheSame enabled.
+Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B,
+ bool ABIsTheSame) {
// We may have peeked through bitcasts in the caller.
// Exit immediately if we don't have (vector) integer types.
Type *Ty = A->getType();
@@ -2383,7 +2610,7 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
return nullptr;
// If A is the 'not' operand of B and has enough signbits, we have our answer.
- if (match(B, m_Not(m_Specific(A)))) {
+ if (ABIsTheSame ? (A == B) : match(B, m_Not(m_Specific(A)))) {
// If these are scalars or vectors of i1, A can be used directly.
if (Ty->isIntOrIntVectorTy(1))
return A;
@@ -2403,6 +2630,10 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
return nullptr;
}
+ // TODO: add support for sext and constant case
+ if (ABIsTheSame)
+ return nullptr;
+
// If both operands are constants, see if the constants are inverse bitmasks.
Constant *AConst, *BConst;
if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst)))
@@ -2451,14 +2682,17 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
/// We have an expression of the form (A & C) | (B & D). Try to simplify this
/// to "A' ? C : D", where A' is a boolean or vector of booleans.
+/// When InvertFalseVal is set to true, we try to match the pattern
+/// where we have peeked through a 'not' op and A and B are the same:
+/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D
Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
- Value *D) {
+ Value *D, bool InvertFalseVal) {
// The potential condition of the select may be bitcasted. In that case, look
// through its bitcast and the corresponding bitcast of the 'not' condition.
Type *OrigType = A->getType();
A = peekThroughBitcast(A, true);
B = peekThroughBitcast(B, true);
- if (Value *Cond = getSelectCondition(A, B)) {
+ if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) {
// ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D))
// If this is a vector, we may need to cast to match the condition's length.
// The bitcasts will either all exist or all not exist. The builder will
@@ -2469,11 +2703,13 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
unsigned Elts = VecTy->getElementCount().getKnownMinValue();
// For a fixed or scalable vector, get the size in bits of N x iM; for a
// scalar this is just M.
- unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinSize();
+ unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinValue();
Type *EltTy = Builder.getIntNTy(SelEltSize / Elts);
SelTy = VectorType::get(EltTy, VecTy->getElementCount());
}
Value *BitcastC = Builder.CreateBitCast(C, SelTy);
+ if (InvertFalseVal)
+ D = Builder.CreateNot(D);
Value *BitcastD = Builder.CreateBitCast(D, SelTy);
Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD);
return Builder.CreateBitCast(Select, OrigType);
@@ -2484,8 +2720,9 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
// (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1)
// (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1)
-Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
- IRBuilderBase &Builder) {
+static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS,
+ bool IsAnd, bool IsLogical,
+ IRBuilderBase &Builder) {
ICmpInst::Predicate LPred =
IsAnd ? LHS->getInversePredicate() : LHS->getPredicate();
ICmpInst::Predicate RPred =
@@ -2504,6 +2741,8 @@ Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd,
else
return nullptr;
+ if (IsLogical)
+ Other = Builder.CreateFreeze(Other);
return Builder.CreateICmp(
IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE,
Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())),
@@ -2552,22 +2791,23 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS,
if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder))
return V;
- // TODO: One of these directions is fine with logical and/or, the other could
- // be supported by inserting freeze.
- if (!IsLogical) {
- if (Value *V = foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, Builder))
- return V;
- if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, Builder))
- return V;
- }
+ if (Value *V =
+ foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder))
+ return V;
+ // We can treat logical like bitwise here, because both operands are used on
+ // the LHS, and as such poison from both will propagate.
+ if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd,
+ /*IsLogical*/ false, Builder))
+ return V;
- // TODO: Verify whether this is safe for logical and/or.
- if (!IsLogical) {
- if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, Builder, Q))
- return V;
- if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, Builder, Q))
- return V;
- }
+ if (Value *V =
+ foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, IsLogical, Builder, Q))
+ return V;
+ // We can convert this case to bitwise and, because both operands are used
+ // on the LHS, and as such poison from both will propagate.
+ if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd,
+ /*IsLogical*/ false, Builder, Q))
+ return V;
if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, IsAnd, Builder))
return V;
@@ -2724,7 +2964,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return X;
// (A&B)|(A&C) -> A&(B|C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
if (Value *V = SimplifyBSwap(I, Builder))
@@ -2777,6 +3017,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
return BinaryOperator::CreateMul(X, IncrementY);
}
+ // X | (X ^ Y) --> X | Y (4 commuted patterns)
+ if (match(&I, m_c_Or(m_Value(X), m_c_Xor(m_Deferred(X), m_Value(Y)))))
+ return BinaryOperator::CreateOr(X, Y);
+
// (A & C) | (B & D)
Value *A, *B, *C, *D;
if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
@@ -2854,6 +3098,20 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
}
}
+ if (match(Op0, m_And(m_Value(A), m_Value(C))) &&
+ match(Op1, m_Not(m_Or(m_Value(B), m_Value(D)))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ // (Cond & C) | ~(Cond | D) -> Cond ? C : ~D
+ if (Value *V = matchSelectFromAndOr(A, C, B, D, true))
+ return replaceInstUsesWith(I, V);
+ if (Value *V = matchSelectFromAndOr(A, C, D, B, true))
+ return replaceInstUsesWith(I, V);
+ if (Value *V = matchSelectFromAndOr(C, A, B, D, true))
+ return replaceInstUsesWith(I, V);
+ if (Value *V = matchSelectFromAndOr(C, A, D, B, true))
+ return replaceInstUsesWith(I, V);
+ }
+
// (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C
if (match(Op0, m_Xor(m_Value(A), m_Value(B))))
if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A))))
@@ -2886,30 +3144,58 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
SwappedForXor = true;
}
- // A | ( A ^ B) -> A | B
- // A | (~A ^ B) -> A | ~B
- // (A & B) | (A ^ B)
- // ~A | (A ^ B) -> ~(A & B)
- // The swap above should always make Op0 the 'not' for the last case.
if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) {
- if (Op0 == A || Op0 == B)
- return BinaryOperator::CreateOr(A, B);
+ // (A | ?) | (A ^ B) --> (A | ?) | B
+ // (B | ?) | (A ^ B) --> (B | ?) | A
+ if (match(Op0, m_c_Or(m_Specific(A), m_Value())))
+ return BinaryOperator::CreateOr(Op0, B);
+ if (match(Op0, m_c_Or(m_Specific(B), m_Value())))
+ return BinaryOperator::CreateOr(Op0, A);
+ // (A & B) | (A ^ B) --> A | B
+ // (B & A) | (A ^ B) --> A | B
if (match(Op0, m_And(m_Specific(A), m_Specific(B))) ||
match(Op0, m_And(m_Specific(B), m_Specific(A))))
return BinaryOperator::CreateOr(A, B);
+ // ~A | (A ^ B) --> ~(A & B)
+ // ~B | (A ^ B) --> ~(A & B)
+ // The swap above should always make Op0 the 'not'.
if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
(match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B)))))
return BinaryOperator::CreateNot(Builder.CreateAnd(A, B));
+ // Same as above, but peek through an 'and' to the common operand:
+ // ~(A & ?) | (A ^ B) --> ~((A & ?) & B)
+ // ~(B & ?) | (A ^ B) --> ~((B & ?) & A)
+ Instruction *And;
+ if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
+ match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
+ m_c_And(m_Specific(A), m_Value())))))
+ return BinaryOperator::CreateNot(Builder.CreateAnd(And, B));
+ if ((Op0->hasOneUse() || Op1->hasOneUse()) &&
+ match(Op0, m_Not(m_CombineAnd(m_Instruction(And),
+ m_c_And(m_Specific(B), m_Value())))))
+ return BinaryOperator::CreateNot(Builder.CreateAnd(And, A));
+
+ // (~A | C) | (A ^ B) --> ~(A & B) | C
+ // (~B | C) | (A ^ B) --> ~(A & B) | C
+ if (Op0->hasOneUse() && Op1->hasOneUse() &&
+ (match(Op0, m_c_Or(m_Not(m_Specific(A)), m_Value(C))) ||
+ match(Op0, m_c_Or(m_Not(m_Specific(B)), m_Value(C))))) {
+ Value *Nand = Builder.CreateNot(Builder.CreateAnd(A, B), "nand");
+ return BinaryOperator::CreateOr(Nand, C);
+ }
+
+ // A | (~A ^ B) --> ~B | A
+ // B | (A ^ ~B) --> ~A | B
if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) {
- Value *Not = Builder.CreateNot(B, B->getName() + ".not");
- return BinaryOperator::CreateOr(Not, Op0);
+ Value *NotB = Builder.CreateNot(B, B->getName() + ".not");
+ return BinaryOperator::CreateOr(NotB, Op0);
}
if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) {
- Value *Not = Builder.CreateNot(A, A->getName() + ".not");
- return BinaryOperator::CreateOr(Not, Op0);
+ Value *NotA = Builder.CreateNot(A, A->getName() + ".not");
+ return BinaryOperator::CreateOr(NotA, Op0);
}
}
@@ -3072,7 +3358,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
}
// (~x) | y --> ~(x & (~y)) iff that gets rid of inversions
- if (sinkNotIntoOtherHandOfAndOrOr(I))
+ if (sinkNotIntoOtherHandOfLogicalOp(I))
return &I;
// Improve "get low bit mask up to and including bit X" pattern:
@@ -3121,6 +3407,15 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
Builder.CreateOr(C, Builder.CreateAnd(A, B)), D);
}
+ if (Instruction *R = reassociateForUses(I, Builder))
+ return R;
+
+ if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder))
+ return Canonicalized;
+
+ if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1))
+ return Folded;
+
return nullptr;
}
@@ -3338,14 +3633,8 @@ static Instruction *visitMaskedMerge(BinaryOperator &I,
// (~x) ^ y
// or into
// x ^ (~y)
-static Instruction *sinkNotIntoXor(BinaryOperator &I,
+static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y,
InstCombiner::BuilderTy &Builder) {
- Value *X, *Y;
- // FIXME: one-use check is not needed in general, but currently we are unable
- // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182)
- if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y))))))
- return nullptr;
-
// We only want to do the transform if it is free to do.
if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) {
// Ok, good.
@@ -3358,6 +3647,41 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I,
return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan");
}
+static Instruction *foldNotXor(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X, *Y;
+ // FIXME: one-use check is not needed in general, but currently we are unable
+ // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182)
+ if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y))))))
+ return nullptr;
+
+ if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder))
+ return NewXor;
+
+ auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) {
+ return A == C || A == D || B == C || B == D;
+ };
+
+ Value *A, *B, *C, *D;
+ // Canonicalize ~((A & B) ^ (A | ?)) -> (A & B) | ~(A | ?)
+ // 4 commuted variants
+ if (match(X, m_And(m_Value(A), m_Value(B))) &&
+ match(Y, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) {
+ Value *NotY = Builder.CreateNot(Y);
+ return BinaryOperator::CreateOr(X, NotY);
+ };
+
+ // Canonicalize ~((A | ?) ^ (A & B)) -> (A & B) | ~(A | ?)
+ // 4 commuted variants
+ if (match(Y, m_And(m_Value(A), m_Value(B))) &&
+ match(X, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) {
+ Value *NotX = Builder.CreateNot(X);
+ return BinaryOperator::CreateOr(Y, NotX);
+ };
+
+ return nullptr;
+}
+
/// Canonicalize a shifty way to code absolute value to the more common pattern
/// that uses negation and select.
static Instruction *canonicalizeAbs(BinaryOperator &Xor,
@@ -3392,39 +3716,127 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor,
}
// Transform
+// z = ~(x &/| y)
+// into:
+// z = ((~x) |/& (~y))
+// iff both x and y are free to invert and all uses of z can be freely updated.
+bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) {
+ Value *Op0, *Op1;
+ if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1))))
+ return false;
+
+ // If this logic op has not been simplified yet, just bail out and let that
+ // happen first. Otherwise, the code below may wrongly invert.
+ if (Op0 == Op1)
+ return false;
+
+ Instruction::BinaryOps NewOpc =
+ match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And;
+ bool IsBinaryOp = isa<BinaryOperator>(I);
+
+ // Can our users be adapted?
+ if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
+ return false;
+
+ // And can the operands be adapted?
+ for (Value *Op : {Op0, Op1})
+ if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) &&
+ (match(Op, m_ImmConstant()) ||
+ (isa<Instruction>(Op) &&
+ InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op),
+ /*IgnoredUser=*/&I)))))
+ return false;
+
+ for (Value **Op : {&Op0, &Op1}) {
+ Value *NotOp;
+ if (auto *C = dyn_cast<Constant>(*Op)) {
+ NotOp = ConstantExpr::getNot(C);
+ } else {
+ Builder.SetInsertPoint(
+ &*cast<Instruction>(*Op)->getInsertionPointAfterDef());
+ NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not");
+ (*Op)->replaceUsesWithIf(
+ NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; });
+ freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I);
+ }
+ *Op = NotOp;
+ }
+
+ Builder.SetInsertPoint(I.getInsertionPointAfterDef());
+ Value *NewLogicOp;
+ if (IsBinaryOp)
+ NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not");
+ else
+ NewLogicOp =
+ Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not");
+
+ replaceInstUsesWith(I, NewLogicOp);
+ // We can not just create an outer `not`, it will most likely be immediately
+ // folded back, reconstructing our initial pattern, and causing an
+ // infinite combine loop, so immediately manually fold it away.
+ freelyInvertAllUsersOf(NewLogicOp);
+ return true;
+}
+
+// Transform
// z = (~x) &/| y
// into:
// z = ~(x |/& (~y))
// iff y is free to invert and all uses of z can be freely updated.
-bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) {
- Instruction::BinaryOps NewOpc;
- switch (I.getOpcode()) {
- case Instruction::And:
- NewOpc = Instruction::Or;
- break;
- case Instruction::Or:
- NewOpc = Instruction::And;
- break;
- default:
+bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) {
+ Value *Op0, *Op1;
+ if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1))))
return false;
- };
+ Instruction::BinaryOps NewOpc =
+ match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And;
+ bool IsBinaryOp = isa<BinaryOperator>(I);
- Value *X, *Y;
- if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y))))
- return false;
-
- // Will we be able to fold the `not` into Y eventually?
- if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse()))
+ Value *NotOp0 = nullptr;
+ Value *NotOp1 = nullptr;
+ Value **OpToInvert = nullptr;
+ if (match(Op0, m_Not(m_Value(NotOp0))) &&
+ InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) &&
+ (match(Op1, m_ImmConstant()) ||
+ (isa<Instruction>(Op1) &&
+ InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1),
+ /*IgnoredUser=*/&I)))) {
+ Op0 = NotOp0;
+ OpToInvert = &Op1;
+ } else if (match(Op1, m_Not(m_Value(NotOp1))) &&
+ InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) &&
+ (match(Op0, m_ImmConstant()) ||
+ (isa<Instruction>(Op0) &&
+ InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0),
+ /*IgnoredUser=*/&I)))) {
+ Op1 = NotOp1;
+ OpToInvert = &Op0;
+ } else
return false;
// And can our users be adapted?
if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr))
return false;
- Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not");
- Value *NewBinOp =
- BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not");
- Builder.Insert(NewBinOp);
+ if (auto *C = dyn_cast<Constant>(*OpToInvert)) {
+ *OpToInvert = ConstantExpr::getNot(C);
+ } else {
+ Builder.SetInsertPoint(
+ &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef());
+ Value *NotOpToInvert =
+ Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not");
+ (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) {
+ return U.getUser() != NotOpToInvert;
+ });
+ freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I);
+ *OpToInvert = NotOpToInvert;
+ }
+
+ Builder.SetInsertPoint(&*I.getInsertionPointAfterDef());
+ Value *NewBinOp;
+ if (IsBinaryOp)
+ NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not");
+ else
+ NewBinOp = Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not");
replaceInstUsesWith(I, NewBinOp);
// We can not just create an outer `not`, it will most likely be immediately
// folded back, reconstructing our initial pattern, and causing an
@@ -3472,23 +3884,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
// Is this a 'not' (~) fed by a binary operator?
BinaryOperator *NotVal;
if (match(NotOp, m_BinOp(NotVal))) {
- if (NotVal->getOpcode() == Instruction::And ||
- NotVal->getOpcode() == Instruction::Or) {
- // Apply DeMorgan's Law when inverts are free:
- // ~(X & Y) --> (~X | ~Y)
- // ~(X | Y) --> (~X & ~Y)
- if (isFreeToInvert(NotVal->getOperand(0),
- NotVal->getOperand(0)->hasOneUse()) &&
- isFreeToInvert(NotVal->getOperand(1),
- NotVal->getOperand(1)->hasOneUse())) {
- Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs");
- Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs");
- if (NotVal->getOpcode() == Instruction::And)
- return BinaryOperator::CreateOr(NotX, NotY);
- return BinaryOperator::CreateAnd(NotX, NotY);
- }
- }
-
// ~((-X) | Y) --> (X - 1) & (~Y)
if (match(NotVal,
m_OneUse(m_c_Or(m_OneUse(m_Neg(m_Value(X))), m_Value(Y))))) {
@@ -3501,6 +3896,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y))))
return BinaryOperator::CreateAShr(X, Y);
+ // Bit-hack form of a signbit test:
+ // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN
+ unsigned FullShift = Ty->getScalarSizeInBits() - 1;
+ if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) {
+ Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg");
+ return new SExtInst(IsNotNeg, Ty);
+ }
+
// If we are inverting a right-shifted constant, we may be able to eliminate
// the 'not' by inverting the constant and using the opposite shift type.
// Canonicalization rules ensure that only a negative constant uses 'ashr',
@@ -3545,11 +3948,28 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
// not (cmp A, B) = !cmp A, B
CmpInst::Predicate Pred;
- if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) {
+ if (match(NotOp, m_Cmp(Pred, m_Value(), m_Value())) &&
+ (NotOp->hasOneUse() ||
+ InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(NotOp),
+ /*IgnoredUser=*/nullptr))) {
cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred));
- return replaceInstUsesWith(I, NotOp);
+ freelyInvertAllUsersOf(NotOp);
+ return &I;
+ }
+
+ // Move a 'not' ahead of casts of a bool to enable logic reduction:
+ // not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X))
+ if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) {
+ Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy();
+ Value *NotX = Builder.CreateNot(X);
+ Value *Sext = Builder.CreateSExt(NotX, SextTy);
+ return CastInst::CreateBitOrPointerCast(Sext, Ty);
}
+ if (auto *NotOpI = dyn_cast<Instruction>(NotOp))
+ if (sinkNotIntoLogicalOp(*NotOpI))
+ return &I;
+
// Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max:
// ~min(~X, ~Y) --> max(X, Y)
// ~max(~X, Y) --> min(X, ~Y)
@@ -3570,6 +3990,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY);
return replaceInstUsesWith(I, InvMaxMin);
}
+
+ if (II->getIntrinsicID() == Intrinsic::is_fpclass) {
+ ConstantInt *ClassMask = cast<ConstantInt>(II->getArgOperand(1));
+ II->setArgOperand(
+ 1, ConstantInt::get(ClassMask->getType(),
+ ~ClassMask->getZExtValue() & fcAllFlags));
+ return replaceInstUsesWith(I, II);
+ }
}
if (NotOp->hasOneUse()) {
@@ -3602,7 +4030,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) {
}
}
- if (Instruction *NewXor = sinkNotIntoXor(I, Builder))
+ if (Instruction *NewXor = foldNotXor(I, Builder))
return NewXor;
return nullptr;
@@ -3629,7 +4057,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
return NewXor;
// (A&B)^(A&C) -> A&(B^C) etc
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
// See if we can simplify any instructions used by the instruction whose sole
@@ -3718,6 +4146,21 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
MaskedValueIsZero(X, *C, 0, &I))
return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC));
+ // When X is a power-of-two or zero and zero input is poison:
+ // ctlz(i32 X) ^ 31 --> cttz(X)
+ // cttz(i32 X) ^ 31 --> ctlz(X)
+ auto *II = dyn_cast<IntrinsicInst>(Op0);
+ if (II && II->hasOneUse() && *RHSC == Ty->getScalarSizeInBits() - 1) {
+ Intrinsic::ID IID = II->getIntrinsicID();
+ if ((IID == Intrinsic::ctlz || IID == Intrinsic::cttz) &&
+ match(II->getArgOperand(1), m_One()) &&
+ isKnownToBeAPowerOfTwo(II->getArgOperand(0), /*OrZero */ true)) {
+ IID = (IID == Intrinsic::ctlz) ? Intrinsic::cttz : Intrinsic::ctlz;
+ Function *F = Intrinsic::getDeclaration(II->getModule(), IID, Ty);
+ return CallInst::Create(F, {II->getArgOperand(0), Builder.getTrue()});
+ }
+ }
+
// If RHSC is inverting the remaining bits of shifted X,
// canonicalize to a 'not' before the shift to help SCEV and codegen:
// (X << C) ^ RHSC --> ~X << C
@@ -3858,5 +4301,17 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) {
m_Value(Y))))
return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1);
+ if (Instruction *R = reassociateForUses(I, Builder))
+ return R;
+
+ if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder))
+ return Canonicalized;
+
+ if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1))
+ return Folded;
+
+ if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I))
+ return Folded;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
index 0327efbf9614..e73667f9c02e 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp
@@ -128,10 +128,9 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
if (Ordering != AtomicOrdering::Release &&
Ordering != AtomicOrdering::Monotonic)
return nullptr;
- auto *SI = new StoreInst(RMWI.getValOperand(),
- RMWI.getPointerOperand(), &RMWI);
- SI->setAtomic(Ordering, RMWI.getSyncScopeID());
- SI->setAlignment(DL.getABITypeAlign(RMWI.getType()));
+ new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(),
+ /*isVolatile*/ false, RMWI.getAlign(), Ordering,
+ RMWI.getSyncScopeID(), &RMWI);
return eraseInstFromFunction(RMWI);
}
@@ -152,13 +151,5 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) {
return replaceOperand(RMWI, 1, ConstantFP::getNegativeZero(RMWI.getType()));
}
- // Check if the required ordering is compatible with an atomic load.
- if (Ordering != AtomicOrdering::Acquire &&
- Ordering != AtomicOrdering::Monotonic)
- return nullptr;
-
- LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "",
- false, DL.getABITypeAlign(RMWI.getType()),
- Ordering, RMWI.getSyncScopeID());
- return Load;
+ return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index bc01d2ef7fe2..fbf1327143a8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -15,8 +15,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/APSInt.h"
#include "llvm/ADT/ArrayRef.h"
-#include "llvm/ADT/None.h"
-#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLFunctionalExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
@@ -34,6 +32,7 @@
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/GlobalVariable.h"
@@ -71,6 +70,7 @@
#include <algorithm>
#include <cassert>
#include <cstdint>
+#include <optional>
#include <utility>
#include <vector>
@@ -135,7 +135,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
// If we have a store to a location which is known constant, we can conclude
// that the store must be storing the constant value (else the memory
// wouldn't be constant), and this must be a noop.
- if (AA->pointsToConstantMemory(MI->getDest())) {
+ if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) {
// Set the size of the copy to 0, it will be deleted on the next iteration.
MI->setLength(Constant::getNullValue(MI->getLength()->getType()));
return MI;
@@ -223,6 +223,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) {
S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD);
if (AccessGroupMD)
S->setMetadata(LLVMContext::MD_access_group, AccessGroupMD);
+ S->copyMetadata(*MI, LLVMContext::MD_DIAssignID);
if (auto *MT = dyn_cast<MemTransferInst>(MI)) {
// non-atomics can be volatile
@@ -252,7 +253,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
// If we have a store to a location which is known constant, we can conclude
// that the store must be storing the constant value (else the memory
// wouldn't be constant), and this must be a noop.
- if (AA->pointsToConstantMemory(MI->getDest())) {
+ if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) {
// Set the size of the copy to 0, it will be deleted on the next iteration.
MI->setLength(Constant::getNullValue(MI->getLength()->getType()));
return MI;
@@ -294,9 +295,15 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) {
Dest = Builder.CreateBitCast(Dest, NewDstPtrTy);
// Extract the fill value and store.
- uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL;
- StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest,
- MI->isVolatile());
+ const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL;
+ Constant *FillVal = ConstantInt::get(ITy, Fill);
+ StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile());
+ S->copyMetadata(*MI, LLVMContext::MD_DIAssignID);
+ for (auto *DAI : at::getAssignmentMarkers(S)) {
+ if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; }))
+ DAI->replaceVariableLocationOp(FillC, FillVal);
+ }
+
S->setAlignment(Alignment);
if (isa<AtomicMemSetInst>(MI))
S->setOrdering(AtomicOrdering::Unordered);
@@ -328,7 +335,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) {
// If we can unconditionally load from this address, replace with a
// load/select idiom. TODO: use DT for context sensitive query
if (isDereferenceablePointer(LoadPtr, II.getType(),
- II.getModule()->getDataLayout(), &II, nullptr)) {
+ II.getModule()->getDataLayout(), &II, &AC)) {
LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment,
"unmaskedload");
LI->copyMetadata(II);
@@ -661,10 +668,21 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) {
// If all bits are zero except for exactly one fixed bit, then the result
// must be 0 or 1, and we can get that answer by shifting to LSB:
// ctpop (X & 32) --> (X & 32) >> 5
+ // TODO: Investigate removing this as its likely unnecessary given the below
+ // `isKnownToBeAPowerOfTwo` check.
if ((~Known.Zero).isPowerOf2())
return BinaryOperator::CreateLShr(
Op0, ConstantInt::get(Ty, (~Known.Zero).exactLogBase2()));
+ // More generally we can also handle non-constant power of 2 patterns such as
+ // shl/shr(Pow2, X), (X & -X), etc... by transforming:
+ // ctpop(Pow2OrZero) --> icmp ne X, 0
+ if (IC.isKnownToBeAPowerOfTwo(Op0, /* OrZero */ true))
+ return CastInst::Create(Instruction::ZExt,
+ IC.Builder.CreateICmp(ICmpInst::ICMP_NE, Op0,
+ Constant::getNullValue(Ty)),
+ Ty);
+
// FIXME: Try to simplify vectors of integers.
auto *IT = dyn_cast<IntegerType>(Ty);
if (!IT)
@@ -720,7 +738,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II,
auto *V1 = II.getArgOperand(0);
auto *V2 = Constant::getNullValue(V1->getType());
- return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes));
+ return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes));
}
// Returns true iff the 2 intrinsics have the same operands, limiting the
@@ -812,9 +830,10 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) {
return nullptr;
}
-static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
- const DataLayout &DL, AssumptionCache *AC,
- DominatorTree *DT) {
+static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI,
+ const DataLayout &DL,
+ AssumptionCache *AC,
+ DominatorTree *DT) {
KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT);
if (Known.isNonNegative())
return false;
@@ -1266,7 +1285,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X))))
return replaceOperand(*II, 0, X);
- if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) {
+ if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) {
// abs(x) -> x if x >= 0
if (!*Sign)
return replaceInstUsesWith(*II, IIOperand);
@@ -1297,11 +1316,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
// umin(x, 1) == zext(x != 0)
if (match(I1, m_One())) {
+ assert(II->getType()->getScalarSizeInBits() != 1 &&
+ "Expected simplify of umin with max constant");
Value *Zero = Constant::getNullValue(I0->getType());
Value *Cmp = Builder.CreateICmpNE(I0, Zero);
return CastInst::Create(Instruction::ZExt, Cmp, II->getType());
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::umax: {
Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1);
@@ -1322,7 +1343,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
// If both operands of unsigned min/max are sign-extended, it is still ok
// to narrow the operation.
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::smax:
case Intrinsic::smin: {
@@ -1431,6 +1452,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::bitreverse: {
+ // bitrev (zext i1 X to ?) --> X ? SignBitC : 0
+ Value *X;
+ if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) &&
+ X->getType()->isIntOrIntVectorTy(1)) {
+ Type *Ty = II->getType();
+ APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits());
+ return SelectInst::Create(X, ConstantInt::get(Ty, SignBit),
+ ConstantInt::getNullValue(Ty));
+ }
+ break;
+ }
case Intrinsic::bswap: {
Value *IIOperand = II->getArgOperand(0);
@@ -1829,6 +1862,63 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
break;
}
+ case Intrinsic::matrix_multiply: {
+ // Optimize negation in matrix multiplication.
+
+ // -A * -B -> A * B
+ Value *A, *B;
+ if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) &&
+ match(II->getArgOperand(1), m_FNeg(m_Value(B)))) {
+ replaceOperand(*II, 0, A);
+ replaceOperand(*II, 1, B);
+ return II;
+ }
+
+ Value *Op0 = II->getOperand(0);
+ Value *Op1 = II->getOperand(1);
+ Value *OpNotNeg, *NegatedOp;
+ unsigned NegatedOpArg, OtherOpArg;
+ if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) {
+ NegatedOp = Op0;
+ NegatedOpArg = 0;
+ OtherOpArg = 1;
+ } else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) {
+ NegatedOp = Op1;
+ NegatedOpArg = 1;
+ OtherOpArg = 0;
+ } else
+ // Multiplication doesn't have a negated operand.
+ break;
+
+ // Only optimize if the negated operand has only one use.
+ if (!NegatedOp->hasOneUse())
+ break;
+
+ Value *OtherOp = II->getOperand(OtherOpArg);
+ VectorType *RetTy = cast<VectorType>(II->getType());
+ VectorType *NegatedOpTy = cast<VectorType>(NegatedOp->getType());
+ VectorType *OtherOpTy = cast<VectorType>(OtherOp->getType());
+ ElementCount NegatedCount = NegatedOpTy->getElementCount();
+ ElementCount OtherCount = OtherOpTy->getElementCount();
+ ElementCount RetCount = RetTy->getElementCount();
+ // (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa.
+ if (ElementCount::isKnownGT(NegatedCount, OtherCount) &&
+ ElementCount::isKnownLT(OtherCount, RetCount)) {
+ Value *InverseOtherOp = Builder.CreateFNeg(OtherOp);
+ replaceOperand(*II, NegatedOpArg, OpNotNeg);
+ replaceOperand(*II, OtherOpArg, InverseOtherOp);
+ return II;
+ }
+ // (-A) * B -> -(A * B), if it is cheaper to negate the result
+ if (ElementCount::isKnownGT(NegatedCount, RetCount)) {
+ SmallVector<Value *, 5> NewArgs(II->args());
+ NewArgs[NegatedOpArg] = OpNotNeg;
+ Instruction *NewMul =
+ Builder.CreateIntrinsic(II->getType(), IID, NewArgs, II);
+ return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II));
+ }
+ break;
+ }
case Intrinsic::fmuladd: {
// Canonicalize fast fmuladd to the separate fmul + fadd.
if (II->isFast()) {
@@ -1850,7 +1940,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return FAdd;
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::fma: {
// fma fneg(x), fneg(y), z -> fma x, y, z
@@ -1940,7 +2030,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return replaceOperand(*II, 0, TVal);
}
- LLVM_FALLTHROUGH;
+ Value *Magnitude, *Sign;
+ if (match(II->getArgOperand(0),
+ m_CopySign(m_Value(Magnitude), m_Value(Sign)))) {
+ // fabs (copysign x, y) -> (fabs x)
+ CallInst *AbsSign =
+ Builder.CreateCall(II->getCalledFunction(), {Magnitude});
+ AbsSign->copyFastMathFlags(II);
+ return replaceInstUsesWith(*II, AbsSign);
+ }
+
+ [[fallthrough]];
}
case Intrinsic::ceil:
case Intrinsic::floor:
@@ -1979,7 +2079,64 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
break;
}
+ case Intrinsic::ptrauth_auth:
+ case Intrinsic::ptrauth_resign: {
+ // (sign|resign) + (auth|resign) can be folded by omitting the middle
+ // sign+auth component if the key and discriminator match.
+ bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign;
+ Value *Key = II->getArgOperand(1);
+ Value *Disc = II->getArgOperand(2);
+ // AuthKey will be the key we need to end up authenticating against in
+ // whatever we replace this sequence with.
+ Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr;
+ if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) {
+ BasePtr = CI->getArgOperand(0);
+ if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) {
+ if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc)
+ break;
+ } else if (CI->getIntrinsicID() == Intrinsic::ptrauth_resign) {
+ if (CI->getArgOperand(3) != Key || CI->getArgOperand(4) != Disc)
+ break;
+ AuthKey = CI->getArgOperand(1);
+ AuthDisc = CI->getArgOperand(2);
+ } else
+ break;
+ } else
+ break;
+
+ unsigned NewIntrin;
+ if (AuthKey && NeedSign) {
+ // resign(0,1) + resign(1,2) = resign(0, 2)
+ NewIntrin = Intrinsic::ptrauth_resign;
+ } else if (AuthKey) {
+ // resign(0,1) + auth(1) = auth(0)
+ NewIntrin = Intrinsic::ptrauth_auth;
+ } else if (NeedSign) {
+ // sign(0) + resign(0, 1) = sign(1)
+ NewIntrin = Intrinsic::ptrauth_sign;
+ } else {
+ // sign(0) + auth(0) = nop
+ replaceInstUsesWith(*II, BasePtr);
+ eraseInstFromFunction(*II);
+ return nullptr;
+ }
+
+ SmallVector<Value *, 4> CallArgs;
+ CallArgs.push_back(BasePtr);
+ if (AuthKey) {
+ CallArgs.push_back(AuthKey);
+ CallArgs.push_back(AuthDisc);
+ }
+
+ if (NeedSign) {
+ CallArgs.push_back(II->getArgOperand(3));
+ CallArgs.push_back(II->getArgOperand(4));
+ }
+
+ Function *NewFn = Intrinsic::getDeclaration(II->getModule(), NewIntrin);
+ return CallInst::Create(NewFn, CallArgs);
+ }
case Intrinsic::arm_neon_vtbl1:
case Intrinsic::aarch64_neon_tbl1:
if (Value *V = simplifyNeonTbl1(*II, Builder))
@@ -2221,7 +2378,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load &&
LHS->getType()->isPointerTy() &&
isValidAssumeForContext(II, LHS, &DT)) {
- MDNode *MD = MDNode::get(II->getContext(), None);
+ MDNode *MD = MDNode::get(II->getContext(), std::nullopt);
LHS->setMetadata(LLVMContext::MD_nonnull, MD);
return RemoveConditionFromAssume(II);
@@ -2288,7 +2445,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
llvm::getKnowledgeFromBundle(cast<AssumeInst>(*II), BOI);
if (BOI.End - BOI.Begin > 2)
continue; // Prevent reducing knowledge in an align with offset since
- // extracting a RetainedKnowledge form them looses offset
+ // extracting a RetainedKnowledge from them looses offset
// information
RetainedKnowledge CanonRK =
llvm::simplifyRetainedKnowledge(cast<AssumeInst>(II), RK,
@@ -2409,7 +2566,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *Vec = II->getArgOperand(0);
Value *Idx = II->getArgOperand(1);
- auto *DstTy = dyn_cast<FixedVectorType>(II->getType());
+ Type *ReturnType = II->getType();
+ // (extract_vector (insert_vector InsertTuple, InsertValue, InsertIdx),
+ // ExtractIdx)
+ unsigned ExtractIdx = cast<ConstantInt>(Idx)->getZExtValue();
+ Value *InsertTuple, *InsertIdx, *InsertValue;
+ if (match(Vec, m_Intrinsic<Intrinsic::vector_insert>(m_Value(InsertTuple),
+ m_Value(InsertValue),
+ m_Value(InsertIdx))) &&
+ InsertValue->getType() == ReturnType) {
+ unsigned Index = cast<ConstantInt>(InsertIdx)->getZExtValue();
+ // Case where we get the same index right after setting it.
+ // extract.vector(insert.vector(InsertTuple, InsertValue, Idx), Idx) -->
+ // InsertValue
+ if (ExtractIdx == Index)
+ return replaceInstUsesWith(CI, InsertValue);
+ // If we are getting a different index than what was set in the
+ // insert.vector intrinsic. We can just set the input tuple to the one up
+ // in the chain. extract.vector(insert.vector(InsertTuple, InsertValue,
+ // InsertIndex), ExtractIndex)
+ // --> extract.vector(InsertTuple, ExtractIndex)
+ else
+ return replaceOperand(CI, 0, InsertTuple);
+ }
+
+ auto *DstTy = dyn_cast<FixedVectorType>(ReturnType);
auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType());
// Only canonicalize if the the destination vector and Vec are fixed
@@ -2439,11 +2620,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
Value *Vec = II->getArgOperand(0);
if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) {
auto *OldBinOp = cast<BinaryOperator>(Vec);
- if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>(
- m_Value(X)))) {
+ if (match(BO0, m_VecReverse(m_Value(X)))) {
// rev(binop rev(X), rev(Y)) --> binop X, Y
- if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>(
- m_Value(Y))))
+ if (match(BO1, m_VecReverse(m_Value(Y))))
return replaceInstUsesWith(CI,
BinaryOperator::CreateWithCopiedFlags(
OldBinOp->getOpcode(), X, Y, OldBinOp,
@@ -2456,17 +2635,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
OldBinOp, OldBinOp->getName(), II));
}
// rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y
- if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>(
- m_Value(Y))) &&
- isSplatValue(BO0))
+ if (match(BO1, m_VecReverse(m_Value(Y))) && isSplatValue(BO0))
return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags(
OldBinOp->getOpcode(), BO0, Y,
OldBinOp, OldBinOp->getName(), II));
}
// rev(unop rev(X)) --> unop X
- if (match(Vec, m_OneUse(m_UnOp(
- m_Intrinsic<Intrinsic::experimental_vector_reverse>(
- m_Value(X)))))) {
+ if (match(Vec, m_OneUse(m_UnOp(m_VecReverse(m_Value(X)))))) {
auto *OldUnOp = cast<UnaryOperator>(Vec);
auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags(
OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II);
@@ -2504,7 +2679,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
return replaceInstUsesWith(CI, Res);
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_add: {
if (IID == Intrinsic::vector_reduce_add) {
@@ -2531,7 +2706,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_xor: {
if (IID == Intrinsic::vector_reduce_xor) {
@@ -2555,7 +2730,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_mul: {
if (IID == Intrinsic::vector_reduce_mul) {
@@ -2577,7 +2752,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_umin:
case Intrinsic::vector_reduce_umax: {
@@ -2604,7 +2779,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_smin:
case Intrinsic::vector_reduce_smax: {
@@ -2642,7 +2817,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::vector_reduce_fmax:
case Intrinsic::vector_reduce_fmin:
@@ -2679,9 +2854,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) {
}
default: {
// Handle target specific intrinsics
- Optional<Instruction *> V = targetInstCombineIntrinsic(*II);
+ std::optional<Instruction *> V = targetInstCombineIntrinsic(*II);
if (V)
- return V.value();
+ return *V;
break;
}
}
@@ -2887,7 +3062,7 @@ bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call,
if (!Call.getType()->isPointerTy())
return Changed;
- Optional<APInt> Size = getAllocSize(&Call, TLI);
+ std::optional<APInt> Size = getAllocSize(&Call, TLI);
if (Size && *Size != 0) {
// TODO: We really should just emit deref_or_null here and then
// let the generic inference code combine that with nonnull.
@@ -3078,6 +3253,30 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy));
}
+ // Drop unnecessary kcfi operand bundles from calls that were converted
+ // into direct calls.
+ auto Bundle = Call.getOperandBundle(LLVMContext::OB_kcfi);
+ if (Bundle && !Call.isIndirectCall()) {
+ DEBUG_WITH_TYPE(DEBUG_TYPE "-kcfi", {
+ if (CalleeF) {
+ ConstantInt *FunctionType = nullptr;
+ ConstantInt *ExpectedType = cast<ConstantInt>(Bundle->Inputs[0]);
+
+ if (MDNode *MD = CalleeF->getMetadata(LLVMContext::MD_kcfi_type))
+ FunctionType = mdconst::extract<ConstantInt>(MD->getOperand(0));
+
+ if (FunctionType &&
+ FunctionType->getZExtValue() != ExpectedType->getZExtValue())
+ dbgs() << Call.getModule()->getName()
+ << ": warning: kcfi: " << Call.getCaller()->getName()
+ << ": call to " << CalleeF->getName()
+ << " using a mismatching function pointer type\n";
+ }
+ });
+
+ return CallBase::removeOperandBundle(&Call, LLVMContext::OB_kcfi);
+ }
+
if (isRemovableAlloc(&Call, &TLI))
return visitAllocSite(Call);
@@ -3140,7 +3339,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
LiveGcValues.insert(BasePtr);
LiveGcValues.insert(DerivedPtr);
}
- Optional<OperandBundleUse> Bundle =
+ std::optional<OperandBundleUse> Bundle =
GCSP.getOperandBundle(LLVMContext::OB_gc_live);
unsigned NumOfGCLives = LiveGcValues.size();
if (!Bundle || NumOfGCLives == Bundle->Inputs.size())
@@ -3148,8 +3347,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) {
// We can reduce the size of gc live bundle.
DenseMap<Value *, unsigned> Val2Idx;
std::vector<Value *> NewLiveGc;
- for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) {
- Value *V = Bundle->Inputs[I];
+ for (Value *V : Bundle->Inputs) {
if (Val2Idx.count(V))
continue;
if (LiveGcValues.count(V)) {
@@ -3289,6 +3487,10 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
if (CallerPAL.hasParamAttr(i, Attribute::SwiftError))
return false;
+ if (CallerPAL.hasParamAttr(i, Attribute::ByVal) !=
+ Callee->getAttributes().hasParamAttr(i, Attribute::ByVal))
+ return false; // Cannot transform to or from byval.
+
// If the parameter is passed as a byval argument, then we have to have a
// sized type and the sized type has to have the same size as the old type.
if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) {
@@ -3447,21 +3649,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) {
NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy);
NC->setDebugLoc(Caller->getDebugLoc());
- // If this is an invoke/callbr instruction, we should insert it after the
- // first non-phi instruction in the normal successor block.
- if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) {
- BasicBlock::iterator I = II->getNormalDest()->getFirstInsertionPt();
- InsertNewInstBefore(NC, *I);
- } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) {
- BasicBlock::iterator I = CBI->getDefaultDest()->getFirstInsertionPt();
- InsertNewInstBefore(NC, *I);
- } else {
- // Otherwise, it's a call, just insert cast right after the call.
- InsertNewInstBefore(NC, *Caller);
- }
+ Instruction *InsertPt = NewCall->getInsertionPointAfterDef();
+ assert(InsertPt && "No place to insert cast");
+ InsertNewInstBefore(NC, *InsertPt);
Worklist.pushUsersToWorkList(*Caller);
} else {
- NV = UndefValue::get(Caller->getType());
+ NV = PoisonValue::get(Caller->getType());
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
index a9a930555b3c..3f851a2b2182 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp
@@ -14,9 +14,12 @@
#include "llvm/ADT/SetVector.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/DebugInfo.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/Support/KnownBits.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
+#include <optional>
+
using namespace llvm;
using namespace PatternMatch;
@@ -118,14 +121,15 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI,
if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr;
// The alloc and cast types should be either both fixed or both scalable.
- uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize();
- uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize();
+ uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue();
+ uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue();
if (CastElTySize == 0 || AllocElTySize == 0) return nullptr;
// If the allocation has multiple uses, only promote it if we're not
// shrinking the amount of memory being allocated.
- uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize();
- uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize();
+ uint64_t AllocElTyStoreSize =
+ DL.getTypeStoreSize(AllocElTy).getKnownMinValue();
+ uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue();
if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr;
// See if we can satisfy the modulus by pulling a scale out of the array
@@ -163,6 +167,10 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI,
New->setAlignment(AI.getAlign());
New->takeName(&AI);
New->setUsedWithInAlloca(AI.isUsedWithInAlloca());
+ New->setMetadata(LLVMContext::MD_DIAssignID,
+ AI.getMetadata(LLVMContext::MD_DIAssignID));
+
+ replaceAllDbgUsesWith(AI, *New, *New, DT);
// If the allocation has multiple real uses, insert a cast and change all
// things that used it to use the new cast. This will also hack on CI, but it
@@ -239,6 +247,11 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty,
Res = NPN;
break;
}
+ case Instruction::FPToUI:
+ case Instruction::FPToSI:
+ Res = CastInst::Create(
+ static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty);
+ break;
default:
// TODO: Can handle more cases here.
llvm_unreachable("Unreachable!");
@@ -483,6 +496,22 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC,
return false;
return true;
}
+ case Instruction::FPToUI:
+ case Instruction::FPToSI: {
+ // If the integer type can hold the max FP value, it is safe to cast
+ // directly to that type. Otherwise, we may create poison via overflow
+ // that did not exist in the original code.
+ //
+ // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need
+ // at least one more bit than the MaxExponent to hold the max FP value.
+ Type *InputTy = I->getOperand(0)->getType()->getScalarType();
+ const fltSemantics &Semantics = InputTy->getFltSemantics();
+ uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics);
+ // Extra sign bit needed.
+ if (I->getOpcode() == Instruction::FPToSI)
+ ++MinBitWidth;
+ return Ty->getScalarSizeInBits() > MinBitWidth;
+ }
default:
// TODO: Can handle more cases here.
break;
@@ -726,7 +755,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc,
InstCombiner::BuilderTy &Builder) {
auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0));
if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) &&
- is_splat(Shuf->getShuffleMask()) &&
+ all_equal(Shuf->getShuffleMask()) &&
Shuf->getType() == Shuf->getOperand(0)->getType()) {
// trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask
// trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask
@@ -974,7 +1003,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
Attribute Attr =
Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange);
- if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
+ if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
if (Log2_32(*MaxVScale) < DestWidth) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
return replaceInstUsesWith(Trunc, VScale);
@@ -986,7 +1015,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) {
return nullptr;
}
-Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) {
+Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp,
+ ZExtInst &Zext) {
// If we are just checking for a icmp eq of a single bit and zext'ing it
// to an integer, then shift the bit to the appropriate place and then
// cast to integer to avoid the comparison.
@@ -1014,28 +1044,20 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext)
// zext (X == 0) to i32 --> X^1 iff X has only the low bit set.
// zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
- // zext (X == 1) to i32 --> X iff X has only the low bit set.
- // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set.
// zext (X != 0) to i32 --> X iff X has only the low bit set.
// zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set.
- // zext (X != 1) to i32 --> X^1 iff X has only the low bit set.
- // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set.
- if ((Op1CV->isZero() || Op1CV->isPowerOf2()) &&
- // This only works for EQ and NE
- Cmp->isEquality()) {
+ if (Op1CV->isZero() && Cmp->isEquality() &&
+ (Cmp->getOperand(0)->getType() == Zext.getType() ||
+ Cmp->getPredicate() == ICmpInst::ICMP_NE)) {
// If Op1C some other power of two, convert:
KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext);
+ // Exactly 1 possible 1? But not the high-bit because that is
+ // canonicalized to this form.
APInt KnownZeroMask(~Known.Zero);
- if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1?
- bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE;
- if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) {
- // (X&4) == 2 --> false
- // (X&4) != 2 --> true
- Constant *Res = ConstantInt::get(Zext.getType(), isNE);
- return replaceInstUsesWith(Zext, Res);
- }
-
+ if (KnownZeroMask.isPowerOf2() &&
+ (Zext.getType()->getScalarSizeInBits() !=
+ KnownZeroMask.logBase2() + 1)) {
uint32_t ShAmt = KnownZeroMask.logBase2();
Value *In = Cmp->getOperand(0);
if (ShAmt) {
@@ -1045,10 +1067,9 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext)
In->getName() + ".lobit");
}
- if (!Op1CV->isZero() == isNE) { // Toggle the low bit.
- Constant *One = ConstantInt::get(In->getType(), 1);
- In = Builder.CreateXor(In, One);
- }
+ // Toggle the low bit for "X == 0".
+ if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
+ In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1));
if (Zext.getType() == In->getType())
return replaceInstUsesWith(Zext, In);
@@ -1073,39 +1094,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext)
Value *And1 = Builder.CreateAnd(Lshr, ConstantInt::get(X->getType(), 1));
return replaceInstUsesWith(Zext, And1);
}
-
- // icmp ne A, B is equal to xor A, B when A and B only really have one bit.
- // It is also profitable to transform icmp eq into not(xor(A, B)) because
- // that may lead to additional simplifications.
- if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) {
- Value *LHS = Cmp->getOperand(0);
- Value *RHS = Cmp->getOperand(1);
-
- KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext);
- KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext);
-
- if (KnownLHS == KnownRHS) {
- APInt KnownBits = KnownLHS.Zero | KnownLHS.One;
- APInt UnknownBit = ~KnownBits;
- if (UnknownBit.countPopulation() == 1) {
- Value *Result = Builder.CreateXor(LHS, RHS);
-
- // Mask off any bits that are set and won't be shifted away.
- if (KnownLHS.One.uge(UnknownBit))
- Result = Builder.CreateAnd(Result,
- ConstantInt::get(ITy, UnknownBit));
-
- // Shift the bit we're testing down to the lsb.
- Result = Builder.CreateLShr(
- Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros()));
-
- if (Cmp->getPredicate() == ICmpInst::ICMP_EQ)
- Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1));
- Result->takeName(Cmp);
- return replaceInstUsesWith(Zext, Result);
- }
- }
- }
}
return nullptr;
@@ -1235,23 +1223,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear,
}
}
-Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
+Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) {
// If this zero extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this zext.
- if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
+ if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back()))
return nullptr;
// If one of the common conversion will work, do it.
- if (Instruction *Result = commonCastTransforms(CI))
+ if (Instruction *Result = commonCastTransforms(Zext))
return Result;
- Value *Src = CI.getOperand(0);
- Type *SrcTy = Src->getType(), *DestTy = CI.getType();
+ Value *Src = Zext.getOperand(0);
+ Type *SrcTy = Src->getType(), *DestTy = Zext.getType();
// Try to extend the entire expression tree to the wide destination type.
unsigned BitsToClear;
if (shouldChangeType(SrcTy, DestTy) &&
- canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) {
+ canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) {
assert(BitsToClear <= SrcTy->getScalarSizeInBits() &&
"Can't clear more bits than in SrcTy");
@@ -1259,25 +1247,25 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
" to avoid zero extend: "
- << CI << '\n');
+ << Zext << '\n');
Value *Res = EvaluateInDifferentType(Src, DestTy, false);
assert(Res->getType() == DestTy);
// Preserve debug values referring to Src if the zext is its last use.
if (auto *SrcOp = dyn_cast<Instruction>(Src))
if (SrcOp->hasOneUse())
- replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT);
+ replaceAllDbgUsesWith(*SrcOp, *Res, Zext, DT);
- uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear;
+ uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear;
uint32_t DestBitSize = DestTy->getScalarSizeInBits();
// If the high bits are already filled with zeros, just replace this
// cast with the result.
if (MaskedValueIsZero(Res,
APInt::getHighBitsSet(DestBitSize,
- DestBitSize-SrcBitsKept),
- 0, &CI))
- return replaceInstUsesWith(CI, Res);
+ DestBitSize - SrcBitsKept),
+ 0, &Zext))
+ return replaceInstUsesWith(Zext, Res);
// We need to emit an AND to clear the high bits.
Constant *C = ConstantInt::get(Res->getType(),
@@ -1288,7 +1276,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
// If this is a TRUNC followed by a ZEXT then we are dealing with integral
// types and if the sizes are just right we can convert this into a logical
// 'and' which will be much cheaper than the pair of casts.
- if (TruncInst *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast
+ if (auto *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast
// TODO: Subsume this into EvaluateInDifferentType.
// Get the sizes of the types involved. We know that the intermediate type
@@ -1296,7 +1284,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
Value *A = CSrc->getOperand(0);
unsigned SrcSize = A->getType()->getScalarSizeInBits();
unsigned MidSize = CSrc->getType()->getScalarSizeInBits();
- unsigned DstSize = CI.getType()->getScalarSizeInBits();
+ unsigned DstSize = DestTy->getScalarSizeInBits();
// If we're actually extending zero bits, then if
// SrcSize < DstSize: zext(a & mask)
// SrcSize == DstSize: a & mask
@@ -1305,7 +1293,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize));
Constant *AndConst = ConstantInt::get(A->getType(), AndValue);
Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask");
- return new ZExtInst(And, CI.getType());
+ return new ZExtInst(And, DestTy);
}
if (SrcSize == DstSize) {
@@ -1314,7 +1302,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
AndValue));
}
if (SrcSize > DstSize) {
- Value *Trunc = Builder.CreateTrunc(A, CI.getType());
+ Value *Trunc = Builder.CreateTrunc(A, DestTy);
APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize));
return BinaryOperator::CreateAnd(Trunc,
ConstantInt::get(Trunc->getType(),
@@ -1322,34 +1310,46 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
}
}
- if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src))
- return transformZExtICmp(Cmp, CI);
+ if (auto *Cmp = dyn_cast<ICmpInst>(Src))
+ return transformZExtICmp(Cmp, Zext);
// zext(trunc(X) & C) -> (X & zext(C)).
Constant *C;
Value *X;
if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) &&
- X->getType() == CI.getType())
- return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType()));
+ X->getType() == DestTy)
+ return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy));
// zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)).
Value *And;
if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) &&
match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) &&
- X->getType() == CI.getType()) {
- Constant *ZC = ConstantExpr::getZExt(C, CI.getType());
+ X->getType() == DestTy) {
+ Constant *ZC = ConstantExpr::getZExt(C, DestTy);
return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC);
}
+ // If we are truncating, masking, and then zexting back to the original type,
+ // that's just a mask. This is not handled by canEvaluateZextd if the
+ // intermediate values have extra uses. This could be generalized further for
+ // a non-constant mask operand.
+ // zext (and (trunc X), C) --> and X, (zext C)
+ if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) &&
+ X->getType() == DestTy) {
+ Constant *ZextC = ConstantExpr::getZExt(C, DestTy);
+ return BinaryOperator::CreateAnd(X, ZextC);
+ }
+
if (match(Src, m_VScale(DL))) {
- if (CI.getFunction() &&
- CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
- Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange);
- if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
+ if (Zext.getFunction() &&
+ Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
+ Attribute Attr =
+ Zext.getFunction()->getFnAttribute(Attribute::VScaleRange);
+ if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
unsigned TypeWidth = Src->getType()->getScalarSizeInBits();
if (Log2_32(*MaxVScale) < TypeWidth) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
- return replaceInstUsesWith(CI, VScale);
+ return replaceInstUsesWith(Zext, VScale);
}
}
}
@@ -1359,48 +1359,44 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) {
}
/// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp.
-Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI,
- Instruction &CI) {
- Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1);
- ICmpInst::Predicate Pred = ICI->getPredicate();
+Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp,
+ SExtInst &Sext) {
+ Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1);
+ ICmpInst::Predicate Pred = Cmp->getPredicate();
// Don't bother if Op1 isn't of vector or integer type.
if (!Op1->getType()->isIntOrIntVectorTy())
return nullptr;
- if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) ||
- (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) {
- // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative
- // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive
+ if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) {
+ // sext (x <s 0) --> ashr x, 31 (all ones if negative)
Value *Sh = ConstantInt::get(Op0->getType(),
Op0->getType()->getScalarSizeInBits() - 1);
Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit");
- if (In->getType() != CI.getType())
- In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/);
+ if (In->getType() != Sext.getType())
+ In = Builder.CreateIntCast(In, Sext.getType(), true /*SExt*/);
- if (Pred == ICmpInst::ICMP_SGT)
- In = Builder.CreateNot(In, In->getName() + ".not");
- return replaceInstUsesWith(CI, In);
+ return replaceInstUsesWith(Sext, In);
}
if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) {
// If we know that only one bit of the LHS of the icmp can be set and we
// have an equality comparison with zero or a power of 2, we can transform
// the icmp and sext into bitwise/integer operations.
- if (ICI->hasOneUse() &&
- ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){
- KnownBits Known = computeKnownBits(Op0, 0, &CI);
+ if (Cmp->hasOneUse() &&
+ Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){
+ KnownBits Known = computeKnownBits(Op0, 0, &Sext);
APInt KnownZeroMask(~Known.Zero);
if (KnownZeroMask.isPowerOf2()) {
- Value *In = ICI->getOperand(0);
+ Value *In = Cmp->getOperand(0);
// If the icmp tests for a known zero bit we can constant fold it.
if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) {
Value *V = Pred == ICmpInst::ICMP_NE ?
- ConstantInt::getAllOnesValue(CI.getType()) :
- ConstantInt::getNullValue(CI.getType());
- return replaceInstUsesWith(CI, V);
+ ConstantInt::getAllOnesValue(Sext.getType()) :
+ ConstantInt::getNullValue(Sext.getType());
+ return replaceInstUsesWith(Sext, V);
}
if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) {
@@ -1431,9 +1427,9 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI,
KnownZeroMask.getBitWidth() - 1), "sext");
}
- if (CI.getType() == In->getType())
- return replaceInstUsesWith(CI, In);
- return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/);
+ if (Sext.getType() == In->getType())
+ return replaceInstUsesWith(Sext, In);
+ return CastInst::CreateIntegerCast(In, Sext.getType(), true/*SExt*/);
}
}
}
@@ -1496,22 +1492,22 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) {
return false;
}
-Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
+Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) {
// If this sign extend is only used by a truncate, let the truncate be
// eliminated before we try to optimize this sext.
- if (CI.hasOneUse() && isa<TruncInst>(CI.user_back()))
+ if (Sext.hasOneUse() && isa<TruncInst>(Sext.user_back()))
return nullptr;
- if (Instruction *I = commonCastTransforms(CI))
+ if (Instruction *I = commonCastTransforms(Sext))
return I;
- Value *Src = CI.getOperand(0);
- Type *SrcTy = Src->getType(), *DestTy = CI.getType();
+ Value *Src = Sext.getOperand(0);
+ Type *SrcTy = Src->getType(), *DestTy = Sext.getType();
unsigned SrcBitSize = SrcTy->getScalarSizeInBits();
unsigned DestBitSize = DestTy->getScalarSizeInBits();
// If the value being extended is zero or positive, use a zext instead.
- if (isKnownNonNegative(Src, DL, 0, &AC, &CI, &DT))
+ if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT))
return CastInst::Create(Instruction::ZExt, Src, DestTy);
// Try to extend the entire expression tree to the wide destination type.
@@ -1520,14 +1516,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
LLVM_DEBUG(
dbgs() << "ICE: EvaluateInDifferentType converting expression type"
" to avoid sign extend: "
- << CI << '\n');
+ << Sext << '\n');
Value *Res = EvaluateInDifferentType(Src, DestTy, true);
assert(Res->getType() == DestTy);
// If the high bits are already filled with sign bit, just replace this
// cast with the result.
- if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize)
- return replaceInstUsesWith(CI, Res);
+ if (ComputeNumSignBits(Res, 0, &Sext) > DestBitSize - SrcBitSize)
+ return replaceInstUsesWith(Sext, Res);
// We need to emit a shl + ashr to do the sign extend.
Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize);
@@ -1540,7 +1536,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
// If the input has more sign bits than bits truncated, then convert
// directly to final type.
unsigned XBitSize = X->getType()->getScalarSizeInBits();
- if (ComputeNumSignBits(X, 0, &CI) > XBitSize - SrcBitSize)
+ if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize)
return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true);
// If input is a trunc from the destination type, then convert into shifts.
@@ -1563,8 +1559,8 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
}
}
- if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src))
- return transformSExtICmp(ICI, CI);
+ if (auto *Cmp = dyn_cast<ICmpInst>(Src))
+ return transformSExtICmp(Cmp, Sext);
// If the input is a shl/ashr pair of a same constant, then this is a sign
// extension from a smaller value. If we could trust arbitrary bitwidth
@@ -1593,7 +1589,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
NumLowbitsLeft);
NewShAmt =
Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA);
- A = Builder.CreateShl(A, NewShAmt, CI.getName());
+ A = Builder.CreateShl(A, NewShAmt, Sext.getName());
return BinaryOperator::CreateAShr(A, NewShAmt);
}
@@ -1616,13 +1612,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) {
}
if (match(Src, m_VScale(DL))) {
- if (CI.getFunction() &&
- CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
- Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange);
- if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
+ if (Sext.getFunction() &&
+ Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) {
+ Attribute Attr =
+ Sext.getFunction()->getFnAttribute(Attribute::VScaleRange);
+ if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) {
if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) {
Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1));
- return replaceInstUsesWith(CI, VScale);
+ return replaceInstUsesWith(Sext, VScale);
}
}
}
@@ -1659,7 +1656,6 @@ static Type *shrinkFPConstant(ConstantFP *CFP) {
// Determine if this is a vector of ConstantFPs and if so, return the minimal
// type we can safely truncate all elements to.
-// TODO: Make these support undef elements.
static Type *shrinkFPConstantVector(Value *V) {
auto *CV = dyn_cast<Constant>(V);
auto *CVVTy = dyn_cast<FixedVectorType>(V->getType());
@@ -1673,6 +1669,9 @@ static Type *shrinkFPConstantVector(Value *V) {
// For fixed-width vectors we find the minimal type by looking
// through the constant values of the vector.
for (unsigned i = 0; i != NumElts; ++i) {
+ if (isa<UndefValue>(CV->getAggregateElement(i)))
+ continue;
+
auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i));
if (!CFP)
return nullptr;
@@ -1688,7 +1687,7 @@ static Type *shrinkFPConstantVector(Value *V) {
}
// Make a vector type from the minimal type.
- return FixedVectorType::get(MinType, NumElts);
+ return MinType ? FixedVectorType::get(MinType, NumElts) : nullptr;
}
/// Find the minimum FP type we can safely truncate to.
@@ -2862,21 +2861,27 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) {
}
}
- // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as
- // a byte-swap:
- // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X)
- // TODO: We should match the related pattern for bitreverse.
- if (DestTy->isIntegerTy() &&
- DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
- SrcTy->getScalarSizeInBits() == 8 &&
- ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() &&
- Shuf->isReverse()) {
- assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
- assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op");
- Function *Bswap =
- Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy);
- Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy);
- return CallInst::Create(Bswap, { ScalarX });
+ // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized
+ // as a byte/bit swap:
+ // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X)
+ // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X)
+ if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 &&
+ Shuf->hasOneUse() && Shuf->isReverse()) {
+ unsigned IntrinsicNum = 0;
+ if (DL.isLegalInteger(DestTy->getScalarSizeInBits()) &&
+ SrcTy->getScalarSizeInBits() == 8) {
+ IntrinsicNum = Intrinsic::bswap;
+ } else if (SrcTy->getScalarSizeInBits() == 1) {
+ IntrinsicNum = Intrinsic::bitreverse;
+ }
+ if (IntrinsicNum != 0) {
+ assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask");
+ assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op");
+ Function *BswapOrBitreverse =
+ Intrinsic::getDeclaration(CI.getModule(), IntrinsicNum, DestTy);
+ Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy);
+ return CallInst::Create(BswapOrBitreverse, {ScalarX});
+ }
}
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 158d2e8289e0..1480a0ff9e2f 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -17,6 +17,7 @@
#include "llvm/Analysis/CmpInstAnalysis.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/ConstantRange.h"
#include "llvm/IR/DataLayout.h"
#include "llvm/IR/GetElementPtrTypeIterator.h"
@@ -281,7 +282,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
if (!GEP->isInBounds()) {
Type *IntPtrTy = DL.getIntPtrType(GEP->getType());
unsigned PtrSize = IntPtrTy->getIntegerBitWidth();
- if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize)
+ if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize)
Idx = Builder.CreateTrunc(Idx, IntPtrTy);
}
@@ -403,108 +404,6 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal(
return nullptr;
}
-/// Return a value that can be used to compare the *offset* implied by a GEP to
-/// zero. For example, if we have &A[i], we want to return 'i' for
-/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales
-/// are involved. The above expression would also be legal to codegen as
-/// "icmp ne (i*4), 0" (assuming A is a pointer to i32).
-/// This latter form is less amenable to optimization though, and we are allowed
-/// to generate the first by knowing that pointer arithmetic doesn't overflow.
-///
-/// If we can't emit an optimized form for this expression, this returns null.
-///
-static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC,
- const DataLayout &DL) {
- gep_type_iterator GTI = gep_type_begin(GEP);
-
- // Check to see if this gep only has a single variable index. If so, and if
- // any constant indices are a multiple of its scale, then we can compute this
- // in terms of the scale of the variable index. For example, if the GEP
- // implies an offset of "12 + i*4", then we can codegen this as "3 + i",
- // because the expression will cross zero at the same point.
- unsigned i, e = GEP->getNumOperands();
- int64_t Offset = 0;
- for (i = 1; i != e; ++i, ++GTI) {
- if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) {
- // Compute the aggregate offset of constant indices.
- if (CI->isZero()) continue;
-
- // Handle a struct index, which adds its field offset to the pointer.
- if (StructType *STy = GTI.getStructTypeOrNull()) {
- Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue());
- } else {
- uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
- Offset += Size*CI->getSExtValue();
- }
- } else {
- // Found our variable index.
- break;
- }
- }
-
- // If there are no variable indices, we must have a constant offset, just
- // evaluate it the general way.
- if (i == e) return nullptr;
-
- Value *VariableIdx = GEP->getOperand(i);
- // Determine the scale factor of the variable element. For example, this is
- // 4 if the variable index is into an array of i32.
- uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType());
-
- // Verify that there are no other variable indices. If so, emit the hard way.
- for (++i, ++GTI; i != e; ++i, ++GTI) {
- ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i));
- if (!CI) return nullptr;
-
- // Compute the aggregate offset of constant indices.
- if (CI->isZero()) continue;
-
- // Handle a struct index, which adds its field offset to the pointer.
- if (StructType *STy = GTI.getStructTypeOrNull()) {
- Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue());
- } else {
- uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType());
- Offset += Size*CI->getSExtValue();
- }
- }
-
- // Okay, we know we have a single variable index, which must be a
- // pointer/array/vector index. If there is no offset, life is simple, return
- // the index.
- Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType());
- unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth();
- if (Offset == 0) {
- // Cast to intptrty in case a truncation occurs. If an extension is needed,
- // we don't need to bother extending: the extension won't affect where the
- // computation crosses zero.
- if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() >
- IntPtrWidth) {
- VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy);
- }
- return VariableIdx;
- }
-
- // Otherwise, there is an index. The computation we will do will be modulo
- // the pointer size.
- Offset = SignExtend64(Offset, IntPtrWidth);
- VariableScale = SignExtend64(VariableScale, IntPtrWidth);
-
- // To do this transformation, any constant index must be a multiple of the
- // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i",
- // but we can't evaluate "10 + 3*i" in terms of i. Check that the offset is a
- // multiple of the variable scale.
- int64_t NewOffs = Offset / (int64_t)VariableScale;
- if (Offset != NewOffs*(int64_t)VariableScale)
- return nullptr;
-
- // Okay, we can do this evaluation. Start by converting the index to intptr.
- if (VariableIdx->getType() != IntPtrTy)
- VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy,
- true /*Signed*/);
- Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs);
- return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset");
-}
-
/// Returns true if we can rewrite Start as a GEP with pointer Base
/// and some integer offset. The nodes that need to be re-written
/// for this transformation will be added to Explored.
@@ -732,8 +631,8 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base,
// Cast base to the expected type.
Value *NewVal = Builder.CreateBitOrPointerCast(
Base, PtrTy, Start->getName() + "to.ptr");
- NewVal = Builder.CreateInBoundsGEP(
- ElemTy, NewVal, makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr");
+ NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]),
+ Val->getName() + ".ptr");
NewVal = Builder.CreateBitOrPointerCast(
NewVal, Val->getType(), Val->getName() + ".conv");
Val->replaceAllUsesWith(NewVal);
@@ -841,18 +740,9 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
RHS = RHS->stripPointerCasts();
Value *PtrBase = GEPLHS->getOperand(0);
- // FIXME: Support vector pointer GEPs.
- if (PtrBase == RHS && GEPLHS->isInBounds() &&
- !GEPLHS->getType()->isVectorTy()) {
+ if (PtrBase == RHS && GEPLHS->isInBounds()) {
// ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0).
- // This transformation (ignoring the base and scales) is valid because we
- // know pointers can't overflow since the gep is inbounds. See if we can
- // output an optimized form.
- Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL);
-
- // If not, synthesize the offset the hard way.
- if (!Offset)
- Offset = EmitGEPOffset(GEPLHS);
+ Value *Offset = EmitGEPOffset(GEPLHS);
return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset,
Constant::getNullValue(Offset->getType()));
}
@@ -926,8 +816,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS,
Type *LHSIndexTy = LOffset->getType();
Type *RHSIndexTy = ROffset->getType();
if (LHSIndexTy != RHSIndexTy) {
- if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() <
- RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) {
+ if (LHSIndexTy->getPrimitiveSizeInBits().getFixedValue() <
+ RHSIndexTy->getPrimitiveSizeInBits().getFixedValue()) {
ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy);
} else
LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy);
@@ -1480,7 +1370,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
return nullptr;
// Try to simplify this compare to T/F based on the dominating condition.
- Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB);
+ std::optional<bool> Imp =
+ isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB);
if (Imp)
return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp));
@@ -1548,16 +1439,34 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
ConstantInt::get(V->getType(), 1));
}
+ Type *SrcTy = X->getType();
unsigned DstBits = Trunc->getType()->getScalarSizeInBits(),
- SrcBits = X->getType()->getScalarSizeInBits();
+ SrcBits = SrcTy->getScalarSizeInBits();
+
+ // TODO: Handle any shifted constant by subtracting trailing zeros.
+ // TODO: Handle non-equality predicates.
+ Value *Y;
+ if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) {
+ // (trunc (1 << Y) to iN) == 0 --> Y u>= N
+ // (trunc (1 << Y) to iN) != 0 --> Y u< N
+ if (C.isZero()) {
+ auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT;
+ return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits));
+ }
+ // (trunc (1 << Y) to iN) == 2**C --> Y == C
+ // (trunc (1 << Y) to iN) != 2**C --> Y != C
+ if (C.isPowerOf2())
+ return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2()));
+ }
+
if (Cmp.isEquality() && Trunc->hasOneUse()) {
// Canonicalize to a mask and wider compare if the wide type is suitable:
// (trunc X to i8) == C --> (X & 0xff) == (zext C)
- if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) {
- Constant *Mask = ConstantInt::get(X->getType(),
- APInt::getLowBitsSet(SrcBits, DstBits));
+ if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) {
+ Constant *Mask =
+ ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits));
Value *And = Builder.CreateAnd(X, Mask);
- Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits));
+ Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits));
return new ICmpInst(Pred, And, WideC);
}
@@ -1570,7 +1479,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
// Pull in the high bits from known-ones set.
APInt NewRHS = C.zext(SrcBits);
NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits);
- return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS));
+ return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, NewRHS));
}
}
@@ -1583,11 +1492,10 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
if (isSignBitCheck(Pred, C, TrueIfSigned) &&
match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
DstBits == SrcBits - ShAmtC->getZExtValue()) {
- return TrueIfSigned
- ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
- ConstantInt::getNullValue(X->getType()))
- : new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
- ConstantInt::getAllOnesValue(X->getType()));
+ return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
+ ConstantInt::getNullValue(SrcTy))
+ : new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
+ ConstantInt::getAllOnesValue(SrcTy));
}
return nullptr;
@@ -1597,6 +1505,9 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp,
BinaryOperator *Xor,
const APInt &C) {
+ if (Instruction *I = foldICmpXorShiftConst(Cmp, Xor, C))
+ return I;
+
Value *X = Xor->getOperand(0);
Value *Y = Xor->getOperand(1);
const APInt *XorC;
@@ -1660,6 +1571,37 @@ Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp,
return nullptr;
}
+/// For power-of-2 C:
+/// ((X s>> ShiftC) ^ X) u< C --> (X + C) u< (C << 1)
+/// ((X s>> ShiftC) ^ X) u> (C - 1) --> (X + C) u> ((C << 1) - 1)
+Instruction *InstCombinerImpl::foldICmpXorShiftConst(ICmpInst &Cmp,
+ BinaryOperator *Xor,
+ const APInt &C) {
+ CmpInst::Predicate Pred = Cmp.getPredicate();
+ APInt PowerOf2;
+ if (Pred == ICmpInst::ICMP_ULT)
+ PowerOf2 = C;
+ else if (Pred == ICmpInst::ICMP_UGT && !C.isMaxValue())
+ PowerOf2 = C + 1;
+ else
+ return nullptr;
+ if (!PowerOf2.isPowerOf2())
+ return nullptr;
+ Value *X;
+ const APInt *ShiftC;
+ if (!match(Xor, m_OneUse(m_c_Xor(m_Value(X),
+ m_AShr(m_Deferred(X), m_APInt(ShiftC))))))
+ return nullptr;
+ uint64_t Shift = ShiftC->getLimitedValue();
+ Type *XType = X->getType();
+ if (Shift == 0 || PowerOf2.isMinSignedValue())
+ return nullptr;
+ Value *Add = Builder.CreateAdd(X, ConstantInt::get(XType, PowerOf2));
+ APInt Bound =
+ Pred == ICmpInst::ICMP_ULT ? PowerOf2 << 1 : ((PowerOf2 << 1) - 1);
+ return new ICmpInst(Pred, Add, ConstantInt::get(XType, Bound));
+}
+
/// Fold icmp (and (sh X, Y), C2), C1.
Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp,
BinaryOperator *And,
@@ -1780,7 +1722,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp,
APInt NewC2 = *C2;
KnownBits Know = computeKnownBits(And->getOperand(0), 0, And);
// Set high zeros of C2 to allow matching negated power-of-2.
- NewC2 = *C2 + APInt::getHighBitsSet(C2->getBitWidth(),
+ NewC2 = *C2 | APInt::getHighBitsSet(C2->getBitWidth(),
Know.countMinLeadingZeros());
// Restrict this fold only for single-use 'and' (PR10267).
@@ -1904,6 +1846,20 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1))));
}
+ // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X)
+ // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X)
+ // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X)
+ // ((zext i1 X) & Y) != 1 --> !((trunc Y) & X)
+ if (match(And, m_OneUse(m_c_And(m_OneUse(m_ZExt(m_Value(X))), m_Value(Y)))) &&
+ X->getType()->isIntOrIntVectorTy(1) && (C.isZero() || C.isOne())) {
+ Value *TruncY = Builder.CreateTrunc(Y, X->getType());
+ if (C.isZero() ^ (Pred == CmpInst::ICMP_NE)) {
+ Value *And = Builder.CreateAnd(TruncY, X);
+ return BinaryOperator::CreateNot(And);
+ }
+ return BinaryOperator::CreateAnd(TruncY, X);
+ }
+
return nullptr;
}
@@ -1988,21 +1944,32 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp,
Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
BinaryOperator *Mul,
const APInt &C) {
+ ICmpInst::Predicate Pred = Cmp.getPredicate();
+ Type *MulTy = Mul->getType();
+ Value *X = Mul->getOperand(0);
+
+ // If there's no overflow:
+ // X * X == 0 --> X == 0
+ // X * X != 0 --> X != 0
+ if (Cmp.isEquality() && C.isZero() && X == Mul->getOperand(1) &&
+ (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap()))
+ return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy));
+
const APInt *MulC;
if (!match(Mul->getOperand(1), m_APInt(MulC)))
return nullptr;
// If this is a test of the sign bit and the multiply is sign-preserving with
- // a constant operand, use the multiply LHS operand instead.
- ICmpInst::Predicate Pred = Cmp.getPredicate();
+ // a constant operand, use the multiply LHS operand instead:
+ // (X * +MulC) < 0 --> X < 0
+ // (X * -MulC) < 0 --> X > 0
if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) {
if (MulC->isNegative())
Pred = ICmpInst::getSwappedPredicate(Pred);
- return new ICmpInst(Pred, Mul->getOperand(0),
- Constant::getNullValue(Mul->getType()));
+ return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy));
}
- if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap()))
+ if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap()))
return nullptr;
// If the multiply does not wrap, try to divide the compare constant by the
@@ -2010,50 +1977,45 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp,
if (Cmp.isEquality()) {
// (mul nsw X, MulC) == C --> X == C /s MulC
if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) {
- Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC));
- return new ICmpInst(Pred, Mul->getOperand(0), NewC);
+ Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC));
+ return new ICmpInst(Pred, X, NewC);
}
// (mul nuw X, MulC) == C --> X == C /u MulC
if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) {
- Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC));
- return new ICmpInst(Pred, Mul->getOperand(0), NewC);
+ Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC));
+ return new ICmpInst(Pred, X, NewC);
}
}
+ // With a matching no-overflow guarantee, fold the constants:
+ // (X * MulC) < C --> X < (C / MulC)
+ // (X * MulC) > C --> X > (C / MulC)
+ // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE?
Constant *NewC = nullptr;
-
- // FIXME: Add assert that Pred is not equal to ICMP_SGE, ICMP_SLE,
- // ICMP_UGE, ICMP_ULE.
-
if (Mul->hasNoSignedWrap()) {
- if (MulC->isNegative()) {
- // MININT / -1 --> overflow.
- if (C.isMinSignedValue() && MulC->isAllOnes())
- return nullptr;
+ // MININT / -1 --> overflow.
+ if (C.isMinSignedValue() && MulC->isAllOnes())
+ return nullptr;
+ if (MulC->isNegative())
Pred = ICmpInst::getSwappedPredicate(Pred);
- }
+
if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE)
NewC = ConstantInt::get(
- Mul->getType(),
- APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP));
+ MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP));
if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT)
NewC = ConstantInt::get(
- Mul->getType(),
- APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN));
- }
-
- if (Mul->hasNoUnsignedWrap()) {
+ MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN));
+ } else {
+ assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw");
if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)
NewC = ConstantInt::get(
- Mul->getType(),
- APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP));
+ MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP));
if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT)
NewC = ConstantInt::get(
- Mul->getType(),
- APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN));
+ MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN));
}
- return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr;
+ return NewC ? new ICmpInst(Pred, X, NewC) : nullptr;
}
/// Fold icmp (shl 1, Y), C.
@@ -2080,39 +2042,21 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl,
Pred = ICmpInst::ICMP_UGT;
}
- // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31
- // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31
unsigned CLog2 = C.logBase2();
- if (CLog2 == TypeBits - 1) {
- if (Pred == ICmpInst::ICMP_UGE)
- Pred = ICmpInst::ICMP_EQ;
- else if (Pred == ICmpInst::ICMP_ULT)
- Pred = ICmpInst::ICMP_NE;
- }
return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2));
} else if (Cmp.isSigned()) {
Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1);
- if (C.isAllOnes()) {
- // (1 << Y) <= -1 -> Y == 31
- if (Pred == ICmpInst::ICMP_SLE)
- return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
-
- // (1 << Y) > -1 -> Y != 31
- if (Pred == ICmpInst::ICMP_SGT)
- return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
- } else if (!C) {
- // (1 << Y) < 0 -> Y == 31
- // (1 << Y) <= 0 -> Y == 31
- if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE)
- return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
+ // (1 << Y) > 0 -> Y != 31
+ // (1 << Y) > C -> Y != 31 if C is negative.
+ if (Pred == ICmpInst::ICMP_SGT && C.sle(0))
+ return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
- // (1 << Y) >= 0 -> Y != 31
- // (1 << Y) > 0 -> Y != 31
- if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE)
- return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne);
- }
- } else if (Cmp.isEquality() && CIsPowerOf2) {
- return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2()));
+ // (1 << Y) < 0 -> Y == 31
+ // (1 << Y) < 1 -> Y == 31
+ // (1 << Y) < C -> Y == 31 if C is negative and not signed min.
+ // Exclude signed min by subtracting 1 and lower the upper bound to 0.
+ if (Pred == ICmpInst::ICMP_SLT && (C-1).sle(0))
+ return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne);
}
return nullptr;
@@ -2833,6 +2777,13 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
if (Pred == CmpInst::ICMP_SLT && C == *C2)
return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax));
+ // (X + -1) <u C --> X <=u C (if X is never null)
+ if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) {
+ const SimplifyQuery Q = SQ.getWithInstruction(&Cmp);
+ if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT))
+ return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C));
+ }
+
if (!Add->hasOneUse())
return nullptr;
@@ -3095,7 +3046,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) {
ArrayRef<int> Mask;
if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) {
// Check whether every element of Mask is the same constant
- if (is_splat(Mask)) {
+ if (all_equal(Mask)) {
auto *VecTy = cast<VectorType>(SrcType);
auto *EltTy = cast<IntegerType>(VecTy->getElementType());
if (C->isSplat(EltTy->getBitWidth())) {
@@ -3139,6 +3090,20 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) {
if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)))
if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C))
return I;
+
+ // (extractval ([s/u]subo X, Y), 0) == 0 --> X == Y
+ // (extractval ([s/u]subo X, Y), 0) != 0 --> X != Y
+ // TODO: This checks one-use, but that is not strictly necessary.
+ Value *Cmp0 = Cmp.getOperand(0);
+ Value *X, *Y;
+ if (C->isZero() && Cmp.isEquality() && Cmp0->hasOneUse() &&
+ (match(Cmp0,
+ m_ExtractValue<0>(m_Intrinsic<Intrinsic::ssub_with_overflow>(
+ m_Value(X), m_Value(Y)))) ||
+ match(Cmp0,
+ m_ExtractValue<0>(m_Intrinsic<Intrinsic::usub_with_overflow>(
+ m_Value(X), m_Value(Y))))))
+ return new ICmpInst(Cmp.getPredicate(), X, Y);
}
if (match(Cmp.getOperand(1), m_APIntAllowUndef(C)))
@@ -3174,10 +3139,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant(
}
break;
case Instruction::Add: {
- // Replace ((add A, B) != C) with (A != C-B) if B & C are constants.
- if (Constant *BOC = dyn_cast<Constant>(BOp1)) {
+ // (A + C2) == C --> A == (C - C2)
+ // (A + C2) != C --> A != (C - C2)
+ // TODO: Remove the one-use limitation? See discussion in D58633.
+ if (Constant *C2 = dyn_cast<Constant>(BOp1)) {
if (BO->hasOneUse())
- return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC));
+ return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, C2));
} else if (C.isZero()) {
// Replace ((add A, B) != 0) with (A != -B) if A or B is
// efficiently invertible, or if the add has just this one use.
@@ -3433,7 +3400,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp,
case Instruction::UDiv:
if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C))
return I;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Instruction::SDiv:
if (Instruction *I = foldICmpDivConstant(Cmp, BO, C))
return I;
@@ -3580,8 +3547,8 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred,
auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * {
if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ))
return Res;
- if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op,
- RHS, DL, SelectCondIsTrue))
+ if (std::optional<bool> Impl = isImpliedCondition(
+ SI->getCondition(), Pred, Op, RHS, DL, SelectCondIsTrue))
return ConstantInt::get(I.getType(), *Impl);
return nullptr;
};
@@ -4488,6 +4455,18 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
}
}
+ // For unsigned predicates / eq / ne:
+ // icmp pred (x << 1), x --> icmp getSignedPredicate(pred) x, 0
+ // icmp pred x, (x << 1) --> icmp getSignedPredicate(pred) 0, x
+ if (!ICmpInst::isSigned(Pred)) {
+ if (match(Op0, m_Shl(m_Specific(Op1), m_One())))
+ return new ICmpInst(ICmpInst::getSignedPredicate(Pred), Op1,
+ Constant::getNullValue(Op1->getType()));
+ else if (match(Op1, m_Shl(m_Specific(Op0), m_One())))
+ return new ICmpInst(ICmpInst::getSignedPredicate(Pred),
+ Constant::getNullValue(Op0->getType()), Op0);
+ }
+
if (Value *V = foldMultiplicationOverflowCheck(I))
return replaceInstUsesWith(I, V);
@@ -4674,17 +4653,29 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
}
}
- // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B)
- // and (B & (1<<X)-1) == (zext A) --> A == (trunc B)
- ConstantInt *Cst1;
- if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) &&
- match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) ||
- (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) &&
- match(Op1, m_ZExt(m_Value(A))))) {
- APInt Pow2 = Cst1->getValue() + 1;
- if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) &&
- Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth())
+ if (match(Op1, m_ZExt(m_Value(A))) &&
+ (Op0->hasOneUse() || Op1->hasOneUse())) {
+ // (B & (Pow2C-1)) == zext A --> A == trunc B
+ // (B & (Pow2C-1)) != zext A --> A != trunc B
+ const APInt *MaskC;
+ if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) &&
+ MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits())
return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType()));
+
+ // Test if 2 values have different or same signbits:
+ // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0
+ // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1
+ unsigned OpWidth = Op0->getType()->getScalarSizeInBits();
+ Value *X, *Y;
+ ICmpInst::Predicate Pred2;
+ if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) &&
+ match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) &&
+ Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) {
+ Value *Xor = Builder.CreateXor(X, Y, "xor.signbits");
+ Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) :
+ Builder.CreateIsNotNeg(Xor);
+ return replaceInstUsesWith(I, R);
+ }
}
// (A >> C) == (B >> C) --> (A^B) u< (1 << C)
@@ -4708,6 +4699,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
}
// (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0
+ ConstantInt *Cst1;
if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) &&
match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) {
unsigned TypeBits = Cst1->getBitWidth();
@@ -4788,6 +4780,20 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) {
Add, ConstantInt::get(A->getType(), C.shl(1)));
}
+ // Canonicalize:
+ // Assume B_Pow2 != 0
+ // 1. A & B_Pow2 != B_Pow2 -> A & B_Pow2 == 0
+ // 2. A & B_Pow2 == B_Pow2 -> A & B_Pow2 != 0
+ if (match(Op0, m_c_And(m_Specific(Op1), m_Value())) &&
+ isKnownToBeAPowerOfTwo(Op1, /* OrZero */ false, 0, &I))
+ return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0,
+ ConstantInt::getNullValue(Op0->getType()));
+
+ if (match(Op1, m_c_And(m_Specific(Op0), m_Value())) &&
+ isKnownToBeAPowerOfTwo(Op0, /* OrZero */ false, 0, &I))
+ return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1,
+ ConstantInt::getNullValue(Op1->getType()));
+
return nullptr;
}
@@ -4993,7 +4999,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
return foldICmpWithZextOrSext(ICmp);
}
-static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
+static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS, bool IsSigned) {
switch (BinaryOp) {
default:
llvm_unreachable("Unsupported binary op");
@@ -5001,7 +5007,8 @@ static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {
case Instruction::Sub:
return match(RHS, m_Zero());
case Instruction::Mul:
- return match(RHS, m_One());
+ return !(RHS->getType()->isIntOrIntVectorTy(1) && IsSigned) &&
+ match(RHS, m_One());
}
}
@@ -5048,7 +5055,7 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp,
if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType()))
OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount());
- if (isNeutralValue(BinaryOp, RHS)) {
+ if (isNeutralValue(BinaryOp, RHS, IsSigned)) {
Result = LHS;
Overflow = ConstantInt::getFalse(OverflowTy);
return true;
@@ -5746,7 +5753,7 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I,
return nullptr;
}
-llvm::Optional<std::pair<CmpInst::Predicate, Constant *>>
+std::optional<std::pair<CmpInst::Predicate, Constant *>>
InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
Constant *C) {
assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) &&
@@ -5769,13 +5776,13 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
if (auto *CI = dyn_cast<ConstantInt>(C)) {
// Bail out if the constant can't be safely incremented/decremented.
if (!ConstantIsOk(CI))
- return llvm::None;
+ return std::nullopt;
} else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) {
unsigned NumElts = FVTy->getNumElements();
for (unsigned i = 0; i != NumElts; ++i) {
Constant *Elt = C->getAggregateElement(i);
if (!Elt)
- return llvm::None;
+ return std::nullopt;
if (isa<UndefValue>(Elt))
continue;
@@ -5784,14 +5791,14 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred,
// know that this constant is min/max.
auto *CI = dyn_cast<ConstantInt>(Elt);
if (!CI || !ConstantIsOk(CI))
- return llvm::None;
+ return std::nullopt;
if (!SafeReplacementConstant)
SafeReplacementConstant = CI;
}
} else {
// ConstantExpr?
- return llvm::None;
+ return std::nullopt;
}
// It may not be safe to change a compare predicate in the presence of
@@ -5901,7 +5908,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
case ICmpInst::ICMP_UGT:
// icmp ugt -> icmp ult
std::swap(A, B);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case ICmpInst::ICMP_ULT:
// icmp ult i1 A, B -> ~A & B
return BinaryOperator::CreateAnd(Builder.CreateNot(A), B);
@@ -5909,7 +5916,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
case ICmpInst::ICMP_SGT:
// icmp sgt -> icmp slt
std::swap(A, B);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case ICmpInst::ICMP_SLT:
// icmp slt i1 A, B -> A & ~B
return BinaryOperator::CreateAnd(Builder.CreateNot(B), A);
@@ -5917,7 +5924,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
case ICmpInst::ICMP_UGE:
// icmp uge -> icmp ule
std::swap(A, B);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case ICmpInst::ICMP_ULE:
// icmp ule i1 A, B -> ~A | B
return BinaryOperator::CreateOr(Builder.CreateNot(A), B);
@@ -5925,7 +5932,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I,
case ICmpInst::ICMP_SGE:
// icmp sge -> icmp sle
std::swap(A, B);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case ICmpInst::ICMP_SLE:
// icmp sle i1 A, B -> A | ~B
return BinaryOperator::CreateOr(Builder.CreateNot(B), A);
@@ -5986,6 +5993,31 @@ static Instruction *foldVectorCmp(CmpInst &Cmp,
const CmpInst::Predicate Pred = Cmp.getPredicate();
Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1);
Value *V1, *V2;
+
+ auto createCmpReverse = [&](CmpInst::Predicate Pred, Value *X, Value *Y) {
+ Value *V = Builder.CreateCmp(Pred, X, Y, Cmp.getName());
+ if (auto *I = dyn_cast<Instruction>(V))
+ I->copyIRFlags(&Cmp);
+ Module *M = Cmp.getModule();
+ Function *F = Intrinsic::getDeclaration(
+ M, Intrinsic::experimental_vector_reverse, V->getType());
+ return CallInst::Create(F, V);
+ };
+
+ if (match(LHS, m_VecReverse(m_Value(V1)))) {
+ // cmp Pred, rev(V1), rev(V2) --> rev(cmp Pred, V1, V2)
+ if (match(RHS, m_VecReverse(m_Value(V2))) &&
+ (LHS->hasOneUse() || RHS->hasOneUse()))
+ return createCmpReverse(Pred, V1, V2);
+
+ // cmp Pred, rev(V1), RHSSplat --> rev(cmp Pred, V1, RHSSplat)
+ if (LHS->hasOneUse() && isSplatValue(RHS))
+ return createCmpReverse(Pred, V1, RHS);
+ }
+ // cmp Pred, LHSSplat, rev(V2) --> rev(cmp Pred, LHSSplat, V2)
+ else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2)))))
+ return createCmpReverse(Pred, LHS, V2);
+
ArrayRef<int> M;
if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M))))
return nullptr;
@@ -6318,11 +6350,11 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) {
}
// (zext a) * (zext b) --> llvm.umul.with.overflow.
- if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
+ if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this))
return R;
}
- if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
+ if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) {
if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this))
return R;
}
@@ -6668,10 +6700,48 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI,
/// Optimize fabs(X) compared with zero.
static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) {
Value *X;
- if (!match(I.getOperand(0), m_FAbs(m_Value(X))) ||
- !match(I.getOperand(1), m_PosZeroFP()))
+ if (!match(I.getOperand(0), m_FAbs(m_Value(X))))
+ return nullptr;
+
+ const APFloat *C;
+ if (!match(I.getOperand(1), m_APFloat(C)))
return nullptr;
+ if (!C->isPosZero()) {
+ if (!C->isSmallestNormalized())
+ return nullptr;
+
+ const Function *F = I.getFunction();
+ DenormalMode Mode = F->getDenormalMode(C->getSemantics());
+ if (Mode.Input == DenormalMode::PreserveSign ||
+ Mode.Input == DenormalMode::PositiveZero) {
+
+ auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
+ Constant *Zero = ConstantFP::getNullValue(X->getType());
+ return new FCmpInst(P, X, Zero, "", I);
+ };
+
+ switch (I.getPredicate()) {
+ case FCmpInst::FCMP_OLT:
+ // fcmp olt fabs(x), smallest_normalized_number -> fcmp oeq x, 0.0
+ return replaceFCmp(&I, FCmpInst::FCMP_OEQ, X);
+ case FCmpInst::FCMP_UGE:
+ // fcmp uge fabs(x), smallest_normalized_number -> fcmp une x, 0.0
+ return replaceFCmp(&I, FCmpInst::FCMP_UNE, X);
+ case FCmpInst::FCMP_OGE:
+ // fcmp oge fabs(x), smallest_normalized_number -> fcmp one x, 0.0
+ return replaceFCmp(&I, FCmpInst::FCMP_ONE, X);
+ case FCmpInst::FCMP_ULT:
+ // fcmp ult fabs(x), smallest_normalized_number -> fcmp ueq x, 0.0
+ return replaceFCmp(&I, FCmpInst::FCMP_UEQ, X);
+ default:
+ break;
+ }
+ }
+
+ return nullptr;
+ }
+
auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) {
I->setPredicate(P);
return IC.replaceOperand(*I, 0, X);
@@ -6828,6 +6898,26 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP()))
return replaceOperand(I, 1, ConstantFP::getNullValue(OpType));
+ // Ignore signbit of bitcasted int when comparing equality to FP 0.0:
+ // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0
+ if (match(Op1, m_PosZeroFP()) &&
+ match(Op0, m_OneUse(m_BitCast(m_Value(X)))) &&
+ X->getType()->isVectorTy() == OpType->isVectorTy() &&
+ X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) {
+ ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE;
+ if (Pred == FCmpInst::FCMP_OEQ)
+ IntPred = ICmpInst::ICMP_EQ;
+ else if (Pred == FCmpInst::FCMP_UNE)
+ IntPred = ICmpInst::ICMP_NE;
+
+ if (IntPred != ICmpInst::BAD_ICMP_PREDICATE) {
+ Type *IntTy = X->getType();
+ const APInt &SignMask = ~APInt::getSignMask(IntTy->getScalarSizeInBits());
+ Value *MaskX = Builder.CreateAnd(X, ConstantInt::get(IntTy, SignMask));
+ return new ICmpInst(IntPred, MaskX, ConstantInt::getNullValue(IntTy));
+ }
+ }
+
// Handle fcmp with instruction LHS and constant RHS.
Instruction *LHSI;
Constant *RHSC;
@@ -6866,10 +6956,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
if (match(Op0, m_FNeg(m_Value(X)))) {
// fcmp pred (fneg X), C --> fcmp swap(pred) X, -C
Constant *C;
- if (match(Op1, m_Constant(C))) {
- Constant *NegC = ConstantExpr::getFNeg(C);
- return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I);
- }
+ if (match(Op1, m_Constant(C)))
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I);
}
if (match(Op0, m_FPExt(m_Value(X)))) {
@@ -6915,7 +7004,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
APFloat Fabs = TruncC;
Fabs.clearSign();
if (!Lossy &&
- (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) {
+ (Fabs.isZero() || !(Fabs < APFloat::getSmallestNormalized(FPSem)))) {
Constant *NewC = ConstantFP::get(X->getType(), TruncC);
return new FCmpInst(Pred, X, NewC, "", &I);
}
@@ -6942,6 +7031,24 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
}
}
+ {
+ Value *CanonLHS = nullptr, *CanonRHS = nullptr;
+ match(Op0, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonLHS)));
+ match(Op1, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonRHS)));
+
+ // (canonicalize(x) == x) => (x == x)
+ if (CanonLHS == Op1)
+ return new FCmpInst(Pred, Op1, Op1, "", &I);
+
+ // (x == canonicalize(x)) => (x == x)
+ if (CanonRHS == Op0)
+ return new FCmpInst(Pred, Op0, Op0, "", &I);
+
+ // (canonicalize(x) == canonicalize(y)) => (x == y)
+ if (CanonLHS && CanonRHS)
+ return new FCmpInst(Pred, CanonLHS, CanonRHS, "", &I);
+ }
+
if (I.getType()->isVectorTy())
if (Instruction *Res = foldVectorCmp(I, Builder))
return Res;
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
index 664226ec187b..f4e88b122383 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
+++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h
@@ -106,7 +106,8 @@ public:
Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted);
Instruction *visitAnd(BinaryOperator &I);
Instruction *visitOr(BinaryOperator &I);
- bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I);
+ bool sinkNotIntoLogicalOp(Instruction &I);
+ bool sinkNotIntoOtherHandOfLogicalOp(Instruction &I);
Instruction *visitXor(BinaryOperator &I);
Instruction *visitShl(BinaryOperator &I);
Value *reassociateShiftAmtsOfTwoSameDirectionShifts(
@@ -127,8 +128,8 @@ public:
Instruction *commonCastTransforms(CastInst &CI);
Instruction *commonPointerCastTransforms(CastInst &CI);
Instruction *visitTrunc(TruncInst &CI);
- Instruction *visitZExt(ZExtInst &CI);
- Instruction *visitSExt(SExtInst &CI);
+ Instruction *visitZExt(ZExtInst &Zext);
+ Instruction *visitSExt(SExtInst &Sext);
Instruction *visitFPTrunc(FPTruncInst &CI);
Instruction *visitFPExt(CastInst &CI);
Instruction *visitFPToUI(FPToUIInst &FI);
@@ -167,6 +168,7 @@ public:
Instruction *visitInsertValueInst(InsertValueInst &IV);
Instruction *visitInsertElementInst(InsertElementInst &IE);
Instruction *visitExtractElementInst(ExtractElementInst &EI);
+ Instruction *simplifyBinOpSplats(ShuffleVectorInst &SVI);
Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI);
Instruction *visitExtractValueInst(ExtractValueInst &EV);
Instruction *visitLandingPadInst(LandingPadInst &LI);
@@ -247,9 +249,9 @@ private:
/// \return null if the transformation cannot be performed. If the
/// transformation can be performed the new instruction that replaces the
/// (zext icmp) pair will be returned.
- Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI);
+ Instruction *transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext);
- Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI);
+ Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext);
bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS,
const Instruction &CxtI) const {
@@ -329,7 +331,7 @@ private:
Instruction *matchSAddSubSat(IntrinsicInst &MinMax1);
Instruction *foldNot(BinaryOperator &I);
- void freelyInvertAllUsersOf(Value *V);
+ void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr);
/// Determine if a pair of casts can be replaced by a single cast.
///
@@ -360,14 +362,24 @@ private:
Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd,
bool IsLogicalSelect = false);
+ Instruction *foldLogicOfIsFPClass(BinaryOperator &Operator, Value *LHS,
+ Value *RHS);
+
+ Instruction *
+ canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);
+
Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS,
Instruction *CxtI, bool IsAnd,
bool IsLogical = false);
- Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D);
- Value *getSelectCondition(Value *A, Value *B);
+ Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D,
+ bool InvertFalseVal = false);
+ Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame);
+ Instruction *foldLShrOverflowBit(BinaryOperator &I);
+ Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV);
Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II);
Instruction *foldFPSignBitOps(BinaryOperator &I);
+ Instruction *foldFDivConstantDivisor(BinaryOperator &I);
// Optimize one of these forms:
// and i1 Op, SI / select i1 Op, i1 SI, i1 false (if IsAnd = true)
@@ -377,64 +389,6 @@ private:
bool IsAnd);
public:
- /// Inserts an instruction \p New before instruction \p Old
- ///
- /// Also adds the new instruction to the worklist and returns \p New so that
- /// it is suitable for use as the return from the visitation patterns.
- Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) {
- assert(New && !New->getParent() &&
- "New instruction already inserted into a basic block!");
- BasicBlock *BB = Old.getParent();
- BB->getInstList().insert(Old.getIterator(), New); // Insert inst
- Worklist.add(New);
- return New;
- }
-
- /// Same as InsertNewInstBefore, but also sets the debug loc.
- Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) {
- New->setDebugLoc(Old.getDebugLoc());
- return InsertNewInstBefore(New, Old);
- }
-
- /// A combiner-aware RAUW-like routine.
- ///
- /// This method is to be used when an instruction is found to be dead,
- /// replaceable with another preexisting expression. Here we add all uses of
- /// I to the worklist, replace all uses of I with the new value, then return
- /// I, so that the inst combiner will know that I was modified.
- Instruction *replaceInstUsesWith(Instruction &I, Value *V) {
- // If there are no uses to replace, then we return nullptr to indicate that
- // no changes were made to the program.
- if (I.use_empty()) return nullptr;
-
- Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist.
-
- // If we are replacing the instruction with itself, this must be in a
- // segment of unreachable code, so just clobber the instruction.
- if (&I == V)
- V = PoisonValue::get(I.getType());
-
- LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n"
- << " with " << *V << '\n');
-
- I.replaceAllUsesWith(V);
- MadeIRChange = true;
- return &I;
- }
-
- /// Replace operand of instruction and add old operand to the worklist.
- Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) {
- Worklist.addValue(I.getOperand(OpNum));
- I.setOperand(OpNum, V);
- return &I;
- }
-
- /// Replace use and add the previously used value to the worklist.
- void replaceUse(Use &U, Value *NewValue) {
- Worklist.addValue(U);
- U = NewValue;
- }
-
/// Create and insert the idiom we use to indicate a block is unreachable
/// without having to rewrite the CFG from within InstCombine.
void CreateNonTerminatorUnreachable(Instruction *InsertAt) {
@@ -467,67 +421,6 @@ public:
return nullptr; // Don't do anything with FI
}
- void computeKnownBits(const Value *V, KnownBits &Known,
- unsigned Depth, const Instruction *CxtI) const {
- llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT);
- }
-
- KnownBits computeKnownBits(const Value *V, unsigned Depth,
- const Instruction *CxtI) const {
- return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT);
- }
-
- bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false,
- unsigned Depth = 0,
- const Instruction *CxtI = nullptr) {
- return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT);
- }
-
- bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0,
- const Instruction *CxtI = nullptr) const {
- return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT);
- }
-
- unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0,
- const Instruction *CxtI = nullptr) const {
- return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForUnsignedMul(const Value *LHS,
- const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForSignedMul(const Value *LHS,
- const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForUnsignedAdd(const Value *LHS,
- const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForSignedAdd(const Value *LHS,
- const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForUnsignedSub(const Value *LHS,
- const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
- OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS,
- const Instruction *CxtI) const {
- return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT);
- }
-
OverflowResult computeOverflow(
Instruction::BinaryOps BinaryOp, bool IsSigned,
Value *LHS, Value *RHS, Instruction *CxtI) const;
@@ -543,7 +436,7 @@ public:
/// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A
/// & (B | C) -> (A&B) | (A&C)" if this is a win). Returns the simplified
/// value, or null if it didn't simplify.
- Value *SimplifyUsingDistributiveLaws(BinaryOperator &I);
+ Value *foldUsingDistributiveLaws(BinaryOperator &I);
/// Tries to simplify add operations using the definition of remainder.
///
@@ -559,8 +452,7 @@ public:
/// This tries to simplify binary operations by factorizing out common terms
/// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
- Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *,
- Value *, Value *, Value *);
+ Value *tryFactorizationFolds(BinaryOperator &I);
/// Match a select chain which produces one of three values based on whether
/// the LHS is less than, equal to, or greater than RHS respectively.
@@ -647,7 +539,7 @@ public:
/// If an integer typed PHI has only one use which is an IntToPtr operation,
/// replace the PHI with an existing pointer typed PHI if it exists. Otherwise
/// insert a new pointer typed PHI and replace the original one.
- Instruction *foldIntegerTypedPHI(PHINode &PN);
+ bool foldIntegerTypedPHI(PHINode &PN);
/// Helper function for FoldPHIArgXIntoPHI() to set debug location for the
/// folded operation.
@@ -716,6 +608,8 @@ public:
const APInt &C1);
Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And,
const APInt &C1, const APInt &C2);
+ Instruction *foldICmpXorShiftConst(ICmpInst &Cmp, BinaryOperator *Xor,
+ const APInt &C);
Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1,
const APInt &C2);
Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1,
@@ -731,6 +625,7 @@ public:
Instruction *foldICmpBitCast(ICmpInst &Cmp);
// Helpers of visitSelectInst().
+ Instruction *foldSelectOfBools(SelectInst &SI);
Instruction *foldSelectExtConst(SelectInst &Sel);
Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI);
Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *);
@@ -790,13 +685,13 @@ class Negator final {
std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I);
- LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth);
+ [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth);
- LLVM_NODISCARD Value *negate(Value *V, unsigned Depth);
+ [[nodiscard]] Value *negate(Value *V, unsigned Depth);
/// Recurse depth-first and attempt to sink the negation.
/// FIXME: use worklist?
- LLVM_NODISCARD Optional<Result> run(Value *Root);
+ [[nodiscard]] std::optional<Result> run(Value *Root);
Negator(const Negator &) = delete;
Negator(Negator &&) = delete;
@@ -806,8 +701,8 @@ class Negator final {
public:
/// Attempt to negate \p Root. Retuns nullptr if negation can't be performed,
/// otherwise returns negated value.
- LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root,
- InstCombinerImpl &IC);
+ [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root,
+ InstCombinerImpl &IC);
};
} // end namespace llvm
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
index e03b7026f802..41bc65620ff6 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp
@@ -28,30 +28,42 @@ using namespace PatternMatch;
#define DEBUG_TYPE "instcombine"
-STATISTIC(NumDeadStore, "Number of dead stores eliminated");
+STATISTIC(NumDeadStore, "Number of dead stores eliminated");
STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global");
-/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived)
+static cl::opt<unsigned> MaxCopiedFromConstantUsers(
+ "instcombine-max-copied-from-constant-users", cl::init(128),
+ cl::desc("Maximum users to visit in copy from constant transform"),
+ cl::Hidden);
+
+/// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived)
/// pointer to an alloca. Ignore any reads of the pointer, return false if we
/// see any stores or other unknown uses. If we see pointer arithmetic, keep
/// track of whether it moves the pointer (with IsOffset) but otherwise traverse
/// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to
-/// the alloca, and if the source pointer is a pointer to a constant global, we
-/// can optimize this.
+/// the alloca, and if the source pointer is a pointer to a constant memory
+/// location, we can optimize this.
static bool
-isOnlyCopiedFromConstantMemory(AAResults *AA,
- Value *V, MemTransferInst *&TheCopy,
+isOnlyCopiedFromConstantMemory(AAResults *AA, AllocaInst *V,
+ MemTransferInst *&TheCopy,
SmallVectorImpl<Instruction *> &ToDelete) {
// We track lifetime intrinsics as we encounter them. If we decide to go
- // ahead and replace the value with the global, this lets the caller quickly
- // eliminate the markers.
+ // ahead and replace the value with the memory location, this lets the caller
+ // quickly eliminate the markers.
+
+ using ValueAndIsOffset = PointerIntPair<Value *, 1, bool>;
+ SmallVector<ValueAndIsOffset, 32> Worklist;
+ SmallPtrSet<ValueAndIsOffset, 32> Visited;
+ Worklist.emplace_back(V, false);
+ while (!Worklist.empty()) {
+ ValueAndIsOffset Elem = Worklist.pop_back_val();
+ if (!Visited.insert(Elem).second)
+ continue;
+ if (Visited.size() > MaxCopiedFromConstantUsers)
+ return false;
- SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect;
- ValuesToInspect.emplace_back(V, false);
- while (!ValuesToInspect.empty()) {
- auto ValuePair = ValuesToInspect.pop_back_val();
- const bool IsOffset = ValuePair.second;
- for (auto &U : ValuePair.first->uses()) {
+ const auto [Value, IsOffset] = Elem;
+ for (auto &U : Value->uses()) {
auto *I = cast<Instruction>(U.getUser());
if (auto *LI = dyn_cast<LoadInst>(I)) {
@@ -60,15 +72,22 @@ isOnlyCopiedFromConstantMemory(AAResults *AA,
continue;
}
- if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) {
+ if (isa<PHINode, SelectInst>(I)) {
+ // We set IsOffset=true, to forbid the memcpy from occurring after the
+ // phi: If one of the phi operands is not based on the alloca, we
+ // would incorrectly omit a write.
+ Worklist.emplace_back(I, true);
+ continue;
+ }
+ if (isa<BitCastInst, AddrSpaceCastInst>(I)) {
// If uses of the bitcast are ok, we are ok.
- ValuesToInspect.emplace_back(I, IsOffset);
+ Worklist.emplace_back(I, IsOffset);
continue;
}
if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
// If the GEP has all zero indices, it doesn't offset the pointer. If it
// doesn't, it does.
- ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices());
+ Worklist.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices());
continue;
}
@@ -85,11 +104,12 @@ isOnlyCopiedFromConstantMemory(AAResults *AA,
if (IsArgOperand && Call->isInAllocaArgument(DataOpNo))
return false;
- // If this is a readonly/readnone call site, then we know it is just a
- // load (but one that potentially returns the value itself), so we can
+ // If this call site doesn't modify the memory, then we know it is just
+ // a load (but one that potentially returns the value itself), so we can
// ignore it if we know that the value isn't captured.
- if (Call->onlyReadsMemory() &&
- (Call->use_empty() || Call->doesNotCapture(DataOpNo)))
+ bool NoCapture = Call->doesNotCapture(DataOpNo);
+ if ((Call->onlyReadsMemory() && (Call->use_empty() || NoCapture)) ||
+ (Call->onlyReadsMemory(DataOpNo) && NoCapture))
continue;
// If this is being passed as a byval argument, the caller is making a
@@ -111,12 +131,14 @@ isOnlyCopiedFromConstantMemory(AAResults *AA,
if (!MI)
return false;
+ // If the transfer is volatile, reject it.
+ if (MI->isVolatile())
+ return false;
+
// If the transfer is using the alloca as a source of the transfer, then
// ignore it since it is a load (unless the transfer is volatile).
- if (U.getOperandNo() == 1) {
- if (MI->isVolatile()) return false;
+ if (U.getOperandNo() == 1)
continue;
- }
// If we already have seen a copy, reject the second one.
if (TheCopy) return false;
@@ -128,8 +150,8 @@ isOnlyCopiedFromConstantMemory(AAResults *AA,
// If the memintrinsic isn't using the alloca as the dest, reject it.
if (U.getOperandNo() != 0) return false;
- // If the source of the memcpy/move is not a constant global, reject it.
- if (!AA->pointsToConstantMemory(MI->getSource()))
+ // If the source of the memcpy/move is not constant, reject it.
+ if (isModSet(AA->getModRefInfoMask(MI->getSource())))
return false;
// Otherwise, the transform is safe. Remember the copy instruction.
@@ -139,9 +161,10 @@ isOnlyCopiedFromConstantMemory(AAResults *AA,
return true;
}
-/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only
-/// modified by a copy from a constant global. If we can prove this, we can
-/// replace any uses of the alloca with uses of the global directly.
+/// isOnlyCopiedFromConstantMemory - Return true if the specified alloca is only
+/// modified by a copy from a constant memory location. If we can prove this, we
+/// can replace any uses of the alloca with uses of the memory location
+/// directly.
static MemTransferInst *
isOnlyCopiedFromConstantMemory(AAResults *AA,
AllocaInst *AI,
@@ -165,7 +188,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI,
}
static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
- AllocaInst &AI) {
+ AllocaInst &AI, DominatorTree &DT) {
// Check for array size of 1 (scalar allocation).
if (!AI.isArrayAllocation()) {
// i32 1 is the canonical array size for scalar allocations.
@@ -184,6 +207,8 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC,
nullptr, AI.getName());
New->setAlignment(AI.getAlign());
+ replaceAllDbgUsesWith(AI, *New, *New, DT);
+
// Scan to the end of the allocation instructions, to skip over a block of
// allocas if possible...also skip interleaved debug info
//
@@ -234,31 +259,83 @@ namespace {
// instruction.
class PointerReplacer {
public:
- PointerReplacer(InstCombinerImpl &IC) : IC(IC) {}
+ PointerReplacer(InstCombinerImpl &IC, Instruction &Root)
+ : IC(IC), Root(Root) {}
- bool collectUsers(Instruction &I);
- void replacePointer(Instruction &I, Value *V);
+ bool collectUsers();
+ void replacePointer(Value *V);
private:
+ bool collectUsersRecursive(Instruction &I);
void replace(Instruction *I);
Value *getReplacement(Value *I);
+ bool isAvailable(Instruction *I) const {
+ return I == &Root || Worklist.contains(I);
+ }
+ SmallPtrSet<Instruction *, 32> ValuesToRevisit;
SmallSetVector<Instruction *, 4> Worklist;
MapVector<Value *, Value *> WorkMap;
InstCombinerImpl &IC;
+ Instruction &Root;
};
} // end anonymous namespace
-bool PointerReplacer::collectUsers(Instruction &I) {
- for (auto U : I.users()) {
+bool PointerReplacer::collectUsers() {
+ if (!collectUsersRecursive(Root))
+ return false;
+
+ // Ensure that all outstanding (indirect) users of I
+ // are inserted into the Worklist. Return false
+ // otherwise.
+ for (auto *Inst : ValuesToRevisit)
+ if (!Worklist.contains(Inst))
+ return false;
+ return true;
+}
+
+bool PointerReplacer::collectUsersRecursive(Instruction &I) {
+ for (auto *U : I.users()) {
auto *Inst = cast<Instruction>(&*U);
if (auto *Load = dyn_cast<LoadInst>(Inst)) {
if (Load->isVolatile())
return false;
Worklist.insert(Load);
- } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) {
+ } else if (auto *PHI = dyn_cast<PHINode>(Inst)) {
+ // All incoming values must be instructions for replacability
+ if (any_of(PHI->incoming_values(),
+ [](Value *V) { return !isa<Instruction>(V); }))
+ return false;
+
+ // If at least one incoming value of the PHI is not in Worklist,
+ // store the PHI for revisiting and skip this iteration of the
+ // loop.
+ if (any_of(PHI->incoming_values(), [this](Value *V) {
+ return !isAvailable(cast<Instruction>(V));
+ })) {
+ ValuesToRevisit.insert(Inst);
+ continue;
+ }
+
+ Worklist.insert(PHI);
+ if (!collectUsersRecursive(*PHI))
+ return false;
+ } else if (auto *SI = dyn_cast<SelectInst>(Inst)) {
+ if (!isa<Instruction>(SI->getTrueValue()) ||
+ !isa<Instruction>(SI->getFalseValue()))
+ return false;
+
+ if (!isAvailable(cast<Instruction>(SI->getTrueValue())) ||
+ !isAvailable(cast<Instruction>(SI->getFalseValue()))) {
+ ValuesToRevisit.insert(Inst);
+ continue;
+ }
+ Worklist.insert(SI);
+ if (!collectUsersRecursive(*SI))
+ return false;
+ } else if (isa<GetElementPtrInst, BitCastInst>(Inst)) {
Worklist.insert(Inst);
- if (!collectUsers(*Inst))
+ if (!collectUsersRecursive(*Inst))
return false;
} else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) {
if (MI->isVolatile())
@@ -293,6 +370,14 @@ void PointerReplacer::replace(Instruction *I) {
IC.InsertNewInstWith(NewI, *LT);
IC.replaceInstUsesWith(*LT, NewI);
WorkMap[LT] = NewI;
+ } else if (auto *PHI = dyn_cast<PHINode>(I)) {
+ Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType();
+ auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(),
+ PHI->getName(), PHI);
+ for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I)
+ NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)),
+ PHI->getIncomingBlock(I));
+ WorkMap[PHI] = NewPHI;
} else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) {
auto *V = getReplacement(GEP->getPointerOperand());
assert(V && "Operand not replaced");
@@ -313,6 +398,13 @@ void PointerReplacer::replace(Instruction *I) {
IC.InsertNewInstWith(NewI, *BC);
NewI->takeName(BC);
WorkMap[BC] = NewI;
+ } else if (auto *SI = dyn_cast<SelectInst>(I)) {
+ auto *NewSI = SelectInst::Create(
+ SI->getCondition(), getReplacement(SI->getTrueValue()),
+ getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI);
+ IC.InsertNewInstWith(NewSI, *SI);
+ NewSI->takeName(SI);
+ WorkMap[SI] = NewSI;
} else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) {
auto *SrcV = getReplacement(MemCpy->getRawSource());
// The pointer may appear in the destination of a copy, but we don't want to
@@ -339,27 +431,27 @@ void PointerReplacer::replace(Instruction *I) {
}
}
-void PointerReplacer::replacePointer(Instruction &I, Value *V) {
+void PointerReplacer::replacePointer(Value *V) {
#ifndef NDEBUG
- auto *PT = cast<PointerType>(I.getType());
+ auto *PT = cast<PointerType>(Root.getType());
auto *NT = cast<PointerType>(V->getType());
assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage");
#endif
- WorkMap[&I] = V;
+ WorkMap[&Root] = V;
for (Instruction *Workitem : Worklist)
replace(Workitem);
}
Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
- if (auto *I = simplifyAllocaArraySize(*this, AI))
+ if (auto *I = simplifyAllocaArraySize(*this, AI, DT))
return I;
if (AI.getAllocatedType()->isSized()) {
// Move all alloca's of zero byte objects to the entry block and merge them
// together. Note that we only do this for alloca's, because malloc should
// allocate and return a unique pointer, even for a zero byte allocation.
- if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinSize() == 0) {
+ if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinValue() == 0) {
// For a zero sized alloca there is no point in doing an array allocation.
// This is helpful if the array size is a complicated expression not used
// elsewhere.
@@ -377,7 +469,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
AllocaInst *EntryAI = dyn_cast<AllocaInst>(FirstInst);
if (!EntryAI || !EntryAI->getAllocatedType()->isSized() ||
DL.getTypeAllocSize(EntryAI->getAllocatedType())
- .getKnownMinSize() != 0) {
+ .getKnownMinValue() != 0) {
AI.moveBefore(FirstInst);
return &AI;
}
@@ -395,11 +487,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
}
// Check to see if this allocation is only modified by a memcpy/memmove from
- // a constant whose alignment is equal to or exceeds that of the allocation.
- // If this is the case, we can change all users to use the constant global
- // instead. This is commonly produced by the CFE by constructs like "void
- // foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' is only subsequently
- // read.
+ // a memory location whose alignment is equal to or exceeds that of the
+ // allocation. If this is the case, we can change all users to use the
+ // constant memory location instead. This is commonly produced by the CFE by
+ // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A'
+ // is only subsequently read.
SmallVector<Instruction *, 4> ToDelete;
if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) {
Value *TheSrc = Copy->getSource();
@@ -415,7 +507,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n');
unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace();
auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace);
- if (AI.getType()->getAddressSpace() == SrcAddrSpace) {
+ if (AI.getAddressSpace() == SrcAddrSpace) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
@@ -426,13 +518,13 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) {
return NewI;
}
- PointerReplacer PtrReplacer(*this);
- if (PtrReplacer.collectUsers(AI)) {
+ PointerReplacer PtrReplacer(*this, AI);
+ if (PtrReplacer.collectUsers()) {
for (Instruction *Delete : ToDelete)
eraseInstFromFunction(*Delete);
Value *Cast = Builder.CreateBitCast(TheSrc, DestTy);
- PtrReplacer.replacePointer(AI, Cast);
+ PtrReplacer.replacePointer(Cast);
++NumGlobalCopies;
}
}
@@ -507,6 +599,7 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI,
// here.
switch (ID) {
case LLVMContext::MD_dbg:
+ case LLVMContext::MD_DIAssignID:
case LLVMContext::MD_tbaa:
case LLVMContext::MD_prof:
case LLVMContext::MD_fpmath:
@@ -575,43 +668,43 @@ static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) {
/// later. However, it is risky in case some backend or other part of LLVM is
/// relying on the exact type loaded to select appropriate atomic operations.
static Instruction *combineLoadToOperationType(InstCombinerImpl &IC,
- LoadInst &LI) {
+ LoadInst &Load) {
// FIXME: We could probably with some care handle both volatile and ordered
// atomic loads here but it isn't clear that this is important.
- if (!LI.isUnordered())
+ if (!Load.isUnordered())
return nullptr;
- if (LI.use_empty())
+ if (Load.use_empty())
return nullptr;
// swifterror values can't be bitcasted.
- if (LI.getPointerOperand()->isSwiftError())
+ if (Load.getPointerOperand()->isSwiftError())
return nullptr;
- const DataLayout &DL = IC.getDataLayout();
-
// Fold away bit casts of the loaded value by loading the desired type.
// Note that we should not do this for pointer<->integer casts,
// because that would result in type punning.
- if (LI.hasOneUse()) {
+ if (Load.hasOneUse()) {
// Don't transform when the type is x86_amx, it makes the pass that lower
// x86_amx type happy.
- if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) {
- assert(!LI.getType()->isX86_AMXTy() &&
- "load from x86_amx* should not happen!");
+ Type *LoadTy = Load.getType();
+ if (auto *BC = dyn_cast<BitCastInst>(Load.user_back())) {
+ assert(!LoadTy->isX86_AMXTy() && "Load from x86_amx* should not happen!");
if (BC->getType()->isX86_AMXTy())
return nullptr;
}
- if (auto* CI = dyn_cast<CastInst>(LI.user_back()))
- if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() ==
- CI->getDestTy()->isPtrOrPtrVectorTy())
- if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) {
- LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy());
- CI->replaceAllUsesWith(NewLoad);
- IC.eraseInstFromFunction(*CI);
- return &LI;
- }
+ if (auto *CastUser = dyn_cast<CastInst>(Load.user_back())) {
+ Type *DestTy = CastUser->getDestTy();
+ if (CastUser->isNoopCast(IC.getDataLayout()) &&
+ LoadTy->isPtrOrPtrVectorTy() == DestTy->isPtrOrPtrVectorTy() &&
+ (!Load.isAtomic() || isSupportedAtomicType(DestTy))) {
+ LoadInst *NewLoad = IC.combineLoadToNewType(Load, DestTy);
+ CastUser->replaceAllUsesWith(NewLoad);
+ IC.eraseInstFromFunction(*CastUser);
+ return &Load;
+ }
+ }
}
// FIXME: We should also canonicalize loads of vectors when their elements are
@@ -639,7 +732,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
".unpack");
NewLoad->setAAMetadata(LI.getAAMetadata());
return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue(
- UndefValue::get(T), NewLoad, 0, Name));
+ PoisonValue::get(T), NewLoad, 0, Name));
}
// We don't want to break loads with padding here as we'd loose
@@ -654,13 +747,13 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
auto *IdxType = Type::getInt32Ty(T->getContext());
auto *Zero = ConstantInt::get(IdxType, 0);
- Value *V = UndefValue::get(T);
+ Value *V = PoisonValue::get(T);
for (unsigned i = 0; i < NumElements; i++) {
Value *Indices[2] = {
Zero,
ConstantInt::get(IdxType, i),
};
- auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices),
+ auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices),
Name + ".elt");
auto *L = IC.Builder.CreateAlignedLoad(
ST->getElementType(i), Ptr,
@@ -681,7 +774,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack");
NewLoad->setAAMetadata(LI.getAAMetadata());
return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue(
- UndefValue::get(T), NewLoad, 0, Name));
+ PoisonValue::get(T), NewLoad, 0, Name));
}
// Bail out if the array is too large. Ideally we would like to optimize
@@ -699,14 +792,14 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) {
auto *IdxType = Type::getInt64Ty(T->getContext());
auto *Zero = ConstantInt::get(IdxType, 0);
- Value *V = UndefValue::get(T);
+ Value *V = PoisonValue::get(T);
uint64_t Offset = 0;
for (uint64_t i = 0; i < NumElements; i++) {
Value *Indices[2] = {
Zero,
ConstantInt::get(IdxType, i),
};
- auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices),
+ auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices),
Name + ".elt");
auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr,
commonAlignment(Align, Offset),
@@ -769,10 +862,13 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize,
if (!CS)
return false;
- uint64_t TypeSize = DL.getTypeAllocSize(AI->getAllocatedType());
+ TypeSize TS = DL.getTypeAllocSize(AI->getAllocatedType());
+ if (TS.isScalable())
+ return false;
// Make sure that, even if the multiplication below would wrap as an
// uint64_t, we still do the right thing.
- if ((CS->getValue().zext(128) * APInt(128, TypeSize)).ugt(MaxSize))
+ if ((CS->getValue().zext(128) * APInt(128, TS.getFixedValue()))
+ .ugt(MaxSize))
return false;
continue;
}
@@ -849,7 +945,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC,
if (!AllocTy || !AllocTy->isSized())
return false;
const DataLayout &DL = IC.getDataLayout();
- uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize();
+ uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedValue();
// If there are more indices after the one we might replace with a zero, make
// sure they're all non-negative. If any of them are negative, the overall
@@ -1183,8 +1279,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
Zero,
ConstantInt::get(IdxType, i),
};
- auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices),
- AddrName);
+ auto *Ptr =
+ IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices), AddrName);
auto *Val = IC.Builder.CreateExtractValue(V, i, EltName);
auto EltAlign = commonAlignment(Align, SL->getElementOffset(i));
llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign);
@@ -1229,8 +1325,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) {
Zero,
ConstantInt::get(IdxType, i),
};
- auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices),
- AddrName);
+ auto *Ptr =
+ IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName);
auto *Val = IC.Builder.CreateExtractValue(V, i, EltName);
auto EltAlign = commonAlignment(Align, Offset);
Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign);
@@ -1372,7 +1468,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) {
// If we have a store to a location which is known constant, we can conclude
// that the store must be storing the constant value (else the memory
// wouldn't be constant), and this must be a noop.
- if (AA->pointsToConstantMemory(Ptr))
+ if (!isModSet(AA->getModRefInfoMask(Ptr)))
return eraseInstFromFunction(SI);
// Do really simple DSE, to catch cases where there are several consecutive
@@ -1547,6 +1643,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) {
SI.getOrdering(), SI.getSyncScopeID());
InsertNewInstBefore(NewSI, *BBI);
NewSI->setDebugLoc(MergedLoc);
+ NewSI->mergeDIAssignID({&SI, OtherStore});
// If the two stores had AA tags, merge them.
AAMDNodes AATags = SI.getAAMetadata();
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 8cb09cbac86f..97f129e200de 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -15,6 +15,7 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/InstructionSimplify.h"
+#include "llvm/Analysis/ValueTracking.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
@@ -139,9 +140,56 @@ static Value *foldMulSelectToNegate(BinaryOperator &I,
return nullptr;
}
+/// Reduce integer multiplication patterns that contain a (+/-1 << Z) factor.
+/// Callers are expected to call this twice to handle commuted patterns.
+static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands,
+ InstCombiner::BuilderTy &Builder) {
+ Value *X = Mul.getOperand(0), *Y = Mul.getOperand(1);
+ if (CommuteOperands)
+ std::swap(X, Y);
+
+ const bool HasNSW = Mul.hasNoSignedWrap();
+ const bool HasNUW = Mul.hasNoUnsignedWrap();
+
+ // X * (1 << Z) --> X << Z
+ Value *Z;
+ if (match(Y, m_Shl(m_One(), m_Value(Z)))) {
+ bool PropagateNSW = HasNSW && cast<ShlOperator>(Y)->hasNoSignedWrap();
+ return Builder.CreateShl(X, Z, Mul.getName(), HasNUW, PropagateNSW);
+ }
+
+ // Similar to above, but an increment of the shifted value becomes an add:
+ // X * ((1 << Z) + 1) --> (X * (1 << Z)) + X --> (X << Z) + X
+ // This increases uses of X, so it may require a freeze, but that is still
+ // expected to be an improvement because it removes the multiply.
+ BinaryOperator *Shift;
+ if (match(Y, m_OneUse(m_Add(m_BinOp(Shift), m_One()))) &&
+ match(Shift, m_OneUse(m_Shl(m_One(), m_Value(Z))))) {
+ bool PropagateNSW = HasNSW && Shift->hasNoSignedWrap();
+ Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr");
+ Value *Shl = Builder.CreateShl(FrX, Z, "mulshl", HasNUW, PropagateNSW);
+ return Builder.CreateAdd(Shl, FrX, Mul.getName(), HasNUW, PropagateNSW);
+ }
+
+ // Similar to above, but a decrement of the shifted value is disguised as
+ // 'not' and becomes a sub:
+ // X * (~(-1 << Z)) --> X * ((1 << Z) - 1) --> (X << Z) - X
+ // This increases uses of X, so it may require a freeze, but that is still
+ // expected to be an improvement because it removes the multiply.
+ if (match(Y, m_OneUse(m_Not(m_OneUse(m_Shl(m_AllOnes(), m_Value(Z))))))) {
+ Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr");
+ Value *Shl = Builder.CreateShl(FrX, Z, "mulshl");
+ return Builder.CreateSub(Shl, FrX, Mul.getName());
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
- if (Value *V = simplifyMulInst(I.getOperand(0), I.getOperand(1),
- SQ.getWithInstruction(&I)))
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ if (Value *V =
+ simplifyMulInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+ SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
if (SimplifyAssociativeOrCommutative(I))
@@ -153,18 +201,18 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Instruction *Phi = foldBinopWithPhiOperands(I))
return Phi;
- if (Value *V = SimplifyUsingDistributiveLaws(I))
+ if (Value *V = foldUsingDistributiveLaws(I))
return replaceInstUsesWith(I, V);
- Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
- unsigned BitWidth = I.getType()->getScalarSizeInBits();
+ Type *Ty = I.getType();
+ const unsigned BitWidth = Ty->getScalarSizeInBits();
+ const bool HasNSW = I.hasNoSignedWrap();
+ const bool HasNUW = I.hasNoUnsignedWrap();
- // X * -1 == 0 - X
+ // X * -1 --> 0 - X
if (match(Op1, m_AllOnes())) {
- BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName());
- if (I.hasNoSignedWrap())
- BO->setHasNoSignedWrap();
- return BO;
+ return HasNSW ? BinaryOperator::CreateNSWNeg(Op0)
+ : BinaryOperator::CreateNeg(Op0);
}
// Also allow combining multiply instructions on vectors.
@@ -179,10 +227,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
Constant *Shl = ConstantExpr::getShl(C1, C2);
BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0));
BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl);
- if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap())
+ if (HasNUW && Mul->hasNoUnsignedWrap())
BO->setHasNoUnsignedWrap();
- if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() &&
- Shl->isNotMinSignedValue())
+ if (HasNSW && Mul->hasNoSignedWrap() && Shl->isNotMinSignedValue())
BO->setHasNoSignedWrap();
return BO;
}
@@ -192,9 +239,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) {
BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst);
- if (I.hasNoUnsignedWrap())
+ if (HasNUW)
Shl->setHasNoUnsignedWrap();
- if (I.hasNoSignedWrap()) {
+ if (HasNSW) {
const APInt *V;
if (match(NewCst, m_APInt(V)) && *V != V->getBitWidth() - 1)
Shl->setHasNoSignedWrap();
@@ -211,6 +258,25 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this))
return BinaryOperator::CreateMul(
NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName());
+
+ // Try to convert multiply of extended operand to narrow negate and shift
+ // for better analysis.
+ // This is valid if the shift amount (trailing zeros in the multiplier
+ // constant) clears more high bits than the bitwidth difference between
+ // source and destination types:
+ // ({z/s}ext X) * (-1<<C) --> (zext (-X)) << C
+ const APInt *NegPow2C;
+ Value *X;
+ if (match(Op0, m_ZExtOrSExt(m_Value(X))) &&
+ match(Op1, m_APIntAllowUndef(NegPow2C))) {
+ unsigned SrcWidth = X->getType()->getScalarSizeInBits();
+ unsigned ShiftAmt = NegPow2C->countTrailingZeros();
+ if (ShiftAmt >= BitWidth - SrcWidth) {
+ Value *N = Builder.CreateNeg(X, X->getName() + ".neg");
+ Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z");
+ return BinaryOperator::CreateShl(Z, ConstantInt::get(Ty, ShiftAmt));
+ }
+ }
}
if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I))
@@ -220,16 +286,29 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
return replaceInstUsesWith(I, FoldedMul);
// Simplify mul instructions with a constant RHS.
- if (isa<Constant>(Op1)) {
- // Canonicalize (X+C1)*CI -> X*CI+C1*CI.
+ Constant *MulC;
+ if (match(Op1, m_ImmConstant(MulC))) {
+ // Canonicalize (X+C1)*MulC -> X*MulC+C1*MulC.
+ // Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC.
Value *X;
Constant *C1;
- if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) {
- Value *Mul = Builder.CreateMul(C1, Op1);
- // Only go forward with the transform if C1*CI simplifies to a tidier
- // constant.
- if (!match(Mul, m_Mul(m_Value(), m_Value())))
- return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul);
+ if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) ||
+ (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) &&
+ haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) {
+ // C1*MulC simplifies to a tidier constant.
+ Value *NewC = Builder.CreateMul(C1, MulC);
+ auto *BOp0 = cast<BinaryOperator>(Op0);
+ bool Op0NUW =
+ (BOp0->getOpcode() == Instruction::Or || BOp0->hasNoUnsignedWrap());
+ Value *NewMul = Builder.CreateMul(X, MulC);
+ auto *BO = BinaryOperator::CreateAdd(NewMul, NewC);
+ if (HasNUW && Op0NUW) {
+ // If NewMulBO is constant we also can set BO to nuw.
+ if (auto *NewMulBO = dyn_cast<BinaryOperator>(NewMul))
+ NewMulBO->setHasNoUnsignedWrap();
+ BO->setHasNoUnsignedWrap();
+ }
+ return BO;
}
}
@@ -254,8 +333,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// -X * -Y --> X * Y
if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) {
auto *NewMul = BinaryOperator::CreateMul(X, Y);
- if (I.hasNoSignedWrap() &&
- cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() &&
+ if (HasNSW && cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() &&
cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap())
NewMul->setHasNoSignedWrap();
return NewMul;
@@ -306,33 +384,15 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
// 2) X * Y --> X & Y, iff X, Y can be only {0,1}.
// Note: We could use known bits to generalize this and related patterns with
// shifts/truncs
- Type *Ty = I.getType();
if (Ty->isIntOrIntVectorTy(1) ||
(match(Op0, m_And(m_Value(), m_One())) &&
match(Op1, m_And(m_Value(), m_One()))))
return BinaryOperator::CreateAnd(Op0, Op1);
- // X*(1 << Y) --> X << Y
- // (1 << Y)*X --> X << Y
- {
- Value *Y;
- BinaryOperator *BO = nullptr;
- bool ShlNSW = false;
- if (match(Op0, m_Shl(m_One(), m_Value(Y)))) {
- BO = BinaryOperator::CreateShl(Op1, Y);
- ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap();
- } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) {
- BO = BinaryOperator::CreateShl(Op0, Y);
- ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap();
- }
- if (BO) {
- if (I.hasNoUnsignedWrap())
- BO->setHasNoUnsignedWrap();
- if (I.hasNoSignedWrap() && ShlNSW)
- BO->setHasNoSignedWrap();
- return BO;
- }
- }
+ if (Value *R = foldMulShl1(I, /* CommuteOperands */ false, Builder))
+ return replaceInstUsesWith(I, R);
+ if (Value *R = foldMulShl1(I, /* CommuteOperands */ true, Builder))
+ return replaceInstUsesWith(I, R);
// (zext bool X) * (zext bool Y) --> zext (and X, Y)
// (sext bool X) * (sext bool Y) --> zext (and X, Y)
@@ -403,8 +463,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
m_One()),
m_Deferred(X)))) {
Value *Abs = Builder.CreateBinaryIntrinsic(
- Intrinsic::abs, X,
- ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap()));
+ Intrinsic::abs, X, ConstantInt::getBool(I.getContext(), HasNSW));
Abs->takeName(&I);
return replaceInstUsesWith(I, Abs);
}
@@ -413,12 +472,12 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) {
return Ext;
bool Changed = false;
- if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) {
+ if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) {
Changed = true;
I.setHasNoSignedWrap(true);
}
- if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) {
+ if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I)) {
Changed = true;
I.setHasNoUnsignedWrap(true);
}
@@ -488,11 +547,19 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
if (match(Op1, m_SpecificFP(-1.0)))
return UnaryOperator::CreateFNegFMF(Op0, &I);
+ // With no-nans: X * 0.0 --> copysign(0.0, X)
+ if (I.hasNoNaNs() && match(Op1, m_PosZeroFP())) {
+ CallInst *CopySign = Builder.CreateIntrinsic(Intrinsic::copysign,
+ {I.getType()}, {Op1, Op0}, &I);
+ return replaceInstUsesWith(I, CopySign);
+ }
+
// -X * C --> X * -C
Value *X, *Y;
Constant *C;
if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C)))
- return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFMulFMF(X, NegC, &I);
// (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E)
if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1))
@@ -596,14 +663,32 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
}
}
+ // pow(X, Y) * X --> pow(X, Y+1)
+ // X * pow(X, Y) --> pow(X, Y+1)
+ if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X),
+ m_Value(Y))),
+ m_Deferred(X)))) {
+ Value *Y1 =
+ Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I);
+ Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I);
+ return replaceInstUsesWith(I, Pow);
+ }
+
if (I.isOnlyUserOfAnyOperand()) {
- // pow(x, y) * pow(x, z) -> pow(x, y + z)
+ // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z)
if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) {
auto *YZ = Builder.CreateFAddFMF(Y, Z, &I);
auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I);
return replaceInstUsesWith(I, NewPow);
}
+ // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y)
+ if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) &&
+ match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) {
+ auto *XZ = Builder.CreateFMulFMF(X, Z, &I);
+ auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I);
+ return replaceInstUsesWith(I, NewPow);
+ }
// powi(x, y) * powi(x, z) -> powi(x, y + z)
if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) &&
@@ -671,6 +756,15 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) {
}
}
+ // Simplify FMUL recurrences starting with 0.0 to 0.0 if nnan and nsz are set.
+ // Given a phi node with entry value as 0 and it used in fmul operation,
+ // we can replace fmul with 0 safely and eleminate loop operation.
+ PHINode *PN = nullptr;
+ Value *Start = nullptr, *Step = nullptr;
+ if (matchSimpleRecurrence(&I, PN, Start, Step) && I.hasNoNaNs() &&
+ I.hasNoSignedZeros() && match(Start, m_Zero()))
+ return replaceInstUsesWith(I, Start);
+
return nullptr;
}
@@ -773,6 +867,70 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient,
return Remainder.isMinValue();
}
+static Instruction *foldIDivShl(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder) {
+ assert((I.getOpcode() == Instruction::SDiv ||
+ I.getOpcode() == Instruction::UDiv) &&
+ "Expected integer divide");
+
+ bool IsSigned = I.getOpcode() == Instruction::SDiv;
+ Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ Type *Ty = I.getType();
+
+ Instruction *Ret = nullptr;
+ Value *X, *Y, *Z;
+
+ // With appropriate no-wrap constraints, remove a common factor in the
+ // dividend and divisor that is disguised as a left-shifted value.
+ if (match(Op1, m_Shl(m_Value(X), m_Value(Z))) &&
+ match(Op0, m_c_Mul(m_Specific(X), m_Value(Y)))) {
+ // Both operands must have the matching no-wrap for this kind of division.
+ auto *Mul = cast<OverflowingBinaryOperator>(Op0);
+ auto *Shl = cast<OverflowingBinaryOperator>(Op1);
+ bool HasNUW = Mul->hasNoUnsignedWrap() && Shl->hasNoUnsignedWrap();
+ bool HasNSW = Mul->hasNoSignedWrap() && Shl->hasNoSignedWrap();
+
+ // (X * Y) u/ (X << Z) --> Y u>> Z
+ if (!IsSigned && HasNUW)
+ Ret = BinaryOperator::CreateLShr(Y, Z);
+
+ // (X * Y) s/ (X << Z) --> Y s/ (1 << Z)
+ if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) {
+ Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z);
+ Ret = BinaryOperator::CreateSDiv(Y, Shl);
+ }
+ }
+
+ // With appropriate no-wrap constraints, remove a common factor in the
+ // dividend and divisor that is disguised as a left-shift amount.
+ if (match(Op0, m_Shl(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_Shl(m_Value(Y), m_Specific(Z)))) {
+ auto *Shl0 = cast<OverflowingBinaryOperator>(Op0);
+ auto *Shl1 = cast<OverflowingBinaryOperator>(Op1);
+
+ // For unsigned div, we need 'nuw' on both shifts or
+ // 'nsw' on both shifts + 'nuw' on the dividend.
+ // (X << Z) / (Y << Z) --> X / Y
+ if (!IsSigned &&
+ ((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) ||
+ (Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() &&
+ Shl1->hasNoSignedWrap())))
+ Ret = BinaryOperator::CreateUDiv(X, Y);
+
+ // For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor.
+ // (X << Z) / (Y << Z) --> X / Y
+ if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() &&
+ Shl1->hasNoUnsignedWrap())
+ Ret = BinaryOperator::CreateSDiv(X, Y);
+ }
+
+ if (!Ret)
+ return nullptr;
+
+ Ret->setIsExact(I.isExact());
+ return Ret;
+}
+
/// This function implements the transforms common to both integer division
/// instructions (udiv and sdiv). It is called by the visitors to those integer
/// division instructions.
@@ -919,6 +1077,41 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) {
}
}
+ // (X << Z) / (X * Y) -> (1 << Z) / Y
+ // TODO: Handle sdiv.
+ if (!IsSigned && Op1->hasOneUse() &&
+ match(Op0, m_NUWShl(m_Value(X), m_Value(Z))) &&
+ match(Op1, m_c_Mul(m_Specific(X), m_Value(Y))))
+ if (cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap()) {
+ Instruction *NewDiv = BinaryOperator::CreateUDiv(
+ Builder.CreateShl(ConstantInt::get(Ty, 1), Z, "", /*NUW*/ true), Y);
+ NewDiv->setIsExact(I.isExact());
+ return NewDiv;
+ }
+
+ if (Instruction *R = foldIDivShl(I, Builder))
+ return R;
+
+ // With the appropriate no-wrap constraint, remove a multiply by the divisor
+ // after peeking through another divide:
+ // ((Op1 * X) / Y) / Op1 --> X / Y
+ if (match(Op0, m_BinOp(I.getOpcode(), m_c_Mul(m_Specific(Op1), m_Value(X)),
+ m_Value(Y)))) {
+ auto *InnerDiv = cast<PossiblyExactOperator>(Op0);
+ auto *Mul = cast<OverflowingBinaryOperator>(InnerDiv->getOperand(0));
+ Instruction *NewDiv = nullptr;
+ if (!IsSigned && Mul->hasNoUnsignedWrap())
+ NewDiv = BinaryOperator::CreateUDiv(X, Y);
+ else if (IsSigned && Mul->hasNoSignedWrap())
+ NewDiv = BinaryOperator::CreateSDiv(X, Y);
+
+ // Exact propagates only if both of the original divides are exact.
+ if (NewDiv) {
+ NewDiv->setIsExact(I.isExact() && InnerDiv->isExact());
+ return NewDiv;
+ }
+ }
+
return nullptr;
}
@@ -1007,8 +1200,8 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
}
Constant *C;
- if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) ||
- (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) {
+ if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) &&
+ match(D, m_Constant(C))) {
// If the constant is the same in the smaller type, use the narrow version.
Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
if (ConstantExpr::getZExt(TruncC, Ty) != C)
@@ -1016,18 +1209,25 @@ static Instruction *narrowUDivURem(BinaryOperator &I,
// udiv (zext X), C --> zext (udiv X, C')
// urem (zext X), C --> zext (urem X, C')
+ return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty);
+ }
+ if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) &&
+ match(N, m_Constant(C))) {
+ // If the constant is the same in the smaller type, use the narrow version.
+ Constant *TruncC = ConstantExpr::getTrunc(C, X->getType());
+ if (ConstantExpr::getZExt(TruncC, Ty) != C)
+ return nullptr;
+
// udiv C, (zext X) --> zext (udiv C', X)
// urem C, (zext X) --> zext (urem C', X)
- Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC)
- : Builder.CreateBinOp(Opcode, TruncC, X);
- return new ZExtInst(NarrowOp, Ty);
+ return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty);
}
return nullptr;
}
Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
- if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1),
+ if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1086,6 +1286,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
return BinaryOperator::CreateUDiv(A, X);
}
+ // Look through a right-shift to find the common factor:
+ // ((Op1 *nuw A) >> B) / Op1 --> A >> B
+ if (match(Op0, m_LShr(m_NUWMul(m_Specific(Op1), m_Value(A)), m_Value(B))) ||
+ match(Op0, m_LShr(m_NUWMul(m_Value(A), m_Specific(Op1)), m_Value(B)))) {
+ Instruction *Lshr = BinaryOperator::CreateLShr(A, B);
+ if (I.isExact() && cast<PossiblyExactOperator>(Op0)->isExact())
+ Lshr->setIsExact();
+ return Lshr;
+ }
+
// Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away.
if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) {
Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true);
@@ -1097,7 +1307,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) {
}
Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
- if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1),
+ if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), I.isExact(),
SQ.getWithInstruction(&I)))
return replaceInstUsesWith(I, V);
@@ -1121,20 +1331,25 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
if (match(Op1, m_SignMask()))
return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty);
- // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative
- // sdiv exact X, -1<<C --> -(ashr exact X, C)
- if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) ||
- match(Op1, m_NegatedPower2()))) {
- bool DivisorWasNegative = match(Op1, m_NegatedPower2());
- if (DivisorWasNegative)
- Op1 = ConstantExpr::getNeg(cast<Constant>(Op1));
- auto *AShr = BinaryOperator::CreateExactAShr(
- Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName());
- if (!DivisorWasNegative)
- return AShr;
- Builder.Insert(AShr);
- AShr->setName(I.getName() + ".neg");
- return BinaryOperator::CreateNeg(AShr, I.getName());
+ if (I.isExact()) {
+ // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative
+ if (match(Op1, m_Power2()) && match(Op1, m_NonNegative())) {
+ Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op1));
+ return BinaryOperator::CreateExactAShr(Op0, C);
+ }
+
+ // sdiv exact X, (1<<ShAmt) --> ashr exact X, ShAmt (if shl is non-negative)
+ Value *ShAmt;
+ if (match(Op1, m_NSWShl(m_One(), m_Value(ShAmt))))
+ return BinaryOperator::CreateExactAShr(Op0, ShAmt);
+
+ // sdiv exact X, -1<<C --> -(ashr exact X, C)
+ if (match(Op1, m_NegatedPower2())) {
+ Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1));
+ Constant *C = ConstantExpr::getExactLogBase2(NegPow2C);
+ Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true);
+ return BinaryOperator::CreateNeg(Ashr);
+ }
}
const APInt *Op1C;
@@ -1184,12 +1399,17 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
ConstantInt::getAllOnesValue(Ty));
}
- // If the sign bits of both operands are zero (i.e. we can prove they are
- // unsigned inputs), turn this into a udiv.
- APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits()));
- if (MaskedValueIsZero(Op0, Mask, 0, &I)) {
- if (MaskedValueIsZero(Op1, Mask, 0, &I)) {
- // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set
+ KnownBits KnownDividend = computeKnownBits(Op0, 0, &I);
+ if (!I.isExact() &&
+ (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) &&
+ KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) {
+ I.setIsExact();
+ return &I;
+ }
+
+ if (KnownDividend.isNonNegative()) {
+ // If both operands are unsigned, turn this into a udiv.
+ if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) {
auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName());
BO->setIsExact(I.isExact());
return BO;
@@ -1219,15 +1439,28 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) {
}
/// Remove negation and try to convert division into multiplication.
-static Instruction *foldFDivConstantDivisor(BinaryOperator &I) {
+Instruction *InstCombinerImpl::foldFDivConstantDivisor(BinaryOperator &I) {
Constant *C;
if (!match(I.getOperand(1), m_Constant(C)))
return nullptr;
// -X / C --> X / -C
Value *X;
+ const DataLayout &DL = I.getModule()->getDataLayout();
if (match(I.getOperand(0), m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFDivFMF(X, NegC, &I);
+
+ // nnan X / +0.0 -> copysign(inf, X)
+ if (I.hasNoNaNs() && match(I.getOperand(1), m_Zero())) {
+ IRBuilder<> B(&I);
+ // TODO: nnan nsz X / -0.0 -> copysign(inf, X)
+ CallInst *CopySign = B.CreateIntrinsic(
+ Intrinsic::copysign, {C->getType()},
+ {ConstantFP::getInfinity(I.getType()), I.getOperand(0)}, &I);
+ CopySign->takeName(&I);
+ return replaceInstUsesWith(I, CopySign);
+ }
// If the constant divisor has an exact inverse, this is always safe. If not,
// then we can still create a reciprocal if fast-math-flags allow it and the
@@ -1239,7 +1472,6 @@ static Instruction *foldFDivConstantDivisor(BinaryOperator &I) {
// on all targets.
// TODO: Use Intrinsic::canonicalize or let function attributes tell us that
// denorms are flushed?
- const DataLayout &DL = I.getModule()->getDataLayout();
auto *RecipC = ConstantFoldBinaryOpOperands(
Instruction::FDiv, ConstantFP::get(I.getType(), 1.0), C, DL);
if (!RecipC || !RecipC->isNormalFP())
@@ -1257,15 +1489,16 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) {
// C / -X --> -C / X
Value *X;
+ const DataLayout &DL = I.getModule()->getDataLayout();
if (match(I.getOperand(1), m_FNeg(m_Value(X))))
- return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I);
+ if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL))
+ return BinaryOperator::CreateFDivFMF(NegC, X, &I);
if (!I.hasAllowReassoc() || !I.hasAllowReciprocal())
return nullptr;
// Try to reassociate C / X expressions where X includes another constant.
Constant *C2, *NewC = nullptr;
- const DataLayout &DL = I.getModule()->getDataLayout();
if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) {
// C / (X * C2) --> (C / C2) / X
NewC = ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C2, DL);
@@ -1435,6 +1668,16 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) {
if (Instruction *Mul = foldFDivPowDivisor(I, Builder))
return Mul;
+ // pow(X, Y) / X --> pow(X, Y-1)
+ if (I.hasAllowReassoc() &&
+ match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1),
+ m_Value(Y))))) {
+ Value *Y1 =
+ Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), -1.0), &I);
+ Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, Op1, Y1, &I);
+ return replaceInstUsesWith(I, Pow);
+ }
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
index c573b03f31a6..e24abc48424d 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp
@@ -15,8 +15,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/None.h"
-#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
@@ -130,7 +128,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) {
// FIXME: can this be reworked into a worklist-based algorithm while preserving
// the depth-first, early bailout traversal?
-LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
+[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) {
// -(undef) -> undef.
if (match(V, m_Undef()))
return V;
@@ -248,6 +246,19 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
return nullptr;
switch (I->getOpcode()) {
+ case Instruction::ZExt: {
+ // Negation of zext of signbit is signbit splat:
+ // 0 - (zext (i8 X u>> 7) to iN) --> sext (i8 X s>> 7) to iN
+ Value *SrcOp = I->getOperand(0);
+ unsigned SrcWidth = SrcOp->getType()->getScalarSizeInBits();
+ const APInt &FullShift = APInt(SrcWidth, SrcWidth - 1);
+ if (IsTrulyNegation &&
+ match(SrcOp, m_LShr(m_Value(X), m_SpecificIntAllowUndef(FullShift)))) {
+ Value *Ashr = Builder.CreateAShr(X, FullShift);
+ return Builder.CreateSExt(Ashr, I->getType());
+ }
+ break;
+ }
case Instruction::And: {
Constant *ShAmt;
// sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y)
@@ -382,7 +393,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg");
// Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`.
auto *Op1C = dyn_cast<Constant>(I->getOperand(1));
- if (!Op1C) // Early return.
+ if (!Op1C || !IsTrulyNegation)
return nullptr;
return Builder.CreateMul(
I->getOperand(0),
@@ -399,7 +410,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
if (match(Ops[1], m_One()))
return Builder.CreateNot(Ops[0], I->getName() + ".neg");
// Else, just defer to Instruction::Add handling.
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Instruction::Add: {
// `add` is negatible if both of its operands are negatible.
@@ -465,7 +476,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) {
llvm_unreachable("Can't get here. We always return from switch.");
}
-LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) {
+[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) {
NegatorMaxDepthVisited.updateMax(Depth);
++NegatorNumValuesVisited;
@@ -502,20 +513,20 @@ LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) {
return NegatedV;
}
-LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) {
+[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) {
Value *Negated = negate(Root, /*Depth=*/0);
if (!Negated) {
// We must cleanup newly-inserted instructions, to avoid any potential
// endless combine looping.
for (Instruction *I : llvm::reverse(NewInstructions))
I->eraseFromParent();
- return llvm::None;
+ return std::nullopt;
}
return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated);
}
-LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root,
- InstCombinerImpl &IC) {
+[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root,
+ InstCombinerImpl &IC) {
++NegatorTotalNegationsAttempted;
LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root
<< "\n");
@@ -525,7 +536,7 @@ LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root,
Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(),
IC.getDominatorTree(), LHSIsZero);
- Optional<Result> Res = N.run(Root);
+ std::optional<Result> Res = N.run(Root);
if (!Res) { // Negation failed.
LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root
<< "\n");
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 90a796a0939e..7f59729f0085 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -20,6 +20,7 @@
#include "llvm/Support/CommandLine.h"
#include "llvm/Transforms/InstCombine/InstCombiner.h"
#include "llvm/Transforms/Utils/Local.h"
+#include <optional>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -102,15 +103,15 @@ void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) {
// ptr_val_inc = ...
// ...
//
-Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
+bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
if (!PN.getType()->isIntegerTy())
- return nullptr;
+ return false;
if (!PN.hasOneUse())
- return nullptr;
+ return false;
auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back());
if (!IntToPtr)
- return nullptr;
+ return false;
// Check if the pointer is actually used as pointer:
auto HasPointerUse = [](Instruction *IIP) {
@@ -131,11 +132,11 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
};
if (!HasPointerUse(IntToPtr))
- return nullptr;
+ return false;
if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) !=
DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType()))
- return nullptr;
+ return false;
SmallVector<Value *, 4> AvailablePtrVals;
for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) {
@@ -174,10 +175,10 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
// For a single use integer load:
auto *LoadI = dyn_cast<LoadInst>(Arg);
if (!LoadI)
- return nullptr;
+ return false;
if (!LoadI->hasOneUse())
- return nullptr;
+ return false;
// Push the integer typed Load instruction into the available
// value set, and fix it up later when the pointer typed PHI
@@ -194,7 +195,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
for (PHINode &PtrPHI : BB->phis()) {
// FIXME: consider handling this in AggressiveInstCombine
if (NumPhis++ > MaxNumPhis)
- return nullptr;
+ return false;
if (&PtrPHI == &PN || PtrPHI.getType() != IntToPtr->getType())
continue;
if (any_of(zip(PN.blocks(), AvailablePtrVals),
@@ -211,16 +212,19 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
if (MatchingPtrPHI) {
assert(MatchingPtrPHI->getType() == IntToPtr->getType() &&
"Phi's Type does not match with IntToPtr");
- // The PtrToCast + IntToPtr will be simplified later
- return CastInst::CreateBitOrPointerCast(MatchingPtrPHI,
- IntToPtr->getOperand(0)->getType());
+ // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here,
+ // to make sure another transform can't undo it in the meantime.
+ replaceInstUsesWith(*IntToPtr, MatchingPtrPHI);
+ eraseInstFromFunction(*IntToPtr);
+ eraseInstFromFunction(PN);
+ return true;
}
// If it requires a conversion for every PHI operand, do not do it.
if (all_of(AvailablePtrVals, [&](Value *V) {
return (V->getType() != IntToPtr->getType()) || isa<IntToPtrInst>(V);
}))
- return nullptr;
+ return false;
// If any of the operand that requires casting is a terminator
// instruction, do not do it. Similarly, do not do the transform if the value
@@ -239,7 +243,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
return true;
return false;
}))
- return nullptr;
+ return false;
PHINode *NewPtrPHI = PHINode::Create(
IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr");
@@ -290,9 +294,12 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) {
NewPtrPHI->addIncoming(CI, IncomingBB);
}
- // The PtrToCast + IntToPtr will be simplified later
- return CastInst::CreateBitOrPointerCast(NewPtrPHI,
- IntToPtr->getOperand(0)->getType());
+ // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here,
+ // to make sure another transform can't undo it in the meantime.
+ replaceInstUsesWith(*IntToPtr, NewPtrPHI);
+ eraseInstFromFunction(*IntToPtr);
+ eraseInstFromFunction(PN);
+ return true;
}
// Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and
@@ -598,7 +605,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) {
Value *Base = FixedOperands[0];
GetElementPtrInst *NewGEP =
GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base,
- makeArrayRef(FixedOperands).slice(1));
+ ArrayRef(FixedOperands).slice(1));
if (AllInBounds) NewGEP->setIsInBounds();
PHIArgMergedDebugLoc(NewGEP, PN);
return NewGEP;
@@ -1322,7 +1329,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN,
// Check that edges outgoing from the idom's terminators dominate respective
// inputs of the Phi.
- Optional<bool> Invert;
+ std::optional<bool> Invert;
for (auto Pair : zip(PN.incoming_values(), PN.blocks())) {
auto *Input = cast<ConstantInt>(std::get<0>(Pair));
BasicBlock *Pred = std::get<1>(Pair);
@@ -1412,8 +1419,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) {
// this PHI only has a single use (a PHI), and if that PHI only has one use (a
// PHI)... break the cycle.
if (PN.hasOneUse()) {
- if (Instruction *Result = foldIntegerTypedPHI(PN))
- return Result;
+ if (foldIntegerTypedPHI(PN))
+ return nullptr;
Instruction *PHIUser = cast<Instruction>(PN.user_back());
if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index ad96a5f475f1..e7d8208f94fd 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -12,7 +12,6 @@
#include "InstCombineInternal.h"
#include "llvm/ADT/APInt.h"
-#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Analysis/AssumptionCache.h"
@@ -20,6 +19,7 @@
#include "llvm/Analysis/InstructionSimplify.h"
#include "llvm/Analysis/OverflowInstAnalysis.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/Analysis/VectorUtils.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/ConstantRange.h"
@@ -314,47 +314,95 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
TI->getType());
}
- // Cond ? -X : -Y --> -(Cond ? X : Y)
- Value *X, *Y;
- if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) &&
- (TI->hasOneUse() || FI->hasOneUse())) {
- // Intersect FMF from the fneg instructions and union those with the select.
- FastMathFlags FMF = TI->getFastMathFlags();
- FMF &= FI->getFastMathFlags();
- FMF |= SI.getFastMathFlags();
- Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI);
- if (auto *NewSelI = dyn_cast<Instruction>(NewSel))
- NewSelI->setFastMathFlags(FMF);
- Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel);
- NewFNeg->setFastMathFlags(FMF);
- return NewFNeg;
- }
-
- // Min/max intrinsic with a common operand can have the common operand pulled
- // after the select. This is the same transform as below for binops, but
- // specialized for intrinsic matching and without the restrictive uses clause.
- auto *TII = dyn_cast<IntrinsicInst>(TI);
- auto *FII = dyn_cast<IntrinsicInst>(FI);
- if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID() &&
- (TII->hasOneUse() || FII->hasOneUse())) {
- Value *T0, *T1, *F0, *F1;
- if (match(TII, m_MaxOrMin(m_Value(T0), m_Value(T1))) &&
- match(FII, m_MaxOrMin(m_Value(F0), m_Value(F1)))) {
- if (T0 == F0) {
- Value *NewSel = Builder.CreateSelect(Cond, T1, F1, "minmaxop", &SI);
- return CallInst::Create(TII->getCalledFunction(), {NewSel, T0});
- }
- if (T0 == F1) {
- Value *NewSel = Builder.CreateSelect(Cond, T1, F0, "minmaxop", &SI);
- return CallInst::Create(TII->getCalledFunction(), {NewSel, T0});
+ Value *OtherOpT, *OtherOpF;
+ bool MatchIsOpZero;
+ auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute,
+ bool Swapped = false) -> Value * {
+ assert(!(Commute && Swapped) &&
+ "Commute and Swapped can't set at the same time");
+ if (!Swapped) {
+ if (TI->getOperand(0) == FI->getOperand(0)) {
+ OtherOpT = TI->getOperand(1);
+ OtherOpF = FI->getOperand(1);
+ MatchIsOpZero = true;
+ return TI->getOperand(0);
+ } else if (TI->getOperand(1) == FI->getOperand(1)) {
+ OtherOpT = TI->getOperand(0);
+ OtherOpF = FI->getOperand(0);
+ MatchIsOpZero = false;
+ return TI->getOperand(1);
}
- if (T1 == F0) {
- Value *NewSel = Builder.CreateSelect(Cond, T0, F1, "minmaxop", &SI);
- return CallInst::Create(TII->getCalledFunction(), {NewSel, T1});
+ }
+
+ if (!Commute && !Swapped)
+ return nullptr;
+
+ // If we are allowing commute or swap of operands, then
+ // allow a cross-operand match. In that case, MatchIsOpZero
+ // means that TI's operand 0 (FI's operand 1) is the common op.
+ if (TI->getOperand(0) == FI->getOperand(1)) {
+ OtherOpT = TI->getOperand(1);
+ OtherOpF = FI->getOperand(0);
+ MatchIsOpZero = true;
+ return TI->getOperand(0);
+ } else if (TI->getOperand(1) == FI->getOperand(0)) {
+ OtherOpT = TI->getOperand(0);
+ OtherOpF = FI->getOperand(1);
+ MatchIsOpZero = false;
+ return TI->getOperand(1);
+ }
+ return nullptr;
+ };
+
+ if (TI->hasOneUse() || FI->hasOneUse()) {
+ // Cond ? -X : -Y --> -(Cond ? X : Y)
+ Value *X, *Y;
+ if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) {
+ // Intersect FMF from the fneg instructions and union those with the
+ // select.
+ FastMathFlags FMF = TI->getFastMathFlags();
+ FMF &= FI->getFastMathFlags();
+ FMF |= SI.getFastMathFlags();
+ Value *NewSel =
+ Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI);
+ if (auto *NewSelI = dyn_cast<Instruction>(NewSel))
+ NewSelI->setFastMathFlags(FMF);
+ Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel);
+ NewFNeg->setFastMathFlags(FMF);
+ return NewFNeg;
+ }
+
+ // Min/max intrinsic with a common operand can have the common operand
+ // pulled after the select. This is the same transform as below for binops,
+ // but specialized for intrinsic matching and without the restrictive uses
+ // clause.
+ auto *TII = dyn_cast<IntrinsicInst>(TI);
+ auto *FII = dyn_cast<IntrinsicInst>(FI);
+ if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) {
+ if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) {
+ if (Value *MatchOp = getCommonOp(TI, FI, true)) {
+ Value *NewSel =
+ Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI);
+ return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp});
+ }
}
- if (T1 == F1) {
- Value *NewSel = Builder.CreateSelect(Cond, T0, F0, "minmaxop", &SI);
- return CallInst::Create(TII->getCalledFunction(), {NewSel, T1});
+ }
+
+ // icmp with a common operand also can have the common operand
+ // pulled after the select.
+ ICmpInst::Predicate TPred, FPred;
+ if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) &&
+ match(FI, m_ICmp(FPred, m_Value(), m_Value()))) {
+ if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) {
+ bool Swapped = TPred != FPred;
+ if (Value *MatchOp =
+ getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) {
+ Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF,
+ SI.getName() + ".v", &SI);
+ return new ICmpInst(
+ MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred),
+ MatchOp, NewSel);
+ }
}
}
}
@@ -370,33 +418,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI,
return nullptr;
// Figure out if the operations have any operands in common.
- Value *MatchOp, *OtherOpT, *OtherOpF;
- bool MatchIsOpZero;
- if (TI->getOperand(0) == FI->getOperand(0)) {
- MatchOp = TI->getOperand(0);
- OtherOpT = TI->getOperand(1);
- OtherOpF = FI->getOperand(1);
- MatchIsOpZero = true;
- } else if (TI->getOperand(1) == FI->getOperand(1)) {
- MatchOp = TI->getOperand(1);
- OtherOpT = TI->getOperand(0);
- OtherOpF = FI->getOperand(0);
- MatchIsOpZero = false;
- } else if (!TI->isCommutative()) {
- return nullptr;
- } else if (TI->getOperand(0) == FI->getOperand(1)) {
- MatchOp = TI->getOperand(0);
- OtherOpT = TI->getOperand(1);
- OtherOpF = FI->getOperand(0);
- MatchIsOpZero = true;
- } else if (TI->getOperand(1) == FI->getOperand(0)) {
- MatchOp = TI->getOperand(1);
- OtherOpT = TI->getOperand(0);
- OtherOpF = FI->getOperand(1);
- MatchIsOpZero = true;
- } else {
+ Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative());
+ if (!MatchOp)
return nullptr;
- }
// If the select condition is a vector, the operands of the original select's
// operands also must be vectors. This may not be the case for getelementptr
@@ -442,44 +466,44 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal,
auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal,
Value *FalseVal,
bool Swapped) -> Instruction * {
- if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) {
- if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) {
- if (unsigned SFO = getSelectFoldableOperands(TVI)) {
- unsigned OpToFold = 0;
- if ((SFO & 1) && FalseVal == TVI->getOperand(0))
- OpToFold = 1;
- else if ((SFO & 2) && FalseVal == TVI->getOperand(1))
- OpToFold = 2;
+ auto *TVI = dyn_cast<BinaryOperator>(TrueVal);
+ if (!TVI || !TVI->hasOneUse() || isa<Constant>(FalseVal))
+ return nullptr;
- if (OpToFold) {
- FastMathFlags FMF;
- // TODO: We probably ought to revisit cases where the select and FP
- // instructions have different flags and add tests to ensure the
- // behaviour is correct.
- if (isa<FPMathOperator>(&SI))
- FMF = SI.getFastMathFlags();
- Constant *C = ConstantExpr::getBinOpIdentity(
- TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros());
- Value *OOp = TVI->getOperand(2 - OpToFold);
- // Avoid creating select between 2 constants unless it's selecting
- // between 0, 1 and -1.
- const APInt *OOpC;
- bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
- if (!isa<Constant>(OOp) ||
- (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
- Value *NewSel = Builder.CreateSelect(
- SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C);
- if (isa<FPMathOperator>(&SI))
- cast<Instruction>(NewSel)->setFastMathFlags(FMF);
- NewSel->takeName(TVI);
- BinaryOperator *BO =
- BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
- BO->copyIRFlags(TVI);
- return BO;
- }
- }
- }
- }
+ unsigned SFO = getSelectFoldableOperands(TVI);
+ unsigned OpToFold = 0;
+ if ((SFO & 1) && FalseVal == TVI->getOperand(0))
+ OpToFold = 1;
+ else if ((SFO & 2) && FalseVal == TVI->getOperand(1))
+ OpToFold = 2;
+
+ if (!OpToFold)
+ return nullptr;
+
+ // TODO: We probably ought to revisit cases where the select and FP
+ // instructions have different flags and add tests to ensure the
+ // behaviour is correct.
+ FastMathFlags FMF;
+ if (isa<FPMathOperator>(&SI))
+ FMF = SI.getFastMathFlags();
+ Constant *C = ConstantExpr::getBinOpIdentity(
+ TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros());
+ Value *OOp = TVI->getOperand(2 - OpToFold);
+ // Avoid creating select between 2 constants unless it's selecting
+ // between 0, 1 and -1.
+ const APInt *OOpC;
+ bool OOpIsAPInt = match(OOp, m_APInt(OOpC));
+ if (!isa<Constant>(OOp) ||
+ (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) {
+ Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp,
+ Swapped ? OOp : C);
+ if (isa<FPMathOperator>(&SI))
+ cast<Instruction>(NewSel)->setFastMathFlags(FMF);
+ NewSel->takeName(TVI);
+ BinaryOperator *BO =
+ BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel);
+ BO->copyIRFlags(TVI);
+ return BO;
}
return nullptr;
};
@@ -779,19 +803,31 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
const Value *FalseVal,
InstCombiner::BuilderTy &Builder) {
ICmpInst::Predicate Pred = ICI->getPredicate();
- if (!ICmpInst::isUnsigned(Pred))
- return nullptr;
+ Value *A = ICI->getOperand(0);
+ Value *B = ICI->getOperand(1);
// (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0
+ // (a == 0) ? 0 : a - 1 -> (a != 0) ? a - 1 : 0
if (match(TrueVal, m_Zero())) {
Pred = ICmpInst::getInversePredicate(Pred);
std::swap(TrueVal, FalseVal);
}
+
if (!match(FalseVal, m_Zero()))
return nullptr;
- Value *A = ICI->getOperand(0);
- Value *B = ICI->getOperand(1);
+ // ugt 0 is canonicalized to ne 0 and requires special handling
+ // (a != 0) ? a + -1 : 0 -> usub.sat(a, 1)
+ if (Pred == ICmpInst::ICMP_NE) {
+ if (match(B, m_Zero()) && match(TrueVal, m_Add(m_Specific(A), m_AllOnes())))
+ return Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A,
+ ConstantInt::get(A->getType(), 1));
+ return nullptr;
+ }
+
+ if (!ICmpInst::isUnsigned(Pred))
+ return nullptr;
+
if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) {
// (b < a) ? a - b : 0 -> (a > b) ? a - b : 0
std::swap(A, B);
@@ -952,8 +988,8 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
Value *CmpLHS = ICI->getOperand(0);
Value *CmpRHS = ICI->getOperand(1);
- // Check if the condition value compares a value for equality against zero.
- if (!ICI->isEquality() || !match(CmpRHS, m_Zero()))
+ // Check if the select condition compares a value for equality.
+ if (!ICI->isEquality())
return nullptr;
Value *SelectArg = FalseVal;
@@ -969,8 +1005,15 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal,
// Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the
// input to the cttz/ctlz is used as LHS for the compare instruction.
- if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) &&
- !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS))))
+ Value *X;
+ if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Value(X))) &&
+ !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Value(X))))
+ return nullptr;
+
+ // (X == 0) ? BitWidth : ctz(X)
+ // (X == -1) ? BitWidth : ctz(~X)
+ if ((X != CmpLHS || !match(CmpRHS, m_Zero())) &&
+ (!match(X, m_Not(m_Specific(CmpLHS))) || !match(CmpRHS, m_AllOnes())))
return nullptr;
IntrinsicInst *II = cast<IntrinsicInst>(Count);
@@ -1139,6 +1182,28 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
return nullptr;
}
+static bool replaceInInstruction(Value *V, Value *Old, Value *New,
+ InstCombiner &IC, unsigned Depth = 0) {
+ // Conservatively limit replacement to two instructions upwards.
+ if (Depth == 2)
+ return false;
+
+ auto *I = dyn_cast<Instruction>(V);
+ if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I))
+ return false;
+
+ bool Changed = false;
+ for (Use &U : I->operands()) {
+ if (U == Old) {
+ IC.replaceUse(U, New);
+ Changed = true;
+ } else {
+ Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1);
+ }
+ }
+ return Changed;
+}
+
/// If we have a select with an equality comparison, then we know the value in
/// one of the arms of the select. See if substituting this value into an arm
/// and simplifying the result yields the same value as the other arm.
@@ -1157,10 +1222,7 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp,
/// TODO: Wrapping flags could be preserved in some cases with better analysis.
Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
ICmpInst &Cmp) {
- // Value equivalence substitution requires an all-or-nothing replacement.
- // It does not make sense for a vector compare where each lane is chosen
- // independently.
- if (!Cmp.isEquality() || Cmp.getType()->isVectorTy())
+ if (!Cmp.isEquality())
return nullptr;
// Canonicalize the pattern to ICMP_EQ by swapping the select operands.
@@ -1189,15 +1251,11 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel,
// with different operands, which should not cause side-effects or trigger
// undefined behavior). Only do this if CmpRHS is a constant, as
// profitability is not clear for other cases.
- // FIXME: The replacement could be performed recursively.
- if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()))
- if (auto *I = dyn_cast<Instruction>(TrueVal))
- if (I->hasOneUse() && isSafeToSpeculativelyExecute(I))
- for (Use &U : I->operands())
- if (U == CmpLHS) {
- replaceUse(U, CmpRHS);
- return &Sel;
- }
+ // FIXME: Support vectors.
+ if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) &&
+ !Cmp.getType()->isVectorTy())
+ if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this))
+ return &Sel;
}
if (TrueVal != CmpRHS &&
isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT))
@@ -1371,7 +1429,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0,
C2->getType()->getScalarSizeInBits()))))
return nullptr; // Can't do, have signed max element[s].
C2 = InstCombiner::AddOne(C2);
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case ICmpInst::Predicate::ICMP_SGE:
// Also non-canonical, but here we don't need to change C2,
// so we don't have any restrictions on C2, so we can just handle it.
@@ -2307,6 +2365,41 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel,
}
Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
+ if (!isa<VectorType>(Sel.getType()))
+ return nullptr;
+
+ Value *Cond = Sel.getCondition();
+ Value *TVal = Sel.getTrueValue();
+ Value *FVal = Sel.getFalseValue();
+ Value *C, *X, *Y;
+
+ if (match(Cond, m_VecReverse(m_Value(C)))) {
+ auto createSelReverse = [&](Value *C, Value *X, Value *Y) {
+ Value *V = Builder.CreateSelect(C, X, Y, Sel.getName(), &Sel);
+ if (auto *I = dyn_cast<Instruction>(V))
+ I->copyIRFlags(&Sel);
+ Module *M = Sel.getModule();
+ Function *F = Intrinsic::getDeclaration(
+ M, Intrinsic::experimental_vector_reverse, V->getType());
+ return CallInst::Create(F, V);
+ };
+
+ if (match(TVal, m_VecReverse(m_Value(X)))) {
+ // select rev(C), rev(X), rev(Y) --> rev(select C, X, Y)
+ if (match(FVal, m_VecReverse(m_Value(Y))) &&
+ (Cond->hasOneUse() || TVal->hasOneUse() || FVal->hasOneUse()))
+ return createSelReverse(C, X, Y);
+
+ // select rev(C), rev(X), FValSplat --> rev(select C, X, FValSplat)
+ if ((Cond->hasOneUse() || TVal->hasOneUse()) && isSplatValue(FVal))
+ return createSelReverse(C, X, FVal);
+ }
+ // select rev(C), TValSplat, rev(Y) --> rev(select C, TValSplat, Y)
+ else if (isSplatValue(TVal) && match(FVal, m_VecReverse(m_Value(Y))) &&
+ (Cond->hasOneUse() || FVal->hasOneUse()))
+ return createSelReverse(C, TVal, Y);
+ }
+
auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType());
if (!VecTy)
return nullptr;
@@ -2323,10 +2416,6 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) {
// A select of a "select shuffle" with a common operand can be rearranged
// to select followed by "select shuffle". Because of poison, this only works
// in the case of a shuffle with no undefined mask elements.
- Value *Cond = Sel.getCondition();
- Value *TVal = Sel.getTrueValue();
- Value *FVal = Sel.getFalseValue();
- Value *X, *Y;
ArrayRef<int> Mask;
if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) &&
!is_contained(Mask, UndefMaskElem) &&
@@ -2472,7 +2561,7 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op,
assert(Op->getType()->isIntOrIntVectorTy(1) &&
"Op must be either i1 or vector of i1.");
- Optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
+ std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd);
if (!Res)
return nullptr;
@@ -2510,6 +2599,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
InstCombinerImpl &IC) {
Value *CondVal = SI.getCondition();
+ bool ChangedFMF = false;
for (bool Swap : {false, true}) {
Value *TrueVal = SI.getTrueValue();
Value *X = SI.getFalseValue();
@@ -2534,13 +2624,33 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
}
}
+ if (!match(TrueVal, m_FNeg(m_Specific(X))))
+ return nullptr;
+
+ // Forward-propagate nnan and ninf from the fneg to the select.
+ // If all inputs are not those values, then the select is not either.
+ // Note: nsz is defined differently, so it may not be correct to propagate.
+ FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags();
+ if (FMF.noNaNs() && !SI.hasNoNaNs()) {
+ SI.setHasNoNaNs(true);
+ ChangedFMF = true;
+ }
+ if (FMF.noInfs() && !SI.hasNoInfs()) {
+ SI.setHasNoInfs(true);
+ ChangedFMF = true;
+ }
+
// With nsz, when 'Swap' is false:
// fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X)
// fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x)
// when 'Swap' is true:
// fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X)
// fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X)
- if (!match(TrueVal, m_FNeg(m_Specific(X))) || !SI.hasNoSignedZeros())
+ //
+ // Note: We require "nnan" for this fold because fcmp ignores the signbit
+ // of NAN, but IEEE-754 specifies the signbit of NAN values with
+ // fneg/fabs operations.
+ if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs())
return nullptr;
if (Swap)
@@ -2563,7 +2673,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI,
}
}
- return nullptr;
+ return ChangedFMF ? &SI : nullptr;
}
// Match the following IR pattern:
@@ -2602,10 +2712,14 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst))))
return nullptr;
+ // Match even if the AND and ADD are swapped.
const APInt *BiasCst, *HighBitMaskCst;
if (!match(XBiasedHighBits,
m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)),
- m_APIntAllowUndef(HighBitMaskCst))))
+ m_APIntAllowUndef(HighBitMaskCst))) &&
+ !match(XBiasedHighBits,
+ m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)),
+ m_APIntAllowUndef(BiasCst))))
return nullptr;
if (!LowBitMaskCst->isMask())
@@ -2635,200 +2749,392 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI,
return R;
}
-Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
+namespace {
+struct DecomposedSelect {
+ Value *Cond = nullptr;
+ Value *TrueVal = nullptr;
+ Value *FalseVal = nullptr;
+};
+} // namespace
+
+/// Look for patterns like
+/// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false
+/// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f
+/// %outer.sel = select i1 %outer.cond, i8 %outer.sel.t, i8 %inner.sel
+/// and rewrite it as
+/// %inner.sel = select i1 %cond.alternative, i8 %sel.outer.t, i8 %sel.inner.t
+/// %sel.outer = select i1 %cond.inner, i8 %inner.sel, i8 %sel.inner.f
+static Instruction *foldNestedSelects(SelectInst &OuterSelVal,
+ InstCombiner::BuilderTy &Builder) {
+ // We must start with a `select`.
+ DecomposedSelect OuterSel;
+ match(&OuterSelVal,
+ m_Select(m_Value(OuterSel.Cond), m_Value(OuterSel.TrueVal),
+ m_Value(OuterSel.FalseVal)));
+
+ // Canonicalize inversion of the outermost `select`'s condition.
+ if (match(OuterSel.Cond, m_Not(m_Value(OuterSel.Cond))))
+ std::swap(OuterSel.TrueVal, OuterSel.FalseVal);
+
+ // The condition of the outermost select must be an `and`/`or`.
+ if (!match(OuterSel.Cond, m_c_LogicalOp(m_Value(), m_Value())))
+ return nullptr;
+
+ // Depending on the logical op, inner select might be in different hand.
+ bool IsAndVariant = match(OuterSel.Cond, m_LogicalAnd());
+ Value *InnerSelVal = IsAndVariant ? OuterSel.FalseVal : OuterSel.TrueVal;
+
+ // Profitability check - avoid increasing instruction count.
+ if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}),
+ [](Value *V) { return V->hasOneUse(); }))
+ return nullptr;
+
+ // The appropriate hand of the outermost `select` must be a select itself.
+ DecomposedSelect InnerSel;
+ if (!match(InnerSelVal,
+ m_Select(m_Value(InnerSel.Cond), m_Value(InnerSel.TrueVal),
+ m_Value(InnerSel.FalseVal))))
+ return nullptr;
+
+ // Canonicalize inversion of the innermost `select`'s condition.
+ if (match(InnerSel.Cond, m_Not(m_Value(InnerSel.Cond))))
+ std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
+
+ Value *AltCond = nullptr;
+ auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) {
+ return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond)));
+ };
+
+ // Finally, match the condition that was driving the outermost `select`,
+ // it should be a logical operation between the condition that was driving
+ // the innermost `select` (after accounting for the possible inversions
+ // of the condition), and some other condition.
+ if (matchOuterCond(m_Specific(InnerSel.Cond))) {
+ // Done!
+ } else if (Value * NotInnerCond; matchOuterCond(m_CombineAnd(
+ m_Not(m_Specific(InnerSel.Cond)), m_Value(NotInnerCond)))) {
+ // Done!
+ std::swap(InnerSel.TrueVal, InnerSel.FalseVal);
+ InnerSel.Cond = NotInnerCond;
+ } else // Not the pattern we were looking for.
+ return nullptr;
+
+ Value *SelInner = Builder.CreateSelect(
+ AltCond, IsAndVariant ? OuterSel.TrueVal : InnerSel.FalseVal,
+ IsAndVariant ? InnerSel.TrueVal : OuterSel.FalseVal);
+ SelInner->takeName(InnerSelVal);
+ return SelectInst::Create(InnerSel.Cond,
+ IsAndVariant ? SelInner : InnerSel.TrueVal,
+ !IsAndVariant ? SelInner : InnerSel.FalseVal);
+}
+
+Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) {
Value *CondVal = SI.getCondition();
Value *TrueVal = SI.getTrueValue();
Value *FalseVal = SI.getFalseValue();
Type *SelType = SI.getType();
- if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal,
- SQ.getWithInstruction(&SI)))
- return replaceInstUsesWith(SI, V);
-
- if (Instruction *I = canonicalizeSelectToShuffle(SI))
- return I;
-
- if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this))
- return I;
-
// Avoid potential infinite loops by checking for non-constant condition.
// TODO: Can we assert instead by improving canonicalizeSelectToShuffle()?
// Scalar select must have simplified?
- if (SelType->isIntOrIntVectorTy(1) && !isa<Constant>(CondVal) &&
- TrueVal->getType() == CondVal->getType()) {
- // Folding select to and/or i1 isn't poison safe in general. impliesPoison
- // checks whether folding it does not convert a well-defined value into
- // poison.
- if (match(TrueVal, m_One())) {
- if (impliesPoison(FalseVal, CondVal)) {
- // Change: A = select B, true, C --> A = or B, C
- return BinaryOperator::CreateOr(CondVal, FalseVal);
- }
+ if (!SelType->isIntOrIntVectorTy(1) || isa<Constant>(CondVal) ||
+ TrueVal->getType() != CondVal->getType())
+ return nullptr;
- if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
- if (auto *RHS = dyn_cast<FCmpInst>(FalseVal))
- if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false,
- /*IsSelectLogical*/ true))
- return replaceInstUsesWith(SI, V);
- }
- if (match(FalseVal, m_Zero())) {
- if (impliesPoison(TrueVal, CondVal)) {
- // Change: A = select B, C, false --> A = and B, C
- return BinaryOperator::CreateAnd(CondVal, TrueVal);
- }
+ auto *One = ConstantInt::getTrue(SelType);
+ auto *Zero = ConstantInt::getFalse(SelType);
+ Value *A, *B, *C, *D;
- if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
- if (auto *RHS = dyn_cast<FCmpInst>(TrueVal))
- if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true,
- /*IsSelectLogical*/ true))
- return replaceInstUsesWith(SI, V);
+ // Folding select to and/or i1 isn't poison safe in general. impliesPoison
+ // checks whether folding it does not convert a well-defined value into
+ // poison.
+ if (match(TrueVal, m_One())) {
+ if (impliesPoison(FalseVal, CondVal)) {
+ // Change: A = select B, true, C --> A = or B, C
+ return BinaryOperator::CreateOr(CondVal, FalseVal);
}
- auto *One = ConstantInt::getTrue(SelType);
- auto *Zero = ConstantInt::getFalse(SelType);
+ if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
+ if (auto *RHS = dyn_cast<FCmpInst>(FalseVal))
+ if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false,
+ /*IsSelectLogical*/ true))
+ return replaceInstUsesWith(SI, V);
- // We match the "full" 0 or 1 constant here to avoid a potential infinite
- // loop with vectors that may have undefined/poison elements.
- // select a, false, b -> select !a, b, false
- if (match(TrueVal, m_Specific(Zero))) {
- Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return SelectInst::Create(NotCond, FalseVal, Zero);
+ // (A && B) || (C && B) --> (A || C) && B
+ if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) &&
+ match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) &&
+ (CondVal->hasOneUse() || FalseVal->hasOneUse())) {
+ bool CondLogicAnd = isa<SelectInst>(CondVal);
+ bool FalseLogicAnd = isa<SelectInst>(FalseVal);
+ auto AndFactorization = [&](Value *Common, Value *InnerCond,
+ Value *InnerVal,
+ bool SelFirst = false) -> Instruction * {
+ Value *InnerSel = Builder.CreateSelect(InnerCond, One, InnerVal);
+ if (SelFirst)
+ std::swap(Common, InnerSel);
+ if (FalseLogicAnd || (CondLogicAnd && Common == A))
+ return SelectInst::Create(Common, InnerSel, Zero);
+ else
+ return BinaryOperator::CreateAnd(Common, InnerSel);
+ };
+
+ if (A == C)
+ return AndFactorization(A, B, D);
+ if (A == D)
+ return AndFactorization(A, B, C);
+ if (B == C)
+ return AndFactorization(B, A, D);
+ if (B == D)
+ return AndFactorization(B, A, C, CondLogicAnd && FalseLogicAnd);
}
- // select a, b, true -> select !a, true, b
- if (match(FalseVal, m_Specific(One))) {
- Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
- return SelectInst::Create(NotCond, One, TrueVal);
+ }
+
+ if (match(FalseVal, m_Zero())) {
+ if (impliesPoison(TrueVal, CondVal)) {
+ // Change: A = select B, C, false --> A = and B, C
+ return BinaryOperator::CreateAnd(CondVal, TrueVal);
+ }
+
+ if (auto *LHS = dyn_cast<FCmpInst>(CondVal))
+ if (auto *RHS = dyn_cast<FCmpInst>(TrueVal))
+ if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true,
+ /*IsSelectLogical*/ true))
+ return replaceInstUsesWith(SI, V);
+
+ // (A || B) && (C || B) --> (A && C) || B
+ if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
+ match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) &&
+ (CondVal->hasOneUse() || TrueVal->hasOneUse())) {
+ bool CondLogicOr = isa<SelectInst>(CondVal);
+ bool TrueLogicOr = isa<SelectInst>(TrueVal);
+ auto OrFactorization = [&](Value *Common, Value *InnerCond,
+ Value *InnerVal,
+ bool SelFirst = false) -> Instruction * {
+ Value *InnerSel = Builder.CreateSelect(InnerCond, InnerVal, Zero);
+ if (SelFirst)
+ std::swap(Common, InnerSel);
+ if (TrueLogicOr || (CondLogicOr && Common == A))
+ return SelectInst::Create(Common, One, InnerSel);
+ else
+ return BinaryOperator::CreateOr(Common, InnerSel);
+ };
+
+ if (A == C)
+ return OrFactorization(A, B, D);
+ if (A == D)
+ return OrFactorization(A, B, C);
+ if (B == C)
+ return OrFactorization(B, A, D);
+ if (B == D)
+ return OrFactorization(B, A, C, CondLogicOr && TrueLogicOr);
}
+ }
+
+ // We match the "full" 0 or 1 constant here to avoid a potential infinite
+ // loop with vectors that may have undefined/poison elements.
+ // select a, false, b -> select !a, b, false
+ if (match(TrueVal, m_Specific(Zero))) {
+ Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
+ return SelectInst::Create(NotCond, FalseVal, Zero);
+ }
+ // select a, b, true -> select !a, true, b
+ if (match(FalseVal, m_Specific(One))) {
+ Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName());
+ return SelectInst::Create(NotCond, One, TrueVal);
+ }
+
+ // DeMorgan in select form: !a && !b --> !(a || b)
+ // select !a, !b, false --> not (select a, true, b)
+ if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
+ (CondVal->hasOneUse() || TrueVal->hasOneUse()) &&
+ !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
+ return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B));
- // select a, a, b -> select a, true, b
- if (CondVal == TrueVal)
- return replaceOperand(SI, 1, One);
- // select a, b, a -> select a, b, false
- if (CondVal == FalseVal)
- return replaceOperand(SI, 2, Zero);
+ // DeMorgan in select form: !a || !b --> !(a && b)
+ // select !a, true, !b --> not (select a, b, false)
+ if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
+ (CondVal->hasOneUse() || FalseVal->hasOneUse()) &&
+ !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
+ return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero));
- // select a, !a, b -> select !a, b, false
- if (match(TrueVal, m_Not(m_Specific(CondVal))))
- return SelectInst::Create(TrueVal, FalseVal, Zero);
- // select a, b, !a -> select !a, true, b
- if (match(FalseVal, m_Not(m_Specific(CondVal))))
- return SelectInst::Create(FalseVal, One, TrueVal);
+ // select (select a, true, b), true, b -> select a, true, b
+ if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
+ match(TrueVal, m_One()) && match(FalseVal, m_Specific(B)))
+ return replaceOperand(SI, 0, A);
+ // select (select a, b, false), b, false -> select a, b, false
+ if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
+ match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
+ return replaceOperand(SI, 0, A);
- Value *A, *B;
+ // ~(A & B) & (A | B) --> A ^ B
+ if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))),
+ m_c_LogicalOr(m_Deferred(A), m_Deferred(B)))))
+ return BinaryOperator::CreateXor(A, B);
- // DeMorgan in select form: !a && !b --> !(a || b)
- // select !a, !b, false --> not (select a, true, b)
- if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
- (CondVal->hasOneUse() || TrueVal->hasOneUse()) &&
- !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
- return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B));
+ // select (~a | c), a, b -> and a, (or c, freeze(b))
+ if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) &&
+ CondVal->hasOneUse()) {
+ FalseVal = Builder.CreateFreeze(FalseVal);
+ return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal));
+ }
+ // select (~c & b), a, b -> and b, (or freeze(a), c)
+ if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) &&
+ CondVal->hasOneUse()) {
+ TrueVal = Builder.CreateFreeze(TrueVal);
+ return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
+ }
+
+ if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
+ Use *Y = nullptr;
+ bool IsAnd = match(FalseVal, m_Zero()) ? true : false;
+ Value *Op1 = IsAnd ? TrueVal : FalseVal;
+ if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) {
+ auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr");
+ InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser()));
+ replaceUse(*Y, FI);
+ return replaceInstUsesWith(SI, Op1);
+ }
+
+ if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
+ if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
+ /* IsAnd */ IsAnd))
+ return I;
- // DeMorgan in select form: !a || !b --> !(a && b)
- // select !a, true, !b --> not (select a, b, false)
- if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) &&
- (CondVal->hasOneUse() || FalseVal->hasOneUse()) &&
- !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr()))
- return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero));
+ if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
+ if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
+ if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
+ /* IsLogical */ true))
+ return replaceInstUsesWith(SI, V);
+ }
- // select (select a, true, b), true, b -> select a, true, b
- if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
- match(TrueVal, m_One()) && match(FalseVal, m_Specific(B)))
+ // select (a || b), c, false -> select a, c, false
+ // select c, (a || b), false -> select c, a, false
+ // if c implies that b is false.
+ if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
+ match(FalseVal, m_Zero())) {
+ std::optional<bool> Res = isImpliedCondition(TrueVal, B, DL);
+ if (Res && *Res == false)
return replaceOperand(SI, 0, A);
- // select (select a, b, false), b, false -> select a, b, false
- if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
- match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero()))
+ }
+ if (match(TrueVal, m_LogicalOr(m_Value(A), m_Value(B))) &&
+ match(FalseVal, m_Zero())) {
+ std::optional<bool> Res = isImpliedCondition(CondVal, B, DL);
+ if (Res && *Res == false)
+ return replaceOperand(SI, 1, A);
+ }
+ // select c, true, (a && b) -> select c, true, a
+ // select (a && b), true, c -> select a, true, c
+ // if c = false implies that b = true
+ if (match(TrueVal, m_One()) &&
+ match(FalseVal, m_LogicalAnd(m_Value(A), m_Value(B)))) {
+ std::optional<bool> Res = isImpliedCondition(CondVal, B, DL, false);
+ if (Res && *Res == true)
+ return replaceOperand(SI, 2, A);
+ }
+ if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) &&
+ match(TrueVal, m_One())) {
+ std::optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false);
+ if (Res && *Res == true)
return replaceOperand(SI, 0, A);
+ }
+ if (match(TrueVal, m_One())) {
Value *C;
- // select (~a | c), a, b -> and a, (or c, freeze(b))
- if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) &&
- CondVal->hasOneUse()) {
- FalseVal = Builder.CreateFreeze(FalseVal);
- return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal));
- }
- // select (~c & b), a, b -> and b, (or freeze(a), c)
- if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) &&
- CondVal->hasOneUse()) {
- TrueVal = Builder.CreateFreeze(TrueVal);
- return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal));
+
+ // (C && A) || (!C && B) --> sel C, A, B
+ // (A && C) || (!C && B) --> sel C, A, B
+ // (C && A) || (B && !C) --> sel C, A, B
+ // (A && C) || (B && !C) --> sel C, A, B (may require freeze)
+ if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(B))) &&
+ match(CondVal, m_c_LogicalAnd(m_Specific(C), m_Value(A)))) {
+ auto *SelCond = dyn_cast<SelectInst>(CondVal);
+ auto *SelFVal = dyn_cast<SelectInst>(FalseVal);
+ bool MayNeedFreeze = SelCond && SelFVal &&
+ match(SelFVal->getTrueValue(),
+ m_Not(m_Specific(SelCond->getTrueValue())));
+ if (MayNeedFreeze)
+ C = Builder.CreateFreeze(C);
+ return SelectInst::Create(C, A, B);
}
- if (!SelType->isVectorTy()) {
- if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 1, S);
- if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ,
- /* AllowRefinement */ true))
- return replaceOperand(SI, 2, S);
+ // (!C && A) || (C && B) --> sel C, B, A
+ // (A && !C) || (C && B) --> sel C, B, A
+ // (!C && A) || (B && C) --> sel C, B, A
+ // (A && !C) || (B && C) --> sel C, B, A (may require freeze)
+ if (match(CondVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(A))) &&
+ match(FalseVal, m_c_LogicalAnd(m_Specific(C), m_Value(B)))) {
+ auto *SelCond = dyn_cast<SelectInst>(CondVal);
+ auto *SelFVal = dyn_cast<SelectInst>(FalseVal);
+ bool MayNeedFreeze = SelCond && SelFVal &&
+ match(SelCond->getTrueValue(),
+ m_Not(m_Specific(SelFVal->getTrueValue())));
+ if (MayNeedFreeze)
+ C = Builder.CreateFreeze(C);
+ return SelectInst::Create(C, B, A);
}
+ }
- if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) {
- Use *Y = nullptr;
- bool IsAnd = match(FalseVal, m_Zero()) ? true : false;
- Value *Op1 = IsAnd ? TrueVal : FalseVal;
- if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) {
- auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr");
- InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser()));
- replaceUse(*Y, FI);
- return replaceInstUsesWith(SI, Op1);
- }
+ return nullptr;
+}
- if (auto *Op1SI = dyn_cast<SelectInst>(Op1))
- if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI,
- /* IsAnd */ IsAnd))
- return I;
+Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
+ Value *CondVal = SI.getCondition();
+ Value *TrueVal = SI.getTrueValue();
+ Value *FalseVal = SI.getFalseValue();
+ Type *SelType = SI.getType();
- if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal))
- if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1))
- if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd,
- /* IsLogical */ true))
- return replaceInstUsesWith(SI, V);
- }
+ if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal,
+ SQ.getWithInstruction(&SI)))
+ return replaceInstUsesWith(SI, V);
- // select (select a, true, b), c, false -> select a, c, false
- // select c, (select a, true, b), false -> select c, a, false
- // if c implies that b is false.
- if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
- match(FalseVal, m_Zero())) {
- Optional<bool> Res = isImpliedCondition(TrueVal, B, DL);
- if (Res && *Res == false)
- return replaceOperand(SI, 0, A);
- }
- if (match(TrueVal, m_Select(m_Value(A), m_One(), m_Value(B))) &&
- match(FalseVal, m_Zero())) {
- Optional<bool> Res = isImpliedCondition(CondVal, B, DL);
- if (Res && *Res == false)
- return replaceOperand(SI, 1, A);
- }
- // select c, true, (select a, b, false) -> select c, true, a
- // select (select a, b, false), true, c -> select a, true, c
- // if c = false implies that b = true
- if (match(TrueVal, m_One()) &&
- match(FalseVal, m_Select(m_Value(A), m_Value(B), m_Zero()))) {
- Optional<bool> Res = isImpliedCondition(CondVal, B, DL, false);
- if (Res && *Res == true)
- return replaceOperand(SI, 2, A);
- }
- if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) &&
- match(TrueVal, m_One())) {
- Optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false);
- if (Res && *Res == true)
- return replaceOperand(SI, 0, A);
- }
+ if (Instruction *I = canonicalizeSelectToShuffle(SI))
+ return I;
- // sel (sel c, a, false), true, (sel !c, b, false) -> sel c, a, b
- // sel (sel !c, a, false), true, (sel c, b, false) -> sel c, b, a
- Value *C1, *C2;
- if (match(CondVal, m_Select(m_Value(C1), m_Value(A), m_Zero())) &&
- match(TrueVal, m_One()) &&
- match(FalseVal, m_Select(m_Value(C2), m_Value(B), m_Zero()))) {
- if (match(C2, m_Not(m_Specific(C1)))) // first case
- return SelectInst::Create(C1, A, B);
- else if (match(C1, m_Not(m_Specific(C2)))) // second case
- return SelectInst::Create(C2, B, A);
- }
+ if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this))
+ return I;
+
+ // If the type of select is not an integer type or if the condition and
+ // the selection type are not both scalar nor both vector types, there is no
+ // point in attempting to match these patterns.
+ Type *CondType = CondVal->getType();
+ if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() &&
+ CondType->isVectorTy() == SelType->isVectorTy()) {
+ if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal,
+ ConstantInt::getTrue(CondType), SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(SI, 1, S);
+
+ if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal,
+ ConstantInt::getFalse(CondType), SQ,
+ /* AllowRefinement */ true))
+ return replaceOperand(SI, 2, S);
+
+ // Handle patterns involving sext/zext + not explicitly,
+ // as simplifyWithOpReplaced() only looks past one instruction.
+ Value *NotCond;
+
+ // select a, sext(!a), b -> select !a, b, 0
+ // select a, zext(!a), b -> select !a, b, 0
+ if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond),
+ m_Not(m_Specific(CondVal))))))
+ return SelectInst::Create(NotCond, FalseVal,
+ Constant::getNullValue(SelType));
+
+ // select a, b, zext(!a) -> select !a, 1, b
+ if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond),
+ m_Not(m_Specific(CondVal))))))
+ return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal);
+
+ // select a, b, sext(!a) -> select !a, -1, b
+ if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond),
+ m_Not(m_Specific(CondVal))))))
+ return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType),
+ TrueVal);
}
+ if (Instruction *R = foldSelectOfBools(SI))
+ return R;
+
// Selecting between two integer or vector splat integer constants?
//
// Note that we don't handle a scalar select of vectors:
@@ -2881,8 +3187,23 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal);
return replaceInstUsesWith(SI, NewSel);
}
+ }
+ }
+
+ if (isa<FPMathOperator>(SI)) {
+ // TODO: Try to forward-propagate FMF from select arms to the select.
+
+ // Canonicalize select of FP values where NaN and -0.0 are not valid as
+ // minnum/maxnum intrinsics.
+ if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
+ Value *X, *Y;
+ if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
+ return replaceInstUsesWith(
+ SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI));
- // NOTE: if we wanted to, this is where to detect MIN/MAX
+ if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y))))
+ return replaceInstUsesWith(
+ SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI));
}
}
@@ -2997,19 +3318,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
- // Canonicalize select of FP values where NaN and -0.0 are not valid as
- // minnum/maxnum intrinsics.
- if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) {
- Value *X, *Y;
- if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y))))
- return replaceInstUsesWith(
- SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI));
-
- if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y))))
- return replaceInstUsesWith(
- SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI));
- }
-
// See if we can fold the select into a phi node if the condition is a select.
if (auto *PN = dyn_cast<PHINode>(SI.getCondition()))
// The true/false values have to be live in the PHI predecessor's blocks.
@@ -3198,5 +3506,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) {
}
}
+ if (Instruction *I = foldNestedSelects(SI, Builder))
+ return I;
+
+ // Match logical variants of the pattern,
+ // and transform them iff that gets rid of inversions.
+ // (~x) | y --> ~(x & (~y))
+ // (~x) & y --> ~(x | (~y))
+ if (sinkNotIntoOtherHandOfLogicalOp(SI))
+ return &SI;
+
return nullptr;
}
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index 13c98b935adf..ec505381cc86 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -346,8 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
Value *X, *Y;
auto matchFirstShift = [&](Value *V) {
APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits());
- return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) &&
- match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) &&
+ return match(V,
+ m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) &&
match(ConstantExpr::getAdd(C0, C1),
m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold));
};
@@ -363,7 +363,7 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I,
// shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1)
Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1);
Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC);
- Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1));
+ Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1);
return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2);
}
@@ -730,13 +730,34 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
return BinaryOperator::Create(
I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X);
+ bool IsLeftShift = I.getOpcode() == Instruction::Shl;
+ Type *Ty = I.getType();
+ unsigned TypeBits = Ty->getScalarSizeInBits();
+
+ // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC)
+ // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC)
+ const APInt *DivC;
+ if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) &&
+ match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() &&
+ !DivC->isMinSignedValue()) {
+ Constant *NegDivC = ConstantInt::get(Ty, -(*DivC));
+ ICmpInst::Predicate Pred =
+ DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE;
+ Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC);
+ auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt
+ : Instruction::ZExt;
+ return CastInst::Create(ExtOpcode, Cmp, Ty);
+ }
+
const APInt *Op1C;
if (!match(C1, m_APInt(Op1C)))
return nullptr;
+ assert(!Op1C->uge(TypeBits) &&
+ "Shift over the type width should have been removed already");
+
// See if we can propagate this shift into the input, this covers the trivial
// cast of lshr(shl(x,c1),c2) as well as other more complex cases.
- bool IsLeftShift = I.getOpcode() == Instruction::Shl;
if (I.getOpcode() != Instruction::AShr &&
canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) {
LLVM_DEBUG(
@@ -748,14 +769,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL));
}
- // See if we can simplify any instructions used by the instruction whose sole
- // purpose is to compute bits we don't care about.
- Type *Ty = I.getType();
- unsigned TypeBits = Ty->getScalarSizeInBits();
- assert(!Op1C->uge(TypeBits) &&
- "Shift over the type width should have been removed already");
- (void)TypeBits;
-
if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I))
return FoldedShift;
@@ -826,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1,
return nullptr;
}
+// Tries to perform
+// (lshr (add (zext X), (zext Y)), K)
+// -> (icmp ult (add X, Y), X)
+// where
+// - The add's operands are zexts from a K-bits integer to a bigger type.
+// - The add is only used by the shr, or by iK (or narrower) truncates.
+// - The lshr type has more than 2 bits (other types are boolean math).
+// - K > 1
+// note that
+// - The resulting add cannot have nuw/nsw, else on overflow we get a
+// poison value and the transform isn't legal anymore.
+Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
+ assert(I.getOpcode() == Instruction::LShr);
+
+ Value *Add = I.getOperand(0);
+ Value *ShiftAmt = I.getOperand(1);
+ Type *Ty = I.getType();
+
+ if (Ty->getScalarSizeInBits() < 3)
+ return nullptr;
+
+ const APInt *ShAmtAPInt = nullptr;
+ Value *X = nullptr, *Y = nullptr;
+ if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) ||
+ !match(Add,
+ m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y))))))
+ return nullptr;
+
+ const unsigned ShAmt = ShAmtAPInt->getZExtValue();
+ if (ShAmt == 1)
+ return nullptr;
+
+ // X/Y are zexts from `ShAmt`-sized ints.
+ if (X->getType()->getScalarSizeInBits() != ShAmt ||
+ Y->getType()->getScalarSizeInBits() != ShAmt)
+ return nullptr;
+
+ // Make sure that `Add` is only used by `I` and `ShAmt`-truncates.
+ if (!Add->hasOneUse()) {
+ for (User *U : Add->users()) {
+ if (U == &I)
+ continue;
+
+ TruncInst *Trunc = dyn_cast<TruncInst>(U);
+ if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt)
+ return nullptr;
+ }
+ }
+
+ // Insert at Add so that the newly created `NarrowAdd` will dominate it's
+ // users (i.e. `Add`'s users).
+ Instruction *AddInst = cast<Instruction>(Add);
+ Builder.SetInsertPoint(AddInst);
+
+ Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed");
+ Value *Overflow =
+ Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow");
+
+ // Replace the uses of the original add with a zext of the
+ // NarrowAdd's result. Note that all users at this stage are known to
+ // be ShAmt-sized truncs, or the lshr itself.
+ if (!Add->hasOneUse())
+ replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty));
+
+ // Replace the LShr with a zext of the overflow check.
+ return new ZExtInst(Overflow, Ty);
+}
+
Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
const SimplifyQuery Q = SQ.getWithInstruction(&I);
@@ -1046,11 +1127,21 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
}
}
- // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
- if (match(Op0, m_One()) &&
- match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
- return BinaryOperator::CreateLShr(
- ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+ if (match(Op0, m_One())) {
+ // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1
+ if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X))))
+ return BinaryOperator::CreateLShr(
+ ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X);
+
+ // The only way to shift out the 1 is with an over-shift, so that would
+ // be poison with or without "nuw". Undef is excluded because (undef << X)
+ // is not undef (it is zero).
+ Constant *ConstantOne = cast<Constant>(Op0);
+ if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) {
+ I.setHasNoUnsignedWrap();
+ return &I;
+ }
+ }
return nullptr;
}
@@ -1068,10 +1159,17 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
Type *Ty = I.getType();
+ Value *X;
const APInt *C;
+ unsigned BitWidth = Ty->getScalarSizeInBits();
+
+ // (iN (~X) u>> (N - 1)) --> zext (X > -1)
+ if (match(Op0, m_OneUse(m_Not(m_Value(X)))) &&
+ match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)))
+ return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty);
+
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();
- unsigned BitWidth = Ty->getScalarSizeInBits();
auto *II = dyn_cast<IntrinsicInst>(Op0);
if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC &&
(II->getIntrinsicID() == Intrinsic::ctlz ||
@@ -1276,6 +1374,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
}
}
+ // Reduce add-carry of bools to logic:
+ // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY)
+ Value *BoolX, *BoolY;
+ if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) &&
+ match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) &&
+ BoolX->getType()->isIntOrIntVectorTy(1) &&
+ BoolY->getType()->isIntOrIntVectorTy(1) &&
+ (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) {
+ Value *And = Builder.CreateAnd(BoolX, BoolY);
+ return new ZExtInst(And, Ty);
+ }
+
// If the shifted-out value is known-zero, then this is an exact shift.
if (!I.isExact() &&
MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) {
@@ -1285,13 +1395,15 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) {
}
// Transform (x << y) >> y to x & (-1 >> y)
- Value *X;
if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) {
Constant *AllOnes = ConstantInt::getAllOnesValue(Ty);
Value *Mask = Builder.CreateLShr(AllOnes, Op1);
return BinaryOperator::CreateAnd(Mask, X);
}
+ if (Instruction *Overflow = foldLShrOverflowBit(I))
+ return Overflow;
+
return nullptr;
}
@@ -1469,8 +1581,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) {
return R;
// See if we can turn a signed shr into an unsigned shr.
- if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I))
- return BinaryOperator::CreateLShr(Op0, Op1);
+ if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) {
+ Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1);
+ Lshr->setIsExact(I.isExact());
+ return Lshr;
+ }
// ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1
if (match(Op0, m_OneUse(m_Not(m_Value(X))))) {
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
index febd0f51d25f..77d675422966 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp
@@ -130,9 +130,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == MaxAnalysisRecursionDepth)
return nullptr;
- if (isa<ScalableVectorType>(VTy))
- return nullptr;
-
Instruction *I = dyn_cast<Instruction>(V);
if (!I) {
computeKnownBits(V, Known, Depth, CxtI);
@@ -154,6 +151,20 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (Depth == 0 && !V->hasOneUse())
DemandedMask.setAllBits();
+ // Update flags after simplifying an operand based on the fact that some high
+ // order bits are not demanded.
+ auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I,
+ unsigned NLZ) {
+ if (NLZ > 0) {
+ // Disable the nsw and nuw flags here: We can no longer guarantee that
+ // we won't wrap after simplification. Removing the nsw/nuw flags is
+ // legal here because the top bit is not demanded.
+ I->setHasNoSignedWrap(false);
+ I->setHasNoUnsignedWrap(false);
+ }
+ return I;
+ };
+
// If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care
// about the high bits of the operands.
auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) {
@@ -165,13 +176,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) ||
ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) {
- if (NLZ > 0) {
- // Disable the nsw and nuw flags here: We can no longer guarantee that
- // we won't wrap after simplification. Removing the nsw/nuw flags is
- // legal here because the top bit is not demanded.
- I->setHasNoSignedWrap(false);
- I->setHasNoUnsignedWrap(false);
- }
+ disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
return true;
}
return false;
@@ -397,7 +402,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
}
}
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Instruction::ZExt: {
unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits();
@@ -416,7 +421,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
if (auto *DstVTy = dyn_cast<VectorType>(VTy)) {
if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) {
- if (cast<FixedVectorType>(DstVTy)->getNumElements() !=
+ if (isa<ScalableVectorType>(DstVTy) ||
+ isa<ScalableVectorType>(SrcVTy) ||
+ cast<FixedVectorType>(DstVTy)->getNumElements() !=
cast<FixedVectorType>(SrcVTy)->getNumElements())
// Don't touch a bitcast between vectors of different element counts.
return nullptr;
@@ -461,7 +468,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
assert(!Known.hasConflict() && "Bits known to be one AND zero?");
break;
}
- case Instruction::Add:
+ case Instruction::Add: {
if ((DemandedMask & 1) == 0) {
// If we do not need the low bit, try to convert bool math to logic:
// add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN
@@ -498,26 +505,68 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
return Builder.CreateSExt(Or, VTy);
}
}
- LLVM_FALLTHROUGH;
+
+ // Right fill the mask of bits for the operands to demand the most
+ // significant bit and all those below it.
+ unsigned NLZ = DemandedMask.countLeadingZeros();
+ APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
+ if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
+ return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
+
+ // If low order bits are not demanded and known to be zero in one operand,
+ // then we don't need to demand them from the other operand, since they
+ // can't cause overflow into any bits that are demanded in the result.
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ APInt DemandedFromLHS = DemandedFromOps;
+ DemandedFromLHS.clearLowBits(NTZ);
+ if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
+ SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1))
+ return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
+
+ // If we are known to be adding zeros to every bit below
+ // the highest demanded bit, we just return the other side.
+ if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
+ return I->getOperand(0);
+ if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
+ return I->getOperand(1);
+
+ // Otherwise just compute the known bits of the result.
+ bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
+ Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown);
+ break;
+ }
case Instruction::Sub: {
- APInt DemandedFromOps;
- if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps))
- return I;
+ // Right fill the mask of bits for the operands to demand the most
+ // significant bit and all those below it.
+ unsigned NLZ = DemandedMask.countLeadingZeros();
+ APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
+ if (ShrinkDemandedConstant(I, 1, DemandedFromOps) ||
+ SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1))
+ return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
+
+ // If low order bits are not demanded and are known to be zero in RHS,
+ // then we don't need to demand them from LHS, since they can't cause a
+ // borrow from any bits that are demanded in the result.
+ unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes();
+ APInt DemandedFromLHS = DemandedFromOps;
+ DemandedFromLHS.clearLowBits(NTZ);
+ if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) ||
+ SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1))
+ return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ);
- // If we are known to be adding/subtracting zeros to every bit below
+ // If we are known to be subtracting zeros from every bit below
// the highest demanded bit, we just return the other side.
if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
// We can't do this with the LHS for subtraction, unless we are only
// demanding the LSB.
- if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) &&
- DemandedFromOps.isSubsetOf(LHSKnown.Zero))
+ if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero))
return I->getOperand(1);
// Otherwise just compute the known bits of the result.
bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap();
- Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add,
- NSW, LHSKnown, RHSKnown);
+ Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown);
break;
}
case Instruction::Mul: {
@@ -747,18 +796,18 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
// UDiv doesn't demand low bits that are zero in the divisor.
const APInt *SA;
if (match(I->getOperand(1), m_APInt(SA))) {
- // If the shift is exact, then it does demand the low bits.
- if (cast<UDivOperator>(I)->isExact())
- break;
-
- // FIXME: Take the demanded mask of the result into account.
+ // TODO: Take the demanded mask of the result into account.
unsigned RHSTrailingZeros = SA->countTrailingZeros();
APInt DemandedMaskIn =
APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros);
- if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1))
+ if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) {
+ // We can't guarantee that "exact" is still true after changing the
+ // the dividend.
+ I->dropPoisonGeneratingFlags();
return I;
+ }
- // Propagate zero bits from the input.
+ // Increase high zero bits from the input.
Known.Zero.setHighBits(std::min(
BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros));
} else {
@@ -922,10 +971,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask,
}
default: {
// Handle target specific intrinsics
- Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
+ std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic(
*II, DemandedMask, Known, KnownBitsComputed);
if (V)
- return V.value();
+ return *V;
break;
}
}
@@ -962,11 +1011,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
// this instruction has a simpler value in that context.
switch (I->getOpcode()) {
case Instruction::And: {
- // If either the LHS or the RHS are Zero, the result is zero.
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
- computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
- CxtI);
-
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown & RHSKnown;
// If the client is only demanding bits that we know, return the known
@@ -975,8 +1021,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
return Constant::getIntegerValue(ITy, Known.One);
// If all of the demanded bits are known 1 on one side, return the other.
- // These bits cannot contribute to the result of the 'and' in this
- // context.
+ // These bits cannot contribute to the result of the 'and' in this context.
if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One))
@@ -985,14 +1030,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
break;
}
case Instruction::Or: {
- // We can simplify (X|Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
- // If either the LHS or the RHS are One, the result is One.
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
- computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
- CxtI);
-
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown | RHSKnown;
// If the client is only demanding bits that we know, return the known
@@ -1000,9 +1039,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
- // If all of the demanded bits are known zero on one side, return the
- // other. These bits cannot contribute to the result of the 'or' in this
- // context.
+ // We can simplify (X|Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+ // If all of the demanded bits are known zero on one side, return the other.
+ // These bits cannot contribute to the result of the 'or' in this context.
if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero))
@@ -1011,13 +1051,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
break;
}
case Instruction::Xor: {
- // We can simplify (X^Y) -> X or Y in the user's context if we know that
- // only bits from X or Y are demanded.
-
computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
- computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1,
- CxtI);
-
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
Known = LHSKnown ^ RHSKnown;
// If the client is only demanding bits that we know, return the known
@@ -1025,8 +1060,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
if (DemandedMask.isSubsetOf(Known.Zero | Known.One))
return Constant::getIntegerValue(ITy, Known.One);
- // If all of the demanded bits are known zero on one side, return the
- // other.
+ // We can simplify (X^Y) -> X or Y in the user's context if we know that
+ // only bits from X or Y are demanded.
+ // If all of the demanded bits are known zero on one side, return the other.
if (DemandedMask.isSubsetOf(RHSKnown.Zero))
return I->getOperand(0);
if (DemandedMask.isSubsetOf(LHSKnown.Zero))
@@ -1034,6 +1070,34 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits(
break;
}
+ case Instruction::Add: {
+ unsigned NLZ = DemandedMask.countLeadingZeros();
+ APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
+
+ // If an operand adds zeros to every bit below the highest demanded bit,
+ // that operand doesn't change the result. Return the other side.
+ computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
+ if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
+ return I->getOperand(0);
+
+ computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI);
+ if (DemandedFromOps.isSubsetOf(LHSKnown.Zero))
+ return I->getOperand(1);
+
+ break;
+ }
+ case Instruction::Sub: {
+ unsigned NLZ = DemandedMask.countLeadingZeros();
+ APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ);
+
+ // If an operand subtracts zeros from every bit below the highest demanded
+ // bit, that operand doesn't change the result. Return the other side.
+ computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI);
+ if (DemandedFromOps.isSubsetOf(RHSKnown.Zero))
+ return I->getOperand(0);
+
+ break;
+ }
case Instruction::AShr: {
// Compute the Known bits to simplify things downstream.
computeKnownBits(I, Known, Depth, CxtI);
@@ -1632,11 +1696,11 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V,
}
default: {
// Handle target specific intrinsics
- Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
+ std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic(
*II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
simplifyAndSetOp);
if (V)
- return V.value();
+ return *V;
break;
}
} // switch on IntrinsicID
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index b80c58183dd5..61e62adbe327 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -105,7 +105,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
// 2) Possibly more ExtractElements with the same index.
// 3) Another operand, which will feed back into the PHI.
Instruction *PHIUser = nullptr;
- for (auto U : PN->users()) {
+ for (auto *U : PN->users()) {
if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) {
if (EI.getIndexOperand() == EU->getIndexOperand())
Extracts.push_back(EU);
@@ -171,7 +171,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI,
}
}
- for (auto E : Extracts)
+ for (auto *E : Extracts)
replaceInstUsesWith(*E, scalarPHI);
return &EI;
@@ -187,13 +187,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
ElementCount NumElts =
cast<VectorType>(Ext.getVectorOperandType())->getElementCount();
Type *DestTy = Ext.getType();
+ unsigned DestWidth = DestTy->getPrimitiveSizeInBits();
bool IsBigEndian = DL.isBigEndian();
// If we are casting an integer to vector and extracting a portion, that is
// a shift-right and truncate.
- // TODO: Allow FP dest type by casting the trunc to FP?
- if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() &&
- isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) {
+ if (X->getType()->isIntegerTy()) {
assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) &&
"Expected fixed vector type for bitcast from scalar integer");
@@ -202,10 +201,18 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
// BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8
if (IsBigEndian)
ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC;
- unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits();
- if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) {
- Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset");
- return new TruncInst(Lshr, DestTy);
+ unsigned ShiftAmountC = ExtIndexC * DestWidth;
+ if (!ShiftAmountC ||
+ (isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) &&
+ Ext.getVectorOperand()->hasOneUse())) {
+ if (ShiftAmountC)
+ X = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset");
+ if (DestTy->isFloatingPointTy()) {
+ Type *DstIntTy = IntegerType::getIntNTy(X->getContext(), DestWidth);
+ Value *Trunc = Builder.CreateTrunc(X, DstIntTy);
+ return new BitCastInst(Trunc, DestTy);
+ }
+ return new TruncInst(X, DestTy);
}
}
@@ -278,7 +285,6 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) {
return nullptr;
unsigned SrcWidth = SrcTy->getScalarSizeInBits();
- unsigned DestWidth = DestTy->getPrimitiveSizeInBits();
unsigned ShAmt = Chunk * DestWidth;
// TODO: This limitation is more strict than necessary. We could sum the
@@ -393,6 +399,20 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
SQ.getWithInstruction(&EI)))
return replaceInstUsesWith(EI, V);
+ // extractelt (select %x, %vec1, %vec2), %const ->
+ // select %x, %vec1[%const], %vec2[%const]
+ // TODO: Support constant folding of multiple select operands:
+ // extractelt (select %x, %vec1, %vec2), (select %x, %c1, %c2)
+ // If the extractelement will for instance try to do out of bounds accesses
+ // because of the values of %c1 and/or %c2, the sequence could be optimized
+ // early. This is currently not possible because constant folding will reach
+ // an unreachable assertion if it doesn't find a constant operand.
+ if (SelectInst *SI = dyn_cast<SelectInst>(EI.getVectorOperand()))
+ if (SI->getCondition()->getType()->isIntegerTy() &&
+ isa<Constant>(EI.getIndexOperand()))
+ if (Instruction *R = FoldOpIntoSelect(EI, SI))
+ return R;
+
// If extracting a specified index from the vector, see if we can recursively
// find a previously computed scalar that was inserted into the vector.
auto *IndexC = dyn_cast<ConstantInt>(Index);
@@ -850,17 +870,16 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
if (NumAggElts > 2)
return nullptr;
- static constexpr auto NotFound = None;
+ static constexpr auto NotFound = std::nullopt;
static constexpr auto FoundMismatch = nullptr;
// Try to find a value of each element of an aggregate.
// FIXME: deal with more complex, not one-dimensional, aggregate types
- SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound);
+ SmallVector<std::optional<Instruction *>, 2> AggElts(NumAggElts, NotFound);
// Do we know values for each element of the aggregate?
auto KnowAllElts = [&AggElts]() {
- return all_of(AggElts,
- [](Optional<Instruction *> Elt) { return Elt != NotFound; });
+ return !llvm::is_contained(AggElts, NotFound);
};
int Depth = 0;
@@ -889,7 +908,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// Now, we may have already previously recorded the value for this element
// of an aggregate. If we did, that means the CurrIVI will later be
// overwritten with the already-recorded value. But if not, let's record it!
- Optional<Instruction *> &Elt = AggElts[Indices.front()];
+ std::optional<Instruction *> &Elt = AggElts[Indices.front()];
Elt = Elt.value_or(InsertedValue);
// FIXME: should we handle chain-terminating undef base operand?
@@ -919,7 +938,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
/// or different elements had different source aggregates.
FoundMismatch
};
- auto Describe = [](Optional<Value *> SourceAggregate) {
+ auto Describe = [](std::optional<Value *> SourceAggregate) {
if (SourceAggregate == NotFound)
return AggregateDescription::NotFound;
if (*SourceAggregate == FoundMismatch)
@@ -933,8 +952,8 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// If found, return the source aggregate from which the extraction was.
// If \p PredBB is provided, does PHI translation of an \p Elt first.
auto FindSourceAggregate =
- [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB,
- Optional<BasicBlock *> PredBB) -> Optional<Value *> {
+ [&](Instruction *Elt, unsigned EltIdx, std::optional<BasicBlock *> UseBB,
+ std::optional<BasicBlock *> PredBB) -> std::optional<Value *> {
// For now(?), only deal with, at most, a single level of PHI indirection.
if (UseBB && PredBB)
Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB));
@@ -961,9 +980,9 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// see if we can find appropriate source aggregate for each of the elements,
// and see it's the same aggregate for each element. If so, return it.
auto FindCommonSourceAggregate =
- [&](Optional<BasicBlock *> UseBB,
- Optional<BasicBlock *> PredBB) -> Optional<Value *> {
- Optional<Value *> SourceAggregate;
+ [&](std::optional<BasicBlock *> UseBB,
+ std::optional<BasicBlock *> PredBB) -> std::optional<Value *> {
+ std::optional<Value *> SourceAggregate;
for (auto I : enumerate(AggElts)) {
assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch &&
@@ -975,7 +994,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// For this element, is there a plausible source aggregate?
// FIXME: we could special-case undef element, IFF we know that in the
// source aggregate said element isn't poison.
- Optional<Value *> SourceAggregateForElement =
+ std::optional<Value *> SourceAggregateForElement =
FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB);
// Okay, what have we found? Does that correlate with previous findings?
@@ -1009,10 +1028,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
return *SourceAggregate;
};
- Optional<Value *> SourceAggregate;
+ std::optional<Value *> SourceAggregate;
// Can we find the source aggregate without looking at predecessors?
- SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None);
+ SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/std::nullopt,
+ /*PredBB=*/std::nullopt);
if (Describe(SourceAggregate) != AggregateDescription::NotFound) {
if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch)
return nullptr; // Conflicting source aggregates!
@@ -1029,7 +1049,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse(
// they all should be defined in the same basic block.
BasicBlock *UseBB = nullptr;
- for (const Optional<Instruction *> &I : AggElts) {
+ for (const std::optional<Instruction *> &I : AggElts) {
BasicBlock *BB = (*I)->getParent();
// If it's the first instruction we've encountered, record the basic block.
if (!UseBB) {
@@ -1495,6 +1515,71 @@ static Instruction *narrowInsElt(InsertElementInst &InsElt,
return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType());
}
+/// If we are inserting 2 halves of a value into adjacent elements of a vector,
+/// try to convert to a single insert with appropriate bitcasts.
+static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt,
+ bool IsBigEndian,
+ InstCombiner::BuilderTy &Builder) {
+ Value *VecOp = InsElt.getOperand(0);
+ Value *ScalarOp = InsElt.getOperand(1);
+ Value *IndexOp = InsElt.getOperand(2);
+
+ // Pattern depends on endian because we expect lower index is inserted first.
+ // Big endian:
+ // inselt (inselt BaseVec, (trunc (lshr X, BW/2), Index0), (trunc X), Index1
+ // Little endian:
+ // inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1
+ // Note: It is not safe to do this transform with an arbitrary base vector
+ // because the bitcast of that vector to fewer/larger elements could
+ // allow poison to spill into an element that was not poison before.
+ // TODO: Detect smaller fractions of the scalar.
+ // TODO: One-use checks are conservative.
+ auto *VTy = dyn_cast<FixedVectorType>(InsElt.getType());
+ Value *Scalar0, *BaseVec;
+ uint64_t Index0, Index1;
+ if (!VTy || (VTy->getNumElements() & 1) ||
+ !match(IndexOp, m_ConstantInt(Index1)) ||
+ !match(VecOp, m_InsertElt(m_Value(BaseVec), m_Value(Scalar0),
+ m_ConstantInt(Index0))) ||
+ !match(BaseVec, m_Undef()))
+ return nullptr;
+
+ // The first insert must be to the index one less than this one, and
+ // the first insert must be to an even index.
+ if (Index0 + 1 != Index1 || Index0 & 1)
+ return nullptr;
+
+ // For big endian, the high half of the value should be inserted first.
+ // For little endian, the low half of the value should be inserted first.
+ Value *X;
+ uint64_t ShAmt;
+ if (IsBigEndian) {
+ if (!match(ScalarOp, m_Trunc(m_Value(X))) ||
+ !match(Scalar0, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt)))))
+ return nullptr;
+ } else {
+ if (!match(Scalar0, m_Trunc(m_Value(X))) ||
+ !match(ScalarOp, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt)))))
+ return nullptr;
+ }
+
+ Type *SrcTy = X->getType();
+ unsigned ScalarWidth = SrcTy->getScalarSizeInBits();
+ unsigned VecEltWidth = VTy->getScalarSizeInBits();
+ if (ScalarWidth != VecEltWidth * 2 || ShAmt != VecEltWidth)
+ return nullptr;
+
+ // Bitcast the base vector to a vector type with the source element type.
+ Type *CastTy = FixedVectorType::get(SrcTy, VTy->getNumElements() / 2);
+ Value *CastBaseVec = Builder.CreateBitCast(BaseVec, CastTy);
+
+ // Scale the insert index for a vector with half as many elements.
+ // bitcast (inselt (bitcast BaseVec), X, NewIndex)
+ uint64_t NewIndex = IsBigEndian ? Index1 / 2 : Index0 / 2;
+ Value *NewInsert = Builder.CreateInsertElement(CastBaseVec, X, NewIndex);
+ return new BitCastInst(NewInsert, VTy);
+}
+
Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
Value *VecOp = IE.getOperand(0);
Value *ScalarOp = IE.getOperand(1);
@@ -1505,10 +1590,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
return replaceInstUsesWith(IE, V);
// Canonicalize type of constant indices to i64 to simplify CSE
- if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp))
+ if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) {
if (auto *NewIdx = getPreferredVectorIndex(IndexC))
return replaceOperand(IE, 2, NewIdx);
+ Value *BaseVec, *OtherScalar;
+ uint64_t OtherIndexVal;
+ if (match(VecOp, m_OneUse(m_InsertElt(m_Value(BaseVec),
+ m_Value(OtherScalar),
+ m_ConstantInt(OtherIndexVal)))) &&
+ !isa<Constant>(OtherScalar) && OtherIndexVal > IndexC->getZExtValue()) {
+ Value *NewIns = Builder.CreateInsertElement(BaseVec, ScalarOp, IdxOp);
+ return InsertElementInst::Create(NewIns, OtherScalar,
+ Builder.getInt64(OtherIndexVal));
+ }
+ }
+
// If the scalar is bitcast and inserted into undef, do the insert in the
// source type followed by bitcast.
// TODO: Generalize for insert into any constant, not just undef?
@@ -1622,6 +1719,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) {
if (Instruction *Ext = narrowInsElt(IE, Builder))
return Ext;
+ if (Instruction *Ext = foldTruncInsEltPair(IE, DL.isBigEndian(), Builder))
+ return Ext;
+
return nullptr;
}
@@ -1653,7 +1753,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// from an undefined element in an operand.
if (llvm::is_contained(Mask, -1))
return false;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
case Instruction::Add:
case Instruction::FAdd:
case Instruction::Sub:
@@ -1700,8 +1800,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask,
// Verify that 'CI' does not occur twice in Mask. A single 'insertelement'
// can't put an element into multiple indices.
bool SeenOnce = false;
- for (int i = 0, e = Mask.size(); i != e; ++i) {
- if (Mask[i] == ElementNumber) {
+ for (int I : Mask) {
+ if (I == ElementNumber) {
if (SeenOnce)
return false;
SeenOnce = true;
@@ -1957,6 +2057,56 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) {
return {};
}
+/// A select shuffle of a select shuffle with a shared operand can be reduced
+/// to a single select shuffle. This is an obvious improvement in IR, and the
+/// backend is expected to lower select shuffles efficiently.
+static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) {
+ assert(Shuf.isSelect() && "Must have select-equivalent shuffle");
+
+ Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1);
+ SmallVector<int, 16> Mask;
+ Shuf.getShuffleMask(Mask);
+ unsigned NumElts = Mask.size();
+
+ // Canonicalize a select shuffle with common operand as Op1.
+ auto *ShufOp = dyn_cast<ShuffleVectorInst>(Op0);
+ if (ShufOp && ShufOp->isSelect() &&
+ (ShufOp->getOperand(0) == Op1 || ShufOp->getOperand(1) == Op1)) {
+ std::swap(Op0, Op1);
+ ShuffleVectorInst::commuteShuffleMask(Mask, NumElts);
+ }
+
+ ShufOp = dyn_cast<ShuffleVectorInst>(Op1);
+ if (!ShufOp || !ShufOp->isSelect() ||
+ (ShufOp->getOperand(0) != Op0 && ShufOp->getOperand(1) != Op0))
+ return nullptr;
+
+ Value *X = ShufOp->getOperand(0), *Y = ShufOp->getOperand(1);
+ SmallVector<int, 16> Mask1;
+ ShufOp->getShuffleMask(Mask1);
+ assert(Mask1.size() == NumElts && "Vector size changed with select shuffle");
+
+ // Canonicalize common operand (Op0) as X (first operand of first shuffle).
+ if (Y == Op0) {
+ std::swap(X, Y);
+ ShuffleVectorInst::commuteShuffleMask(Mask1, NumElts);
+ }
+
+ // If the mask chooses from X (operand 0), it stays the same.
+ // If the mask chooses from the earlier shuffle, the other mask value is
+ // transferred to the combined select shuffle:
+ // shuf X, (shuf X, Y, M1), M --> shuf X, Y, M'
+ SmallVector<int, 16> NewMask(NumElts);
+ for (unsigned i = 0; i != NumElts; ++i)
+ NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i];
+
+ // A select mask with undef elements might look like an identity mask.
+ assert((ShuffleVectorInst::isSelectMask(NewMask) ||
+ ShuffleVectorInst::isIdentityMask(NewMask)) &&
+ "Unexpected shuffle mask");
+ return new ShuffleVectorInst(X, Y, NewMask);
+}
+
static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) {
assert(Shuf.isSelect() && "Must have select-equivalent shuffle");
@@ -2061,6 +2211,9 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) {
return &Shuf;
}
+ if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf))
+ return I;
+
if (Instruction *I = foldSelectShuffleWith1Binop(Shuf))
return I;
@@ -2541,6 +2694,35 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) {
return new ShuffleVectorInst(X, Y, NewMask);
}
+// Splatting the first element of the result of a BinOp, where any of the
+// BinOp's operands are the result of a first element splat can be simplified to
+// splatting the first element of the result of the BinOp
+Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) {
+ if (!match(SVI.getOperand(1), m_Undef()) ||
+ !match(SVI.getShuffleMask(), m_ZeroMask()))
+ return nullptr;
+
+ Value *Op0 = SVI.getOperand(0);
+ Value *X, *Y;
+ if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()),
+ m_Value(Y))) &&
+ !match(Op0, m_BinOp(m_Value(X),
+ m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask()))))
+ return nullptr;
+ if (X->getType() != Y->getType())
+ return nullptr;
+
+ auto *BinOp = cast<BinaryOperator>(Op0);
+ if (!isSafeToSpeculativelyExecute(BinOp))
+ return nullptr;
+
+ Value *NewBO = Builder.CreateBinOp(BinOp->getOpcode(), X, Y);
+ if (auto NewBOI = dyn_cast<Instruction>(NewBO))
+ NewBOI->copyIRFlags(BinOp);
+
+ return new ShuffleVectorInst(NewBO, SVI.getShuffleMask());
+}
+
Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *LHS = SVI.getOperand(0);
Value *RHS = SVI.getOperand(1);
@@ -2549,7 +2731,9 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
SVI.getType(), ShufQuery))
return replaceInstUsesWith(SVI, V);
- // Bail out for scalable vectors
+ if (Instruction *I = simplifyBinOpSplats(SVI))
+ return I;
+
if (isa<ScalableVectorType>(LHS->getType()))
return nullptr;
@@ -2694,7 +2878,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) {
Value *V = LHS;
unsigned MaskElems = Mask.size();
auto *SrcTy = cast<FixedVectorType>(V->getType());
- unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize();
+ unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedValue();
unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType());
assert(SrcElemBitWidth && "vector elements must have a bitwidth");
unsigned SrcNumElems = SrcTy->getNumElements();
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 71c763de43b4..fb6f4f96ea48 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -38,7 +38,6 @@
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
-#include "llvm/ADT/None.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/Statistic.h"
@@ -99,16 +98,19 @@
#include "llvm/Support/KnownBits.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/InstCombine/InstCombine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Local.h"
#include <algorithm>
#include <cassert>
#include <cstdint>
#include <memory>
+#include <optional>
#include <string>
#include <utility>
#define DEBUG_TYPE "instcombine"
#include "llvm/Transforms/Utils/InstructionWorklist.h"
+#include <optional>
using namespace llvm;
using namespace llvm::PatternMatch;
@@ -167,16 +169,16 @@ MaxArraySize("instcombine-maxarray-size", cl::init(1024),
static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare",
cl::Hidden, cl::init(true));
-Optional<Instruction *>
+std::optional<Instruction *>
InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) {
// Handle target specific intrinsics
if (II.getCalledFunction()->isTargetIntrinsic()) {
return TTI.instCombineIntrinsic(*this, II);
}
- return None;
+ return std::nullopt;
}
-Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic(
+std::optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic(
IntrinsicInst &II, APInt DemandedMask, KnownBits &Known,
bool &KnownBitsComputed) {
// Handle target specific intrinsics
@@ -184,10 +186,10 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic(
return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known,
KnownBitsComputed);
}
- return None;
+ return std::nullopt;
}
-Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
+std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2,
APInt &UndefElts3,
std::function<void(Instruction *, unsigned, APInt, APInt &)>
@@ -198,11 +200,11 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic(
*this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3,
SimplifyAndSetOp);
}
- return None;
+ return std::nullopt;
}
Value *InstCombinerImpl::EmitGEPOffset(User *GEP) {
- return llvm::EmitGEPOffset(&Builder, DL, GEP);
+ return llvm::emitGEPOffset(&Builder, DL, GEP);
}
/// Legal integers and common types are considered desirable. This is used to
@@ -223,11 +225,12 @@ bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const {
/// Return true if it is desirable to convert an integer computation from a
/// given bit width to a new bit width.
-/// We don't want to convert from a legal to an illegal type or from a smaller
-/// to a larger illegal type. A width of '1' is always treated as a desirable
-/// type because i1 is a fundamental type in IR, and there are many specialized
-/// optimizations for i1 types. Common/desirable widths are equally treated as
-/// legal to convert to, in order to open up more combining opportunities.
+/// We don't want to convert from a legal or desirable type (like i8) to an
+/// illegal type or from a smaller to a larger illegal type. A width of '1'
+/// is always treated as a desirable type because i1 is a fundamental type in
+/// IR, and there are many specialized optimizations for i1 types.
+/// Common/desirable widths are equally treated as legal to convert to, in
+/// order to open up more combining opportunities.
bool InstCombinerImpl::shouldChangeType(unsigned FromWidth,
unsigned ToWidth) const {
bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth);
@@ -238,9 +241,9 @@ bool InstCombinerImpl::shouldChangeType(unsigned FromWidth,
if (ToWidth < FromWidth && isDesirableIntType(ToWidth))
return true;
- // If this is a legal integer from type, and the result would be an illegal
- // type, don't do the transformation.
- if (FromLegal && !ToLegal)
+ // If this is a legal or desiable integer from type, and the result would be
+ // an illegal type, don't do the transformation.
+ if ((FromLegal || isDesirableIntType(FromWidth)) && !ToLegal)
return false;
// Otherwise, if both are illegal, do not increase the size of the result. We
@@ -367,14 +370,14 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1,
// inttoptr ( ptrtoint (x) ) --> x
Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) {
auto *IntToPtr = dyn_cast<IntToPtrInst>(Val);
- if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) ==
+ if (IntToPtr && DL.getTypeSizeInBits(IntToPtr->getDestTy()) ==
DL.getTypeSizeInBits(IntToPtr->getSrcTy())) {
auto *PtrToInt = dyn_cast<PtrToIntInst>(IntToPtr->getOperand(0));
Type *CastTy = IntToPtr->getDestTy();
if (PtrToInt &&
CastTy->getPointerAddressSpace() ==
PtrToInt->getSrcTy()->getPointerAddressSpace() &&
- DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) ==
+ DL.getTypeSizeInBits(PtrToInt->getSrcTy()) ==
DL.getTypeSizeInBits(PtrToInt->getDestTy())) {
return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy,
"", PtrToInt);
@@ -632,14 +635,14 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op,
/// This tries to simplify binary operations by factorizing out common terms
/// (e. g. "(A*B)+(A*C)" -> "A*(B+C)").
-Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
- Instruction::BinaryOps InnerOpcode,
- Value *A, Value *B, Value *C,
- Value *D) {
+static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ,
+ InstCombiner::BuilderTy &Builder,
+ Instruction::BinaryOps InnerOpcode, Value *A,
+ Value *B, Value *C, Value *D) {
assert(A && B && C && D && "All values must be provided");
Value *V = nullptr;
- Value *SimplifiedInst = nullptr;
+ Value *RetVal = nullptr;
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
@@ -647,7 +650,7 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
bool InnerCommutative = Instruction::isCommutative(InnerOpcode);
// Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"?
- if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode))
+ if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) {
// Does the instruction have the form "(A op' B) op (A op' D)" or, in the
// commutative case, "(A op' B) op (C op' A)"?
if (A == C || (InnerCommutative && A == D)) {
@@ -656,17 +659,18 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
// Consider forming "A op' (B op D)".
// If "B op D" simplifies then it can be formed with no cost.
V = simplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I));
- // If "B op D" doesn't simplify then only go on if both of the existing
+
+ // If "B op D" doesn't simplify then only go on if one of the existing
// operations "A op' B" and "C op' D" will be zapped as no longer used.
- if (!V && LHS->hasOneUse() && RHS->hasOneUse())
+ if (!V && (LHS->hasOneUse() || RHS->hasOneUse()))
V = Builder.CreateBinOp(TopLevelOpcode, B, D, RHS->getName());
- if (V) {
- SimplifiedInst = Builder.CreateBinOp(InnerOpcode, A, V);
- }
+ if (V)
+ RetVal = Builder.CreateBinOp(InnerOpcode, A, V);
}
+ }
// Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"?
- if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode))
+ if (!RetVal && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) {
// Does the instruction have the form "(A op' B) op (C op' B)" or, in the
// commutative case, "(A op' B) op (B op' D)"?
if (B == D || (InnerCommutative && B == C)) {
@@ -676,61 +680,94 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
// If "A op C" simplifies then it can be formed with no cost.
V = simplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I));
- // If "A op C" doesn't simplify then only go on if both of the existing
+ // If "A op C" doesn't simplify then only go on if one of the existing
// operations "A op' B" and "C op' D" will be zapped as no longer used.
- if (!V && LHS->hasOneUse() && RHS->hasOneUse())
+ if (!V && (LHS->hasOneUse() || RHS->hasOneUse()))
V = Builder.CreateBinOp(TopLevelOpcode, A, C, LHS->getName());
- if (V) {
- SimplifiedInst = Builder.CreateBinOp(InnerOpcode, V, B);
- }
+ if (V)
+ RetVal = Builder.CreateBinOp(InnerOpcode, V, B);
}
+ }
- if (SimplifiedInst) {
- ++NumFactor;
- SimplifiedInst->takeName(&I);
+ if (!RetVal)
+ return nullptr;
- // Check if we can add NSW/NUW flags to SimplifiedInst. If so, set them.
- if (BinaryOperator *BO = dyn_cast<BinaryOperator>(SimplifiedInst)) {
- if (isa<OverflowingBinaryOperator>(SimplifiedInst)) {
- bool HasNSW = false;
- bool HasNUW = false;
- if (isa<OverflowingBinaryOperator>(&I)) {
- HasNSW = I.hasNoSignedWrap();
- HasNUW = I.hasNoUnsignedWrap();
- }
+ ++NumFactor;
+ RetVal->takeName(&I);
- if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) {
- HasNSW &= LOBO->hasNoSignedWrap();
- HasNUW &= LOBO->hasNoUnsignedWrap();
- }
+ // Try to add no-overflow flags to the final value.
+ if (isa<OverflowingBinaryOperator>(RetVal)) {
+ bool HasNSW = false;
+ bool HasNUW = false;
+ if (isa<OverflowingBinaryOperator>(&I)) {
+ HasNSW = I.hasNoSignedWrap();
+ HasNUW = I.hasNoUnsignedWrap();
+ }
+ if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) {
+ HasNSW &= LOBO->hasNoSignedWrap();
+ HasNUW &= LOBO->hasNoUnsignedWrap();
+ }
- if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) {
- HasNSW &= ROBO->hasNoSignedWrap();
- HasNUW &= ROBO->hasNoUnsignedWrap();
- }
+ if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) {
+ HasNSW &= ROBO->hasNoSignedWrap();
+ HasNUW &= ROBO->hasNoUnsignedWrap();
+ }
- if (TopLevelOpcode == Instruction::Add &&
- InnerOpcode == Instruction::Mul) {
- // We can propagate 'nsw' if we know that
- // %Y = mul nsw i16 %X, C
- // %Z = add nsw i16 %Y, %X
- // =>
- // %Z = mul nsw i16 %X, C+1
- //
- // iff C+1 isn't INT_MIN
- const APInt *CInt;
- if (match(V, m_APInt(CInt))) {
- if (!CInt->isMinSignedValue())
- BO->setHasNoSignedWrap(HasNSW);
- }
+ if (TopLevelOpcode == Instruction::Add && InnerOpcode == Instruction::Mul) {
+ // We can propagate 'nsw' if we know that
+ // %Y = mul nsw i16 %X, C
+ // %Z = add nsw i16 %Y, %X
+ // =>
+ // %Z = mul nsw i16 %X, C+1
+ //
+ // iff C+1 isn't INT_MIN
+ const APInt *CInt;
+ if (match(V, m_APInt(CInt)) && !CInt->isMinSignedValue())
+ cast<Instruction>(RetVal)->setHasNoSignedWrap(HasNSW);
- // nuw can be propagated with any constant or nuw value.
- BO->setHasNoUnsignedWrap(HasNUW);
- }
- }
+ // nuw can be propagated with any constant or nuw value.
+ cast<Instruction>(RetVal)->setHasNoUnsignedWrap(HasNUW);
}
}
- return SimplifiedInst;
+ return RetVal;
+}
+
+Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) {
+ Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
+ BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
+ BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
+ Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
+ Value *A, *B, *C, *D;
+ Instruction::BinaryOps LHSOpcode, RHSOpcode;
+
+ if (Op0)
+ LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
+ if (Op1)
+ RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
+
+ // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
+ // a common term.
+ if (Op0 && Op1 && LHSOpcode == RHSOpcode)
+ if (Value *V = tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, D))
+ return V;
+
+ // The instruction has the form "(A op' B) op (C)". Try to factorize common
+ // term.
+ if (Op0)
+ if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
+ if (Value *V =
+ tryFactorization(I, SQ, Builder, LHSOpcode, A, B, RHS, Ident))
+ return V;
+
+ // The instruction has the form "(B) op (C op' D)". Try to factorize common
+ // term.
+ if (Op1)
+ if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
+ if (Value *V =
+ tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D))
+ return V;
+
+ return nullptr;
}
/// This tries to simplify binary operations which some other binary operation
@@ -738,41 +775,15 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I,
/// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in
/// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win).
/// Returns the simplified value, or null if it didn't simplify.
-Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) {
+Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) {
Value *LHS = I.getOperand(0), *RHS = I.getOperand(1);
BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS);
BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS);
Instruction::BinaryOps TopLevelOpcode = I.getOpcode();
- {
- // Factorization.
- Value *A, *B, *C, *D;
- Instruction::BinaryOps LHSOpcode, RHSOpcode;
- if (Op0)
- LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B);
- if (Op1)
- RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D);
-
- // The instruction has the form "(A op' B) op (C op' D)". Try to factorize
- // a common term.
- if (Op0 && Op1 && LHSOpcode == RHSOpcode)
- if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D))
- return V;
-
- // The instruction has the form "(A op' B) op (C)". Try to factorize common
- // term.
- if (Op0)
- if (Value *Ident = getIdentityValue(LHSOpcode, RHS))
- if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident))
- return V;
-
- // The instruction has the form "(B) op (C op' D)". Try to factorize common
- // term.
- if (Op1)
- if (Value *Ident = getIdentityValue(RHSOpcode, LHS))
- if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D))
- return V;
- }
+ // Factorization.
+ if (Value *R = tryFactorizationFolds(I))
+ return R;
// Expansion.
if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) {
@@ -876,6 +887,28 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
SimplifyQuery Q = SQ.getWithInstruction(&I);
Value *Cond, *True = nullptr, *False = nullptr;
+
+ // Special-case for add/negate combination. Replace the zero in the negation
+ // with the trailing add operand:
+ // (Cond ? TVal : -N) + Z --> Cond ? True : (Z - N)
+ // (Cond ? -N : FVal) + Z --> Cond ? (Z - N) : False
+ auto foldAddNegate = [&](Value *TVal, Value *FVal, Value *Z) -> Value * {
+ // We need an 'add' and exactly 1 arm of the select to have been simplified.
+ if (Opcode != Instruction::Add || (!True && !False) || (True && False))
+ return nullptr;
+
+ Value *N;
+ if (True && match(FVal, m_Neg(m_Value(N)))) {
+ Value *Sub = Builder.CreateSub(Z, N);
+ return Builder.CreateSelect(Cond, True, Sub, I.getName());
+ }
+ if (False && match(TVal, m_Neg(m_Value(N)))) {
+ Value *Sub = Builder.CreateSub(Z, N);
+ return Builder.CreateSelect(Cond, Sub, False, I.getName());
+ }
+ return nullptr;
+ };
+
if (LHSIsSelect && RHSIsSelect && A == D) {
// (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F)
Cond = A;
@@ -893,11 +926,15 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
Cond = A;
True = simplifyBinOp(Opcode, B, RHS, FMF, Q);
False = simplifyBinOp(Opcode, C, RHS, FMF, Q);
+ if (Value *NewSel = foldAddNegate(B, C, RHS))
+ return NewSel;
} else if (RHSIsSelect && RHS->hasOneUse()) {
// X op (D ? E : F) -> D ? (X op E) : (X op F)
Cond = D;
True = simplifyBinOp(Opcode, LHS, E, FMF, Q);
False = simplifyBinOp(Opcode, LHS, F, FMF, Q);
+ if (Value *NewSel = foldAddNegate(E, F, LHS))
+ return NewSel;
}
if (!True || !False)
@@ -910,8 +947,10 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I,
/// Freely adapt every user of V as-if V was changed to !V.
/// WARNING: only if canFreelyInvertAllUsersOf() said this can be done.
-void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) {
- for (User *U : I->users()) {
+void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) {
+ for (User *U : make_early_inc_range(I->users())) {
+ if (U == IgnoredUser)
+ continue; // Don't consider this user.
switch (cast<Instruction>(U)->getOpcode()) {
case Instruction::Select: {
auto *SI = cast<SelectInst>(U);
@@ -1033,6 +1072,9 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO,
return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1));
}
+ if (auto *EI = dyn_cast<ExtractElementInst>(&I))
+ return Builder.CreateExtractElement(SO, EI->getIndexOperand());
+
assert(I.isBinaryOp() && "Unexpected opcode for select folding");
// Figure out if the constant is the left or the right argument.
@@ -1133,22 +1175,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI,
return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI);
}
-static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV,
- InstCombiner::BuilderTy &Builder) {
- bool ConstIsRHS = isa<Constant>(I->getOperand(1));
- Constant *C = cast<Constant>(I->getOperand(ConstIsRHS));
-
- Value *Op0 = InV, *Op1 = C;
- if (!ConstIsRHS)
- std::swap(Op0, Op1);
-
- Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phi.bo");
- auto *FPInst = dyn_cast<Instruction>(RI);
- if (FPInst && isa<FPMathOperator>(FPInst))
- FPInst->copyFastMathFlags(I);
- return RI;
-}
-
Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
unsigned NumPHIValues = PN->getNumIncomingValues();
if (NumPHIValues == 0)
@@ -1167,48 +1193,69 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
// Otherwise, we can replace *all* users with the new PHI we form.
}
- // Check to see if all of the operands of the PHI are simple constants
- // (constantint/constantfp/undef). If there is one non-constant value,
- // remember the BB it is in. If there is more than one or if *it* is a PHI,
- // bail out. We don't do arbitrary constant expressions here because moving
- // their computation can be expensive without a cost model.
- BasicBlock *NonConstBB = nullptr;
+ // Check to see whether the instruction can be folded into each phi operand.
+ // If there is one operand that does not fold, remember the BB it is in.
+ // If there is more than one or if *it* is a PHI, bail out.
+ SmallVector<Value *> NewPhiValues;
+ BasicBlock *NonSimplifiedBB = nullptr;
+ Value *NonSimplifiedInVal = nullptr;
for (unsigned i = 0; i != NumPHIValues; ++i) {
Value *InVal = PN->getIncomingValue(i);
- // For non-freeze, require constant operand
- // For freeze, require non-undef, non-poison operand
- if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant()))
- continue;
- if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal))
+ BasicBlock *InBB = PN->getIncomingBlock(i);
+
+ // NB: It is a precondition of this transform that the operands be
+ // phi translatable! This is usually trivially satisfied by limiting it
+ // to constant ops, and for selects we do a more sophisticated check.
+ SmallVector<Value *> Ops;
+ for (Value *Op : I.operands()) {
+ if (Op == PN)
+ Ops.push_back(InVal);
+ else
+ Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB));
+ }
+
+ // Don't consider the simplification successful if we get back a constant
+ // expression. That's just an instruction in hiding.
+ // Also reject the case where we simplify back to the phi node. We wouldn't
+ // be able to remove it in that case.
+ Value *NewVal = simplifyInstructionWithOperands(
+ &I, Ops, SQ.getWithInstruction(InBB->getTerminator()));
+ if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) {
+ NewPhiValues.push_back(NewVal);
continue;
+ }
if (isa<PHINode>(InVal)) return nullptr; // Itself a phi.
- if (NonConstBB) return nullptr; // More than one non-const value.
+ if (NonSimplifiedBB) return nullptr; // More than one non-simplified value.
- NonConstBB = PN->getIncomingBlock(i);
+ NonSimplifiedBB = InBB;
+ NonSimplifiedInVal = InVal;
+ NewPhiValues.push_back(nullptr);
// If the InVal is an invoke at the end of the pred block, then we can't
// insert a computation after it without breaking the edge.
if (isa<InvokeInst>(InVal))
- if (cast<Instruction>(InVal)->getParent() == NonConstBB)
+ if (cast<Instruction>(InVal)->getParent() == NonSimplifiedBB)
return nullptr;
// If the incoming non-constant value is reachable from the phis block,
// we'll push the operation across a loop backedge. This could result in
// an infinite combine loop, and is generally non-profitable (especially
// if the operation was originally outside the loop).
- if (isPotentiallyReachable(PN->getParent(), NonConstBB, nullptr, &DT, LI))
+ if (isPotentiallyReachable(PN->getParent(), NonSimplifiedBB, nullptr, &DT,
+ LI))
return nullptr;
}
- // If there is exactly one non-constant value, we can insert a copy of the
+ // If there is exactly one non-simplified value, we can insert a copy of the
// operation in that block. However, if this is a critical edge, we would be
// inserting the computation on some other paths (e.g. inside a loop). Only
// do this if the pred block is unconditionally branching into the phi block.
// Also, make sure that the pred block is not dead code.
- if (NonConstBB != nullptr) {
- BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator());
- if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB))
+ if (NonSimplifiedBB != nullptr) {
+ BranchInst *BI = dyn_cast<BranchInst>(NonSimplifiedBB->getTerminator());
+ if (!BI || !BI->isUnconditional() ||
+ !DT.isReachableFromEntry(NonSimplifiedBB))
return nullptr;
}
@@ -1219,83 +1266,23 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) {
// If we are going to have to insert a new computation, do so right before the
// predecessor's terminator.
- if (NonConstBB)
- Builder.SetInsertPoint(NonConstBB->getTerminator());
-
- // Next, add all of the operands to the PHI.
- if (SelectInst *SI = dyn_cast<SelectInst>(&I)) {
- // We only currently try to fold the condition of a select when it is a phi,
- // not the true/false values.
- Value *TrueV = SI->getTrueValue();
- Value *FalseV = SI->getFalseValue();
- BasicBlock *PhiTransBB = PN->getParent();
- for (unsigned i = 0; i != NumPHIValues; ++i) {
- BasicBlock *ThisBB = PN->getIncomingBlock(i);
- Value *TrueVInPred = TrueV->DoPHITranslation(PhiTransBB, ThisBB);
- Value *FalseVInPred = FalseV->DoPHITranslation(PhiTransBB, ThisBB);
- Value *InV = nullptr;
- // Beware of ConstantExpr: it may eventually evaluate to getNullValue,
- // even if currently isNullValue gives false.
- Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i));
- // For vector constants, we cannot use isNullValue to fold into
- // FalseVInPred versus TrueVInPred. When we have individual nonzero
- // elements in the vector, we will incorrectly fold InC to
- // `TrueVInPred`.
- if (InC && isa<ConstantInt>(InC))
- InV = InC->isNullValue() ? FalseVInPred : TrueVInPred;
- else {
- // Generate the select in the same block as PN's current incoming block.
- // Note: ThisBB need not be the NonConstBB because vector constants
- // which are constants by definition are handled here.
- // FIXME: This can lead to an increase in IR generation because we might
- // generate selects for vector constant phi operand, that could not be
- // folded to TrueVInPred or FalseVInPred as done for ConstantInt. For
- // non-vector phis, this transformation was always profitable because
- // the select would be generated exactly once in the NonConstBB.
- Builder.SetInsertPoint(ThisBB->getTerminator());
- InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred,
- FalseVInPred, "phi.sel");
- }
- NewPN->addIncoming(InV, ThisBB);
- }
- } else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) {
- Constant *C = cast<Constant>(I.getOperand(1));
- for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV = nullptr;
- if (auto *InC = dyn_cast<Constant>(PN->getIncomingValue(i)))
- InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C);
- else
- InV = Builder.CreateCmp(CI->getPredicate(), PN->getIncomingValue(i),
- C, "phi.cmp");
- NewPN->addIncoming(InV, PN->getIncomingBlock(i));
- }
- } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
- for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i),
- Builder);
- NewPN->addIncoming(InV, PN->getIncomingBlock(i));
- }
- } else if (isa<FreezeInst>(&I)) {
- for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV;
- if (NonConstBB == PN->getIncomingBlock(i))
- InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr");
- else
- InV = PN->getIncomingValue(i);
- NewPN->addIncoming(InV, PN->getIncomingBlock(i));
- }
- } else {
- CastInst *CI = cast<CastInst>(&I);
- Type *RetTy = CI->getType();
- for (unsigned i = 0; i != NumPHIValues; ++i) {
- Value *InV;
- if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)))
- InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy);
+ Instruction *Clone = nullptr;
+ if (NonSimplifiedBB) {
+ Clone = I.clone();
+ for (Use &U : Clone->operands()) {
+ if (U == PN)
+ U = NonSimplifiedInVal;
else
- InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i),
- I.getType(), "phi.cast");
- NewPN->addIncoming(InV, PN->getIncomingBlock(i));
+ U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB);
}
+ InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator());
+ }
+
+ for (unsigned i = 0; i != NumPHIValues; ++i) {
+ if (NewPhiValues[i])
+ NewPN->addIncoming(NewPhiValues[i], PN->getIncomingBlock(i));
+ else
+ NewPN->addIncoming(Clone, PN->getIncomingBlock(i));
}
for (User *U : make_early_inc_range(PN->users())) {
@@ -1696,6 +1683,35 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
return new ShuffleVectorInst(NewBO0, NewBO1, Mask);
}
+ auto createBinOpReverse = [&](Value *X, Value *Y) {
+ Value *V = Builder.CreateBinOp(Opcode, X, Y, Inst.getName());
+ if (auto *BO = dyn_cast<BinaryOperator>(V))
+ BO->copyIRFlags(&Inst);
+ Module *M = Inst.getModule();
+ Function *F = Intrinsic::getDeclaration(
+ M, Intrinsic::experimental_vector_reverse, V->getType());
+ return CallInst::Create(F, V);
+ };
+
+ // NOTE: Reverse shuffles don't require the speculative execution protection
+ // below because they don't affect which lanes take part in the computation.
+
+ Value *V1, *V2;
+ if (match(LHS, m_VecReverse(m_Value(V1)))) {
+ // Op(rev(V1), rev(V2)) -> rev(Op(V1, V2))
+ if (match(RHS, m_VecReverse(m_Value(V2))) &&
+ (LHS->hasOneUse() || RHS->hasOneUse() ||
+ (LHS == RHS && LHS->hasNUses(2))))
+ return createBinOpReverse(V1, V2);
+
+ // Op(rev(V1), RHSSplat)) -> rev(Op(V1, RHSSplat))
+ if (LHS->hasOneUse() && isSplatValue(RHS))
+ return createBinOpReverse(V1, RHS);
+ }
+ // Op(LHSSplat, rev(V2)) -> rev(Op(LHSSplat, V2))
+ else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2)))))
+ return createBinOpReverse(LHS, V2);
+
// It may not be safe to reorder shuffles and things like div, urem, etc.
// because we may trap when executing those ops on unknown vector elements.
// See PR20059.
@@ -1711,7 +1727,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
// If both arguments of the binary operation are shuffles that use the same
// mask and shuffle within a single vector, move the shuffle after the binop.
- Value *V1, *V2;
if (match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))) &&
match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(Mask))) &&
V1->getType() == V2->getType() &&
@@ -2228,7 +2243,7 @@ Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI,
if (Instruction *I = visitBitCast(*BCI)) {
if (I != BCI) {
I->takeName(BCI);
- BCI->getParent()->getInstList().insert(BCI->getIterator(), I);
+ I->insertInto(BCI->getParent(), BCI->getIterator());
replaceInstUsesWith(*BCI, I);
}
return &GEP;
@@ -2434,10 +2449,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
NewGEP->setOperand(DI, NewPN);
}
- GEP.getParent()->getInstList().insert(
- GEP.getParent()->getFirstInsertionPt(), NewGEP);
- replaceOperand(GEP, 0, NewGEP);
- PtrOp = NewGEP;
+ NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt());
+ return replaceOperand(GEP, 0, NewGEP);
}
if (auto *Src = dyn_cast<GEPOperator>(PtrOp))
@@ -2450,7 +2463,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
unsigned AS = GEP.getPointerAddressSpace();
if (GEP.getOperand(1)->getType()->getScalarSizeInBits() ==
DL.getIndexSizeInBits(AS)) {
- uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedSize();
+ uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedValue();
bool Matched = false;
uint64_t C;
@@ -2580,8 +2593,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) {
// Check that changing the type amounts to dividing the index by a scale
// factor.
- uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize();
- uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy).getFixedSize();
+ uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue();
+ uint64_t SrcSize =
+ DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue();
if (ResSize && SrcSize % ResSize == 0) {
Value *Idx = GEP.getOperand(1);
unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
@@ -2617,10 +2631,10 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
StrippedPtrEltTy->isArrayTy()) {
// Check that changing to the array element type amounts to dividing the
// index by a scale factor.
- uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize();
+ uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue();
uint64_t ArrayEltSize =
DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType())
- .getFixedSize();
+ .getFixedValue();
if (ResSize && ArrayEltSize % ResSize == 0) {
Value *Idx = GEP.getOperand(1);
unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits();
@@ -2681,7 +2695,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) {
BasePtrOffset.isNonNegative()) {
APInt AllocSize(
IdxWidth,
- DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinSize());
+ DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue());
if (BasePtrOffset.ule(AllocSize)) {
return GetElementPtrInst::CreateInBounds(
GEP.getSourceElementType(), PtrOp, Indices, GEP.getName());
@@ -2724,7 +2738,7 @@ static bool isRemovableWrite(CallBase &CB, Value *UsedV,
// If the only possible side effect of the call is writing to the alloca,
// and the result isn't used, we can safely remove any reads implied by the
// call including those which might read the alloca itself.
- Optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI);
+ std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI);
return Dest && Dest->Ptr == UsedV;
}
@@ -2732,7 +2746,7 @@ static bool isAllocSiteRemovable(Instruction *AI,
SmallVectorImpl<WeakTrackingVH> &Users,
const TargetLibraryInfo &TLI) {
SmallVector<Instruction*, 4> Worklist;
- const Optional<StringRef> Family = getAllocationFamily(AI, &TLI);
+ const std::optional<StringRef> Family = getAllocationFamily(AI, &TLI);
Worklist.push_back(AI);
do {
@@ -2778,7 +2792,7 @@ static bool isAllocSiteRemovable(Instruction *AI,
MemIntrinsic *MI = cast<MemIntrinsic>(II);
if (MI->isVolatile() || MI->getRawDest() != PI)
return false;
- LLVM_FALLTHROUGH;
+ [[fallthrough]];
}
case Intrinsic::assume:
case Intrinsic::invariant_start:
@@ -2808,7 +2822,7 @@ static bool isAllocSiteRemovable(Instruction *AI,
continue;
}
- if (getReallocatedOperand(cast<CallBase>(I), &TLI) == PI &&
+ if (getReallocatedOperand(cast<CallBase>(I)) == PI &&
getAllocationFamily(I, &TLI) == Family) {
assert(Family);
Users.emplace_back(I);
@@ -2902,7 +2916,7 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) {
Module *M = II->getModule();
Function *F = Intrinsic::getDeclaration(M, Intrinsic::donothing);
InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(),
- None, "", II->getParent());
+ std::nullopt, "", II->getParent());
}
// Remove debug intrinsics which describe the value contained within the
@@ -3052,7 +3066,7 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) {
// realloc() entirely.
CallInst *CI = dyn_cast<CallInst>(Op);
if (CI && CI->hasOneUse())
- if (Value *ReallocatedOp = getReallocatedOperand(CI, &TLI))
+ if (Value *ReallocatedOp = getReallocatedOperand(CI))
return eraseInstFromFunction(*replaceInstUsesWith(*CI, ReallocatedOp));
// If we optimize for code size, try to move the call to free before the null
@@ -3166,31 +3180,41 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) {
return visitUnconditionalBranchInst(BI);
// Change br (not X), label True, label False to: br X, label False, True
- Value *X = nullptr;
- if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) &&
- !isa<Constant>(X)) {
+ Value *Cond = BI.getCondition();
+ Value *X;
+ if (match(Cond, m_Not(m_Value(X))) && !isa<Constant>(X)) {
// Swap Destinations and condition...
BI.swapSuccessors();
return replaceOperand(BI, 0, X);
}
+ // Canonicalize logical-and-with-invert as logical-or-with-invert.
+ // This is done by inverting the condition and swapping successors:
+ // br (X && !Y), T, F --> br !(X && !Y), F, T --> br (!X || Y), F, T
+ Value *Y;
+ if (isa<SelectInst>(Cond) &&
+ match(Cond,
+ m_OneUse(m_LogicalAnd(m_Value(X), m_OneUse(m_Not(m_Value(Y))))))) {
+ Value *NotX = Builder.CreateNot(X, "not." + X->getName());
+ Value *Or = Builder.CreateLogicalOr(NotX, Y);
+ BI.swapSuccessors();
+ return replaceOperand(BI, 0, Or);
+ }
+
// If the condition is irrelevant, remove the use so that other
// transforms on the condition become more effective.
- if (!isa<ConstantInt>(BI.getCondition()) &&
- BI.getSuccessor(0) == BI.getSuccessor(1))
- return replaceOperand(
- BI, 0, ConstantInt::getFalse(BI.getCondition()->getType()));
+ if (!isa<ConstantInt>(Cond) && BI.getSuccessor(0) == BI.getSuccessor(1))
+ return replaceOperand(BI, 0, ConstantInt::getFalse(Cond->getType()));
// Canonicalize, for example, fcmp_one -> fcmp_oeq.
CmpInst::Predicate Pred;
- if (match(&BI, m_Br(m_OneUse(m_FCmp(Pred, m_Value(), m_Value())),
- m_BasicBlock(), m_BasicBlock())) &&
+ if (match(Cond, m_OneUse(m_FCmp(Pred, m_Value(), m_Value()))) &&
!isCanonicalPredicate(Pred)) {
// Swap destinations and condition.
- CmpInst *Cond = cast<CmpInst>(BI.getCondition());
- Cond->setPredicate(CmpInst::getInversePredicate(Pred));
+ auto *Cmp = cast<CmpInst>(Cond);
+ Cmp->setPredicate(CmpInst::getInversePredicate(Pred));
BI.swapSuccessors();
- Worklist.push(Cond);
+ Worklist.push(Cmp);
return &BI;
}
@@ -3218,7 +3242,7 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
// Compute the number of leading bits we can ignore.
// TODO: A better way to determine this would use ComputeNumSignBits().
- for (auto &C : SI.cases()) {
+ for (const auto &C : SI.cases()) {
LeadingKnownZeros = std::min(
LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros());
LeadingKnownOnes = std::min(
@@ -3247,6 +3271,81 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) {
return nullptr;
}
+Instruction *
+InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) {
+ auto *WO = dyn_cast<WithOverflowInst>(EV.getAggregateOperand());
+ if (!WO)
+ return nullptr;
+
+ Intrinsic::ID OvID = WO->getIntrinsicID();
+ const APInt *C = nullptr;
+ if (match(WO->getRHS(), m_APIntAllowUndef(C))) {
+ if (*EV.idx_begin() == 0 && (OvID == Intrinsic::smul_with_overflow ||
+ OvID == Intrinsic::umul_with_overflow)) {
+ // extractvalue (any_mul_with_overflow X, -1), 0 --> -X
+ if (C->isAllOnes())
+ return BinaryOperator::CreateNeg(WO->getLHS());
+ // extractvalue (any_mul_with_overflow X, 2^n), 0 --> X << n
+ if (C->isPowerOf2()) {
+ return BinaryOperator::CreateShl(
+ WO->getLHS(),
+ ConstantInt::get(WO->getLHS()->getType(), C->logBase2()));
+ }
+ }
+ }
+
+ // We're extracting from an overflow intrinsic. See if we're the only user.
+ // That allows us to simplify multiple result intrinsics to simpler things
+ // that just get one value.
+ if (!WO->hasOneUse())
+ return nullptr;
+
+ // Check if we're grabbing only the result of a 'with overflow' intrinsic
+ // and replace it with a traditional binary instruction.
+ if (*EV.idx_begin() == 0) {
+ Instruction::BinaryOps BinOp = WO->getBinaryOp();
+ Value *LHS = WO->getLHS(), *RHS = WO->getRHS();
+ // Replace the old instruction's uses with poison.
+ replaceInstUsesWith(*WO, PoisonValue::get(WO->getType()));
+ eraseInstFromFunction(*WO);
+ return BinaryOperator::Create(BinOp, LHS, RHS);
+ }
+
+ assert(*EV.idx_begin() == 1 && "Unexpected extract index for overflow inst");
+
+ // (usub LHS, RHS) overflows when LHS is unsigned-less-than RHS.
+ if (OvID == Intrinsic::usub_with_overflow)
+ return new ICmpInst(ICmpInst::ICMP_ULT, WO->getLHS(), WO->getRHS());
+
+ // smul with i1 types overflows when both sides are set: -1 * -1 == +1, but
+ // +1 is not possible because we assume signed values.
+ if (OvID == Intrinsic::smul_with_overflow &&
+ WO->getLHS()->getType()->isIntOrIntVectorTy(1))
+ return BinaryOperator::CreateAnd(WO->getLHS(), WO->getRHS());
+
+ // If only the overflow result is used, and the right hand side is a
+ // constant (or constant splat), we can remove the intrinsic by directly
+ // checking for overflow.
+ if (C) {
+ // Compute the no-wrap range for LHS given RHS=C, then construct an
+ // equivalent icmp, potentially using an offset.
+ ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(
+ WO->getBinaryOp(), *C, WO->getNoWrapKind());
+
+ CmpInst::Predicate Pred;
+ APInt NewRHSC, Offset;
+ NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
+ auto *OpTy = WO->getRHS()->getType();
+ auto *NewLHS = WO->getLHS();
+ if (Offset != 0)
+ NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset));
+ return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS,
+ ConstantInt::get(OpTy, NewRHSC));
+ }
+
+ return nullptr;
+}
+
Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
Value *Agg = EV.getAggregateOperand();
@@ -3294,7 +3393,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
Value *NewEV = Builder.CreateExtractValue(IV->getAggregateOperand(),
EV.getIndices());
return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(),
- makeArrayRef(insi, inse));
+ ArrayRef(insi, inse));
}
if (insi == inse)
// The insert list is a prefix of the extract list
@@ -3306,60 +3405,13 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
// with
// %E extractvalue { i32 } { i32 42 }, 0
return ExtractValueInst::Create(IV->getInsertedValueOperand(),
- makeArrayRef(exti, exte));
+ ArrayRef(exti, exte));
}
- if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) {
- // extractvalue (any_mul_with_overflow X, -1), 0 --> -X
- Intrinsic::ID OvID = WO->getIntrinsicID();
- if (*EV.idx_begin() == 0 &&
- (OvID == Intrinsic::smul_with_overflow ||
- OvID == Intrinsic::umul_with_overflow) &&
- match(WO->getArgOperand(1), m_AllOnes())) {
- return BinaryOperator::CreateNeg(WO->getArgOperand(0));
- }
-
- // We're extracting from an overflow intrinsic, see if we're the only user,
- // which allows us to simplify multiple result intrinsics to simpler
- // things that just get one value.
- if (WO->hasOneUse()) {
- // Check if we're grabbing only the result of a 'with overflow' intrinsic
- // and replace it with a traditional binary instruction.
- if (*EV.idx_begin() == 0) {
- Instruction::BinaryOps BinOp = WO->getBinaryOp();
- Value *LHS = WO->getLHS(), *RHS = WO->getRHS();
- // Replace the old instruction's uses with poison.
- replaceInstUsesWith(*WO, PoisonValue::get(WO->getType()));
- eraseInstFromFunction(*WO);
- return BinaryOperator::Create(BinOp, LHS, RHS);
- }
-
- assert(*EV.idx_begin() == 1 &&
- "unexpected extract index for overflow inst");
- // If only the overflow result is used, and the right hand side is a
- // constant (or constant splat), we can remove the intrinsic by directly
- // checking for overflow.
- const APInt *C;
- if (match(WO->getRHS(), m_APInt(C))) {
- // Compute the no-wrap range for LHS given RHS=C, then construct an
- // equivalent icmp, potentially using an offset.
- ConstantRange NWR =
- ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C,
- WO->getNoWrapKind());
+ if (Instruction *R = foldExtractOfOverflowIntrinsic(EV))
+ return R;
- CmpInst::Predicate Pred;
- APInt NewRHSC, Offset;
- NWR.getEquivalentICmp(Pred, NewRHSC, Offset);
- auto *OpTy = WO->getRHS()->getType();
- auto *NewLHS = WO->getLHS();
- if (Offset != 0)
- NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset));
- return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS,
- ConstantInt::get(OpTy, NewRHSC));
- }
- }
- }
- if (LoadInst *L = dyn_cast<LoadInst>(Agg))
+ if (LoadInst *L = dyn_cast<LoadInst>(Agg)) {
// If the (non-volatile) load only has one use, we can rewrite this to a
// load from a GEP. This reduces the size of the load. If a load is used
// only by extractvalue instructions then this either must have been
@@ -3386,6 +3438,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) {
// the wrong spot, so use replaceInstUsesWith().
return replaceInstUsesWith(EV, NL);
}
+ }
+
+ if (auto *PN = dyn_cast<PHINode>(Agg))
+ if (Instruction *Res = foldOpIntoPhi(EV, PN))
+ return Res;
+
// We could simplify extracts from other values. Note that nested extracts may
// already be simplified implicitly by the above: extract (extract (insert) )
// will be translated into extract ( insert ( extract ) ) first and then just
@@ -3771,7 +3829,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) {
// poison. If the only source of new poison is flags, we can simply
// strip them (since we know the only use is the freeze and nothing can
// benefit from them.)
- if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false))
+ if (canCreateUndefOrPoison(cast<Operator>(OrigOp),
+ /*ConsiderFlagsAndMetadata*/ false))
return nullptr;
// If operand is guaranteed not to be poison, there is no need to add freeze
@@ -3779,7 +3838,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) {
// poison.
Use *MaybePoisonOperand = nullptr;
for (Use &U : OrigOpInst->operands()) {
- if (isGuaranteedNotToBeUndefOrPoison(U.get()))
+ if (isa<MetadataAsValue>(U.get()) ||
+ isGuaranteedNotToBeUndefOrPoison(U.get()))
continue;
if (!MaybePoisonOperand)
MaybePoisonOperand = &U;
@@ -3787,7 +3847,7 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) {
return nullptr;
}
- OrigOpInst->dropPoisonGeneratingFlags();
+ OrigOpInst->dropPoisonGeneratingFlagsAndMetadata();
// If all operands are guaranteed to be non-poison, we can drop freeze.
if (!MaybePoisonOperand)
@@ -3850,7 +3910,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI,
Instruction *I = dyn_cast<Instruction>(V);
if (!I || canCreateUndefOrPoison(cast<Operator>(I),
- /*ConsiderFlags*/ false))
+ /*ConsiderFlagsAndMetadata*/ false))
return nullptr;
DropFlags.push_back(I);
@@ -3858,7 +3918,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI,
}
for (Instruction *I : DropFlags)
- I->dropPoisonGeneratingFlags();
+ I->dropPoisonGeneratingFlagsAndMetadata();
if (StartNeedsFreeze) {
Builder.SetInsertPoint(StartBB->getTerminator());
@@ -3880,21 +3940,14 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) {
// *all* uses if the operand is an invoke/callbr and the use is in a phi on
// the normal/default destination. This is why the domination check in the
// replacement below is still necessary.
- Instruction *MoveBefore = nullptr;
+ Instruction *MoveBefore;
if (isa<Argument>(Op)) {
- MoveBefore = &FI.getFunction()->getEntryBlock().front();
- while (isa<AllocaInst>(MoveBefore))
- MoveBefore = MoveBefore->getNextNode();
- } else if (auto *PN = dyn_cast<PHINode>(Op)) {
- MoveBefore = PN->getParent()->getFirstNonPHI();
- } else if (auto *II = dyn_cast<InvokeInst>(Op)) {
- MoveBefore = II->getNormalDest()->getFirstNonPHI();
- } else if (auto *CB = dyn_cast<CallBrInst>(Op)) {
- MoveBefore = CB->getDefaultDest()->getFirstNonPHI();
+ MoveBefore =
+ &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca();
} else {
- auto *I = cast<Instruction>(Op);
- assert(!I->isTerminator() && "Cannot be a terminator");
- MoveBefore = I->getNextNode();
+ MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef();
+ if (!MoveBefore)
+ return false;
}
bool Changed = false;
@@ -3987,7 +4040,7 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) {
// to allow reload along used path as described below. Otherwise, this
// is simply a store to a dead allocation which will be removed.
return false;
- Optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI);
+ std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI);
if (!Dest)
return false;
auto *AI = dyn_cast<AllocaInst>(getUnderlyingObject(Dest->Ptr));
@@ -4103,7 +4156,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock,
SmallVector<DbgVariableIntrinsic *, 2> DIIClones;
SmallSet<DebugVariable, 4> SunkVariables;
- for (auto User : DbgUsersToSink) {
+ for (auto *User : DbgUsersToSink) {
// A dbg.declare instruction should not be cloned, since there can only be
// one per variable fragment. It should be left in the original place
// because the sunk instruction is not an alloca (otherwise we could not be
@@ -4118,6 +4171,11 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock,
if (!SunkVariables.insert(DbgUserVariable).second)
continue;
+ // Leave dbg.assign intrinsics in their original positions and there should
+ // be no need to insert a clone.
+ if (isa<DbgAssignIntrinsic>(User))
+ continue;
+
DIIClones.emplace_back(cast<DbgVariableIntrinsic>(User->clone()));
if (isa<DbgDeclareInst>(User) && isa<CastInst>(I))
DIIClones.back()->replaceVariableLocationOp(I, I->getOperand(0));
@@ -4190,9 +4248,9 @@ bool InstCombinerImpl::run() {
// prove that the successor is not executed more frequently than our block.
// Return the UserBlock if successful.
auto getOptionalSinkBlockForInst =
- [this](Instruction *I) -> Optional<BasicBlock *> {
+ [this](Instruction *I) -> std::optional<BasicBlock *> {
if (!EnableCodeSinking)
- return None;
+ return std::nullopt;
BasicBlock *BB = I->getParent();
BasicBlock *UserParent = nullptr;
@@ -4202,7 +4260,7 @@ bool InstCombinerImpl::run() {
if (U->isDroppable())
continue;
if (NumUsers > MaxSinkNumUsers)
- return None;
+ return std::nullopt;
Instruction *UserInst = cast<Instruction>(U);
// Special handling for Phi nodes - get the block the use occurs in.
@@ -4213,14 +4271,14 @@ bool InstCombinerImpl::run() {
// sophisticated analysis (i.e finding NearestCommonDominator of
// these use blocks).
if (UserParent && UserParent != PN->getIncomingBlock(i))
- return None;
+ return std::nullopt;
UserParent = PN->getIncomingBlock(i);
}
}
assert(UserParent && "expected to find user block!");
} else {
if (UserParent && UserParent != UserInst->getParent())
- return None;
+ return std::nullopt;
UserParent = UserInst->getParent();
}
@@ -4230,7 +4288,7 @@ bool InstCombinerImpl::run() {
// Try sinking to another block. If that block is unreachable, then do
// not bother. SimplifyCFG should handle it.
if (UserParent == BB || !DT.isReachableFromEntry(UserParent))
- return None;
+ return std::nullopt;
auto *Term = UserParent->getTerminator();
// See if the user is one of our successors that has only one
@@ -4242,7 +4300,7 @@ bool InstCombinerImpl::run() {
// - the User will be executed at most once.
// So sinking I down to User is always profitable or neutral.
if (UserParent->getUniquePredecessor() != BB && !succ_empty(Term))
- return None;
+ return std::nullopt;
assert(DT.dominates(BB, UserParent) && "Dominance relation broken?");
}
@@ -4252,7 +4310,7 @@ bool InstCombinerImpl::run() {
// No user or only has droppable users.
if (!UserParent)
- return None;
+ return std::nullopt;
return UserParent;
};
@@ -4312,7 +4370,7 @@ bool InstCombinerImpl::run() {
InsertPos = InstParent->getFirstNonPHI()->getIterator();
}
- InstParent->getInstList().insert(InsertPos, Result);
+ Result->insertInto(InstParent, InsertPos);
// Push the new instruction and any users onto the worklist.
Worklist.pushUsersToWorkList(*Result);
@@ -4360,7 +4418,7 @@ public:
const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList);
if (!MDScopeList || !Container.insert(MDScopeList).second)
return;
- for (auto &MDOperand : MDScopeList->operands())
+ for (const auto &MDOperand : MDScopeList->operands())
if (auto *MDScope = dyn_cast<MDNode>(MDOperand))
Container.insert(MDScope);
};
@@ -4543,6 +4601,13 @@ static bool combineInstructionsOverFunction(
bool MadeIRChange = false;
if (ShouldLowerDbgDeclare)
MadeIRChange = LowerDbgDeclare(F);
+ // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will
+ // almost never return true when running an assignment tracking build. Take
+ // this opportunity to do some clean up for assignment tracking builds too.
+ if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) {
+ for (auto &BB : F)
+ RemoveRedundantDbgInstrs(&BB);
+ }
// Iterate while there is work to do.
unsigned Iteration = 0;