diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-12-09 13:28:42 +0000 |
| commit | b1c73532ee8997fe5dfbeb7d223027bdf99758a0 (patch) | |
| tree | 7d6e51c294ab6719475d660217aa0c0ad0526292 /llvm/lib/Transforms/InstCombine | |
| parent | 7fa27ce4a07f19b07799a767fc29416f3b625afb (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
15 files changed, 2895 insertions, 1786 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 91ca44e0f11e..719a2678fc18 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -830,15 +830,15 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, // (sext (X +nsw NarrowC)) + C --> (sext X) + (sext(NarrowC) + C) Constant *NarrowC; if (match(Op0, m_OneUse(m_SExt(m_NSWAdd(m_Value(X), m_Constant(NarrowC)))))) { - Constant *WideC = ConstantExpr::getSExt(NarrowC, Ty); - Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideC = Builder.CreateSExt(NarrowC, Ty); + Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateSExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } // (zext (X +nuw NarrowC)) + C --> (zext X) + (zext(NarrowC) + C) if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_Constant(NarrowC)))))) { - Constant *WideC = ConstantExpr::getZExt(NarrowC, Ty); - Constant *NewC = ConstantExpr::getAdd(WideC, Op1C); + Value *WideC = Builder.CreateZExt(NarrowC, Ty); + Value *NewC = Builder.CreateAdd(WideC, Op1C); Value *WideX = Builder.CreateZExt(X, Ty); return BinaryOperator::CreateAdd(WideX, NewC); } @@ -903,8 +903,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // (X | Op01C) + Op1C --> X + (Op01C + Op1C) iff the `or` is actually an `add` Constant *Op01C; - if (match(Op0, m_Or(m_Value(X), m_ImmConstant(Op01C))) && - haveNoCommonBitsSet(X, Op01C, DL, &AC, &Add, &DT)) + if (match(Op0, m_DisjointOr(m_Value(X), m_ImmConstant(Op01C)))) return BinaryOperator::CreateAdd(X, ConstantExpr::getAdd(Op01C, Op1C)); // (X | C2) + C --> (X | C2) ^ C2 iff (C2 == -C) @@ -995,6 +994,69 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { return nullptr; } +// match variations of a^2 + 2*a*b + b^2 +// +// to reuse the code between the FP and Int versions, the instruction OpCodes +// and constant types have been turned into template parameters. +// +// Mul2Rhs: The constant to perform the multiplicative equivalent of X*2 with; +// should be `m_SpecificFP(2.0)` for FP and `m_SpecificInt(1)` for Int +// (we're matching `X<<1` instead of `X*2` for Int) +template <bool FP, typename Mul2Rhs> +static bool matchesSquareSum(BinaryOperator &I, Mul2Rhs M2Rhs, Value *&A, + Value *&B) { + constexpr unsigned MulOp = FP ? Instruction::FMul : Instruction::Mul; + constexpr unsigned AddOp = FP ? Instruction::FAdd : Instruction::Add; + constexpr unsigned Mul2Op = FP ? Instruction::FMul : Instruction::Shl; + + // (a * a) + (((a * 2) + b) * b) + if (match(&I, m_c_BinOp( + AddOp, m_OneUse(m_BinOp(MulOp, m_Value(A), m_Deferred(A))), + m_OneUse(m_BinOp( + MulOp, + m_c_BinOp(AddOp, m_BinOp(Mul2Op, m_Deferred(A), M2Rhs), + m_Value(B)), + m_Deferred(B)))))) + return true; + + // ((a * b) * 2) or ((a * 2) * b) + // + + // (a * a + b * b) or (b * b + a * a) + return match( + &I, + m_c_BinOp(AddOp, + m_CombineOr( + m_OneUse(m_BinOp( + Mul2Op, m_BinOp(MulOp, m_Value(A), m_Value(B)), M2Rhs)), + m_OneUse(m_BinOp(MulOp, m_BinOp(Mul2Op, m_Value(A), M2Rhs), + m_Value(B)))), + m_OneUse(m_c_BinOp( + AddOp, m_BinOp(MulOp, m_Deferred(A), m_Deferred(A)), + m_BinOp(MulOp, m_Deferred(B), m_Deferred(B)))))); +} + +// Fold integer variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +Instruction *InstCombinerImpl::foldSquareSumInt(BinaryOperator &I) { + Value *A, *B; + if (matchesSquareSum</*FP*/ false>(I, m_SpecificInt(1), A, B)) { + Value *AB = Builder.CreateAdd(A, B); + return BinaryOperator::CreateMul(AB, AB); + } + return nullptr; +} + +// Fold floating point variations of a^2 + 2*a*b + b^2 -> (a + b)^2 +// Requires `nsz` and `reassoc`. +Instruction *InstCombinerImpl::foldSquareSumFP(BinaryOperator &I) { + assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "Assumption mismatch"); + Value *A, *B; + if (matchesSquareSum</*FP*/ true>(I, m_SpecificFP(2.0), A, B)) { + Value *AB = Builder.CreateFAddFMF(A, B, &I); + return BinaryOperator::CreateFMulFMF(AB, AB, &I); + } + return nullptr; +} + // Matches multiplication expression Op * C where C is a constant. Returns the // constant value in C and the other operand in Op. Returns true if such a // match is found. @@ -1146,6 +1208,21 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +// Transform: +// (add A, (shl (neg B), Y)) +// -> (sub A, (shl B, Y)) +static Instruction *combineAddSubWithShlAddSub(InstCombiner::BuilderTy &Builder, + const BinaryOperator &I) { + Value *A, *B, *Cnt; + if (match(&I, + m_c_Add(m_OneUse(m_Shl(m_OneUse(m_Neg(m_Value(B))), m_Value(Cnt))), + m_Value(A)))) { + Value *NewShl = Builder.CreateShl(B, Cnt); + return BinaryOperator::CreateSub(A, NewShl); + } + return nullptr; +} + /// Try to reduce signed division by power-of-2 to an arithmetic shift right. static Instruction *foldAddToAshr(BinaryOperator &Add) { // Division must be by power-of-2, but not the minimum signed value. @@ -1156,18 +1233,28 @@ static Instruction *foldAddToAshr(BinaryOperator &Add) { return nullptr; // Rounding is done by adding -1 if the dividend (X) is negative and has any - // low bits set. The canonical pattern for that is an "ugt" compare with SMIN: - // sext (icmp ugt (X & (DivC - 1)), SMIN) - const APInt *MaskC; + // low bits set. It recognizes two canonical patterns: + // 1. For an 'ugt' cmp with the signed minimum value (SMIN), the + // pattern is: sext (icmp ugt (X & (DivC - 1)), SMIN). + // 2. For an 'eq' cmp, the pattern's: sext (icmp eq X & (SMIN + 1), SMIN + 1). + // Note that, by the time we end up here, if possible, ugt has been + // canonicalized into eq. + const APInt *MaskC, *MaskCCmp; ICmpInst::Predicate Pred; if (!match(Add.getOperand(1), m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)), - m_SignMask()))) || - Pred != ICmpInst::ICMP_UGT) + m_APInt(MaskCCmp))))) + return nullptr; + + if ((Pred != ICmpInst::ICMP_UGT || !MaskCCmp->isSignMask()) && + (Pred != ICmpInst::ICMP_EQ || *MaskCCmp != *MaskC)) return nullptr; APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits()); - if (*MaskC != (SMin | (*DivC - 1))) + bool IsMaskValid = Pred == ICmpInst::ICMP_UGT + ? (*MaskC == (SMin | (*DivC - 1))) + : (*DivC == 2 && *MaskC == SMin + 1); + if (!IsMaskValid) return nullptr; // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC) @@ -1327,8 +1414,10 @@ static Instruction *foldBoxMultiply(BinaryOperator &I) { // ResLo = (CrossSum << HalfBits) + (YLo * XLo) Value *XLo, *YLo; Value *CrossSum; + // Require one-use on the multiply to avoid increasing the number of + // multiplications. if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), - m_Mul(m_Value(YLo), m_Value(XLo))))) + m_OneUse(m_Mul(m_Value(YLo), m_Value(XLo)))))) return nullptr; // XLo = X & HalfMask @@ -1386,6 +1475,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Instruction *R = foldBinOpShiftWithShift(I)) return R; + if (Instruction *R = combineAddSubWithShlAddSub(Builder, I)) + return R; + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) @@ -1406,7 +1498,11 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B)); // -A + B --> B - A - return BinaryOperator::CreateSub(RHS, A); + auto *Sub = BinaryOperator::CreateSub(RHS, A); + auto *OB0 = cast<OverflowingBinaryOperator>(LHS); + Sub->setHasNoSignedWrap(I.hasNoSignedWrap() && OB0->hasNoSignedWrap()); + + return Sub; } // A + -B --> A - B @@ -1485,8 +1581,9 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return replaceInstUsesWith(I, Constant::getNullValue(I.getType())); // A+B --> A|B iff A and B have no bits set in common. - if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) - return BinaryOperator::CreateOr(LHS, RHS); + WithCache<const Value *> LHSCache(LHS), RHSCache(RHS); + if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(&I))) + return BinaryOperator::CreateDisjointOr(LHS, RHS); if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1576,15 +1673,33 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { m_c_UMin(m_Deferred(A), m_Deferred(B)))))) return BinaryOperator::CreateWithCopiedFlags(Instruction::Add, A, B, &I); + // (~X) + (~Y) --> -2 - (X + Y) + { + // To ensure we can save instructions we need to ensure that we consume both + // LHS/RHS (i.e they have a `not`). + bool ConsumesLHS, ConsumesRHS; + if (isFreeToInvert(LHS, LHS->hasOneUse(), ConsumesLHS) && ConsumesLHS && + isFreeToInvert(RHS, RHS->hasOneUse(), ConsumesRHS) && ConsumesRHS) { + Value *NotLHS = getFreelyInverted(LHS, LHS->hasOneUse(), &Builder); + Value *NotRHS = getFreelyInverted(RHS, RHS->hasOneUse(), &Builder); + assert(NotLHS != nullptr && NotRHS != nullptr && + "isFreeToInvert desynced with getFreelyInverted"); + Value *LHSPlusRHS = Builder.CreateAdd(NotLHS, NotRHS); + return BinaryOperator::CreateSub(ConstantInt::get(RHS->getType(), -2), + LHSPlusRHS); + } + } + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHS, RHS, I)) { + if (!I.hasNoSignedWrap() && willNotOverflowSignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedAdd(LHS, RHS, I)) { + if (!I.hasNoUnsignedWrap() && + willNotOverflowUnsignedAdd(LHSCache, RHSCache, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -1610,11 +1725,14 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { // ctpop(A) + ctpop(B) => ctpop(A | B) if A and B have no bits set in common. if (match(LHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(A)))) && match(RHS, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(B)))) && - haveNoCommonBitsSet(A, B, DL, &AC, &I, &DT)) + haveNoCommonBitsSet(A, B, SQ.getWithInstruction(&I))) return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateOr(A, B)})); + if (Instruction *Res = foldSquareSumInt(I)) + return Res; + if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; @@ -1755,10 +1873,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { // instcombined. if (ConstantFP *CFP = dyn_cast<ConstantFP>(RHS)) if (IsValidPromotion(FPType, LHSIntVal->getType())) { - Constant *CI = - ConstantExpr::getFPToSI(CFP, LHSIntVal->getType()); + Constant *CI = ConstantFoldCastOperand(Instruction::FPToSI, CFP, + LHSIntVal->getType(), DL); if (LHSConv->hasOneUse() && - ConstantExpr::getSIToFP(CI, I.getType()) == CFP && + ConstantFoldCastOperand(Instruction::SIToFP, CI, I.getType(), DL) == + CFP && willNotOverflowSignedAdd(LHSIntVal, CI, I)) { // Insert the new integer add. Value *NewAdd = Builder.CreateNSWAdd(LHSIntVal, CI, "addconv"); @@ -1794,6 +1913,9 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { if (Instruction *F = factorizeFAddFSub(I, Builder)) return F; + if (Instruction *F = foldSquareSumFP(I)) + return F; + // Try to fold fadd into start value of reduction intrinsic. if (match(&I, m_c_FAdd(m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_fadd>( m_AnyZeroFP(), m_Value(X))), @@ -2017,14 +2139,16 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // C-(X+C2) --> (C-C2)-X if (match(Op1, m_Add(m_Value(X), m_ImmConstant(C2)))) { - // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW - // => (C-C2)-X can have NSW + // C-C2 never overflow, and C-(X+C2), (X+C2) has NSW/NUW + // => (C-C2)-X can have NSW/NUW bool WillNotSOV = willNotOverflowSignedSub(C, C2, I); BinaryOperator *Res = BinaryOperator::CreateSub(ConstantExpr::getSub(C, C2), X); auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); Res->setHasNoSignedWrap(I.hasNoSignedWrap() && OBO1->hasNoSignedWrap() && WillNotSOV); + Res->setHasNoUnsignedWrap(I.hasNoUnsignedWrap() && + OBO1->hasNoUnsignedWrap()); return Res; } } @@ -2058,7 +2182,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { m_Select(m_Value(), m_Specific(Op1), m_Specific(&I))) || match(UI, m_Select(m_Value(), m_Specific(&I), m_Specific(Op1))); })) { - if (Value *NegOp1 = Negator::Negate(IsNegation, Op1, *this)) + if (Value *NegOp1 = Negator::Negate(IsNegation, /* IsNSW */ IsNegation && + I.hasNoSignedWrap(), + Op1, *this)) return BinaryOperator::CreateAdd(NegOp1, Op0); } if (IsNegation) @@ -2093,19 +2219,50 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // ((X - Y) - Op1) --> X - (Y + Op1) if (match(Op0, m_OneUse(m_Sub(m_Value(X), m_Value(Y))))) { - Value *Add = Builder.CreateAdd(Y, Op1); - return BinaryOperator::CreateSub(X, Add); + OverflowingBinaryOperator *LHSSub = cast<OverflowingBinaryOperator>(Op0); + bool HasNUW = I.hasNoUnsignedWrap() && LHSSub->hasNoUnsignedWrap(); + bool HasNSW = HasNUW && I.hasNoSignedWrap() && LHSSub->hasNoSignedWrap(); + Value *Add = Builder.CreateAdd(Y, Op1, "", /* HasNUW */ HasNUW, + /* HasNSW */ HasNSW); + BinaryOperator *Sub = BinaryOperator::CreateSub(X, Add); + Sub->setHasNoUnsignedWrap(HasNUW); + Sub->setHasNoSignedWrap(HasNSW); + return Sub; + } + + { + // (X + Z) - (Y + Z) --> (X - Y) + // This is done in other passes, but we want to be able to consume this + // pattern in InstCombine so we can generate it without creating infinite + // loops. + if (match(Op0, m_Add(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Add(m_Value(Y), m_Specific(Z)))) + return BinaryOperator::CreateSub(X, Y); + + // (X + C0) - (Y + C1) --> (X - Y) + (C0 - C1) + Constant *CX, *CY; + if (match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(CX)))) && + match(Op1, m_OneUse(m_Add(m_Value(Y), m_ImmConstant(CY))))) { + Value *OpsSub = Builder.CreateSub(X, Y); + Constant *ConstsSub = ConstantExpr::getSub(CX, CY); + return BinaryOperator::CreateAdd(OpsSub, ConstsSub); + } } // (~X) - (~Y) --> Y - X - // This is placed after the other reassociations and explicitly excludes a - // sub-of-sub pattern to avoid infinite looping. - if (isFreeToInvert(Op0, Op0->hasOneUse()) && - isFreeToInvert(Op1, Op1->hasOneUse()) && - !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) { - Value *NotOp0 = Builder.CreateNot(Op0); - Value *NotOp1 = Builder.CreateNot(Op1); - return BinaryOperator::CreateSub(NotOp1, NotOp0); + { + // Need to ensure we can consume at least one of the `not` instructions, + // otherwise this can inf loop. + bool ConsumesOp0, ConsumesOp1; + if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) && + isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) && + (ConsumesOp0 || ConsumesOp1)) { + Value *NotOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder); + Value *NotOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder); + assert(NotOp0 != nullptr && NotOp1 != nullptr && + "isFreeToInvert desynced with getFreelyInverted"); + return BinaryOperator::CreateSub(NotOp1, NotOp0); + } } auto m_AddRdx = [](Value *&Vec) { @@ -2520,18 +2677,33 @@ static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) { return nullptr; } -static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, - InstCombiner::BuilderTy &Builder) { - Value *FNeg; - if (!match(&I, m_FNeg(m_Value(FNeg)))) - return nullptr; - +Instruction *InstCombinerImpl::hoistFNegAboveFMulFDiv(Value *FNegOp, + Instruction &FMFSource) { Value *X, *Y; - if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) - return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + if (match(FNegOp, m_FMul(m_Value(X), m_Value(Y)))) { + return cast<Instruction>(Builder.CreateFMulFMF( + Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource)); + } + + if (match(FNegOp, m_FDiv(m_Value(X), m_Value(Y)))) { + return cast<Instruction>(Builder.CreateFDivFMF( + Builder.CreateFNegFMF(X, &FMFSource), Y, &FMFSource)); + } + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(FNegOp)) { + // Make sure to preserve flags and metadata on the call. + if (II->getIntrinsicID() == Intrinsic::ldexp) { + FastMathFlags FMF = FMFSource.getFastMathFlags() | II->getFastMathFlags(); + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); - if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) - return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + CallInst *New = Builder.CreateCall( + II->getCalledFunction(), + {Builder.CreateFNeg(II->getArgOperand(0)), II->getArgOperand(1)}); + New->copyMetadata(*II); + return New; + } + } return nullptr; } @@ -2553,13 +2725,13 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) return BinaryOperator::CreateFSubFMF(Y, X, &I); - if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) - return R; - Value *OneUse; if (!match(Op, m_OneUse(m_Value(OneUse)))) return nullptr; + if (Instruction *R = hoistFNegAboveFMulFDiv(OneUse, I)) + return replaceInstUsesWith(I, R); + // Try to eliminate fneg if at least 1 arm of the select is negated. Value *Cond; if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) { @@ -2569,8 +2741,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); if (auto *OldSel = dyn_cast<SelectInst>(Op)) { - FastMathFlags FMF = I.getFastMathFlags(); - FMF |= OldSel->getFastMathFlags(); + FastMathFlags FMF = I.getFastMathFlags() | OldSel->getFastMathFlags(); S->setFastMathFlags(FMF); if (!OldSel->hasNoSignedZeros() && !CommonOperand && !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) @@ -2638,9 +2809,6 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; - if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) - return R; - Value *X, *Y; Constant *C; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8a1fb6b7f17e..6002f599ca71 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1099,39 +1099,6 @@ static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); } - Value *Base, *Offset; - if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) - return nullptr; - - if (!match(UnsignedICmp, - m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || - !ICmpInst::isUnsigned(UnsignedPred)) - return nullptr; - - // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset - // (no overflow and not null) - if ((UnsignedPred == ICmpInst::ICMP_UGE || - UnsignedPred == ICmpInst::ICMP_UGT) && - EqPred == ICmpInst::ICMP_NE && IsAnd) - return Builder.CreateICmpUGT(Base, Offset); - - // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset - // (overflow or null) - if ((UnsignedPred == ICmpInst::ICMP_ULE || - UnsignedPred == ICmpInst::ICMP_ULT) && - EqPred == ICmpInst::ICMP_EQ && !IsAnd) - return Builder.CreateICmpULE(Base, Offset); - - // Base <= Offset && (Base - Offset) != 0 --> Base < Offset - if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && - IsAnd) - return Builder.CreateICmpULT(Base, Offset); - - // Base > Offset || (Base - Offset) == 0 --> Base >= Offset - if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && - !IsAnd) - return Builder.CreateICmpUGE(Base, Offset); - return nullptr; } @@ -1179,13 +1146,40 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; CmpInst::Predicate Pred = IsAnd ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; - if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) - return nullptr; + auto GetMatchPart = [&](ICmpInst *Cmp, + unsigned OpNo) -> std::optional<IntPart> { + if (Pred == Cmp->getPredicate()) + return matchIntPart(Cmp->getOperand(OpNo)); + + const APInt *C; + // (icmp eq (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp ult (xor x, y), 1 << C) so also look for that. + if (Pred == CmpInst::ICMP_EQ && Cmp->getPredicate() == CmpInst::ICMP_ULT) { + if (!match(Cmp->getOperand(1), m_Power2(C)) || + !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value()))) + return std::nullopt; + } - std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); - std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); - std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); - std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); + // (icmp ne (lshr x, C), (lshr y, C)) gets optimized to: + // (icmp ugt (xor x, y), (1 << C) - 1) so also look for that. + else if (Pred == CmpInst::ICMP_NE && + Cmp->getPredicate() == CmpInst::ICMP_UGT) { + if (!match(Cmp->getOperand(1), m_LowBitMask(C)) || + !match(Cmp->getOperand(0), m_Xor(m_Value(), m_Value()))) + return std::nullopt; + } else { + return std::nullopt; + } + + unsigned From = Pred == CmpInst::ICMP_NE ? C->popcount() : C->countr_zero(); + Instruction *I = cast<Instruction>(Cmp->getOperand(0)); + return {{I->getOperand(OpNo), From, C->getBitWidth() - From}}; + }; + + std::optional<IntPart> L0 = GetMatchPart(Cmp0, 0); + std::optional<IntPart> R0 = GetMatchPart(Cmp0, 1); + std::optional<IntPart> L1 = GetMatchPart(Cmp1, 0); + std::optional<IntPart> R1 = GetMatchPart(Cmp1, 1); if (!L0 || !R0 || !L1 || !R1) return nullptr; @@ -1616,7 +1610,7 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) static Instruction *matchDeMorgansLaws(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { + InstCombiner &IC) { const Instruction::BinaryOps Opcode = I.getOpcode(); assert((Opcode == Instruction::And || Opcode == Instruction::Or) && "Trying to match De Morgan's Laws with something other than and/or"); @@ -1629,10 +1623,10 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *A, *B; if (match(Op0, m_OneUse(m_Not(m_Value(A)))) && match(Op1, m_OneUse(m_Not(m_Value(B)))) && - !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && - !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { + !IC.isFreeToInvert(A, A->hasOneUse()) && + !IC.isFreeToInvert(B, B->hasOneUse())) { Value *AndOr = - Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); + IC.Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } @@ -1644,8 +1638,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *C; if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && match(Op1, m_Not(m_Value(C)))) { - Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); - return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); + Value *FlippedBO = IC.Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, IC.Builder.CreateNot(FlippedBO)); } return nullptr; @@ -1669,7 +1663,7 @@ bool InstCombinerImpl::shouldOptimizeCast(CastInst *CI) { /// Fold {and,or,xor} (cast X), C. static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { Constant *C = dyn_cast<Constant>(Logic.getOperand(1)); if (!C) return nullptr; @@ -1684,21 +1678,17 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, // instruction may be cheaper (particularly in the case of vectors). Value *X; if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { - Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); - Constant *ZextTruncC = ConstantExpr::getZExt(TruncC, DestTy); - if (ZextTruncC == C) { + if (Constant *TruncC = IC.getLosslessUnsignedTrunc(C, SrcTy)) { // LogicOpc (zext X), C --> zext (LogicOpc X, C) - Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC); return new ZExtInst(NewOp, DestTy); } } if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { - Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); - Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy); - if (SextTruncC == C) { + if (Constant *TruncC = IC.getLosslessSignedTrunc(C, SrcTy)) { // LogicOpc (sext X), C --> sext (LogicOpc X, C) - Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + Value *NewOp = IC.Builder.CreateBinOp(LogicOpc, X, TruncC); return new SExtInst(NewOp, DestTy); } } @@ -1756,7 +1746,7 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { if (!SrcTy->isIntOrIntVectorTy()) return nullptr; - if (Instruction *Ret = foldLogicCastConstant(I, Cast0, Builder)) + if (Instruction *Ret = foldLogicCastConstant(I, Cast0, *this)) return Ret; CastInst *Cast1 = dyn_cast<CastInst>(Op1); @@ -1802,29 +1792,6 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { return CastInst::Create(CastOpcode, NewOp, DestTy); } - // For now, only 'and'/'or' have optimizations after this. - if (LogicOpc == Instruction::Xor) - return nullptr; - - // If this is logic(cast(icmp), cast(icmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - ICmpInst *ICmp0 = dyn_cast<ICmpInst>(Cast0Src); - ICmpInst *ICmp1 = dyn_cast<ICmpInst>(Cast1Src); - if (ICmp0 && ICmp1) { - if (Value *Res = - foldAndOrOfICmps(ICmp0, ICmp1, I, LogicOpc == Instruction::And)) - return CastInst::Create(CastOpcode, Res, DestTy); - return nullptr; - } - - // If this is logic(cast(fcmp), cast(fcmp)), try to fold this even if the - // cast is otherwise not optimizable. This happens for vector sexts. - FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); - FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); - if (FCmp0 && FCmp1) - if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And)) - return CastInst::Create(CastOpcode, R, DestTy); - return nullptr; } @@ -2160,10 +2127,10 @@ Instruction *InstCombinerImpl::foldBinOpOfDisplacedShifts(BinaryOperator &I) { Constant *ShiftedC1, *ShiftedC2, *AddC; Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); - if (!match(&I, - m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), - m_Shift(m_ImmConstant(ShiftedC2), - m_Add(m_Deferred(ShAmt), m_ImmConstant(AddC)))))) + if (!match(&I, m_c_BinOp(m_Shift(m_ImmConstant(ShiftedC1), m_Value(ShAmt)), + m_Shift(m_ImmConstant(ShiftedC2), + m_AddLike(m_Deferred(ShAmt), + m_ImmConstant(AddC)))))) return nullptr; // Make sure the add constant is a valid shift amount. @@ -2254,6 +2221,14 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return SelectInst::Create(Cmp, ConstantInt::getNullValue(Ty), Y); } + // Canonicalize: + // (X +/- Y) & Y --> ~X & Y when Y is a power of 2. + if (match(&I, m_c_And(m_Value(Y), m_OneUse(m_CombineOr( + m_c_Add(m_Value(X), m_Deferred(Y)), + m_Sub(m_Value(X), m_Deferred(Y)))))) && + isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, /*Depth*/ 0, &I)) + return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); + const APInt *C; if (match(Op1, m_APInt(C))) { const APInt *XorC; @@ -2300,13 +2275,6 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { const APInt *AddC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC)))) { - // If we add zeros to every bit below a mask, the add has no effect: - // (X + AddC) & LowMaskC --> X & LowMaskC - unsigned Ctlz = C->countl_zero(); - APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); - if ((*AddC & LowMask).isZero()) - return BinaryOperator::CreateAnd(X, Op1); - // If we are masking the result of the add down to exactly one bit and // the constant we are adding has no bits set below that bit, then the // add is flipping a single bit. Example: @@ -2455,6 +2423,28 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { } } + // If we are clearing the sign bit of a floating-point value, convert this to + // fabs, then cast back to integer. + // + // This is a generous interpretation for noimplicitfloat, this is not a true + // floating-point operation. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && + match(Op1, m_MaxSignedValue()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); + return new BitCastInst(FAbs, I.getType()); + } + } + if (match(&I, m_And(m_OneUse(m_Shl(m_ZExt(m_Value(X)), m_Value(Y))), m_SignMask())) && match(Y, m_SpecificInt_ICMP( @@ -2479,21 +2469,21 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (I.getType()->isIntOrIntVectorTy(1)) { if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ true)) - return I; + return R; } if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ true)) - return I; + return R; } } if (Instruction *FoldedLogic = foldBinOpIntoSelectOrPhi(I)) return FoldedLogic; - if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this)) return DeMorgan; { @@ -2513,16 +2503,24 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(Op1, B); // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C - if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) - if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) - return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) { + Value *NotC = Op1->hasOneUse() + ? Builder.CreateNot(C) + : getFreelyInverted(C, C->hasOneUse(), &Builder); + if (NotC != nullptr) + return BinaryOperator::CreateAnd(Op0, NotC); + } // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C - if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) - if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) - return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); + if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B))) && + match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) { + Value *NotC = Op0->hasOneUse() + ? Builder.CreateNot(C) + : getFreelyInverted(C, C->hasOneUse(), &Builder); + if (NotC != nullptr) + return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); + } // (A | B) & (~A ^ B) -> A & B // (A | B) & (B ^ ~A) -> A & B @@ -2621,23 +2619,34 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // with binop identity constant. But creating a select with non-constant // arm may not be reversible due to poison semantics. Is that a good // canonicalization? - Value *A; - if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op1, Constant::getNullValue(Ty)); - if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + Value *A, *B; + if (match(&I, m_c_And(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); + return SelectInst::Create(A, B, Constant::getNullValue(Ty)); // Similarly, a 'not' of the bool translates to a swap of the select arms: - // ~sext(A) & Op1 --> A ? 0 : Op1 - // Op0 & ~sext(A) --> A ? 0 : Op0 - if (match(Op0, m_Not(m_SExt(m_Value(A)))) && + // ~sext(A) & B / B & ~sext(A) --> A ? 0 : B + if (match(&I, m_c_And(m_Not(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Constant::getNullValue(Ty), Op1); - if (match(Op1, m_Not(m_SExt(m_Value(A)))) && + return SelectInst::Create(A, Constant::getNullValue(Ty), B); + + // and(zext(A), B) -> A ? (B & 1) : 0 + if (match(&I, m_c_And(m_OneUse(m_ZExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, Constant::getNullValue(Ty), Op0); + return SelectInst::Create(A, Builder.CreateAnd(B, ConstantInt::get(Ty, 1)), + Constant::getNullValue(Ty)); + + // (-1 + A) & B --> A ? 0 : B where A is 0/1. + if (match(&I, m_c_And(m_OneUse(m_Add(m_ZExtOrSelf(m_Value(A)), m_AllOnes())), + m_Value(B)))) { + if (A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), B); + if (computeKnownBits(A, /* Depth */ 0, &I).countMaxActiveBits() <= 1) { + return SelectInst::Create( + Builder.CreateICmpEQ(A, Constant::getNullValue(A->getType())), B, + Constant::getNullValue(Ty)); + } + } // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( @@ -2698,105 +2707,178 @@ Instruction *InstCombinerImpl::matchBSwapOrBitReverse(Instruction &I, } /// Match UB-safe variants of the funnel shift intrinsic. -static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC) { +static Instruction *matchFunnelShift(Instruction &Or, InstCombinerImpl &IC, + const DominatorTree &DT) { // TODO: Can we reduce the code duplication between this and the related // rotate matching code under visitSelect and visitTrunc? unsigned Width = Or.getType()->getScalarSizeInBits(); + Instruction *Or0, *Or1; + if (!match(Or.getOperand(0), m_Instruction(Or0)) || + !match(Or.getOperand(1), m_Instruction(Or1))) + return nullptr; + + bool IsFshl = true; // Sub on LSHR. + SmallVector<Value *, 3> FShiftArgs; + // First, find an or'd pair of opposite shifts: // or (lshr ShVal0, ShAmt0), (shl ShVal1, ShAmt1) - BinaryOperator *Or0, *Or1; - if (!match(Or.getOperand(0), m_BinOp(Or0)) || - !match(Or.getOperand(1), m_BinOp(Or1))) - return nullptr; + if (isa<BinaryOperator>(Or0) && isa<BinaryOperator>(Or1)) { + Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; + if (!match(Or0, + m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || + !match(Or1, + m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || + Or0->getOpcode() == Or1->getOpcode()) + return nullptr; - Value *ShVal0, *ShVal1, *ShAmt0, *ShAmt1; - if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal0), m_Value(ShAmt0)))) || - !match(Or1, m_OneUse(m_LogicalShift(m_Value(ShVal1), m_Value(ShAmt1)))) || - Or0->getOpcode() == Or1->getOpcode()) - return nullptr; + // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). + if (Or0->getOpcode() == BinaryOperator::LShr) { + std::swap(Or0, Or1); + std::swap(ShVal0, ShVal1); + std::swap(ShAmt0, ShAmt1); + } + assert(Or0->getOpcode() == BinaryOperator::Shl && + Or1->getOpcode() == BinaryOperator::LShr && + "Illegal or(shift,shift) pair"); - // Canonicalize to or(shl(ShVal0, ShAmt0), lshr(ShVal1, ShAmt1)). - if (Or0->getOpcode() == BinaryOperator::LShr) { - std::swap(Or0, Or1); - std::swap(ShVal0, ShVal1); - std::swap(ShAmt0, ShAmt1); - } - assert(Or0->getOpcode() == BinaryOperator::Shl && - Or1->getOpcode() == BinaryOperator::LShr && - "Illegal or(shift,shift) pair"); + // Match the shift amount operands for a funnel shift pattern. This always + // matches a subtraction on the R operand. + auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { + // Check for constant shift amounts that sum to the bitwidth. + const APInt *LI, *RI; + if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) + if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) + return ConstantInt::get(L->getType(), *LI); + + Constant *LC, *RC; + if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && + match(L, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(R, + m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && + match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) + return ConstantExpr::mergeUndefsWith(LC, RC); + + // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. + // We limit this to X < Width in case the backend re-expands the + // intrinsic, and has to reintroduce a shift modulo operation (InstCombine + // might remove it after this fold). This still doesn't guarantee that the + // final codegen will match this original pattern. + if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { + KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); + return KnownL.getMaxValue().ult(Width) ? L : nullptr; + } - // Match the shift amount operands for a funnel shift pattern. This always - // matches a subtraction on the R operand. - auto matchShiftAmount = [&](Value *L, Value *R, unsigned Width) -> Value * { - // Check for constant shift amounts that sum to the bitwidth. - const APInt *LI, *RI; - if (match(L, m_APIntAllowUndef(LI)) && match(R, m_APIntAllowUndef(RI))) - if (LI->ult(Width) && RI->ult(Width) && (*LI + *RI) == Width) - return ConstantInt::get(L->getType(), *LI); + // For non-constant cases, the following patterns currently only work for + // rotation patterns. + // TODO: Add general funnel-shift compatible patterns. + if (ShVal0 != ShVal1) + return nullptr; - Constant *LC, *RC; - if (match(L, m_Constant(LC)) && match(R, m_Constant(RC)) && - match(L, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && - match(R, m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, APInt(Width, Width))) && - match(ConstantExpr::getAdd(LC, RC), m_SpecificIntAllowUndef(Width))) - return ConstantExpr::mergeUndefsWith(LC, RC); + // For non-constant cases we don't support non-pow2 shift masks. + // TODO: Is it worth matching urem as well? + if (!isPowerOf2_32(Width)) + return nullptr; - // (shl ShVal, X) | (lshr ShVal, (Width - x)) iff X < Width. - // We limit this to X < Width in case the backend re-expands the intrinsic, - // and has to reintroduce a shift modulo operation (InstCombine might remove - // it after this fold). This still doesn't guarantee that the final codegen - // will match this original pattern. - if (match(R, m_OneUse(m_Sub(m_SpecificInt(Width), m_Specific(L))))) { - KnownBits KnownL = IC.computeKnownBits(L, /*Depth*/ 0, &Or); - return KnownL.getMaxValue().ult(Width) ? L : nullptr; + // The shift amount may be masked with negation: + // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) + Value *X; + unsigned Mask = Width - 1; + if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && + match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) + return X; + + // Similar to above, but the shift amount may be extended after masking, + // so return the extended value as the parameter for the intrinsic. + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, + m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), + m_SpecificInt(Mask)))) + return L; + + if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && + match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) + return L; + + return nullptr; + }; + + Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); + if (!ShAmt) { + ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); + IsFshl = false; // Sub on SHL. } + if (!ShAmt) + return nullptr; + + FShiftArgs = {ShVal0, ShVal1, ShAmt}; + } else if (isa<ZExtInst>(Or0) || isa<ZExtInst>(Or1)) { + // If there are two 'or' instructions concat variables in opposite order: + // + // Slot1 and Slot2 are all zero bits. + // | Slot1 | Low | Slot2 | High | + // LowHigh = or (shl (zext Low), ZextLowShlAmt), (zext High) + // | Slot2 | High | Slot1 | Low | + // HighLow = or (shl (zext High), ZextHighShlAmt), (zext Low) + // + // the latter 'or' can be safely convert to + // -> HighLow = fshl LowHigh, LowHigh, ZextHighShlAmt + // if ZextLowShlAmt + ZextHighShlAmt == Width. + if (!isa<ZExtInst>(Or1)) + std::swap(Or0, Or1); - // For non-constant cases, the following patterns currently only work for - // rotation patterns. - // TODO: Add general funnel-shift compatible patterns. - if (ShVal0 != ShVal1) + Value *High, *ZextHigh, *Low; + const APInt *ZextHighShlAmt; + if (!match(Or0, + m_OneUse(m_Shl(m_Value(ZextHigh), m_APInt(ZextHighShlAmt))))) return nullptr; - // For non-constant cases we don't support non-pow2 shift masks. - // TODO: Is it worth matching urem as well? - if (!isPowerOf2_32(Width)) + if (!match(Or1, m_ZExt(m_Value(Low))) || + !match(ZextHigh, m_ZExt(m_Value(High)))) return nullptr; - // The shift amount may be masked with negation: - // (shl ShVal, (X & (Width - 1))) | (lshr ShVal, ((-X) & (Width - 1))) - Value *X; - unsigned Mask = Width - 1; - if (match(L, m_And(m_Value(X), m_SpecificInt(Mask))) && - match(R, m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask)))) - return X; + unsigned HighSize = High->getType()->getScalarSizeInBits(); + unsigned LowSize = Low->getType()->getScalarSizeInBits(); + // Make sure High does not overlap with Low and most significant bits of + // High aren't shifted out. + if (ZextHighShlAmt->ult(LowSize) || ZextHighShlAmt->ugt(Width - HighSize)) + return nullptr; - // Similar to above, but the shift amount may be extended after masking, - // so return the extended value as the parameter for the intrinsic. - if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && - match(R, m_And(m_Neg(m_ZExt(m_And(m_Specific(X), m_SpecificInt(Mask)))), - m_SpecificInt(Mask)))) - return L; + for (User *U : ZextHigh->users()) { + Value *X, *Y; + if (!match(U, m_Or(m_Value(X), m_Value(Y)))) + continue; - if (match(L, m_ZExt(m_And(m_Value(X), m_SpecificInt(Mask)))) && - match(R, m_ZExt(m_And(m_Neg(m_Specific(X)), m_SpecificInt(Mask))))) - return L; + if (!isa<ZExtInst>(Y)) + std::swap(X, Y); - return nullptr; - }; + const APInt *ZextLowShlAmt; + if (!match(X, m_Shl(m_Specific(Or1), m_APInt(ZextLowShlAmt))) || + !match(Y, m_Specific(ZextHigh)) || !DT.dominates(U, &Or)) + continue; - Value *ShAmt = matchShiftAmount(ShAmt0, ShAmt1, Width); - bool IsFshl = true; // Sub on LSHR. - if (!ShAmt) { - ShAmt = matchShiftAmount(ShAmt1, ShAmt0, Width); - IsFshl = false; // Sub on SHL. + // HighLow is good concat. If sum of two shifts amount equals to Width, + // LowHigh must also be a good concat. + if (*ZextLowShlAmt + *ZextHighShlAmt != Width) + continue; + + // Low must not overlap with High and most significant bits of Low must + // not be shifted out. + assert(ZextLowShlAmt->uge(HighSize) && + ZextLowShlAmt->ule(Width - LowSize) && "Invalid concat"); + + FShiftArgs = {U, U, ConstantInt::get(Or0->getType(), *ZextHighShlAmt)}; + break; + } } - if (!ShAmt) + + if (FShiftArgs.empty()) return nullptr; Intrinsic::ID IID = IsFshl ? Intrinsic::fshl : Intrinsic::fshr; Function *F = Intrinsic::getDeclaration(Or.getModule(), IID, Or.getType()); - return CallInst::Create(F, {ShVal0, ShVal1, ShAmt}); + return CallInst::Create(F, FShiftArgs); } /// Attempt to combine or(zext(x),shl(zext(y),bw/2) concat packing patterns. @@ -3272,14 +3354,14 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1)) { if (auto *SI0 = dyn_cast<SelectInst>(Op0)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op1, *SI0, /* IsAnd */ false)) - return I; + return R; } if (auto *SI1 = dyn_cast<SelectInst>(Op1)) { - if (auto *I = + if (auto *R = foldAndOrOfSelectUsingImpliedCond(Op0, *SI1, /* IsAnd */ false)) - return I; + return R; } } @@ -3290,7 +3372,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { /*MatchBitReversals*/ true)) return BitOp; - if (Instruction *Funnel = matchFunnelShift(I, *this)) + if (Instruction *Funnel = matchFunnelShift(I, *this, DT)) return Funnel; if (Instruction *Concat = matchOrConcat(I, Builder)) @@ -3311,9 +3393,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // If the operands have no common bits set: // or (mul X, Y), X --> add (mul X, Y), X --> mul X, (Y + 1) - if (match(&I, - m_c_Or(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), m_Deferred(X))) && - haveNoCommonBitsSet(Op0, Op1, DL)) { + if (match(&I, m_c_DisjointOr(m_OneUse(m_Mul(m_Value(X), m_Value(Y))), + m_Deferred(X)))) { Value *IncrementY = Builder.CreateAdd(Y, ConstantInt::get(Ty, 1)); return BinaryOperator::CreateMul(X, IncrementY); } @@ -3435,7 +3516,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (match(Op0, m_And(m_Or(m_Specific(Op1), m_Value(C)), m_Value(A)))) return BinaryOperator::CreateOr(Op1, Builder.CreateAnd(A, C)); - if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) + if (Instruction *DeMorgan = matchDeMorgansLaws(I, *this)) return DeMorgan; // Canonicalize xor to the RHS. @@ -3581,12 +3662,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // with binop identity constant. But creating a select with non-constant // arm may not be reversible due to poison semantics. Is that a good // canonicalization? - if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + if (match(&I, m_c_Or(m_OneUse(m_SExt(m_Value(A))), m_Value(B))) && A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op1); - if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && - A->getType()->isIntOrIntVectorTy(1)) - return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), Op0); + return SelectInst::Create(A, ConstantInt::getAllOnesValue(Ty), B); // Note: If we've gotten to the point of visiting the outer OR, then the // inner one couldn't be simplified. If it was a constant, then it won't @@ -3628,6 +3706,26 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } } + { + // ((A & B) ^ A) | ((A & B) ^ B) -> A ^ B + // (A ^ (A & B)) | (B ^ (A & B)) -> A ^ B + // ((A & B) ^ B) | ((A & B) ^ A) -> A ^ B + // (B ^ (A & B)) | (A ^ (A & B)) -> A ^ B + const auto TryXorOpt = [&](Value *Lhs, Value *Rhs) -> Instruction * { + if (match(Lhs, m_c_Xor(m_And(m_Value(A), m_Value(B)), m_Deferred(A))) && + match(Rhs, + m_c_Xor(m_And(m_Specific(A), m_Specific(B)), m_Deferred(B)))) { + return BinaryOperator::CreateXor(A, B); + } + return nullptr; + }; + + if (Instruction *Result = TryXorOpt(Op0, Op1)) + return Result; + if (Instruction *Result = TryXorOpt(Op1, Op0)) + return Result; + } + if (Instruction *V = canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) return V; @@ -3720,6 +3818,31 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Res = foldBinOpOfDisplacedShifts(I)) return Res; + // If we are setting the sign bit of a floating-point value, convert + // this to fneg(fabs), then cast back to integer. + // + // If the result isn't immediately cast back to a float, this will increase + // the number of instructions. This is still probably a better canonical form + // as it enables FP value tracking. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // + // This is generous interpretation of noimplicitfloat, this is not a true + // floating-point operation. + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FAbs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, CastOp); + Value *FNegFAbs = Builder.CreateFNeg(FAbs); + return new BitCastInst(FNegFAbs, I.getType()); + } + } + return nullptr; } @@ -3931,26 +4054,6 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, return nullptr; } -// Transform -// ~(x ^ y) -// into: -// (~x) ^ y -// or into -// x ^ (~y) -static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y, - InstCombiner::BuilderTy &Builder) { - // We only want to do the transform if it is free to do. - if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { - // Ok, good. - } else if (InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) { - std::swap(X, Y); - } else - return nullptr; - - Value *NotX = Builder.CreateNot(X, X->getName() + ".not"); - return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); -} - static Instruction *foldNotXor(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { Value *X, *Y; @@ -3959,9 +4062,6 @@ static Instruction *foldNotXor(BinaryOperator &I, if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) return nullptr; - if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder)) - return NewXor; - auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) { return A == C || A == D || B == C || B == D; }; @@ -4023,13 +4123,13 @@ static bool canFreelyInvert(InstCombiner &IC, Value *Op, Instruction *IgnoredUser) { auto *I = dyn_cast<Instruction>(Op); return I && IC.isFreeToInvert(I, /*WillInvertAllUses=*/true) && - InstCombiner::canFreelyInvertAllUsersOf(I, IgnoredUser); + IC.canFreelyInvertAllUsersOf(I, IgnoredUser); } static Value *freelyInvert(InstCombinerImpl &IC, Value *Op, Instruction *IgnoredUser) { auto *I = cast<Instruction>(Op); - IC.Builder.SetInsertPoint(&*I->getInsertionPointAfterDef()); + IC.Builder.SetInsertPoint(*I->getInsertionPointAfterDef()); Value *NotOp = IC.Builder.CreateNot(Op, Op->getName() + ".not"); Op->replaceUsesWithIf(NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); @@ -4067,7 +4167,7 @@ bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { Op0 = freelyInvert(*this, Op0, &I); Op1 = freelyInvert(*this, Op1, &I); - Builder.SetInsertPoint(I.getInsertionPointAfterDef()); + Builder.SetInsertPoint(*I.getInsertionPointAfterDef()); Value *NewLogicOp; if (IsBinaryOp) NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); @@ -4115,7 +4215,7 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { *OpToInvert = freelyInvert(*this, *OpToInvert, &I); - Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); + Builder.SetInsertPoint(*I.getInsertionPointAfterDef()); Value *NewBinOp; if (IsBinaryOp) NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); @@ -4259,15 +4359,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // ~max(~X, Y) --> min(X, ~Y) auto *II = dyn_cast<IntrinsicInst>(NotOp); if (II && II->hasOneUse()) { - if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) && - isFreeToInvert(X, X->hasOneUse()) && - isFreeToInvert(Y, Y->hasOneUse())) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *NotX = Builder.CreateNot(X); - Value *NotY = Builder.CreateNot(Y); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY); - return replaceInstUsesWith(I, InvMaxMin); - } if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) { Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); Value *NotY = Builder.CreateNot(Y); @@ -4317,6 +4408,11 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (Instruction *NewXor = foldNotXor(I, Builder)) return NewXor; + // TODO: Could handle multi-use better by checking if all uses of NotOp (other + // than I) can be inverted. + if (Value *R = getFreelyInverted(NotOp, NotOp->hasOneUse(), &Builder)) + return replaceInstUsesWith(I, R); + return nullptr; } @@ -4366,7 +4462,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *M; if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), m_c_And(m_Deferred(M), m_Value())))) - return BinaryOperator::CreateOr(Op0, Op1); + return BinaryOperator::CreateDisjointOr(Op0, Op1); if (Instruction *Xor = visitMaskedMerge(I, Builder)) return Xor; @@ -4466,6 +4562,27 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // a 'not' op and moving it before the shift. Doing that requires // preventing the inverse fold in canShiftBinOpWithConstantRHS(). } + + // If we are XORing the sign bit of a floating-point value, convert + // this to fneg, then cast back to integer. + // + // This is generous interpretation of noimplicitfloat, this is not a true + // floating-point operation. + // + // Assumes any IEEE-represented type has the sign bit in the high bit. + // TODO: Unify with APInt matcher. This version allows undef unlike m_APInt + Value *CastOp; + if (match(Op0, m_BitCast(m_Value(CastOp))) && match(Op1, m_SignMask()) && + !Builder.GetInsertBlock()->getParent()->hasFnAttribute( + Attribute::NoImplicitFloat)) { + Type *EltTy = CastOp->getType()->getScalarType(); + if (EltTy->isFloatingPointTy() && EltTy->isIEEE() && + EltTy->getPrimitiveSizeInBits() == + I.getType()->getScalarType()->getPrimitiveSizeInBits()) { + Value *FNeg = Builder.CreateFNeg(CastOp); + return new BitCastInst(FNeg, I.getType()); + } + } } // FIXME: This should not be limited to scalar (pull into APInt match above). diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index d3ec6a7aa667..255ce6973a16 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -89,12 +89,6 @@ static cl::opt<unsigned> GuardWideningWindow( cl::desc("How wide an instruction window to bypass looking for " "another guard")); -namespace llvm { -/// enable preservation of attributes in assume like: -/// call void @llvm.assume(i1 true) [ "nonnull"(i32* %PTR) ] -extern cl::opt<bool> EnableKnowledgeRetention; -} // namespace llvm - /// Return the specified type promoted as it would be to pass though a va_arg /// area. static Type *getPromotedType(Type *Ty) { @@ -174,14 +168,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { return nullptr; // Use an integer load+store unless we can find something better. - unsigned SrcAddrSp = - cast<PointerType>(MI->getArgOperand(1)->getType())->getAddressSpace(); - unsigned DstAddrSp = - cast<PointerType>(MI->getArgOperand(0)->getType())->getAddressSpace(); - IntegerType* IntType = IntegerType::get(MI->getContext(), Size<<3); - Type *NewSrcPtrTy = PointerType::get(IntType, SrcAddrSp); - Type *NewDstPtrTy = PointerType::get(IntType, DstAddrSp); // If the memcpy has metadata describing the members, see if we can get the // TBAA tag describing our copy. @@ -200,8 +187,8 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { CopyMD = cast<MDNode>(M->getOperand(2)); } - Value *Src = Builder.CreateBitCast(MI->getArgOperand(1), NewSrcPtrTy); - Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); + Value *Src = MI->getArgOperand(1); + Value *Dest = MI->getArgOperand(0); LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. L->setAlignment(*CopySrcAlign); @@ -291,9 +278,6 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { Type *ITy = IntegerType::get(MI->getContext(), Len*8); // n=1 -> i8. Value *Dest = MI->getDest(); - unsigned DstAddrSp = cast<PointerType>(Dest->getType())->getAddressSpace(); - Type *NewDstPtrTy = PointerType::get(ITy, DstAddrSp); - Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); // Extract the fill value and store. const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; @@ -301,7 +285,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile()); S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); for (auto *DAI : at::getAssignmentMarkers(S)) { - if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; })) + if (llvm::is_contained(DAI->location_ops(), FillC)) DAI->replaceVariableLocationOp(FillC, FillVal); } @@ -500,8 +484,6 @@ static Instruction *simplifyInvariantGroupIntrinsic(IntrinsicInst &II, if (Result->getType()->getPointerAddressSpace() != II.getType()->getPointerAddressSpace()) Result = IC.Builder.CreateAddrSpaceCast(Result, II.getType()); - if (Result->getType() != II.getType()) - Result = IC.Builder.CreateBitCast(Result, II.getType()); return cast<Instruction>(Result); } @@ -532,6 +514,8 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { return IC.replaceInstUsesWith(II, ConstantInt::getNullValue(II.getType())); } + Constant *C; + if (IsTZ) { // cttz(-x) -> cttz(x) if (match(Op0, m_Neg(m_Value(X)))) @@ -567,6 +551,38 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { if (match(Op0, m_Intrinsic<Intrinsic::abs>(m_Value(X)))) return IC.replaceOperand(II, 0, X); + + // cttz(shl(%const, %val), 1) --> add(cttz(%const, 1), %val) + if (match(Op0, m_Shl(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCttz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1); + return BinaryOperator::CreateAdd(ConstCttz, X); + } + + // cttz(lshr exact (%const, %val), 1) --> sub(cttz(%const, 1), %val) + if (match(Op0, m_Exact(m_LShr(m_ImmConstant(C), m_Value(X)))) && + match(Op1, m_One())) { + Value *ConstCttz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::cttz, C, Op1); + return BinaryOperator::CreateSub(ConstCttz, X); + } + } else { + // ctlz(lshr(%const, %val), 1) --> add(ctlz(%const, 1), %val) + if (match(Op0, m_LShr(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCtlz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1); + return BinaryOperator::CreateAdd(ConstCtlz, X); + } + + // ctlz(shl nuw (%const, %val), 1) --> sub(ctlz(%const, 1), %val) + if (match(Op0, m_NUWShl(m_ImmConstant(C), m_Value(X))) && + match(Op1, m_One())) { + Value *ConstCtlz = + IC.Builder.CreateBinaryIntrinsic(Intrinsic::ctlz, C, Op1); + return BinaryOperator::CreateSub(ConstCtlz, X); + } } KnownBits Known = IC.computeKnownBits(Op0, 0, &II); @@ -911,11 +927,27 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { Value *FAbsSrc; if (match(Src0, m_FAbs(m_Value(FAbsSrc)))) { - II.setArgOperand(1, ConstantInt::get(Src1->getType(), fabs(Mask))); + II.setArgOperand(1, ConstantInt::get(Src1->getType(), inverse_fabs(Mask))); return replaceOperand(II, 0, FAbsSrc); } - // TODO: is.fpclass(x, fcInf) -> fabs(x) == inf + if ((OrderedMask == fcInf || OrderedInvertedMask == fcInf) && + (IsOrdered || IsUnordered) && !IsStrict) { + // is.fpclass(x, fcInf) -> fcmp oeq fabs(x), +inf + // is.fpclass(x, ~fcInf) -> fcmp one fabs(x), +inf + // is.fpclass(x, fcInf|fcNan) -> fcmp ueq fabs(x), +inf + // is.fpclass(x, ~(fcInf|fcNan)) -> fcmp une fabs(x), +inf + Constant *Inf = ConstantFP::getInfinity(Src0->getType()); + FCmpInst::Predicate Pred = + IsUnordered ? FCmpInst::FCMP_UEQ : FCmpInst::FCMP_OEQ; + if (OrderedInvertedMask == fcInf) + Pred = IsUnordered ? FCmpInst::FCMP_UNE : FCmpInst::FCMP_ONE; + + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, Src0); + Value *CmpInf = Builder.CreateFCmp(Pred, Fabs, Inf); + CmpInf->takeName(&II); + return replaceInstUsesWith(II, CmpInf); + } if ((OrderedMask == fcPosInf || OrderedMask == fcNegInf) && (IsOrdered || IsUnordered) && !IsStrict) { @@ -992,8 +1024,7 @@ Instruction *InstCombinerImpl::foldIntrinsicIsFPClass(IntrinsicInst &II) { return replaceInstUsesWith(II, FCmp); } - KnownFPClass Known = computeKnownFPClass( - Src0, DL, Mask, 0, &getTargetLibraryInfo(), &AC, &II, &DT); + KnownFPClass Known = computeKnownFPClass(Src0, Mask, &II); // Clear test bits we know must be false from the source value. // fp_class (nnan x), qnan|snan|other -> fp_class (nnan x), other @@ -1030,6 +1061,20 @@ static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +static std::optional<bool> getKnownSignOrZero(Value *Op, Instruction *CxtI, + const DataLayout &DL, + AssumptionCache *AC, + DominatorTree *DT) { + if (std::optional<bool> Sign = getKnownSign(Op, CxtI, DL, AC, DT)) + return Sign; + + Value *X, *Y; + if (match(Op, m_NSWSub(m_Value(X), m_Value(Y)))) + return isImpliedByDomCondition(ICmpInst::ICMP_SLE, X, Y, CxtI, DL); + + return std::nullopt; +} + /// Return true if two values \p Op0 and \p Op1 are known to have the same sign. static bool signBitMustBeTheSame(Value *Op0, Value *Op1, Instruction *CxtI, const DataLayout &DL, AssumptionCache *AC, @@ -1530,12 +1575,15 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); - if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { - // abs(x) -> x if x >= 0 - if (!*Sign) + if (std::optional<bool> Known = + getKnownSignOrZero(IIOperand, II, DL, &AC, &DT)) { + // abs(x) -> x if x >= 0 (include abs(x-y) --> x - y where x >= y) + // abs(x) -> x if x > 0 (include abs(x-y) --> x - y where x > y) + if (!*Known) return replaceInstUsesWith(*II, IIOperand); // abs(x) -> -x if x < 0 + // abs(x) -> -x if x < = 0 (include abs(x-y) --> y - x where x <= y) if (IntMinIsPoison) return BinaryOperator::CreateNSWNeg(IIOperand); return BinaryOperator::CreateNeg(IIOperand); @@ -1580,8 +1628,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Constant *C; if (match(I0, m_ZExt(m_Value(X))) && match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(NarrowC, II->getType()) == C) { + if (Constant *NarrowC = getLosslessUnsignedTrunc(C, X->getType())) { Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); return CastInst::Create(Instruction::ZExt, NarrowMaxMin, II->getType()); } @@ -1603,13 +1650,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Constant *C; if (match(I0, m_SExt(m_Value(X))) && match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NarrowC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getSExt(NarrowC, II->getType()) == C) { + if (Constant *NarrowC = getLosslessSignedTrunc(C, X->getType())) { Value *NarrowMaxMin = Builder.CreateBinaryIntrinsic(IID, X, NarrowC); return CastInst::Create(Instruction::SExt, NarrowMaxMin, II->getType()); } } + // umin(i1 X, i1 Y) -> and i1 X, Y + // smax(i1 X, i1 Y) -> and i1 X, Y + if ((IID == Intrinsic::umin || IID == Intrinsic::smax) && + II->getType()->isIntOrIntVectorTy(1)) { + return BinaryOperator::CreateAnd(I0, I1); + } + + // umax(i1 X, i1 Y) -> or i1 X, Y + // smin(i1 X, i1 Y) -> or i1 X, Y + if ((IID == Intrinsic::umax || IID == Intrinsic::smin) && + II->getType()->isIntOrIntVectorTy(1)) { + return BinaryOperator::CreateOr(I0, I1); + } + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { // smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y) // smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y) @@ -1672,12 +1732,12 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { Value *A; if (match(X, m_OneUse(m_Not(m_Value(A)))) && - !isFreeToInvert(A, A->hasOneUse()) && - isFreeToInvert(Y, Y->hasOneUse())) { - Value *NotY = Builder.CreateNot(Y); - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); - return BinaryOperator::CreateNot(InvMaxMin); + !isFreeToInvert(A, A->hasOneUse())) { + if (Value *NotY = getFreelyInverted(Y, Y->hasOneUse(), &Builder)) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); + return BinaryOperator::CreateNot(InvMaxMin); + } } return nullptr; }; @@ -1929,6 +1989,52 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return &CI; break; } + case Intrinsic::ptrmask: { + unsigned BitWidth = DL.getPointerTypeSizeInBits(II->getType()); + KnownBits Known(BitWidth); + if (SimplifyDemandedInstructionBits(*II, Known)) + return II; + + Value *InnerPtr, *InnerMask; + bool Changed = false; + // Combine: + // (ptrmask (ptrmask p, A), B) + // -> (ptrmask p, (and A, B)) + if (match(II->getArgOperand(0), + m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(InnerPtr), + m_Value(InnerMask))))) { + assert(II->getArgOperand(1)->getType() == InnerMask->getType() && + "Mask types must match"); + // TODO: If InnerMask == Op1, we could copy attributes from inner + // callsite -> outer callsite. + Value *NewMask = Builder.CreateAnd(II->getArgOperand(1), InnerMask); + replaceOperand(CI, 0, InnerPtr); + replaceOperand(CI, 1, NewMask); + Changed = true; + } + + // See if we can deduce non-null. + if (!CI.hasRetAttr(Attribute::NonNull) && + (Known.isNonZero() || + isKnownNonZero(II, DL, /*Depth*/ 0, &AC, II, &DT))) { + CI.addRetAttr(Attribute::NonNull); + Changed = true; + } + + unsigned NewAlignmentLog = + std::min(Value::MaxAlignmentExponent, + std::min(BitWidth - 1, Known.countMinTrailingZeros())); + // Known bits will capture if we had alignment information associated with + // the pointer argument. + if (NewAlignmentLog > Log2(CI.getRetAlign().valueOrOne())) { + CI.addRetAttr(Attribute::getWithAlignment( + CI.getContext(), Align(uint64_t(1) << NewAlignmentLog))); + Changed = true; + } + if (Changed) + return &CI; + break; + } case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: { if (Instruction *I = foldIntrinsicWithOverflowCommon(II)) @@ -2493,10 +2599,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { VectorType *NewVT = cast<VectorType>(II->getType()); if (Constant *CV0 = dyn_cast<Constant>(Arg0)) { if (Constant *CV1 = dyn_cast<Constant>(Arg1)) { - CV0 = ConstantExpr::getIntegerCast(CV0, NewVT, /*isSigned=*/!Zext); - CV1 = ConstantExpr::getIntegerCast(CV1, NewVT, /*isSigned=*/!Zext); - - return replaceInstUsesWith(CI, ConstantExpr::getMul(CV0, CV1)); + Value *V0 = Builder.CreateIntCast(CV0, NewVT, /*isSigned=*/!Zext); + Value *V1 = Builder.CreateIntCast(CV1, NewVT, /*isSigned=*/!Zext); + return replaceInstUsesWith(CI, Builder.CreateMul(V0, V1)); } // Couldn't simplify - canonicalize constant to the RHS. @@ -2950,24 +3055,27 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceOperand(CI, 0, InsertTuple); } - auto *DstTy = dyn_cast<FixedVectorType>(ReturnType); - auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); + auto *DstTy = dyn_cast<VectorType>(ReturnType); + auto *VecTy = dyn_cast<VectorType>(Vec->getType()); - // Only canonicalize if the the destination vector and Vec are fixed - // vectors. if (DstTy && VecTy) { - unsigned DstNumElts = DstTy->getNumElements(); - unsigned VecNumElts = VecTy->getNumElements(); + auto DstEltCnt = DstTy->getElementCount(); + auto VecEltCnt = VecTy->getElementCount(); unsigned IdxN = cast<ConstantInt>(Idx)->getZExtValue(); // Extracting the entirety of Vec is a nop. - if (VecNumElts == DstNumElts) { + if (DstEltCnt == VecTy->getElementCount()) { replaceInstUsesWith(CI, Vec); return eraseInstFromFunction(CI); } + // Only canonicalize to shufflevector if the destination vector and + // Vec are fixed vectors. + if (VecEltCnt.isScalable() || DstEltCnt.isScalable()) + break; + SmallVector<int, 8> Mask; - for (unsigned i = 0; i != DstNumElts; ++i) + for (unsigned i = 0; i != DstEltCnt.getKnownMinValue(); ++i) Mask.push_back(IdxN + i); Value *Shuffle = Builder.CreateShuffleVector(Vec, Mask); @@ -3943,9 +4051,9 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); - Instruction *InsertPt = NewCall->getInsertionPointAfterDef(); - assert(InsertPt && "No place to insert cast"); - InsertNewInstBefore(NC, *InsertPt); + auto OptInsertPt = NewCall->getInsertionPointAfterDef(); + assert(OptInsertPt && "No place to insert cast"); + InsertNewInstBefore(NC, *OptInsertPt); Worklist.pushUsersToWorkList(*Caller); } else { NV = PoisonValue::get(Caller->getType()); @@ -3972,8 +4080,6 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Instruction * InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, IntrinsicInst &Tramp) { - Value *Callee = Call.getCalledOperand(); - Type *CalleeTy = Callee->getType(); FunctionType *FTy = Call.getFunctionType(); AttributeList Attrs = Call.getAttributes(); @@ -4070,12 +4176,8 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Replace the trampoline call with a direct call. Let the generic // code sort out any function type mismatches. - FunctionType *NewFTy = FunctionType::get(FTy->getReturnType(), NewTypes, - FTy->isVarArg()); - Constant *NewCallee = - NestF->getType() == PointerType::getUnqual(NewFTy) ? - NestF : ConstantExpr::getBitCast(NestF, - PointerType::getUnqual(NewFTy)); + FunctionType *NewFTy = + FunctionType::get(FTy->getReturnType(), NewTypes, FTy->isVarArg()); AttributeList NewPAL = AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(), Attrs.getRetAttrs(), NewArgAttrs); @@ -4085,19 +4187,18 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, Instruction *NewCaller; if (InvokeInst *II = dyn_cast<InvokeInst>(&Call)) { - NewCaller = InvokeInst::Create(NewFTy, NewCallee, - II->getNormalDest(), II->getUnwindDest(), - NewArgs, OpBundles); + NewCaller = InvokeInst::Create(NewFTy, NestF, II->getNormalDest(), + II->getUnwindDest(), NewArgs, OpBundles); cast<InvokeInst>(NewCaller)->setCallingConv(II->getCallingConv()); cast<InvokeInst>(NewCaller)->setAttributes(NewPAL); } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(&Call)) { NewCaller = - CallBrInst::Create(NewFTy, NewCallee, CBI->getDefaultDest(), + CallBrInst::Create(NewFTy, NestF, CBI->getDefaultDest(), CBI->getIndirectDests(), NewArgs, OpBundles); cast<CallBrInst>(NewCaller)->setCallingConv(CBI->getCallingConv()); cast<CallBrInst>(NewCaller)->setAttributes(NewPAL); } else { - NewCaller = CallInst::Create(NewFTy, NewCallee, NewArgs, OpBundles); + NewCaller = CallInst::Create(NewFTy, NestF, NewArgs, OpBundles); cast<CallInst>(NewCaller)->setTailCallKind( cast<CallInst>(Call).getTailCallKind()); cast<CallInst>(NewCaller)->setCallingConv( @@ -4113,7 +4214,6 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Replace the trampoline call with a direct call. Since there is no 'nest' // parameter, there is no need to adjust the argument list. Let the generic // code sort out any function type mismatches. - Constant *NewCallee = ConstantExpr::getBitCast(NestF, CalleeTy); - Call.setCalledFunction(FTy, NewCallee); + Call.setCalledFunction(FTy, NestF); return &Call; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 5c84f666616d..6629ca840a67 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -29,11 +29,8 @@ using namespace PatternMatch; /// true for, actually insert the code to evaluate the expression. Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, bool isSigned) { - if (Constant *C = dyn_cast<Constant>(V)) { - C = ConstantExpr::getIntegerCast(C, Ty, isSigned /*Sext or ZExt*/); - // If we got a constantexpr back, try to simplify it with DL info. - return ConstantFoldConstant(C, DL, &TLI); - } + if (Constant *C = dyn_cast<Constant>(V)) + return ConstantFoldIntegerCast(C, Ty, isSigned, DL); // Otherwise, it must be an instruction. Instruction *I = cast<Instruction>(V); @@ -112,7 +109,7 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, } Res->takeName(I); - return InsertNewInstWith(Res, *I); + return InsertNewInstWith(Res, I->getIterator()); } Instruction::CastOps @@ -217,7 +214,8 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { /// free to be evaluated in that type. This is a helper for canEvaluate*. static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { if (isa<Constant>(V)) - return true; + return match(V, m_ImmConstant()); + Value *X; if ((match(V, m_ZExtOrSExt(m_Value(X))) || match(V, m_Trunc(m_Value(X)))) && X->getType() == Ty) @@ -229,7 +227,6 @@ static bool canAlwaysEvaluateInType(Value *V, Type *Ty) { /// Filter out values that we can not evaluate in the destination type for free. /// This is a helper for canEvaluate*. static bool canNotEvaluateInType(Value *V, Type *Ty) { - assert(!isa<Constant>(V) && "Constant should already be handled."); if (!isa<Instruction>(V)) return true; // We don't extend or shrink something that has multiple uses -- doing so @@ -505,11 +502,13 @@ Instruction *InstCombinerImpl::narrowFunnelShift(TruncInst &Trunc) { if (!MaskedValueIsZero(ShVal1, HiBitMask, 0, &Trunc)) return nullptr; - // We have an unnecessarily wide rotate! - // trunc (or (shl ShVal0, ShAmt), (lshr ShVal1, BitWidth - ShAmt)) - // Narrow the inputs and convert to funnel shift intrinsic: - // llvm.fshl.i8(trunc(ShVal), trunc(ShVal), trunc(ShAmt)) - Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + // Adjust the width of ShAmt for narrowed funnel shift operation: + // - Zero-extend if ShAmt is narrower than the destination type. + // - Truncate if ShAmt is wider, discarding non-significant high-order bits. + // This prepares ShAmt for llvm.fshl.i8(trunc(ShVal), trunc(ShVal), + // zext/trunc(ShAmt)). + Value *NarrowShAmt = Builder.CreateZExtOrTrunc(ShAmt, DestTy); + Value *X, *Y; X = Y = Builder.CreateTrunc(ShVal0, DestTy); if (ShVal0 != ShVal1) @@ -582,13 +581,15 @@ Instruction *InstCombinerImpl::narrowBinOp(TruncInst &Trunc) { APInt(SrcWidth, MaxShiftAmt)))) { auto *OldShift = cast<Instruction>(Trunc.getOperand(0)); bool IsExact = OldShift->isExact(); - auto *ShAmt = ConstantExpr::getIntegerCast(C, A->getType(), true); - ShAmt = Constant::mergeUndefsWith(ShAmt, C); - Value *Shift = - OldShift->getOpcode() == Instruction::AShr - ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) - : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); - return CastInst::CreateTruncOrBitCast(Shift, DestTy); + if (Constant *ShAmt = ConstantFoldIntegerCast(C, A->getType(), + /*IsSigned*/ true, DL)) { + ShAmt = Constant::mergeUndefsWith(ShAmt, C); + Value *Shift = + OldShift->getOpcode() == Instruction::AShr + ? Builder.CreateAShr(A, ShAmt, OldShift->getName(), IsExact) + : Builder.CreateLShr(A, ShAmt, OldShift->getName(), IsExact); + return CastInst::CreateTruncOrBitCast(Shift, DestTy); + } } } break; @@ -904,19 +905,18 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. // zext (X != 0) to i32 --> X iff X has only the low bit set. // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. - if (Op1CV->isZero() && Cmp->isEquality() && - (Cmp->getOperand(0)->getType() == Zext.getType() || - Cmp->getPredicate() == ICmpInst::ICMP_NE)) { - // If Op1C some other power of two, convert: - KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); + if (Op1CV->isZero() && Cmp->isEquality()) { // Exactly 1 possible 1? But not the high-bit because that is // canonicalized to this form. + KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); APInt KnownZeroMask(~Known.Zero); - if (KnownZeroMask.isPowerOf2() && - (Zext.getType()->getScalarSizeInBits() != - KnownZeroMask.logBase2() + 1)) { - uint32_t ShAmt = KnownZeroMask.logBase2(); + uint32_t ShAmt = KnownZeroMask.logBase2(); + bool IsExpectShAmt = KnownZeroMask.isPowerOf2() && + (Zext.getType()->getScalarSizeInBits() != ShAmt + 1); + if (IsExpectShAmt && + (Cmp->getOperand(0)->getType() == Zext.getType() || + Cmp->getPredicate() == ICmpInst::ICMP_NE || ShAmt == 0)) { Value *In = Cmp->getOperand(0); if (ShAmt) { // Perform a logical shr by shiftamt. @@ -1184,14 +1184,14 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { Value *X; if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && X->getType() == DestTy) - return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy)); + return BinaryOperator::CreateAnd(X, Builder.CreateZExt(C, DestTy)); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && X->getType() == DestTy) { - Constant *ZC = ConstantExpr::getZExt(C, DestTy); + Value *ZC = Builder.CreateZExt(C, DestTy); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } @@ -1202,7 +1202,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // zext (and (trunc X), C) --> and X, (zext C) if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) && X->getType() == DestTy) { - Constant *ZextC = ConstantExpr::getZExt(C, DestTy); + Value *ZextC = Builder.CreateZExt(C, DestTy); return BinaryOperator::CreateAnd(X, ZextC); } @@ -1221,6 +1221,22 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { } } + if (!Zext.hasNonNeg()) { + // If this zero extend is only used by a shift, add nneg flag. + if (Zext.hasOneUse() && + SrcTy->getScalarSizeInBits() > + Log2_64_Ceil(DestTy->getScalarSizeInBits()) && + match(Zext.user_back(), m_Shift(m_Value(), m_Specific(&Zext)))) { + Zext.setNonNeg(); + return &Zext; + } + + if (isKnownNonNegative(Src, SQ.getWithInstruction(&Zext))) { + Zext.setNonNeg(); + return &Zext; + } + } + return nullptr; } @@ -1373,8 +1389,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { unsigned DestBitSize = DestTy->getScalarSizeInBits(); // If the value being extended is zero or positive, use a zext instead. - if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT)) - return CastInst::Create(Instruction::ZExt, Src, DestTy); + if (isKnownNonNegative(Src, SQ.getWithInstruction(&Sext))) { + auto CI = CastInst::Create(Instruction::ZExt, Src, DestTy); + CI->setNonNeg(true); + return CI; + } // Try to extend the entire expression tree to the wide destination type. if (shouldChangeType(SrcTy, DestTy) && canEvaluateSExtd(Src, DestTy)) { @@ -1445,9 +1464,11 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { // TODO: Eventually this could be subsumed by EvaluateInDifferentType. Constant *BA = nullptr, *CA = nullptr; if (match(Src, m_AShr(m_Shl(m_Trunc(m_Value(A)), m_Constant(BA)), - m_Constant(CA))) && + m_ImmConstant(CA))) && BA->isElementWiseEqual(CA) && A->getType() == DestTy) { - Constant *WideCurrShAmt = ConstantExpr::getSExt(CA, DestTy); + Constant *WideCurrShAmt = + ConstantFoldCastOperand(Instruction::SExt, CA, DestTy, DL); + assert(WideCurrShAmt && "Constant folding of ImmConstant cannot fail"); Constant *NumLowbitsLeft = ConstantExpr::getSub( ConstantInt::get(DestTy, SrcTy->getScalarSizeInBits()), WideCurrShAmt); Constant *NewShAmt = ConstantExpr::getSub( @@ -1915,29 +1936,6 @@ Instruction *InstCombinerImpl::visitIntToPtr(IntToPtrInst &CI) { return nullptr; } -/// Implement the transforms for cast of pointer (bitcast/ptrtoint) -Instruction *InstCombinerImpl::commonPointerCastTransforms(CastInst &CI) { - Value *Src = CI.getOperand(0); - - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Src)) { - // If casting the result of a getelementptr instruction with no offset, turn - // this into a cast of the original pointer! - if (GEP->hasAllZeroIndices() && - // If CI is an addrspacecast and GEP changes the poiner type, merging - // GEP into CI would undo canonicalizing addrspacecast with different - // pointer types, causing infinite loops. - (!isa<AddrSpaceCastInst>(CI) || - GEP->getType() == GEP->getPointerOperandType())) { - // Changing the cast operand is usually not a good idea but it is safe - // here because the pointer operand is being replaced with another - // pointer operand so the opcode doesn't need to change. - return replaceOperand(CI, 0, GEP->getOperand(0)); - } - } - - return commonCastTransforms(CI); -} - Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { // If the destination integer type is not the intptr_t type for this target, // do a ptrtoint to intptr_t then do a trunc or zext. This allows the cast @@ -1955,6 +1953,15 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); } + // (ptrtoint (ptrmask P, M)) + // -> (and (ptrtoint P), M) + // This is generally beneficial as `and` is better supported than `ptrmask`. + Value *Ptr, *Mask; + if (match(SrcOp, m_OneUse(m_Intrinsic<Intrinsic::ptrmask>(m_Value(Ptr), + m_Value(Mask)))) && + Mask->getType() == Ty) + return BinaryOperator::CreateAnd(Builder.CreatePtrToInt(Ptr, Ty), Mask); + if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. // While this can increase the number of instructions it doesn't actually @@ -1979,7 +1986,7 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return InsertElementInst::Create(Vec, NewCast, Index); } - return commonPointerCastTransforms(CI); + return commonCastTransforms(CI); } /// This input value (which is known to have vector type) is being zero extended @@ -2136,9 +2143,12 @@ static bool collectInsertionElements(Value *V, unsigned Shift, Type *ElementIntTy = IntegerType::get(C->getContext(), ElementSize); for (unsigned i = 0; i != NumElts; ++i) { - unsigned ShiftI = Shift+i*ElementSize; - Constant *Piece = ConstantExpr::getLShr(C, ConstantInt::get(C->getType(), - ShiftI)); + unsigned ShiftI = Shift + i * ElementSize; + Constant *Piece = ConstantFoldBinaryInstruction( + Instruction::LShr, C, ConstantInt::get(C->getType(), ShiftI)); + if (!Piece) + return false; + Piece = ConstantExpr::getTrunc(Piece, ElementIntTy); if (!collectInsertionElements(Piece, ShiftI, Elements, VecEltTy, isBigEndian)) @@ -2701,11 +2711,9 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (Instruction *I = foldBitCastSelect(CI, Builder)) return I; - if (SrcTy->isPointerTy()) - return commonPointerCastTransforms(CI); return commonCastTransforms(CI); } Instruction *InstCombinerImpl::visitAddrSpaceCast(AddrSpaceCastInst &CI) { - return commonPointerCastTransforms(CI); + return commonCastTransforms(CI); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 656f04370e17..e42e011bd436 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -12,12 +12,14 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APSInt.h" +#include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Utils/Local.h" #include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" @@ -26,6 +28,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <bitset> using namespace llvm; using namespace PatternMatch; @@ -412,7 +415,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( /// Returns true if we can rewrite Start as a GEP with pointer Base /// and some integer offset. The nodes that need to be re-written /// for this transformation will be added to Explored. -static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, +static bool canRewriteGEPAsOffset(Value *Start, Value *Base, const DataLayout &DL, SetVector<Value *> &Explored) { SmallVector<Value *, 16> WorkList(1, Start); @@ -440,27 +443,15 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, continue; } - if (!isa<IntToPtrInst>(V) && !isa<PtrToIntInst>(V) && - !isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) + if (!isa<GetElementPtrInst>(V) && !isa<PHINode>(V)) // We've found some value that we can't explore which is different from // the base. Therefore we can't do this transformation. return false; - if (isa<IntToPtrInst>(V) || isa<PtrToIntInst>(V)) { - auto *CI = cast<CastInst>(V); - if (!CI->isNoopCast(DL)) - return false; - - if (!Explored.contains(CI->getOperand(0))) - WorkList.push_back(CI->getOperand(0)); - } - if (auto *GEP = dyn_cast<GEPOperator>(V)) { - // We're limiting the GEP to having one index. This will preserve - // the original pointer type. We could handle more cases in the - // future. - if (GEP->getNumIndices() != 1 || !GEP->isInBounds() || - GEP->getSourceElementType() != ElemTy) + // Only allow inbounds GEPs with at most one variable offset. + auto IsNonConst = [](Value *V) { return !isa<ConstantInt>(V); }; + if (!GEP->isInBounds() || count_if(GEP->indices(), IsNonConst) > 1) return false; if (!Explored.contains(GEP->getOperand(0))) @@ -514,7 +505,8 @@ static bool canRewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, static void setInsertionPoint(IRBuilder<> &Builder, Value *V, bool Before = true) { if (auto *PHI = dyn_cast<PHINode>(V)) { - Builder.SetInsertPoint(&*PHI->getParent()->getFirstInsertionPt()); + BasicBlock *Parent = PHI->getParent(); + Builder.SetInsertPoint(Parent, Parent->getFirstInsertionPt()); return; } if (auto *I = dyn_cast<Instruction>(V)) { @@ -526,7 +518,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, if (auto *A = dyn_cast<Argument>(V)) { // Set the insertion point in the entry block. BasicBlock &Entry = A->getParent()->getEntryBlock(); - Builder.SetInsertPoint(&*Entry.getFirstInsertionPt()); + Builder.SetInsertPoint(&Entry, Entry.getFirstInsertionPt()); return; } // Otherwise, this is a constant and we don't need to set a new @@ -536,7 +528,7 @@ static void setInsertionPoint(IRBuilder<> &Builder, Value *V, /// Returns a re-written value of Start as an indexed GEP using Base as a /// pointer. -static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, +static Value *rewriteGEPAsOffset(Value *Start, Value *Base, const DataLayout &DL, SetVector<Value *> &Explored, InstCombiner &IC) { @@ -567,36 +559,18 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Create all the other instructions. for (Value *Val : Explored) { - if (NewInsts.contains(Val)) continue; - if (auto *CI = dyn_cast<CastInst>(Val)) { - // Don't get rid of the intermediate variable here; the store can grow - // the map which will invalidate the reference to the input value. - Value *V = NewInsts[CI->getOperand(0)]; - NewInsts[CI] = V; - continue; - } if (auto *GEP = dyn_cast<GEPOperator>(Val)) { - Value *Index = NewInsts[GEP->getOperand(1)] ? NewInsts[GEP->getOperand(1)] - : GEP->getOperand(1); setInsertionPoint(Builder, GEP); - // Indices might need to be sign extended. GEPs will magically do - // this, but we need to do it ourselves here. - if (Index->getType()->getScalarSizeInBits() != - NewInsts[GEP->getOperand(0)]->getType()->getScalarSizeInBits()) { - Index = Builder.CreateSExtOrTrunc( - Index, NewInsts[GEP->getOperand(0)]->getType(), - GEP->getOperand(0)->getName() + ".sext"); - } - - auto *Op = NewInsts[GEP->getOperand(0)]; + Value *Op = NewInsts[GEP->getOperand(0)]; + Value *OffsetV = emitGEPOffset(&Builder, DL, GEP); if (isa<ConstantInt>(Op) && cast<ConstantInt>(Op)->isZero()) - NewInsts[GEP] = Index; + NewInsts[GEP] = OffsetV; else NewInsts[GEP] = Builder.CreateNSWAdd( - Op, Index, GEP->getOperand(0)->getName() + ".add"); + Op, OffsetV, GEP->getOperand(0)->getName() + ".add"); continue; } if (isa<PHINode>(Val)) @@ -624,23 +598,14 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, } } - PointerType *PtrTy = - ElemTy->getPointerTo(Start->getType()->getPointerAddressSpace()); for (Value *Val : Explored) { if (Val == Base) continue; - // Depending on the type, for external users we have to emit - // a GEP or a GEP + ptrtoint. setInsertionPoint(Builder, Val, false); - - // Cast base to the expected type. - Value *NewVal = Builder.CreateBitOrPointerCast( - Base, PtrTy, Start->getName() + "to.ptr"); - NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]), - Val->getName() + ".ptr"); - NewVal = Builder.CreateBitOrPointerCast( - NewVal, Val->getType(), Val->getName() + ".conv"); + // Create GEP for external users. + Value *NewVal = Builder.CreateInBoundsGEP( + Builder.getInt8Ty(), Base, NewInsts[Val], Val->getName() + ".ptr"); IC.replaceInstUsesWith(*cast<Instruction>(Val), NewVal); // Add old instruction to worklist for DCE. We don't directly remove it // here because the original compare is one of the users. @@ -650,48 +615,6 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, return NewInsts[Start]; } -/// Looks through GEPs, IntToPtrInsts and PtrToIntInsts in order to express -/// the input Value as a constant indexed GEP. Returns a pair containing -/// the GEPs Pointer and Index. -static std::pair<Value *, Value *> -getAsConstantIndexedAddress(Type *ElemTy, Value *V, const DataLayout &DL) { - Type *IndexType = IntegerType::get(V->getContext(), - DL.getIndexTypeSizeInBits(V->getType())); - - Constant *Index = ConstantInt::getNullValue(IndexType); - while (true) { - if (GEPOperator *GEP = dyn_cast<GEPOperator>(V)) { - // We accept only inbouds GEPs here to exclude the possibility of - // overflow. - if (!GEP->isInBounds()) - break; - if (GEP->hasAllConstantIndices() && GEP->getNumIndices() == 1 && - GEP->getSourceElementType() == ElemTy) { - V = GEP->getOperand(0); - Constant *GEPIndex = static_cast<Constant *>(GEP->getOperand(1)); - Index = ConstantExpr::getAdd( - Index, ConstantExpr::getSExtOrTrunc(GEPIndex, IndexType)); - continue; - } - break; - } - if (auto *CI = dyn_cast<IntToPtrInst>(V)) { - if (!CI->isNoopCast(DL)) - break; - V = CI->getOperand(0); - continue; - } - if (auto *CI = dyn_cast<PtrToIntInst>(V)) { - if (!CI->isNoopCast(DL)) - break; - V = CI->getOperand(0); - continue; - } - break; - } - return {V, Index}; -} - /// Converts (CMP GEPLHS, RHS) if this change would make RHS a constant. /// We can look through PHIs, GEPs and casts in order to determine a common base /// between GEPLHS and RHS. @@ -706,14 +629,19 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, if (!GEPLHS->hasAllConstantIndices()) return nullptr; - Type *ElemTy = GEPLHS->getSourceElementType(); - Value *PtrBase, *Index; - std::tie(PtrBase, Index) = getAsConstantIndexedAddress(ElemTy, GEPLHS, DL); + APInt Offset(DL.getIndexTypeSizeInBits(GEPLHS->getType()), 0); + Value *PtrBase = + GEPLHS->stripAndAccumulateConstantOffsets(DL, Offset, + /*AllowNonInbounds*/ false); + + // Bail if we looked through addrspacecast. + if (PtrBase->getType() != GEPLHS->getType()) + return nullptr; // The set of nodes that will take part in this transformation. SetVector<Value *> Nodes; - if (!canRewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes)) + if (!canRewriteGEPAsOffset(RHS, PtrBase, DL, Nodes)) return nullptr; // We know we can re-write this as @@ -722,13 +650,14 @@ static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, // can't have overflow on either side. We can therefore re-write // this as: // OFFSET1 cmp OFFSET2 - Value *NewRHS = rewriteGEPAsOffset(ElemTy, RHS, PtrBase, DL, Nodes, IC); + Value *NewRHS = rewriteGEPAsOffset(RHS, PtrBase, DL, Nodes, IC); // RewriteGEPAsOffset has replaced RHS and all of its uses with a re-written // GEP having PtrBase as the pointer base, and has returned in NewRHS the // offset. Since Index is the offset of LHS to the base pointer, we will now // compare the offsets instead of comparing the pointers. - return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Index, NewRHS); + return new ICmpInst(ICmpInst::getSignedPredicate(Cond), + IC.Builder.getInt(Offset), NewRHS); } /// Fold comparisons between a GEP instruction and something else. At this point @@ -844,17 +773,6 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, return transformToIndexedCompare(GEPLHS, RHS, Cond, DL, *this); } - // If one of the GEPs has all zero indices, recurse. - // FIXME: Handle vector of pointers. - if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) - return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), - ICmpInst::getSwappedPredicate(Cond), I); - - // If the other GEP has all zero indices, recurse. - // FIXME: Handle vector of pointers. - if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) - return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); - bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); if (GEPLHS->getNumOperands() == GEPRHS->getNumOperands() && GEPLHS->getSourceElementType() == GEPRHS->getSourceElementType()) { @@ -894,8 +812,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // Only lower this if the icmp is the only user of the GEP or if we expect // the result to fold to a constant! if ((GEPsInBounds || CmpInst::isEquality(Cond)) && - (isa<ConstantExpr>(GEPLHS) || GEPLHS->hasOneUse()) && - (isa<ConstantExpr>(GEPRHS) || GEPRHS->hasOneUse())) { + (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && + (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse())) { // ((gep Ptr, OFFSET1) cmp (gep Ptr, OFFSET2) ---> (OFFSET1 cmp OFFSET2) Value *L = EmitGEPOffset(GEPLHS); Value *R = EmitGEPOffset(GEPRHS); @@ -1285,9 +1203,9 @@ Instruction *InstCombinerImpl::foldICmpWithZero(ICmpInst &Cmp) { if (Pred == ICmpInst::ICMP_SGT) { Value *A, *B; if (match(Cmp.getOperand(0), m_SMin(m_Value(A), m_Value(B)))) { - if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + if (isKnownPositive(A, SQ.getWithInstruction(&Cmp))) return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + if (isKnownPositive(B, SQ.getWithInstruction(&Cmp))) return new ICmpInst(Pred, A, Cmp.getOperand(1)); } } @@ -1554,6 +1472,61 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, return nullptr; } +/// Fold icmp (trunc X), (trunc Y). +/// Fold icmp (trunc X), (zext Y). +Instruction * +InstCombinerImpl::foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, + const SimplifyQuery &Q) { + if (Cmp.isSigned()) + return nullptr; + + Value *X, *Y; + ICmpInst::Predicate Pred; + bool YIsZext = false; + // Try to match icmp (trunc X), (trunc Y) + if (match(&Cmp, m_ICmp(Pred, m_Trunc(m_Value(X)), m_Trunc(m_Value(Y))))) { + if (X->getType() != Y->getType() && + (!Cmp.getOperand(0)->hasOneUse() || !Cmp.getOperand(1)->hasOneUse())) + return nullptr; + if (!isDesirableIntType(X->getType()->getScalarSizeInBits()) && + isDesirableIntType(Y->getType()->getScalarSizeInBits())) { + std::swap(X, Y); + Pred = Cmp.getSwappedPredicate(Pred); + } + } + // Try to match icmp (trunc X), (zext Y) + else if (match(&Cmp, m_c_ICmp(Pred, m_Trunc(m_Value(X)), + m_OneUse(m_ZExt(m_Value(Y)))))) + + YIsZext = true; + else + return nullptr; + + Type *TruncTy = Cmp.getOperand(0)->getType(); + unsigned TruncBits = TruncTy->getScalarSizeInBits(); + + // If this transform will end up changing from desirable types -> undesirable + // types skip it. + if (isDesirableIntType(TruncBits) && + !isDesirableIntType(X->getType()->getScalarSizeInBits())) + return nullptr; + + // Check if the trunc is unneeded. + KnownBits KnownX = llvm::computeKnownBits(X, /*Depth*/ 0, Q); + if (KnownX.countMaxActiveBits() > TruncBits) + return nullptr; + + if (!YIsZext) { + // If Y is also a trunc, make sure it is unneeded. + KnownBits KnownY = llvm::computeKnownBits(Y, /*Depth*/ 0, Q); + if (KnownY.countMaxActiveBits() > TruncBits) + return nullptr; + } + + Value *NewY = Builder.CreateZExtOrTrunc(Y, X->getType()); + return new ICmpInst(Pred, X, NewY); +} + /// Fold icmp (xor X, Y), C. Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, @@ -1944,19 +1917,18 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return nullptr; } -/// Fold icmp eq/ne (or (xor (X1, X2), xor(X3, X4))), 0. -static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, - InstCombiner::BuilderTy &Builder) { - // Are we using xors to bitwise check for a pair or pairs of (in)equalities? - // Convert to a shorter form that has more potential to be folded even - // further. - // ((X1 ^ X2) || (X3 ^ X4)) == 0 --> (X1 == X2) && (X3 == X4) - // ((X1 ^ X2) || (X3 ^ X4)) != 0 --> (X1 != X2) || (X3 != X4) - // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) == 0 --> +/// Fold icmp eq/ne (or (xor/sub (X1, X2), xor/sub (X3, X4))), 0. +static Value *foldICmpOrXorSubChain(ICmpInst &Cmp, BinaryOperator *Or, + InstCombiner::BuilderTy &Builder) { + // Are we using xors or subs to bitwise check for a pair or pairs of + // (in)equalities? Convert to a shorter form that has more potential to be + // folded even further. + // ((X1 ^/- X2) || (X3 ^/- X4)) == 0 --> (X1 == X2) && (X3 == X4) + // ((X1 ^/- X2) || (X3 ^/- X4)) != 0 --> (X1 != X2) || (X3 != X4) + // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) == 0 --> // (X1 == X2) && (X3 == X4) && (X5 == X6) - // ((X1 ^ X2) || (X3 ^ X4) || (X5 ^ X6)) != 0 --> + // ((X1 ^/- X2) || (X3 ^/- X4) || (X5 ^/- X6)) != 0 --> // (X1 != X2) || (X3 != X4) || (X5 != X6) - // TODO: Implement for sub SmallVector<std::pair<Value *, Value *>, 2> CmpValues; SmallVector<Value *, 16> WorkList(1, Or); @@ -1967,9 +1939,16 @@ static Value *foldICmpOrXorChain(ICmpInst &Cmp, BinaryOperator *Or, if (match(OrOperatorArgument, m_OneUse(m_Xor(m_Value(Lhs), m_Value(Rhs))))) { CmpValues.emplace_back(Lhs, Rhs); - } else { - WorkList.push_back(OrOperatorArgument); + return; } + + if (match(OrOperatorArgument, + m_OneUse(m_Sub(m_Value(Lhs), m_Value(Rhs))))) { + CmpValues.emplace_back(Lhs, Rhs); + return; + } + + WorkList.push_back(OrOperatorArgument); }; Value *CurrentValue = WorkList.pop_back_val(); @@ -2082,7 +2061,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, return BinaryOperator::Create(BOpc, CmpP, CmpQ); } - if (Value *V = foldICmpOrXorChain(Cmp, Or, Builder)) + if (Value *V = foldICmpOrXorSubChain(Cmp, Or, Builder)) return replaceInstUsesWith(Cmp, V); return nullptr; @@ -2443,7 +2422,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // constant-value-based preconditions in the folds below, then we could assert // those conditions rather than checking them. This is difficult because of // undef/poison (PR34838). - if (IsAShr) { + if (IsAShr && Shr->hasOneUse()) { if (IsExact || Pred == CmpInst::ICMP_SLT || Pred == CmpInst::ICMP_ULT) { // When ShAmtC can be shifted losslessly: // icmp PRED (ashr exact X, ShAmtC), C --> icmp PRED X, (C << ShAmtC) @@ -2483,7 +2462,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, ConstantInt::getAllOnesValue(ShrTy)); } } - } else { + } else if (!IsAShr) { if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) @@ -2888,19 +2867,97 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); } +static Value *createLogicFromTable(const std::bitset<4> &Table, Value *Op0, + Value *Op1, IRBuilderBase &Builder, + bool HasOneUse) { + auto FoldConstant = [&](bool Val) { + Constant *Res = Val ? Builder.getTrue() : Builder.getFalse(); + if (Op0->getType()->isVectorTy()) + Res = ConstantVector::getSplat( + cast<VectorType>(Op0->getType())->getElementCount(), Res); + return Res; + }; + + switch (Table.to_ulong()) { + case 0: // 0 0 0 0 + return FoldConstant(false); + case 1: // 0 0 0 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateOr(Op0, Op1)) : nullptr; + case 2: // 0 0 1 0 + return HasOneUse ? Builder.CreateAnd(Builder.CreateNot(Op0), Op1) : nullptr; + case 3: // 0 0 1 1 + return Builder.CreateNot(Op0); + case 4: // 0 1 0 0 + return HasOneUse ? Builder.CreateAnd(Op0, Builder.CreateNot(Op1)) : nullptr; + case 5: // 0 1 0 1 + return Builder.CreateNot(Op1); + case 6: // 0 1 1 0 + return Builder.CreateXor(Op0, Op1); + case 7: // 0 1 1 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateAnd(Op0, Op1)) : nullptr; + case 8: // 1 0 0 0 + return Builder.CreateAnd(Op0, Op1); + case 9: // 1 0 0 1 + return HasOneUse ? Builder.CreateNot(Builder.CreateXor(Op0, Op1)) : nullptr; + case 10: // 1 0 1 0 + return Op1; + case 11: // 1 0 1 1 + return HasOneUse ? Builder.CreateOr(Builder.CreateNot(Op0), Op1) : nullptr; + case 12: // 1 1 0 0 + return Op0; + case 13: // 1 1 0 1 + return HasOneUse ? Builder.CreateOr(Op0, Builder.CreateNot(Op1)) : nullptr; + case 14: // 1 1 1 0 + return Builder.CreateOr(Op0, Op1); + case 15: // 1 1 1 1 + return FoldConstant(true); + default: + llvm_unreachable("Invalid Operation"); + } + return nullptr; +} + /// Fold icmp (add X, Y), C. Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, const APInt &C) { Value *Y = Add->getOperand(1); + Value *X = Add->getOperand(0); + + Value *Op0, *Op1; + Instruction *Ext0, *Ext1; + const CmpInst::Predicate Pred = Cmp.getPredicate(); + if (match(Add, + m_Add(m_CombineAnd(m_Instruction(Ext0), m_ZExtOrSExt(m_Value(Op0))), + m_CombineAnd(m_Instruction(Ext1), + m_ZExtOrSExt(m_Value(Op1))))) && + Op0->getType()->isIntOrIntVectorTy(1) && + Op1->getType()->isIntOrIntVectorTy(1)) { + unsigned BW = C.getBitWidth(); + std::bitset<4> Table; + auto ComputeTable = [&](bool Op0Val, bool Op1Val) { + int Res = 0; + if (Op0Val) + Res += isa<ZExtInst>(Ext0) ? 1 : -1; + if (Op1Val) + Res += isa<ZExtInst>(Ext1) ? 1 : -1; + return ICmpInst::compare(APInt(BW, Res, true), C, Pred); + }; + + Table[0] = ComputeTable(false, false); + Table[1] = ComputeTable(false, true); + Table[2] = ComputeTable(true, false); + Table[3] = ComputeTable(true, true); + if (auto *Cond = + createLogicFromTable(Table, Op0, Op1, Builder, Add->hasOneUse())) + return replaceInstUsesWith(Cmp, Cond); + } const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) return nullptr; // Fold icmp pred (add X, C2), C. - Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - const CmpInst::Predicate Pred = Cmp.getPredicate(); // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE @@ -3172,18 +3229,6 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { } } - // Test to see if the operands of the icmp are casted versions of other - // values. If the ptr->ptr cast can be stripped off both arguments, do so. - if (DstType->isPointerTy() && (isa<Constant>(Op1) || isa<BitCastInst>(Op1))) { - // If operand #1 is a bitcast instruction, it must also be a ptr->ptr cast - // so eliminate it as well. - if (auto *BC2 = dyn_cast<BitCastInst>(Op1)) - Op1 = BC2->getOperand(0); - - Op1 = Builder.CreateBitCast(Op1, SrcType); - return new ICmpInst(Pred, BCSrcOp, Op1); - } - const APInt *C; if (!match(Cmp.getOperand(1), m_APInt(C)) || !DstType->isIntegerTy() || !SrcType->isIntOrIntVectorTy()) @@ -3196,10 +3241,12 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 // Example: are all elements equal? --> are zero elements not equal? // TODO: Try harder to reduce compare of 2 freely invertible operands? - if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && - isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { - Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), DstType); - return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse()) { + if (Value *NotBCSrcOp = + getFreelyInverted(BCSrcOp, BCSrcOp->hasOneUse(), &Builder)) { + Value *Cast = Builder.CreateBitCast(NotBCSrcOp, DstType); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(DstType)); + } } // If this is checking if all elements of an extended vector are clear or not, @@ -3878,21 +3925,9 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { return nullptr; switch (LHSI->getOpcode()) { - case Instruction::GetElementPtr: - // icmp pred GEP (P, int 0, int 0, int 0), null -> icmp pred P, null - if (RHSC->isNullValue() && - cast<GetElementPtrInst>(LHSI)->hasAllZeroIndices()) - return new ICmpInst( - I.getPredicate(), LHSI->getOperand(0), - Constant::getNullValue(LHSI->getOperand(0)->getType())); - break; case Instruction::PHI: - // Only fold icmp into the PHI if the phi and icmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) - return NV; + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; break; case Instruction::IntToPtr: // icmp pred inttoptr(X), null -> icmp pred X, 0 @@ -4243,7 +4278,12 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; - NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + if (NewShAmt->getType() != WidestTy) { + NewShAmt = + ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, WidestTy, SQ.DL); + if (!NewShAmt) + return nullptr; + } unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); // Is the new shift amount smaller than the bit width? @@ -4424,6 +4464,65 @@ static Instruction *foldICmpXNegX(ICmpInst &I, return nullptr; } +static Instruction *foldICmpAndXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + // Normalize and operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_And(m_Specific(Op0), m_Value()))) { + std::swap(Op0, Op1); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + if (!match(Op0, m_c_And(m_Specific(Op1), m_Value(A)))) + return nullptr; + + // (icmp (X & Y) u< X --> (X & Y) != X + if (Pred == ICmpInst::ICMP_ULT) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + // (icmp (X & Y) u>= X --> (X & Y) == X + if (Pred == ICmpInst::ICMP_UGE) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + + return nullptr; +} + +static Instruction *foldICmpOrXX(ICmpInst &I, const SimplifyQuery &Q, + InstCombinerImpl &IC) { + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; + + // Normalize or operand as operand 0. + CmpInst::Predicate Pred = I.getPredicate(); + if (match(Op1, m_c_Or(m_Specific(Op0), m_Value(A)))) { + std::swap(Op0, Op1); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else if (!match(Op0, m_c_Or(m_Specific(Op1), m_Value(A)))) { + return nullptr; + } + + // icmp (X | Y) u<= X --> (X | Y) == X + if (Pred == ICmpInst::ICMP_ULE) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); + + // icmp (X | Y) u> X --> (X | Y) != X + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + + if (ICmpInst::isEquality(Pred) && Op0->hasOneUse()) { + // icmp (X | Y) eq/ne Y --> (X & ~Y) eq/ne 0 if Y is freely invertible + if (Value *NotOp1 = + IC.getFreelyInverted(Op1, Op1->hasOneUse(), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateAnd(A, NotOp1), + Constant::getNullValue(Op1->getType())); + // icmp (X | Y) eq/ne Y --> (~X | Y) eq/ne -1 if X is freely invertible. + if (Value *NotA = IC.getFreelyInverted(A, A->hasOneUse(), &IC.Builder)) + return new ICmpInst(Pred, IC.Builder.CreateOr(Op1, NotA), + Constant::getAllOnesValue(Op1->getType())); + } + return nullptr; +} + static Instruction *foldICmpXorXX(ICmpInst &I, const SimplifyQuery &Q, InstCombinerImpl &IC) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1), *A; @@ -4746,6 +4845,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (Instruction * R = foldICmpXorXX(I, Q, *this)) return R; + if (Instruction *R = foldICmpOrXX(I, Q, *this)) + return R; { // Try to remove shared multiplier from comparison: @@ -4915,6 +5016,9 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldICmpAndXX(I, Q, *this)) + return R; + if (Value *V = foldICmpWithTruncSignExtendedVal(I, Builder)) return replaceInstUsesWith(I, V); @@ -4924,88 +5028,153 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, return nullptr; } -/// Fold icmp Pred min|max(X, Y), X. -static Instruction *foldICmpWithMinMax(ICmpInst &Cmp) { - ICmpInst::Predicate Pred = Cmp.getPredicate(); - Value *Op0 = Cmp.getOperand(0); - Value *X = Cmp.getOperand(1); - - // Canonicalize minimum or maximum operand to LHS of the icmp. - if (match(X, m_c_SMin(m_Specific(Op0), m_Value())) || - match(X, m_c_SMax(m_Specific(Op0), m_Value())) || - match(X, m_c_UMin(m_Specific(Op0), m_Value())) || - match(X, m_c_UMax(m_Specific(Op0), m_Value()))) { - std::swap(Op0, X); - Pred = Cmp.getSwappedPredicate(); - } - - Value *Y; - if (match(Op0, m_c_SMin(m_Specific(X), m_Value(Y)))) { - // smin(X, Y) == X --> X s<= Y - // smin(X, Y) s>= X --> X s<= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); - - // smin(X, Y) != X --> X s> Y - // smin(X, Y) s< X --> X s> Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SLT) - return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); - - // These cases should be handled in InstSimplify: - // smin(X, Y) s<= X --> true - // smin(X, Y) s> X --> false +/// Fold icmp Pred min|max(X, Y), Z. +Instruction * +InstCombinerImpl::foldICmpWithMinMaxImpl(Instruction &I, + MinMaxIntrinsic *MinMax, Value *Z, + ICmpInst::Predicate Pred) { + Value *X = MinMax->getLHS(); + Value *Y = MinMax->getRHS(); + if (ICmpInst::isSigned(Pred) && !MinMax->isSigned()) return nullptr; - } - - if (match(Op0, m_c_SMax(m_Specific(X), m_Value(Y)))) { - // smax(X, Y) == X --> X s>= Y - // smax(X, Y) s<= X --> X s>= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); - - // smax(X, Y) != X --> X s< Y - // smax(X, Y) s> X --> X s< Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); - - // These cases should be handled in InstSimplify: - // smax(X, Y) s>= X --> true - // smax(X, Y) s< X --> false + if (ICmpInst::isUnsigned(Pred) && MinMax->isSigned()) return nullptr; + SimplifyQuery Q = SQ.getWithInstruction(&I); + auto IsCondKnownTrue = [](Value *Val) -> std::optional<bool> { + if (!Val) + return std::nullopt; + if (match(Val, m_One())) + return true; + if (match(Val, m_Zero())) + return false; + return std::nullopt; + }; + auto CmpXZ = IsCondKnownTrue(simplifyICmpInst(Pred, X, Z, Q)); + auto CmpYZ = IsCondKnownTrue(simplifyICmpInst(Pred, Y, Z, Q)); + if (!CmpXZ.has_value() && !CmpYZ.has_value()) + return nullptr; + if (!CmpXZ.has_value()) { + std::swap(X, Y); + std::swap(CmpXZ, CmpYZ); } - if (match(Op0, m_c_UMin(m_Specific(X), m_Value(Y)))) { - // umin(X, Y) == X --> X u<= Y - // umin(X, Y) u>= X --> X u<= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_UGE) - return new ICmpInst(ICmpInst::ICMP_ULE, X, Y); - - // umin(X, Y) != X --> X u> Y - // umin(X, Y) u< X --> X u> Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_ULT) - return new ICmpInst(ICmpInst::ICMP_UGT, X, Y); + auto FoldIntoCmpYZ = [&]() -> Instruction * { + if (CmpYZ.has_value()) + return replaceInstUsesWith(I, ConstantInt::getBool(I.getType(), *CmpYZ)); + return ICmpInst::Create(Instruction::ICmp, Pred, Y, Z); + }; - // These cases should be handled in InstSimplify: - // umin(X, Y) u<= X --> true - // umin(X, Y) u> X --> false - return nullptr; + switch (Pred) { + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_NE: { + // If X == Z: + // Expr Result + // min(X, Y) == Z X <= Y + // max(X, Y) == Z X >= Y + // min(X, Y) != Z X > Y + // max(X, Y) != Z X < Y + if ((Pred == ICmpInst::ICMP_EQ) == *CmpXZ) { + ICmpInst::Predicate NewPred = + ICmpInst::getNonStrictPredicate(MinMax->getPredicate()); + if (Pred == ICmpInst::ICMP_NE) + NewPred = ICmpInst::getInversePredicate(NewPred); + return ICmpInst::Create(Instruction::ICmp, NewPred, X, Y); + } + // Otherwise (X != Z): + ICmpInst::Predicate NewPred = MinMax->getPredicate(); + auto MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); + if (!MinMaxCmpXZ.has_value()) { + std::swap(X, Y); + std::swap(CmpXZ, CmpYZ); + // Re-check pre-condition X != Z + if (!CmpXZ.has_value() || (Pred == ICmpInst::ICMP_EQ) == *CmpXZ) + break; + MinMaxCmpXZ = IsCondKnownTrue(simplifyICmpInst(NewPred, X, Z, Q)); + } + if (!MinMaxCmpXZ.has_value()) + break; + if (*MinMaxCmpXZ) { + // Expr Fact Result + // min(X, Y) == Z X < Z false + // max(X, Y) == Z X > Z false + // min(X, Y) != Z X < Z true + // max(X, Y) != Z X > Z true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred == ICmpInst::ICMP_NE)); + } else { + // Expr Fact Result + // min(X, Y) == Z X > Z Y == Z + // max(X, Y) == Z X < Z Y == Z + // min(X, Y) != Z X > Z Y != Z + // max(X, Y) != Z X < Z Y != Z + return FoldIntoCmpYZ(); + } + break; + } + case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_SLE: + case ICmpInst::ICMP_ULE: + case ICmpInst::ICMP_SGT: + case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_SGE: + case ICmpInst::ICMP_UGE: { + bool IsSame = MinMax->getPredicate() == ICmpInst::getStrictPredicate(Pred); + if (*CmpXZ) { + if (IsSame) { + // Expr Fact Result + // min(X, Y) < Z X < Z true + // min(X, Y) <= Z X <= Z true + // max(X, Y) > Z X > Z true + // max(X, Y) >= Z X >= Z true + return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); + } else { + // Expr Fact Result + // max(X, Y) < Z X < Z Y < Z + // max(X, Y) <= Z X <= Z Y <= Z + // min(X, Y) > Z X > Z Y > Z + // min(X, Y) >= Z X >= Z Y >= Z + return FoldIntoCmpYZ(); + } + } else { + if (IsSame) { + // Expr Fact Result + // min(X, Y) < Z X >= Z Y < Z + // min(X, Y) <= Z X > Z Y <= Z + // max(X, Y) > Z X <= Z Y > Z + // max(X, Y) >= Z X < Z Y >= Z + return FoldIntoCmpYZ(); + } else { + // Expr Fact Result + // max(X, Y) < Z X >= Z false + // max(X, Y) <= Z X > Z false + // min(X, Y) > Z X <= Z false + // min(X, Y) >= Z X < Z false + return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + } + } + break; + } + default: + break; } - if (match(Op0, m_c_UMax(m_Specific(X), m_Value(Y)))) { - // umax(X, Y) == X --> X u>= Y - // umax(X, Y) u<= X --> X u>= Y - if (Pred == CmpInst::ICMP_EQ || Pred == CmpInst::ICMP_ULE) - return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); + return nullptr; +} +Instruction *InstCombinerImpl::foldICmpWithMinMax(ICmpInst &Cmp) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Lhs = Cmp.getOperand(0); + Value *Rhs = Cmp.getOperand(1); - // umax(X, Y) != X --> X u< Y - // umax(X, Y) u> X --> X u< Y - if (Pred == CmpInst::ICMP_NE || Pred == CmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); + if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Lhs)) { + if (Instruction *Res = foldICmpWithMinMaxImpl(Cmp, MinMax, Rhs, Pred)) + return Res; + } - // These cases should be handled in InstSimplify: - // umax(X, Y) u>= X --> true - // umax(X, Y) u< X --> false - return nullptr; + if (MinMaxIntrinsic *MinMax = dyn_cast<MinMaxIntrinsic>(Rhs)) { + if (Instruction *Res = foldICmpWithMinMaxImpl( + Cmp, MinMax, Lhs, ICmpInst::getSwappedPredicate(Pred))) + return Res; } return nullptr; @@ -5173,35 +5342,6 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); } - // Test if 2 values have different or same signbits: - // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 - // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 - // (X s>> BitWidth - 1) == sext (Y s> -1) --> (X ^ Y) < 0 - // (X s>> BitWidth - 1) != sext (Y s> -1) --> (X ^ Y) > -1 - Instruction *ExtI; - if (match(Op1, m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(A)))) && - (Op0->hasOneUse() || Op1->hasOneUse())) { - unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); - Instruction *ShiftI; - Value *X, *Y; - ICmpInst::Predicate Pred2; - if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), - m_Shr(m_Value(X), - m_SpecificIntAllowUndef(OpWidth - 1)))) && - match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && - Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { - unsigned ExtOpc = ExtI->getOpcode(); - unsigned ShiftOpc = ShiftI->getOpcode(); - if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || - (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { - Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); - Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) - : Builder.CreateIsNotNeg(Xor); - return replaceInstUsesWith(I, R); - } - } - } - // (A >> C) == (B >> C) --> (A^B) u< (1 << C) // For lshr and ashr pairs. const APInt *AP1, *AP2; @@ -5307,6 +5447,40 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Pred, A, Builder.CreateIntrinsic(Op0->getType(), Intrinsic::fshl, {A, A, B})); + // Canonicalize: + // icmp eq/ne OneUse(A ^ Cst), B --> icmp eq/ne (A ^ B), Cst + Constant *Cst; + if (match(&I, m_c_ICmp(PredUnused, + m_OneUse(m_Xor(m_Value(A), m_ImmConstant(Cst))), + m_CombineAnd(m_Value(B), m_Unless(m_ImmConstant()))))) + return new ICmpInst(Pred, Builder.CreateXor(A, B), Cst); + + { + // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) + auto m_Matcher = + m_CombineOr(m_CombineOr(m_c_Add(m_Value(B), m_Deferred(A)), + m_c_Xor(m_Value(B), m_Deferred(A))), + m_Sub(m_Value(B), m_Deferred(A))); + std::optional<bool> IsZero = std::nullopt; + if (match(&I, m_c_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), + m_Deferred(A)))) + IsZero = false; + // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) + else if (match(&I, + m_ICmp(PredUnused, m_OneUse(m_c_And(m_Value(A), m_Matcher)), + m_Zero()))) + IsZero = true; + + if (IsZero && isKnownToBeAPowerOfTwo(A, /* OrZero */ true, /*Depth*/ 0, &I)) + // (icmp eq/ne (and (add/sub/xor X, P2), P2), P2) + // -> (icmp eq/ne (and X, P2), 0) + // (icmp eq/ne (and (add/sub/xor X, P2), P2), 0) + // -> (icmp eq/ne (and X, P2), P2) + return new ICmpInst(Pred, Builder.CreateAnd(B, A), + *IsZero ? A + : ConstantInt::getNullValue(A->getType())); + } + return nullptr; } @@ -5383,8 +5557,8 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { // icmp Pred (ext X), (ext Y) Value *Y; if (match(ICmp.getOperand(1), m_ZExtOrSExt(m_Value(Y)))) { - bool IsZext0 = isa<ZExtOperator>(ICmp.getOperand(0)); - bool IsZext1 = isa<ZExtOperator>(ICmp.getOperand(1)); + bool IsZext0 = isa<ZExtInst>(ICmp.getOperand(0)); + bool IsZext1 = isa<ZExtInst>(ICmp.getOperand(1)); if (IsZext0 != IsZext1) { // If X and Y and both i1 @@ -5396,11 +5570,16 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Builder.CreateOr(X, Y), Constant::getNullValue(X->getType())); - // If we have mismatched casts, treat the zext of a non-negative source as - // a sext to simulate matching casts. Otherwise, we are done. - // TODO: Can we handle some predicates (equality) without non-negative? - if ((IsZext0 && isKnownNonNegative(X, DL, 0, &AC, &ICmp, &DT)) || - (IsZext1 && isKnownNonNegative(Y, DL, 0, &AC, &ICmp, &DT))) + // If we have mismatched casts and zext has the nneg flag, we can + // treat the "zext nneg" as "sext". Otherwise, we cannot fold and quit. + + auto *NonNegInst0 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(0)); + auto *NonNegInst1 = dyn_cast<PossiblyNonNegInst>(ICmp.getOperand(1)); + + bool IsNonNeg0 = NonNegInst0 && NonNegInst0->hasNonNeg(); + bool IsNonNeg1 = NonNegInst1 && NonNegInst1->hasNonNeg(); + + if ((IsZext0 && IsNonNeg0) || (IsZext1 && IsNonNeg1)) IsSignedExt = true; else return nullptr; @@ -5442,25 +5621,20 @@ Instruction *InstCombinerImpl::foldICmpWithZextOrSext(ICmpInst &ICmp) { if (!C) return nullptr; - // Compute the constant that would happen if we truncated to SrcTy then - // re-extended to DestTy. + // If a lossless truncate is possible... Type *SrcTy = CastOp0->getSrcTy(); - Type *DestTy = CastOp0->getDestTy(); - Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); - Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); - - // If the re-extended constant didn't change... - if (Res2 == C) { + Constant *Res = getLosslessTrunc(C, SrcTy, CastOp0->getOpcode()); + if (Res) { if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), X, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res); // A signed comparison of sign extended values simplifies into a // signed comparison. if (IsSignedExt && IsSignedCmp) - return new ICmpInst(ICmp.getPredicate(), X, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res); } // The re-extended constant changed, partly changed (in the case of a vector), @@ -5518,13 +5692,8 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { Value *NewOp1 = nullptr; if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { Value *PtrSrc = PtrToIntOp1->getOperand(0); - if (PtrSrc->getType()->getPointerAddressSpace() == - Op0Src->getType()->getPointerAddressSpace()) { + if (PtrSrc->getType() == Op0Src->getType()) NewOp1 = PtrToIntOp1->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (Op0Src->getType() != NewOp1->getType()) - NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); - } } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); } @@ -5641,22 +5810,20 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, /// \returns Instruction which must replace the compare instruction, NULL if no /// replacement required. static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, - Value *OtherVal, + const APInt *OtherVal, InstCombinerImpl &IC) { // Don't bother doing this transformation for pointers, don't do it for // vectors. if (!isa<IntegerType>(MulVal->getType())) return nullptr; - assert(I.getOperand(0) == MulVal || I.getOperand(1) == MulVal); - assert(I.getOperand(0) == OtherVal || I.getOperand(1) == OtherVal); auto *MulInstr = dyn_cast<Instruction>(MulVal); if (!MulInstr) return nullptr; assert(MulInstr->getOpcode() == Instruction::Mul); - auto *LHS = cast<ZExtOperator>(MulInstr->getOperand(0)), - *RHS = cast<ZExtOperator>(MulInstr->getOperand(1)); + auto *LHS = cast<ZExtInst>(MulInstr->getOperand(0)), + *RHS = cast<ZExtInst>(MulInstr->getOperand(1)); assert(LHS->getOpcode() == Instruction::ZExt); assert(RHS->getOpcode() == Instruction::ZExt); Value *A = LHS->getOperand(0), *B = RHS->getOperand(0); @@ -5709,70 +5876,26 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, // Recognize patterns switch (I.getPredicate()) { - case ICmpInst::ICMP_EQ: - case ICmpInst::ICMP_NE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp eq/neq mulval, and(mulval, mask), mask selects low MulWidth bits. - ConstantInt *CI; - Value *ValToMask; - if (match(OtherVal, m_And(m_Value(ValToMask), m_ConstantInt(CI)))) { - if (ValToMask != MulVal) - return nullptr; - const APInt &CVal = CI->getValue() + 1; - if (CVal.isPowerOf2()) { - unsigned MaskWidth = CVal.logBase2(); - if (MaskWidth == MulWidth) - break; // Recognized - } - } - return nullptr; - - case ICmpInst::ICMP_UGT: + case ICmpInst::ICMP_UGT: { // Recognize pattern: // mulval = mul(zext A, zext B) // cmp ugt mulval, max - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getMaxValue(MulWidth); - MaxVal = MaxVal.zext(CI->getBitWidth()); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } - return nullptr; - - case ICmpInst::ICMP_UGE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp uge mulval, max+1 - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } - return nullptr; - - case ICmpInst::ICMP_ULE: - // Recognize pattern: - // mulval = mul(zext A, zext B) - // cmp ule mulval, max - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getMaxValue(MulWidth); - MaxVal = MaxVal.zext(CI->getBitWidth()); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } + APInt MaxVal = APInt::getMaxValue(MulWidth); + MaxVal = MaxVal.zext(OtherVal->getBitWidth()); + if (MaxVal.eq(*OtherVal)) + break; // Recognized return nullptr; + } - case ICmpInst::ICMP_ULT: + case ICmpInst::ICMP_ULT: { // Recognize pattern: // mulval = mul(zext A, zext B) // cmp ule mulval, max + 1 - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) { - APInt MaxVal = APInt::getOneBitSet(CI->getBitWidth(), MulWidth); - if (MaxVal.eq(CI->getValue())) - break; // Recognized - } + APInt MaxVal = APInt::getOneBitSet(OtherVal->getBitWidth(), MulWidth); + if (MaxVal.eq(*OtherVal)) + break; // Recognized return nullptr; + } default: return nullptr; @@ -5798,7 +5921,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, if (MulVal->hasNUsesOrMore(2)) { Value *Mul = Builder.CreateExtractValue(Call, 0, "umul.value"); for (User *U : make_early_inc_range(MulVal->users())) { - if (U == &I || U == OtherVal) + if (U == &I) continue; if (TruncInst *TI = dyn_cast<TruncInst>(U)) { if (TI->getType()->getPrimitiveSizeInBits() == MulWidth) @@ -5819,34 +5942,10 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, IC.addToWorklist(cast<Instruction>(U)); } } - if (isa<Instruction>(OtherVal)) - IC.addToWorklist(cast<Instruction>(OtherVal)); // The original icmp gets replaced with the overflow value, maybe inverted // depending on predicate. - bool Inverse = false; - switch (I.getPredicate()) { - case ICmpInst::ICMP_NE: - break; - case ICmpInst::ICMP_EQ: - Inverse = true; - break; - case ICmpInst::ICMP_UGT: - case ICmpInst::ICMP_UGE: - if (I.getOperand(0) == MulVal) - break; - Inverse = true; - break; - case ICmpInst::ICMP_ULT: - case ICmpInst::ICMP_ULE: - if (I.getOperand(1) == MulVal) - break; - Inverse = true; - break; - default: - llvm_unreachable("Unexpected predicate"); - } - if (Inverse) { + if (I.getPredicate() == ICmpInst::ICMP_ULT) { Value *Res = Builder.CreateExtractValue(Call, 1); return BinaryOperator::CreateNot(Res); } @@ -6015,13 +6114,19 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { KnownBits Op0Known(BitWidth); KnownBits Op1Known(BitWidth); - if (SimplifyDemandedBits(&I, 0, - getDemandedBitsLHSMask(I, BitWidth), - Op0Known, 0)) - return &I; + { + // Don't use dominating conditions when folding icmp using known bits. This + // may convert signed into unsigned predicates in ways that other passes + // (especially IndVarSimplify) may not be able to reliably undo. + SQ.DC = nullptr; + auto _ = make_scope_exit([&]() { SQ.DC = &DC; }); + if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth), + Op0Known, 0)) + return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) - return &I; + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) + return &I; + } // Given the known and unknown bits, compute a range that the LHS could be // in. Compute the Min, Max and RHS values based on the known bits. For the @@ -6269,57 +6374,70 @@ Instruction *InstCombinerImpl::foldICmpUsingBoolRange(ICmpInst &I) { Y->getType()->isIntOrIntVectorTy(1) && Pred == ICmpInst::ICMP_ULE) return BinaryOperator::CreateOr(Builder.CreateIsNull(X), Y); + // icmp eq/ne X, (zext/sext (icmp eq/ne X, C)) + ICmpInst::Predicate Pred1, Pred2; const APInt *C; - if (match(I.getOperand(0), m_c_Add(m_ZExt(m_Value(X)), m_SExt(m_Value(Y)))) && - match(I.getOperand(1), m_APInt(C)) && - X->getType()->isIntOrIntVectorTy(1) && - Y->getType()->isIntOrIntVectorTy(1)) { - unsigned BitWidth = C->getBitWidth(); - Pred = I.getPredicate(); - APInt Zero = APInt::getZero(BitWidth); - APInt MinusOne = APInt::getAllOnes(BitWidth); - APInt One(BitWidth, 1); - if ((C->sgt(Zero) && Pred == ICmpInst::ICMP_SGT) || - (C->slt(Zero) && Pred == ICmpInst::ICMP_SLT)) - return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if ((C->sgt(One) && Pred == ICmpInst::ICMP_SLT) || - (C->slt(MinusOne) && Pred == ICmpInst::ICMP_SGT)) - return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - - if (I.getOperand(0)->hasOneUse()) { - APInt NewC = *C; - // canonicalize predicate to eq/ne - if ((*C == Zero && Pred == ICmpInst::ICMP_SLT) || - (*C != Zero && *C != MinusOne && Pred == ICmpInst::ICMP_UGT)) { - // x s< 0 in [-1, 1] --> x == -1 - // x u> 1(or any const !=0 !=-1) in [-1, 1] --> x == -1 - NewC = MinusOne; - Pred = ICmpInst::ICMP_EQ; - } else if ((*C == MinusOne && Pred == ICmpInst::ICMP_SGT) || - (*C != Zero && *C != One && Pred == ICmpInst::ICMP_ULT)) { - // x s> -1 in [-1, 1] --> x != -1 - // x u< -1 in [-1, 1] --> x != -1 - Pred = ICmpInst::ICMP_NE; - } else if (*C == Zero && Pred == ICmpInst::ICMP_SGT) { - // x s> 0 in [-1, 1] --> x == 1 - NewC = One; - Pred = ICmpInst::ICMP_EQ; - } else if (*C == One && Pred == ICmpInst::ICMP_SLT) { - // x s< 1 in [-1, 1] --> x != 1 - Pred = ICmpInst::ICMP_NE; + Instruction *ExtI; + if (match(&I, m_c_ICmp(Pred1, m_Value(X), + m_CombineAnd(m_Instruction(ExtI), + m_ZExtOrSExt(m_ICmp(Pred2, m_Deferred(X), + m_APInt(C)))))) && + ICmpInst::isEquality(Pred1) && ICmpInst::isEquality(Pred2)) { + bool IsSExt = ExtI->getOpcode() == Instruction::SExt; + bool HasOneUse = ExtI->hasOneUse() && ExtI->getOperand(0)->hasOneUse(); + auto CreateRangeCheck = [&] { + Value *CmpV1 = + Builder.CreateICmp(Pred1, X, Constant::getNullValue(X->getType())); + Value *CmpV2 = Builder.CreateICmp( + Pred1, X, ConstantInt::getSigned(X->getType(), IsSExt ? -1 : 1)); + return BinaryOperator::Create( + Pred1 == ICmpInst::ICMP_EQ ? Instruction::Or : Instruction::And, + CmpV1, CmpV2); + }; + if (C->isZero()) { + if (Pred2 == ICmpInst::ICMP_EQ) { + // icmp eq X, (zext/sext (icmp eq X, 0)) --> false + // icmp ne X, (zext/sext (icmp eq X, 0)) --> true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); + } else if (!IsSExt || HasOneUse) { + // icmp eq X, (zext (icmp ne X, 0)) --> X == 0 || X == 1 + // icmp ne X, (zext (icmp ne X, 0)) --> X != 0 && X != 1 + // icmp eq X, (sext (icmp ne X, 0)) --> X == 0 || X == -1 + // icmp ne X, (sext (icmp ne X, 0)) --> X != 0 && X == -1 + return CreateRangeCheck(); } - - if (NewC == MinusOne) { - if (Pred == ICmpInst::ICMP_EQ) - return BinaryOperator::CreateAnd(Builder.CreateNot(X), Y); - if (Pred == ICmpInst::ICMP_NE) - return BinaryOperator::CreateOr(X, Builder.CreateNot(Y)); - } else if (NewC == One) { - if (Pred == ICmpInst::ICMP_EQ) - return BinaryOperator::CreateAnd(X, Builder.CreateNot(Y)); - if (Pred == ICmpInst::ICMP_NE) - return BinaryOperator::CreateOr(Builder.CreateNot(X), Y); + } else if (IsSExt ? C->isAllOnes() : C->isOne()) { + if (Pred2 == ICmpInst::ICMP_NE) { + // icmp eq X, (zext (icmp ne X, 1)) --> false + // icmp ne X, (zext (icmp ne X, 1)) --> true + // icmp eq X, (sext (icmp ne X, -1)) --> false + // icmp ne X, (sext (icmp ne X, -1)) --> true + return replaceInstUsesWith( + I, ConstantInt::getBool(I.getType(), Pred1 == ICmpInst::ICMP_NE)); + } else if (!IsSExt || HasOneUse) { + // icmp eq X, (zext (icmp eq X, 1)) --> X == 0 || X == 1 + // icmp ne X, (zext (icmp eq X, 1)) --> X != 0 && X != 1 + // icmp eq X, (sext (icmp eq X, -1)) --> X == 0 || X == -1 + // icmp ne X, (sext (icmp eq X, -1)) --> X != 0 && X == -1 + return CreateRangeCheck(); } + } else { + // when C != 0 && C != 1: + // icmp eq X, (zext (icmp eq X, C)) --> icmp eq X, 0 + // icmp eq X, (zext (icmp ne X, C)) --> icmp eq X, 1 + // icmp ne X, (zext (icmp eq X, C)) --> icmp ne X, 0 + // icmp ne X, (zext (icmp ne X, C)) --> icmp ne X, 1 + // when C != 0 && C != -1: + // icmp eq X, (sext (icmp eq X, C)) --> icmp eq X, 0 + // icmp eq X, (sext (icmp ne X, C)) --> icmp eq X, -1 + // icmp ne X, (sext (icmp eq X, C)) --> icmp ne X, 0 + // icmp ne X, (sext (icmp ne X, C)) --> icmp ne X, -1 + return ICmpInst::Create( + Instruction::ICmp, Pred1, X, + ConstantInt::getSigned(X->getType(), Pred2 == ICmpInst::ICMP_NE + ? (IsSExt ? -1 : 1) + : 0)); } } @@ -6783,6 +6901,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; + if (Instruction *Res = foldICmpTruncWithTruncOrExt(I, Q)) + return Res; + // Test if the ICmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -6913,38 +7034,40 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { return Res; { - Value *A, *B; - // Transform (A & ~B) == 0 --> (A & B) != 0 - // and (A & ~B) != 0 --> (A & B) == 0 + Value *X, *Y; + // Transform (X & ~Y) == 0 --> (X & Y) != 0 + // and (X & ~Y) != 0 --> (X & Y) == 0 // if A is a power of 2. - if (match(Op0, m_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Zero()) && - isKnownToBeAPowerOfTwo(A, false, 0, &I) && I.isEquality()) - return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(A, B), + if (match(Op0, m_And(m_Value(X), m_Not(m_Value(Y)))) && + match(Op1, m_Zero()) && isKnownToBeAPowerOfTwo(X, false, 0, &I) && + I.isEquality()) + return new ICmpInst(I.getInversePredicate(), Builder.CreateAnd(X, Y), Op1); - // ~X < ~Y --> Y < X - // ~X < C --> X > ~C - if (match(Op0, m_Not(m_Value(A)))) { - if (match(Op1, m_Not(m_Value(B)))) - return new ICmpInst(I.getPredicate(), B, A); - - const APInt *C; - if (match(Op1, m_APInt(C))) - return new ICmpInst(I.getSwappedPredicate(), A, - ConstantInt::get(Op1->getType(), ~(*C))); + // Op0 pred Op1 -> ~Op1 pred ~Op0, if this allows us to drop an instruction. + if (Op0->getType()->isIntOrIntVectorTy()) { + bool ConsumesOp0, ConsumesOp1; + if (isFreeToInvert(Op0, Op0->hasOneUse(), ConsumesOp0) && + isFreeToInvert(Op1, Op1->hasOneUse(), ConsumesOp1) && + (ConsumesOp0 || ConsumesOp1)) { + Value *InvOp0 = getFreelyInverted(Op0, Op0->hasOneUse(), &Builder); + Value *InvOp1 = getFreelyInverted(Op1, Op1->hasOneUse(), &Builder); + assert(InvOp0 && InvOp1 && + "Mismatch between isFreeToInvert and getFreelyInverted"); + return new ICmpInst(I.getSwappedPredicate(), InvOp0, InvOp1); + } } Instruction *AddI = nullptr; - if (match(&I, m_UAddWithOverflow(m_Value(A), m_Value(B), + if (match(&I, m_UAddWithOverflow(m_Value(X), m_Value(Y), m_Instruction(AddI))) && - isa<IntegerType>(A->getType())) { + isa<IntegerType>(X->getType())) { Value *Result; Constant *Overflow; // m_UAddWithOverflow can match patterns that do not include an explicit // "add" instruction, so check the opcode of the matched op. if (AddI->getOpcode() == Instruction::Add && - OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, A, B, *AddI, + OptimizeOverflowCheck(Instruction::Add, /*Signed*/ false, X, Y, *AddI, Result, Overflow)) { replaceInstUsesWith(*AddI, Result); eraseInstFromFunction(*AddI); @@ -6952,14 +7075,37 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } - // (zext a) * (zext b) --> llvm.umul.with.overflow. - if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) + // (zext X) * (zext Y) --> llvm.umul.with.overflow. + if (match(Op0, m_NUWMul(m_ZExt(m_Value(X)), m_ZExt(m_Value(Y)))) && + match(Op1, m_APInt(C))) { + if (Instruction *R = processUMulZExtIdiom(I, Op0, C, *this)) return R; } - if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { - if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) - return R; + + // Signbit test folds + // Fold (X u>> BitWidth - 1 Pred ZExt(i1)) --> X s< 0 Pred i1 + // Fold (X s>> BitWidth - 1 Pred SExt(i1)) --> X s< 0 Pred i1 + Instruction *ExtI; + if ((I.isUnsigned() || I.isEquality()) && + match(Op1, + m_CombineAnd(m_Instruction(ExtI), m_ZExtOrSExt(m_Value(Y)))) && + Y->getType()->getScalarSizeInBits() == 1 && + (Op0->hasOneUse() || Op1->hasOneUse())) { + unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Instruction *ShiftI; + if (match(Op0, m_CombineAnd(m_Instruction(ShiftI), + m_Shr(m_Value(X), m_SpecificIntAllowUndef( + OpWidth - 1))))) { + unsigned ExtOpc = ExtI->getOpcode(); + unsigned ShiftOpc = ShiftI->getOpcode(); + if ((ExtOpc == Instruction::ZExt && ShiftOpc == Instruction::LShr) || + (ExtOpc == Instruction::SExt && ShiftOpc == Instruction::AShr)) { + Value *SLTZero = + Builder.CreateICmpSLT(X, Constant::getNullValue(X->getType())); + Value *Cmp = Builder.CreateICmp(Pred, SLTZero, Y, I.getName()); + return replaceInstUsesWith(I, Cmp); + } + } } } @@ -7177,17 +7323,14 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, } // Okay, now we know that the FP constant fits in the range [SMIN, SMAX] or - // [0, UMAX], but it may still be fractional. See if it is fractional by - // casting the FP value to the integer value and back, checking for equality. + // [0, UMAX], but it may still be fractional. Check whether this is the case + // using the IsExact flag. // Don't do this for zero, because -0.0 is not fractional. - Constant *RHSInt = LHSUnsigned - ? ConstantExpr::getFPToUI(RHSC, IntTy) - : ConstantExpr::getFPToSI(RHSC, IntTy); + APSInt RHSInt(IntWidth, LHSUnsigned); + bool IsExact; + RHS.convertToInteger(RHSInt, APFloat::rmTowardZero, &IsExact); if (!RHS.isZero()) { - bool Equal = LHSUnsigned - ? ConstantExpr::getUIToFP(RHSInt, RHSC->getType()) == RHSC - : ConstantExpr::getSIToFP(RHSInt, RHSC->getType()) == RHSC; - if (!Equal) { + if (!IsExact) { // If we had a comparison against a fractional value, we have to adjust // the compare predicate and sometimes the value. RHSC is rounded towards // zero at this point. @@ -7253,7 +7396,7 @@ Instruction *InstCombinerImpl::foldFCmpIntToFPConst(FCmpInst &I, // Lower this FP comparison into an appropriate integer version of the // comparison. - return new ICmpInst(Pred, LHSI->getOperand(0), RHSInt); + return new ICmpInst(Pred, LHSI->getOperand(0), Builder.getInt(RHSInt)); } /// Fold (C / X) < 0.0 --> X < 0.0 if possible. Swap predicate if necessary. @@ -7532,12 +7675,8 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op0, m_Instruction(LHSI)) && match(Op1, m_Constant(RHSC))) { switch (LHSI->getOpcode()) { case Instruction::PHI: - // Only fold fcmp into the PHI if the phi and fcmp are in the same - // block. If in the same block, we're encouraging jump threading. If - // not, we are just pessimizing the code by making an i1 phi. - if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) - return NV; + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) + return NV; break; case Instruction::SIToFP: case Instruction::UIToFP: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 701579e1de48..bb620ad8d41c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -16,6 +16,7 @@ #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/PostOrderIterator.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" @@ -73,6 +74,10 @@ public: virtual ~InstCombinerImpl() = default; + /// Perform early cleanup and prepare the InstCombine worklist. + bool prepareWorklist(Function &F, + ReversePostOrderTraversal<BasicBlock *> &RPOT); + /// Run the combiner over the entire worklist until it is empty. /// /// \returns true if the IR is changed. @@ -93,6 +98,7 @@ public: Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); + Instruction *foldFMulReassoc(BinaryOperator &I); Instruction *visitFMul(BinaryOperator &I); Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); @@ -126,7 +132,6 @@ public: Instruction *FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I); Instruction *commonCastTransforms(CastInst &CI); - Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *visitTrunc(TruncInst &CI); Instruction *visitZExt(ZExtInst &Zext); Instruction *visitSExt(SExtInst &Sext); @@ -193,6 +198,44 @@ public: LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy, const Twine &Suffix = ""); + KnownFPClass computeKnownFPClass(Value *Val, FastMathFlags FMF, + FPClassTest Interested = fcAllFlags, + const Instruction *CtxI = nullptr, + unsigned Depth = 0) const { + return llvm::computeKnownFPClass(Val, FMF, DL, Interested, Depth, &TLI, &AC, + CtxI, &DT); + } + + KnownFPClass computeKnownFPClass(Value *Val, + FPClassTest Interested = fcAllFlags, + const Instruction *CtxI = nullptr, + unsigned Depth = 0) const { + return llvm::computeKnownFPClass(Val, DL, Interested, Depth, &TLI, &AC, + CtxI, &DT); + } + + /// Check if fmul \p MulVal, +0.0 will yield +0.0 (or signed zero is + /// ignorable). + bool fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF, + const Instruction *CtxI) const; + + Constant *getLosslessTrunc(Constant *C, Type *TruncTy, unsigned ExtOp) { + Constant *TruncC = ConstantExpr::getTrunc(C, TruncTy); + Constant *ExtTruncC = + ConstantFoldCastOperand(ExtOp, TruncC, C->getType(), DL); + if (ExtTruncC && ExtTruncC == C) + return TruncC; + return nullptr; + } + + Constant *getLosslessUnsignedTrunc(Constant *C, Type *TruncTy) { + return getLosslessTrunc(C, TruncTy, Instruction::ZExt); + } + + Constant *getLosslessSignedTrunc(Constant *C, Type *TruncTy) { + return getLosslessTrunc(C, TruncTy, Instruction::SExt); + } + private: bool annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); bool isDesirableIntType(unsigned BitWidth) const; @@ -252,13 +295,15 @@ private: Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext); - bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowSignedAdd(const WithCache<const Value *> &LHS, + const WithCache<const Value *> &RHS, const Instruction &CxtI) const { return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; } - bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, + bool willNotOverflowUnsignedAdd(const WithCache<const Value *> &LHS, + const WithCache<const Value *> &RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; @@ -387,15 +432,17 @@ private: Instruction *foldAndOrOfSelectUsingImpliedCond(Value *Op, SelectInst &SI, bool IsAnd); + Instruction *hoistFNegAboveFMulFDiv(Value *FNegOp, Instruction &FMFSource); + public: /// Create and insert the idiom we use to indicate a block is unreachable /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { auto &Ctx = InsertAt->getContext(); auto *SI = new StoreInst(ConstantInt::getTrue(Ctx), - PoisonValue::get(Type::getInt1PtrTy(Ctx)), + PoisonValue::get(PointerType::getUnqual(Ctx)), /*isVolatile*/ false, Align(1)); - InsertNewInstBefore(SI, *InsertAt); + InsertNewInstBefore(SI, InsertAt->getIterator()); } /// Combiner aware instruction erasure. @@ -412,6 +459,7 @@ public: // use counts. SmallVector<Value *> Ops(I.operands()); Worklist.remove(&I); + DC.removeValue(&I); I.eraseFromParent(); for (Value *Op : Ops) Worklist.handleUseCountDecrement(Op); @@ -498,6 +546,7 @@ public: /// Tries to simplify operands to an integer instruction based on its /// demanded bits. bool SimplifyDemandedInstructionBits(Instruction &Inst); + bool SimplifyDemandedInstructionBits(Instruction &Inst, KnownBits &Known); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, unsigned Depth = 0, @@ -535,6 +584,9 @@ public: Instruction *foldAddWithConstant(BinaryOperator &Add); + Instruction *foldSquareSumInt(BinaryOperator &I); + Instruction *foldSquareSumFP(BinaryOperator &I); + /// Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *foldPHIArgOpIntoPHI(PHINode &PN); @@ -580,6 +632,9 @@ public: Instruction *foldICmpInstWithConstantAllowUndef(ICmpInst &Cmp, const APInt &C); Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); + Instruction *foldICmpWithMinMaxImpl(Instruction &I, MinMaxIntrinsic *MinMax, + Value *Z, ICmpInst::Predicate Pred); + Instruction *foldICmpWithMinMax(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); Instruction *foldSignBitTest(ICmpInst &I); @@ -593,6 +648,8 @@ public: ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, const APInt &C); + Instruction *foldICmpTruncWithTruncOrExt(ICmpInst &Cmp, + const SimplifyQuery &Q); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, const APInt &C); Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, @@ -667,8 +724,12 @@ public: bool tryToSinkInstruction(Instruction *I, BasicBlock *DestBlock); bool removeInstructionsBeforeUnreachable(Instruction &I); - bool handleUnreachableFrom(Instruction *I); - bool handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); + void addDeadEdge(BasicBlock *From, BasicBlock *To, + SmallVectorImpl<BasicBlock *> &Worklist); + void handleUnreachableFrom(Instruction *I, + SmallVectorImpl<BasicBlock *> &Worklist); + void handlePotentiallyDeadBlocks(SmallVectorImpl<BasicBlock *> &Worklist); + void handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc); void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); }; @@ -679,16 +740,11 @@ class Negator final { using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy Builder; - const DataLayout &DL; - AssumptionCache &AC; - const DominatorTree &DT; - const bool IsTrulyNegation; SmallDenseMap<Value *, Value *> NegationsCache; - Negator(LLVMContext &C, const DataLayout &DL, AssumptionCache &AC, - const DominatorTree &DT, bool IsTrulyNegation); + Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation); #if LLVM_ENABLE_STATS unsigned NumValuesVisitedInThisNegator = 0; @@ -700,13 +756,13 @@ class Negator final { std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); - [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth); + [[nodiscard]] Value *visitImpl(Value *V, bool IsNSW, unsigned Depth); - [[nodiscard]] Value *negate(Value *V, unsigned Depth); + [[nodiscard]] Value *negate(Value *V, bool IsNSW, unsigned Depth); /// Recurse depth-first and attempt to sink the negation. /// FIXME: use worklist? - [[nodiscard]] std::optional<Result> run(Value *Root); + [[nodiscard]] std::optional<Result> run(Value *Root, bool IsNSW); Negator(const Negator &) = delete; Negator(Negator &&) = delete; @@ -716,7 +772,7 @@ class Negator final { public: /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, /// otherwise returns negated value. - [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root, + [[nodiscard]] static Value *Negate(bool LHSIsZero, bool IsNSW, Value *Root, InstCombinerImpl &IC); }; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 6aa20ee26b9a..b72b68c68d98 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -36,6 +36,13 @@ static cl::opt<unsigned> MaxCopiedFromConstantUsers( cl::desc("Maximum users to visit in copy from constant transform"), cl::Hidden); +namespace llvm { +cl::opt<bool> EnableInferAlignmentPass( + "enable-infer-alignment-pass", cl::init(true), cl::Hidden, cl::ZeroOrMore, + cl::desc("Enable the InferAlignment pass, disabling alignment inference in " + "InstCombine")); +} + /// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived) /// pointer to an alloca. Ignore any reads of the pointer, return false if we /// see any stores or other unknown uses. If we see pointer arithmetic, keep @@ -224,7 +231,7 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, Value *Idx[2] = {NullIdx, NullIdx}; Instruction *GEP = GetElementPtrInst::CreateInBounds( NewTy, New, Idx, New->getName() + ".sub"); - IC.InsertNewInstBefore(GEP, *It); + IC.InsertNewInstBefore(GEP, It); // Now make everything use the getelementptr instead of the original // allocation. @@ -380,7 +387,7 @@ void PointerReplacer::replace(Instruction *I) { NewI->takeName(LT); copyMetadataForLoad(*NewI, *LT); - IC.InsertNewInstWith(NewI, *LT); + IC.InsertNewInstWith(NewI, LT->getIterator()); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; } else if (auto *PHI = dyn_cast<PHINode>(I)) { @@ -398,7 +405,7 @@ void PointerReplacer::replace(Instruction *I) { Indices.append(GEP->idx_begin(), GEP->idx_end()); auto *NewI = GetElementPtrInst::Create(GEP->getSourceElementType(), V, Indices); - IC.InsertNewInstWith(NewI, *GEP); + IC.InsertNewInstWith(NewI, GEP->getIterator()); NewI->takeName(GEP); WorkMap[GEP] = NewI; } else if (auto *BC = dyn_cast<BitCastInst>(I)) { @@ -407,14 +414,14 @@ void PointerReplacer::replace(Instruction *I) { auto *NewT = PointerType::get(BC->getType()->getContext(), V->getType()->getPointerAddressSpace()); auto *NewI = new BitCastInst(V, NewT); - IC.InsertNewInstWith(NewI, *BC); + IC.InsertNewInstWith(NewI, BC->getIterator()); NewI->takeName(BC); WorkMap[BC] = NewI; } else if (auto *SI = dyn_cast<SelectInst>(I)) { auto *NewSI = SelectInst::Create( SI->getCondition(), getReplacement(SI->getTrueValue()), getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); - IC.InsertNewInstWith(NewSI, *SI); + IC.InsertNewInstWith(NewSI, SI->getIterator()); NewSI->takeName(SI); WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { @@ -449,7 +456,7 @@ void PointerReplacer::replace(Instruction *I) { ASC->getType()->getPointerAddressSpace()) { auto *NewI = new AddrSpaceCastInst(V, ASC->getType(), ""); NewI->takeName(ASC); - IC.InsertNewInstWith(NewI, *ASC); + IC.InsertNewInstWith(NewI, ASC->getIterator()); NewV = NewI; } IC.replaceInstUsesWith(*ASC, NewV); @@ -507,8 +514,6 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { // types. const Align MaxAlign = std::max(EntryAI->getAlign(), AI.getAlign()); EntryAI->setAlignment(MaxAlign); - if (AI.getType() != EntryAI->getType()) - return new BitCastInst(EntryAI, AI.getType()); return replaceInstUsesWith(AI, EntryAI); } } @@ -534,13 +539,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { LLVM_DEBUG(dbgs() << "Found alloca equal to global: " << AI << '\n'); LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); - auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); if (AI.getAddressSpace() == SrcAddrSpace) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); - Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - Instruction *NewI = replaceInstUsesWith(AI, Cast); + Instruction *NewI = replaceInstUsesWith(AI, TheSrc); eraseInstFromFunction(*Copy); ++NumGlobalCopies; return NewI; @@ -551,8 +554,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); - Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - PtrReplacer.replacePointer(Cast); + PtrReplacer.replacePointer(TheSrc); ++NumGlobalCopies; } } @@ -582,16 +584,9 @@ LoadInst *InstCombinerImpl::combineLoadToNewType(LoadInst &LI, Type *NewTy, assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); - Value *Ptr = LI.getPointerOperand(); - unsigned AS = LI.getPointerAddressSpace(); - Type *NewPtrTy = NewTy->getPointerTo(AS); - Value *NewPtr = nullptr; - if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && - NewPtr->getType() == NewPtrTy)) - NewPtr = Builder.CreateBitCast(Ptr, NewPtrTy); - - LoadInst *NewLoad = Builder.CreateAlignedLoad( - NewTy, NewPtr, LI.getAlign(), LI.isVolatile(), LI.getName() + Suffix); + LoadInst *NewLoad = + Builder.CreateAlignedLoad(NewTy, LI.getPointerOperand(), LI.getAlign(), + LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); copyMetadataForLoad(*NewLoad, LI); return NewLoad; @@ -606,13 +601,11 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, "can't fold an atomic store of requested type"); Value *Ptr = SI.getPointerOperand(); - unsigned AS = SI.getPointerAddressSpace(); SmallVector<std::pair<unsigned, MDNode *>, 8> MD; SI.getAllMetadata(MD); - StoreInst *NewStore = IC.Builder.CreateAlignedStore( - V, IC.Builder.CreateBitCast(Ptr, V->getType()->getPointerTo(AS)), - SI.getAlign(), SI.isVolatile()); + StoreInst *NewStore = + IC.Builder.CreateAlignedStore(V, Ptr, SI.getAlign(), SI.isVolatile()); NewStore->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); for (const auto &MDPair : MD) { unsigned ID = MDPair.first; @@ -655,29 +648,6 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, return NewStore; } -/// Returns true if instruction represent minmax pattern like: -/// select ((cmp load V1, load V2), V1, V2). -static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { - assert(V->getType()->isPointerTy() && "Expected pointer type."); - // Ignore possible ty* to ixx* bitcast. - V = InstCombiner::peekThroughBitcast(V); - // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax - // pattern. - CmpInst::Predicate Pred; - Instruction *L1; - Instruction *L2; - Value *LHS; - Value *RHS; - if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), - m_Value(LHS), m_Value(RHS)))) - return false; - LoadTy = L1->getType(); - return (match(L1, m_Load(m_Specific(LHS))) && - match(L2, m_Load(m_Specific(RHS)))) || - (match(L1, m_Load(m_Specific(RHS))) && - match(L2, m_Load(m_Specific(LHS)))); -} - /// Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// @@ -818,7 +788,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { return nullptr; const DataLayout &DL = IC.getDataLayout(); - auto EltSize = DL.getTypeAllocSize(ET); + TypeSize EltSize = DL.getTypeAllocSize(ET); const auto Align = LI.getAlign(); auto *Addr = LI.getPointerOperand(); @@ -826,7 +796,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *Zero = ConstantInt::get(IdxType, 0); Value *V = PoisonValue::get(T); - uint64_t Offset = 0; + TypeSize Offset = TypeSize::get(0, ET->isScalableTy()); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, @@ -834,9 +804,9 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { }; auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), Name + ".elt"); + auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue()); auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, - commonAlignment(Align, Offset), - Name + ".unpack"); + EltAlign, Name + ".unpack"); L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); Offset += EltSize; @@ -971,7 +941,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, Type *SourceElementType = GEPI->getSourceElementType(); // Size information about scalable vectors is not available, so we cannot // deduce whether indexing at n is undefined behaviour or not. Bail out. - if (isa<ScalableVectorType>(SourceElementType)) + if (SourceElementType->isScalableTy()) return false; Type *AllocTy = GetElementPtrInst::getIndexedType(SourceElementType, Ops); @@ -1020,7 +990,7 @@ static Instruction *replaceGEPIdxWithZero(InstCombinerImpl &IC, Value *Ptr, Instruction *NewGEPI = GEPI->clone(); NewGEPI->setOperand(Idx, ConstantInt::get(GEPI->getOperand(Idx)->getType(), 0)); - IC.InsertNewInstBefore(NewGEPI, *GEPI); + IC.InsertNewInstBefore(NewGEPI, GEPI->getIterator()); return NewGEPI; } } @@ -1062,11 +1032,13 @@ Instruction *InstCombinerImpl::visitLoadInst(LoadInst &LI) { if (Instruction *Res = combineLoadToOperationType(*this, LI)) return Res; - // Attempt to improve the alignment. - Align KnownAlign = getOrEnforceKnownAlignment( - Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); - if (KnownAlign > LI.getAlign()) - LI.setAlignment(KnownAlign); + if (!EnableInferAlignmentPass) { + // Attempt to improve the alignment. + Align KnownAlign = getOrEnforceKnownAlignment( + Op, DL.getPrefTypeAlign(LI.getType()), DL, &LI, &AC, &DT); + if (KnownAlign > LI.getAlign()) + LI.setAlignment(KnownAlign); + } // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) @@ -1337,7 +1309,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { return false; const DataLayout &DL = IC.getDataLayout(); - auto EltSize = DL.getTypeAllocSize(AT->getElementType()); + TypeSize EltSize = DL.getTypeAllocSize(AT->getElementType()); const auto Align = SI.getAlign(); SmallString<16> EltName = V->getName(); @@ -1349,7 +1321,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *IdxType = Type::getInt64Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - uint64_t Offset = 0; + TypeSize Offset = TypeSize::get(0, AT->getElementType()->isScalableTy()); for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, @@ -1358,7 +1330,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); - auto EltAlign = commonAlignment(Align, Offset); + auto EltAlign = commonAlignment(Align, Offset.getKnownMinValue()); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); NS->setAAMetadata(SI.getAAMetadata()); Offset += EltSize; @@ -1399,58 +1371,6 @@ static bool equivalentAddressValues(Value *A, Value *B) { return false; } -/// Converts store (bitcast (load (bitcast (select ...)))) to -/// store (load (select ...)), where select is minmax: -/// select ((cmp load V1, load V2), V1, V2). -static bool removeBitcastsFromLoadStoreOnMinMax(InstCombinerImpl &IC, - StoreInst &SI) { - // bitcast? - if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) - return false; - // load? integer? - Value *LoadAddr; - if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) - return false; - auto *LI = cast<LoadInst>(SI.getValueOperand()); - if (!LI->getType()->isIntegerTy()) - return false; - Type *CmpLoadTy; - if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) - return false; - - // Make sure the type would actually change. - // This condition can be hit with chains of bitcasts. - if (LI->getType() == CmpLoadTy) - return false; - - // Make sure we're not changing the size of the load/store. - const auto &DL = IC.getDataLayout(); - if (DL.getTypeStoreSizeInBits(LI->getType()) != - DL.getTypeStoreSizeInBits(CmpLoadTy)) - return false; - - if (!all_of(LI->users(), [LI, LoadAddr](User *U) { - auto *SI = dyn_cast<StoreInst>(U); - return SI && SI->getPointerOperand() != LI && - InstCombiner::peekThroughBitcast(SI->getPointerOperand()) != - LoadAddr && - !SI->getPointerOperand()->isSwiftError(); - })) - return false; - - IC.Builder.SetInsertPoint(LI); - LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy); - // Replace all the stores with stores of the newly loaded value. - for (auto *UI : LI->users()) { - auto *USI = cast<StoreInst>(UI); - IC.Builder.SetInsertPoint(USI); - combineStoreToNewValue(IC, *USI, NewLI); - } - IC.replaceInstUsesWith(*LI, PoisonValue::get(LI->getType())); - IC.eraseInstFromFunction(*LI); - return true; -} - Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1459,19 +1379,18 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { if (combineStoreToValueType(*this, SI)) return eraseInstFromFunction(SI); - // Attempt to improve the alignment. - const Align KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); - if (KnownAlign > SI.getAlign()) - SI.setAlignment(KnownAlign); + if (!EnableInferAlignmentPass) { + // Attempt to improve the alignment. + const Align KnownAlign = getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlign(Val->getType()), DL, &SI, &AC, &DT); + if (KnownAlign > SI.getAlign()) + SI.setAlignment(KnownAlign); + } // Try to canonicalize the stored type. if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); - if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) - return eraseInstFromFunction(SI); - // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) return replaceOperand(SI, 1, NewGEPI); @@ -1508,8 +1427,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { --BBI; // Don't count debug info directives, lest they affect codegen, // and we skip pointer-to-pointer bitcasts, which are NOPs. - if (BBI->isDebugOrPseudoInst() || - (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + if (BBI->isDebugOrPseudoInst()) { ScanInsts++; continue; } @@ -1560,11 +1478,15 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { // This is a non-terminator unreachable marker. Don't remove it. if (isa<UndefValue>(Ptr)) { - // Remove all instructions after the marker and guaranteed-to-transfer - // instructions before the marker. - if (handleUnreachableFrom(SI.getNextNode()) || - removeInstructionsBeforeUnreachable(SI)) + // Remove guaranteed-to-transfer instructions before the marker. + if (removeInstructionsBeforeUnreachable(SI)) return &SI; + + // Remove all instructions after the marker and handle dead blocks this + // implies. + SmallVector<BasicBlock *> Worklist; + handleUnreachableFrom(SI.getNextNode(), Worklist); + handlePotentiallyDeadBlocks(Worklist); return nullptr; } @@ -1626,8 +1548,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { if (OtherBr->isUnconditional()) { --BBI; // Skip over debugging info and pseudo probes. - while (BBI->isDebugOrPseudoInst() || - (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { + while (BBI->isDebugOrPseudoInst()) { if (BBI==OtherBB->begin()) return false; --BBI; @@ -1681,7 +1602,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { Builder.SetInsertPoint(OtherStore); PN->addIncoming(Builder.CreateBitOrPointerCast(MergedVal, PN->getType()), OtherBB); - MergedVal = InsertNewInstBefore(PN, DestBB->front()); + MergedVal = InsertNewInstBefore(PN, DestBB->begin()); PN->setDebugLoc(MergedLoc); } @@ -1690,7 +1611,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), SI.getAlign(), SI.getOrdering(), SI.getSyncScopeID()); - InsertNewInstBefore(NewSI, *BBI); + InsertNewInstBefore(NewSI, BBI); NewSI->setDebugLoc(MergedLoc); NewSI->mergeDIAssignID({&SI, OtherStore}); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 50458e2773e6..8d5866e98a8e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -258,9 +258,14 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Op0->hasOneUse() && match(Op1, m_NegatedPower2())) { // Interpret X * (-1<<C) as (-X) * (1<<C) and try to sink the negation. // The "* (1<<C)" thus becomes a potential shifting opportunity. - if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) - return BinaryOperator::CreateMul( - NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); + if (Value *NegOp0 = + Negator::Negate(/*IsNegation*/ true, HasNSW, Op0, *this)) { + auto *Op1C = cast<Constant>(Op1); + return replaceInstUsesWith( + I, Builder.CreateMul(NegOp0, ConstantExpr::getNeg(Op1C), "", + /* HasNUW */ false, + HasNSW && Op1C->isNotMinSignedValue())); + } // Try to convert multiply of extended operand to narrow negate and shift // for better analysis. @@ -295,9 +300,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC. Value *X; Constant *C1; - if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) || - (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) && - haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) { + if (match(Op0, m_OneUse(m_AddLike(m_Value(X), m_ImmConstant(C1))))) { // C1*MulC simplifies to a tidier constant. Value *NewC = Builder.CreateMul(C1, MulC); auto *BOp0 = cast<BinaryOperator>(Op0); @@ -555,6 +558,180 @@ Instruction *InstCombinerImpl::foldFPSignBitOps(BinaryOperator &I) { return nullptr; } +Instruction *InstCombinerImpl::foldFMulReassoc(BinaryOperator &I) { + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *X, *Y; + Constant *C; + + // Reassociate constant RHS with another constant to form constant + // expression. + if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { + Constant *C1; + if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { + // (C1 / X) * C --> (C * C1) / X + Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL); + if (CC1 && CC1->isNormalFP()) + return BinaryOperator::CreateFDivFMF(CC1, X, &I); + } + if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { + // (X / C1) * C --> X * (C / C1) + Constant *CDivC1 = + ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL); + if (CDivC1 && CDivC1->isNormalFP()) + return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); + + // If the constant was a denormal, try reassociating differently. + // (X / C1) * C --> X / (C1 / C) + Constant *C1DivC = + ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL); + if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP()) + return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); + } + + // We do not need to match 'fadd C, X' and 'fsub X, C' because they are + // canonicalized to 'fadd X, C'. Distributing the multiply may allow + // further folds and (X * C) + C2 is 'fma'. + if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { + // (X + C1) * C --> (X * C) + (C * C1) + if (Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFAddFMF(XC, CC1, &I); + } + } + if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { + // (C1 - X) * C --> (C * C1) - (X * C) + if (Constant *CC1 = + ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL)) { + Value *XC = Builder.CreateFMulFMF(X, C, &I); + return BinaryOperator::CreateFSubFMF(CC1, XC, &I); + } + } + } + + Value *Z; + if (match(&I, + m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), m_Value(Z)))) { + // Sink division: (X / Y) * Z --> (X * Z) / Y + Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); + return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); + } + + // sqrt(X) * sqrt(Y) -> sqrt(X * Y) + // nnan disallows the possibility of returning a number if both operands are + // negative (in that case, we should return NaN). + if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && + match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); + return replaceInstUsesWith(I, Sqrt); + } + + // The following transforms are done irrespective of the number of uses + // for the expression "1.0/sqrt(X)". + // 1) 1.0/sqrt(X) * X -> X/sqrt(X) + // 2) X * 1.0/sqrt(X) -> X/sqrt(X) + // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it + // has the necessary (reassoc) fast-math-flags. + if (I.hasNoSignedZeros() && + match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op1 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + if (I.hasNoSignedZeros() && + match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && + match(Y, m_Sqrt(m_Value(X))) && Op0 == X) + return BinaryOperator::CreateFDivFMF(X, Y, &I); + + // Like the similar transform in instsimplify, this requires 'nsz' because + // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. + if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && Op0->hasNUses(2)) { + // Peek through fdiv to find squaring of square root: + // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y + if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(XX, Y, &I); + } + // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) + if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { + Value *XX = Builder.CreateFMulFMF(X, X, &I); + return BinaryOperator::CreateFDivFMF(Y, XX, &I); + } + } + + // pow(X, Y) * X --> pow(X, Y+1) + // X * pow(X, Y) --> pow(X, Y+1) + if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), + m_Value(Y))), + m_Deferred(X)))) { + Value *Y1 = Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + + if (I.isOnlyUserOfAnyOperand()) { + // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { + auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); + return replaceInstUsesWith(I, NewPow); + } + // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { + auto *XZ = Builder.CreateFMulFMF(X, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); + return replaceInstUsesWith(I, NewPow); + } + + // powi(x, y) * powi(x, z) -> powi(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && + Y->getType() == Z->getType()) { + auto *YZ = Builder.CreateAdd(Y, Z); + auto *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + return replaceInstUsesWith(I, NewPow); + } + + // exp(X) * exp(Y) -> exp(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); + return replaceInstUsesWith(I, Exp); + } + + // exp2(X) * exp2(Y) -> exp2(X + Y) + if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && + match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) { + Value *XY = Builder.CreateFAddFMF(X, Y, &I); + Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); + return replaceInstUsesWith(I, Exp2); + } + } + + // (X*Y) * X => (X*X) * Y where Y != X + // The purpose is two-fold: + // 1) to form a power expression (of X). + // 2) potentially shorten the critical path: After transformation, the + // latency of the instruction Y is amortized by the expression of X*X, + // and therefore Y is in a "less critical" position compared to what it + // was before the transformation. + if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && Op1 != Y) { + Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && Op0 != Y) { + Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); + return BinaryOperator::CreateFMulFMF(XX, Y, &I); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = simplifyFMulInst(I.getOperand(0), I.getOperand(1), I.getFastMathFlags(), @@ -602,176 +779,9 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); - if (I.hasAllowReassoc()) { - // Reassociate constant RHS with another constant to form constant - // expression. - if (match(Op1, m_Constant(C)) && C->isFiniteNonZeroFP()) { - Constant *C1; - if (match(Op0, m_OneUse(m_FDiv(m_Constant(C1), m_Value(X))))) { - // (C1 / X) * C --> (C * C1) / X - Constant *CC1 = - ConstantFoldBinaryOpOperands(Instruction::FMul, C, C1, DL); - if (CC1 && CC1->isNormalFP()) - return BinaryOperator::CreateFDivFMF(CC1, X, &I); - } - if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { - // (X / C1) * C --> X * (C / C1) - Constant *CDivC1 = - ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C1, DL); - if (CDivC1 && CDivC1->isNormalFP()) - return BinaryOperator::CreateFMulFMF(X, CDivC1, &I); - - // If the constant was a denormal, try reassociating differently. - // (X / C1) * C --> X / (C1 / C) - Constant *C1DivC = - ConstantFoldBinaryOpOperands(Instruction::FDiv, C1, C, DL); - if (C1DivC && Op0->hasOneUse() && C1DivC->isNormalFP()) - return BinaryOperator::CreateFDivFMF(X, C1DivC, &I); - } - - // We do not need to match 'fadd C, X' and 'fsub X, C' because they are - // canonicalized to 'fadd X, C'. Distributing the multiply may allow - // further folds and (X * C) + C2 is 'fma'. - if (match(Op0, m_OneUse(m_FAdd(m_Value(X), m_Constant(C1))))) { - // (X + C1) * C --> (X * C) + (C * C1) - if (Constant *CC1 = ConstantFoldBinaryOpOperands( - Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFAddFMF(XC, CC1, &I); - } - } - if (match(Op0, m_OneUse(m_FSub(m_Constant(C1), m_Value(X))))) { - // (C1 - X) * C --> (C * C1) - (X * C) - if (Constant *CC1 = ConstantFoldBinaryOpOperands( - Instruction::FMul, C, C1, DL)) { - Value *XC = Builder.CreateFMulFMF(X, C, &I); - return BinaryOperator::CreateFSubFMF(CC1, XC, &I); - } - } - } - - Value *Z; - if (match(&I, m_c_FMul(m_OneUse(m_FDiv(m_Value(X), m_Value(Y))), - m_Value(Z)))) { - // Sink division: (X / Y) * Z --> (X * Z) / Y - Value *NewFMul = Builder.CreateFMulFMF(X, Z, &I); - return BinaryOperator::CreateFDivFMF(NewFMul, Y, &I); - } - - // sqrt(X) * sqrt(Y) -> sqrt(X * Y) - // nnan disallows the possibility of returning a number if both operands are - // negative (in that case, we should return NaN). - if (I.hasNoNaNs() && match(Op0, m_OneUse(m_Sqrt(m_Value(X)))) && - match(Op1, m_OneUse(m_Sqrt(m_Value(Y))))) { - Value *XY = Builder.CreateFMulFMF(X, Y, &I); - Value *Sqrt = Builder.CreateUnaryIntrinsic(Intrinsic::sqrt, XY, &I); - return replaceInstUsesWith(I, Sqrt); - } - - // The following transforms are done irrespective of the number of uses - // for the expression "1.0/sqrt(X)". - // 1) 1.0/sqrt(X) * X -> X/sqrt(X) - // 2) X * 1.0/sqrt(X) -> X/sqrt(X) - // We always expect the backend to reduce X/sqrt(X) to sqrt(X), if it - // has the necessary (reassoc) fast-math-flags. - if (I.hasNoSignedZeros() && - match(Op0, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Sqrt(m_Value(X))) && Op1 == X) - return BinaryOperator::CreateFDivFMF(X, Y, &I); - if (I.hasNoSignedZeros() && - match(Op1, (m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) && - match(Y, m_Sqrt(m_Value(X))) && Op0 == X) - return BinaryOperator::CreateFDivFMF(X, Y, &I); - - // Like the similar transform in instsimplify, this requires 'nsz' because - // sqrt(-0.0) = -0.0, and -0.0 * -0.0 does not simplify to -0.0. - if (I.hasNoNaNs() && I.hasNoSignedZeros() && Op0 == Op1 && - Op0->hasNUses(2)) { - // Peek through fdiv to find squaring of square root: - // (X / sqrt(Y)) * (X / sqrt(Y)) --> (X * X) / Y - if (match(Op0, m_FDiv(m_Value(X), m_Sqrt(m_Value(Y))))) { - Value *XX = Builder.CreateFMulFMF(X, X, &I); - return BinaryOperator::CreateFDivFMF(XX, Y, &I); - } - // (sqrt(Y) / X) * (sqrt(Y) / X) --> Y / (X * X) - if (match(Op0, m_FDiv(m_Sqrt(m_Value(Y)), m_Value(X)))) { - Value *XX = Builder.CreateFMulFMF(X, X, &I); - return BinaryOperator::CreateFDivFMF(Y, XX, &I); - } - } - - // pow(X, Y) * X --> pow(X, Y+1) - // X * pow(X, Y) --> pow(X, Y+1) - if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), - m_Value(Y))), - m_Deferred(X)))) { - Value *Y1 = - Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); - Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); - return replaceInstUsesWith(I, Pow); - } - - if (I.isOnlyUserOfAnyOperand()) { - // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) - if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { - auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); - auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); - return replaceInstUsesWith(I, NewPow); - } - // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) - if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { - auto *XZ = Builder.CreateFMulFMF(X, Z, &I); - auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); - return replaceInstUsesWith(I, NewPow); - } - - // powi(x, y) * powi(x, z) -> powi(x, y + z) - if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && - match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && - Y->getType() == Z->getType()) { - auto *YZ = Builder.CreateAdd(Y, Z); - auto *NewPow = Builder.CreateIntrinsic( - Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); - return replaceInstUsesWith(I, NewPow); - } - - // exp(X) * exp(Y) -> exp(X + Y) - if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && - match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { - Value *XY = Builder.CreateFAddFMF(X, Y, &I); - Value *Exp = Builder.CreateUnaryIntrinsic(Intrinsic::exp, XY, &I); - return replaceInstUsesWith(I, Exp); - } - - // exp2(X) * exp2(Y) -> exp2(X + Y) - if (match(Op0, m_Intrinsic<Intrinsic::exp2>(m_Value(X))) && - match(Op1, m_Intrinsic<Intrinsic::exp2>(m_Value(Y)))) { - Value *XY = Builder.CreateFAddFMF(X, Y, &I); - Value *Exp2 = Builder.CreateUnaryIntrinsic(Intrinsic::exp2, XY, &I); - return replaceInstUsesWith(I, Exp2); - } - } - - // (X*Y) * X => (X*X) * Y where Y != X - // The purpose is two-fold: - // 1) to form a power expression (of X). - // 2) potentially shorten the critical path: After transformation, the - // latency of the instruction Y is amortized by the expression of X*X, - // and therefore Y is in a "less critical" position compared to what it - // was before the transformation. - if (match(Op0, m_OneUse(m_c_FMul(m_Specific(Op1), m_Value(Y)))) && - Op1 != Y) { - Value *XX = Builder.CreateFMulFMF(Op1, Op1, &I); - return BinaryOperator::CreateFMulFMF(XX, Y, &I); - } - if (match(Op1, m_OneUse(m_c_FMul(m_Specific(Op0), m_Value(Y)))) && - Op0 != Y) { - Value *XX = Builder.CreateFMulFMF(Op0, Op0, &I); - return BinaryOperator::CreateFMulFMF(XX, Y, &I); - } - } + if (I.hasAllowReassoc()) + if (Instruction *FoldedMul = foldFMulReassoc(I)) + return FoldedMul; // log2(X * 0.5) * Y = log2(X) * Y - Y if (I.isFast()) { @@ -802,7 +812,7 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { I.hasNoSignedZeros() && match(Start, m_Zero())) return replaceInstUsesWith(I, Start); - // minimun(X, Y) * maximum(X, Y) => X * Y. + // minimum(X, Y) * maximum(X, Y) => X * Y. if (match(&I, m_c_FMul(m_Intrinsic<Intrinsic::maximum>(m_Value(X), m_Value(Y)), m_c_Intrinsic<Intrinsic::minimum>(m_Deferred(X), @@ -918,8 +928,7 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, return Remainder.isMinValue(); } -static Instruction *foldIDivShl(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { +static Value *foldIDivShl(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { assert((I.getOpcode() == Instruction::SDiv || I.getOpcode() == Instruction::UDiv) && "Expected integer divide"); @@ -928,7 +937,6 @@ static Instruction *foldIDivShl(BinaryOperator &I, Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); - Instruction *Ret = nullptr; Value *X, *Y, *Z; // With appropriate no-wrap constraints, remove a common factor in the @@ -943,12 +951,12 @@ static Instruction *foldIDivShl(BinaryOperator &I, // (X * Y) u/ (X << Z) --> Y u>> Z if (!IsSigned && HasNUW) - Ret = BinaryOperator::CreateLShr(Y, Z); + return Builder.CreateLShr(Y, Z, "", I.isExact()); // (X * Y) s/ (X << Z) --> Y s/ (1 << Z) if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) { Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z); - Ret = BinaryOperator::CreateSDiv(Y, Shl); + return Builder.CreateSDiv(Y, Shl, "", I.isExact()); } } @@ -966,20 +974,38 @@ static Instruction *foldIDivShl(BinaryOperator &I, ((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) || (Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap()))) - Ret = BinaryOperator::CreateUDiv(X, Y); + return Builder.CreateUDiv(X, Y, "", I.isExact()); // For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor. // (X << Z) / (Y << Z) --> X / Y if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() && Shl1->hasNoUnsignedWrap()) - Ret = BinaryOperator::CreateSDiv(X, Y); + return Builder.CreateSDiv(X, Y, "", I.isExact()); } - if (!Ret) - return nullptr; + // If X << Y and X << Z does not overflow, then: + // (X << Y) / (X << Z) -> (1 << Y) / (1 << Z) -> 1 << Y >> Z + if (match(Op0, m_Shl(m_Value(X), m_Value(Y))) && + match(Op1, m_Shl(m_Specific(X), m_Value(Z)))) { + auto *Shl0 = cast<OverflowingBinaryOperator>(Op0); + auto *Shl1 = cast<OverflowingBinaryOperator>(Op1); - Ret->setIsExact(I.isExact()); - return Ret; + if (IsSigned ? (Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap()) + : (Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap())) { + Constant *One = ConstantInt::get(X->getType(), 1); + // Only preserve the nsw flag if dividend has nsw + // or divisor has nsw and operator is sdiv. + Value *Dividend = Builder.CreateShl( + One, Y, "shl.dividend", + /*HasNUW*/ true, + /*HasNSW*/ + IsSigned ? (Shl0->hasNoUnsignedWrap() || Shl1->hasNoUnsignedWrap()) + : Shl0->hasNoSignedWrap()); + return Builder.CreateLShr(Dividend, Z, "", I.isExact()); + } + } + + return nullptr; } /// This function implements the transforms common to both integer division @@ -1156,8 +1182,8 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { return NewDiv; } - if (Instruction *R = foldIDivShl(I, Builder)) - return R; + if (Value *R = foldIDivShl(I, Builder)) + return replaceInstUsesWith(I, R); // With the appropriate no-wrap constraint, remove a multiply by the divisor // after peeking through another divide: @@ -1263,7 +1289,7 @@ static Value *takeLog2(IRBuilderBase &Builder, Value *Op, unsigned Depth, /// If we have zero-extended operands of an unsigned div or rem, we may be able /// to narrow the operation (sink the zext below the math). static Instruction *narrowUDivURem(BinaryOperator &I, - InstCombiner::BuilderTy &Builder) { + InstCombinerImpl &IC) { Instruction::BinaryOps Opcode = I.getOpcode(); Value *N = I.getOperand(0); Value *D = I.getOperand(1); @@ -1273,7 +1299,7 @@ static Instruction *narrowUDivURem(BinaryOperator &I, X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) { // udiv (zext X), (zext Y) --> zext (udiv X, Y) // urem (zext X), (zext Y) --> zext (urem X, Y) - Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y); + Value *NarrowOp = IC.Builder.CreateBinOp(Opcode, X, Y); return new ZExtInst(NarrowOp, Ty); } @@ -1281,24 +1307,24 @@ static Instruction *narrowUDivURem(BinaryOperator &I, if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. - Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(TruncC, Ty) != C) + Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType()); + if (!TruncC) return nullptr; // udiv (zext X), C --> zext (udiv X, C') // urem (zext X), C --> zext (urem X, C') - return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty); + return new ZExtInst(IC.Builder.CreateBinOp(Opcode, X, TruncC), Ty); } if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. - Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); - if (ConstantExpr::getZExt(TruncC, Ty) != C) + Constant *TruncC = IC.getLosslessUnsignedTrunc(C, X->getType()); + if (!TruncC) return nullptr; // udiv C, (zext X) --> zext (udiv C', X) // urem C, (zext X) --> zext (urem C', X) - return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty); + return new ZExtInst(IC.Builder.CreateBinOp(Opcode, TruncC, X), Ty); } return nullptr; @@ -1346,7 +1372,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { return CastInst::CreateZExtOrBitCast(Cmp, Ty); } - if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) + if (Instruction *NarrowDiv = narrowUDivURem(I, *this)) return NarrowDiv; // If the udiv operands are non-overflowing multiplies with a common operand, @@ -1405,7 +1431,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { // sdiv Op0, (sext i1 X) --> -Op0 (because if X is 0, the op is undefined) if (match(Op1, m_AllOnes()) || (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1))) - return BinaryOperator::CreateNeg(Op0); + return BinaryOperator::CreateNSWNeg(Op0); // X / INT_MIN --> X == INT_MIN if (match(Op1, m_SignMask())) @@ -1428,7 +1454,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1)); Constant *C = ConstantExpr::getExactLogBase2(NegPow2C); Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true); - return BinaryOperator::CreateNeg(Ashr); + return BinaryOperator::CreateNSWNeg(Ashr); } } @@ -1490,7 +1516,7 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (KnownDividend.isNonNegative()) { // If both operands are unsigned, turn this into a udiv. - if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) { + if (isKnownNonNegative(Op1, SQ.getWithInstruction(&I))) { auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); BO->setIsExact(I.isExact()); return BO; @@ -1516,6 +1542,13 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { } } + // -X / X --> X == INT_MIN ? 1 : -1 + if (isKnownNegation(Op0, Op1)) { + APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + Value *Cond = Builder.CreateICmpEQ(Op0, ConstantInt::get(Ty, MinVal)); + return SelectInst::Create(Cond, ConstantInt::get(Ty, 1), + ConstantInt::getAllOnesValue(Ty)); + } return nullptr; } @@ -1759,6 +1792,21 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { return replaceInstUsesWith(I, Pow); } + // powi(X, Y) / X --> powi(X, Y-1) + // This is legal when (Y - 1) can't wraparound, in which case reassoc and nnan + // are required. + // TODO: Multi-use may be also better off creating Powi(x,y-1) + if (I.hasAllowReassoc() && I.hasNoNaNs() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::powi>(m_Specific(Op1), + m_Value(Y)))) && + willNotOverflowSignedSub(Y, ConstantInt::get(Y->getType(), 1), I)) { + Constant *NegOne = ConstantInt::getAllOnesValue(Y->getType()); + Value *Y1 = Builder.CreateAdd(Y, NegOne); + Type *Types[] = {Op1->getType(), Y1->getType()}; + Value *Pow = Builder.CreateIntrinsic(Intrinsic::powi, Types, {Op1, Y1}, &I); + return replaceInstUsesWith(I, Pow); + } + return nullptr; } @@ -1936,7 +1984,7 @@ Instruction *InstCombinerImpl::visitURem(BinaryOperator &I) { if (Instruction *common = commonIRemTransforms(I)) return common; - if (Instruction *NarrowRem = narrowUDivURem(I, Builder)) + if (Instruction *NarrowRem = narrowUDivURem(I, *this)) return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index e24abc48424d..513b185c83a4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -20,7 +20,6 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/Twine.h" -#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Constant.h" @@ -98,14 +97,13 @@ static cl::opt<unsigned> cl::desc("What is the maximal lookup depth when trying to " "check for viability of negation sinking.")); -Negator::Negator(LLVMContext &C, const DataLayout &DL_, AssumptionCache &AC_, - const DominatorTree &DT_, bool IsTrulyNegation_) - : Builder(C, TargetFolder(DL_), +Negator::Negator(LLVMContext &C, const DataLayout &DL, bool IsTrulyNegation_) + : Builder(C, TargetFolder(DL), IRBuilderCallbackInserter([&](Instruction *I) { ++NegatorNumInstructionsCreatedTotal; NewInstructions.push_back(I); })), - DL(DL_), AC(AC_), DT(DT_), IsTrulyNegation(IsTrulyNegation_) {} + IsTrulyNegation(IsTrulyNegation_) {} #if LLVM_ENABLE_STATS Negator::~Negator() { @@ -128,7 +126,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // FIXME: can this be reworked into a worklist-based algorithm while preserving // the depth-first, early bailout traversal? -[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::visitImpl(Value *V, bool IsNSW, unsigned Depth) { // -(undef) -> undef. if (match(V, m_Undef())) return V; @@ -237,7 +235,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // However, only do this either if the old `sub` doesn't stick around, or // it was subtracting from a constant. Otherwise, this isn't profitable. return Builder.CreateSub(I->getOperand(1), I->getOperand(0), - I->getName() + ".neg"); + I->getName() + ".neg", /* HasNUW */ false, + IsNSW && I->hasNoSignedWrap()); } // Some other cases, while still don't require recursion, @@ -302,7 +301,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { switch (I->getOpcode()) { case Instruction::Freeze: { // `freeze` is negatible if its operand is negatible. - Value *NegOp = negate(I->getOperand(0), Depth + 1); + Value *NegOp = negate(I->getOperand(0), IsNSW, Depth + 1); if (!NegOp) // Early return. return nullptr; return Builder.CreateFreeze(NegOp, I->getName() + ".neg"); @@ -313,7 +312,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { SmallVector<Value *, 4> NegatedIncomingValues(PHI->getNumOperands()); for (auto I : zip(PHI->incoming_values(), NegatedIncomingValues)) { if (!(std::get<1>(I) = - negate(std::get<0>(I), Depth + 1))) // Early return. + negate(std::get<0>(I), IsNSW, Depth + 1))) // Early return. return nullptr; } // All incoming values are indeed negatible. Create negated PHI node. @@ -336,10 +335,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return NewSelect; } // `select` is negatible if both hands of `select` are negatible. - Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1); if (!NegOp1) // Early return. return nullptr; - Value *NegOp2 = negate(I->getOperand(2), Depth + 1); + Value *NegOp2 = negate(I->getOperand(2), IsNSW, Depth + 1); if (!NegOp2) return nullptr; // Do preserve the metadata! @@ -349,10 +348,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { case Instruction::ShuffleVector: { // `shufflevector` is negatible if both operands are negatible. auto *Shuf = cast<ShuffleVectorInst>(I); - Value *NegOp0 = negate(I->getOperand(0), Depth + 1); + Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1); if (!NegOp0) // Early return. return nullptr; - Value *NegOp1 = negate(I->getOperand(1), Depth + 1); + Value *NegOp1 = negate(I->getOperand(1), IsNSW, Depth + 1); if (!NegOp1) return nullptr; return Builder.CreateShuffleVector(NegOp0, NegOp1, Shuf->getShuffleMask(), @@ -361,7 +360,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { case Instruction::ExtractElement: { // `extractelement` is negatible if source operand is negatible. auto *EEI = cast<ExtractElementInst>(I); - Value *NegVector = negate(EEI->getVectorOperand(), Depth + 1); + Value *NegVector = negate(EEI->getVectorOperand(), IsNSW, Depth + 1); if (!NegVector) // Early return. return nullptr; return Builder.CreateExtractElement(NegVector, EEI->getIndexOperand(), @@ -371,10 +370,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // `insertelement` is negatible if both the source vector and // element-to-be-inserted are negatible. auto *IEI = cast<InsertElementInst>(I); - Value *NegVector = negate(IEI->getOperand(0), Depth + 1); + Value *NegVector = negate(IEI->getOperand(0), IsNSW, Depth + 1); if (!NegVector) // Early return. return nullptr; - Value *NegNewElt = negate(IEI->getOperand(1), Depth + 1); + Value *NegNewElt = negate(IEI->getOperand(1), IsNSW, Depth + 1); if (!NegNewElt) // Early return. return nullptr; return Builder.CreateInsertElement(NegVector, NegNewElt, IEI->getOperand(2), @@ -382,15 +381,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { } case Instruction::Trunc: { // `trunc` is negatible if its operand is negatible. - Value *NegOp = negate(I->getOperand(0), Depth + 1); + Value *NegOp = negate(I->getOperand(0), /* IsNSW */ false, Depth + 1); if (!NegOp) // Early return. return nullptr; return Builder.CreateTrunc(NegOp, I->getType(), I->getName() + ".neg"); } case Instruction::Shl: { // `shl` is negatible if the first operand is negatible. - if (Value *NegOp0 = negate(I->getOperand(0), Depth + 1)) - return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); + IsNSW &= I->hasNoSignedWrap(); + if (Value *NegOp0 = negate(I->getOperand(0), IsNSW, Depth + 1)) + return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg", + /* HasNUW */ false, IsNSW); // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); if (!Op1C || !IsTrulyNegation) @@ -398,11 +399,10 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return Builder.CreateMul( I->getOperand(0), ConstantExpr::getShl(Constant::getAllOnesValue(Op1C->getType()), Op1C), - I->getName() + ".neg"); + I->getName() + ".neg", /* HasNUW */ false, IsNSW); } case Instruction::Or: { - if (!haveNoCommonBitsSet(I->getOperand(0), I->getOperand(1), DL, &AC, I, - &DT)) + if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) return nullptr; // Don't know how to handle `or` in general. std::array<Value *, 2> Ops = getSortedOperandsOfBinOp(I); // `or`/`add` are interchangeable when operands have no common bits set. @@ -417,7 +417,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { SmallVector<Value *, 2> NegatedOps, NonNegatedOps; for (Value *Op : I->operands()) { // Can we sink the negation into this operand? - if (Value *NegOp = negate(Op, Depth + 1)) { + if (Value *NegOp = negate(Op, /* IsNSW */ false, Depth + 1)) { NegatedOps.emplace_back(NegOp); // Successfully negated operand! continue; } @@ -446,9 +446,11 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // `xor` is negatible if one of its operands is invertible. // FIXME: InstCombineInverter? But how to connect Inverter and Negator? if (auto *C = dyn_cast<Constant>(Ops[1])) { - Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); - return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), - I->getName() + ".neg"); + if (IsTrulyNegation) { + Value *Xor = Builder.CreateXor(Ops[0], ConstantExpr::getNot(C)); + return Builder.CreateAdd(Xor, ConstantInt::get(Xor->getType(), 1), + I->getName() + ".neg"); + } } return nullptr; } @@ -458,16 +460,17 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { Value *NegatedOp, *OtherOp; // First try the second operand, in case it's a constant it will be best to // just invert it instead of sinking the `neg` deeper. - if (Value *NegOp1 = negate(Ops[1], Depth + 1)) { + if (Value *NegOp1 = negate(Ops[1], /* IsNSW */ false, Depth + 1)) { NegatedOp = NegOp1; OtherOp = Ops[0]; - } else if (Value *NegOp0 = negate(Ops[0], Depth + 1)) { + } else if (Value *NegOp0 = negate(Ops[0], /* IsNSW */ false, Depth + 1)) { NegatedOp = NegOp0; OtherOp = Ops[1]; } else // Can't negate either of them. return nullptr; - return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg"); + return Builder.CreateMul(NegatedOp, OtherOp, I->getName() + ".neg", + /* HasNUW */ false, IsNSW && I->hasNoSignedWrap()); } default: return nullptr; // Don't know, likely not negatible for free. @@ -476,7 +479,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { llvm_unreachable("Can't get here. We always return from switch."); } -[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::negate(Value *V, bool IsNSW, unsigned Depth) { NegatorMaxDepthVisited.updateMax(Depth); ++NegatorNumValuesVisited; @@ -506,15 +509,16 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { #endif // No luck. Try negating it for real. - Value *NegatedV = visitImpl(V, Depth); + Value *NegatedV = visitImpl(V, IsNSW, Depth); // And cache the (real) result for the future. NegationsCache[V] = NegatedV; return NegatedV; } -[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) { - Value *Negated = negate(Root, /*Depth=*/0); +[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root, + bool IsNSW) { + Value *Negated = negate(Root, IsNSW, /*Depth=*/0); if (!Negated) { // We must cleanup newly-inserted instructions, to avoid any potential // endless combine looping. @@ -525,7 +529,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); } -[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root, +[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, bool IsNSW, Value *Root, InstCombinerImpl &IC) { ++NegatorTotalNegationsAttempted; LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root @@ -534,9 +538,8 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { if (!NegatorEnabled || !DebugCounter::shouldExecute(NegatorCounter)) return nullptr; - Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), - IC.getDominatorTree(), LHSIsZero); - std::optional<Result> Res = N.run(Root); + Negator N(Root->getContext(), IC.getDataLayout(), LHSIsZero); + std::optional<Result> Res = N.run(Root, IsNSW); if (!Res) { // Negation failed. LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root << "\n"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 2f6aa85062a5..20b34c1379d5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -248,7 +248,7 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { PHINode *NewPtrPHI = PHINode::Create( IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); - InsertNewInstBefore(NewPtrPHI, PN); + InsertNewInstBefore(NewPtrPHI, PN.getIterator()); SmallDenseMap<Value *, Instruction *> Casts; for (auto Incoming : zip(PN.blocks(), AvailablePtrVals)) { auto *IncomingBB = std::get<0>(Incoming); @@ -285,10 +285,10 @@ bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (isa<PHINode>(IncomingI)) InsertPos = BB->getFirstInsertionPt(); assert(InsertPos != BB->end() && "should have checked above"); - InsertNewInstBefore(CI, *InsertPos); + InsertNewInstBefore(CI, InsertPos); } else { auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); - InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt()); + InsertNewInstBefore(CI, InsertBB->getFirstInsertionPt()); } } NewPtrPHI->addIncoming(CI, IncomingBB); @@ -353,7 +353,7 @@ InstCombinerImpl::foldPHIArgInsertValueInstructionIntoPHI(PHINode &PN) { NewOperand->addIncoming( cast<InsertValueInst>(std::get<1>(Incoming))->getOperand(OpIdx), std::get<0>(Incoming)); - InsertNewInstBefore(NewOperand, PN); + InsertNewInstBefore(NewOperand, PN.getIterator()); } // And finally, create `insertvalue` over the newly-formed PHI nodes. @@ -391,7 +391,7 @@ InstCombinerImpl::foldPHIArgExtractValueInstructionIntoPHI(PHINode &PN) { NewAggregateOperand->addIncoming( cast<ExtractValueInst>(std::get<1>(Incoming))->getAggregateOperand(), std::get<0>(Incoming)); - InsertNewInstBefore(NewAggregateOperand, PN); + InsertNewInstBefore(NewAggregateOperand, PN.getIterator()); // And finally, create `extractvalue` over the newly-formed PHI nodes. auto *NewEVI = ExtractValueInst::Create(NewAggregateOperand, @@ -450,7 +450,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { NewLHS = PHINode::Create(LHSType, PN.getNumIncomingValues(), FirstInst->getOperand(0)->getName() + ".pn"); NewLHS->addIncoming(InLHS, PN.getIncomingBlock(0)); - InsertNewInstBefore(NewLHS, PN); + InsertNewInstBefore(NewLHS, PN.getIterator()); LHSVal = NewLHS; } @@ -458,7 +458,7 @@ Instruction *InstCombinerImpl::foldPHIArgBinOpIntoPHI(PHINode &PN) { NewRHS = PHINode::Create(RHSType, PN.getNumIncomingValues(), FirstInst->getOperand(1)->getName() + ".pn"); NewRHS->addIncoming(InRHS, PN.getIncomingBlock(0)); - InsertNewInstBefore(NewRHS, PN); + InsertNewInstBefore(NewRHS, PN.getIterator()); RHSVal = NewRHS; } @@ -581,7 +581,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { Value *FirstOp = FirstInst->getOperand(I); PHINode *NewPN = PHINode::Create(FirstOp->getType(), E, FirstOp->getName() + ".pn"); - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); NewPN->addIncoming(FirstOp, PN.getIncomingBlock(0)); OperandPhis[I] = NewPN; @@ -769,7 +769,7 @@ Instruction *InstCombinerImpl::foldPHIArgLoadIntoPHI(PHINode &PN) { NewLI->setOperand(0, InVal); delete NewPN; } else { - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); } // If this was a volatile load that we are merging, make sure to loop through @@ -825,8 +825,8 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { NumZexts++; } else if (auto *C = dyn_cast<Constant>(V)) { // Make sure that constants can fit in the new type. - Constant *Trunc = ConstantExpr::getTrunc(C, NarrowType); - if (ConstantExpr::getZExt(Trunc, C->getType()) != C) + Constant *Trunc = getLosslessUnsignedTrunc(C, NarrowType); + if (!Trunc) return nullptr; NewIncoming.push_back(Trunc); NumConsts++; @@ -853,7 +853,7 @@ Instruction *InstCombinerImpl::foldPHIArgZextsIntoPHI(PHINode &Phi) { for (unsigned I = 0; I != NumIncomingValues; ++I) NewPhi->addIncoming(NewIncoming[I], Phi.getIncomingBlock(I)); - InsertNewInstBefore(NewPhi, Phi); + InsertNewInstBefore(NewPhi, Phi.getIterator()); return CastInst::CreateZExtOrBitCast(NewPhi, Phi.getType()); } @@ -943,7 +943,7 @@ Instruction *InstCombinerImpl::foldPHIArgOpIntoPHI(PHINode &PN) { PhiVal = InVal; delete NewPN; } else { - InsertNewInstBefore(NewPN, PN); + InsertNewInstBefore(NewPN, PN.getIterator()); PhiVal = NewPN; } @@ -996,8 +996,8 @@ static bool isDeadPHICycle(PHINode *PN, /// Return true if this phi node is always equal to NonPhiInVal. /// This happens with mutually cyclic phi nodes like: /// z = some value; x = phi (y, z); y = phi (x, z) -static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, - SmallPtrSetImpl<PHINode*> &ValueEqualPHIs) { +static bool PHIsEqualValue(PHINode *PN, Value *&NonPhiInVal, + SmallPtrSetImpl<PHINode *> &ValueEqualPHIs) { // See if we already saw this PHI node. if (!ValueEqualPHIs.insert(PN).second) return true; @@ -1010,8 +1010,11 @@ static bool PHIsEqualValue(PHINode *PN, Value *NonPhiInVal, // the value. for (Value *Op : PN->incoming_values()) { if (PHINode *OpPN = dyn_cast<PHINode>(Op)) { - if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) - return false; + if (!PHIsEqualValue(OpPN, NonPhiInVal, ValueEqualPHIs)) { + if (NonPhiInVal) + return false; + NonPhiInVal = OpPN; + } } else if (Op != NonPhiInVal) return false; } @@ -1368,7 +1371,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // sinking. auto InsertPt = BB->getFirstInsertionPt(); if (InsertPt != BB->end()) { - Self.Builder.SetInsertPoint(&*InsertPt); + Self.Builder.SetInsertPoint(&*BB, InsertPt); return Self.Builder.CreateNot(Cond); } @@ -1437,22 +1440,45 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // are induction variable analysis (sometimes) and ADCE, which is only run // late. if (PHIUser->hasOneUse() && - (isa<BinaryOperator>(PHIUser) || isa<GetElementPtrInst>(PHIUser)) && + (isa<BinaryOperator>(PHIUser) || isa<UnaryOperator>(PHIUser) || + isa<GetElementPtrInst>(PHIUser)) && PHIUser->user_back() == &PN) { return replaceInstUsesWith(PN, PoisonValue::get(PN.getType())); } - // When a PHI is used only to be compared with zero, it is safe to replace - // an incoming value proved as known nonzero with any non-zero constant. - // For example, in the code below, the incoming value %v can be replaced - // with any non-zero constant based on the fact that the PHI is only used to - // be compared with zero and %v is a known non-zero value: - // %v = select %cond, 1, 2 - // %p = phi [%v, BB] ... - // icmp eq, %p, 0 - auto *CmpInst = dyn_cast<ICmpInst>(PHIUser); - // FIXME: To be simple, handle only integer type for now. - if (CmpInst && isa<IntegerType>(PN.getType()) && CmpInst->isEquality() && - match(CmpInst->getOperand(1), m_Zero())) { + } + + // When a PHI is used only to be compared with zero, it is safe to replace + // an incoming value proved as known nonzero with any non-zero constant. + // For example, in the code below, the incoming value %v can be replaced + // with any non-zero constant based on the fact that the PHI is only used to + // be compared with zero and %v is a known non-zero value: + // %v = select %cond, 1, 2 + // %p = phi [%v, BB] ... + // icmp eq, %p, 0 + // FIXME: To be simple, handle only integer type for now. + // This handles a small number of uses to keep the complexity down, and an + // icmp(or(phi)) can equally be replaced with any non-zero constant as the + // "or" will only add bits. + if (!PN.hasNUsesOrMore(3)) { + SmallVector<Instruction *> DropPoisonFlags; + bool AllUsesOfPhiEndsInCmp = all_of(PN.users(), [&](User *U) { + auto *CmpInst = dyn_cast<ICmpInst>(U); + if (!CmpInst) { + // This is always correct as OR only add bits and we are checking + // against 0. + if (U->hasOneUse() && match(U, m_c_Or(m_Specific(&PN), m_Value()))) { + DropPoisonFlags.push_back(cast<Instruction>(U)); + CmpInst = dyn_cast<ICmpInst>(U->user_back()); + } + } + if (!CmpInst || !isa<IntegerType>(PN.getType()) || + !CmpInst->isEquality() || !match(CmpInst->getOperand(1), m_Zero())) { + return false; + } + return true; + }); + // All uses of PHI results in a compare with zero. + if (AllUsesOfPhiEndsInCmp) { ConstantInt *NonZeroConst = nullptr; bool MadeChange = false; for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { @@ -1461,9 +1487,11 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (isKnownNonZero(VA, DL, 0, &AC, CtxI, &DT)) { if (!NonZeroConst) NonZeroConst = getAnyNonZeroConstInt(PN); - if (NonZeroConst != VA) { replaceOperand(PN, I, NonZeroConst); + // The "disjoint" flag may no longer hold after the transform. + for (Instruction *I : DropPoisonFlags) + I->dropPoisonGeneratingFlags(); MadeChange = true; } } @@ -1478,7 +1506,9 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // z = some value; x = phi (y, z); y = phi (x, z) // where the phi nodes don't necessarily need to be in the same block. Do a // quick check to see if the PHI node only contains a single non-phi value, if - // so, scan to see if the phi cycle is actually equal to that value. + // so, scan to see if the phi cycle is actually equal to that value. If the + // phi has no non-phi values then allow the "NonPhiInVal" to be set later if + // one of the phis itself does not have a single input. { unsigned InValNo = 0, NumIncomingVals = PN.getNumIncomingValues(); // Scan for the first non-phi operand. @@ -1486,25 +1516,25 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { isa<PHINode>(PN.getIncomingValue(InValNo))) ++InValNo; - if (InValNo != NumIncomingVals) { - Value *NonPhiInVal = PN.getIncomingValue(InValNo); + Value *NonPhiInVal = + InValNo != NumIncomingVals ? PN.getIncomingValue(InValNo) : nullptr; - // Scan the rest of the operands to see if there are any conflicts, if so - // there is no need to recursively scan other phis. + // Scan the rest of the operands to see if there are any conflicts, if so + // there is no need to recursively scan other phis. + if (NonPhiInVal) for (++InValNo; InValNo != NumIncomingVals; ++InValNo) { Value *OpVal = PN.getIncomingValue(InValNo); if (OpVal != NonPhiInVal && !isa<PHINode>(OpVal)) break; } - // If we scanned over all operands, then we have one unique value plus - // phi values. Scan PHI nodes to see if they all merge in each other or - // the value. - if (InValNo == NumIncomingVals) { - SmallPtrSet<PHINode*, 16> ValueEqualPHIs; - if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) - return replaceInstUsesWith(PN, NonPhiInVal); - } + // If we scanned over all operands, then we have one unique value plus + // phi values. Scan PHI nodes to see if they all merge in each other or + // the value. + if (InValNo == NumIncomingVals) { + SmallPtrSet<PHINode *, 16> ValueEqualPHIs; + if (PHIsEqualValue(&PN, NonPhiInVal, ValueEqualPHIs)) + return replaceInstUsesWith(PN, NonPhiInVal); } } @@ -1512,11 +1542,12 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // the blocks in the same order. This will help identical PHIs be eliminated // by other passes. Other passes shouldn't depend on this for correctness // however. - PHINode *FirstPN = cast<PHINode>(PN.getParent()->begin()); - if (&PN != FirstPN) - for (unsigned I = 0, E = FirstPN->getNumIncomingValues(); I != E; ++I) { + auto Res = PredOrder.try_emplace(PN.getParent()); + if (!Res.second) { + const auto &Preds = Res.first->second; + for (unsigned I = 0, E = PN.getNumIncomingValues(); I != E; ++I) { BasicBlock *BBA = PN.getIncomingBlock(I); - BasicBlock *BBB = FirstPN->getIncomingBlock(I); + BasicBlock *BBB = Preds[I]; if (BBA != BBB) { Value *VA = PN.getIncomingValue(I); unsigned J = PN.getBasicBlockIndex(BBB); @@ -1531,6 +1562,10 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // this in this case. } } + } else { + // Remember the block order of the first encountered phi node. + append_range(Res.first->second, PN.blocks()); + } // Is there an identical PHI node in this basic block? for (PHINode &IdenticalPN : PN.getParent()->phis()) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 661c50062223..2dda46986f0f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -689,34 +689,40 @@ static Value *foldSelectICmpLshrAshr(const ICmpInst *IC, Value *TrueVal, } /// We want to turn: -/// (select (icmp eq (and X, C1), 0), Y, (or Y, C2)) +/// (select (icmp eq (and X, C1), 0), Y, (BinOp Y, C2)) /// into: -/// (or (shl (and X, C1), C3), Y) +/// IF C2 u>= C1 +/// (BinOp Y, (shl (and X, C1), C3)) +/// ELSE +/// (BinOp Y, (lshr (and X, C1), C3)) /// iff: +/// 0 on the RHS is the identity value (i.e add, xor, shl, etc...) /// C1 and C2 are both powers of 2 /// where: -/// C3 = Log(C2) - Log(C1) +/// IF C2 u>= C1 +/// C3 = Log(C2) - Log(C1) +/// ELSE +/// C3 = Log(C1) - Log(C2) /// /// This transform handles cases where: /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, +static Value *foldSelectICmpAndBinOp(const ICmpInst *IC, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy &Builder) { // Only handle integer compares. Also, if this is a vector select, we need a // vector compare. if (!TrueVal->getType()->isIntOrIntVectorTy() || - TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); Value *CmpRHS = IC->getOperand(1); - Value *V; unsigned C1Log; - bool IsEqualZero; bool NeedAnd = false; + CmpInst::Predicate Pred = IC->getPredicate(); if (IC->isEquality()) { if (!match(CmpRHS, m_Zero())) return nullptr; @@ -725,49 +731,49 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, if (!match(CmpLHS, m_And(m_Value(), m_Power2(C1)))) return nullptr; - V = CmpLHS; C1Log = C1->logBase2(); - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_EQ; - } else if (IC->getPredicate() == ICmpInst::ICMP_SLT || - IC->getPredicate() == ICmpInst::ICMP_SGT) { - // We also need to recognize (icmp slt (trunc (X)), 0) and - // (icmp sgt (trunc (X)), -1). - IsEqualZero = IC->getPredicate() == ICmpInst::ICMP_SGT; - if ((IsEqualZero && !match(CmpRHS, m_AllOnes())) || - (!IsEqualZero && !match(CmpRHS, m_Zero()))) - return nullptr; - - if (!match(CmpLHS, m_OneUse(m_Trunc(m_Value(V))))) + } else { + APInt C1; + if (!decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, CmpLHS, C1) || + !C1.isPowerOf2()) return nullptr; - C1Log = CmpLHS->getType()->getScalarSizeInBits() - 1; + C1Log = C1.logBase2(); NeedAnd = true; - } else { - return nullptr; } + Value *Y, *V = CmpLHS; + BinaryOperator *BinOp; const APInt *C2; - bool OrOnTrueVal = false; - bool OrOnFalseVal = match(FalseVal, m_Or(m_Specific(TrueVal), m_Power2(C2))); - if (!OrOnFalseVal) - OrOnTrueVal = match(TrueVal, m_Or(m_Specific(FalseVal), m_Power2(C2))); - - if (!OrOnFalseVal && !OrOnTrueVal) + bool NeedXor; + if (match(FalseVal, m_BinOp(m_Specific(TrueVal), m_Power2(C2)))) { + Y = TrueVal; + BinOp = cast<BinaryOperator>(FalseVal); + NeedXor = Pred == ICmpInst::ICMP_NE; + } else if (match(TrueVal, m_BinOp(m_Specific(FalseVal), m_Power2(C2)))) { + Y = FalseVal; + BinOp = cast<BinaryOperator>(TrueVal); + NeedXor = Pred == ICmpInst::ICMP_EQ; + } else { return nullptr; + } - Value *Y = OrOnFalseVal ? TrueVal : FalseVal; + // Check that 0 on RHS is identity value for this binop. + auto *IdentityC = + ConstantExpr::getBinOpIdentity(BinOp->getOpcode(), BinOp->getType(), + /*AllowRHSConstant*/ true); + if (IdentityC == nullptr || !IdentityC->isNullValue()) + return nullptr; unsigned C2Log = C2->logBase2(); - bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. - Value *Or = OrOnFalseVal ? FalseVal : TrueVal; - if ((NeedShift + NeedXor + NeedZExtTrunc) > - (IC->hasOneUse() + Or->hasOneUse())) + if ((NeedShift + NeedXor + NeedZExtTrunc + NeedAnd) > + (IC->hasOneUse() + BinOp->hasOneUse())) return nullptr; if (NeedAnd) { @@ -788,7 +794,7 @@ static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, if (NeedXor) V = Builder.CreateXor(V, *C2); - return Builder.CreateOr(V, Y); + return Builder.CreateBinOp(BinOp->getOpcode(), Y, V); } /// Canonicalize a set or clear of a masked set of constant bits to @@ -870,7 +876,7 @@ static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { auto *FalseValI = cast<Instruction>(FalseVal); auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), - *FalseValI); + FalseValI->getIterator()); IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); return IC.replaceInstUsesWith(SI, FalseValI); } @@ -1303,45 +1309,28 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, return nullptr; // InstSimplify already performed this fold if it was possible subject to - // current poison-generating flags. Try the transform again with - // poison-generating flags temporarily dropped. - bool WasNUW = false, WasNSW = false, WasExact = false, WasInBounds = false; - if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(FalseVal)) { - WasNUW = OBO->hasNoUnsignedWrap(); - WasNSW = OBO->hasNoSignedWrap(); - FalseInst->setHasNoUnsignedWrap(false); - FalseInst->setHasNoSignedWrap(false); - } - if (auto *PEO = dyn_cast<PossiblyExactOperator>(FalseVal)) { - WasExact = PEO->isExact(); - FalseInst->setIsExact(false); - } - if (auto *GEP = dyn_cast<GetElementPtrInst>(FalseVal)) { - WasInBounds = GEP->isInBounds(); - GEP->setIsInBounds(false); - } + // current poison-generating flags. Check whether dropping poison-generating + // flags enables the transform. // Try each equivalence substitution possibility. // We have an 'EQ' comparison, so the select's false value will propagate. // Example: // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 + SmallVector<Instruction *> DropFlags; if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, SQ, - /* AllowRefinement */ false) == TrueVal || + /* AllowRefinement */ false, + &DropFlags) == TrueVal || simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, SQ, - /* AllowRefinement */ false) == TrueVal) { + /* AllowRefinement */ false, + &DropFlags) == TrueVal) { + for (Instruction *I : DropFlags) { + I->dropPoisonGeneratingFlagsAndMetadata(); + Worklist.add(I); + } + return replaceInstUsesWith(Sel, FalseVal); } - // Restore poison-generating flags if the transform did not apply. - if (WasNUW) - FalseInst->setHasNoUnsignedWrap(); - if (WasNSW) - FalseInst->setHasNoSignedWrap(); - if (WasExact) - FalseInst->setIsExact(); - if (WasInBounds) - cast<GetElementPtrInst>(FalseInst)->setIsInBounds(); - return nullptr; } @@ -1506,8 +1495,13 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(ReplacementLow, m_ImmConstant(LowC)) || !match(ReplacementHigh, m_ImmConstant(HighC))) return nullptr; - ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); - ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); + const DataLayout &DL = Sel0.getModule()->getDataLayout(); + ReplacementLow = + ConstantFoldCastOperand(Instruction::SExt, LowC, X->getType(), DL); + ReplacementHigh = + ConstantFoldCastOperand(Instruction::SExt, HighC, X->getType(), DL); + assert(ReplacementLow && ReplacementHigh && + "Constant folding of ImmConstant cannot fail"); } // All good, finally emit the new pattern. @@ -1797,7 +1791,7 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *V = foldSelectZeroOrOnes(ICI, TrueVal, FalseVal, Builder)) return V; - if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectICmpAndBinOp(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectICmpLshrAshr(ICI, TrueVal, FalseVal, Builder)) @@ -2094,9 +2088,8 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { // If the constant is the same after truncation to the smaller type and // extension to the original type, we can narrow the select. Type *SelType = Sel.getType(); - Constant *TruncC = ConstantExpr::getTrunc(C, SmallType); - Constant *ExtC = ConstantExpr::getCast(ExtOpcode, TruncC, SelType); - if (ExtC == C && ExtInst->hasOneUse()) { + Constant *TruncC = getLosslessTrunc(C, SmallType, ExtOpcode); + if (TruncC && ExtInst->hasOneUse()) { Value *TruncCVal = cast<Value>(TruncC); if (ExtInst == Sel.getFalseValue()) std::swap(X, TruncCVal); @@ -2107,23 +2100,6 @@ Instruction *InstCombinerImpl::foldSelectExtConst(SelectInst &Sel) { return CastInst::Create(Instruction::CastOps(ExtOpcode), NewSel, SelType); } - // If one arm of the select is the extend of the condition, replace that arm - // with the extension of the appropriate known bool value. - if (Cond == X) { - if (ExtInst == Sel.getTrueValue()) { - // select X, (sext X), C --> select X, -1, C - // select X, (zext X), C --> select X, 1, C - Constant *One = ConstantInt::getTrue(SmallType); - Constant *AllOnesOrOne = ConstantExpr::getCast(ExtOpcode, One, SelType); - return SelectInst::Create(Cond, AllOnesOrOne, C, "", nullptr, &Sel); - } else { - // select X, C, (sext X) --> select X, C, 0 - // select X, C, (zext X) --> select X, C, 0 - Constant *Zero = ConstantInt::getNullValue(SelType); - return SelectInst::Create(Cond, C, Zero, "", nullptr, &Sel); - } - } - return nullptr; } @@ -2561,7 +2537,7 @@ static Instruction *foldSelectToPhiImpl(SelectInst &Sel, BasicBlock *BB, return nullptr; } - Builder.SetInsertPoint(&*BB->begin()); + Builder.SetInsertPoint(BB, BB->begin()); auto *PN = Builder.CreatePHI(Sel.getType(), Inputs.size()); for (auto *Pred : predecessors(BB)) PN->addIncoming(Inputs[Pred], Pred); @@ -2584,6 +2560,61 @@ static Instruction *foldSelectToPhi(SelectInst &Sel, const DominatorTree &DT, return nullptr; } +/// Tries to reduce a pattern that arises when calculating the remainder of the +/// Euclidean division. When the divisor is a power of two and is guaranteed not +/// to be negative, a signed remainder can be folded with a bitwise and. +/// +/// (x % n) < 0 ? (x % n) + n : (x % n) +/// -> x & (n - 1) +static Instruction *foldSelectWithSRem(SelectInst &SI, InstCombinerImpl &IC, + IRBuilderBase &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + ICmpInst::Predicate Pred; + Value *Op, *RemRes, *Remainder; + const APInt *C; + bool TrueIfSigned = false; + + if (!(match(CondVal, m_ICmp(Pred, m_Value(RemRes), m_APInt(C))) && + IC.isSignBitCheck(Pred, *C, TrueIfSigned))) + return nullptr; + + // If the sign bit is not set, we have a SGE/SGT comparison, and the operands + // of the select are inverted. + if (!TrueIfSigned) + std::swap(TrueVal, FalseVal); + + auto FoldToBitwiseAnd = [&](Value *Remainder) -> Instruction * { + Value *Add = Builder.CreateAdd( + Remainder, Constant::getAllOnesValue(RemRes->getType())); + return BinaryOperator::CreateAnd(Op, Add); + }; + + // Match the general case: + // %rem = srem i32 %x, %n + // %cnd = icmp slt i32 %rem, 0 + // %add = add i32 %rem, %n + // %sel = select i1 %cnd, i32 %add, i32 %rem + if (match(TrueVal, m_Add(m_Value(RemRes), m_Value(Remainder))) && + match(RemRes, m_SRem(m_Value(Op), m_Specific(Remainder))) && + IC.isKnownToBeAPowerOfTwo(Remainder, /*OrZero*/ true) && + FalseVal == RemRes) + return FoldToBitwiseAnd(Remainder); + + // Match the case where the one arm has been replaced by constant 1: + // %rem = srem i32 %n, 2 + // %cnd = icmp slt i32 %rem, 0 + // %sel = select i1 %cnd, i32 1, i32 %rem + if (match(TrueVal, m_One()) && + match(RemRes, m_SRem(m_Value(Op), m_SpecificInt(2))) && + FalseVal == RemRes) + return FoldToBitwiseAnd(ConstantInt::get(RemRes->getType(), 2)); + + return nullptr; +} + static Value *foldSelectWithFrozenICmp(SelectInst &Sel, InstCombiner::BuilderTy &Builder) { FreezeInst *FI = dyn_cast<FreezeInst>(Sel.getCondition()); if (!FI) @@ -2860,8 +2891,15 @@ static Instruction *foldNestedSelects(SelectInst &OuterSelVal, std::swap(InnerSel.TrueVal, InnerSel.FalseVal); Value *AltCond = nullptr; - auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) { - return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond))); + auto matchOuterCond = [OuterSel, IsAndVariant, &AltCond](auto m_InnerCond) { + // An unsimplified select condition can match both LogicalAnd and LogicalOr + // (select true, true, false). Since below we assume that LogicalAnd implies + // InnerSel match the FVal and vice versa for LogicalOr, we can't match the + // alternative pattern here. + return IsAndVariant ? match(OuterSel.Cond, + m_c_LogicalAnd(m_InnerCond, m_Value(AltCond))) + : match(OuterSel.Cond, + m_c_LogicalOr(m_InnerCond, m_Value(AltCond))); }; // Finally, match the condition that was driving the outermost `select`, @@ -3024,31 +3062,37 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) return replaceOperand(SI, 0, A); - // select a, (select ~a, true, b), false -> select a, b, false - if (match(TrueVal, m_c_LogicalOr(m_Not(m_Specific(CondVal)), m_Value(B))) && - match(FalseVal, m_Zero())) - return replaceOperand(SI, 1, B); - // select a, true, (select ~a, b, false) -> select a, true, b - if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Specific(CondVal)), m_Value(B))) && - match(TrueVal, m_One())) - return replaceOperand(SI, 2, B); // ~(A & B) & (A | B) --> A ^ B if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), m_c_LogicalOr(m_Deferred(A), m_Deferred(B))))) return BinaryOperator::CreateXor(A, B); - // select (~a | c), a, b -> and a, (or c, freeze(b)) - if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && - CondVal->hasOneUse()) { - FalseVal = Builder.CreateFreeze(FalseVal); - return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + // select (~a | c), a, b -> select a, (select c, true, b), false + if (match(CondVal, + m_OneUse(m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))))) { + Value *OrV = Builder.CreateSelect(C, One, FalseVal); + return SelectInst::Create(TrueVal, OrV, Zero); + } + // select (c & b), a, b -> select b, (select ~c, true, a), false + if (match(CondVal, m_OneUse(m_c_And(m_Value(C), m_Specific(FalseVal))))) { + if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) { + Value *OrV = Builder.CreateSelect(NotC, One, TrueVal); + return SelectInst::Create(FalseVal, OrV, Zero); + } + } + // select (a | c), a, b -> select a, true, (select ~c, b, false) + if (match(CondVal, m_OneUse(m_c_Or(m_Specific(TrueVal), m_Value(C))))) { + if (Value *NotC = getFreelyInverted(C, C->hasOneUse(), &Builder)) { + Value *AndV = Builder.CreateSelect(NotC, FalseVal, Zero); + return SelectInst::Create(TrueVal, One, AndV); + } } - // select (~c & b), a, b -> and b, (or freeze(a), c) - if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && - CondVal->hasOneUse()) { - TrueVal = Builder.CreateFreeze(TrueVal); - return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + // select (c & ~b), a, b -> select b, true, (select c, a, false) + if (match(CondVal, + m_OneUse(m_c_And(m_Value(C), m_Not(m_Specific(FalseVal)))))) { + Value *AndV = Builder.CreateSelect(C, TrueVal, Zero); + return SelectInst::Create(FalseVal, One, AndV); } if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { @@ -3057,7 +3101,7 @@ Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { Value *Op1 = IsAnd ? TrueVal : FalseVal; if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); - InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); + InsertNewInstBefore(FI, cast<Instruction>(Y->getUser())->getIterator()); replaceUse(*Y, FI); return replaceInstUsesWith(SI, Op1); } @@ -3272,6 +3316,31 @@ static Instruction *foldBitCeil(SelectInst &SI, IRBuilderBase &Builder) { Masked); } +bool InstCombinerImpl::fmulByZeroIsZero(Value *MulVal, FastMathFlags FMF, + const Instruction *CtxI) const { + KnownFPClass Known = computeKnownFPClass(MulVal, FMF, fcNegative, CtxI); + + return Known.isKnownNeverNaN() && Known.isKnownNeverInfinity() && + (FMF.noSignedZeros() || Known.signBitIsZeroOrNaN()); +} + +static bool matchFMulByZeroIfResultEqZero(InstCombinerImpl &IC, Value *Cmp0, + Value *Cmp1, Value *TrueVal, + Value *FalseVal, Instruction &CtxI, + bool SelectIsNSZ) { + Value *MulRHS; + if (match(Cmp1, m_PosZeroFP()) && + match(TrueVal, m_c_FMul(m_Specific(Cmp0), m_Value(MulRHS)))) { + FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags(); + // nsz must be on the select, it must be ignored on the multiply. We + // need nnan and ninf on the multiply for the other value. + FMF.setNoSignedZeros(SelectIsNSZ); + return IC.fmulByZeroIsZero(MulRHS, FMF, &CtxI); + } + + return false; +} + Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -3303,28 +3372,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { ConstantInt::getFalse(CondType), SQ, /* AllowRefinement */ true)) return replaceOperand(SI, 2, S); - - // Handle patterns involving sext/zext + not explicitly, - // as simplifyWithOpReplaced() only looks past one instruction. - Value *NotCond; - - // select a, sext(!a), b -> select !a, b, 0 - // select a, zext(!a), b -> select !a, b, 0 - if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, FalseVal, - Constant::getNullValue(SelType)); - - // select a, b, zext(!a) -> select !a, 1, b - if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal); - - // select a, b, sext(!a) -> select !a, -1, b - if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond), - m_Not(m_Specific(CondVal)))))) - return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType), - TrueVal); } if (Instruction *R = foldSelectOfBools(SI)) @@ -3362,7 +3409,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + auto *SIFPOp = dyn_cast<FPMathOperator>(&SI); + if (auto *FCmp = dyn_cast<FCmpInst>(CondVal)) { + FCmpInst::Predicate Pred = FCmp->getPredicate(); Value *Cmp0 = FCmp->getOperand(0), *Cmp1 = FCmp->getOperand(1); // Are we selecting a value based on a comparison of the two values? if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || @@ -3372,7 +3422,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // // e.g. // (X ugt Y) ? X : Y -> (X ole Y) ? Y : X - if (FCmp->hasOneUse() && FCmpInst::isUnordered(FCmp->getPredicate())) { + if (FCmp->hasOneUse() && FCmpInst::isUnordered(Pred)) { FCmpInst::Predicate InvPred = FCmp->getInversePredicate(); IRBuilder<>::FastMathFlagGuard FMFG(Builder); // FIXME: The FMF should propagate from the select, not the fcmp. @@ -3383,14 +3433,47 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { return replaceInstUsesWith(SI, NewSel); } } + + if (SIFPOp) { + // Fold out scale-if-equals-zero pattern. + // + // This pattern appears in code with denormal range checks after it's + // assumed denormals are treated as zero. This drops a canonicalization. + + // TODO: Could relax the signed zero logic. We just need to know the sign + // of the result matches (fmul x, y has the same sign as x). + // + // TODO: Handle always-canonicalizing variant that selects some value or 1 + // scaling factor in the fmul visitor. + + // TODO: Handle ldexp too + + Value *MatchCmp0 = nullptr; + Value *MatchCmp1 = nullptr; + + // (select (fcmp [ou]eq x, 0.0), (fmul x, K), x => x + // (select (fcmp [ou]ne x, 0.0), x, (fmul x, K) => x + if (Pred == CmpInst::FCMP_OEQ || Pred == CmpInst::FCMP_UEQ) { + MatchCmp0 = FalseVal; + MatchCmp1 = TrueVal; + } else if (Pred == CmpInst::FCMP_ONE || Pred == CmpInst::FCMP_UNE) { + MatchCmp0 = TrueVal; + MatchCmp1 = FalseVal; + } + + if (Cmp0 == MatchCmp0 && + matchFMulByZeroIfResultEqZero(*this, Cmp0, Cmp1, MatchCmp1, MatchCmp0, + SI, SIFPOp->hasNoSignedZeros())) + return replaceInstUsesWith(SI, Cmp0); + } } - if (isa<FPMathOperator>(SI)) { + if (SIFPOp) { // TODO: Try to forward-propagate FMF from select arms to the select. // Canonicalize select of FP values where NaN and -0.0 are not valid as // minnum/maxnum intrinsics. - if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + if (SIFPOp->hasNoNaNs() && SIFPOp->hasNoSignedZeros()) { Value *X, *Y; if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) return replaceInstUsesWith( @@ -3430,6 +3513,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Instruction *I = foldSelectExtConst(SI)) return I; + if (Instruction *I = foldSelectWithSRem(SI, *this, Builder)) + return I; + // Fold (select C, (gep Ptr, Idx), Ptr) -> (gep Ptr, (select C, Idx, 0)) // Fold (select C, Ptr, (gep Ptr, Idx)) -> (gep Ptr, (select C, 0, Idx)) auto SelectGepWithBase = [&](GetElementPtrInst *Gep, Value *Base, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 89dad455f015..b7958978c450 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -136,9 +136,14 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( assert(IdenticalShOpcodes && "Should not get here with different shifts."); - // All good, we can do this fold. - NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + if (NewShAmt->getType() != X->getType()) { + NewShAmt = ConstantFoldCastOperand(Instruction::ZExt, NewShAmt, + X->getType(), SQ.DL); + if (!NewShAmt) + return nullptr; + } + // All good, we can do this fold. BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); // The flags can only be propagated if there wasn't a trunc. @@ -245,7 +250,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, SumOfShAmts = Constant::replaceUndefsWith( SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), ExtendedTy->getScalarSizeInBits())); - auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + auto *ExtendedSumOfShAmts = ConstantFoldCastOperand( + Instruction::ZExt, SumOfShAmts, ExtendedTy, Q.DL); + if (!ExtendedSumOfShAmts) + return nullptr; + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); auto *ExtendedInvertedMask = @@ -278,16 +287,22 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, ShAmtsDiff = Constant::replaceUndefsWith( ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -WidestTyBitWidth)); - auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + auto *ExtendedNumHighBitsToClear = ConstantFoldCastOperand( + Instruction::ZExt, ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), WidestTyBitWidth, /*isSigned=*/false), ShAmtsDiff), - ExtendedTy); + ExtendedTy, Q.DL); + if (!ExtendedNumHighBitsToClear) + return nullptr; + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); - NewMask = - ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + NewMask = ConstantFoldBinaryOpOperands(Instruction::LShr, ExtendedAllOnes, + ExtendedNumHighBitsToClear, Q.DL); + if (!NewMask) + return nullptr; } else return nullptr; // Don't know anything about this pattern. @@ -545,8 +560,8 @@ static bool canEvaluateShiftedShift(unsigned OuterShAmt, bool IsOuterShl, /// this succeeds, getShiftedValue() will be called to produce the value. static bool canEvaluateShifted(Value *V, unsigned NumBits, bool IsLeftShift, InstCombinerImpl &IC, Instruction *CxtI) { - // We can always evaluate constants shifted. - if (isa<Constant>(V)) + // We can always evaluate immediate constants. + if (match(V, m_ImmConstant())) return true; Instruction *I = dyn_cast<Instruction>(V); @@ -709,13 +724,13 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, case Instruction::Mul: { assert(!isLeftShift && "Unexpected shift direction!"); auto *Neg = BinaryOperator::CreateNeg(I->getOperand(0)); - IC.InsertNewInstWith(Neg, *I); + IC.InsertNewInstWith(Neg, I->getIterator()); unsigned TypeWidth = I->getType()->getScalarSizeInBits(); APInt Mask = APInt::getLowBitsSet(TypeWidth, TypeWidth - NumBits); auto *And = BinaryOperator::CreateAnd(Neg, ConstantInt::get(I->getType(), Mask)); And->takeName(I); - return IC.InsertNewInstWith(And, *I); + return IC.InsertNewInstWith(And, I->getIterator()); } } } @@ -745,7 +760,7 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, // (C2 >> X) >> C1 --> (C2 >> C1) >> X Constant *C2; Value *X; - if (match(Op0, m_BinOp(I.getOpcode(), m_Constant(C2), m_Value(X)))) + if (match(Op0, m_BinOp(I.getOpcode(), m_ImmConstant(C2), m_Value(X)))) return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); @@ -928,6 +943,60 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } +// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. +static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) { + assert(I.isShift() && "Expected a shift as input"); + // We already have all the flags. + if (I.getOpcode() == Instruction::Shl) { + if (I.hasNoUnsignedWrap() && I.hasNoSignedWrap()) + return false; + } else { + if (I.isExact()) + return false; + + // shr (shl X, Y), Y + if (match(I.getOperand(0), m_Shl(m_Value(), m_Specific(I.getOperand(1))))) { + I.setIsExact(); + return true; + } + } + + // Compute what we know about shift count. + KnownBits KnownCnt = computeKnownBits(I.getOperand(1), /* Depth */ 0, Q); + unsigned BitWidth = KnownCnt.getBitWidth(); + // Since shift produces a poison value if RHS is equal to or larger than the + // bit width, we can safely assume that RHS is less than the bit width. + uint64_t MaxCnt = KnownCnt.getMaxValue().getLimitedValue(BitWidth - 1); + + KnownBits KnownAmt = computeKnownBits(I.getOperand(0), /* Depth */ 0, Q); + bool Changed = false; + + if (I.getOpcode() == Instruction::Shl) { + // If we have as many leading zeros than maximum shift cnt we have nuw. + if (!I.hasNoUnsignedWrap() && MaxCnt <= KnownAmt.countMinLeadingZeros()) { + I.setHasNoUnsignedWrap(); + Changed = true; + } + // If we have more sign bits than maximum shift cnt we have nsw. + if (!I.hasNoSignedWrap()) { + if (MaxCnt < KnownAmt.countMinSignBits() || + MaxCnt < ComputeNumSignBits(I.getOperand(0), Q.DL, /*Depth*/ 0, Q.AC, + Q.CxtI, Q.DT)) { + I.setHasNoSignedWrap(); + Changed = true; + } + } + return Changed; + } + + // If we have at least as many trailing zeros as maximum count then we have + // exact. + Changed = MaxCnt <= KnownAmt.countMinTrailingZeros(); + I.setIsExact(Changed); + + return Changed; +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -976,7 +1045,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } @@ -997,7 +1070,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); - NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + NewShl->setHasNoUnsignedWrap( + I.hasNoUnsignedWrap() || + (ShrAmt && + cast<Instruction>(Op0)->getOpcode() == Instruction::LShr && + I.hasNoSignedWrap())); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); @@ -1108,22 +1185,11 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NewShift = Builder.CreateShl(X, Op1); return BinaryOperator::CreateSub(NewLHS, NewShift); } - - // If the shifted-out value is known-zero, then this is a NUW shift. - if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, - &I)) { - I.setHasNoUnsignedWrap(); - return &I; - } - - // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { - I.setHasNoSignedWrap(); - return &I; - } } + if (setShiftFlags(I, Q)) + return &I; + // Transform (x >> y) << y to x & (-1 << y) // Valid for any type of right-shift. Value *X; @@ -1161,15 +1227,6 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Value *NegX = Builder.CreateNeg(X, "neg"); return BinaryOperator::CreateAnd(NegX, X); } - - // The only way to shift out the 1 is with an over-shift, so that would - // be poison with or without "nuw". Undef is excluded because (undef << X) - // is not undef (it is zero). - Constant *ConstantOne = cast<Constant>(Op0); - if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { - I.setHasNoUnsignedWrap(); - return &I; - } } return nullptr; @@ -1235,9 +1292,10 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { unsigned ShlAmtC = C1->getZExtValue(); Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) + // (X <<nuw C1) >>u C --> X <<nuw/nsw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); + NewShl->setHasNoSignedWrap(ShAmtC > 0); return NewShl; } if (Op0->hasOneUse()) { @@ -1370,12 +1428,13 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { if (Op0->hasOneUse()) { APInt NewMulC = MulC->lshr(ShAmtC); // if c is divisible by (1 << ShAmtC): - // lshr (mul nuw x, MulC), ShAmtC -> mul nuw x, (MulC >> ShAmtC) + // lshr (mul nuw x, MulC), ShAmtC -> mul nuw nsw x, (MulC >> ShAmtC) if (MulC->eq(NewMulC.shl(ShAmtC))) { auto *NewMul = BinaryOperator::CreateNUWMul(X, ConstantInt::get(Ty, NewMulC)); - BinaryOperator *OrigMul = cast<BinaryOperator>(Op0); - NewMul->setHasNoSignedWrap(OrigMul->hasNoSignedWrap()); + assert(ShAmtC != 0 && + "lshr X, 0 should be handled by simplifyLShrInst."); + NewMul->setHasNoSignedWrap(true); return NewMul; } } @@ -1414,15 +1473,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *And = Builder.CreateAnd(BoolX, BoolY); return new ZExtInst(And, Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Transform (x << y) >> y to x & (-1 >> y) if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); @@ -1581,15 +1637,12 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { if (match(Op0, m_OneUse(m_NSWSub(m_Value(X), m_Value(Y))))) return new SExtInst(Builder.CreateICmpSLT(X, Y), Ty); } - - // If the shifted-out value is known-zero, then this is an exact shift. - if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { - I.setIsExact(); - return &I; - } } + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (setShiftFlags(I, Q)) + return &I; + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` // as the pattern to splat the lowest bit. // FIXME: iff X is already masked, we don't need the one-use check. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 00eece9534b0..046ce9d1207e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -24,6 +24,12 @@ using namespace llvm::PatternMatch; #define DEBUG_TYPE "instcombine" +static cl::opt<bool> + VerifyKnownBits("instcombine-verify-known-bits", + cl::desc("Verify that computeKnownBits() and " + "SimplifyDemandedBits() are consistent"), + cl::Hidden, cl::init(false)); + /// Check to see if the specified operand of the specified instruction is a /// constant integer. If so, check to see if there are any bits set in the /// constant that are not demanded. If so, shrink the constant and return true. @@ -48,15 +54,20 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, return true; } +/// Returns the bitwidth of the given scalar or pointer type. For vector types, +/// returns the element type's bitwidth. +static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { + if (unsigned BitWidth = Ty->getScalarSizeInBits()) + return BitWidth; + return DL.getPointerTypeSizeInBits(Ty); +} /// Inst is an integer instruction that SimplifyDemandedBits knows about. See if /// the instruction has any properties that allow us to simplify its operands. -bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { - unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); - KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnes(BitWidth)); - +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst, + KnownBits &Known) { + APInt DemandedMask(APInt::getAllOnes(Known.getBitWidth())); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); if (!V) return false; @@ -65,6 +76,13 @@ bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { return true; } +/// Inst is an integer instruction that SimplifyDemandedBits knows about. See if +/// the instruction has any properties that allow us to simplify its operands. +bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { + KnownBits Known(getBitWidth(Inst.getType(), DL)); + return SimplifyDemandedInstructionBits(Inst, Known); +} + /// This form of SimplifyDemandedBits simplifies the specified instruction /// operand if possible, updating it in place. It returns true if it made any /// change and false otherwise. @@ -95,8 +113,8 @@ bool InstCombinerImpl::SimplifyDemandedBits(Instruction *I, unsigned OpNo, /// expression. /// Known.One and Known.Zero always follow the invariant that: /// Known.One & Known.Zero == 0. -/// That is, a bit can't be both 1 and 0. Note that the bits in Known.One and -/// Known.Zero may only be accurate for those bits set in DemandedMask. Note +/// That is, a bit can't be both 1 and 0. The bits in Known.One and Known.Zero +/// are accurate even for bits not in DemandedMask. Note /// also that the bitwidth of V, DemandedMask, Known.Zero and Known.One must all /// be the same. /// @@ -143,7 +161,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return SimplifyMultipleUseDemandedBits(I, DemandedMask, Known, Depth, CxtI); KnownBits LHSKnown(BitWidth), RHSKnown(BitWidth); - // If this is the root being simplified, allow it to have multiple uses, // just set the DemandedMask to all bits so that we can try to simplify the // operands. This allows visitTruncInst (for example) to simplify the @@ -196,7 +213,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -220,13 +237,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // If either the LHS or the RHS are One, the result is One. if (SimplifyDemandedBits(I, 1, DemandedMask, RHSKnown, Depth + 1) || SimplifyDemandedBits(I, 0, DemandedMask & ~RHSKnown.One, LHSKnown, - Depth + 1)) + Depth + 1)) { + // Disjoint flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -244,6 +264,16 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (ShrinkDemandedConstant(I, 1, DemandedMask)) return I; + // Infer disjoint flag if no common bits are set. + if (!cast<PossiblyDisjointInst>(I)->isDisjoint()) { + WithCache<const Value *> LHSCache(I->getOperand(0), LHSKnown), + RHSCache(I->getOperand(1), RHSKnown); + if (haveNoCommonBitsSet(LHSCache, RHSCache, SQ.getWithInstruction(I))) { + cast<PossiblyDisjointInst>(I)->setIsDisjoint(true); + return I; + } + } + break; } case Instruction::Xor: { @@ -265,7 +295,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, - Depth, DL, &AC, CxtI, &DT); + Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -284,9 +314,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // e.g. (A & C1)^(B & C2) -> (A & C1)|(B & C2) iff C1&C2 == 0 if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.Zero)) { Instruction *Or = - BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1), - I->getName()); - return InsertNewInstWith(Or, *I); + BinaryOperator::CreateOr(I->getOperand(0), I->getOperand(1)); + if (DemandedMask.isAllOnes()) + cast<PossiblyDisjointInst>(Or)->setIsDisjoint(true); + Or->takeName(I); + return InsertNewInstWith(Or, I->getIterator()); } // If all of the demanded bits on one side are known, and all of the set @@ -298,7 +330,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *AndC = Constant::getIntegerValue(VTy, ~RHSKnown.One & DemandedMask); Instruction *And = BinaryOperator::CreateAnd(I->getOperand(0), AndC); - return InsertNewInstWith(And, *I); + return InsertNewInstWith(And, I->getIterator()); } // If the RHS is a constant, see if we can change it. Don't alter a -1 @@ -330,11 +362,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *AndC = ConstantInt::get(VTy, NewMask & AndRHS->getValue()); Instruction *NewAnd = BinaryOperator::CreateAnd(I->getOperand(0), AndC); - InsertNewInstWith(NewAnd, *I); + InsertNewInstWith(NewAnd, I->getIterator()); Constant *XorC = ConstantInt::get(VTy, NewMask & XorRHS->getValue()); Instruction *NewXor = BinaryOperator::CreateXor(NewAnd, XorC); - return InsertNewInstWith(NewXor, *I); + return InsertNewInstWith(NewXor, I->getIterator()); } } break; @@ -411,36 +443,21 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); KnownBits InputKnown(SrcBitWidth); - if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, InputDemandedMask, InputKnown, Depth + 1)) { + // For zext nneg, we may have dropped the instruction which made the + // input non-negative. + I->dropPoisonGeneratingFlags(); return I; + } assert(InputKnown.getBitWidth() == SrcBitWidth && "Src width changed?"); + if (I->getOpcode() == Instruction::ZExt && I->hasNonNeg() && + !InputKnown.isNegative()) + InputKnown.makeNonNegative(); Known = InputKnown.zextOrTrunc(BitWidth); - assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - break; - } - case Instruction::BitCast: - if (!I->getOperand(0)->getType()->isIntOrIntVectorTy()) - return nullptr; // vector->int or fp->int? - - if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { - if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { - if (isa<ScalableVectorType>(DstVTy) || - isa<ScalableVectorType>(SrcVTy) || - cast<FixedVectorType>(DstVTy)->getNumElements() != - cast<FixedVectorType>(SrcVTy)->getNumElements()) - // Don't touch a bitcast between vectors of different element counts. - return nullptr; - } else - // Don't touch a scalar-to-vector bitcast. - return nullptr; - } else if (I->getOperand(0)->getType()->isVectorTy()) - // Don't touch a vector-to-scalar bitcast. - return nullptr; - if (SimplifyDemandedBits(I, 0, DemandedMask, Known, Depth + 1)) - return I; assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; + } case Instruction::SExt: { // Compute the bits in the result that are not present in the input. unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); @@ -461,8 +478,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (InputKnown.isNonNegative() || DemandedMask.getActiveBits() <= SrcBitWidth) { // Convert to ZExt cast. - CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy, I->getName()); - return InsertNewInstWith(NewCast, *I); + CastInst *NewCast = new ZExtInst(I->getOperand(0), VTy); + NewCast->takeName(I); + return InsertNewInstWith(NewCast, I->getIterator()); } // If the sign bit of the input is known set or clear, then we know the @@ -586,7 +604,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (match(I->getOperand(1), m_APInt(C)) && C->countr_zero() == CTZ) { Constant *ShiftC = ConstantInt::get(VTy, CTZ); Instruction *Shl = BinaryOperator::CreateShl(I->getOperand(0), ShiftC); - return InsertNewInstWith(Shl, *I); + return InsertNewInstWith(Shl, I->getIterator()); } } // For a squared value "X * X", the bottom 2 bits are 0 and X[0] because: @@ -595,7 +613,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (I->getOperand(0) == I->getOperand(1) && DemandedMask.ult(4)) { Constant *One = ConstantInt::get(VTy, 1); Instruction *And1 = BinaryOperator::CreateAnd(I->getOperand(0), One); - return InsertNewInstWith(And1, *I); + return InsertNewInstWith(And1, I->getIterator()); } computeKnownBits(I, Known, Depth, CxtI); @@ -624,10 +642,12 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.countr_zero() >= ShiftAmt && match(I->getOperand(0), m_LShr(m_ImmConstant(C), m_Value(X)))) { Constant *LeftShiftAmtC = ConstantInt::get(VTy, ShiftAmt); - Constant *NewC = ConstantExpr::getShl(C, LeftShiftAmtC); - if (ConstantExpr::getLShr(NewC, LeftShiftAmtC) == C) { + Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::Shl, C, + LeftShiftAmtC, DL); + if (ConstantFoldBinaryOpOperands(Instruction::LShr, NewC, LeftShiftAmtC, + DL) == C) { Instruction *Lshr = BinaryOperator::CreateLShr(NewC, X); - return InsertNewInstWith(Lshr, *I); + return InsertNewInstWith(Lshr, I->getIterator()); } } @@ -688,24 +708,23 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Constant *C; if (match(I->getOperand(0), m_Shl(m_ImmConstant(C), m_Value(X)))) { Constant *RightShiftAmtC = ConstantInt::get(VTy, ShiftAmt); - Constant *NewC = ConstantExpr::getLShr(C, RightShiftAmtC); - if (ConstantExpr::getShl(NewC, RightShiftAmtC) == C) { + Constant *NewC = ConstantFoldBinaryOpOperands(Instruction::LShr, C, + RightShiftAmtC, DL); + if (ConstantFoldBinaryOpOperands(Instruction::Shl, NewC, + RightShiftAmtC, DL) == C) { Instruction *Shl = BinaryOperator::CreateShl(NewC, X); - return InsertNewInstWith(Shl, *I); + return InsertNewInstWith(Shl, I->getIterator()); } } } // Unsigned shift right. APInt DemandedMaskIn(DemandedMask.shl(ShiftAmt)); - - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (cast<LShrOperator>(I)->isExact()) - DemandedMaskIn.setLowBits(ShiftAmt); - - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + // exact flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); @@ -733,7 +752,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); - return InsertNewInstWith(NewVal, *I); + return InsertNewInstWith(NewVal, I->getIterator()); } const APInt *SA; @@ -747,13 +766,11 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (DemandedMask.countl_zero() <= ShiftAmt) DemandedMaskIn.setSignBit(); - // If the shift is exact, then it does demand the low bits (and knows that - // they are zero). - if (cast<AShrOperator>(I)->isExact()) - DemandedMaskIn.setLowBits(ShiftAmt); - - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) { + // exact flag may not longer hold. + I->dropPoisonGeneratingFlags(); return I; + } assert(!Known.hasConflict() && "Bits known to be one AND zero?"); // Compute the new bits that are at the top now plus sign bits. @@ -770,7 +787,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), I->getOperand(1)); LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); - return InsertNewInstWith(LShr, *I); + LShr->takeName(I); + return InsertNewInstWith(LShr, I->getIterator()); } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. Known.One |= HighBits; } @@ -867,7 +885,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, match(II->getArgOperand(0), m_Not(m_Value(X)))) { Function *Ctpop = Intrinsic::getDeclaration( II->getModule(), Intrinsic::ctpop, VTy); - return InsertNewInstWith(CallInst::Create(Ctpop, {X}), *I); + return InsertNewInstWith(CallInst::Create(Ctpop, {X}), I->getIterator()); } break; } @@ -894,10 +912,52 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, NewVal = BinaryOperator::CreateShl( II->getArgOperand(0), ConstantInt::get(VTy, NTZ - NLZ)); NewVal->takeName(I); - return InsertNewInstWith(NewVal, *I); + return InsertNewInstWith(NewVal, I->getIterator()); } break; } + case Intrinsic::ptrmask: { + unsigned MaskWidth = I->getOperand(1)->getType()->getScalarSizeInBits(); + RHSKnown = KnownBits(MaskWidth); + // If either the LHS or the RHS are Zero, the result is zero. + if (SimplifyDemandedBits(I, 0, DemandedMask, LHSKnown, Depth + 1) || + SimplifyDemandedBits( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth), + RHSKnown, Depth + 1)) + return I; + + // TODO: Should be 1-extend + RHSKnown = RHSKnown.anyextOrTrunc(BitWidth); + assert(!RHSKnown.hasConflict() && "Bits known to be one AND zero?"); + assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); + + Known = LHSKnown & RHSKnown; + KnownBitsComputed = true; + + // If the client is only demanding bits we know to be zero, return + // `llvm.ptrmask(p, 0)`. We can't return `null` here due to pointer + // provenance, but making the mask zero will be easily optimizable in + // the backend. + if (DemandedMask.isSubsetOf(Known.Zero) && + !match(I->getOperand(1), m_Zero())) + return replaceOperand( + *I, 1, Constant::getNullValue(I->getOperand(1)->getType())); + + // Mask in demanded space does nothing. + // NOTE: We may have attributes associated with the return value of the + // llvm.ptrmask intrinsic that will be lost when we just return the + // operand. We should try to preserve them. + if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) + return I->getOperand(0); + + // If the RHS is a constant, see if we can simplify it. + if (ShrinkDemandedConstant( + I, 1, (DemandedMask & ~LHSKnown.Zero).zextOrTrunc(MaskWidth))) + return I; + + break; + } + case Intrinsic::fshr: case Intrinsic::fshl: { const APInt *SA; @@ -918,7 +978,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, SimplifyDemandedBits(I, 1, DemandedMaskRHS, RHSKnown, Depth + 1)) return I; } else { // fshl is a rotate - // Avoid converting rotate into funnel shift. + // Avoid converting rotate into funnel shift. // Only simplify if one operand is constant. LHSKnown = computeKnownBits(I->getOperand(0), Depth + 1, I); if (DemandedMaskLHS.isSubsetOf(LHSKnown.Zero | LHSKnown.One) && @@ -982,10 +1042,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } + if (V->getType()->isPointerTy()) { + Align Alignment = V->getPointerAlignment(DL); + Known.Zero.setLowBits(Log2(Alignment)); + } + // If the client is only demanding bits that we know, return the known - // constant. - if (DemandedMask.isSubsetOf(Known.Zero|Known.One)) + // constant. We can't directly simplify pointers as a constant because of + // pointer provenance. + // TODO: We could return `(inttoptr const)` for pointers. + if (!V->getType()->isPointerTy() && DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(VTy, Known.One); + + if (VerifyKnownBits) { + KnownBits ReferenceKnown = computeKnownBits(V, Depth, CxtI); + if (Known != ReferenceKnown) { + errs() << "Mismatched known bits for " << *V << " in " + << I->getFunction()->getName() << "\n"; + errs() << "computeKnownBits(): " << ReferenceKnown << "\n"; + errs() << "SimplifyDemandedBits(): " << Known << "\n"; + std::abort(); + } + } + return nullptr; } @@ -1009,8 +1088,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::And: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown & RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1029,8 +1109,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::Or: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown | RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1051,8 +1132,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( case Instruction::Xor: { computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); - Known = LHSKnown ^ RHSKnown; - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + Known = analyzeKnownBitsFromAndXorOr(cast<Operator>(I), LHSKnown, RHSKnown, + Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); // If the client is only demanding bits that we know, return the known // constant. @@ -1085,7 +1167,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); Known = KnownBits::computeForAddSub(/*Add*/ true, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::Sub: { @@ -1101,7 +1183,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = KnownBits::computeForAddSub(/*Add*/ false, NSW, LHSKnown, RHSKnown); - computeKnownBitsFromAssume(I, Known, Depth, SQ.getWithInstruction(CxtI)); + computeKnownBitsFromContext(I, Known, Depth, SQ.getWithInstruction(CxtI)); break; } case Instruction::AShr: { @@ -1219,7 +1301,7 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits( New->setIsExact(true); } - return InsertNewInstWith(New, *Shl); + return InsertNewInstWith(New, Shl->getIterator()); } return nullptr; @@ -1549,7 +1631,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, Instruction *New = InsertElementInst::Create( Op, Value, ConstantInt::get(Type::getInt64Ty(I->getContext()), Idx), Shuffle->getName()); - InsertNewInstWith(New, *Shuffle); + InsertNewInstWith(New, Shuffle->getIterator()); return New; } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 4a5ffef2b08e..c8b58c51d4e6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -132,7 +132,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // Create a scalar PHI node that will replace the vector PHI node // just before the current PHI node. PHINode *scalarPHI = cast<PHINode>(InsertNewInstWith( - PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), *PN)); + PHINode::Create(EI.getType(), PN->getNumIncomingValues(), ""), PN->getIterator())); // Scalarize each PHI operand. for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { Value *PHIInVal = PN->getIncomingValue(i); @@ -148,10 +148,10 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, Value *Op = InsertNewInstWith( ExtractElementInst::Create(B0->getOperand(opId), Elt, B0->getOperand(opId)->getName() + ".Elt"), - *B0); + B0->getIterator()); Value *newPHIUser = InsertNewInstWith( BinaryOperator::CreateWithCopiedFlags(B0->getOpcode(), - scalarPHI, Op, B0), *B0); + scalarPHI, Op, B0), B0->getIterator()); scalarPHI->addIncoming(newPHIUser, inBB); } else { // Scalarize PHI input: @@ -165,7 +165,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, InsertPos = inBB->getFirstInsertionPt(); } - InsertNewInstWith(newEI, *InsertPos); + InsertNewInstWith(newEI, InsertPos); scalarPHI->addIncoming(newEI, inBB); } @@ -441,7 +441,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { if (IndexC->getValue().getActiveBits() <= BitWidth) Idx = ConstantInt::get(Ty, IndexC->getValue().zextOrTrunc(BitWidth)); else - Idx = UndefValue::get(Ty); + Idx = PoisonValue::get(Ty); return replaceInstUsesWith(EI, Idx); } } @@ -742,7 +742,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt, if (ExtVecOpInst && !isa<PHINode>(ExtVecOpInst)) WideVec->insertAfter(ExtVecOpInst); else - IC.InsertNewInstWith(WideVec, *ExtElt->getParent()->getFirstInsertionPt()); + IC.InsertNewInstWith(WideVec, ExtElt->getParent()->getFirstInsertionPt()); // Replace extracts from the original narrow vector with extracts from the new // wide vector. @@ -751,7 +751,7 @@ static bool replaceExtractElements(InsertElementInst *InsElt, if (!OldExt || OldExt->getParent() != WideVec->getParent()) continue; auto *NewExt = ExtractElementInst::Create(WideVec, OldExt->getOperand(1)); - IC.InsertNewInstWith(NewExt, *OldExt); + IC.InsertNewInstWith(NewExt, OldExt->getIterator()); IC.replaceInstUsesWith(*OldExt, NewExt); // Add the old extracts to the worklist for DCE. We can't remove the // extracts directly, because they may still be used by the calling code. @@ -1121,7 +1121,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // Note that the same block can be a predecessor more than once, // and we need to preserve that invariant for the PHI node. BuilderTy::InsertPointGuard Guard(Builder); - Builder.SetInsertPoint(UseBB->getFirstNonPHI()); + Builder.SetInsertPoint(UseBB, UseBB->getFirstNonPHIIt()); auto *PHI = Builder.CreatePHI(AggTy, Preds.size(), OrigIVI.getName() + ".merged"); for (BasicBlock *Pred : Preds) @@ -2122,8 +2122,8 @@ static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i]; // A select mask with undef elements might look like an identity mask. - assert((ShuffleVectorInst::isSelectMask(NewMask) || - ShuffleVectorInst::isIdentityMask(NewMask)) && + assert((ShuffleVectorInst::isSelectMask(NewMask, NumElts) || + ShuffleVectorInst::isIdentityMask(NewMask, NumElts)) && "Unexpected shuffle mask"); return new ShuffleVectorInst(X, Y, NewMask); } @@ -2197,9 +2197,9 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, !match(Op1, m_Undef()) || match(Mask, m_ZeroMask()) || IndexC == 0) return nullptr; - // Insert into element 0 of an undef vector. - UndefValue *UndefVec = UndefValue::get(Shuf.getType()); - Value *NewIns = Builder.CreateInsertElement(UndefVec, X, (uint64_t)0); + // Insert into element 0 of a poison vector. + PoisonValue *PoisonVec = PoisonValue::get(Shuf.getType()); + Value *NewIns = Builder.CreateInsertElement(PoisonVec, X, (uint64_t)0); // Splat from element 0. Any mask element that is undefined remains undefined. // For example: diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index afd6e034f46d..f072f5cec309 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -130,13 +130,6 @@ STATISTIC(NumReassoc , "Number of reassociations"); DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); -// FIXME: these limits eventually should be as low as 2. -#ifndef NDEBUG -static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 100; -#else -static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; -#endif - static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); @@ -145,12 +138,6 @@ static cl::opt<unsigned> MaxSinkNumUsers( "instcombine-max-sink-users", cl::init(32), cl::desc("Maximum number of undroppable users for instruction sinking")); -static cl::opt<unsigned> InfiniteLoopDetectionThreshold( - "instcombine-infinite-loop-threshold", - cl::desc("Number of instruction combining iterations considered an " - "infinite loop"), - cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden); - static cl::opt<unsigned> MaxArraySize("instcombine-maxarray-size", cl::init(1024), cl::desc("Maximum array size considered when doing a combine")); @@ -358,15 +345,19 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // Fold the constants together in the destination type: // (op (cast (op X, C2)), C1) --> (op (cast X), FoldedC) + const DataLayout &DL = IC.getDataLayout(); Type *DestTy = C1->getType(); - Constant *CastC2 = ConstantExpr::getCast(CastOpcode, C2, DestTy); - Constant *FoldedC = - ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, IC.getDataLayout()); + Constant *CastC2 = ConstantFoldCastOperand(CastOpcode, C2, DestTy, DL); + if (!CastC2) + return false; + Constant *FoldedC = ConstantFoldBinaryOpOperands(AssocOpcode, C1, CastC2, DL); if (!FoldedC) return false; IC.replaceOperand(*Cast, 0, BinOp2->getOperand(0)); IC.replaceOperand(*BinOp1, 1, FoldedC); + BinOp1->dropPoisonGeneratingFlags(); + Cast->dropPoisonGeneratingFlags(); return true; } @@ -542,12 +533,12 @@ bool InstCombinerImpl::SimplifyAssociativeOrCommutative(BinaryOperator &I) { BinaryOperator::Create(Opcode, A, B); if (isa<FPMathOperator>(NewBO)) { - FastMathFlags Flags = I.getFastMathFlags(); - Flags &= Op0->getFastMathFlags(); - Flags &= Op1->getFastMathFlags(); - NewBO->setFastMathFlags(Flags); + FastMathFlags Flags = I.getFastMathFlags() & + Op0->getFastMathFlags() & + Op1->getFastMathFlags(); + NewBO->setFastMathFlags(Flags); } - InsertNewInstWith(NewBO, I); + InsertNewInstWith(NewBO, I.getIterator()); NewBO->takeName(Op1); replaceOperand(I, 0, NewBO); replaceOperand(I, 1, CRes); @@ -749,7 +740,16 @@ static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, // 2) BinOp1 == BinOp2 (if BinOp == `add`, then also requires `shl`). // // -> (BinOp (logic_shift (BinOp X, Y)), Mask) +// +// (Binop1 (Binop2 (arithmetic_shift X, Amt), Mask), (arithmetic_shift Y, Amt)) +// IFF +// 1) Binop1 is bitwise logical operator `and`, `or` or `xor` +// 2) Binop2 is `not` +// +// -> (arithmetic_shift Binop1((not X), Y), Amt) + Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { + const DataLayout &DL = I.getModule()->getDataLayout(); auto IsValidBinOpc = [](unsigned Opc) { switch (Opc) { default: @@ -768,11 +768,13 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { // constraints. auto IsCompletelyDistributable = [](unsigned BinOpc1, unsigned BinOpc2, unsigned ShOpc) { + assert(ShOpc != Instruction::AShr); return (BinOpc1 != Instruction::Add && BinOpc2 != Instruction::Add) || ShOpc == Instruction::Shl; }; auto GetInvShift = [](unsigned ShOpc) { + assert(ShOpc != Instruction::AShr); return ShOpc == Instruction::LShr ? Instruction::Shl : Instruction::LShr; }; @@ -796,23 +798,23 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { // Otherwise, need mask that meets the below requirement. // (logic_shift (inv_logic_shift Mask, ShAmt), ShAmt) == Mask - return ConstantExpr::get( - ShOpc, ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift), - CShift) == CMask; + Constant *MaskInvShift = + ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL); + return ConstantFoldBinaryOpOperands(ShOpc, MaskInvShift, CShift, DL) == + CMask; }; auto MatchBinOp = [&](unsigned ShOpnum) -> Instruction * { Constant *CMask, *CShift; Value *X, *Y, *ShiftedX, *Mask, *Shift; if (!match(I.getOperand(ShOpnum), - m_OneUse(m_LogicalShift(m_Value(Y), m_Value(Shift))))) + m_OneUse(m_Shift(m_Value(Y), m_Value(Shift))))) return nullptr; if (!match(I.getOperand(1 - ShOpnum), m_BinOp(m_Value(ShiftedX), m_Value(Mask)))) return nullptr; - if (!match(ShiftedX, - m_OneUse(m_LogicalShift(m_Value(X), m_Specific(Shift))))) + if (!match(ShiftedX, m_OneUse(m_Shift(m_Value(X), m_Specific(Shift))))) return nullptr; // Make sure we are matching instruction shifts and not ConstantExpr @@ -836,6 +838,18 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { if (!IsValidBinOpc(I.getOpcode()) || !IsValidBinOpc(BinOpc)) return nullptr; + if (ShOpc == Instruction::AShr) { + if (Instruction::isBitwiseLogicOp(I.getOpcode()) && + BinOpc == Instruction::Xor && match(Mask, m_AllOnes())) { + Value *NotX = Builder.CreateNot(X); + Value *NewBinOp = Builder.CreateBinOp(I.getOpcode(), Y, NotX); + return BinaryOperator::Create( + static_cast<Instruction::BinaryOps>(ShOpc), NewBinOp, Shift); + } + + return nullptr; + } + // If BinOp1 == BinOp2 and it's bitwise or shl with add, then just // distribute to drop the shift irrelevant of constants. if (BinOpc == I.getOpcode() && @@ -857,7 +871,8 @@ Instruction *InstCombinerImpl::foldBinOpShiftWithShift(BinaryOperator &I) { if (!CanDistributeBinops(I.getOpcode(), BinOpc, ShOpc, CMask, CShift)) return nullptr; - Constant *NewCMask = ConstantExpr::get(GetInvShift(ShOpc), CMask, CShift); + Constant *NewCMask = + ConstantFoldBinaryOpOperands(GetInvShift(ShOpc), CMask, CShift, DL); Value *NewBinOp2 = Builder.CreateBinOp( static_cast<Instruction::BinaryOps>(BinOpc), X, NewCMask); Value *NewBinOp1 = Builder.CreateBinOp(I.getOpcode(), Y, NewBinOp2); @@ -924,13 +939,17 @@ InstCombinerImpl::foldBinOpOfSelectAndCastOfSelectCondition(BinaryOperator &I) { // If the value used in the zext/sext is the select condition, or the negated // of the select condition, the binop can be simplified. - if (CondVal == A) - return SelectInst::Create(CondVal, NewFoldedConst(false, TrueVal), + if (CondVal == A) { + Value *NewTrueVal = NewFoldedConst(false, TrueVal); + return SelectInst::Create(CondVal, NewTrueVal, NewFoldedConst(true, FalseVal)); + } - if (match(A, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(CondVal, NewFoldedConst(true, TrueVal), + if (match(A, m_Not(m_Specific(CondVal)))) { + Value *NewTrueVal = NewFoldedConst(true, TrueVal); + return SelectInst::Create(CondVal, NewTrueVal, NewFoldedConst(false, FalseVal)); + } return nullptr; } @@ -1167,6 +1186,8 @@ void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { break; case Instruction::Xor: replaceInstUsesWith(cast<Instruction>(*U), I); + // Add to worklist for DCE. + addToWorklist(cast<Instruction>(U)); break; default: llvm_unreachable("Got unexpected user - out of sync with " @@ -1268,7 +1289,7 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, Value *NewOp, InstCombiner &IC) { Instruction *Clone = I.clone(); Clone->replaceUsesOfWith(SI, NewOp); - IC.InsertNewInstBefore(Clone, *SI); + IC.InsertNewInstBefore(Clone, SI->getIterator()); return Clone; } @@ -1302,6 +1323,21 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return nullptr; } + // Test if a FCmpInst instruction is used exclusively by a select as + // part of a minimum or maximum operation. If so, refrain from doing + // any other folding. This helps out other analyses which understand + // non-obfuscated minimum and maximum idioms. And in this case, at + // least one of the comparison operands has at least one user besides + // the compare (the select), which would often largely negate the + // benefit of folding anyway. + if (auto *CI = dyn_cast<FCmpInst>(SI->getCondition())) { + if (CI->hasOneUse()) { + Value *Op0 = CI->getOperand(0), *Op1 = CI->getOperand(1); + if ((TV == Op0 && FV == Op1) || (FV == Op0 && TV == Op1)) + return nullptr; + } + } + // Make sure that one of the select arms constant folds successfully. Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); @@ -1316,6 +1352,47 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } +static Value *simplifyInstructionWithPHI(Instruction &I, PHINode *PN, + Value *InValue, BasicBlock *InBB, + const DataLayout &DL, + const SimplifyQuery SQ) { + // NB: It is a precondition of this transform that the operands be + // phi translatable! This is usually trivially satisfied by limiting it + // to constant ops, and for selects we do a more sophisticated check. + SmallVector<Value *> Ops; + for (Value *Op : I.operands()) { + if (Op == PN) + Ops.push_back(InValue); + else + Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); + } + + // Don't consider the simplification successful if we get back a constant + // expression. That's just an instruction in hiding. + // Also reject the case where we simplify back to the phi node. We wouldn't + // be able to remove it in that case. + Value *NewVal = simplifyInstructionWithOperands( + &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); + if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) + return NewVal; + + // Check if incoming PHI value can be replaced with constant + // based on implied condition. + BranchInst *TerminatorBI = dyn_cast<BranchInst>(InBB->getTerminator()); + const ICmpInst *ICmp = dyn_cast<ICmpInst>(&I); + if (TerminatorBI && TerminatorBI->isConditional() && + TerminatorBI->getSuccessor(0) != TerminatorBI->getSuccessor(1) && ICmp) { + bool LHSIsTrue = TerminatorBI->getSuccessor(0) == PN->getParent(); + std::optional<bool> ImpliedCond = + isImpliedCondition(TerminatorBI->getCondition(), ICmp->getPredicate(), + Ops[0], Ops[1], DL, LHSIsTrue); + if (ImpliedCond) + return ConstantInt::getBool(I.getType(), ImpliedCond.value()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) @@ -1344,29 +1421,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { Value *InVal = PN->getIncomingValue(i); BasicBlock *InBB = PN->getIncomingBlock(i); - // NB: It is a precondition of this transform that the operands be - // phi translatable! This is usually trivially satisfied by limiting it - // to constant ops, and for selects we do a more sophisticated check. - SmallVector<Value *> Ops; - for (Value *Op : I.operands()) { - if (Op == PN) - Ops.push_back(InVal); - else - Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); - } - - // Don't consider the simplification successful if we get back a constant - // expression. That's just an instruction in hiding. - // Also reject the case where we simplify back to the phi node. We wouldn't - // be able to remove it in that case. - Value *NewVal = simplifyInstructionWithOperands( - &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); - if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) { + if (auto *NewVal = simplifyInstructionWithPHI(I, PN, InVal, InBB, DL, SQ)) { NewPhiValues.push_back(NewVal); continue; } - if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. if (NonSimplifiedBB) return nullptr; // More than one non-simplified value. NonSimplifiedBB = InBB; @@ -1402,7 +1461,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // Okay, we can do the transformation: create the new PHI node. PHINode *NewPN = PHINode::Create(I.getType(), PN->getNumIncomingValues()); - InsertNewInstBefore(NewPN, *PN); + InsertNewInstBefore(NewPN, PN->getIterator()); NewPN->takeName(PN); NewPN->setDebugLoc(PN->getDebugLoc()); @@ -1417,7 +1476,7 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { else U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB); } - InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator()); + InsertNewInstBefore(Clone, NonSimplifiedBB->getTerminator()->getIterator()); } for (unsigned i = 0; i != NumPHIValues; ++i) { @@ -1848,8 +1907,8 @@ Instruction *InstCombinerImpl::narrowMathIfNoOverflow(BinaryOperator &BO) { Constant *WideC; if (!Op0->hasOneUse() || !match(Op1, m_Constant(WideC))) return nullptr; - Constant *NarrowC = ConstantExpr::getTrunc(WideC, X->getType()); - if (ConstantExpr::getCast(CastOpc, NarrowC, BO.getType()) != WideC) + Constant *NarrowC = getLosslessTrunc(WideC, X->getType(), CastOpc); + if (!NarrowC) return nullptr; Y = NarrowC; } @@ -1940,7 +1999,7 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), 0); if (NumVarIndices != Src->getNumIndices()) { // FIXME: getIndexedOffsetInType() does not handled scalable vectors. - if (isa<ScalableVectorType>(BaseType)) + if (BaseType->isScalableTy()) return nullptr; SmallVector<Value *> ConstantIndices; @@ -2048,12 +2107,126 @@ Instruction *InstCombinerImpl::visitGEPOfGEP(GetElementPtrInst &GEP, return nullptr; } +Value *InstCombiner::getFreelyInvertedImpl(Value *V, bool WillInvertAllUses, + BuilderTy *Builder, + bool &DoesConsume, unsigned Depth) { + static Value *const NonNull = reinterpret_cast<Value *>(uintptr_t(1)); + // ~(~(X)) -> X. + Value *A, *B; + if (match(V, m_Not(m_Value(A)))) { + DoesConsume = true; + return A; + } + + Constant *C; + // Constants can be considered to be not'ed values. + if (match(V, m_ImmConstant(C))) + return ConstantExpr::getNot(C); + + if (Depth++ >= MaxAnalysisRecursionDepth) + return nullptr; + + // The rest of the cases require that we invert all uses so don't bother + // doing the analysis if we know we can't use the result. + if (!WillInvertAllUses) + return nullptr; + + // Compares can be inverted if all of their uses are being modified to use + // the ~V. + if (auto *I = dyn_cast<CmpInst>(V)) { + if (Builder != nullptr) + return Builder->CreateCmp(I->getInversePredicate(), I->getOperand(0), + I->getOperand(1)); + return NonNull; + } + + // If `V` is of the form `A + B` then `-1 - V` can be folded into + // `(-1 - B) - A` if we are willing to invert all of the uses. + if (match(V, m_Add(m_Value(A), m_Value(B)))) { + if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateSub(BV, A) : NonNull; + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateSub(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `A ^ ~B` then `~(A ^ ~B)` can be folded + // into `A ^ B` if we are willing to invert all of the uses. + if (match(V, m_Xor(m_Value(A), m_Value(B)))) { + if (auto *BV = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateXor(A, BV) : NonNull; + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateXor(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `B - A` then `-1 - V` can be folded into + // `A + (-1 - B)` if we are willing to invert all of the uses. + if (match(V, m_Sub(m_Value(A), m_Value(B)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAdd(AV, B) : NonNull; + return nullptr; + } + + // If `V` is of the form `(~A) s>> B` then `~((~A) s>> B)` can be folded + // into `A s>> B` if we are willing to invert all of the uses. + if (match(V, m_AShr(m_Value(A), m_Value(B)))) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAShr(AV, B) : NonNull; + return nullptr; + } + + // Treat lshr with non-negative operand as ashr. + if (match(V, m_LShr(m_Value(A), m_Value(B))) && + isKnownNonNegative(A, SQ.getWithInstruction(cast<Instruction>(V)), + Depth)) { + if (auto *AV = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) + return Builder ? Builder->CreateAShr(AV, B) : NonNull; + return nullptr; + } + + Value *Cond; + // LogicOps are special in that we canonicalize them at the cost of an + // instruction. + bool IsSelect = match(V, m_Select(m_Value(Cond), m_Value(A), m_Value(B))) && + !shouldAvoidAbsorbingNotIntoSelect(*cast<SelectInst>(V)); + // Selects/min/max with invertible operands are freely invertible + if (IsSelect || match(V, m_MaxOrMin(m_Value(A), m_Value(B)))) { + if (!getFreelyInvertedImpl(B, B->hasOneUse(), /*Builder*/ nullptr, + DoesConsume, Depth)) + return nullptr; + if (Value *NotA = getFreelyInvertedImpl(A, A->hasOneUse(), Builder, + DoesConsume, Depth)) { + if (Builder != nullptr) { + Value *NotB = getFreelyInvertedImpl(B, B->hasOneUse(), Builder, + DoesConsume, Depth); + assert(NotB != nullptr && + "Unable to build inverted value for known freely invertable op"); + if (auto *II = dyn_cast<IntrinsicInst>(V)) + return Builder->CreateBinaryIntrinsic( + getInverseMinMaxIntrinsic(II->getIntrinsicID()), NotA, NotB); + return Builder->CreateSelect(Cond, NotA, NotB); + } + return NonNull; + } + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Value *PtrOp = GEP.getOperand(0); SmallVector<Value *, 8> Indices(GEP.indices()); Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); - bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); + bool IsGEPSrcEleScalable = GEPEltType->isScalableTy(); if (Value *V = simplifyGEPInst(GEPEltType, PtrOp, Indices, GEP.isInBounds(), SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); @@ -2221,7 +2394,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { NewGEP->setOperand(DI, NewPN); } - NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); + NewGEP->insertBefore(*GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); return replaceOperand(GEP, 0, NewGEP); } @@ -2264,11 +2437,43 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { return CastInst::CreatePointerBitCastOrAddrSpaceCast(Y, GEPType); } } - // We do not handle pointer-vector geps here. if (GEPType->isVectorTy()) return nullptr; + if (GEP.getNumIndices() == 1) { + // Try to replace ADD + GEP with GEP + GEP. + Value *Idx1, *Idx2; + if (match(GEP.getOperand(1), + m_OneUse(m_Add(m_Value(Idx1), m_Value(Idx2))))) { + // %idx = add i64 %idx1, %idx2 + // %gep = getelementptr i32, ptr %ptr, i64 %idx + // as: + // %newptr = getelementptr i32, ptr %ptr, i64 %idx1 + // %newgep = getelementptr i32, ptr %newptr, i64 %idx2 + auto *NewPtr = Builder.CreateGEP(GEP.getResultElementType(), + GEP.getPointerOperand(), Idx1); + return GetElementPtrInst::Create(GEP.getResultElementType(), NewPtr, + Idx2); + } + ConstantInt *C; + if (match(GEP.getOperand(1), m_OneUse(m_SExt(m_OneUse(m_NSWAdd( + m_Value(Idx1), m_ConstantInt(C))))))) { + // %add = add nsw i32 %idx1, idx2 + // %sidx = sext i32 %add to i64 + // %gep = getelementptr i32, ptr %ptr, i64 %sidx + // as: + // %newptr = getelementptr i32, ptr %ptr, i32 %idx1 + // %newgep = getelementptr i32, ptr %newptr, i32 idx2 + auto *NewPtr = Builder.CreateGEP( + GEP.getResultElementType(), GEP.getPointerOperand(), + Builder.CreateSExt(Idx1, GEP.getOperand(1)->getType())); + return GetElementPtrInst::Create( + GEP.getResultElementType(), NewPtr, + Builder.CreateSExt(C, GEP.getOperand(1)->getType())); + } + } + if (!GEP.isInBounds()) { unsigned IdxWidth = DL.getIndexSizeInBits(PtrOp->getType()->getPointerAddressSpace()); @@ -2362,6 +2567,26 @@ static bool isAllocSiteRemovable(Instruction *AI, unsigned OtherIndex = (ICI->getOperand(0) == PI) ? 1 : 0; if (!isNeverEqualToUnescapedAlloc(ICI->getOperand(OtherIndex), TLI, AI)) return false; + + // Do not fold compares to aligned_alloc calls, as they may have to + // return null in case the required alignment cannot be satisfied, + // unless we can prove that both alignment and size are valid. + auto AlignmentAndSizeKnownValid = [](CallBase *CB) { + // Check if alignment and size of a call to aligned_alloc is valid, + // that is alignment is a power-of-2 and the size is a multiple of the + // alignment. + const APInt *Alignment; + const APInt *Size; + return match(CB->getArgOperand(0), m_APInt(Alignment)) && + match(CB->getArgOperand(1), m_APInt(Size)) && + Alignment->isPowerOf2() && Size->urem(*Alignment).isZero(); + }; + auto *CB = dyn_cast<CallBase>(AI); + LibFunc TheLibFunc; + if (CB && TLI.getLibFunc(*CB->getCalledFunction(), TheLibFunc) && + TLI.has(TheLibFunc) && TheLibFunc == LibFunc_aligned_alloc && + !AlignmentAndSizeKnownValid(CB)) + return false; Users.emplace_back(I); continue; } @@ -2451,9 +2676,10 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If we are removing an alloca with a dbg.declare, insert dbg.value calls // before each store. SmallVector<DbgVariableIntrinsic *, 8> DVIs; + SmallVector<DPValue *, 8> DPVs; std::unique_ptr<DIBuilder> DIB; if (isa<AllocaInst>(MI)) { - findDbgUsers(DVIs, &MI); + findDbgUsers(DVIs, &MI, &DPVs); DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); } @@ -2493,6 +2719,9 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { for (auto *DVI : DVIs) if (DVI->isAddressOfVariable()) ConvertDebugDeclareToDebugValue(DVI, SI, *DIB); + for (auto *DPV : DPVs) + if (DPV->isAddressOfVariable()) + ConvertDebugDeclareToDebugValue(DPV, SI, *DIB); } else { // Casts, GEP, or anything else: we're about to delete this instruction, // so it can not have any valid uses. @@ -2531,9 +2760,15 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { // If there is a dead store to `%a` in @trivially_inlinable_no_op, the // "arg0" dbg.value may be stale after the call. However, failing to remove // the DW_OP_deref dbg.value causes large gaps in location coverage. + // + // FIXME: the Assignment Tracking project has now likely made this + // redundant (and it's sometimes harmful). for (auto *DVI : DVIs) if (DVI->isAddressOfVariable() || DVI->getExpression()->startsWithDeref()) DVI->eraseFromParent(); + for (auto *DPV : DPVs) + if (DPV->isAddressOfVariable() || DPV->getExpression()->startsWithDeref()) + DPV->eraseFromParent(); return eraseInstFromFunction(MI); } @@ -2612,7 +2847,7 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) { if (&Instr == FreeInstrBBTerminator) break; - Instr.moveBefore(TI); + Instr.moveBeforePreserving(TI); } assert(FreeInstrBB->size() == 1 && "Only the branch instruction should remain"); @@ -2746,55 +2981,77 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { return nullptr; } +void InstCombinerImpl::addDeadEdge(BasicBlock *From, BasicBlock *To, + SmallVectorImpl<BasicBlock *> &Worklist) { + if (!DeadEdges.insert({From, To}).second) + return; + + // Replace phi node operands in successor with poison. + for (PHINode &PN : To->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == From && !isa<PoisonValue>(U)) { + replaceUse(U, PoisonValue::get(PN.getType())); + addToWorklist(&PN); + MadeIRChange = true; + } + + Worklist.push_back(To); +} + // Under the assumption that I is unreachable, remove it and following -// instructions. -bool InstCombinerImpl::handleUnreachableFrom(Instruction *I) { - bool Changed = false; +// instructions. Changes are reported directly to MadeIRChange. +void InstCombinerImpl::handleUnreachableFrom( + Instruction *I, SmallVectorImpl<BasicBlock *> &Worklist) { BasicBlock *BB = I->getParent(); for (Instruction &Inst : make_early_inc_range( make_range(std::next(BB->getTerminator()->getReverseIterator()), std::next(I->getReverseIterator())))) { if (!Inst.use_empty() && !Inst.getType()->isTokenTy()) { replaceInstUsesWith(Inst, PoisonValue::get(Inst.getType())); - Changed = true; + MadeIRChange = true; } if (Inst.isEHPad() || Inst.getType()->isTokenTy()) continue; + // RemoveDIs: erase debug-info on this instruction manually. + Inst.dropDbgValues(); eraseInstFromFunction(Inst); - Changed = true; + MadeIRChange = true; } - // Replace phi node operands in successor blocks with poison. + // RemoveDIs: to match behaviour in dbg.value mode, drop debug-info on + // terminator too. + BB->getTerminator()->dropDbgValues(); + + // Handle potentially dead successors. for (BasicBlock *Succ : successors(BB)) - for (PHINode &PN : Succ->phis()) - for (Use &U : PN.incoming_values()) - if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { - replaceUse(U, PoisonValue::get(PN.getType())); - addToWorklist(&PN); - Changed = true; - } + addDeadEdge(BB, Succ, Worklist); +} - // TODO: Successor blocks may also be dead. - return Changed; +void InstCombinerImpl::handlePotentiallyDeadBlocks( + SmallVectorImpl<BasicBlock *> &Worklist) { + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + if (!all_of(predecessors(BB), [&](BasicBlock *Pred) { + return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred); + })) + continue; + + handleUnreachableFrom(&BB->front(), Worklist); + } } -bool InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, +void InstCombinerImpl::handlePotentiallyDeadSuccessors(BasicBlock *BB, BasicBlock *LiveSucc) { - bool Changed = false; + SmallVector<BasicBlock *> Worklist; for (BasicBlock *Succ : successors(BB)) { // The live successor isn't dead. if (Succ == LiveSucc) continue; - if (!all_of(predecessors(Succ), [&](BasicBlock *Pred) { - return DT.dominates(BasicBlockEdge(BB, Succ), - BasicBlockEdge(Pred, Succ)); - })) - continue; - - Changed |= handleUnreachableFrom(&Succ->front()); + addDeadEdge(BB, Succ, Worklist); } - return Changed; + + handlePotentiallyDeadBlocks(Worklist); } Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { @@ -2840,14 +3097,17 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return &BI; } - if (isa<UndefValue>(Cond) && - handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr)) - return &BI; - if (auto *CI = dyn_cast<ConstantInt>(Cond)) - if (handlePotentiallyDeadSuccessors(BI.getParent(), - BI.getSuccessor(!CI->getZExtValue()))) - return &BI; + if (isa<UndefValue>(Cond)) { + handlePotentiallyDeadSuccessors(BI.getParent(), /*LiveSucc*/ nullptr); + return nullptr; + } + if (auto *CI = dyn_cast<ConstantInt>(Cond)) { + handlePotentiallyDeadSuccessors(BI.getParent(), + BI.getSuccessor(!CI->getZExtValue())); + return nullptr; + } + DC.registerBranch(&BI); return nullptr; } @@ -2866,14 +3126,6 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, Op0); } - if (isa<UndefValue>(Cond) && - handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr)) - return &SI; - if (auto *CI = dyn_cast<ConstantInt>(Cond)) - if (handlePotentiallyDeadSuccessors( - SI.getParent(), SI.findCaseValue(CI)->getCaseSuccessor())) - return &SI; - KnownBits Known = computeKnownBits(Cond, 0, &SI); unsigned LeadingKnownZeros = Known.countMinLeadingZeros(); unsigned LeadingKnownOnes = Known.countMinLeadingOnes(); @@ -2906,6 +3158,16 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return replaceOperand(SI, 0, NewCond); } + if (isa<UndefValue>(Cond)) { + handlePotentiallyDeadSuccessors(SI.getParent(), /*LiveSucc*/ nullptr); + return nullptr; + } + if (auto *CI = dyn_cast<ConstantInt>(Cond)) { + handlePotentiallyDeadSuccessors(SI.getParent(), + SI.findCaseValue(CI)->getCaseSuccessor()); + return nullptr; + } + return nullptr; } @@ -3532,7 +3794,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, Value *StartV = StartU->get(); BasicBlock *StartBB = PN->getIncomingBlock(*StartU); bool StartNeedsFreeze = !isGuaranteedNotToBeUndefOrPoison(StartV); - // We can't insert freeze if the the start value is the result of the + // We can't insert freeze if the start value is the result of the // terminator (e.g. an invoke). if (StartNeedsFreeze && StartBB->getTerminator() == StartV) return nullptr; @@ -3583,19 +3845,27 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { // *all* uses if the operand is an invoke/callbr and the use is in a phi on // the normal/default destination. This is why the domination check in the // replacement below is still necessary. - Instruction *MoveBefore; + BasicBlock::iterator MoveBefore; if (isa<Argument>(Op)) { MoveBefore = - &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); + FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); } else { - MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef(); - if (!MoveBefore) + auto MoveBeforeOpt = cast<Instruction>(Op)->getInsertionPointAfterDef(); + if (!MoveBeforeOpt) return false; + MoveBefore = *MoveBeforeOpt; } + // Don't move to the position of a debug intrinsic. + if (isa<DbgInfoIntrinsic>(MoveBefore)) + MoveBefore = MoveBefore->getNextNonDebugInstruction()->getIterator(); + // Re-point iterator to come after any debug-info records, if we're + // running in "RemoveDIs" mode + MoveBefore.setHeadBit(false); + bool Changed = false; - if (&FI != MoveBefore) { - FI.moveBefore(MoveBefore); + if (&FI != &*MoveBefore) { + FI.moveBefore(*MoveBefore->getParent(), MoveBefore); Changed = true; } @@ -3798,7 +4068,7 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, /// the new position. BasicBlock::iterator InsertPos = DestBlock->getFirstInsertionPt(); - I->moveBefore(&*InsertPos); + I->moveBefore(*DestBlock, InsertPos); ++NumSunkInst; // Also sink all related debug uses from the source basic block. Otherwise we @@ -3808,10 +4078,19 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // here, but that computation has been sunk. SmallVector<DbgVariableIntrinsic *, 2> DbgUsers; findDbgUsers(DbgUsers, I); - // Process the sinking DbgUsers in reverse order, as we only want to clone the - // last appearing debug intrinsic for each given variable. + + // For all debug values in the destination block, the sunk instruction + // will still be available, so they do not need to be dropped. + SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSalvage; + SmallVector<DPValue *, 2> DPValuesToSalvage; + for (auto &DbgUser : DbgUsers) + if (DbgUser->getParent() != DestBlock) + DbgUsersToSalvage.push_back(DbgUser); + + // Process the sinking DbgUsersToSalvage in reverse order, as we only want + // to clone the last appearing debug intrinsic for each given variable. SmallVector<DbgVariableIntrinsic *, 2> DbgUsersToSink; - for (DbgVariableIntrinsic *DVI : DbgUsers) + for (DbgVariableIntrinsic *DVI : DbgUsersToSalvage) if (DVI->getParent() == SrcBlock) DbgUsersToSink.push_back(DVI); llvm::sort(DbgUsersToSink, @@ -3847,7 +4126,10 @@ bool InstCombinerImpl::tryToSinkInstruction(Instruction *I, // Perform salvaging without the clones, then sink the clones. if (!DIIClones.empty()) { - salvageDebugInfoForDbgValues(*I, DbgUsers); + // RemoveDIs: pass in empty vector of DPValues until we get to instrumenting + // this pass. + SmallVector<DPValue *, 1> DummyDPValues; + salvageDebugInfoForDbgValues(*I, DbgUsersToSalvage, DummyDPValues); // The clones are in reverse order of original appearance, reverse again to // maintain the original order. for (auto &DIIClone : llvm::reverse(DIIClones)) { @@ -4093,43 +4375,52 @@ public: } }; -/// Populate the IC worklist from a function, by walking it in depth-first -/// order and adding all reachable code to the worklist. +/// Populate the IC worklist from a function, by walking it in reverse +/// post-order and adding all reachable code to the worklist. /// /// This has a couple of tricks to make the code faster and more powerful. In /// particular, we constant fold and DCE instructions as we go, to avoid adding /// them to the worklist (this significantly speeds up instcombine on code where /// many instructions are dead or constant). Additionally, if we find a branch /// whose condition is a known constant, we only visit the reachable successors. -static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, - const TargetLibraryInfo *TLI, - InstructionWorklist &ICWorklist) { +bool InstCombinerImpl::prepareWorklist( + Function &F, ReversePostOrderTraversal<BasicBlock *> &RPOT) { bool MadeIRChange = false; - SmallPtrSet<BasicBlock *, 32> Visited; - SmallVector<BasicBlock*, 256> Worklist; - Worklist.push_back(&F.front()); - + SmallPtrSet<BasicBlock *, 32> LiveBlocks; SmallVector<Instruction *, 128> InstrsForInstructionWorklist; DenseMap<Constant *, Constant *> FoldedConstants; AliasScopeTracker SeenAliasScopes; - do { - BasicBlock *BB = Worklist.pop_back_val(); + auto HandleOnlyLiveSuccessor = [&](BasicBlock *BB, BasicBlock *LiveSucc) { + for (BasicBlock *Succ : successors(BB)) + if (Succ != LiveSucc && DeadEdges.insert({BB, Succ}).second) + for (PHINode &PN : Succ->phis()) + for (Use &U : PN.incoming_values()) + if (PN.getIncomingBlock(U) == BB && !isa<PoisonValue>(U)) { + U.set(PoisonValue::get(PN.getType())); + MadeIRChange = true; + } + }; - // We have now visited this block! If we've already been here, ignore it. - if (!Visited.insert(BB).second) + for (BasicBlock *BB : RPOT) { + if (!BB->isEntryBlock() && all_of(predecessors(BB), [&](BasicBlock *Pred) { + return DeadEdges.contains({Pred, BB}) || DT.dominates(BB, Pred); + })) { + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } + LiveBlocks.insert(BB); for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { // ConstantProp instruction if trivially constant. if (!Inst.use_empty() && (Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0)))) - if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) { + if (Constant *C = ConstantFoldInstruction(&Inst, DL, &TLI)) { LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst << '\n'); Inst.replaceAllUsesWith(C); ++NumConstProp; - if (isInstructionTriviallyDead(&Inst, TLI)) + if (isInstructionTriviallyDead(&Inst, &TLI)) Inst.eraseFromParent(); MadeIRChange = true; continue; @@ -4143,7 +4434,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, auto *C = cast<Constant>(U); Constant *&FoldRes = FoldedConstants[C]; if (!FoldRes) - FoldRes = ConstantFoldConstant(C, DL, TLI); + FoldRes = ConstantFoldConstant(C, DL, &TLI); if (FoldRes != C) { LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst @@ -4163,37 +4454,39 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, } } - // Recursively visit successors. If this is a branch or switch on a - // constant, only visit the reachable successor. + // If this is a branch or switch on a constant, mark only the single + // live successor. Otherwise assume all successors are live. Instruction *TI = BB->getTerminator(); if (BranchInst *BI = dyn_cast<BranchInst>(TI); BI && BI->isConditional()) { - if (isa<UndefValue>(BI->getCondition())) + if (isa<UndefValue>(BI->getCondition())) { // Branch on undef is UB. + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } if (auto *Cond = dyn_cast<ConstantInt>(BI->getCondition())) { bool CondVal = Cond->getZExtValue(); - BasicBlock *ReachableBB = BI->getSuccessor(!CondVal); - Worklist.push_back(ReachableBB); + HandleOnlyLiveSuccessor(BB, BI->getSuccessor(!CondVal)); continue; } } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { - if (isa<UndefValue>(SI->getCondition())) + if (isa<UndefValue>(SI->getCondition())) { // Switch on undef is UB. + HandleOnlyLiveSuccessor(BB, nullptr); continue; + } if (auto *Cond = dyn_cast<ConstantInt>(SI->getCondition())) { - Worklist.push_back(SI->findCaseValue(Cond)->getCaseSuccessor()); + HandleOnlyLiveSuccessor(BB, + SI->findCaseValue(Cond)->getCaseSuccessor()); continue; } } - - append_range(Worklist, successors(TI)); - } while (!Worklist.empty()); + } // Remove instructions inside unreachable blocks. This prevents the // instcombine code from having to deal with some bad special cases, and // reduces use counts of instructions. for (BasicBlock &BB : F) { - if (Visited.count(&BB)) + if (LiveBlocks.count(&BB)) continue; unsigned NumDeadInstInBB; @@ -4210,11 +4503,11 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // of the function down. This jives well with the way that it adds all uses // of instructions to the worklist after doing a transformation, thus avoiding // some N^2 behavior in pathological cases. - ICWorklist.reserve(InstrsForInstructionWorklist.size()); + Worklist.reserve(InstrsForInstructionWorklist.size()); for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) { // DCE instruction if trivially dead. As we iterate in reverse program // order here, we will clean up whole chains of dead instructions. - if (isInstructionTriviallyDead(Inst, TLI) || + if (isInstructionTriviallyDead(Inst, &TLI) || SeenAliasScopes.isNoAliasScopeDeclDead(Inst)) { ++NumDeadInst; LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); @@ -4224,7 +4517,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, continue; } - ICWorklist.push(Inst); + Worklist.push(Inst); } return MadeIRChange; @@ -4234,7 +4527,7 @@ static bool combineInstructionsOverFunction( Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { + ProfileSummaryInfo *PSI, LoopInfo *LI, const InstCombineOptions &Opts) { auto &DL = F.getParent()->getDataLayout(); /// Builder - This is an IRBuilder that automatically inserts new @@ -4247,6 +4540,8 @@ static bool combineInstructionsOverFunction( AC.registerAssumption(Assume); })); + ReversePostOrderTraversal<BasicBlock *> RPOT(&F.front()); + // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. bool MadeIRChange = false; @@ -4256,35 +4551,33 @@ static bool combineInstructionsOverFunction( // Iterate while there is work to do. unsigned Iteration = 0; while (true) { - ++NumWorklistIterations; ++Iteration; - if (Iteration > InfiniteLoopDetectionThreshold) { - report_fatal_error( - "Instruction Combining seems stuck in an infinite loop after " + - Twine(InfiniteLoopDetectionThreshold) + " iterations."); - } - - if (Iteration > MaxIterations) { - LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations + if (Iteration > Opts.MaxIterations && !Opts.VerifyFixpoint) { + LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << Opts.MaxIterations << " on " << F.getName() - << " reached; stopping before reaching a fixpoint\n"); + << " reached; stopping without verifying fixpoint\n"); break; } + ++NumWorklistIterations; LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); - MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombinerImpl IC(Worklist, Builder, F.hasMinSize(), AA, AC, TLI, TTI, DT, ORE, BFI, PSI, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; - - if (!IC.run()) + bool MadeChangeInThisIteration = IC.prepareWorklist(F, RPOT); + MadeChangeInThisIteration |= IC.run(); + if (!MadeChangeInThisIteration) break; MadeIRChange = true; + if (Iteration > Opts.MaxIterations) { + report_fatal_error( + "Instruction Combining did not reach a fixpoint after " + + Twine(Opts.MaxIterations) + " iterations"); + } } if (Iteration == 1) @@ -4307,7 +4600,8 @@ void InstCombinePass::printPipeline( OS, MapClassName2PassName); OS << '<'; OS << "max-iterations=" << Options.MaxIterations << ";"; - OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info"; + OS << (Options.UseLoopInfo ? "" : "no-") << "use-loop-info;"; + OS << (Options.VerifyFixpoint ? "" : "no-") << "verify-fixpoint"; OS << '>'; } @@ -4333,7 +4627,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, Options.MaxIterations, LI)) + BFI, PSI, LI, Options)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -4382,8 +4676,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { nullptr; return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, TTI, DT, ORE, - BFI, PSI, - InstCombineDefaultMaxIterations, LI); + BFI, PSI, LI, InstCombineOptions()); } char InstructionCombiningPass::ID = 0; |
