diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:04 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2023-02-11 12:38:11 +0000 |
| commit | e3b557809604d036af6e00c60f012c2025b59a5e (patch) | |
| tree | 8a11ba2269a3b669601e2fd41145b174008f4da8 /llvm/lib/Transforms/InstCombine | |
| parent | 08e8dd7b9db7bb4a9de26d44c1cbfd24e869c014 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
16 files changed, 3646 insertions, 1714 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 4a459ec6c550..b68efc993723 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -576,8 +576,7 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { } } - assert((NextTmpIdx <= array_lengthof(TmpResult) + 1) && - "out-of-bound access"); + assert((NextTmpIdx <= std::size(TmpResult) + 1) && "out-of-bound access"); Value *Result; if (!SimpVect.empty()) @@ -849,6 +848,7 @@ static Instruction *foldNoWrapAdd(BinaryOperator &Add, Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Type *Ty = Add.getType(); Constant *Op1C; if (!match(Op1, m_ImmConstant(Op1C))) return nullptr; @@ -883,7 +883,14 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { if (match(Op0, m_Not(m_Value(X)))) return BinaryOperator::CreateSub(InstCombiner::SubOne(Op1C), X); + // (iN X s>> (N - 1)) + 1 --> zext (X > -1) const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + if (match(Op0, m_OneUse(m_AShr(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1)))) && + match(Op1, m_One())) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (!match(Op1, m_APInt(C))) return nullptr; @@ -911,7 +918,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X - Type *Ty = Add.getType(); if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); @@ -969,15 +975,6 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } - // If all bits affected by the add are included in a high-bit-mask, do the - // add before the mask op: - // (X & 0xFF00) + xx00 --> (X + xx00) & 0xFF00 - if (match(Op0, m_OneUse(m_And(m_Value(X), m_APInt(C2)))) && - C2->isNegative() && C2->isShiftedMask() && *C == (*C & *C2)) { - Value *NewAdd = Builder.CreateAdd(X, ConstantInt::get(Ty, *C)); - return BinaryOperator::CreateAnd(NewAdd, ConstantInt::get(Ty, *C2)); - } - return nullptr; } @@ -1132,6 +1129,35 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +/// Try to reduce signed division by power-of-2 to an arithmetic shift right. +static Instruction *foldAddToAshr(BinaryOperator &Add) { + // Division must be by power-of-2, but not the minimum signed value. + Value *X; + const APInt *DivC; + if (!match(Add.getOperand(0), m_SDiv(m_Value(X), m_Power2(DivC))) || + DivC->isNegative()) + return nullptr; + + // Rounding is done by adding -1 if the dividend (X) is negative and has any + // low bits set. The canonical pattern for that is an "ugt" compare with SMIN: + // sext (icmp ugt (X & (DivC - 1)), SMIN) + const APInt *MaskC; + ICmpInst::Predicate Pred; + if (!match(Add.getOperand(1), + m_SExt(m_ICmp(Pred, m_And(m_Specific(X), m_APInt(MaskC)), + m_SignMask()))) || + Pred != ICmpInst::ICMP_UGT) + return nullptr; + + APInt SMin = APInt::getSignedMinValue(Add.getType()->getScalarSizeInBits()); + if (*MaskC != (SMin | (*DivC - 1))) + return nullptr; + + // (X / DivC) + sext ((X & (SMin | (DivC - 1)) >u SMin) --> X >>s log2(DivC) + return BinaryOperator::CreateAShr( + X, ConstantInt::get(Add.getType(), DivC->exactLogBase2())); +} + Instruction *InstCombinerImpl:: canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( BinaryOperator &I) { @@ -1234,7 +1260,7 @@ Instruction *InstCombinerImpl:: } /// This is a specialization of a more general transform from -/// SimplifyUsingDistributiveLaws. If that code can be made to work optimally +/// foldUsingDistributiveLaws. If that code can be made to work optimally /// for multi-use cases or propagating nsw/nuw, then we would not need this. static Instruction *factorizeMathWithShlOps(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1270,6 +1296,45 @@ static Instruction *factorizeMathWithShlOps(BinaryOperator &I, return NewShl; } +/// Reduce a sequence of masked half-width multiplies to a single multiply. +/// ((XLow * YHigh) + (YLow * XHigh)) << HalfBits) + (XLow * YLow) --> X * Y +static Instruction *foldBoxMultiply(BinaryOperator &I) { + unsigned BitWidth = I.getType()->getScalarSizeInBits(); + // Skip the odd bitwidth types. + if ((BitWidth & 0x1)) + return nullptr; + + unsigned HalfBits = BitWidth >> 1; + APInt HalfMask = APInt::getMaxValue(HalfBits); + + // ResLo = (CrossSum << HalfBits) + (YLo * XLo) + Value *XLo, *YLo; + Value *CrossSum; + if (!match(&I, m_c_Add(m_Shl(m_Value(CrossSum), m_SpecificInt(HalfBits)), + m_Mul(m_Value(YLo), m_Value(XLo))))) + return nullptr; + + // XLo = X & HalfMask + // YLo = Y & HalfMask + // TODO: Refactor with SimplifyDemandedBits or KnownBits known leading zeros + // to enhance robustness + Value *X, *Y; + if (!match(XLo, m_And(m_Value(X), m_SpecificInt(HalfMask))) || + !match(YLo, m_And(m_Value(Y), m_SpecificInt(HalfMask)))) + return nullptr; + + // CrossSum = (X' * (Y >> Halfbits)) + (Y' * (X >> HalfBits)) + // X' can be either X or XLo in the pattern (and the same for Y') + if (match(CrossSum, + m_c_Add(m_c_Mul(m_LShr(m_Specific(Y), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(X), m_Specific(XLo))), + m_c_Mul(m_LShr(m_Specific(X), m_SpecificInt(HalfBits)), + m_CombineOr(m_Specific(Y), m_Specific(YLo)))))) + return BinaryOperator::CreateMul(X, Y); + + return nullptr; +} + Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (Value *V = simplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1286,9 +1351,12 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return Phi; // (A*B)+(A*C) -> A*(B+C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); + if (Instruction *R = foldBoxMultiply(I)) + return R; + if (Instruction *R = factorizeMathWithShlOps(I, Builder)) return R; @@ -1376,35 +1444,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, NewMask); } + // ZExt (B - A) + ZExt(A) --> ZExt(B) + if ((match(RHS, m_ZExt(m_Value(A))) && + match(LHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A))))) || + (match(LHS, m_ZExt(m_Value(A))) && + match(RHS, m_ZExt(m_NUWSub(m_Value(B), m_Specific(A)))))) + return new ZExtInst(B, LHS->getType()); + // A+B --> A|B iff A and B have no bits set in common. if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - // add (select X 0 (sub n A)) A --> select X A n - { - SelectInst *SI = dyn_cast<SelectInst>(LHS); - Value *A = RHS; - if (!SI) { - SI = dyn_cast<SelectInst>(RHS); - A = LHS; - } - if (SI && SI->hasOneUse()) { - Value *TV = SI->getTrueValue(); - Value *FV = SI->getFalseValue(); - Value *N; - - // Can we fold the add into the argument of the select? - // We check both true and false select arguments for a matching subtract. - if (match(FV, m_Zero()) && match(TV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the true select value. - return SelectInst::Create(SI->getCondition(), N, A); - - if (match(TV, m_Zero()) && match(FV, m_Sub(m_Value(N), m_Specific(A)))) - // Fold the add into the false select value. - return SelectInst::Create(SI->getCondition(), A, N); - } - } - if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1424,6 +1474,68 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { return &I; } + // (add A (or A, -A)) --> (and (add A, -1) A) + // (add A (or -A, A)) --> (and (add A, -1) A) + // (add (or A, -A) A) --> (and (add A, -1) A) + // (add (or -A, A) A) --> (and (add A, -1) A) + if (match(&I, m_c_BinOp(m_Value(A), m_OneUse(m_c_Or(m_Neg(m_Deferred(A)), + m_Deferred(A)))))) { + Value *Add = + Builder.CreateAdd(A, Constant::getAllOnesValue(A->getType()), "", + I.hasNoUnsignedWrap(), I.hasNoSignedWrap()); + return BinaryOperator::CreateAnd(Add, A); + } + + // Canonicalize ((A & -A) - 1) --> ((A - 1) & ~A) + // Forms all commutable operations, and simplifies ctpop -> cttz folds. + if (match(&I, + m_Add(m_OneUse(m_c_And(m_Value(A), m_OneUse(m_Neg(m_Deferred(A))))), + m_AllOnes()))) { + Constant *AllOnes = ConstantInt::getAllOnesValue(RHS->getType()); + Value *Dec = Builder.CreateAdd(A, AllOnes); + Value *Not = Builder.CreateXor(A, AllOnes); + return BinaryOperator::CreateAnd(Dec, Not); + } + + // Disguised reassociation/factorization: + // ~(A * C1) + A + // ((A * -C1) - 1) + A + // ((A * -C1) + A) - 1 + // (A * (1 - C1)) - 1 + if (match(&I, + m_c_Add(m_OneUse(m_Not(m_OneUse(m_Mul(m_Value(A), m_APInt(C1))))), + m_Deferred(A)))) { + Type *Ty = I.getType(); + Constant *NewMulC = ConstantInt::get(Ty, 1 - *C1); + Value *NewMul = Builder.CreateMul(A, NewMulC); + return BinaryOperator::CreateAdd(NewMul, ConstantInt::getAllOnesValue(Ty)); + } + + // (A * -2**C) + B --> B - (A << C) + const APInt *NegPow2C; + if (match(&I, m_c_Add(m_OneUse(m_Mul(m_Value(A), m_NegatedPower2(NegPow2C))), + m_Value(B)))) { + Constant *ShiftAmtC = ConstantInt::get(Ty, NegPow2C->countTrailingZeros()); + Value *Shl = Builder.CreateShl(A, ShiftAmtC); + return BinaryOperator::CreateSub(B, Shl); + } + + // Canonicalize signum variant that ends in add: + // (A s>> (BW - 1)) + (zext (A s> 0)) --> (A s>> (BW - 1)) | (zext (A != 0)) + ICmpInst::Predicate Pred; + uint64_t BitWidth = Ty->getScalarSizeInBits(); + if (match(LHS, m_AShr(m_Value(A), m_SpecificIntAllowUndef(BitWidth - 1))) && + match(RHS, m_OneUse(m_ZExt( + m_OneUse(m_ICmp(Pred, m_Specific(A), m_ZeroInt()))))) && + Pred == CmpInst::ICMP_SGT) { + Value *NotZero = Builder.CreateIsNotNull(A, "isnotnull"); + Value *Zext = Builder.CreateZExt(NotZero, Ty, "isnotnull.zext"); + return BinaryOperator::CreateOr(LHS, Zext); + } + + if (Instruction *Ashr = foldAddToAshr(I)) + return Ashr; + // TODO(jingyue): Consider willNotOverflowSignedAdd and // willNotOverflowUnsignedAdd to reduce the number of invocations of // computeKnownBits. @@ -1665,6 +1777,11 @@ Instruction *InstCombinerImpl::visitFAdd(BinaryOperator &I) { return BinaryOperator::CreateFMulFMF(X, NewMulC, &I); } + // (-X - Y) + (X + Z) --> Z - Y + if (match(&I, m_c_FAdd(m_FSub(m_FNeg(m_Value(X)), m_Value(Y)), + m_c_FAdd(m_Deferred(X), m_Value(Z))))) + return BinaryOperator::CreateFSubFMF(Z, Y, &I); + if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1879,7 +1996,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return TryToNarrowDeduceFlags(); // Should have been handled in Negator! // (A*B)-(A*C) -> A*(B-C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (I.getType()->isIntOrIntVectorTy(1)) @@ -1967,12 +2084,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } const APInt *Op0C; - if (match(Op0, m_APInt(Op0C)) && Op0C->isMask()) { - // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known - // zero. - KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnes()) - return BinaryOperator::CreateXor(Op1, Op0); + if (match(Op0, m_APInt(Op0C))) { + if (Op0C->isMask()) { + // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known + // zero. + KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); + if ((*Op0C | RHSKnown.Zero).isAllOnes()) + return BinaryOperator::CreateXor(Op1, Op0); + } + + // C - ((C3 -nuw X) & C2) --> (C - (C2 & C3)) + (X & C2) when: + // (C3 - ((C2 & C3) - 1)) is pow2 + // ((C2 + C3) & ((C2 & C3) - 1)) == ((C2 & C3) - 1) + // C2 is negative pow2 || sub nuw + const APInt *C2, *C3; + BinaryOperator *InnerSub; + if (match(Op1, m_OneUse(m_And(m_BinOp(InnerSub), m_APInt(C2)))) && + match(InnerSub, m_Sub(m_APInt(C3), m_Value(X))) && + (InnerSub->hasNoUnsignedWrap() || C2->isNegatedPowerOf2())) { + APInt C2AndC3 = *C2 & *C3; + APInt C2AndC3Minus1 = C2AndC3 - 1; + APInt C2AddC3 = *C2 + *C3; + if ((*C3 - C2AndC3Minus1).isPowerOf2() && + C2AndC3Minus1.isSubsetOf(C2AddC3)) { + Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), *C2)); + return BinaryOperator::CreateAdd( + And, ConstantInt::get(I.getType(), *Op0C - C2AndC3)); + } + } } { @@ -2165,8 +2304,9 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { Value *A; const APInt *ShAmt; Type *Ty = I.getType(); + unsigned BitWidth = Ty->getScalarSizeInBits(); if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && - Op1->hasNUses(2) && *ShAmt == Ty->getScalarSizeInBits() - 1 && + Op1->hasNUses(2) && *ShAmt == BitWidth - 1 && match(Op0, m_OneUse(m_c_Xor(m_Specific(A), m_Specific(Op1))))) { // B = ashr i32 A, 31 ; smear the sign bit // sub (xor A, B), B ; flip bits if negative and subtract -1 (add 1) @@ -2185,7 +2325,6 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { const APInt *AddC, *AndC; if (match(Op0, m_Add(m_Value(X), m_APInt(AddC))) && match(Op1, m_And(m_Specific(X), m_APInt(AndC)))) { - unsigned BitWidth = Ty->getScalarSizeInBits(); unsigned Cttz = AddC->countTrailingZeros(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); if ((HighMask & *AndC).isZero()) @@ -2227,18 +2366,34 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { } // C - ctpop(X) => ctpop(~X) if C is bitwidth - if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && + if (match(Op0, m_SpecificInt(BitWidth)) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::ctpop, {I.getType()}, {Builder.CreateNot(X)})); + // Reduce multiplies for difference-of-squares by factoring: + // (X * X) - (Y * Y) --> (X + Y) * (X - Y) + if (match(Op0, m_OneUse(m_Mul(m_Value(X), m_Deferred(X)))) && + match(Op1, m_OneUse(m_Mul(m_Value(Y), m_Deferred(Y))))) { + auto *OBO0 = cast<OverflowingBinaryOperator>(Op0); + auto *OBO1 = cast<OverflowingBinaryOperator>(Op1); + bool PropagateNSW = I.hasNoSignedWrap() && OBO0->hasNoSignedWrap() && + OBO1->hasNoSignedWrap() && BitWidth > 2; + bool PropagateNUW = I.hasNoUnsignedWrap() && OBO0->hasNoUnsignedWrap() && + OBO1->hasNoUnsignedWrap() && BitWidth > 1; + Value *Add = Builder.CreateAdd(X, Y, "add", PropagateNUW, PropagateNSW); + Value *Sub = Builder.CreateSub(X, Y, "sub", PropagateNUW, PropagateNSW); + Value *Mul = Builder.CreateMul(Add, Sub, "", PropagateNUW, PropagateNSW); + return replaceInstUsesWith(I, Mul); + } + return TryToNarrowDeduceFlags(); } /// This eliminates floating-point negation in either 'fneg(X)' or /// 'fsub(-0.0, X)' form by combining into a constant operand. -static Instruction *foldFNegIntoConstant(Instruction &I) { +static Instruction *foldFNegIntoConstant(Instruction &I, const DataLayout &DL) { // This is limited with one-use because fneg is assumed better for // reassociation and cheaper in codegen than fmul/fdiv. // TODO: Should the m_OneUse restriction be removed? @@ -2252,28 +2407,31 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { // Fold negation into constant operand. // -(X * C) --> X * (-C) if (match(FNegOp, m_FMul(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFMulFMF(X, NegC, &I); // -(X / C) --> X / (-C) if (match(FNegOp, m_FDiv(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(X, NegC, &I); // -(C / X) --> (-C) / X - if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) { - Instruction *FDiv = - BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (match(FNegOp, m_FDiv(m_Constant(C), m_Value(X)))) + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) { + Instruction *FDiv = BinaryOperator::CreateFDivFMF(NegC, X, &I); - // Intersect 'nsz' and 'ninf' because those special value exceptions may not - // apply to the fdiv. Everything else propagates from the fneg. - // TODO: We could propagate nsz/ninf from fdiv alone? - FastMathFlags FMF = I.getFastMathFlags(); - FastMathFlags OpFMF = FNegOp->getFastMathFlags(); - FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); - FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); - return FDiv; - } + // Intersect 'nsz' and 'ninf' because those special value exceptions may + // not apply to the fdiv. Everything else propagates from the fneg. + // TODO: We could propagate nsz/ninf from fdiv alone? + FastMathFlags FMF = I.getFastMathFlags(); + FastMathFlags OpFMF = FNegOp->getFastMathFlags(); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); + return FDiv; + } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: // -(X + C) --> -X + -C --> -C - X if (I.hasNoSignedZeros() && match(FNegOp, m_FAdd(m_Value(X), m_Constant(C)))) - return BinaryOperator::CreateFSubFMF(ConstantExpr::getFNeg(C), X, &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFSubFMF(NegC, X, &I); return nullptr; } @@ -2301,7 +2459,7 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { getSimplifyQuery().getWithInstruction(&I))) return replaceInstUsesWith(I, V); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; Value *X, *Y; @@ -2314,18 +2472,26 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) return R; + Value *OneUse; + if (!match(Op, m_OneUse(m_Value(OneUse)))) + return nullptr; + // Try to eliminate fneg if at least 1 arm of the select is negated. Value *Cond; - if (match(Op, m_OneUse(m_Select(m_Value(Cond), m_Value(X), m_Value(Y))))) { + if (match(OneUse, m_Select(m_Value(Cond), m_Value(X), m_Value(Y)))) { // Unlike most transforms, this one is not safe to propagate nsz unless - // it is present on the original select. (We are conservatively intersecting - // the nsz flags from the select and root fneg instruction.) + // it is present on the original select. We union the flags from the select + // and fneg and then remove nsz if needed. auto propagateSelectFMF = [&](SelectInst *S, bool CommonOperand) { S->copyFastMathFlags(&I); - if (auto *OldSel = dyn_cast<SelectInst>(Op)) + if (auto *OldSel = dyn_cast<SelectInst>(Op)) { + FastMathFlags FMF = I.getFastMathFlags(); + FMF |= OldSel->getFastMathFlags(); + S->setFastMathFlags(FMF); if (!OldSel->hasNoSignedZeros() && !CommonOperand && !isGuaranteedNotToBeUndefOrPoison(OldSel->getCondition())) S->setHasNoSignedZeros(false); + } }; // -(Cond ? -P : Y) --> Cond ? P : -Y Value *P; @@ -2344,6 +2510,21 @@ Instruction *InstCombinerImpl::visitFNeg(UnaryOperator &I) { } } + // fneg (copysign x, y) -> copysign x, (fneg y) + if (match(OneUse, m_CopySign(m_Value(X), m_Value(Y)))) { + // The source copysign has an additional value input, so we can't propagate + // flags the copysign doesn't also have. + FastMathFlags FMF = I.getFastMathFlags(); + FMF &= cast<FPMathOperator>(OneUse)->getFastMathFlags(); + + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(FMF); + + Value *NegY = Builder.CreateFNeg(Y); + Value *NewCopySign = Builder.CreateCopySign(X, NegY); + return replaceInstUsesWith(I, NewCopySign); + } + return nullptr; } @@ -2370,7 +2551,7 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { if (match(&I, m_FNeg(m_Value(Op)))) return UnaryOperator::CreateFNegFMF(Op, &I); - if (Instruction *X = foldFNegIntoConstant(I)) + if (Instruction *X = foldFNegIntoConstant(I, DL)) return X; if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) @@ -2409,7 +2590,8 @@ Instruction *InstCombinerImpl::visitFSub(BinaryOperator &I) { // But don't transform constant expressions because there's an inverse fold // for X + (-Y) --> X - Y. if (match(Op1, m_ImmConstant(C))) - return BinaryOperator::CreateFAddFMF(Op0, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFAddFMF(Op0, NegC, &I); // X - (-Y) --> X + Y if (match(Op1, m_FNeg(m_Value(Y)))) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 8253c575bc37..97a001b2ed32 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -233,17 +233,13 @@ static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pre /// the right hand side as a pair. /// LHS and RHS are the left hand side and the right hand side ICmps and PredL /// and PredR are their predicates, respectively. -static -Optional<std::pair<unsigned, unsigned>> -getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, - Value *&D, Value *&E, ICmpInst *LHS, - ICmpInst *RHS, - ICmpInst::Predicate &PredL, - ICmpInst::Predicate &PredR) { +static std::optional<std::pair<unsigned, unsigned>> getMaskedTypeForICmpPair( + Value *&A, Value *&B, Value *&C, Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { // Don't allow pointers. Splat vectors are fine. if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) - return None; + return std::nullopt; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -274,7 +270,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if LHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredL)) - return None; + return std::nullopt; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); @@ -288,7 +284,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, A = R12; D = R11; } else { - return None; + return std::nullopt; } E = R2; R1 = nullptr; @@ -316,7 +312,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, // Bail if RHS was a icmp that can't be decomposed into an equality. if (!ICmpInst::isEquality(PredR)) - return None; + return std::nullopt; // Look for ANDs on the right side of the RHS icmp. if (!Ok) { @@ -336,7 +332,7 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R1; Ok = true; } else { - return None; + return std::nullopt; } assert(Ok && "Failed to find AND on the right side of the RHS icmp."); @@ -358,7 +354,8 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, unsigned LeftType = getMaskedICmpType(A, B, C, PredL); unsigned RightType = getMaskedICmpType(A, D, E, PredR); - return Optional<std::pair<unsigned, unsigned>>(std::make_pair(LeftType, RightType)); + return std::optional<std::pair<unsigned, unsigned>>( + std::make_pair(LeftType, RightType)); } /// Try to fold (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) into a single @@ -526,7 +523,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, InstCombiner::BuilderTy &Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - Optional<std::pair<unsigned, unsigned>> MaskPair = + std::optional<std::pair<unsigned, unsigned>> MaskPair = getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); if (!MaskPair) return nullptr; @@ -1016,10 +1013,10 @@ struct IntPart { }; /// Match an extraction of bits from an integer. -static Optional<IntPart> matchIntPart(Value *V) { +static std::optional<IntPart> matchIntPart(Value *V) { Value *X; if (!match(V, m_OneUse(m_Trunc(m_Value(X))))) - return None; + return std::nullopt; unsigned NumOriginalBits = X->getType()->getScalarSizeInBits(); unsigned NumExtractedBits = V->getType()->getScalarSizeInBits(); @@ -1056,10 +1053,10 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, if (Cmp0->getPredicate() != Pred || Cmp1->getPredicate() != Pred) return nullptr; - Optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); - Optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); - Optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); - Optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); + std::optional<IntPart> L0 = matchIntPart(Cmp0->getOperand(0)); + std::optional<IntPart> R0 = matchIntPart(Cmp0->getOperand(1)); + std::optional<IntPart> L1 = matchIntPart(Cmp1->getOperand(0)); + std::optional<IntPart> R1 = matchIntPart(Cmp1->getOperand(1)); if (!L0 || !R0 || !L1 || !R1) return nullptr; @@ -1094,7 +1091,7 @@ Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, /// common operand with the constant. Callers are expected to call this with /// Cmp0/Cmp1 switched to handle logic op commutativity. static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, - bool IsAnd, + bool IsAnd, bool IsLogical, InstCombiner::BuilderTy &Builder, const SimplifyQuery &Q) { // Match an equality compare with a non-poison constant as Cmp0. @@ -1130,6 +1127,9 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, return nullptr; SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); } + if (IsLogical) + return IsAnd ? Builder.CreateLogicalAnd(Cmp0, SubstituteCmp) + : Builder.CreateLogicalOr(Cmp0, SubstituteCmp); return Builder.CreateBinOp(IsAnd ? Instruction::And : Instruction::Or, Cmp0, SubstituteCmp); } @@ -1174,7 +1174,7 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, Type *Ty = V1->getType(); Value *NewV = V1; - Optional<ConstantRange> CR = CR1.exactUnionWith(CR2); + std::optional<ConstantRange> CR = CR1.exactUnionWith(CR2); if (!CR) { if (!(ICmp1->hasOneUse() && ICmp2->hasOneUse()) || CR1.isWrappedSet() || CR2.isWrappedSet()) @@ -1205,6 +1205,47 @@ Value *InstCombinerImpl::foldAndOrOfICmpsUsingRanges(ICmpInst *ICmp1, return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); } +/// Ignore all operations which only change the sign of a value, returning the +/// underlying magnitude value. +static Value *stripSignOnlyFPOps(Value *Val) { + match(Val, m_FNeg(m_Value(Val))); + match(Val, m_FAbs(m_Value(Val))); + match(Val, m_CopySign(m_Value(Val), m_Value())); + return Val; +} + +/// Matches canonical form of isnan, fcmp ord x, 0 +static bool matchIsNotNaN(FCmpInst::Predicate P, Value *LHS, Value *RHS) { + return P == FCmpInst::FCMP_ORD && match(RHS, m_AnyZeroFP()); +} + +/// Matches fcmp u__ x, +/-inf +static bool matchUnorderedInfCompare(FCmpInst::Predicate P, Value *LHS, + Value *RHS) { + return FCmpInst::isUnordered(P) && match(RHS, m_Inf()); +} + +/// and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf +/// +/// Clang emits this pattern for doing an isfinite check in __builtin_isnormal. +static Value *matchIsFiniteTest(InstCombiner::BuilderTy &Builder, FCmpInst *LHS, + FCmpInst *RHS) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + + if (!matchIsNotNaN(PredL, LHS0, LHS1) || + !matchUnorderedInfCompare(PredR, RHS0, RHS1)) + return nullptr; + + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + FastMathFlags FMF = LHS->getFastMathFlags(); + FMF &= RHS->getFastMathFlags(); + Builder.setFastMathFlags(FMF); + + return Builder.CreateFCmp(FCmpInst::getOrderedPredicate(PredR), RHS0, RHS1); +} + Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, bool IsLogicalSelect) { Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); @@ -1263,9 +1304,79 @@ Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, return Builder.CreateFCmp(PredL, LHS0, RHS0); } + if (IsAnd && stripSignOnlyFPOps(LHS0) == stripSignOnlyFPOps(RHS0)) { + // and (fcmp ord x, 0), (fcmp u* x, inf) -> fcmp o* x, inf + // and (fcmp ord x, 0), (fcmp u* fabs(x), inf) -> fcmp o* x, inf + if (Value *Left = matchIsFiniteTest(Builder, LHS, RHS)) + return Left; + if (Value *Right = matchIsFiniteTest(Builder, RHS, LHS)) + return Right; + } + return nullptr; } +/// or (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 | mask1) +/// and (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 & mask1) +/// xor (is_fpclass x, mask0), (is_fpclass x, mask1) +/// -> is_fpclass x, (mask0 ^ mask1) +Instruction *InstCombinerImpl::foldLogicOfIsFPClass(BinaryOperator &BO, + Value *Op0, Value *Op1) { + Value *ClassVal; + uint64_t ClassMask0, ClassMask1; + + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Value(ClassVal), m_ConstantInt(ClassMask0)))) && + match(Op1, m_OneUse(m_Intrinsic<Intrinsic::is_fpclass>( + m_Specific(ClassVal), m_ConstantInt(ClassMask1))))) { + unsigned NewClassMask; + switch (BO.getOpcode()) { + case Instruction::And: + NewClassMask = ClassMask0 & ClassMask1; + break; + case Instruction::Or: + NewClassMask = ClassMask0 | ClassMask1; + break; + case Instruction::Xor: + NewClassMask = ClassMask0 ^ ClassMask1; + break; + default: + llvm_unreachable("not a binary logic operator"); + } + + // TODO: Also check for special fcmps + auto *II = cast<IntrinsicInst>(Op0); + II->setArgOperand( + 1, ConstantInt::get(II->getArgOperand(1)->getType(), NewClassMask)); + return replaceInstUsesWith(BO, II); + } + + return nullptr; +} + +/// Look for the pattern that conditionally negates a value via math operations: +/// cond.splat = sext i1 cond +/// sub = add cond.splat, x +/// xor = xor sub, cond.splat +/// and rewrite it to do the same, but via logical operations: +/// value.neg = sub 0, value +/// cond = select i1 neg, value.neg, value +Instruction *InstCombinerImpl::canonicalizeConditionalNegationViaMathToSelect( + BinaryOperator &I) { + assert(I.getOpcode() == BinaryOperator::Xor && "Only for xor!"); + Value *Cond, *X; + // As per complexity ordering, `xor` is not commutative here. + if (!match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value())) || + !match(I.getOperand(1), m_SExt(m_Value(Cond))) || + !Cond->getType()->isIntOrIntVectorTy(1) || + !match(I.getOperand(0), m_c_Add(m_SExt(m_Deferred(Cond)), m_Value(X)))) + return nullptr; + return SelectInst::Create(Cond, Builder.CreateNeg(X, X->getName() + ".neg"), + X); +} + /// This a limited reassociation for a special case (see above) where we are /// checking if two values are either both NAN (unordered) or not-NAN (ordered). /// This could be handled more generally in '-reassociation', but it seems like @@ -1430,11 +1541,33 @@ Instruction *InstCombinerImpl::foldCastedBitwiseLogic(BinaryOperator &I) { if (!Cast1) return nullptr; - // Both operands of the logic operation are casts. The casts must be of the - // same type for reduction. - auto CastOpcode = Cast0->getOpcode(); - if (CastOpcode != Cast1->getOpcode() || SrcTy != Cast1->getSrcTy()) + // Both operands of the logic operation are casts. The casts must be the + // same kind for reduction. + Instruction::CastOps CastOpcode = Cast0->getOpcode(); + if (CastOpcode != Cast1->getOpcode()) + return nullptr; + + // If the source types do not match, but the casts are matching extends, we + // can still narrow the logic op. + if (SrcTy != Cast1->getSrcTy()) { + Value *X, *Y; + if (match(Cast0, m_OneUse(m_ZExtOrSExt(m_Value(X)))) && + match(Cast1, m_OneUse(m_ZExtOrSExt(m_Value(Y))))) { + // Cast the narrower source to the wider source type. + unsigned XNumBits = X->getType()->getScalarSizeInBits(); + unsigned YNumBits = Y->getType()->getScalarSizeInBits(); + if (XNumBits < YNumBits) + X = Builder.CreateCast(CastOpcode, X, Y->getType()); + else + Y = Builder.CreateCast(CastOpcode, Y, X->getType()); + // Do the logic op in the intermediate width, then widen more. + Value *NarrowLogic = Builder.CreateBinOp(LogicOpc, X, Y); + return CastInst::Create(CastOpcode, NarrowLogic, DestTy); + } + + // Give up for other cast opcodes. return nullptr; + } Value *Cast0Src = Cast0->getOperand(0); Value *Cast1Src = Cast1->getOperand(0); @@ -1722,6 +1855,77 @@ static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, return nullptr; } +/// Try to reassociate a pair of binops so that values with one use only are +/// part of the same instruction. This may enable folds that are limited with +/// multi-use restrictions and makes it more likely to match other patterns that +/// are looking for a common operand. +static Instruction *reassociateForUses(BinaryOperator &BO, + InstCombinerImpl::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = BO.getOpcode(); + Value *X, *Y, *Z; + if (match(&BO, + m_c_BinOp(Opcode, m_OneUse(m_BinOp(Opcode, m_Value(X), m_Value(Y))), + m_OneUse(m_Value(Z))))) { + if (!isa<Constant>(X) && !isa<Constant>(Y) && !isa<Constant>(Z)) { + // (X op Y) op Z --> (Y op Z) op X + if (!X->hasOneUse()) { + Value *YZ = Builder.CreateBinOp(Opcode, Y, Z); + return BinaryOperator::Create(Opcode, YZ, X); + } + // (X op Y) op Z --> (X op Z) op Y + if (!Y->hasOneUse()) { + Value *XZ = Builder.CreateBinOp(Opcode, X, Z); + return BinaryOperator::Create(Opcode, XZ, Y); + } + } + } + + return nullptr; +} + +// Match +// (X + C2) | C +// (X + C2) ^ C +// (X + C2) & C +// and convert to do the bitwise logic first: +// (X | C) + C2 +// (X ^ C) + C2 +// (X & C) + C2 +// iff bits affected by logic op are lower than last bit affected by math op +static Instruction *canonicalizeLogicFirst(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Type *Ty = I.getType(); + Instruction::BinaryOps OpC = I.getOpcode(); + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + Value *X; + const APInt *C, *C2; + + if (!(match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C2)))) && + match(Op1, m_APInt(C)))) + return nullptr; + + unsigned Width = Ty->getScalarSizeInBits(); + unsigned LastOneMath = Width - C2->countTrailingZeros(); + + switch (OpC) { + case Instruction::And: + if (C->countLeadingOnes() < LastOneMath) + return nullptr; + break; + case Instruction::Xor: + case Instruction::Or: + if (C->countLeadingZeros() < LastOneMath) + return nullptr; + break; + default: + llvm_unreachable("Unexpected BinaryOp!"); + } + + Value *NewBinOp = Builder.CreateBinOp(OpC, X, ConstantInt::get(Ty, *C)); + return BinaryOperator::CreateAdd(NewBinOp, ConstantInt::get(Ty, *C2)); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1754,7 +1958,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return X; // (A|B)&(A|C) -> A|(B&C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -2156,24 +2360,36 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); - // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 - unsigned FullShift = Ty->getScalarSizeInBits() - 1; - if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), - m_Value(Y)))) { + // Similarly, a 'not' of the bool translates to a swap of the select arms: + // ~sext(A) & Op1 --> A ? 0 : Op1 + // Op0 & ~sext(A) --> A ? 0 : Op0 + if (match(Op0, m_Not(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), Op1); + if (match(Op1, m_Not(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Constant::getNullValue(Ty), Op0); + + // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 -- with optional sext + if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( + m_AShr(m_Value(X), m_APIntAllowUndef(C)))), + m_Value(Y))) && + *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); return SelectInst::Create(IsNeg, Y, ConstantInt::getNullValue(Ty)); } // If there's a 'not' of the shifted value, swap the select operands: - // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y - if (match(&I, m_c_And(m_OneUse(m_Not( - m_AShr(m_Value(X), m_SpecificInt(FullShift)))), - m_Value(Y)))) { + // ~(iN X s>> (N-1)) & Y --> (X s< 0) ? 0 : Y -- with optional sext + if (match(&I, m_c_And(m_OneUse(m_SExtOrSelf( + m_Not(m_AShr(m_Value(X), m_APIntAllowUndef(C))))), + m_Value(Y))) && + *C == X->getType()->getScalarSizeInBits() - 1) { Value *IsNeg = Builder.CreateIsNeg(X, "isneg"); return SelectInst::Create(IsNeg, ConstantInt::getNullValue(Ty), Y); } // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions - if (sinkNotIntoOtherHandOfAndOrOr(I)) + if (sinkNotIntoOtherHandOfLogicalOp(I)) return &I; // An and recurrence w/loop invariant step is equivelent to (and start, step) @@ -2182,6 +2398,15 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (matchSimpleRecurrence(&I, PN, Start, Step) && DT.dominates(Step, PN)) return replaceInstUsesWith(I, Builder.CreateAnd(Start, Step)); + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + return nullptr; } @@ -2375,7 +2600,9 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// We have an expression of the form (A & C) | (B & D). If A is a scalar or /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. -Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { +/// We will detect (A & C) | ~(B | D) when the flag ABIsTheSame enabled. +Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B, + bool ABIsTheSame) { // We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); @@ -2383,7 +2610,7 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { return nullptr; // If A is the 'not' operand of B and has enough signbits, we have our answer. - if (match(B, m_Not(m_Specific(A)))) { + if (ABIsTheSame ? (A == B) : match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; @@ -2403,6 +2630,10 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { return nullptr; } + // TODO: add support for sext and constant case + if (ABIsTheSame) + return nullptr; + // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) @@ -2451,14 +2682,17 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { /// We have an expression of the form (A & C) | (B & D). Try to simplify this /// to "A' ? C : D", where A' is a boolean or vector of booleans. +/// When InvertFalseVal is set to true, we try to match the pattern +/// where we have peeked through a 'not' op and A and B are the same: +/// (A & C) | ~(A | D) --> (A & C) | (~A & ~D) --> A' ? C : ~D Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, - Value *D) { + Value *D, bool InvertFalseVal) { // The potential condition of the select may be bitcasted. In that case, look // through its bitcast and the corresponding bitcast of the 'not' condition. Type *OrigType = A->getType(); A = peekThroughBitcast(A, true); B = peekThroughBitcast(B, true); - if (Value *Cond = getSelectCondition(A, B)) { + if (Value *Cond = getSelectCondition(A, B, InvertFalseVal)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will @@ -2469,11 +2703,13 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, unsigned Elts = VecTy->getElementCount().getKnownMinValue(); // For a fixed or scalable vector, get the size in bits of N x iM; for a // scalar this is just M. - unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinSize(); + unsigned SelEltSize = SelTy->getPrimitiveSizeInBits().getKnownMinValue(); Type *EltTy = Builder.getIntNTy(SelEltSize / Elts); SelTy = VectorType::get(EltTy, VecTy->getElementCount()); } Value *BitcastC = Builder.CreateBitCast(C, SelTy); + if (InvertFalseVal) + D = Builder.CreateNot(D); Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); @@ -2484,8 +2720,9 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, // (icmp eq X, 0) | (icmp ult Other, X) -> (icmp ule Other, X-1) // (icmp ne X, 0) & (icmp uge Other, X) -> (icmp ugt Other, X-1) -Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, - IRBuilderBase &Builder) { +static Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, + bool IsAnd, bool IsLogical, + IRBuilderBase &Builder) { ICmpInst::Predicate LPred = IsAnd ? LHS->getInversePredicate() : LHS->getPredicate(); ICmpInst::Predicate RPred = @@ -2504,6 +2741,8 @@ Value *foldAndOrOfICmpEqZeroAndICmp(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else return nullptr; + if (IsLogical) + Other = Builder.CreateFreeze(Other); return Builder.CreateICmp( IsAnd ? ICmpInst::ICMP_ULT : ICmpInst::ICMP_UGE, Builder.CreateAdd(LHS0, Constant::getAllOnesValue(LHS0->getType())), @@ -2552,22 +2791,23 @@ Value *InstCombinerImpl::foldAndOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, IsAnd, IsLogical, Builder)) return V; - // TODO: One of these directions is fine with logical and/or, the other could - // be supported by inserting freeze. - if (!IsLogical) { - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, Builder)) - return V; - if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, Builder)) - return V; - } + if (Value *V = + foldAndOrOfICmpEqZeroAndICmp(LHS, RHS, IsAnd, IsLogical, Builder)) + return V; + // We can treat logical like bitwise here, because both operands are used on + // the LHS, and as such poison from both will propagate. + if (Value *V = foldAndOrOfICmpEqZeroAndICmp(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder)) + return V; - // TODO: Verify whether this is safe for logical and/or. - if (!IsLogical) { - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, Builder, Q)) - return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, Builder, Q)) - return V; - } + if (Value *V = + foldAndOrOfICmpsWithConstEq(LHS, RHS, IsAnd, IsLogical, Builder, Q)) + return V; + // We can convert this case to bitwise and, because both operands are used + // on the LHS, and as such poison from both will propagate. + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, IsAnd, + /*IsLogical*/ false, Builder, Q)) + return V; if (Value *V = foldIsPowerOf2OrZero(LHS, RHS, IsAnd, Builder)) return V; @@ -2724,7 +2964,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return X; // (A&B)|(A&C) -> A&(B|C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); if (Value *V = SimplifyBSwap(I, Builder)) @@ -2777,6 +3017,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { return BinaryOperator::CreateMul(X, IncrementY); } + // X | (X ^ Y) --> X | Y (4 commuted patterns) + if (match(&I, m_c_Or(m_Value(X), m_c_Xor(m_Deferred(X), m_Value(Y))))) + return BinaryOperator::CreateOr(X, Y); + // (A & C) | (B & D) Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && @@ -2854,6 +3098,20 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } } + if (match(Op0, m_And(m_Value(A), m_Value(C))) && + match(Op1, m_Not(m_Or(m_Value(B), m_Value(D)))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // (Cond & C) | ~(Cond | D) -> Cond ? C : ~D + if (Value *V = matchSelectFromAndOr(A, C, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(A, C, D, B, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, B, D, true)) + return replaceInstUsesWith(I, V); + if (Value *V = matchSelectFromAndOr(C, A, D, B, true)) + return replaceInstUsesWith(I, V); + } + // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) @@ -2886,30 +3144,58 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { SwappedForXor = true; } - // A | ( A ^ B) -> A | B - // A | (~A ^ B) -> A | ~B - // (A & B) | (A ^ B) - // ~A | (A ^ B) -> ~(A & B) - // The swap above should always make Op0 the 'not' for the last case. if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) { - if (Op0 == A || Op0 == B) - return BinaryOperator::CreateOr(A, B); + // (A | ?) | (A ^ B) --> (A | ?) | B + // (B | ?) | (A ^ B) --> (B | ?) | A + if (match(Op0, m_c_Or(m_Specific(A), m_Value()))) + return BinaryOperator::CreateOr(Op0, B); + if (match(Op0, m_c_Or(m_Specific(B), m_Value()))) + return BinaryOperator::CreateOr(Op0, A); + // (A & B) | (A ^ B) --> A | B + // (B & A) | (A ^ B) --> A | B if (match(Op0, m_And(m_Specific(A), m_Specific(B))) || match(Op0, m_And(m_Specific(B), m_Specific(A)))) return BinaryOperator::CreateOr(A, B); + // ~A | (A ^ B) --> ~(A & B) + // ~B | (A ^ B) --> ~(A & B) + // The swap above should always make Op0 the 'not'. if ((Op0->hasOneUse() || Op1->hasOneUse()) && (match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B))))) return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + // Same as above, but peek through an 'and' to the common operand: + // ~(A & ?) | (A ^ B) --> ~((A & ?) & B) + // ~(B & ?) | (A ^ B) --> ~((B & ?) & A) + Instruction *And; + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + match(Op0, m_Not(m_CombineAnd(m_Instruction(And), + m_c_And(m_Specific(A), m_Value()))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(And, B)); + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + match(Op0, m_Not(m_CombineAnd(m_Instruction(And), + m_c_And(m_Specific(B), m_Value()))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(And, A)); + + // (~A | C) | (A ^ B) --> ~(A & B) | C + // (~B | C) | (A ^ B) --> ~(A & B) | C + if (Op0->hasOneUse() && Op1->hasOneUse() && + (match(Op0, m_c_Or(m_Not(m_Specific(A)), m_Value(C))) || + match(Op0, m_c_Or(m_Not(m_Specific(B)), m_Value(C))))) { + Value *Nand = Builder.CreateNot(Builder.CreateAnd(A, B), "nand"); + return BinaryOperator::CreateOr(Nand, C); + } + + // A | (~A ^ B) --> ~B | A + // B | (A ^ ~B) --> ~A | B if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { - Value *Not = Builder.CreateNot(B, B->getName() + ".not"); - return BinaryOperator::CreateOr(Not, Op0); + Value *NotB = Builder.CreateNot(B, B->getName() + ".not"); + return BinaryOperator::CreateOr(NotB, Op0); } if (Op1->hasOneUse() && match(B, m_Not(m_Specific(Op0)))) { - Value *Not = Builder.CreateNot(A, A->getName() + ".not"); - return BinaryOperator::CreateOr(Not, Op0); + Value *NotA = Builder.CreateNot(A, A->getName() + ".not"); + return BinaryOperator::CreateOr(NotA, Op0); } } @@ -3072,7 +3358,7 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { } // (~x) | y --> ~(x & (~y)) iff that gets rid of inversions - if (sinkNotIntoOtherHandOfAndOrOr(I)) + if (sinkNotIntoOtherHandOfLogicalOp(I)) return &I; // Improve "get low bit mask up to and including bit X" pattern: @@ -3121,6 +3407,15 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Builder.CreateOr(C, Builder.CreateAnd(A, B)), D); } + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + return nullptr; } @@ -3338,14 +3633,8 @@ static Instruction *visitMaskedMerge(BinaryOperator &I, // (~x) ^ y // or into // x ^ (~y) -static Instruction *sinkNotIntoXor(BinaryOperator &I, +static Instruction *sinkNotIntoXor(BinaryOperator &I, Value *X, Value *Y, InstCombiner::BuilderTy &Builder) { - Value *X, *Y; - // FIXME: one-use check is not needed in general, but currently we are unable - // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) - if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) - return nullptr; - // We only want to do the transform if it is free to do. if (InstCombiner::isFreeToInvert(X, X->hasOneUse())) { // Ok, good. @@ -3358,6 +3647,41 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return BinaryOperator::CreateXor(NotX, Y, I.getName() + ".demorgan"); } +static Instruction *foldNotXor(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y; + // FIXME: one-use check is not needed in general, but currently we are unable + // to fold 'not' into 'icmp', if that 'icmp' has multiple uses. (D35182) + if (!match(&I, m_Not(m_OneUse(m_Xor(m_Value(X), m_Value(Y)))))) + return nullptr; + + if (Instruction *NewXor = sinkNotIntoXor(I, X, Y, Builder)) + return NewXor; + + auto hasCommonOperand = [](Value *A, Value *B, Value *C, Value *D) { + return A == C || A == D || B == C || B == D; + }; + + Value *A, *B, *C, *D; + // Canonicalize ~((A & B) ^ (A | ?)) -> (A & B) | ~(A | ?) + // 4 commuted variants + if (match(X, m_And(m_Value(A), m_Value(B))) && + match(Y, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) { + Value *NotY = Builder.CreateNot(Y); + return BinaryOperator::CreateOr(X, NotY); + }; + + // Canonicalize ~((A | ?) ^ (A & B)) -> (A & B) | ~(A | ?) + // 4 commuted variants + if (match(Y, m_And(m_Value(A), m_Value(B))) && + match(X, m_Or(m_Value(C), m_Value(D))) && hasCommonOperand(A, B, C, D)) { + Value *NotX = Builder.CreateNot(X); + return BinaryOperator::CreateOr(Y, NotX); + }; + + return nullptr; +} + /// Canonicalize a shifty way to code absolute value to the more common pattern /// that uses negation and select. static Instruction *canonicalizeAbs(BinaryOperator &Xor, @@ -3392,39 +3716,127 @@ static Instruction *canonicalizeAbs(BinaryOperator &Xor, } // Transform +// z = ~(x &/| y) +// into: +// z = ((~x) |/& (~y)) +// iff both x and y are free to invert and all uses of z can be freely updated. +bool InstCombinerImpl::sinkNotIntoLogicalOp(Instruction &I) { + Value *Op0, *Op1; + if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1)))) + return false; + + // If this logic op has not been simplified yet, just bail out and let that + // happen first. Otherwise, the code below may wrongly invert. + if (Op0 == Op1) + return false; + + Instruction::BinaryOps NewOpc = + match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And; + bool IsBinaryOp = isa<BinaryOperator>(I); + + // Can our users be adapted? + if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) + return false; + + // And can the operands be adapted? + for (Value *Op : {Op0, Op1}) + if (!(InstCombiner::isFreeToInvert(Op, /*WillInvertAllUses=*/true) && + (match(Op, m_ImmConstant()) || + (isa<Instruction>(Op) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op), + /*IgnoredUser=*/&I))))) + return false; + + for (Value **Op : {&Op0, &Op1}) { + Value *NotOp; + if (auto *C = dyn_cast<Constant>(*Op)) { + NotOp = ConstantExpr::getNot(C); + } else { + Builder.SetInsertPoint( + &*cast<Instruction>(*Op)->getInsertionPointAfterDef()); + NotOp = Builder.CreateNot(*Op, (*Op)->getName() + ".not"); + (*Op)->replaceUsesWithIf( + NotOp, [NotOp](Use &U) { return U.getUser() != NotOp; }); + freelyInvertAllUsersOf(NotOp, /*IgnoredUser=*/&I); + } + *Op = NotOp; + } + + Builder.SetInsertPoint(I.getInsertionPointAfterDef()); + Value *NewLogicOp; + if (IsBinaryOp) + NewLogicOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); + else + NewLogicOp = + Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not"); + + replaceInstUsesWith(I, NewLogicOp); + // We can not just create an outer `not`, it will most likely be immediately + // folded back, reconstructing our initial pattern, and causing an + // infinite combine loop, so immediately manually fold it away. + freelyInvertAllUsersOf(NewLogicOp); + return true; +} + +// Transform // z = (~x) &/| y // into: // z = ~(x |/& (~y)) // iff y is free to invert and all uses of z can be freely updated. -bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { - Instruction::BinaryOps NewOpc; - switch (I.getOpcode()) { - case Instruction::And: - NewOpc = Instruction::Or; - break; - case Instruction::Or: - NewOpc = Instruction::And; - break; - default: +bool InstCombinerImpl::sinkNotIntoOtherHandOfLogicalOp(Instruction &I) { + Value *Op0, *Op1; + if (!match(&I, m_LogicalOp(m_Value(Op0), m_Value(Op1)))) return false; - }; + Instruction::BinaryOps NewOpc = + match(&I, m_LogicalAnd()) ? Instruction::Or : Instruction::And; + bool IsBinaryOp = isa<BinaryOperator>(I); - Value *X, *Y; - if (!match(&I, m_c_BinOp(m_Not(m_Value(X)), m_Value(Y)))) - return false; - - // Will we be able to fold the `not` into Y eventually? - if (!InstCombiner::isFreeToInvert(Y, Y->hasOneUse())) + Value *NotOp0 = nullptr; + Value *NotOp1 = nullptr; + Value **OpToInvert = nullptr; + if (match(Op0, m_Not(m_Value(NotOp0))) && + InstCombiner::isFreeToInvert(Op1, /*WillInvertAllUses=*/true) && + (match(Op1, m_ImmConstant()) || + (isa<Instruction>(Op1) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op1), + /*IgnoredUser=*/&I)))) { + Op0 = NotOp0; + OpToInvert = &Op1; + } else if (match(Op1, m_Not(m_Value(NotOp1))) && + InstCombiner::isFreeToInvert(Op0, /*WillInvertAllUses=*/true) && + (match(Op0, m_ImmConstant()) || + (isa<Instruction>(Op0) && + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(Op0), + /*IgnoredUser=*/&I)))) { + Op1 = NotOp1; + OpToInvert = &Op0; + } else return false; // And can our users be adapted? if (!InstCombiner::canFreelyInvertAllUsersOf(&I, /*IgnoredUser=*/nullptr)) return false; - Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); - Value *NewBinOp = - BinaryOperator::Create(NewOpc, X, NotY, I.getName() + ".not"); - Builder.Insert(NewBinOp); + if (auto *C = dyn_cast<Constant>(*OpToInvert)) { + *OpToInvert = ConstantExpr::getNot(C); + } else { + Builder.SetInsertPoint( + &*cast<Instruction>(*OpToInvert)->getInsertionPointAfterDef()); + Value *NotOpToInvert = + Builder.CreateNot(*OpToInvert, (*OpToInvert)->getName() + ".not"); + (*OpToInvert)->replaceUsesWithIf(NotOpToInvert, [NotOpToInvert](Use &U) { + return U.getUser() != NotOpToInvert; + }); + freelyInvertAllUsersOf(NotOpToInvert, /*IgnoredUser=*/&I); + *OpToInvert = NotOpToInvert; + } + + Builder.SetInsertPoint(&*I.getInsertionPointAfterDef()); + Value *NewBinOp; + if (IsBinaryOp) + NewBinOp = Builder.CreateBinOp(NewOpc, Op0, Op1, I.getName() + ".not"); + else + NewBinOp = Builder.CreateLogicalOp(NewOpc, Op0, Op1, I.getName() + ".not"); replaceInstUsesWith(I, NewBinOp); // We can not just create an outer `not`, it will most likely be immediately // folded back, reconstructing our initial pattern, and causing an @@ -3472,23 +3884,6 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; if (match(NotOp, m_BinOp(NotVal))) { - if (NotVal->getOpcode() == Instruction::And || - NotVal->getOpcode() == Instruction::Or) { - // Apply DeMorgan's Law when inverts are free: - // ~(X & Y) --> (~X | ~Y) - // ~(X | Y) --> (~X & ~Y) - if (isFreeToInvert(NotVal->getOperand(0), - NotVal->getOperand(0)->hasOneUse()) && - isFreeToInvert(NotVal->getOperand(1), - NotVal->getOperand(1)->hasOneUse())) { - Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); - Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); - if (NotVal->getOpcode() == Instruction::And) - return BinaryOperator::CreateOr(NotX, NotY); - return BinaryOperator::CreateAnd(NotX, NotY); - } - } - // ~((-X) | Y) --> (X - 1) & (~Y) if (match(NotVal, m_OneUse(m_c_Or(m_OneUse(m_Neg(m_Value(X))), m_Value(Y))))) { @@ -3501,6 +3896,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { if (match(NotVal, m_AShr(m_Not(m_Value(X)), m_Value(Y)))) return BinaryOperator::CreateAShr(X, Y); + // Bit-hack form of a signbit test: + // iN ~X >>s (N-1) --> sext i1 (X > -1) to iN + unsigned FullShift = Ty->getScalarSizeInBits() - 1; + if (match(NotVal, m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))))) { + Value *IsNotNeg = Builder.CreateIsNotNeg(X, "isnotneg"); + return new SExtInst(IsNotNeg, Ty); + } + // If we are inverting a right-shifted constant, we may be able to eliminate // the 'not' by inverting the constant and using the opposite shift type. // Canonicalization rules ensure that only a negative constant uses 'ashr', @@ -3545,11 +3948,28 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { // not (cmp A, B) = !cmp A, B CmpInst::Predicate Pred; - if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + if (match(NotOp, m_Cmp(Pred, m_Value(), m_Value())) && + (NotOp->hasOneUse() || + InstCombiner::canFreelyInvertAllUsersOf(cast<Instruction>(NotOp), + /*IgnoredUser=*/nullptr))) { cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred)); - return replaceInstUsesWith(I, NotOp); + freelyInvertAllUsersOf(NotOp); + return &I; + } + + // Move a 'not' ahead of casts of a bool to enable logic reduction: + // not (bitcast (sext i1 X)) --> bitcast (sext (not i1 X)) + if (match(NotOp, m_OneUse(m_BitCast(m_OneUse(m_SExt(m_Value(X)))))) && X->getType()->isIntOrIntVectorTy(1)) { + Type *SextTy = cast<BitCastOperator>(NotOp)->getSrcTy(); + Value *NotX = Builder.CreateNot(X); + Value *Sext = Builder.CreateSExt(NotX, SextTy); + return CastInst::CreateBitOrPointerCast(Sext, Ty); } + if (auto *NotOpI = dyn_cast<Instruction>(NotOp)) + if (sinkNotIntoLogicalOp(*NotOpI)) + return &I; + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: // ~min(~X, ~Y) --> max(X, Y) // ~max(~X, Y) --> min(X, ~Y) @@ -3570,6 +3990,14 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); return replaceInstUsesWith(I, InvMaxMin); } + + if (II->getIntrinsicID() == Intrinsic::is_fpclass) { + ConstantInt *ClassMask = cast<ConstantInt>(II->getArgOperand(1)); + II->setArgOperand( + 1, ConstantInt::get(ClassMask->getType(), + ~ClassMask->getZExtValue() & fcAllFlags)); + return replaceInstUsesWith(I, II); + } } if (NotOp->hasOneUse()) { @@ -3602,7 +4030,7 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { } } - if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + if (Instruction *NewXor = foldNotXor(I, Builder)) return NewXor; return nullptr; @@ -3629,7 +4057,7 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { return NewXor; // (A&B)^(A&C) -> A&(B^C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); // See if we can simplify any instructions used by the instruction whose sole @@ -3718,6 +4146,21 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { MaskedValueIsZero(X, *C, 0, &I)) return BinaryOperator::CreateXor(X, ConstantInt::get(Ty, *C ^ *RHSC)); + // When X is a power-of-two or zero and zero input is poison: + // ctlz(i32 X) ^ 31 --> cttz(X) + // cttz(i32 X) ^ 31 --> ctlz(X) + auto *II = dyn_cast<IntrinsicInst>(Op0); + if (II && II->hasOneUse() && *RHSC == Ty->getScalarSizeInBits() - 1) { + Intrinsic::ID IID = II->getIntrinsicID(); + if ((IID == Intrinsic::ctlz || IID == Intrinsic::cttz) && + match(II->getArgOperand(1), m_One()) && + isKnownToBeAPowerOfTwo(II->getArgOperand(0), /*OrZero */ true)) { + IID = (IID == Intrinsic::ctlz) ? Intrinsic::cttz : Intrinsic::ctlz; + Function *F = Intrinsic::getDeclaration(II->getModule(), IID, Ty); + return CallInst::Create(F, {II->getArgOperand(0), Builder.getTrue()}); + } + } + // If RHSC is inverting the remaining bits of shifted X, // canonicalize to a 'not' before the shift to help SCEV and codegen: // (X << C) ^ RHSC --> ~X << C @@ -3858,5 +4301,17 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { m_Value(Y)))) return BinaryOperator::CreateXor(Builder.CreateXor(X, Y), C1); + if (Instruction *R = reassociateForUses(I, Builder)) + return R; + + if (Instruction *Canonicalized = canonicalizeLogicFirst(I, Builder)) + return Canonicalized; + + if (Instruction *Folded = foldLogicOfIsFPClass(I, Op0, Op1)) + return Folded; + + if (Instruction *Folded = canonicalizeConditionalNegationViaMathToSelect(I)) + return Folded; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 0327efbf9614..e73667f9c02e 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -128,10 +128,9 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { if (Ordering != AtomicOrdering::Release && Ordering != AtomicOrdering::Monotonic) return nullptr; - auto *SI = new StoreInst(RMWI.getValOperand(), - RMWI.getPointerOperand(), &RMWI); - SI->setAtomic(Ordering, RMWI.getSyncScopeID()); - SI->setAlignment(DL.getABITypeAlign(RMWI.getType())); + new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), + /*isVolatile*/ false, RMWI.getAlign(), Ordering, + RMWI.getSyncScopeID(), &RMWI); return eraseInstFromFunction(RMWI); } @@ -152,13 +151,5 @@ Instruction *InstCombinerImpl::visitAtomicRMWInst(AtomicRMWInst &RMWI) { return replaceOperand(RMWI, 1, ConstantFP::getNegativeZero(RMWI.getType())); } - // Check if the required ordering is compatible with an atomic load. - if (Ordering != AtomicOrdering::Acquire && - Ordering != AtomicOrdering::Monotonic) - return nullptr; - - LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand(), "", - false, DL.getABITypeAlign(RMWI.getType()), - Ordering, RMWI.getSyncScopeID()); - return Load; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index bc01d2ef7fe2..fbf1327143a8 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/APSInt.h" #include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLFunctionalExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallVector.h" @@ -34,6 +32,7 @@ #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" @@ -71,6 +70,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <optional> #include <utility> #include <vector> @@ -135,7 +135,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(MI->getDest())) { + if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) { // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(MI->getLength()->getType())); return MI; @@ -223,6 +223,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { S->setMetadata(LLVMContext::MD_mem_parallel_loop_access, LoopMemParallelMD); if (AccessGroupMD) S->setMetadata(LLVMContext::MD_access_group, AccessGroupMD); + S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); if (auto *MT = dyn_cast<MemTransferInst>(MI)) { // non-atomics can be volatile @@ -252,7 +253,7 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(MI->getDest())) { + if (!isModSet(AA->getModRefInfoMask(MI->getDest()))) { // Set the size of the copy to 0, it will be deleted on the next iteration. MI->setLength(Constant::getNullValue(MI->getLength()->getType())); return MI; @@ -294,9 +295,15 @@ Instruction *InstCombinerImpl::SimplifyAnyMemSet(AnyMemSetInst *MI) { Dest = Builder.CreateBitCast(Dest, NewDstPtrTy); // Extract the fill value and store. - uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; - StoreInst *S = Builder.CreateStore(ConstantInt::get(ITy, Fill), Dest, - MI->isVolatile()); + const uint64_t Fill = FillC->getZExtValue()*0x0101010101010101ULL; + Constant *FillVal = ConstantInt::get(ITy, Fill); + StoreInst *S = Builder.CreateStore(FillVal, Dest, MI->isVolatile()); + S->copyMetadata(*MI, LLVMContext::MD_DIAssignID); + for (auto *DAI : at::getAssignmentMarkers(S)) { + if (any_of(DAI->location_ops(), [&](Value *V) { return V == FillC; })) + DAI->replaceVariableLocationOp(FillC, FillVal); + } + S->setAlignment(Alignment); if (isa<AtomicMemSetInst>(MI)) S->setOrdering(AtomicOrdering::Unordered); @@ -328,7 +335,7 @@ Value *InstCombinerImpl::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query if (isDereferenceablePointer(LoadPtr, II.getType(), - II.getModule()->getDataLayout(), &II, nullptr)) { + II.getModule()->getDataLayout(), &II, &AC)) { LoadInst *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); LI->copyMetadata(II); @@ -661,10 +668,21 @@ static Instruction *foldCtpop(IntrinsicInst &II, InstCombinerImpl &IC) { // If all bits are zero except for exactly one fixed bit, then the result // must be 0 or 1, and we can get that answer by shifting to LSB: // ctpop (X & 32) --> (X & 32) >> 5 + // TODO: Investigate removing this as its likely unnecessary given the below + // `isKnownToBeAPowerOfTwo` check. if ((~Known.Zero).isPowerOf2()) return BinaryOperator::CreateLShr( Op0, ConstantInt::get(Ty, (~Known.Zero).exactLogBase2())); + // More generally we can also handle non-constant power of 2 patterns such as + // shl/shr(Pow2, X), (X & -X), etc... by transforming: + // ctpop(Pow2OrZero) --> icmp ne X, 0 + if (IC.isKnownToBeAPowerOfTwo(Op0, /* OrZero */ true)) + return CastInst::Create(Instruction::ZExt, + IC.Builder.CreateICmp(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Ty)), + Ty); + // FIXME: Try to simplify vectors of integers. auto *IT = dyn_cast<IntegerType>(Ty); if (!IT) @@ -720,7 +738,7 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, auto *V1 = II.getArgOperand(0); auto *V2 = Constant::getNullValue(V1->getType()); - return Builder.CreateShuffleVector(V1, V2, makeArrayRef(Indexes)); + return Builder.CreateShuffleVector(V1, V2, ArrayRef(Indexes)); } // Returns true iff the 2 intrinsics have the same operands, limiting the @@ -812,9 +830,10 @@ InstCombinerImpl::foldIntrinsicWithOverflowCommon(IntrinsicInst *II) { return nullptr; } -static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, - const DataLayout &DL, AssumptionCache *AC, - DominatorTree *DT) { +static std::optional<bool> getKnownSign(Value *Op, Instruction *CxtI, + const DataLayout &DL, + AssumptionCache *AC, + DominatorTree *DT) { KnownBits Known = computeKnownBits(Op, DL, 0, AC, CxtI, DT); if (Known.isNonNegative()) return false; @@ -1266,7 +1285,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (match(IIOperand, m_Select(m_Value(), m_Neg(m_Value(X)), m_Deferred(X)))) return replaceOperand(*II, 0, X); - if (Optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { + if (std::optional<bool> Sign = getKnownSign(IIOperand, II, DL, &AC, &DT)) { // abs(x) -> x if x >= 0 if (!*Sign) return replaceInstUsesWith(*II, IIOperand); @@ -1297,11 +1316,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); // umin(x, 1) == zext(x != 0) if (match(I1, m_One())) { + assert(II->getType()->getScalarSizeInBits() != 1 && + "Expected simplify of umin with max constant"); Value *Zero = Constant::getNullValue(I0->getType()); Value *Cmp = Builder.CreateICmpNE(I0, Zero); return CastInst::Create(Instruction::ZExt, Cmp, II->getType()); } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::umax: { Value *I0 = II->getArgOperand(0), *I1 = II->getArgOperand(1); @@ -1322,7 +1343,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } // If both operands of unsigned min/max are sign-extended, it is still ok // to narrow the operation. - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::smax: case Intrinsic::smin: { @@ -1431,6 +1452,18 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } + case Intrinsic::bitreverse: { + // bitrev (zext i1 X to ?) --> X ? SignBitC : 0 + Value *X; + if (match(II->getArgOperand(0), m_ZExt(m_Value(X))) && + X->getType()->isIntOrIntVectorTy(1)) { + Type *Ty = II->getType(); + APInt SignBit = APInt::getSignMask(Ty->getScalarSizeInBits()); + return SelectInst::Create(X, ConstantInt::get(Ty, SignBit), + ConstantInt::getNullValue(Ty)); + } + break; + } case Intrinsic::bswap: { Value *IIOperand = II->getArgOperand(0); @@ -1829,6 +1862,63 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } + case Intrinsic::matrix_multiply: { + // Optimize negation in matrix multiplication. + + // -A * -B -> A * B + Value *A, *B; + if (match(II->getArgOperand(0), m_FNeg(m_Value(A))) && + match(II->getArgOperand(1), m_FNeg(m_Value(B)))) { + replaceOperand(*II, 0, A); + replaceOperand(*II, 1, B); + return II; + } + + Value *Op0 = II->getOperand(0); + Value *Op1 = II->getOperand(1); + Value *OpNotNeg, *NegatedOp; + unsigned NegatedOpArg, OtherOpArg; + if (match(Op0, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op0; + NegatedOpArg = 0; + OtherOpArg = 1; + } else if (match(Op1, m_FNeg(m_Value(OpNotNeg)))) { + NegatedOp = Op1; + NegatedOpArg = 1; + OtherOpArg = 0; + } else + // Multiplication doesn't have a negated operand. + break; + + // Only optimize if the negated operand has only one use. + if (!NegatedOp->hasOneUse()) + break; + + Value *OtherOp = II->getOperand(OtherOpArg); + VectorType *RetTy = cast<VectorType>(II->getType()); + VectorType *NegatedOpTy = cast<VectorType>(NegatedOp->getType()); + VectorType *OtherOpTy = cast<VectorType>(OtherOp->getType()); + ElementCount NegatedCount = NegatedOpTy->getElementCount(); + ElementCount OtherCount = OtherOpTy->getElementCount(); + ElementCount RetCount = RetTy->getElementCount(); + // (-A) * B -> A * (-B), if it is cheaper to negate B and vice versa. + if (ElementCount::isKnownGT(NegatedCount, OtherCount) && + ElementCount::isKnownLT(OtherCount, RetCount)) { + Value *InverseOtherOp = Builder.CreateFNeg(OtherOp); + replaceOperand(*II, NegatedOpArg, OpNotNeg); + replaceOperand(*II, OtherOpArg, InverseOtherOp); + return II; + } + // (-A) * B -> -(A * B), if it is cheaper to negate the result + if (ElementCount::isKnownGT(NegatedCount, RetCount)) { + SmallVector<Value *, 5> NewArgs(II->args()); + NewArgs[NegatedOpArg] = OpNotNeg; + Instruction *NewMul = + Builder.CreateIntrinsic(II->getType(), IID, NewArgs, II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(NewMul, II)); + } + break; + } case Intrinsic::fmuladd: { // Canonicalize fast fmuladd to the separate fmul + fadd. if (II->isFast()) { @@ -1850,7 +1940,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return FAdd; } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::fma: { // fma fneg(x), fneg(y), z -> fma x, y, z @@ -1940,7 +2030,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceOperand(*II, 0, TVal); } - LLVM_FALLTHROUGH; + Value *Magnitude, *Sign; + if (match(II->getArgOperand(0), + m_CopySign(m_Value(Magnitude), m_Value(Sign)))) { + // fabs (copysign x, y) -> (fabs x) + CallInst *AbsSign = + Builder.CreateCall(II->getCalledFunction(), {Magnitude}); + AbsSign->copyFastMathFlags(II); + return replaceInstUsesWith(*II, AbsSign); + } + + [[fallthrough]]; } case Intrinsic::ceil: case Intrinsic::floor: @@ -1979,7 +2079,64 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::ptrauth_auth: + case Intrinsic::ptrauth_resign: { + // (sign|resign) + (auth|resign) can be folded by omitting the middle + // sign+auth component if the key and discriminator match. + bool NeedSign = II->getIntrinsicID() == Intrinsic::ptrauth_resign; + Value *Key = II->getArgOperand(1); + Value *Disc = II->getArgOperand(2); + // AuthKey will be the key we need to end up authenticating against in + // whatever we replace this sequence with. + Value *AuthKey = nullptr, *AuthDisc = nullptr, *BasePtr; + if (auto CI = dyn_cast<CallBase>(II->getArgOperand(0))) { + BasePtr = CI->getArgOperand(0); + if (CI->getIntrinsicID() == Intrinsic::ptrauth_sign) { + if (CI->getArgOperand(1) != Key || CI->getArgOperand(2) != Disc) + break; + } else if (CI->getIntrinsicID() == Intrinsic::ptrauth_resign) { + if (CI->getArgOperand(3) != Key || CI->getArgOperand(4) != Disc) + break; + AuthKey = CI->getArgOperand(1); + AuthDisc = CI->getArgOperand(2); + } else + break; + } else + break; + + unsigned NewIntrin; + if (AuthKey && NeedSign) { + // resign(0,1) + resign(1,2) = resign(0, 2) + NewIntrin = Intrinsic::ptrauth_resign; + } else if (AuthKey) { + // resign(0,1) + auth(1) = auth(0) + NewIntrin = Intrinsic::ptrauth_auth; + } else if (NeedSign) { + // sign(0) + resign(0, 1) = sign(1) + NewIntrin = Intrinsic::ptrauth_sign; + } else { + // sign(0) + auth(0) = nop + replaceInstUsesWith(*II, BasePtr); + eraseInstFromFunction(*II); + return nullptr; + } + + SmallVector<Value *, 4> CallArgs; + CallArgs.push_back(BasePtr); + if (AuthKey) { + CallArgs.push_back(AuthKey); + CallArgs.push_back(AuthDisc); + } + + if (NeedSign) { + CallArgs.push_back(II->getArgOperand(3)); + CallArgs.push_back(II->getArgOperand(4)); + } + + Function *NewFn = Intrinsic::getDeclaration(II->getModule(), NewIntrin); + return CallInst::Create(NewFn, CallArgs); + } case Intrinsic::arm_neon_vtbl1: case Intrinsic::aarch64_neon_tbl1: if (Value *V = simplifyNeonTbl1(*II, Builder)) @@ -2221,7 +2378,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Pred == ICmpInst::ICMP_NE && LHS->getOpcode() == Instruction::Load && LHS->getType()->isPointerTy() && isValidAssumeForContext(II, LHS, &DT)) { - MDNode *MD = MDNode::get(II->getContext(), None); + MDNode *MD = MDNode::get(II->getContext(), std::nullopt); LHS->setMetadata(LLVMContext::MD_nonnull, MD); return RemoveConditionFromAssume(II); @@ -2288,7 +2445,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { llvm::getKnowledgeFromBundle(cast<AssumeInst>(*II), BOI); if (BOI.End - BOI.Begin > 2) continue; // Prevent reducing knowledge in an align with offset since - // extracting a RetainedKnowledge form them looses offset + // extracting a RetainedKnowledge from them looses offset // information RetainedKnowledge CanonRK = llvm::simplifyRetainedKnowledge(cast<AssumeInst>(II), RK, @@ -2409,7 +2566,31 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Vec = II->getArgOperand(0); Value *Idx = II->getArgOperand(1); - auto *DstTy = dyn_cast<FixedVectorType>(II->getType()); + Type *ReturnType = II->getType(); + // (extract_vector (insert_vector InsertTuple, InsertValue, InsertIdx), + // ExtractIdx) + unsigned ExtractIdx = cast<ConstantInt>(Idx)->getZExtValue(); + Value *InsertTuple, *InsertIdx, *InsertValue; + if (match(Vec, m_Intrinsic<Intrinsic::vector_insert>(m_Value(InsertTuple), + m_Value(InsertValue), + m_Value(InsertIdx))) && + InsertValue->getType() == ReturnType) { + unsigned Index = cast<ConstantInt>(InsertIdx)->getZExtValue(); + // Case where we get the same index right after setting it. + // extract.vector(insert.vector(InsertTuple, InsertValue, Idx), Idx) --> + // InsertValue + if (ExtractIdx == Index) + return replaceInstUsesWith(CI, InsertValue); + // If we are getting a different index than what was set in the + // insert.vector intrinsic. We can just set the input tuple to the one up + // in the chain. extract.vector(insert.vector(InsertTuple, InsertValue, + // InsertIndex), ExtractIndex) + // --> extract.vector(InsertTuple, ExtractIndex) + else + return replaceOperand(CI, 0, InsertTuple); + } + + auto *DstTy = dyn_cast<FixedVectorType>(ReturnType); auto *VecTy = dyn_cast<FixedVectorType>(Vec->getType()); // Only canonicalize if the the destination vector and Vec are fixed @@ -2439,11 +2620,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Value *Vec = II->getArgOperand(0); if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { auto *OldBinOp = cast<BinaryOperator>(Vec); - if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(X)))) { + if (match(BO0, m_VecReverse(m_Value(X)))) { // rev(binop rev(X), rev(Y)) --> binop X, Y - if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(Y)))) + if (match(BO1, m_VecReverse(m_Value(Y)))) return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( OldBinOp->getOpcode(), X, Y, OldBinOp, @@ -2456,17 +2635,13 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { OldBinOp, OldBinOp->getName(), II)); } // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y - if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(Y))) && - isSplatValue(BO0)) + if (match(BO1, m_VecReverse(m_Value(Y))) && isSplatValue(BO0)) return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( OldBinOp->getOpcode(), BO0, Y, OldBinOp, OldBinOp->getName(), II)); } // rev(unop rev(X)) --> unop X - if (match(Vec, m_OneUse(m_UnOp( - m_Intrinsic<Intrinsic::experimental_vector_reverse>( - m_Value(X)))))) { + if (match(Vec, m_OneUse(m_UnOp(m_VecReverse(m_Value(X)))))) { auto *OldUnOp = cast<UnaryOperator>(Vec); auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); @@ -2504,7 +2679,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { return replaceInstUsesWith(CI, Res); } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_add: { if (IID == Intrinsic::vector_reduce_add) { @@ -2531,7 +2706,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_xor: { if (IID == Intrinsic::vector_reduce_xor) { @@ -2555,7 +2730,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_mul: { if (IID == Intrinsic::vector_reduce_mul) { @@ -2577,7 +2752,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_umin: case Intrinsic::vector_reduce_umax: { @@ -2604,7 +2779,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_smin: case Intrinsic::vector_reduce_smax: { @@ -2642,7 +2817,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::vector_reduce_fmax: case Intrinsic::vector_reduce_fmin: @@ -2679,9 +2854,9 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } default: { // Handle target specific intrinsics - Optional<Instruction *> V = targetInstCombineIntrinsic(*II); + std::optional<Instruction *> V = targetInstCombineIntrinsic(*II); if (V) - return V.value(); + return *V; break; } } @@ -2887,7 +3062,7 @@ bool InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, if (!Call.getType()->isPointerTy()) return Changed; - Optional<APInt> Size = getAllocSize(&Call, TLI); + std::optional<APInt> Size = getAllocSize(&Call, TLI); if (Size && *Size != 0) { // TODO: We really should just emit deref_or_null here and then // let the generic inference code combine that with nonnull. @@ -3078,6 +3253,30 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { Call, Builder.CreateBitOrPointerCast(ReturnedArg, CallTy)); } + // Drop unnecessary kcfi operand bundles from calls that were converted + // into direct calls. + auto Bundle = Call.getOperandBundle(LLVMContext::OB_kcfi); + if (Bundle && !Call.isIndirectCall()) { + DEBUG_WITH_TYPE(DEBUG_TYPE "-kcfi", { + if (CalleeF) { + ConstantInt *FunctionType = nullptr; + ConstantInt *ExpectedType = cast<ConstantInt>(Bundle->Inputs[0]); + + if (MDNode *MD = CalleeF->getMetadata(LLVMContext::MD_kcfi_type)) + FunctionType = mdconst::extract<ConstantInt>(MD->getOperand(0)); + + if (FunctionType && + FunctionType->getZExtValue() != ExpectedType->getZExtValue()) + dbgs() << Call.getModule()->getName() + << ": warning: kcfi: " << Call.getCaller()->getName() + << ": call to " << CalleeF->getName() + << " using a mismatching function pointer type\n"; + } + }); + + return CallBase::removeOperandBundle(&Call, LLVMContext::OB_kcfi); + } + if (isRemovableAlloc(&Call, &TLI)) return visitAllocSite(Call); @@ -3140,7 +3339,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { LiveGcValues.insert(BasePtr); LiveGcValues.insert(DerivedPtr); } - Optional<OperandBundleUse> Bundle = + std::optional<OperandBundleUse> Bundle = GCSP.getOperandBundle(LLVMContext::OB_gc_live); unsigned NumOfGCLives = LiveGcValues.size(); if (!Bundle || NumOfGCLives == Bundle->Inputs.size()) @@ -3148,8 +3347,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // We can reduce the size of gc live bundle. DenseMap<Value *, unsigned> Val2Idx; std::vector<Value *> NewLiveGc; - for (unsigned I = 0, E = Bundle->Inputs.size(); I < E; ++I) { - Value *V = Bundle->Inputs[I]; + for (Value *V : Bundle->Inputs) { if (Val2Idx.count(V)) continue; if (LiveGcValues.count(V)) { @@ -3289,6 +3487,10 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) return false; + if (CallerPAL.hasParamAttr(i, Attribute::ByVal) != + Callee->getAttributes().hasParamAttr(i, Attribute::ByVal)) + return false; // Cannot transform to or from byval. + // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { @@ -3447,21 +3649,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { NV = NC = CastInst::CreateBitOrPointerCast(NC, OldRetTy); NC->setDebugLoc(Caller->getDebugLoc()); - // If this is an invoke/callbr instruction, we should insert it after the - // first non-phi instruction in the normal successor block. - if (InvokeInst *II = dyn_cast<InvokeInst>(Caller)) { - BasicBlock::iterator I = II->getNormalDest()->getFirstInsertionPt(); - InsertNewInstBefore(NC, *I); - } else if (CallBrInst *CBI = dyn_cast<CallBrInst>(Caller)) { - BasicBlock::iterator I = CBI->getDefaultDest()->getFirstInsertionPt(); - InsertNewInstBefore(NC, *I); - } else { - // Otherwise, it's a call, just insert cast right after the call. - InsertNewInstBefore(NC, *Caller); - } + Instruction *InsertPt = NewCall->getInsertionPointAfterDef(); + assert(InsertPt && "No place to insert cast"); + InsertNewInstBefore(NC, *InsertPt); Worklist.pushUsersToWorkList(*Caller); } else { - NV = UndefValue::get(Caller->getType()); + NV = PoisonValue::get(Caller->getType()); } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index a9a930555b3c..3f851a2b2182 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -14,9 +14,12 @@ #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfo.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" +#include <optional> + using namespace llvm; using namespace PatternMatch; @@ -118,14 +121,15 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, if (!AI.hasOneUse() && CastElTyAlign == AllocElTyAlign) return nullptr; // The alloc and cast types should be either both fixed or both scalable. - uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinSize(); - uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinSize(); + uint64_t AllocElTySize = DL.getTypeAllocSize(AllocElTy).getKnownMinValue(); + uint64_t CastElTySize = DL.getTypeAllocSize(CastElTy).getKnownMinValue(); if (CastElTySize == 0 || AllocElTySize == 0) return nullptr; // If the allocation has multiple uses, only promote it if we're not // shrinking the amount of memory being allocated. - uint64_t AllocElTyStoreSize = DL.getTypeStoreSize(AllocElTy).getKnownMinSize(); - uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinSize(); + uint64_t AllocElTyStoreSize = + DL.getTypeStoreSize(AllocElTy).getKnownMinValue(); + uint64_t CastElTyStoreSize = DL.getTypeStoreSize(CastElTy).getKnownMinValue(); if (!AI.hasOneUse() && CastElTyStoreSize < AllocElTyStoreSize) return nullptr; // See if we can satisfy the modulus by pulling a scale out of the array @@ -163,6 +167,10 @@ Instruction *InstCombinerImpl::PromoteCastOfAllocation(BitCastInst &CI, New->setAlignment(AI.getAlign()); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); + New->setMetadata(LLVMContext::MD_DIAssignID, + AI.getMetadata(LLVMContext::MD_DIAssignID)); + + replaceAllDbgUsesWith(AI, *New, *New, DT); // If the allocation has multiple real uses, insert a cast and change all // things that used it to use the new cast. This will also hack on CI, but it @@ -239,6 +247,11 @@ Value *InstCombinerImpl::EvaluateInDifferentType(Value *V, Type *Ty, Res = NPN; break; } + case Instruction::FPToUI: + case Instruction::FPToSI: + Res = CastInst::Create( + static_cast<Instruction::CastOps>(Opc), I->getOperand(0), Ty); + break; default: // TODO: Can handle more cases here. llvm_unreachable("Unreachable!"); @@ -483,6 +496,22 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombinerImpl &IC, return false; return true; } + case Instruction::FPToUI: + case Instruction::FPToSI: { + // If the integer type can hold the max FP value, it is safe to cast + // directly to that type. Otherwise, we may create poison via overflow + // that did not exist in the original code. + // + // The max FP value is pow(2, MaxExponent) * (1 + MaxFraction), so we need + // at least one more bit than the MaxExponent to hold the max FP value. + Type *InputTy = I->getOperand(0)->getType()->getScalarType(); + const fltSemantics &Semantics = InputTy->getFltSemantics(); + uint32_t MinBitWidth = APFloatBase::semanticsMaxExponent(Semantics); + // Extra sign bit needed. + if (I->getOpcode() == Instruction::FPToSI) + ++MinBitWidth; + return Ty->getScalarSizeInBits() > MinBitWidth; + } default: // TODO: Can handle more cases here. break; @@ -726,7 +755,7 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, InstCombiner::BuilderTy &Builder) { auto *Shuf = dyn_cast<ShuffleVectorInst>(Trunc.getOperand(0)); if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && - is_splat(Shuf->getShuffleMask()) && + all_equal(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask @@ -974,7 +1003,7 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { Attribute Attr = Trunc.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { if (Log2_32(*MaxVScale) < DestWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); return replaceInstUsesWith(Trunc, VScale); @@ -986,7 +1015,8 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { return nullptr; } -Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) { +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, + ZExtInst &Zext) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -1014,28 +1044,20 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. // zext (X == 0) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - // zext (X == 1) to i32 --> X iff X has only the low bit set. - // zext (X == 2) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 0) to i32 --> X iff X has only the low bit set. // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. - // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. - // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV->isZero() || Op1CV->isPowerOf2()) && - // This only works for EQ and NE - Cmp->isEquality()) { + if (Op1CV->isZero() && Cmp->isEquality() && + (Cmp->getOperand(0)->getType() == Zext.getType() || + Cmp->getPredicate() == ICmpInst::ICMP_NE)) { // If Op1C some other power of two, convert: KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); + // Exactly 1 possible 1? But not the high-bit because that is + // canonicalized to this form. APInt KnownZeroMask(~Known.Zero); - if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) { - // (X&4) == 2 --> false - // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Zext.getType(), isNE); - return replaceInstUsesWith(Zext, Res); - } - + if (KnownZeroMask.isPowerOf2() && + (Zext.getType()->getScalarSizeInBits() != + KnownZeroMask.logBase2() + 1)) { uint32_t ShAmt = KnownZeroMask.logBase2(); Value *In = Cmp->getOperand(0); if (ShAmt) { @@ -1045,10 +1067,9 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) In->getName() + ".lobit"); } - if (!Op1CV->isZero() == isNE) { // Toggle the low bit. - Constant *One = ConstantInt::get(In->getType(), 1); - In = Builder.CreateXor(In, One); - } + // Toggle the low bit for "X == 0". + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) + In = Builder.CreateXor(In, ConstantInt::get(In->getType(), 1)); if (Zext.getType() == In->getType()) return replaceInstUsesWith(Zext, In); @@ -1073,39 +1094,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) Value *And1 = Builder.CreateAnd(Lshr, ConstantInt::get(X->getType(), 1)); return replaceInstUsesWith(Zext, And1); } - - // icmp ne A, B is equal to xor A, B when A and B only really have one bit. - // It is also profitable to transform icmp eq into not(xor(A, B)) because - // that may lead to additional simplifications. - if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) { - Value *LHS = Cmp->getOperand(0); - Value *RHS = Cmp->getOperand(1); - - KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext); - KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext); - - if (KnownLHS == KnownRHS) { - APInt KnownBits = KnownLHS.Zero | KnownLHS.One; - APInt UnknownBit = ~KnownBits; - if (UnknownBit.countPopulation() == 1) { - Value *Result = Builder.CreateXor(LHS, RHS); - - // Mask off any bits that are set and won't be shifted away. - if (KnownLHS.One.uge(UnknownBit)) - Result = Builder.CreateAnd(Result, - ConstantInt::get(ITy, UnknownBit)); - - // Shift the bit we're testing down to the lsb. - Result = Builder.CreateLShr( - Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros())); - - if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) - Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1)); - Result->takeName(Cmp); - return replaceInstUsesWith(Zext, Result); - } - } - } } return nullptr; @@ -1235,23 +1223,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } } -Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { +Instruction *InstCombinerImpl::visitZExt(ZExtInst &Zext) { // If this zero extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this zext. - if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + if (Zext.hasOneUse() && isa<TruncInst>(Zext.user_back())) return nullptr; // If one of the common conversion will work, do it. - if (Instruction *Result = commonCastTransforms(CI)) + if (Instruction *Result = commonCastTransforms(Zext)) return Result; - Value *Src = CI.getOperand(0); - Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + Value *Src = Zext.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = Zext.getType(); // Try to extend the entire expression tree to the wide destination type. unsigned BitsToClear; if (shouldChangeType(SrcTy, DestTy) && - canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &CI)) { + canEvaluateZExtd(Src, DestTy, BitsToClear, *this, &Zext)) { assert(BitsToClear <= SrcTy->getScalarSizeInBits() && "Can't clear more bits than in SrcTy"); @@ -1259,25 +1247,25 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid zero extend: " - << CI << '\n'); + << Zext << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, false); assert(Res->getType() == DestTy); // Preserve debug values referring to Src if the zext is its last use. if (auto *SrcOp = dyn_cast<Instruction>(Src)) if (SrcOp->hasOneUse()) - replaceAllDbgUsesWith(*SrcOp, *Res, CI, DT); + replaceAllDbgUsesWith(*SrcOp, *Res, Zext, DT); - uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits()-BitsToClear; + uint32_t SrcBitsKept = SrcTy->getScalarSizeInBits() - BitsToClear; uint32_t DestBitSize = DestTy->getScalarSizeInBits(); // If the high bits are already filled with zeros, just replace this // cast with the result. if (MaskedValueIsZero(Res, APInt::getHighBitsSet(DestBitSize, - DestBitSize-SrcBitsKept), - 0, &CI)) - return replaceInstUsesWith(CI, Res); + DestBitSize - SrcBitsKept), + 0, &Zext)) + return replaceInstUsesWith(Zext, Res); // We need to emit an AND to clear the high bits. Constant *C = ConstantInt::get(Res->getType(), @@ -1288,7 +1276,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { // If this is a TRUNC followed by a ZEXT then we are dealing with integral // types and if the sizes are just right we can convert this into a logical // 'and' which will be much cheaper than the pair of casts. - if (TruncInst *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast + if (auto *CSrc = dyn_cast<TruncInst>(Src)) { // A->B->C cast // TODO: Subsume this into EvaluateInDifferentType. // Get the sizes of the types involved. We know that the intermediate type @@ -1296,7 +1284,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { Value *A = CSrc->getOperand(0); unsigned SrcSize = A->getType()->getScalarSizeInBits(); unsigned MidSize = CSrc->getType()->getScalarSizeInBits(); - unsigned DstSize = CI.getType()->getScalarSizeInBits(); + unsigned DstSize = DestTy->getScalarSizeInBits(); // If we're actually extending zero bits, then if // SrcSize < DstSize: zext(a & mask) // SrcSize == DstSize: a & mask @@ -1305,7 +1293,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { APInt AndValue(APInt::getLowBitsSet(SrcSize, MidSize)); Constant *AndConst = ConstantInt::get(A->getType(), AndValue); Value *And = Builder.CreateAnd(A, AndConst, CSrc->getName() + ".mask"); - return new ZExtInst(And, CI.getType()); + return new ZExtInst(And, DestTy); } if (SrcSize == DstSize) { @@ -1314,7 +1302,7 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { AndValue)); } if (SrcSize > DstSize) { - Value *Trunc = Builder.CreateTrunc(A, CI.getType()); + Value *Trunc = Builder.CreateTrunc(A, DestTy); APInt AndValue(APInt::getLowBitsSet(DstSize, MidSize)); return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Trunc->getType(), @@ -1322,34 +1310,46 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { } } - if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) - return transformZExtICmp(Cmp, CI); + if (auto *Cmp = dyn_cast<ICmpInst>(Src)) + return transformZExtICmp(Cmp, Zext); // zext(trunc(X) & C) -> (X & zext(C)). Constant *C; Value *X; if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && - X->getType() == CI.getType()) - return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType())); + X->getType() == DestTy) + return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, DestTy)); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && - X->getType() == CI.getType()) { - Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); + X->getType() == DestTy) { + Constant *ZC = ConstantExpr::getZExt(C, DestTy); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } + // If we are truncating, masking, and then zexting back to the original type, + // that's just a mask. This is not handled by canEvaluateZextd if the + // intermediate values have extra uses. This could be generalized further for + // a non-constant mask operand. + // zext (and (trunc X), C) --> and X, (zext C) + if (match(Src, m_And(m_Trunc(m_Value(X)), m_Constant(C))) && + X->getType() == DestTy) { + Constant *ZextC = ConstantExpr::getZExt(C, DestTy); + return BinaryOperator::CreateAnd(X, ZextC); + } + if (match(Src, m_VScale(DL))) { - if (CI.getFunction() && - CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Zext.getFunction() && + Zext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = + Zext.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); if (Log2_32(*MaxVScale) < TypeWidth) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + return replaceInstUsesWith(Zext, VScale); } } } @@ -1359,48 +1359,44 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { } /// Transform (sext icmp) to bitwise / integer operations to eliminate the icmp. -Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, - Instruction &CI) { - Value *Op0 = ICI->getOperand(0), *Op1 = ICI->getOperand(1); - ICmpInst::Predicate Pred = ICI->getPredicate(); +Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *Cmp, + SExtInst &Sext) { + Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1); + ICmpInst::Predicate Pred = Cmp->getPredicate(); // Don't bother if Op1 isn't of vector or integer type. if (!Op1->getType()->isIntOrIntVectorTy()) return nullptr; - if ((Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) || - (Pred == ICmpInst::ICMP_SGT && match(Op1, m_AllOnes()))) { - // (x <s 0) ? -1 : 0 -> ashr x, 31 -> all ones if negative - // (x >s -1) ? -1 : 0 -> not (ashr x, 31) -> all ones if positive + if (Pred == ICmpInst::ICMP_SLT && match(Op1, m_ZeroInt())) { + // sext (x <s 0) --> ashr x, 31 (all ones if negative) Value *Sh = ConstantInt::get(Op0->getType(), Op0->getType()->getScalarSizeInBits() - 1); Value *In = Builder.CreateAShr(Op0, Sh, Op0->getName() + ".lobit"); - if (In->getType() != CI.getType()) - In = Builder.CreateIntCast(In, CI.getType(), true /*SExt*/); + if (In->getType() != Sext.getType()) + In = Builder.CreateIntCast(In, Sext.getType(), true /*SExt*/); - if (Pred == ICmpInst::ICMP_SGT) - In = Builder.CreateNot(In, In->getName() + ".not"); - return replaceInstUsesWith(CI, In); + return replaceInstUsesWith(Sext, In); } if (ConstantInt *Op1C = dyn_cast<ConstantInt>(Op1)) { // If we know that only one bit of the LHS of the icmp can be set and we // have an equality comparison with zero or a power of 2, we can transform // the icmp and sext into bitwise/integer operations. - if (ICI->hasOneUse() && - ICI->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ - KnownBits Known = computeKnownBits(Op0, 0, &CI); + if (Cmp->hasOneUse() && + Cmp->isEquality() && (Op1C->isZero() || Op1C->getValue().isPowerOf2())){ + KnownBits Known = computeKnownBits(Op0, 0, &Sext); APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { - Value *In = ICI->getOperand(0); + Value *In = Cmp->getOperand(0); // If the icmp tests for a known zero bit we can constant fold it. if (!Op1C->isZero() && Op1C->getValue() != KnownZeroMask) { Value *V = Pred == ICmpInst::ICMP_NE ? - ConstantInt::getAllOnesValue(CI.getType()) : - ConstantInt::getNullValue(CI.getType()); - return replaceInstUsesWith(CI, V); + ConstantInt::getAllOnesValue(Sext.getType()) : + ConstantInt::getNullValue(Sext.getType()); + return replaceInstUsesWith(Sext, V); } if (!Op1C->isZero() == (Pred == ICmpInst::ICMP_NE)) { @@ -1431,9 +1427,9 @@ Instruction *InstCombinerImpl::transformSExtICmp(ICmpInst *ICI, KnownZeroMask.getBitWidth() - 1), "sext"); } - if (CI.getType() == In->getType()) - return replaceInstUsesWith(CI, In); - return CastInst::CreateIntegerCast(In, CI.getType(), true/*SExt*/); + if (Sext.getType() == In->getType()) + return replaceInstUsesWith(Sext, In); + return CastInst::CreateIntegerCast(In, Sext.getType(), true/*SExt*/); } } } @@ -1496,22 +1492,22 @@ static bool canEvaluateSExtd(Value *V, Type *Ty) { return false; } -Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { +Instruction *InstCombinerImpl::visitSExt(SExtInst &Sext) { // If this sign extend is only used by a truncate, let the truncate be // eliminated before we try to optimize this sext. - if (CI.hasOneUse() && isa<TruncInst>(CI.user_back())) + if (Sext.hasOneUse() && isa<TruncInst>(Sext.user_back())) return nullptr; - if (Instruction *I = commonCastTransforms(CI)) + if (Instruction *I = commonCastTransforms(Sext)) return I; - Value *Src = CI.getOperand(0); - Type *SrcTy = Src->getType(), *DestTy = CI.getType(); + Value *Src = Sext.getOperand(0); + Type *SrcTy = Src->getType(), *DestTy = Sext.getType(); unsigned SrcBitSize = SrcTy->getScalarSizeInBits(); unsigned DestBitSize = DestTy->getScalarSizeInBits(); // If the value being extended is zero or positive, use a zext instead. - if (isKnownNonNegative(Src, DL, 0, &AC, &CI, &DT)) + if (isKnownNonNegative(Src, DL, 0, &AC, &Sext, &DT)) return CastInst::Create(Instruction::ZExt, Src, DestTy); // Try to extend the entire expression tree to the wide destination type. @@ -1520,14 +1516,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { LLVM_DEBUG( dbgs() << "ICE: EvaluateInDifferentType converting expression type" " to avoid sign extend: " - << CI << '\n'); + << Sext << '\n'); Value *Res = EvaluateInDifferentType(Src, DestTy, true); assert(Res->getType() == DestTy); // If the high bits are already filled with sign bit, just replace this // cast with the result. - if (ComputeNumSignBits(Res, 0, &CI) > DestBitSize - SrcBitSize) - return replaceInstUsesWith(CI, Res); + if (ComputeNumSignBits(Res, 0, &Sext) > DestBitSize - SrcBitSize) + return replaceInstUsesWith(Sext, Res); // We need to emit a shl + ashr to do the sign extend. Value *ShAmt = ConstantInt::get(DestTy, DestBitSize-SrcBitSize); @@ -1540,7 +1536,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { // If the input has more sign bits than bits truncated, then convert // directly to final type. unsigned XBitSize = X->getType()->getScalarSizeInBits(); - if (ComputeNumSignBits(X, 0, &CI) > XBitSize - SrcBitSize) + if (ComputeNumSignBits(X, 0, &Sext) > XBitSize - SrcBitSize) return CastInst::CreateIntegerCast(X, DestTy, /* isSigned */ true); // If input is a trunc from the destination type, then convert into shifts. @@ -1563,8 +1559,8 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { } } - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) - return transformSExtICmp(ICI, CI); + if (auto *Cmp = dyn_cast<ICmpInst>(Src)) + return transformSExtICmp(Cmp, Sext); // If the input is a shl/ashr pair of a same constant, then this is a sign // extension from a smaller value. If we could trust arbitrary bitwidth @@ -1593,7 +1589,7 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { NumLowbitsLeft); NewShAmt = Constant::mergeUndefsWith(Constant::mergeUndefsWith(NewShAmt, BA), CA); - A = Builder.CreateShl(A, NewShAmt, CI.getName()); + A = Builder.CreateShl(A, NewShAmt, Sext.getName()); return BinaryOperator::CreateAShr(A, NewShAmt); } @@ -1616,13 +1612,14 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { } if (match(Src, m_VScale(DL))) { - if (CI.getFunction() && - CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { - Attribute Attr = CI.getFunction()->getFnAttribute(Attribute::VScaleRange); - if (Optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { + if (Sext.getFunction() && + Sext.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + Attribute Attr = + Sext.getFunction()->getFnAttribute(Attribute::VScaleRange); + if (std::optional<unsigned> MaxVScale = Attr.getVScaleRangeMax()) { if (Log2_32(*MaxVScale) < (SrcBitSize - 1)) { Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); - return replaceInstUsesWith(CI, VScale); + return replaceInstUsesWith(Sext, VScale); } } } @@ -1659,7 +1656,6 @@ static Type *shrinkFPConstant(ConstantFP *CFP) { // Determine if this is a vector of ConstantFPs and if so, return the minimal // type we can safely truncate all elements to. -// TODO: Make these support undef elements. static Type *shrinkFPConstantVector(Value *V) { auto *CV = dyn_cast<Constant>(V); auto *CVVTy = dyn_cast<FixedVectorType>(V->getType()); @@ -1673,6 +1669,9 @@ static Type *shrinkFPConstantVector(Value *V) { // For fixed-width vectors we find the minimal type by looking // through the constant values of the vector. for (unsigned i = 0; i != NumElts; ++i) { + if (isa<UndefValue>(CV->getAggregateElement(i))) + continue; + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) return nullptr; @@ -1688,7 +1687,7 @@ static Type *shrinkFPConstantVector(Value *V) { } // Make a vector type from the minimal type. - return FixedVectorType::get(MinType, NumElts); + return MinType ? FixedVectorType::get(MinType, NumElts) : nullptr; } /// Find the minimum FP type we can safely truncate to. @@ -2862,21 +2861,27 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { } } - // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as - // a byte-swap: - // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X) - // TODO: We should match the related pattern for bitreverse. - if (DestTy->isIntegerTy() && - DL.isLegalInteger(DestTy->getScalarSizeInBits()) && - SrcTy->getScalarSizeInBits() == 8 && - ShufElts.getKnownMinValue() % 2 == 0 && Shuf->hasOneUse() && - Shuf->isReverse()) { - assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); - assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op"); - Function *Bswap = - Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy); - Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); - return CallInst::Create(Bswap, { ScalarX }); + // A bitcasted-to-scalar and byte/bit reversing shuffle is better recognized + // as a byte/bit swap: + // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) -> bswap (bitcast X) + // bitcast <N x i1> (shuf X, undef, <N, N-1,...0>) -> bitreverse (bitcast X) + if (DestTy->isIntegerTy() && ShufElts.getKnownMinValue() % 2 == 0 && + Shuf->hasOneUse() && Shuf->isReverse()) { + unsigned IntrinsicNum = 0; + if (DL.isLegalInteger(DestTy->getScalarSizeInBits()) && + SrcTy->getScalarSizeInBits() == 8) { + IntrinsicNum = Intrinsic::bswap; + } else if (SrcTy->getScalarSizeInBits() == 1) { + IntrinsicNum = Intrinsic::bitreverse; + } + if (IntrinsicNum != 0) { + assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); + assert(match(ShufOp1, m_Undef()) && "Unexpected shuffle op"); + Function *BswapOrBitreverse = + Intrinsic::getDeclaration(CI.getModule(), IntrinsicNum, DestTy); + Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); + return CallInst::Create(BswapOrBitreverse, {ScalarX}); + } } } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 158d2e8289e0..1480a0ff9e2f 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,6 +17,7 @@ #include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -281,7 +282,7 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( if (!GEP->isInBounds()) { Type *IntPtrTy = DL.getIntPtrType(GEP->getType()); unsigned PtrSize = IntPtrTy->getIntegerBitWidth(); - if (Idx->getType()->getPrimitiveSizeInBits().getFixedSize() > PtrSize) + if (Idx->getType()->getPrimitiveSizeInBits().getFixedValue() > PtrSize) Idx = Builder.CreateTrunc(Idx, IntPtrTy); } @@ -403,108 +404,6 @@ Instruction *InstCombinerImpl::foldCmpLoadFromIndexedGlobal( return nullptr; } -/// Return a value that can be used to compare the *offset* implied by a GEP to -/// zero. For example, if we have &A[i], we want to return 'i' for -/// "icmp ne i, 0". Note that, in general, indices can be complex, and scales -/// are involved. The above expression would also be legal to codegen as -/// "icmp ne (i*4), 0" (assuming A is a pointer to i32). -/// This latter form is less amenable to optimization though, and we are allowed -/// to generate the first by knowing that pointer arithmetic doesn't overflow. -/// -/// If we can't emit an optimized form for this expression, this returns null. -/// -static Value *evaluateGEPOffsetExpression(User *GEP, InstCombinerImpl &IC, - const DataLayout &DL) { - gep_type_iterator GTI = gep_type_begin(GEP); - - // Check to see if this gep only has a single variable index. If so, and if - // any constant indices are a multiple of its scale, then we can compute this - // in terms of the scale of the variable index. For example, if the GEP - // implies an offset of "12 + i*4", then we can codegen this as "3 + i", - // because the expression will cross zero at the same point. - unsigned i, e = GEP->getNumOperands(); - int64_t Offset = 0; - for (i = 1; i != e; ++i, ++GTI) { - if (ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i))) { - // Compute the aggregate offset of constant indices. - if (CI->isZero()) continue; - - // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); - } else { - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size*CI->getSExtValue(); - } - } else { - // Found our variable index. - break; - } - } - - // If there are no variable indices, we must have a constant offset, just - // evaluate it the general way. - if (i == e) return nullptr; - - Value *VariableIdx = GEP->getOperand(i); - // Determine the scale factor of the variable element. For example, this is - // 4 if the variable index is into an array of i32. - uint64_t VariableScale = DL.getTypeAllocSize(GTI.getIndexedType()); - - // Verify that there are no other variable indices. If so, emit the hard way. - for (++i, ++GTI; i != e; ++i, ++GTI) { - ConstantInt *CI = dyn_cast<ConstantInt>(GEP->getOperand(i)); - if (!CI) return nullptr; - - // Compute the aggregate offset of constant indices. - if (CI->isZero()) continue; - - // Handle a struct index, which adds its field offset to the pointer. - if (StructType *STy = GTI.getStructTypeOrNull()) { - Offset += DL.getStructLayout(STy)->getElementOffset(CI->getZExtValue()); - } else { - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size*CI->getSExtValue(); - } - } - - // Okay, we know we have a single variable index, which must be a - // pointer/array/vector index. If there is no offset, life is simple, return - // the index. - Type *IntPtrTy = DL.getIntPtrType(GEP->getOperand(0)->getType()); - unsigned IntPtrWidth = IntPtrTy->getIntegerBitWidth(); - if (Offset == 0) { - // Cast to intptrty in case a truncation occurs. If an extension is needed, - // we don't need to bother extending: the extension won't affect where the - // computation crosses zero. - if (VariableIdx->getType()->getPrimitiveSizeInBits().getFixedSize() > - IntPtrWidth) { - VariableIdx = IC.Builder.CreateTrunc(VariableIdx, IntPtrTy); - } - return VariableIdx; - } - - // Otherwise, there is an index. The computation we will do will be modulo - // the pointer size. - Offset = SignExtend64(Offset, IntPtrWidth); - VariableScale = SignExtend64(VariableScale, IntPtrWidth); - - // To do this transformation, any constant index must be a multiple of the - // variable scale factor. For example, we can evaluate "12 + 4*i" as "3 + i", - // but we can't evaluate "10 + 3*i" in terms of i. Check that the offset is a - // multiple of the variable scale. - int64_t NewOffs = Offset / (int64_t)VariableScale; - if (Offset != NewOffs*(int64_t)VariableScale) - return nullptr; - - // Okay, we can do this evaluation. Start by converting the index to intptr. - if (VariableIdx->getType() != IntPtrTy) - VariableIdx = IC.Builder.CreateIntCast(VariableIdx, IntPtrTy, - true /*Signed*/); - Constant *OffsetVal = ConstantInt::get(IntPtrTy, NewOffs); - return IC.Builder.CreateAdd(VariableIdx, OffsetVal, "offset"); -} - /// Returns true if we can rewrite Start as a GEP with pointer Base /// and some integer offset. The nodes that need to be re-written /// for this transformation will be added to Explored. @@ -732,8 +631,8 @@ static Value *rewriteGEPAsOffset(Type *ElemTy, Value *Start, Value *Base, // Cast base to the expected type. Value *NewVal = Builder.CreateBitOrPointerCast( Base, PtrTy, Start->getName() + "to.ptr"); - NewVal = Builder.CreateInBoundsGEP( - ElemTy, NewVal, makeArrayRef(NewInsts[Val]), Val->getName() + ".ptr"); + NewVal = Builder.CreateInBoundsGEP(ElemTy, NewVal, ArrayRef(NewInsts[Val]), + Val->getName() + ".ptr"); NewVal = Builder.CreateBitOrPointerCast( NewVal, Val->getType(), Val->getName() + ".conv"); Val->replaceAllUsesWith(NewVal); @@ -841,18 +740,9 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - // FIXME: Support vector pointer GEPs. - if (PtrBase == RHS && GEPLHS->isInBounds() && - !GEPLHS->getType()->isVectorTy()) { + if (PtrBase == RHS && GEPLHS->isInBounds()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). - // This transformation (ignoring the base and scales) is valid because we - // know pointers can't overflow since the gep is inbounds. See if we can - // output an optimized form. - Value *Offset = evaluateGEPOffsetExpression(GEPLHS, *this, DL); - - // If not, synthesize the offset the hard way. - if (!Offset) - Offset = EmitGEPOffset(GEPLHS); + Value *Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); } @@ -926,8 +816,8 @@ Instruction *InstCombinerImpl::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Type *LHSIndexTy = LOffset->getType(); Type *RHSIndexTy = ROffset->getType(); if (LHSIndexTy != RHSIndexTy) { - if (LHSIndexTy->getPrimitiveSizeInBits().getFixedSize() < - RHSIndexTy->getPrimitiveSizeInBits().getFixedSize()) { + if (LHSIndexTy->getPrimitiveSizeInBits().getFixedValue() < + RHSIndexTy->getPrimitiveSizeInBits().getFixedValue()) { ROffset = Builder.CreateTrunc(ROffset, LHSIndexTy); } else LOffset = Builder.CreateTrunc(LOffset, RHSIndexTy); @@ -1480,7 +1370,8 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { return nullptr; // Try to simplify this compare to T/F based on the dominating condition. - Optional<bool> Imp = isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); + std::optional<bool> Imp = + isImpliedCondition(DomCond, &Cmp, DL, TrueBB == CmpBB); if (Imp) return replaceInstUsesWith(Cmp, ConstantInt::get(Cmp.getType(), *Imp)); @@ -1548,16 +1439,34 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } + Type *SrcTy = X->getType(); unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), - SrcBits = X->getType()->getScalarSizeInBits(); + SrcBits = SrcTy->getScalarSizeInBits(); + + // TODO: Handle any shifted constant by subtracting trailing zeros. + // TODO: Handle non-equality predicates. + Value *Y; + if (Cmp.isEquality() && match(X, m_Shl(m_One(), m_Value(Y)))) { + // (trunc (1 << Y) to iN) == 0 --> Y u>= N + // (trunc (1 << Y) to iN) != 0 --> Y u< N + if (C.isZero()) { + auto NewPred = (Pred == Cmp.ICMP_EQ) ? Cmp.ICMP_UGE : Cmp.ICMP_ULT; + return new ICmpInst(NewPred, Y, ConstantInt::get(SrcTy, DstBits)); + } + // (trunc (1 << Y) to iN) == 2**C --> Y == C + // (trunc (1 << Y) to iN) != 2**C --> Y != C + if (C.isPowerOf2()) + return new ICmpInst(Pred, Y, ConstantInt::get(SrcTy, C.logBase2())); + } + if (Cmp.isEquality() && Trunc->hasOneUse()) { // Canonicalize to a mask and wider compare if the wide type is suitable: // (trunc X to i8) == C --> (X & 0xff) == (zext C) - if (!X->getType()->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { - Constant *Mask = ConstantInt::get(X->getType(), - APInt::getLowBitsSet(SrcBits, DstBits)); + if (!SrcTy->isVectorTy() && shouldChangeType(DstBits, SrcBits)) { + Constant *Mask = + ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcBits, DstBits)); Value *And = Builder.CreateAnd(X, Mask); - Constant *WideC = ConstantInt::get(X->getType(), C.zext(SrcBits)); + Constant *WideC = ConstantInt::get(SrcTy, C.zext(SrcBits)); return new ICmpInst(Pred, And, WideC); } @@ -1570,7 +1479,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, // Pull in the high bits from known-ones set. APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); + return new ICmpInst(Pred, X, ConstantInt::get(SrcTy, NewRHS)); } } @@ -1583,11 +1492,10 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, if (isSignBitCheck(Pred, C, TrueIfSigned) && match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) && DstBits == SrcBits - ShAmtC->getZExtValue()) { - return TrueIfSigned - ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, - ConstantInt::getNullValue(X->getType())) - : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, - ConstantInt::getAllOnesValue(X->getType())); + return TrueIfSigned ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, + ConstantInt::getNullValue(SrcTy)) + : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, + ConstantInt::getAllOnesValue(SrcTy)); } return nullptr; @@ -1597,6 +1505,9 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, const APInt &C) { + if (Instruction *I = foldICmpXorShiftConst(Cmp, Xor, C)) + return I; + Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1660,6 +1571,37 @@ Instruction *InstCombinerImpl::foldICmpXorConstant(ICmpInst &Cmp, return nullptr; } +/// For power-of-2 C: +/// ((X s>> ShiftC) ^ X) u< C --> (X + C) u< (C << 1) +/// ((X s>> ShiftC) ^ X) u> (C - 1) --> (X + C) u> ((C << 1) - 1) +Instruction *InstCombinerImpl::foldICmpXorShiftConst(ICmpInst &Cmp, + BinaryOperator *Xor, + const APInt &C) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + APInt PowerOf2; + if (Pred == ICmpInst::ICMP_ULT) + PowerOf2 = C; + else if (Pred == ICmpInst::ICMP_UGT && !C.isMaxValue()) + PowerOf2 = C + 1; + else + return nullptr; + if (!PowerOf2.isPowerOf2()) + return nullptr; + Value *X; + const APInt *ShiftC; + if (!match(Xor, m_OneUse(m_c_Xor(m_Value(X), + m_AShr(m_Deferred(X), m_APInt(ShiftC)))))) + return nullptr; + uint64_t Shift = ShiftC->getLimitedValue(); + Type *XType = X->getType(); + if (Shift == 0 || PowerOf2.isMinSignedValue()) + return nullptr; + Value *Add = Builder.CreateAdd(X, ConstantInt::get(XType, PowerOf2)); + APInt Bound = + Pred == ICmpInst::ICMP_ULT ? PowerOf2 << 1 : ((PowerOf2 << 1) - 1); + return new ICmpInst(Pred, Add, ConstantInt::get(XType, Bound)); +} + /// Fold icmp (and (sh X, Y), C2), C1. Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, @@ -1780,7 +1722,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, APInt NewC2 = *C2; KnownBits Know = computeKnownBits(And->getOperand(0), 0, And); // Set high zeros of C2 to allow matching negated power-of-2. - NewC2 = *C2 + APInt::getHighBitsSet(C2->getBitWidth(), + NewC2 = *C2 | APInt::getHighBitsSet(C2->getBitWidth(), Know.countMinLeadingZeros()); // Restrict this fold only for single-use 'and' (PR10267). @@ -1904,6 +1846,20 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); } + // ((zext i1 X) & Y) == 0 --> !((trunc Y) & X) + // ((zext i1 X) & Y) != 0 --> ((trunc Y) & X) + // ((zext i1 X) & Y) == 1 --> ((trunc Y) & X) + // ((zext i1 X) & Y) != 1 --> !((trunc Y) & X) + if (match(And, m_OneUse(m_c_And(m_OneUse(m_ZExt(m_Value(X))), m_Value(Y)))) && + X->getType()->isIntOrIntVectorTy(1) && (C.isZero() || C.isOne())) { + Value *TruncY = Builder.CreateTrunc(Y, X->getType()); + if (C.isZero() ^ (Pred == CmpInst::ICMP_NE)) { + Value *And = Builder.CreateAnd(TruncY, X); + return BinaryOperator::CreateNot(And); + } + return BinaryOperator::CreateAnd(TruncY, X); + } + return nullptr; } @@ -1988,21 +1944,32 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, const APInt &C) { + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Type *MulTy = Mul->getType(); + Value *X = Mul->getOperand(0); + + // If there's no overflow: + // X * X == 0 --> X == 0 + // X * X != 0 --> X != 0 + if (Cmp.isEquality() && C.isZero() && X == Mul->getOperand(1) && + (Mul->hasNoUnsignedWrap() || Mul->hasNoSignedWrap())) + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); + const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; // If this is a test of the sign bit and the multiply is sign-preserving with - // a constant operand, use the multiply LHS operand instead. - ICmpInst::Predicate Pred = Cmp.getPredicate(); + // a constant operand, use the multiply LHS operand instead: + // (X * +MulC) < 0 --> X < 0 + // (X * -MulC) < 0 --> X > 0 if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - return new ICmpInst(Pred, Mul->getOperand(0), - Constant::getNullValue(Mul->getType())); + return new ICmpInst(Pred, X, ConstantInt::getNullValue(MulTy)); } - if (MulC->isZero() || !(Mul->hasNoSignedWrap() || Mul->hasNoUnsignedWrap())) + if (MulC->isZero() || (!Mul->hasNoSignedWrap() && !Mul->hasNoUnsignedWrap())) return nullptr; // If the multiply does not wrap, try to divide the compare constant by the @@ -2010,50 +1977,45 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, if (Cmp.isEquality()) { // (mul nsw X, MulC) == C --> X == C /s MulC if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.sdiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { - Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); - return new ICmpInst(Pred, Mul->getOperand(0), NewC); + Constant *NewC = ConstantInt::get(MulTy, C.udiv(*MulC)); + return new ICmpInst(Pred, X, NewC); } } + // With a matching no-overflow guarantee, fold the constants: + // (X * MulC) < C --> X < (C / MulC) + // (X * MulC) > C --> X > (C / MulC) + // TODO: Assert that Pred is not equal to SGE, SLE, UGE, ULE? Constant *NewC = nullptr; - - // FIXME: Add assert that Pred is not equal to ICMP_SGE, ICMP_SLE, - // ICMP_UGE, ICMP_ULE. - if (Mul->hasNoSignedWrap()) { - if (MulC->isNegative()) { - // MININT / -1 --> overflow. - if (C.isMinSignedValue() && MulC->isAllOnes()) - return nullptr; + // MININT / -1 --> overflow. + if (C.isMinSignedValue() && MulC->isAllOnes()) + return nullptr; + if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); - } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_SGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); - } - - if (Mul->hasNoUnsignedWrap()) { + MulTy, APIntOps::RoundingSDiv(C, *MulC, APInt::Rounding::DOWN)); + } else { + assert(Mul->hasNoUnsignedWrap() && "Expected mul nuw"); if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::UP)); if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) NewC = ConstantInt::get( - Mul->getType(), - APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); + MulTy, APIntOps::RoundingUDiv(C, *MulC, APInt::Rounding::DOWN)); } - return NewC ? new ICmpInst(Pred, Mul->getOperand(0), NewC) : nullptr; + return NewC ? new ICmpInst(Pred, X, NewC) : nullptr; } /// Fold icmp (shl 1, Y), C. @@ -2080,39 +2042,21 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, Pred = ICmpInst::ICMP_UGT; } - // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 - // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 unsigned CLog2 = C.logBase2(); - if (CLog2 == TypeBits - 1) { - if (Pred == ICmpInst::ICMP_UGE) - Pred = ICmpInst::ICMP_EQ; - else if (Pred == ICmpInst::ICMP_ULT) - Pred = ICmpInst::ICMP_NE; - } return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnes()) { - // (1 << Y) <= -1 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); - - // (1 << Y) > -1 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!C) { - // (1 << Y) < 0 -> Y == 31 - // (1 << Y) <= 0 -> Y == 31 - if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) - return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); + // (1 << Y) > 0 -> Y != 31 + // (1 << Y) > C -> Y != 31 if C is negative. + if (Pred == ICmpInst::ICMP_SGT && C.sle(0)) + return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - // (1 << Y) >= 0 -> Y != 31 - // (1 << Y) > 0 -> Y != 31 - if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) - return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } - } else if (Cmp.isEquality() && CIsPowerOf2) { - return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); + // (1 << Y) < 0 -> Y == 31 + // (1 << Y) < 1 -> Y == 31 + // (1 << Y) < C -> Y == 31 if C is negative and not signed min. + // Exclude signed min by subtracting 1 and lower the upper bound to 0. + if (Pred == ICmpInst::ICMP_SLT && (C-1).sle(0)) + return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); } return nullptr; @@ -2833,6 +2777,13 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, if (Pred == CmpInst::ICMP_SLT && C == *C2) return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax)); + // (X + -1) <u C --> X <=u C (if X is never null) + if (Pred == CmpInst::ICMP_ULT && C2->isAllOnes()) { + const SimplifyQuery Q = SQ.getWithInstruction(&Cmp); + if (llvm::isKnownNonZero(X, DL, 0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(ICmpInst::ICMP_ULE, X, ConstantInt::get(Ty, C)); + } + if (!Add->hasOneUse()) return nullptr; @@ -3095,7 +3046,7 @@ Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { ArrayRef<int> Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { // Check whether every element of Mask is the same constant - if (is_splat(Mask)) { + if (all_equal(Mask)) { auto *VecTy = cast<VectorType>(SrcType); auto *EltTy = cast<IntegerType>(VecTy->getElementType()); if (C->isSplat(EltTy->getBitWidth())) { @@ -3139,6 +3090,20 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstant(ICmpInst &Cmp) { if (auto *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0))) if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, II, *C)) return I; + + // (extractval ([s/u]subo X, Y), 0) == 0 --> X == Y + // (extractval ([s/u]subo X, Y), 0) != 0 --> X != Y + // TODO: This checks one-use, but that is not strictly necessary. + Value *Cmp0 = Cmp.getOperand(0); + Value *X, *Y; + if (C->isZero() && Cmp.isEquality() && Cmp0->hasOneUse() && + (match(Cmp0, + m_ExtractValue<0>(m_Intrinsic<Intrinsic::ssub_with_overflow>( + m_Value(X), m_Value(Y)))) || + match(Cmp0, + m_ExtractValue<0>(m_Intrinsic<Intrinsic::usub_with_overflow>( + m_Value(X), m_Value(Y)))))) + return new ICmpInst(Cmp.getPredicate(), X, Y); } if (match(Cmp.getOperand(1), m_APIntAllowUndef(C))) @@ -3174,10 +3139,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( } break; case Instruction::Add: { - // Replace ((add A, B) != C) with (A != C-B) if B & C are constants. - if (Constant *BOC = dyn_cast<Constant>(BOp1)) { + // (A + C2) == C --> A == (C - C2) + // (A + C2) != C --> A != (C - C2) + // TODO: Remove the one-use limitation? See discussion in D58633. + if (Constant *C2 = dyn_cast<Constant>(BOp1)) { if (BO->hasOneUse()) - return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); + return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, C2)); } else if (C.isZero()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. @@ -3433,7 +3400,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpWithConstant(ICmpInst &Cmp, case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) return I; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::SDiv: if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) return I; @@ -3580,8 +3547,8 @@ Instruction *InstCombinerImpl::foldSelectICmp(ICmpInst::Predicate Pred, auto SimplifyOp = [&](Value *Op, bool SelectCondIsTrue) -> Value * { if (Value *Res = simplifyICmpInst(Pred, Op, RHS, SQ)) return Res; - if (Optional<bool> Impl = isImpliedCondition(SI->getCondition(), Pred, Op, - RHS, DL, SelectCondIsTrue)) + if (std::optional<bool> Impl = isImpliedCondition( + SI->getCondition(), Pred, Op, RHS, DL, SelectCondIsTrue)) return ConstantInt::get(I.getType(), *Impl); return nullptr; }; @@ -4488,6 +4455,18 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } + // For unsigned predicates / eq / ne: + // icmp pred (x << 1), x --> icmp getSignedPredicate(pred) x, 0 + // icmp pred x, (x << 1) --> icmp getSignedPredicate(pred) 0, x + if (!ICmpInst::isSigned(Pred)) { + if (match(Op0, m_Shl(m_Specific(Op1), m_One()))) + return new ICmpInst(ICmpInst::getSignedPredicate(Pred), Op1, + Constant::getNullValue(Op1->getType())); + else if (match(Op1, m_Shl(m_Specific(Op0), m_One()))) + return new ICmpInst(ICmpInst::getSignedPredicate(Pred), + Constant::getNullValue(Op0->getType()), Op0); + } + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); @@ -4674,17 +4653,29 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) - // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) - ConstantInt *Cst1; - if ((Op0->hasOneUse() && match(Op0, m_ZExt(m_Value(A))) && - match(Op1, m_And(m_Value(B), m_ConstantInt(Cst1)))) || - (Op1->hasOneUse() && match(Op0, m_And(m_Value(B), m_ConstantInt(Cst1))) && - match(Op1, m_ZExt(m_Value(A))))) { - APInt Pow2 = Cst1->getValue() + 1; - if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && - Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) + if (match(Op1, m_ZExt(m_Value(A))) && + (Op0->hasOneUse() || Op1->hasOneUse())) { + // (B & (Pow2C-1)) == zext A --> A == trunc B + // (B & (Pow2C-1)) != zext A --> A != trunc B + const APInt *MaskC; + if (match(Op0, m_And(m_Value(B), m_LowBitMask(MaskC))) && + MaskC->countTrailingOnes() == A->getType()->getScalarSizeInBits()) return new ICmpInst(Pred, A, Builder.CreateTrunc(B, A->getType())); + + // Test if 2 values have different or same signbits: + // (X u>> BitWidth - 1) == zext (Y s> -1) --> (X ^ Y) < 0 + // (X u>> BitWidth - 1) != zext (Y s> -1) --> (X ^ Y) > -1 + unsigned OpWidth = Op0->getType()->getScalarSizeInBits(); + Value *X, *Y; + ICmpInst::Predicate Pred2; + if (match(Op0, m_LShr(m_Value(X), m_SpecificIntAllowUndef(OpWidth - 1))) && + match(A, m_ICmp(Pred2, m_Value(Y), m_AllOnes())) && + Pred2 == ICmpInst::ICMP_SGT && X->getType() == Y->getType()) { + Value *Xor = Builder.CreateXor(X, Y, "xor.signbits"); + Value *R = (Pred == ICmpInst::ICMP_EQ) ? Builder.CreateIsNeg(Xor) : + Builder.CreateIsNotNeg(Xor); + return replaceInstUsesWith(I, R); + } } // (A >> C) == (B >> C) --> (A^B) u< (1 << C) @@ -4708,6 +4699,7 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } // (A << C) == (B << C) --> ((A^B) & (~0U >> C)) == 0 + ConstantInt *Cst1; if (match(Op0, m_OneUse(m_Shl(m_Value(A), m_ConstantInt(Cst1)))) && match(Op1, m_OneUse(m_Shl(m_Value(B), m_Specific(Cst1))))) { unsigned TypeBits = Cst1->getBitWidth(); @@ -4788,6 +4780,20 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { Add, ConstantInt::get(A->getType(), C.shl(1))); } + // Canonicalize: + // Assume B_Pow2 != 0 + // 1. A & B_Pow2 != B_Pow2 -> A & B_Pow2 == 0 + // 2. A & B_Pow2 == B_Pow2 -> A & B_Pow2 != 0 + if (match(Op0, m_c_And(m_Specific(Op1), m_Value())) && + isKnownToBeAPowerOfTwo(Op1, /* OrZero */ false, 0, &I)) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op0, + ConstantInt::getNullValue(Op0->getType())); + + if (match(Op1, m_c_And(m_Specific(Op0), m_Value())) && + isKnownToBeAPowerOfTwo(Op0, /* OrZero */ false, 0, &I)) + return new ICmpInst(CmpInst::getInversePredicate(Pred), Op1, + ConstantInt::getNullValue(Op1->getType())); + return nullptr; } @@ -4993,7 +4999,7 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return foldICmpWithZextOrSext(ICmp); } -static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { +static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS, bool IsSigned) { switch (BinaryOp) { default: llvm_unreachable("Unsupported binary op"); @@ -5001,7 +5007,8 @@ static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { case Instruction::Sub: return match(RHS, m_Zero()); case Instruction::Mul: - return match(RHS, m_One()); + return !(RHS->getType()->isIntOrIntVectorTy(1) && IsSigned) && + match(RHS, m_One()); } } @@ -5048,7 +5055,7 @@ bool InstCombinerImpl::OptimizeOverflowCheck(Instruction::BinaryOps BinaryOp, if (auto *LHSTy = dyn_cast<VectorType>(LHS->getType())) OverflowTy = VectorType::get(OverflowTy, LHSTy->getElementCount()); - if (isNeutralValue(BinaryOp, RHS)) { + if (isNeutralValue(BinaryOp, RHS, IsSigned)) { Result = LHS; Overflow = ConstantInt::getFalse(OverflowTy); return true; @@ -5746,7 +5753,7 @@ static Instruction *foldICmpUsingBoolRange(ICmpInst &I, return nullptr; } -llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +std::optional<std::pair<CmpInst::Predicate, Constant *>> InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C) { assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && @@ -5769,13 +5776,13 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, if (auto *CI = dyn_cast<ConstantInt>(C)) { // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) - return llvm::None; + return std::nullopt; } else if (auto *FVTy = dyn_cast<FixedVectorType>(Type)) { unsigned NumElts = FVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { Constant *Elt = C->getAggregateElement(i); if (!Elt) - return llvm::None; + return std::nullopt; if (isa<UndefValue>(Elt)) continue; @@ -5784,14 +5791,14 @@ InstCombiner::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); if (!CI || !ConstantIsOk(CI)) - return llvm::None; + return std::nullopt; if (!SafeReplacementConstant) SafeReplacementConstant = CI; } } else { // ConstantExpr? - return llvm::None; + return std::nullopt; } // It may not be safe to change a compare predicate in the presence of @@ -5901,7 +5908,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_UGT: // icmp ugt -> icmp ult std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_ULT: // icmp ult i1 A, B -> ~A & B return BinaryOperator::CreateAnd(Builder.CreateNot(A), B); @@ -5909,7 +5916,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_SGT: // icmp sgt -> icmp slt std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SLT: // icmp slt i1 A, B -> A & ~B return BinaryOperator::CreateAnd(Builder.CreateNot(B), A); @@ -5917,7 +5924,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_UGE: // icmp uge -> icmp ule std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_ULE: // icmp ule i1 A, B -> ~A | B return BinaryOperator::CreateOr(Builder.CreateNot(A), B); @@ -5925,7 +5932,7 @@ static Instruction *canonicalizeICmpBool(ICmpInst &I, case ICmpInst::ICMP_SGE: // icmp sge -> icmp sle std::swap(A, B); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::ICMP_SLE: // icmp sle i1 A, B -> A | ~B return BinaryOperator::CreateOr(Builder.CreateNot(B), A); @@ -5986,6 +5993,31 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, const CmpInst::Predicate Pred = Cmp.getPredicate(); Value *LHS = Cmp.getOperand(0), *RHS = Cmp.getOperand(1); Value *V1, *V2; + + auto createCmpReverse = [&](CmpInst::Predicate Pred, Value *X, Value *Y) { + Value *V = Builder.CreateCmp(Pred, X, Y, Cmp.getName()); + if (auto *I = dyn_cast<Instruction>(V)) + I->copyIRFlags(&Cmp); + Module *M = Cmp.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + if (match(LHS, m_VecReverse(m_Value(V1)))) { + // cmp Pred, rev(V1), rev(V2) --> rev(cmp Pred, V1, V2) + if (match(RHS, m_VecReverse(m_Value(V2))) && + (LHS->hasOneUse() || RHS->hasOneUse())) + return createCmpReverse(Pred, V1, V2); + + // cmp Pred, rev(V1), RHSSplat --> rev(cmp Pred, V1, RHSSplat) + if (LHS->hasOneUse() && isSplatValue(RHS)) + return createCmpReverse(Pred, V1, RHS); + } + // cmp Pred, LHSSplat, rev(V2) --> rev(cmp Pred, LHSSplat, V2) + else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2))))) + return createCmpReverse(Pred, LHS, V2); + ArrayRef<int> M; if (!match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(M)))) return nullptr; @@ -6318,11 +6350,11 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } // (zext a) * (zext b) --> llvm.umul.with.overflow. - if (match(Op0, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (match(Op0, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { if (Instruction *R = processUMulZExtIdiom(I, Op0, Op1, *this)) return R; } - if (match(Op1, m_Mul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { + if (match(Op1, m_NUWMul(m_ZExt(m_Value(A)), m_ZExt(m_Value(B))))) { if (Instruction *R = processUMulZExtIdiom(I, Op1, Op0, *this)) return R; } @@ -6668,10 +6700,48 @@ static Instruction *foldFCmpReciprocalAndZero(FCmpInst &I, Instruction *LHSI, /// Optimize fabs(X) compared with zero. static Instruction *foldFabsWithFcmpZero(FCmpInst &I, InstCombinerImpl &IC) { Value *X; - if (!match(I.getOperand(0), m_FAbs(m_Value(X))) || - !match(I.getOperand(1), m_PosZeroFP())) + if (!match(I.getOperand(0), m_FAbs(m_Value(X)))) + return nullptr; + + const APFloat *C; + if (!match(I.getOperand(1), m_APFloat(C))) return nullptr; + if (!C->isPosZero()) { + if (!C->isSmallestNormalized()) + return nullptr; + + const Function *F = I.getFunction(); + DenormalMode Mode = F->getDenormalMode(C->getSemantics()); + if (Mode.Input == DenormalMode::PreserveSign || + Mode.Input == DenormalMode::PositiveZero) { + + auto replaceFCmp = [](FCmpInst *I, FCmpInst::Predicate P, Value *X) { + Constant *Zero = ConstantFP::getNullValue(X->getType()); + return new FCmpInst(P, X, Zero, "", I); + }; + + switch (I.getPredicate()) { + case FCmpInst::FCMP_OLT: + // fcmp olt fabs(x), smallest_normalized_number -> fcmp oeq x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_OEQ, X); + case FCmpInst::FCMP_UGE: + // fcmp uge fabs(x), smallest_normalized_number -> fcmp une x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_UNE, X); + case FCmpInst::FCMP_OGE: + // fcmp oge fabs(x), smallest_normalized_number -> fcmp one x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_ONE, X); + case FCmpInst::FCMP_ULT: + // fcmp ult fabs(x), smallest_normalized_number -> fcmp ueq x, 0.0 + return replaceFCmp(&I, FCmpInst::FCMP_UEQ, X); + default: + break; + } + } + + return nullptr; + } + auto replacePredAndOp0 = [&IC](FCmpInst *I, FCmpInst::Predicate P, Value *X) { I->setPredicate(P); return IC.replaceOperand(*I, 0, X); @@ -6828,6 +6898,26 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op1, m_AnyZeroFP()) && !match(Op1, m_PosZeroFP())) return replaceOperand(I, 1, ConstantFP::getNullValue(OpType)); + // Ignore signbit of bitcasted int when comparing equality to FP 0.0: + // fcmp oeq/une (bitcast X), 0.0 --> (and X, SignMaskC) ==/!= 0 + if (match(Op1, m_PosZeroFP()) && + match(Op0, m_OneUse(m_BitCast(m_Value(X)))) && + X->getType()->isVectorTy() == OpType->isVectorTy() && + X->getType()->getScalarSizeInBits() == OpType->getScalarSizeInBits()) { + ICmpInst::Predicate IntPred = ICmpInst::BAD_ICMP_PREDICATE; + if (Pred == FCmpInst::FCMP_OEQ) + IntPred = ICmpInst::ICMP_EQ; + else if (Pred == FCmpInst::FCMP_UNE) + IntPred = ICmpInst::ICMP_NE; + + if (IntPred != ICmpInst::BAD_ICMP_PREDICATE) { + Type *IntTy = X->getType(); + const APInt &SignMask = ~APInt::getSignMask(IntTy->getScalarSizeInBits()); + Value *MaskX = Builder.CreateAnd(X, ConstantInt::get(IntTy, SignMask)); + return new ICmpInst(IntPred, MaskX, ConstantInt::getNullValue(IntTy)); + } + } + // Handle fcmp with instruction LHS and constant RHS. Instruction *LHSI; Constant *RHSC; @@ -6866,10 +6956,9 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { if (match(Op0, m_FNeg(m_Value(X)))) { // fcmp pred (fneg X), C --> fcmp swap(pred) X, -C Constant *C; - if (match(Op1, m_Constant(C))) { - Constant *NegC = ConstantExpr::getFNeg(C); - return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); - } + if (match(Op1, m_Constant(C))) + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return new FCmpInst(I.getSwappedPredicate(), X, NegC, "", &I); } if (match(Op0, m_FPExt(m_Value(X)))) { @@ -6915,7 +7004,7 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { APFloat Fabs = TruncC; Fabs.clearSign(); if (!Lossy && - (!(Fabs < APFloat::getSmallestNormalized(FPSem)) || Fabs.isZero())) { + (Fabs.isZero() || !(Fabs < APFloat::getSmallestNormalized(FPSem)))) { Constant *NewC = ConstantFP::get(X->getType(), TruncC); return new FCmpInst(Pred, X, NewC, "", &I); } @@ -6942,6 +7031,24 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { } } + { + Value *CanonLHS = nullptr, *CanonRHS = nullptr; + match(Op0, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonLHS))); + match(Op1, m_Intrinsic<Intrinsic::canonicalize>(m_Value(CanonRHS))); + + // (canonicalize(x) == x) => (x == x) + if (CanonLHS == Op1) + return new FCmpInst(Pred, Op1, Op1, "", &I); + + // (x == canonicalize(x)) => (x == x) + if (CanonRHS == Op0) + return new FCmpInst(Pred, Op0, Op0, "", &I); + + // (canonicalize(x) == canonicalize(y)) => (x == y) + if (CanonLHS && CanonRHS) + return new FCmpInst(Pred, CanonLHS, CanonRHS, "", &I); + } + if (I.getType()->isVectorTy()) if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 664226ec187b..f4e88b122383 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -106,7 +106,8 @@ public: Value *simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, bool Inverted); Instruction *visitAnd(BinaryOperator &I); Instruction *visitOr(BinaryOperator &I); - bool sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I); + bool sinkNotIntoLogicalOp(Instruction &I); + bool sinkNotIntoOtherHandOfLogicalOp(Instruction &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); Value *reassociateShiftAmtsOfTwoSameDirectionShifts( @@ -127,8 +128,8 @@ public: Instruction *commonCastTransforms(CastInst &CI); Instruction *commonPointerCastTransforms(CastInst &CI); Instruction *visitTrunc(TruncInst &CI); - Instruction *visitZExt(ZExtInst &CI); - Instruction *visitSExt(SExtInst &CI); + Instruction *visitZExt(ZExtInst &Zext); + Instruction *visitSExt(SExtInst &Sext); Instruction *visitFPTrunc(FPTruncInst &CI); Instruction *visitFPExt(CastInst &CI); Instruction *visitFPToUI(FPToUIInst &FI); @@ -167,6 +168,7 @@ public: Instruction *visitInsertValueInst(InsertValueInst &IV); Instruction *visitInsertElementInst(InsertElementInst &IE); Instruction *visitExtractElementInst(ExtractElementInst &EI); + Instruction *simplifyBinOpSplats(ShuffleVectorInst &SVI); Instruction *visitShuffleVectorInst(ShuffleVectorInst &SVI); Instruction *visitExtractValueInst(ExtractValueInst &EV); Instruction *visitLandingPadInst(LandingPadInst &LI); @@ -247,9 +249,9 @@ private: /// \return null if the transformation cannot be performed. If the /// transformation can be performed the new instruction that replaces the /// (zext icmp) pair will be returned. - Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI); + Instruction *transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext); - Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); + Instruction *transformSExtICmp(ICmpInst *Cmp, SExtInst &Sext); bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { @@ -329,7 +331,7 @@ private: Instruction *matchSAddSubSat(IntrinsicInst &MinMax1); Instruction *foldNot(BinaryOperator &I); - void freelyInvertAllUsersOf(Value *V); + void freelyInvertAllUsersOf(Value *V, Value *IgnoredUser = nullptr); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -360,14 +362,24 @@ private: Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd, bool IsLogicalSelect = false); + Instruction *foldLogicOfIsFPClass(BinaryOperator &Operator, Value *LHS, + Value *RHS); + + Instruction * + canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI, bool IsAnd, bool IsLogical = false); - Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D); - Value *getSelectCondition(Value *A, Value *B); + Value *matchSelectFromAndOr(Value *A, Value *B, Value *C, Value *D, + bool InvertFalseVal = false); + Value *getSelectCondition(Value *A, Value *B, bool ABIsTheSame); + Instruction *foldLShrOverflowBit(BinaryOperator &I); + Instruction *foldExtractOfOverflowIntrinsic(ExtractValueInst &EV); Instruction *foldIntrinsicWithOverflowCommon(IntrinsicInst *II); Instruction *foldFPSignBitOps(BinaryOperator &I); + Instruction *foldFDivConstantDivisor(BinaryOperator &I); // Optimize one of these forms: // and i1 Op, SI / select i1 Op, i1 SI, i1 false (if IsAnd = true) @@ -377,64 +389,6 @@ private: bool IsAnd); public: - /// Inserts an instruction \p New before instruction \p Old - /// - /// Also adds the new instruction to the worklist and returns \p New so that - /// it is suitable for use as the return from the visitation patterns. - Instruction *InsertNewInstBefore(Instruction *New, Instruction &Old) { - assert(New && !New->getParent() && - "New instruction already inserted into a basic block!"); - BasicBlock *BB = Old.getParent(); - BB->getInstList().insert(Old.getIterator(), New); // Insert inst - Worklist.add(New); - return New; - } - - /// Same as InsertNewInstBefore, but also sets the debug loc. - Instruction *InsertNewInstWith(Instruction *New, Instruction &Old) { - New->setDebugLoc(Old.getDebugLoc()); - return InsertNewInstBefore(New, Old); - } - - /// A combiner-aware RAUW-like routine. - /// - /// This method is to be used when an instruction is found to be dead, - /// replaceable with another preexisting expression. Here we add all uses of - /// I to the worklist, replace all uses of I with the new value, then return - /// I, so that the inst combiner will know that I was modified. - Instruction *replaceInstUsesWith(Instruction &I, Value *V) { - // If there are no uses to replace, then we return nullptr to indicate that - // no changes were made to the program. - if (I.use_empty()) return nullptr; - - Worklist.pushUsersToWorkList(I); // Add all modified instrs to worklist. - - // If we are replacing the instruction with itself, this must be in a - // segment of unreachable code, so just clobber the instruction. - if (&I == V) - V = PoisonValue::get(I.getType()); - - LLVM_DEBUG(dbgs() << "IC: Replacing " << I << "\n" - << " with " << *V << '\n'); - - I.replaceAllUsesWith(V); - MadeIRChange = true; - return &I; - } - - /// Replace operand of instruction and add old operand to the worklist. - Instruction *replaceOperand(Instruction &I, unsigned OpNum, Value *V) { - Worklist.addValue(I.getOperand(OpNum)); - I.setOperand(OpNum, V); - return &I; - } - - /// Replace use and add the previously used value to the worklist. - void replaceUse(Use &U, Value *NewValue) { - Worklist.addValue(U); - U = NewValue; - } - /// Create and insert the idiom we use to indicate a block is unreachable /// without having to rewrite the CFG from within InstCombine. void CreateNonTerminatorUnreachable(Instruction *InsertAt) { @@ -467,67 +421,6 @@ public: return nullptr; // Don't do anything with FI } - void computeKnownBits(const Value *V, KnownBits &Known, - unsigned Depth, const Instruction *CxtI) const { - llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); - } - - KnownBits computeKnownBits(const Value *V, unsigned Depth, - const Instruction *CxtI) const { - return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); - } - - bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero = false, - unsigned Depth = 0, - const Instruction *CxtI = nullptr) { - return llvm::isKnownToBeAPowerOfTwo(V, DL, OrZero, Depth, &AC, CxtI, &DT); - } - - bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth = 0, - const Instruction *CxtI = nullptr) const { - return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); - } - - unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, - const Instruction *CxtI = nullptr) const { - return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedMul(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedMul(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedMul(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedAdd(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForUnsignedSub(const Value *LHS, - const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForUnsignedSub(LHS, RHS, DL, &AC, CxtI, &DT); - } - - OverflowResult computeOverflowForSignedSub(const Value *LHS, const Value *RHS, - const Instruction *CxtI) const { - return llvm::computeOverflowForSignedSub(LHS, RHS, DL, &AC, CxtI, &DT); - } - OverflowResult computeOverflow( Instruction::BinaryOps BinaryOp, bool IsSigned, Value *LHS, Value *RHS, Instruction *CxtI) const; @@ -543,7 +436,7 @@ public: /// -> "A*(B+C)") or expanding out if this results in simplifications (eg: "A /// & (B | C) -> (A&B) | (A&C)" if this is a win). Returns the simplified /// value, or null if it didn't simplify. - Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + Value *foldUsingDistributiveLaws(BinaryOperator &I); /// Tries to simplify add operations using the definition of remainder. /// @@ -559,8 +452,7 @@ public: /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). - Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, - Value *, Value *, Value *); + Value *tryFactorizationFolds(BinaryOperator &I); /// Match a select chain which produces one of three values based on whether /// the LHS is less than, equal to, or greater than RHS respectively. @@ -647,7 +539,7 @@ public: /// If an integer typed PHI has only one use which is an IntToPtr operation, /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise /// insert a new pointer typed PHI and replace the original one. - Instruction *foldIntegerTypedPHI(PHINode &PN); + bool foldIntegerTypedPHI(PHINode &PN); /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the /// folded operation. @@ -716,6 +608,8 @@ public: const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, const APInt &C1, const APInt &C2); + Instruction *foldICmpXorShiftConst(ICmpInst &Cmp, BinaryOperator *Xor, + const APInt &C); Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, const APInt &C2); Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, @@ -731,6 +625,7 @@ public: Instruction *foldICmpBitCast(ICmpInst &Cmp); // Helpers of visitSelectInst(). + Instruction *foldSelectOfBools(SelectInst &SI); Instruction *foldSelectExtConst(SelectInst &Sel); Instruction *foldSelectOpOp(SelectInst &SI, Instruction *TI, Instruction *FI); Instruction *foldSelectIntoOp(SelectInst &SI, Value *, Value *); @@ -790,13 +685,13 @@ class Negator final { std::array<Value *, 2> getSortedOperandsOfBinOp(Instruction *I); - LLVM_NODISCARD Value *visitImpl(Value *V, unsigned Depth); + [[nodiscard]] Value *visitImpl(Value *V, unsigned Depth); - LLVM_NODISCARD Value *negate(Value *V, unsigned Depth); + [[nodiscard]] Value *negate(Value *V, unsigned Depth); /// Recurse depth-first and attempt to sink the negation. /// FIXME: use worklist? - LLVM_NODISCARD Optional<Result> run(Value *Root); + [[nodiscard]] std::optional<Result> run(Value *Root); Negator(const Negator &) = delete; Negator(Negator &&) = delete; @@ -806,8 +701,8 @@ class Negator final { public: /// Attempt to negate \p Root. Retuns nullptr if negation can't be performed, /// otherwise returns negated value. - LLVM_NODISCARD static Value *Negate(bool LHSIsZero, Value *Root, - InstCombinerImpl &IC); + [[nodiscard]] static Value *Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC); }; } // end namespace llvm diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index e03b7026f802..41bc65620ff6 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -28,30 +28,42 @@ using namespace PatternMatch; #define DEBUG_TYPE "instcombine" -STATISTIC(NumDeadStore, "Number of dead stores eliminated"); +STATISTIC(NumDeadStore, "Number of dead stores eliminated"); STATISTIC(NumGlobalCopies, "Number of allocas copied from constant global"); -/// isOnlyCopiedFromConstantGlobal - Recursively walk the uses of a (derived) +static cl::opt<unsigned> MaxCopiedFromConstantUsers( + "instcombine-max-copied-from-constant-users", cl::init(128), + cl::desc("Maximum users to visit in copy from constant transform"), + cl::Hidden); + +/// isOnlyCopiedFromConstantMemory - Recursively walk the uses of a (derived) /// pointer to an alloca. Ignore any reads of the pointer, return false if we /// see any stores or other unknown uses. If we see pointer arithmetic, keep /// track of whether it moves the pointer (with IsOffset) but otherwise traverse /// the uses. If we see a memcpy/memmove that targets an unoffseted pointer to -/// the alloca, and if the source pointer is a pointer to a constant global, we -/// can optimize this. +/// the alloca, and if the source pointer is a pointer to a constant memory +/// location, we can optimize this. static bool -isOnlyCopiedFromConstantMemory(AAResults *AA, - Value *V, MemTransferInst *&TheCopy, +isOnlyCopiedFromConstantMemory(AAResults *AA, AllocaInst *V, + MemTransferInst *&TheCopy, SmallVectorImpl<Instruction *> &ToDelete) { // We track lifetime intrinsics as we encounter them. If we decide to go - // ahead and replace the value with the global, this lets the caller quickly - // eliminate the markers. + // ahead and replace the value with the memory location, this lets the caller + // quickly eliminate the markers. + + using ValueAndIsOffset = PointerIntPair<Value *, 1, bool>; + SmallVector<ValueAndIsOffset, 32> Worklist; + SmallPtrSet<ValueAndIsOffset, 32> Visited; + Worklist.emplace_back(V, false); + while (!Worklist.empty()) { + ValueAndIsOffset Elem = Worklist.pop_back_val(); + if (!Visited.insert(Elem).second) + continue; + if (Visited.size() > MaxCopiedFromConstantUsers) + return false; - SmallVector<std::pair<Value *, bool>, 35> ValuesToInspect; - ValuesToInspect.emplace_back(V, false); - while (!ValuesToInspect.empty()) { - auto ValuePair = ValuesToInspect.pop_back_val(); - const bool IsOffset = ValuePair.second; - for (auto &U : ValuePair.first->uses()) { + const auto [Value, IsOffset] = Elem; + for (auto &U : Value->uses()) { auto *I = cast<Instruction>(U.getUser()); if (auto *LI = dyn_cast<LoadInst>(I)) { @@ -60,15 +72,22 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, continue; } - if (isa<BitCastInst>(I) || isa<AddrSpaceCastInst>(I)) { + if (isa<PHINode, SelectInst>(I)) { + // We set IsOffset=true, to forbid the memcpy from occurring after the + // phi: If one of the phi operands is not based on the alloca, we + // would incorrectly omit a write. + Worklist.emplace_back(I, true); + continue; + } + if (isa<BitCastInst, AddrSpaceCastInst>(I)) { // If uses of the bitcast are ok, we are ok. - ValuesToInspect.emplace_back(I, IsOffset); + Worklist.emplace_back(I, IsOffset); continue; } if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { // If the GEP has all zero indices, it doesn't offset the pointer. If it // doesn't, it does. - ValuesToInspect.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); + Worklist.emplace_back(I, IsOffset || !GEP->hasAllZeroIndices()); continue; } @@ -85,11 +104,12 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, if (IsArgOperand && Call->isInAllocaArgument(DataOpNo)) return false; - // If this is a readonly/readnone call site, then we know it is just a - // load (but one that potentially returns the value itself), so we can + // If this call site doesn't modify the memory, then we know it is just + // a load (but one that potentially returns the value itself), so we can // ignore it if we know that the value isn't captured. - if (Call->onlyReadsMemory() && - (Call->use_empty() || Call->doesNotCapture(DataOpNo))) + bool NoCapture = Call->doesNotCapture(DataOpNo); + if ((Call->onlyReadsMemory() && (Call->use_empty() || NoCapture)) || + (Call->onlyReadsMemory(DataOpNo) && NoCapture)) continue; // If this is being passed as a byval argument, the caller is making a @@ -111,12 +131,14 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, if (!MI) return false; + // If the transfer is volatile, reject it. + if (MI->isVolatile()) + return false; + // If the transfer is using the alloca as a source of the transfer, then // ignore it since it is a load (unless the transfer is volatile). - if (U.getOperandNo() == 1) { - if (MI->isVolatile()) return false; + if (U.getOperandNo() == 1) continue; - } // If we already have seen a copy, reject the second one. if (TheCopy) return false; @@ -128,8 +150,8 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, // If the memintrinsic isn't using the alloca as the dest, reject it. if (U.getOperandNo() != 0) return false; - // If the source of the memcpy/move is not a constant global, reject it. - if (!AA->pointsToConstantMemory(MI->getSource())) + // If the source of the memcpy/move is not constant, reject it. + if (isModSet(AA->getModRefInfoMask(MI->getSource()))) return false; // Otherwise, the transform is safe. Remember the copy instruction. @@ -139,9 +161,10 @@ isOnlyCopiedFromConstantMemory(AAResults *AA, return true; } -/// isOnlyCopiedFromConstantGlobal - Return true if the specified alloca is only -/// modified by a copy from a constant global. If we can prove this, we can -/// replace any uses of the alloca with uses of the global directly. +/// isOnlyCopiedFromConstantMemory - Return true if the specified alloca is only +/// modified by a copy from a constant memory location. If we can prove this, we +/// can replace any uses of the alloca with uses of the memory location +/// directly. static MemTransferInst * isOnlyCopiedFromConstantMemory(AAResults *AA, AllocaInst *AI, @@ -165,7 +188,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, } static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, - AllocaInst &AI) { + AllocaInst &AI, DominatorTree &DT) { // Check for array size of 1 (scalar allocation). if (!AI.isArrayAllocation()) { // i32 1 is the canonical array size for scalar allocations. @@ -184,6 +207,8 @@ static Instruction *simplifyAllocaArraySize(InstCombinerImpl &IC, nullptr, AI.getName()); New->setAlignment(AI.getAlign()); + replaceAllDbgUsesWith(AI, *New, *New, DT); + // Scan to the end of the allocation instructions, to skip over a block of // allocas if possible...also skip interleaved debug info // @@ -234,31 +259,83 @@ namespace { // instruction. class PointerReplacer { public: - PointerReplacer(InstCombinerImpl &IC) : IC(IC) {} + PointerReplacer(InstCombinerImpl &IC, Instruction &Root) + : IC(IC), Root(Root) {} - bool collectUsers(Instruction &I); - void replacePointer(Instruction &I, Value *V); + bool collectUsers(); + void replacePointer(Value *V); private: + bool collectUsersRecursive(Instruction &I); void replace(Instruction *I); Value *getReplacement(Value *I); + bool isAvailable(Instruction *I) const { + return I == &Root || Worklist.contains(I); + } + SmallPtrSet<Instruction *, 32> ValuesToRevisit; SmallSetVector<Instruction *, 4> Worklist; MapVector<Value *, Value *> WorkMap; InstCombinerImpl &IC; + Instruction &Root; }; } // end anonymous namespace -bool PointerReplacer::collectUsers(Instruction &I) { - for (auto U : I.users()) { +bool PointerReplacer::collectUsers() { + if (!collectUsersRecursive(Root)) + return false; + + // Ensure that all outstanding (indirect) users of I + // are inserted into the Worklist. Return false + // otherwise. + for (auto *Inst : ValuesToRevisit) + if (!Worklist.contains(Inst)) + return false; + return true; +} + +bool PointerReplacer::collectUsersRecursive(Instruction &I) { + for (auto *U : I.users()) { auto *Inst = cast<Instruction>(&*U); if (auto *Load = dyn_cast<LoadInst>(Inst)) { if (Load->isVolatile()) return false; Worklist.insert(Load); - } else if (isa<GetElementPtrInst>(Inst) || isa<BitCastInst>(Inst)) { + } else if (auto *PHI = dyn_cast<PHINode>(Inst)) { + // All incoming values must be instructions for replacability + if (any_of(PHI->incoming_values(), + [](Value *V) { return !isa<Instruction>(V); })) + return false; + + // If at least one incoming value of the PHI is not in Worklist, + // store the PHI for revisiting and skip this iteration of the + // loop. + if (any_of(PHI->incoming_values(), [this](Value *V) { + return !isAvailable(cast<Instruction>(V)); + })) { + ValuesToRevisit.insert(Inst); + continue; + } + + Worklist.insert(PHI); + if (!collectUsersRecursive(*PHI)) + return false; + } else if (auto *SI = dyn_cast<SelectInst>(Inst)) { + if (!isa<Instruction>(SI->getTrueValue()) || + !isa<Instruction>(SI->getFalseValue())) + return false; + + if (!isAvailable(cast<Instruction>(SI->getTrueValue())) || + !isAvailable(cast<Instruction>(SI->getFalseValue()))) { + ValuesToRevisit.insert(Inst); + continue; + } + Worklist.insert(SI); + if (!collectUsersRecursive(*SI)) + return false; + } else if (isa<GetElementPtrInst, BitCastInst>(Inst)) { Worklist.insert(Inst); - if (!collectUsers(*Inst)) + if (!collectUsersRecursive(*Inst)) return false; } else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) { if (MI->isVolatile()) @@ -293,6 +370,14 @@ void PointerReplacer::replace(Instruction *I) { IC.InsertNewInstWith(NewI, *LT); IC.replaceInstUsesWith(*LT, NewI); WorkMap[LT] = NewI; + } else if (auto *PHI = dyn_cast<PHINode>(I)) { + Type *NewTy = getReplacement(PHI->getIncomingValue(0))->getType(); + auto *NewPHI = PHINode::Create(NewTy, PHI->getNumIncomingValues(), + PHI->getName(), PHI); + for (unsigned int I = 0; I < PHI->getNumIncomingValues(); ++I) + NewPHI->addIncoming(getReplacement(PHI->getIncomingValue(I)), + PHI->getIncomingBlock(I)); + WorkMap[PHI] = NewPHI; } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { auto *V = getReplacement(GEP->getPointerOperand()); assert(V && "Operand not replaced"); @@ -313,6 +398,13 @@ void PointerReplacer::replace(Instruction *I) { IC.InsertNewInstWith(NewI, *BC); NewI->takeName(BC); WorkMap[BC] = NewI; + } else if (auto *SI = dyn_cast<SelectInst>(I)) { + auto *NewSI = SelectInst::Create( + SI->getCondition(), getReplacement(SI->getTrueValue()), + getReplacement(SI->getFalseValue()), SI->getName(), nullptr, SI); + IC.InsertNewInstWith(NewSI, *SI); + NewSI->takeName(SI); + WorkMap[SI] = NewSI; } else if (auto *MemCpy = dyn_cast<MemTransferInst>(I)) { auto *SrcV = getReplacement(MemCpy->getRawSource()); // The pointer may appear in the destination of a copy, but we don't want to @@ -339,27 +431,27 @@ void PointerReplacer::replace(Instruction *I) { } } -void PointerReplacer::replacePointer(Instruction &I, Value *V) { +void PointerReplacer::replacePointer(Value *V) { #ifndef NDEBUG - auto *PT = cast<PointerType>(I.getType()); + auto *PT = cast<PointerType>(Root.getType()); auto *NT = cast<PointerType>(V->getType()); assert(PT != NT && PT->hasSameElementTypeAs(NT) && "Invalid usage"); #endif - WorkMap[&I] = V; + WorkMap[&Root] = V; for (Instruction *Workitem : Worklist) replace(Workitem); } Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { - if (auto *I = simplifyAllocaArraySize(*this, AI)) + if (auto *I = simplifyAllocaArraySize(*this, AI, DT)) return I; if (AI.getAllocatedType()->isSized()) { // Move all alloca's of zero byte objects to the entry block and merge them // together. Note that we only do this for alloca's, because malloc should // allocate and return a unique pointer, even for a zero byte allocation. - if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinSize() == 0) { + if (DL.getTypeAllocSize(AI.getAllocatedType()).getKnownMinValue() == 0) { // For a zero sized alloca there is no point in doing an array allocation. // This is helpful if the array size is a complicated expression not used // elsewhere. @@ -377,7 +469,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { AllocaInst *EntryAI = dyn_cast<AllocaInst>(FirstInst); if (!EntryAI || !EntryAI->getAllocatedType()->isSized() || DL.getTypeAllocSize(EntryAI->getAllocatedType()) - .getKnownMinSize() != 0) { + .getKnownMinValue() != 0) { AI.moveBefore(FirstInst); return &AI; } @@ -395,11 +487,11 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { } // Check to see if this allocation is only modified by a memcpy/memmove from - // a constant whose alignment is equal to or exceeds that of the allocation. - // If this is the case, we can change all users to use the constant global - // instead. This is commonly produced by the CFE by constructs like "void - // foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' is only subsequently - // read. + // a memory location whose alignment is equal to or exceeds that of the + // allocation. If this is the case, we can change all users to use the + // constant memory location instead. This is commonly produced by the CFE by + // constructs like "void foo() { int A[] = {1,2,3,4,5,6,7,8,9...}; }" if 'A' + // is only subsequently read. SmallVector<Instruction *, 4> ToDelete; if (MemTransferInst *Copy = isOnlyCopiedFromConstantMemory(AA, &AI, ToDelete)) { Value *TheSrc = Copy->getSource(); @@ -415,7 +507,7 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { LLVM_DEBUG(dbgs() << " memcpy = " << *Copy << '\n'); unsigned SrcAddrSpace = TheSrc->getType()->getPointerAddressSpace(); auto *DestTy = PointerType::get(AI.getAllocatedType(), SrcAddrSpace); - if (AI.getType()->getAddressSpace() == SrcAddrSpace) { + if (AI.getAddressSpace() == SrcAddrSpace) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); @@ -426,13 +518,13 @@ Instruction *InstCombinerImpl::visitAllocaInst(AllocaInst &AI) { return NewI; } - PointerReplacer PtrReplacer(*this); - if (PtrReplacer.collectUsers(AI)) { + PointerReplacer PtrReplacer(*this, AI); + if (PtrReplacer.collectUsers()) { for (Instruction *Delete : ToDelete) eraseInstFromFunction(*Delete); Value *Cast = Builder.CreateBitCast(TheSrc, DestTy); - PtrReplacer.replacePointer(AI, Cast); + PtrReplacer.replacePointer(Cast); ++NumGlobalCopies; } } @@ -507,6 +599,7 @@ static StoreInst *combineStoreToNewValue(InstCombinerImpl &IC, StoreInst &SI, // here. switch (ID) { case LLVMContext::MD_dbg: + case LLVMContext::MD_DIAssignID: case LLVMContext::MD_tbaa: case LLVMContext::MD_prof: case LLVMContext::MD_fpmath: @@ -575,43 +668,43 @@ static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { /// later. However, it is risky in case some backend or other part of LLVM is /// relying on the exact type loaded to select appropriate atomic operations. static Instruction *combineLoadToOperationType(InstCombinerImpl &IC, - LoadInst &LI) { + LoadInst &Load) { // FIXME: We could probably with some care handle both volatile and ordered // atomic loads here but it isn't clear that this is important. - if (!LI.isUnordered()) + if (!Load.isUnordered()) return nullptr; - if (LI.use_empty()) + if (Load.use_empty()) return nullptr; // swifterror values can't be bitcasted. - if (LI.getPointerOperand()->isSwiftError()) + if (Load.getPointerOperand()->isSwiftError()) return nullptr; - const DataLayout &DL = IC.getDataLayout(); - // Fold away bit casts of the loaded value by loading the desired type. // Note that we should not do this for pointer<->integer casts, // because that would result in type punning. - if (LI.hasOneUse()) { + if (Load.hasOneUse()) { // Don't transform when the type is x86_amx, it makes the pass that lower // x86_amx type happy. - if (auto *BC = dyn_cast<BitCastInst>(LI.user_back())) { - assert(!LI.getType()->isX86_AMXTy() && - "load from x86_amx* should not happen!"); + Type *LoadTy = Load.getType(); + if (auto *BC = dyn_cast<BitCastInst>(Load.user_back())) { + assert(!LoadTy->isX86_AMXTy() && "Load from x86_amx* should not happen!"); if (BC->getType()->isX86_AMXTy()) return nullptr; } - if (auto* CI = dyn_cast<CastInst>(LI.user_back())) - if (CI->isNoopCast(DL) && LI.getType()->isPtrOrPtrVectorTy() == - CI->getDestTy()->isPtrOrPtrVectorTy()) - if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { - LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy()); - CI->replaceAllUsesWith(NewLoad); - IC.eraseInstFromFunction(*CI); - return &LI; - } + if (auto *CastUser = dyn_cast<CastInst>(Load.user_back())) { + Type *DestTy = CastUser->getDestTy(); + if (CastUser->isNoopCast(IC.getDataLayout()) && + LoadTy->isPtrOrPtrVectorTy() == DestTy->isPtrOrPtrVectorTy() && + (!Load.isAtomic() || isSupportedAtomicType(DestTy))) { + LoadInst *NewLoad = IC.combineLoadToNewType(Load, DestTy); + CastUser->replaceAllUsesWith(NewLoad); + IC.eraseInstFromFunction(*CastUser); + return &Load; + } + } } // FIXME: We should also canonicalize loads of vectors when their elements are @@ -639,7 +732,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { ".unpack"); NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( - UndefValue::get(T), NewLoad, 0, Name)); + PoisonValue::get(T), NewLoad, 0, Name)); } // We don't want to break loads with padding here as we'd loose @@ -654,13 +747,13 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *IdxType = Type::getInt32Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - Value *V = UndefValue::get(T); + Value *V = PoisonValue::get(T); for (unsigned i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), + auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices), Name + ".elt"); auto *L = IC.Builder.CreateAlignedLoad( ST->getElementType(i), Ptr, @@ -681,7 +774,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( - UndefValue::get(T), NewLoad, 0, Name)); + PoisonValue::get(T), NewLoad, 0, Name)); } // Bail out if the array is too large. Ideally we would like to optimize @@ -699,14 +792,14 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *IdxType = Type::getInt64Ty(T->getContext()); auto *Zero = ConstantInt::get(IdxType, 0); - Value *V = UndefValue::get(T); + Value *V = PoisonValue::get(T); uint64_t Offset = 0; for (uint64_t i = 0; i < NumElements; i++) { Value *Indices[2] = { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), + auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), Name + ".elt"); auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, commonAlignment(Align, Offset), @@ -769,10 +862,13 @@ static bool isObjectSizeLessThanOrEq(Value *V, uint64_t MaxSize, if (!CS) return false; - uint64_t TypeSize = DL.getTypeAllocSize(AI->getAllocatedType()); + TypeSize TS = DL.getTypeAllocSize(AI->getAllocatedType()); + if (TS.isScalable()) + return false; // Make sure that, even if the multiplication below would wrap as an // uint64_t, we still do the right thing. - if ((CS->getValue().zext(128) * APInt(128, TypeSize)).ugt(MaxSize)) + if ((CS->getValue().zext(128) * APInt(128, TS.getFixedValue())) + .ugt(MaxSize)) return false; continue; } @@ -849,7 +945,7 @@ static bool canReplaceGEPIdxWithZero(InstCombinerImpl &IC, if (!AllocTy || !AllocTy->isSized()) return false; const DataLayout &DL = IC.getDataLayout(); - uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedSize(); + uint64_t TyAllocSize = DL.getTypeAllocSize(AllocTy).getFixedValue(); // If there are more indices after the one we might replace with a zero, make // sure they're all non-negative. If any of them are negative, the overall @@ -1183,8 +1279,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(ST, Addr, makeArrayRef(Indices), - AddrName); + auto *Ptr = + IC.Builder.CreateInBoundsGEP(ST, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); @@ -1229,8 +1325,8 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { Zero, ConstantInt::get(IdxType, i), }; - auto *Ptr = IC.Builder.CreateInBoundsGEP(AT, Addr, makeArrayRef(Indices), - AddrName); + auto *Ptr = + IC.Builder.CreateInBoundsGEP(AT, Addr, ArrayRef(Indices), AddrName); auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, Offset); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); @@ -1372,7 +1468,7 @@ Instruction *InstCombinerImpl::visitStoreInst(StoreInst &SI) { // If we have a store to a location which is known constant, we can conclude // that the store must be storing the constant value (else the memory // wouldn't be constant), and this must be a noop. - if (AA->pointsToConstantMemory(Ptr)) + if (!isModSet(AA->getModRefInfoMask(Ptr))) return eraseInstFromFunction(SI); // Do really simple DSE, to catch cases where there are several consecutive @@ -1547,6 +1643,7 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(MergedLoc); + NewSI->mergeDIAssignID({&SI, OtherStore}); // If the two stores had AA tags, merge them. AAMDNodes AATags = SI.getAAMetadata(); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 8cb09cbac86f..97f129e200de 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" @@ -139,9 +140,56 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, return nullptr; } +/// Reduce integer multiplication patterns that contain a (+/-1 << Z) factor. +/// Callers are expected to call this twice to handle commuted patterns. +static Value *foldMulShl1(BinaryOperator &Mul, bool CommuteOperands, + InstCombiner::BuilderTy &Builder) { + Value *X = Mul.getOperand(0), *Y = Mul.getOperand(1); + if (CommuteOperands) + std::swap(X, Y); + + const bool HasNSW = Mul.hasNoSignedWrap(); + const bool HasNUW = Mul.hasNoUnsignedWrap(); + + // X * (1 << Z) --> X << Z + Value *Z; + if (match(Y, m_Shl(m_One(), m_Value(Z)))) { + bool PropagateNSW = HasNSW && cast<ShlOperator>(Y)->hasNoSignedWrap(); + return Builder.CreateShl(X, Z, Mul.getName(), HasNUW, PropagateNSW); + } + + // Similar to above, but an increment of the shifted value becomes an add: + // X * ((1 << Z) + 1) --> (X * (1 << Z)) + X --> (X << Z) + X + // This increases uses of X, so it may require a freeze, but that is still + // expected to be an improvement because it removes the multiply. + BinaryOperator *Shift; + if (match(Y, m_OneUse(m_Add(m_BinOp(Shift), m_One()))) && + match(Shift, m_OneUse(m_Shl(m_One(), m_Value(Z))))) { + bool PropagateNSW = HasNSW && Shift->hasNoSignedWrap(); + Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Shl = Builder.CreateShl(FrX, Z, "mulshl", HasNUW, PropagateNSW); + return Builder.CreateAdd(Shl, FrX, Mul.getName(), HasNUW, PropagateNSW); + } + + // Similar to above, but a decrement of the shifted value is disguised as + // 'not' and becomes a sub: + // X * (~(-1 << Z)) --> X * ((1 << Z) - 1) --> (X << Z) - X + // This increases uses of X, so it may require a freeze, but that is still + // expected to be an improvement because it removes the multiply. + if (match(Y, m_OneUse(m_Not(m_OneUse(m_Shl(m_AllOnes(), m_Value(Z))))))) { + Value *FrX = Builder.CreateFreeze(X, X->getName() + ".fr"); + Value *Shl = Builder.CreateShl(FrX, Z, "mulshl"); + return Builder.CreateSub(Shl, FrX, Mul.getName()); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { - if (Value *V = simplifyMulInst(I.getOperand(0), I.getOperand(1), - SQ.getWithInstruction(&I))) + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + if (Value *V = + simplifyMulInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); if (SimplifyAssociativeOrCommutative(I)) @@ -153,18 +201,18 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Instruction *Phi = foldBinopWithPhiOperands(I)) return Phi; - if (Value *V = SimplifyUsingDistributiveLaws(I)) + if (Value *V = foldUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - unsigned BitWidth = I.getType()->getScalarSizeInBits(); + Type *Ty = I.getType(); + const unsigned BitWidth = Ty->getScalarSizeInBits(); + const bool HasNSW = I.hasNoSignedWrap(); + const bool HasNUW = I.hasNoUnsignedWrap(); - // X * -1 == 0 - X + // X * -1 --> 0 - X if (match(Op1, m_AllOnes())) { - BinaryOperator *BO = BinaryOperator::CreateNeg(Op0, I.getName()); - if (I.hasNoSignedWrap()) - BO->setHasNoSignedWrap(); - return BO; + return HasNSW ? BinaryOperator::CreateNSWNeg(Op0) + : BinaryOperator::CreateNeg(Op0); } // Also allow combining multiply instructions on vectors. @@ -179,10 +227,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { Constant *Shl = ConstantExpr::getShl(C1, C2); BinaryOperator *Mul = cast<BinaryOperator>(I.getOperand(0)); BinaryOperator *BO = BinaryOperator::CreateMul(NewOp, Shl); - if (I.hasNoUnsignedWrap() && Mul->hasNoUnsignedWrap()) + if (HasNUW && Mul->hasNoUnsignedWrap()) BO->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap() && Mul->hasNoSignedWrap() && - Shl->isNotMinSignedValue()) + if (HasNSW && Mul->hasNoSignedWrap() && Shl->isNotMinSignedValue()) BO->setHasNoSignedWrap(); return BO; } @@ -192,9 +239,9 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Constant *NewCst = ConstantExpr::getExactLogBase2(C1)) { BinaryOperator *Shl = BinaryOperator::CreateShl(NewOp, NewCst); - if (I.hasNoUnsignedWrap()) + if (HasNUW) Shl->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap()) { + if (HasNSW) { const APInt *V; if (match(NewCst, m_APInt(V)) && *V != V->getBitWidth() - 1) Shl->setHasNoSignedWrap(); @@ -211,6 +258,25 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { if (Value *NegOp0 = Negator::Negate(/*IsNegation*/ true, Op0, *this)) return BinaryOperator::CreateMul( NegOp0, ConstantExpr::getNeg(cast<Constant>(Op1)), I.getName()); + + // Try to convert multiply of extended operand to narrow negate and shift + // for better analysis. + // This is valid if the shift amount (trailing zeros in the multiplier + // constant) clears more high bits than the bitwidth difference between + // source and destination types: + // ({z/s}ext X) * (-1<<C) --> (zext (-X)) << C + const APInt *NegPow2C; + Value *X; + if (match(Op0, m_ZExtOrSExt(m_Value(X))) && + match(Op1, m_APIntAllowUndef(NegPow2C))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned ShiftAmt = NegPow2C->countTrailingZeros(); + if (ShiftAmt >= BitWidth - SrcWidth) { + Value *N = Builder.CreateNeg(X, X->getName() + ".neg"); + Value *Z = Builder.CreateZExt(N, Ty, N->getName() + ".z"); + return BinaryOperator::CreateShl(Z, ConstantInt::get(Ty, ShiftAmt)); + } + } } if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) @@ -220,16 +286,29 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return replaceInstUsesWith(I, FoldedMul); // Simplify mul instructions with a constant RHS. - if (isa<Constant>(Op1)) { - // Canonicalize (X+C1)*CI -> X*CI+C1*CI. + Constant *MulC; + if (match(Op1, m_ImmConstant(MulC))) { + // Canonicalize (X+C1)*MulC -> X*MulC+C1*MulC. + // Canonicalize (X|C1)*MulC -> X*MulC+C1*MulC. Value *X; Constant *C1; - if (match(Op0, m_OneUse(m_Add(m_Value(X), m_Constant(C1))))) { - Value *Mul = Builder.CreateMul(C1, Op1); - // Only go forward with the transform if C1*CI simplifies to a tidier - // constant. - if (!match(Mul, m_Mul(m_Value(), m_Value()))) - return BinaryOperator::CreateAdd(Builder.CreateMul(X, Op1), Mul); + if ((match(Op0, m_OneUse(m_Add(m_Value(X), m_ImmConstant(C1))))) || + (match(Op0, m_OneUse(m_Or(m_Value(X), m_ImmConstant(C1)))) && + haveNoCommonBitsSet(X, C1, DL, &AC, &I, &DT))) { + // C1*MulC simplifies to a tidier constant. + Value *NewC = Builder.CreateMul(C1, MulC); + auto *BOp0 = cast<BinaryOperator>(Op0); + bool Op0NUW = + (BOp0->getOpcode() == Instruction::Or || BOp0->hasNoUnsignedWrap()); + Value *NewMul = Builder.CreateMul(X, MulC); + auto *BO = BinaryOperator::CreateAdd(NewMul, NewC); + if (HasNUW && Op0NUW) { + // If NewMulBO is constant we also can set BO to nuw. + if (auto *NewMulBO = dyn_cast<BinaryOperator>(NewMul)) + NewMulBO->setHasNoUnsignedWrap(); + BO->setHasNoUnsignedWrap(); + } + return BO; } } @@ -254,8 +333,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // -X * -Y --> X * Y if (match(Op0, m_Neg(m_Value(X))) && match(Op1, m_Neg(m_Value(Y)))) { auto *NewMul = BinaryOperator::CreateMul(X, Y); - if (I.hasNoSignedWrap() && - cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && + if (HasNSW && cast<OverflowingBinaryOperator>(Op0)->hasNoSignedWrap() && cast<OverflowingBinaryOperator>(Op1)->hasNoSignedWrap()) NewMul->setHasNoSignedWrap(); return NewMul; @@ -306,33 +384,15 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { // 2) X * Y --> X & Y, iff X, Y can be only {0,1}. // Note: We could use known bits to generalize this and related patterns with // shifts/truncs - Type *Ty = I.getType(); if (Ty->isIntOrIntVectorTy(1) || (match(Op0, m_And(m_Value(), m_One())) && match(Op1, m_And(m_Value(), m_One())))) return BinaryOperator::CreateAnd(Op0, Op1); - // X*(1 << Y) --> X << Y - // (1 << Y)*X --> X << Y - { - Value *Y; - BinaryOperator *BO = nullptr; - bool ShlNSW = false; - if (match(Op0, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op1, Y); - ShlNSW = cast<ShlOperator>(Op0)->hasNoSignedWrap(); - } else if (match(Op1, m_Shl(m_One(), m_Value(Y)))) { - BO = BinaryOperator::CreateShl(Op0, Y); - ShlNSW = cast<ShlOperator>(Op1)->hasNoSignedWrap(); - } - if (BO) { - if (I.hasNoUnsignedWrap()) - BO->setHasNoUnsignedWrap(); - if (I.hasNoSignedWrap() && ShlNSW) - BO->setHasNoSignedWrap(); - return BO; - } - } + if (Value *R = foldMulShl1(I, /* CommuteOperands */ false, Builder)) + return replaceInstUsesWith(I, R); + if (Value *R = foldMulShl1(I, /* CommuteOperands */ true, Builder)) + return replaceInstUsesWith(I, R); // (zext bool X) * (zext bool Y) --> zext (and X, Y) // (sext bool X) * (sext bool Y) --> zext (and X, Y) @@ -403,8 +463,7 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { m_One()), m_Deferred(X)))) { Value *Abs = Builder.CreateBinaryIntrinsic( - Intrinsic::abs, X, - ConstantInt::getBool(I.getContext(), I.hasNoSignedWrap())); + Intrinsic::abs, X, ConstantInt::getBool(I.getContext(), HasNSW)); Abs->takeName(&I); return replaceInstUsesWith(I, Abs); } @@ -413,12 +472,12 @@ Instruction *InstCombinerImpl::visitMul(BinaryOperator &I) { return Ext; bool Changed = false; - if (!I.hasNoSignedWrap() && willNotOverflowSignedMul(Op0, Op1, I)) { + if (!HasNSW && willNotOverflowSignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoSignedWrap(true); } - if (!I.hasNoUnsignedWrap() && willNotOverflowUnsignedMul(Op0, Op1, I)) { + if (!HasNUW && willNotOverflowUnsignedMul(Op0, Op1, I)) { Changed = true; I.setHasNoUnsignedWrap(true); } @@ -488,11 +547,19 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { if (match(Op1, m_SpecificFP(-1.0))) return UnaryOperator::CreateFNegFMF(Op0, &I); + // With no-nans: X * 0.0 --> copysign(0.0, X) + if (I.hasNoNaNs() && match(Op1, m_PosZeroFP())) { + CallInst *CopySign = Builder.CreateIntrinsic(Intrinsic::copysign, + {I.getType()}, {Op1, Op0}, &I); + return replaceInstUsesWith(I, CopySign); + } + // -X * C --> X * -C Value *X, *Y; Constant *C; if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) - return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFMulFMF(X, NegC, &I); // (select A, B, C) * (select A, D, E) --> select A, (B*D), (C*E) if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) @@ -596,14 +663,32 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { } } + // pow(X, Y) * X --> pow(X, Y+1) + // X * pow(X, Y) --> pow(X, Y+1) + if (match(&I, m_c_FMul(m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Value(X), + m_Value(Y))), + m_Deferred(X)))) { + Value *Y1 = + Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), 1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + if (I.isOnlyUserOfAnyOperand()) { - // pow(x, y) * pow(x, z) -> pow(x, y + z) + // pow(X, Y) * pow(X, Z) -> pow(X, Y + Z) if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && match(Op1, m_Intrinsic<Intrinsic::pow>(m_Specific(X), m_Value(Z)))) { auto *YZ = Builder.CreateFAddFMF(Y, Z, &I); auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, X, YZ, &I); return replaceInstUsesWith(I, NewPow); } + // pow(X, Y) * pow(Z, Y) -> pow(X * Z, Y) + if (match(Op0, m_Intrinsic<Intrinsic::pow>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::pow>(m_Value(Z), m_Specific(Y)))) { + auto *XZ = Builder.CreateFMulFMF(X, Z, &I); + auto *NewPow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, XZ, Y, &I); + return replaceInstUsesWith(I, NewPow); + } // powi(x, y) * powi(x, z) -> powi(x, y + z) if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && @@ -671,6 +756,15 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { } } + // Simplify FMUL recurrences starting with 0.0 to 0.0 if nnan and nsz are set. + // Given a phi node with entry value as 0 and it used in fmul operation, + // we can replace fmul with 0 safely and eleminate loop operation. + PHINode *PN = nullptr; + Value *Start = nullptr, *Step = nullptr; + if (matchSimpleRecurrence(&I, PN, Start, Step) && I.hasNoNaNs() && + I.hasNoSignedZeros() && match(Start, m_Zero())) + return replaceInstUsesWith(I, Start); + return nullptr; } @@ -773,6 +867,70 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, return Remainder.isMinValue(); } +static Instruction *foldIDivShl(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert((I.getOpcode() == Instruction::SDiv || + I.getOpcode() == Instruction::UDiv) && + "Expected integer divide"); + + bool IsSigned = I.getOpcode() == Instruction::SDiv; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Type *Ty = I.getType(); + + Instruction *Ret = nullptr; + Value *X, *Y, *Z; + + // With appropriate no-wrap constraints, remove a common factor in the + // dividend and divisor that is disguised as a left-shifted value. + if (match(Op1, m_Shl(m_Value(X), m_Value(Z))) && + match(Op0, m_c_Mul(m_Specific(X), m_Value(Y)))) { + // Both operands must have the matching no-wrap for this kind of division. + auto *Mul = cast<OverflowingBinaryOperator>(Op0); + auto *Shl = cast<OverflowingBinaryOperator>(Op1); + bool HasNUW = Mul->hasNoUnsignedWrap() && Shl->hasNoUnsignedWrap(); + bool HasNSW = Mul->hasNoSignedWrap() && Shl->hasNoSignedWrap(); + + // (X * Y) u/ (X << Z) --> Y u>> Z + if (!IsSigned && HasNUW) + Ret = BinaryOperator::CreateLShr(Y, Z); + + // (X * Y) s/ (X << Z) --> Y s/ (1 << Z) + if (IsSigned && HasNSW && (Op0->hasOneUse() || Op1->hasOneUse())) { + Value *Shl = Builder.CreateShl(ConstantInt::get(Ty, 1), Z); + Ret = BinaryOperator::CreateSDiv(Y, Shl); + } + } + + // With appropriate no-wrap constraints, remove a common factor in the + // dividend and divisor that is disguised as a left-shift amount. + if (match(Op0, m_Shl(m_Value(X), m_Value(Z))) && + match(Op1, m_Shl(m_Value(Y), m_Specific(Z)))) { + auto *Shl0 = cast<OverflowingBinaryOperator>(Op0); + auto *Shl1 = cast<OverflowingBinaryOperator>(Op1); + + // For unsigned div, we need 'nuw' on both shifts or + // 'nsw' on both shifts + 'nuw' on the dividend. + // (X << Z) / (Y << Z) --> X / Y + if (!IsSigned && + ((Shl0->hasNoUnsignedWrap() && Shl1->hasNoUnsignedWrap()) || + (Shl0->hasNoUnsignedWrap() && Shl0->hasNoSignedWrap() && + Shl1->hasNoSignedWrap()))) + Ret = BinaryOperator::CreateUDiv(X, Y); + + // For signed div, we need 'nsw' on both shifts + 'nuw' on the divisor. + // (X << Z) / (Y << Z) --> X / Y + if (IsSigned && Shl0->hasNoSignedWrap() && Shl1->hasNoSignedWrap() && + Shl1->hasNoUnsignedWrap()) + Ret = BinaryOperator::CreateSDiv(X, Y); + } + + if (!Ret) + return nullptr; + + Ret->setIsExact(I.isExact()); + return Ret; +} + /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. @@ -919,6 +1077,41 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } + // (X << Z) / (X * Y) -> (1 << Z) / Y + // TODO: Handle sdiv. + if (!IsSigned && Op1->hasOneUse() && + match(Op0, m_NUWShl(m_Value(X), m_Value(Z))) && + match(Op1, m_c_Mul(m_Specific(X), m_Value(Y)))) + if (cast<OverflowingBinaryOperator>(Op1)->hasNoUnsignedWrap()) { + Instruction *NewDiv = BinaryOperator::CreateUDiv( + Builder.CreateShl(ConstantInt::get(Ty, 1), Z, "", /*NUW*/ true), Y); + NewDiv->setIsExact(I.isExact()); + return NewDiv; + } + + if (Instruction *R = foldIDivShl(I, Builder)) + return R; + + // With the appropriate no-wrap constraint, remove a multiply by the divisor + // after peeking through another divide: + // ((Op1 * X) / Y) / Op1 --> X / Y + if (match(Op0, m_BinOp(I.getOpcode(), m_c_Mul(m_Specific(Op1), m_Value(X)), + m_Value(Y)))) { + auto *InnerDiv = cast<PossiblyExactOperator>(Op0); + auto *Mul = cast<OverflowingBinaryOperator>(InnerDiv->getOperand(0)); + Instruction *NewDiv = nullptr; + if (!IsSigned && Mul->hasNoUnsignedWrap()) + NewDiv = BinaryOperator::CreateUDiv(X, Y); + else if (IsSigned && Mul->hasNoSignedWrap()) + NewDiv = BinaryOperator::CreateSDiv(X, Y); + + // Exact propagates only if both of the original divides are exact. + if (NewDiv) { + NewDiv->setIsExact(I.isExact() && InnerDiv->isExact()); + return NewDiv; + } + } + return nullptr; } @@ -1007,8 +1200,8 @@ static Instruction *narrowUDivURem(BinaryOperator &I, } Constant *C; - if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) || - (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) { + if (isa<Instruction>(N) && match(N, m_OneUse(m_ZExt(m_Value(X)))) && + match(D, m_Constant(C))) { // If the constant is the same in the smaller type, use the narrow version. Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); if (ConstantExpr::getZExt(TruncC, Ty) != C) @@ -1016,18 +1209,25 @@ static Instruction *narrowUDivURem(BinaryOperator &I, // udiv (zext X), C --> zext (udiv X, C') // urem (zext X), C --> zext (urem X, C') + return new ZExtInst(Builder.CreateBinOp(Opcode, X, TruncC), Ty); + } + if (isa<Instruction>(D) && match(D, m_OneUse(m_ZExt(m_Value(X)))) && + match(N, m_Constant(C))) { + // If the constant is the same in the smaller type, use the narrow version. + Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(TruncC, Ty) != C) + return nullptr; + // udiv C, (zext X) --> zext (udiv C', X) // urem C, (zext X) --> zext (urem C', X) - Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC) - : Builder.CreateBinOp(Opcode, TruncC, X); - return new ZExtInst(NarrowOp, Ty); + return new ZExtInst(Builder.CreateBinOp(Opcode, TruncC, X), Ty); } return nullptr; } Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { - if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifyUDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1086,6 +1286,16 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { return BinaryOperator::CreateUDiv(A, X); } + // Look through a right-shift to find the common factor: + // ((Op1 *nuw A) >> B) / Op1 --> A >> B + if (match(Op0, m_LShr(m_NUWMul(m_Specific(Op1), m_Value(A)), m_Value(B))) || + match(Op0, m_LShr(m_NUWMul(m_Value(A), m_Specific(Op1)), m_Value(B)))) { + Instruction *Lshr = BinaryOperator::CreateLShr(A, B); + if (I.isExact() && cast<PossiblyExactOperator>(Op0)->isExact()) + Lshr->setIsExact(); + return Lshr; + } + // Op1 udiv Op2 -> Op1 lshr log2(Op2), if log2() folds away. if (takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/false)) { Value *Res = takeLog2(Builder, Op1, /*Depth*/0, /*DoFold*/true); @@ -1097,7 +1307,7 @@ Instruction *InstCombinerImpl::visitUDiv(BinaryOperator &I) { } Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { - if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), + if (Value *V = simplifySDivInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); @@ -1121,20 +1331,25 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { if (match(Op1, m_SignMask())) return new ZExtInst(Builder.CreateICmpEQ(Op0, Op1), Ty); - // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative - // sdiv exact X, -1<<C --> -(ashr exact X, C) - if (I.isExact() && ((match(Op1, m_Power2()) && match(Op1, m_NonNegative())) || - match(Op1, m_NegatedPower2()))) { - bool DivisorWasNegative = match(Op1, m_NegatedPower2()); - if (DivisorWasNegative) - Op1 = ConstantExpr::getNeg(cast<Constant>(Op1)); - auto *AShr = BinaryOperator::CreateExactAShr( - Op0, ConstantExpr::getExactLogBase2(cast<Constant>(Op1)), I.getName()); - if (!DivisorWasNegative) - return AShr; - Builder.Insert(AShr); - AShr->setName(I.getName() + ".neg"); - return BinaryOperator::CreateNeg(AShr, I.getName()); + if (I.isExact()) { + // sdiv exact X, 1<<C --> ashr exact X, C iff 1<<C is non-negative + if (match(Op1, m_Power2()) && match(Op1, m_NonNegative())) { + Constant *C = ConstantExpr::getExactLogBase2(cast<Constant>(Op1)); + return BinaryOperator::CreateExactAShr(Op0, C); + } + + // sdiv exact X, (1<<ShAmt) --> ashr exact X, ShAmt (if shl is non-negative) + Value *ShAmt; + if (match(Op1, m_NSWShl(m_One(), m_Value(ShAmt)))) + return BinaryOperator::CreateExactAShr(Op0, ShAmt); + + // sdiv exact X, -1<<C --> -(ashr exact X, C) + if (match(Op1, m_NegatedPower2())) { + Constant *NegPow2C = ConstantExpr::getNeg(cast<Constant>(Op1)); + Constant *C = ConstantExpr::getExactLogBase2(NegPow2C); + Value *Ashr = Builder.CreateAShr(Op0, C, I.getName() + ".neg", true); + return BinaryOperator::CreateNeg(Ashr); + } } const APInt *Op1C; @@ -1184,12 +1399,17 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { ConstantInt::getAllOnesValue(Ty)); } - // If the sign bits of both operands are zero (i.e. we can prove they are - // unsigned inputs), turn this into a udiv. - APInt Mask(APInt::getSignMask(Ty->getScalarSizeInBits())); - if (MaskedValueIsZero(Op0, Mask, 0, &I)) { - if (MaskedValueIsZero(Op1, Mask, 0, &I)) { - // X sdiv Y -> X udiv Y, iff X and Y don't have sign bit set + KnownBits KnownDividend = computeKnownBits(Op0, 0, &I); + if (!I.isExact() && + (match(Op1, m_Power2(Op1C)) || match(Op1, m_NegatedPower2(Op1C))) && + KnownDividend.countMinTrailingZeros() >= Op1C->countTrailingZeros()) { + I.setIsExact(); + return &I; + } + + if (KnownDividend.isNonNegative()) { + // If both operands are unsigned, turn this into a udiv. + if (isKnownNonNegative(Op1, DL, 0, &AC, &I, &DT)) { auto *BO = BinaryOperator::CreateUDiv(Op0, Op1, I.getName()); BO->setIsExact(I.isExact()); return BO; @@ -1219,15 +1439,28 @@ Instruction *InstCombinerImpl::visitSDiv(BinaryOperator &I) { } /// Remove negation and try to convert division into multiplication. -static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { +Instruction *InstCombinerImpl::foldFDivConstantDivisor(BinaryOperator &I) { Constant *C; if (!match(I.getOperand(1), m_Constant(C))) return nullptr; // -X / C --> X / -C Value *X; + const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(0), m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFDivFMF(X, ConstantExpr::getFNeg(C), &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(X, NegC, &I); + + // nnan X / +0.0 -> copysign(inf, X) + if (I.hasNoNaNs() && match(I.getOperand(1), m_Zero())) { + IRBuilder<> B(&I); + // TODO: nnan nsz X / -0.0 -> copysign(inf, X) + CallInst *CopySign = B.CreateIntrinsic( + Intrinsic::copysign, {C->getType()}, + {ConstantFP::getInfinity(I.getType()), I.getOperand(0)}, &I); + CopySign->takeName(&I); + return replaceInstUsesWith(I, CopySign); + } // If the constant divisor has an exact inverse, this is always safe. If not, // then we can still create a reciprocal if fast-math-flags allow it and the @@ -1239,7 +1472,6 @@ static Instruction *foldFDivConstantDivisor(BinaryOperator &I) { // on all targets. // TODO: Use Intrinsic::canonicalize or let function attributes tell us that // denorms are flushed? - const DataLayout &DL = I.getModule()->getDataLayout(); auto *RecipC = ConstantFoldBinaryOpOperands( Instruction::FDiv, ConstantFP::get(I.getType(), 1.0), C, DL); if (!RecipC || !RecipC->isNormalFP()) @@ -1257,15 +1489,16 @@ static Instruction *foldFDivConstantDividend(BinaryOperator &I) { // C / -X --> -C / X Value *X; + const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(1), m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFDivFMF(ConstantExpr::getFNeg(C), X, &I); + if (Constant *NegC = ConstantFoldUnaryOpOperand(Instruction::FNeg, C, DL)) + return BinaryOperator::CreateFDivFMF(NegC, X, &I); if (!I.hasAllowReassoc() || !I.hasAllowReciprocal()) return nullptr; // Try to reassociate C / X expressions where X includes another constant. Constant *C2, *NewC = nullptr; - const DataLayout &DL = I.getModule()->getDataLayout(); if (match(I.getOperand(1), m_FMul(m_Value(X), m_Constant(C2)))) { // C / (X * C2) --> (C / C2) / X NewC = ConstantFoldBinaryOpOperands(Instruction::FDiv, C, C2, DL); @@ -1435,6 +1668,16 @@ Instruction *InstCombinerImpl::visitFDiv(BinaryOperator &I) { if (Instruction *Mul = foldFDivPowDivisor(I, Builder)) return Mul; + // pow(X, Y) / X --> pow(X, Y-1) + if (I.hasAllowReassoc() && + match(Op0, m_OneUse(m_Intrinsic<Intrinsic::pow>(m_Specific(Op1), + m_Value(Y))))) { + Value *Y1 = + Builder.CreateFAddFMF(Y, ConstantFP::get(I.getType(), -1.0), &I); + Value *Pow = Builder.CreateBinaryIntrinsic(Intrinsic::pow, Op1, Y1, &I); + return replaceInstUsesWith(I, Pow); + } + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index c573b03f31a6..e24abc48424d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -15,8 +15,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -130,7 +128,7 @@ std::array<Value *, 2> Negator::getSortedOperandsOfBinOp(Instruction *I) { // FIXME: can this be reworked into a worklist-based algorithm while preserving // the depth-first, early bailout traversal? -LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::visitImpl(Value *V, unsigned Depth) { // -(undef) -> undef. if (match(V, m_Undef())) return V; @@ -248,6 +246,19 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return nullptr; switch (I->getOpcode()) { + case Instruction::ZExt: { + // Negation of zext of signbit is signbit splat: + // 0 - (zext (i8 X u>> 7) to iN) --> sext (i8 X s>> 7) to iN + Value *SrcOp = I->getOperand(0); + unsigned SrcWidth = SrcOp->getType()->getScalarSizeInBits(); + const APInt &FullShift = APInt(SrcWidth, SrcWidth - 1); + if (IsTrulyNegation && + match(SrcOp, m_LShr(m_Value(X), m_SpecificIntAllowUndef(FullShift)))) { + Value *Ashr = Builder.CreateAShr(X, FullShift); + return Builder.CreateSExt(Ashr, I->getType()); + } + break; + } case Instruction::And: { Constant *ShAmt; // sub(y,and(lshr(x,C),1)) --> add(ashr(shl(x,(BW-1)-C),BW-1),y) @@ -382,7 +393,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { return Builder.CreateShl(NegOp0, I->getOperand(1), I->getName() + ".neg"); // Otherwise, `shl %x, C` can be interpreted as `mul %x, 1<<C`. auto *Op1C = dyn_cast<Constant>(I->getOperand(1)); - if (!Op1C) // Early return. + if (!Op1C || !IsTrulyNegation) return nullptr; return Builder.CreateMul( I->getOperand(0), @@ -399,7 +410,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { if (match(Ops[1], m_One())) return Builder.CreateNot(Ops[0], I->getName() + ".neg"); // Else, just defer to Instruction::Add handling. - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Instruction::Add: { // `add` is negatible if both of its operands are negatible. @@ -465,7 +476,7 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { llvm_unreachable("Can't get here. We always return from switch."); } -LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { +[[nodiscard]] Value *Negator::negate(Value *V, unsigned Depth) { NegatorMaxDepthVisited.updateMax(Depth); ++NegatorNumValuesVisited; @@ -502,20 +513,20 @@ LLVM_NODISCARD Value *Negator::negate(Value *V, unsigned Depth) { return NegatedV; } -LLVM_NODISCARD Optional<Negator::Result> Negator::run(Value *Root) { +[[nodiscard]] std::optional<Negator::Result> Negator::run(Value *Root) { Value *Negated = negate(Root, /*Depth=*/0); if (!Negated) { // We must cleanup newly-inserted instructions, to avoid any potential // endless combine looping. for (Instruction *I : llvm::reverse(NewInstructions)) I->eraseFromParent(); - return llvm::None; + return std::nullopt; } return std::make_pair(ArrayRef<Instruction *>(NewInstructions), Negated); } -LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, - InstCombinerImpl &IC) { +[[nodiscard]] Value *Negator::Negate(bool LHSIsZero, Value *Root, + InstCombinerImpl &IC) { ++NegatorTotalNegationsAttempted; LLVM_DEBUG(dbgs() << "Negator: attempting to sink negation into " << *Root << "\n"); @@ -525,7 +536,7 @@ LLVM_NODISCARD Value *Negator::Negate(bool LHSIsZero, Value *Root, Negator N(Root->getContext(), IC.getDataLayout(), IC.getAssumptionCache(), IC.getDominatorTree(), LHSIsZero); - Optional<Result> Res = N.run(Root); + std::optional<Result> Res = N.run(Root); if (!Res) { // Negation failed. LLVM_DEBUG(dbgs() << "Negator: failed to sink negation into " << *Root << "\n"); diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 90a796a0939e..7f59729f0085 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -20,6 +20,7 @@ #include "llvm/Support/CommandLine.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -102,15 +103,15 @@ void InstCombinerImpl::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { // ptr_val_inc = ... // ... // -Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { +bool InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (!PN.getType()->isIntegerTy()) - return nullptr; + return false; if (!PN.hasOneUse()) - return nullptr; + return false; auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back()); if (!IntToPtr) - return nullptr; + return false; // Check if the pointer is actually used as pointer: auto HasPointerUse = [](Instruction *IIP) { @@ -131,11 +132,11 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { }; if (!HasPointerUse(IntToPtr)) - return nullptr; + return false; if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) != DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType())) - return nullptr; + return false; SmallVector<Value *, 4> AvailablePtrVals; for (auto Incoming : zip(PN.blocks(), PN.incoming_values())) { @@ -174,10 +175,10 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { // For a single use integer load: auto *LoadI = dyn_cast<LoadInst>(Arg); if (!LoadI) - return nullptr; + return false; if (!LoadI->hasOneUse()) - return nullptr; + return false; // Push the integer typed Load instruction into the available // value set, and fix it up later when the pointer typed PHI @@ -194,7 +195,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { for (PHINode &PtrPHI : BB->phis()) { // FIXME: consider handling this in AggressiveInstCombine if (NumPhis++ > MaxNumPhis) - return nullptr; + return false; if (&PtrPHI == &PN || PtrPHI.getType() != IntToPtr->getType()) continue; if (any_of(zip(PN.blocks(), AvailablePtrVals), @@ -211,16 +212,19 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { if (MatchingPtrPHI) { assert(MatchingPtrPHI->getType() == IntToPtr->getType() && "Phi's Type does not match with IntToPtr"); - // The PtrToCast + IntToPtr will be simplified later - return CastInst::CreateBitOrPointerCast(MatchingPtrPHI, - IntToPtr->getOperand(0)->getType()); + // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here, + // to make sure another transform can't undo it in the meantime. + replaceInstUsesWith(*IntToPtr, MatchingPtrPHI); + eraseInstFromFunction(*IntToPtr); + eraseInstFromFunction(PN); + return true; } // If it requires a conversion for every PHI operand, do not do it. if (all_of(AvailablePtrVals, [&](Value *V) { return (V->getType() != IntToPtr->getType()) || isa<IntToPtrInst>(V); })) - return nullptr; + return false; // If any of the operand that requires casting is a terminator // instruction, do not do it. Similarly, do not do the transform if the value @@ -239,7 +243,7 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { return true; return false; })) - return nullptr; + return false; PHINode *NewPtrPHI = PHINode::Create( IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); @@ -290,9 +294,12 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { NewPtrPHI->addIncoming(CI, IncomingBB); } - // The PtrToCast + IntToPtr will be simplified later - return CastInst::CreateBitOrPointerCast(NewPtrPHI, - IntToPtr->getOperand(0)->getType()); + // Explicitly replace the inttoptr (rather than inserting a ptrtoint) here, + // to make sure another transform can't undo it in the meantime. + replaceInstUsesWith(*IntToPtr, NewPtrPHI); + eraseInstFromFunction(*IntToPtr); + eraseInstFromFunction(PN); + return true; } // Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and @@ -598,7 +605,7 @@ Instruction *InstCombinerImpl::foldPHIArgGEPIntoPHI(PHINode &PN) { Value *Base = FixedOperands[0]; GetElementPtrInst *NewGEP = GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, - makeArrayRef(FixedOperands).slice(1)); + ArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); PHIArgMergedDebugLoc(NewGEP, PN); return NewGEP; @@ -1322,7 +1329,7 @@ static Value *simplifyUsingControlFlow(InstCombiner &Self, PHINode &PN, // Check that edges outgoing from the idom's terminators dominate respective // inputs of the Phi. - Optional<bool> Invert; + std::optional<bool> Invert; for (auto Pair : zip(PN.incoming_values(), PN.blocks())) { auto *Input = cast<ConstantInt>(std::get<0>(Pair)); BasicBlock *Pred = std::get<1>(Pair); @@ -1412,8 +1419,8 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { // this PHI only has a single use (a PHI), and if that PHI only has one use (a // PHI)... break the cycle. if (PN.hasOneUse()) { - if (Instruction *Result = foldIntegerTypedPHI(PN)) - return Result; + if (foldIntegerTypedPHI(PN)) + return nullptr; Instruction *PHIUser = cast<Instruction>(PN.user_back()); if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ad96a5f475f1..e7d8208f94fd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,7 +12,6 @@ #include "InstCombineInternal.h" #include "llvm/ADT/APInt.h" -#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/AssumptionCache.h" @@ -20,6 +19,7 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/OverflowInstAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" @@ -314,47 +314,95 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, TI->getType()); } - // Cond ? -X : -Y --> -(Cond ? X : Y) - Value *X, *Y; - if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && - (TI->hasOneUse() || FI->hasOneUse())) { - // Intersect FMF from the fneg instructions and union those with the select. - FastMathFlags FMF = TI->getFastMathFlags(); - FMF &= FI->getFastMathFlags(); - FMF |= SI.getFastMathFlags(); - Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) - NewSelI->setFastMathFlags(FMF); - Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); - NewFNeg->setFastMathFlags(FMF); - return NewFNeg; - } - - // Min/max intrinsic with a common operand can have the common operand pulled - // after the select. This is the same transform as below for binops, but - // specialized for intrinsic matching and without the restrictive uses clause. - auto *TII = dyn_cast<IntrinsicInst>(TI); - auto *FII = dyn_cast<IntrinsicInst>(FI); - if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID() && - (TII->hasOneUse() || FII->hasOneUse())) { - Value *T0, *T1, *F0, *F1; - if (match(TII, m_MaxOrMin(m_Value(T0), m_Value(T1))) && - match(FII, m_MaxOrMin(m_Value(F0), m_Value(F1)))) { - if (T0 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); - } - if (T0 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); + Value *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + auto getCommonOp = [&](Instruction *TI, Instruction *FI, bool Commute, + bool Swapped = false) -> Value * { + assert(!(Commute && Swapped) && + "Commute and Swapped can't set at the same time"); + if (!Swapped) { + if (TI->getOperand(0) == FI->getOperand(0)) { + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + return TI->getOperand(0); + } else if (TI->getOperand(1) == FI->getOperand(1)) { + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + return TI->getOperand(1); } - if (T1 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + + if (!Commute && !Swapped) + return nullptr; + + // If we are allowing commute or swap of operands, then + // allow a cross-operand match. In that case, MatchIsOpZero + // means that TI's operand 0 (FI's operand 1) is the common op. + if (TI->getOperand(0) == FI->getOperand(1)) { + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + return TI->getOperand(0); + } else if (TI->getOperand(1) == FI->getOperand(0)) { + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = false; + return TI->getOperand(1); + } + return nullptr; + }; + + if (TI->hasOneUse() || FI->hasOneUse()) { + // Cond ? -X : -Y --> -(Cond ? X : Y) + Value *X, *Y; + if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) { + // Intersect FMF from the fneg instructions and union those with the + // select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); + Value *NewSel = + Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); + if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; + } + + // Min/max intrinsic with a common operand can have the common operand + // pulled after the select. This is the same transform as below for binops, + // but specialized for intrinsic matching and without the restrictive uses + // clause. + auto *TII = dyn_cast<IntrinsicInst>(TI); + auto *FII = dyn_cast<IntrinsicInst>(FI); + if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) { + if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) { + if (Value *MatchOp = getCommonOp(TI, FI, true)) { + Value *NewSel = + Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); + } } - if (T1 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + } + + // icmp with a common operand also can have the common operand + // pulled after the select. + ICmpInst::Predicate TPred, FPred; + if (match(TI, m_ICmp(TPred, m_Value(), m_Value())) && + match(FI, m_ICmp(FPred, m_Value(), m_Value()))) { + if (TPred == FPred || TPred == CmpInst::getSwappedPredicate(FPred)) { + bool Swapped = TPred != FPred; + if (Value *MatchOp = + getCommonOp(TI, FI, ICmpInst::isEquality(TPred), Swapped)) { + Value *NewSel = Builder.CreateSelect(Cond, OtherOpT, OtherOpF, + SI.getName() + ".v", &SI); + return new ICmpInst( + MatchIsOpZero ? TPred : CmpInst::getSwappedPredicate(TPred), + MatchOp, NewSel); + } } } } @@ -370,33 +418,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, return nullptr; // Figure out if the operations have any operands in common. - Value *MatchOp, *OtherOpT, *OtherOpF; - bool MatchIsOpZero; - if (TI->getOperand(0) == FI->getOperand(0)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(1)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = false; - } else if (!TI->isCommutative()) { - return nullptr; - } else if (TI->getOperand(0) == FI->getOperand(1)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(0)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else { + Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative()); + if (!MatchOp) return nullptr; - } // If the select condition is a vector, the operands of the original select's // operands also must be vectors. This may not be the case for getelementptr @@ -442,44 +466,44 @@ Instruction *InstCombinerImpl::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, Value *FalseVal, bool Swapped) -> Instruction * { - if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { - if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { - if (unsigned SFO = getSelectFoldableOperands(TVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && FalseVal == TVI->getOperand(0)) - OpToFold = 1; - else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) - OpToFold = 2; + auto *TVI = dyn_cast<BinaryOperator>(TrueVal); + if (!TVI || !TVI->hasOneUse() || isa<Constant>(FalseVal)) + return nullptr; - if (OpToFold) { - FastMathFlags FMF; - // TODO: We probably ought to revisit cases where the select and FP - // instructions have different flags and add tests to ensure the - // behaviour is correct. - if (isa<FPMathOperator>(&SI)) - FMF = SI.getFastMathFlags(); - Constant *C = ConstantExpr::getBinOpIdentity( - TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); - Value *OOp = TVI->getOperand(2 - OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa<Constant>(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect( - SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C); - if (isa<FPMathOperator>(&SI)) - cast<Instruction>(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(TVI); - BinaryOperator *BO = - BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; - } - } - } - } + unsigned SFO = getSelectFoldableOperands(TVI); + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) + OpToFold = 1; + else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) + OpToFold = 2; + + if (!OpToFold) + return nullptr; + + // TODO: We probably ought to revisit cases where the select and FP + // instructions have different flags and add tests to ensure the + // behaviour is correct. + FastMathFlags FMF; + if (isa<FPMathOperator>(&SI)) + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); + Value *OOp = TVI->getOperand(2 - OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0, 1 and -1. + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { + Value *NewSel = Builder.CreateSelect(SI.getCondition(), Swapped ? C : OOp, + Swapped ? OOp : C); + if (isa<FPMathOperator>(&SI)) + cast<Instruction>(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; } return nullptr; }; @@ -779,19 +803,31 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, const Value *FalseVal, InstCombiner::BuilderTy &Builder) { ICmpInst::Predicate Pred = ICI->getPredicate(); - if (!ICmpInst::isUnsigned(Pred)) - return nullptr; + Value *A = ICI->getOperand(0); + Value *B = ICI->getOperand(1); // (b > a) ? 0 : a - b -> (b <= a) ? a - b : 0 + // (a == 0) ? 0 : a - 1 -> (a != 0) ? a - 1 : 0 if (match(TrueVal, m_Zero())) { Pred = ICmpInst::getInversePredicate(Pred); std::swap(TrueVal, FalseVal); } + if (!match(FalseVal, m_Zero())) return nullptr; - Value *A = ICI->getOperand(0); - Value *B = ICI->getOperand(1); + // ugt 0 is canonicalized to ne 0 and requires special handling + // (a != 0) ? a + -1 : 0 -> usub.sat(a, 1) + if (Pred == ICmpInst::ICMP_NE) { + if (match(B, m_Zero()) && match(TrueVal, m_Add(m_Specific(A), m_AllOnes()))) + return Builder.CreateBinaryIntrinsic(Intrinsic::usub_sat, A, + ConstantInt::get(A->getType(), 1)); + return nullptr; + } + + if (!ICmpInst::isUnsigned(Pred)) + return nullptr; + if (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_ULT) { // (b < a) ? a - b : 0 -> (a > b) ? a - b : 0 std::swap(A, B); @@ -952,8 +988,8 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - // Check if the condition value compares a value for equality against zero. - if (!ICI->isEquality() || !match(CmpRHS, m_Zero())) + // Check if the select condition compares a value for equality. + if (!ICI->isEquality()) return nullptr; Value *SelectArg = FalseVal; @@ -969,8 +1005,15 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, // Check that 'Count' is a call to intrinsic cttz/ctlz. Also check that the // input to the cttz/ctlz is used as LHS for the compare instruction. - if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Specific(CmpLHS))) && - !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Specific(CmpLHS)))) + Value *X; + if (!match(Count, m_Intrinsic<Intrinsic::cttz>(m_Value(X))) && + !match(Count, m_Intrinsic<Intrinsic::ctlz>(m_Value(X)))) + return nullptr; + + // (X == 0) ? BitWidth : ctz(X) + // (X == -1) ? BitWidth : ctz(~X) + if ((X != CmpLHS || !match(CmpRHS, m_Zero())) && + (!match(X, m_Not(m_Specific(CmpLHS))) || !match(CmpRHS, m_AllOnes()))) return nullptr; IntrinsicInst *II = cast<IntrinsicInst>(Count); @@ -1139,6 +1182,28 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, return nullptr; } +static bool replaceInInstruction(Value *V, Value *Old, Value *New, + InstCombiner &IC, unsigned Depth = 0) { + // Conservatively limit replacement to two instructions upwards. + if (Depth == 2) + return false; + + auto *I = dyn_cast<Instruction>(V); + if (!I || !I->hasOneUse() || !isSafeToSpeculativelyExecute(I)) + return false; + + bool Changed = false; + for (Use &U : I->operands()) { + if (U == Old) { + IC.replaceUse(U, New); + Changed = true; + } else { + Changed |= replaceInInstruction(U, Old, New, IC, Depth + 1); + } + } + return Changed; +} + /// If we have a select with an equality comparison, then we know the value in /// one of the arms of the select. See if substituting this value into an arm /// and simplifying the result yields the same value as the other arm. @@ -1157,10 +1222,7 @@ static Instruction *canonicalizeSPF(SelectInst &Sel, ICmpInst &Cmp, /// TODO: Wrapping flags could be preserved in some cases with better analysis. Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp) { - // Value equivalence substitution requires an all-or-nothing replacement. - // It does not make sense for a vector compare where each lane is chosen - // independently. - if (!Cmp.isEquality() || Cmp.getType()->isVectorTy()) + if (!Cmp.isEquality()) return nullptr; // Canonicalize the pattern to ICMP_EQ by swapping the select operands. @@ -1189,15 +1251,11 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // with different operands, which should not cause side-effects or trigger // undefined behavior). Only do this if CmpRHS is a constant, as // profitability is not clear for other cases. - // FIXME: The replacement could be performed recursively. - if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant())) - if (auto *I = dyn_cast<Instruction>(TrueVal)) - if (I->hasOneUse() && isSafeToSpeculativelyExecute(I)) - for (Use &U : I->operands()) - if (U == CmpLHS) { - replaceUse(U, CmpRHS); - return &Sel; - } + // FIXME: Support vectors. + if (match(CmpRHS, m_ImmConstant()) && !match(CmpLHS, m_ImmConstant()) && + !Cmp.getType()->isVectorTy()) + if (replaceInInstruction(TrueVal, CmpLHS, CmpRHS, *this)) + return &Sel; } if (TrueVal != CmpRHS && isGuaranteedNotToBeUndefOrPoison(CmpLHS, SQ.AC, &Sel, &DT)) @@ -1371,7 +1429,7 @@ static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, C2->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have signed max element[s]. C2 = InstCombiner::AddOne(C2); - LLVM_FALLTHROUGH; + [[fallthrough]]; case ICmpInst::Predicate::ICMP_SGE: // Also non-canonical, but here we don't need to change C2, // so we don't have any restrictions on C2, so we can just handle it. @@ -2307,6 +2365,41 @@ static Instruction *foldSelectToCopysign(SelectInst &Sel, } Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { + if (!isa<VectorType>(Sel.getType())) + return nullptr; + + Value *Cond = Sel.getCondition(); + Value *TVal = Sel.getTrueValue(); + Value *FVal = Sel.getFalseValue(); + Value *C, *X, *Y; + + if (match(Cond, m_VecReverse(m_Value(C)))) { + auto createSelReverse = [&](Value *C, Value *X, Value *Y) { + Value *V = Builder.CreateSelect(C, X, Y, Sel.getName(), &Sel); + if (auto *I = dyn_cast<Instruction>(V)) + I->copyIRFlags(&Sel); + Module *M = Sel.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + if (match(TVal, m_VecReverse(m_Value(X)))) { + // select rev(C), rev(X), rev(Y) --> rev(select C, X, Y) + if (match(FVal, m_VecReverse(m_Value(Y))) && + (Cond->hasOneUse() || TVal->hasOneUse() || FVal->hasOneUse())) + return createSelReverse(C, X, Y); + + // select rev(C), rev(X), FValSplat --> rev(select C, X, FValSplat) + if ((Cond->hasOneUse() || TVal->hasOneUse()) && isSplatValue(FVal)) + return createSelReverse(C, X, FVal); + } + // select rev(C), TValSplat, rev(Y) --> rev(select C, TValSplat, Y) + else if (isSplatValue(TVal) && match(FVal, m_VecReverse(m_Value(Y))) && + (Cond->hasOneUse() || FVal->hasOneUse())) + return createSelReverse(C, TVal, Y); + } + auto *VecTy = dyn_cast<FixedVectorType>(Sel.getType()); if (!VecTy) return nullptr; @@ -2323,10 +2416,6 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { // A select of a "select shuffle" with a common operand can be rearranged // to select followed by "select shuffle". Because of poison, this only works // in the case of a shuffle with no undefined mask elements. - Value *Cond = Sel.getCondition(); - Value *TVal = Sel.getTrueValue(); - Value *FVal = Sel.getFalseValue(); - Value *X, *Y; ArrayRef<int> Mask; if (match(TVal, m_OneUse(m_Shuffle(m_Value(X), m_Value(Y), m_Mask(Mask)))) && !is_contained(Mask, UndefMaskElem) && @@ -2472,7 +2561,7 @@ Instruction *InstCombinerImpl::foldAndOrOfSelectUsingImpliedCond(Value *Op, assert(Op->getType()->isIntOrIntVectorTy(1) && "Op must be either i1 or vector of i1."); - Optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); + std::optional<bool> Res = isImpliedCondition(Op, CondVal, DL, IsAnd); if (!Res) return nullptr; @@ -2510,6 +2599,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, InstCombinerImpl &IC) { Value *CondVal = SI.getCondition(); + bool ChangedFMF = false; for (bool Swap : {false, true}) { Value *TrueVal = SI.getTrueValue(); Value *X = SI.getFalseValue(); @@ -2534,13 +2624,33 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } + if (!match(TrueVal, m_FNeg(m_Specific(X)))) + return nullptr; + + // Forward-propagate nnan and ninf from the fneg to the select. + // If all inputs are not those values, then the select is not either. + // Note: nsz is defined differently, so it may not be correct to propagate. + FastMathFlags FMF = cast<FPMathOperator>(TrueVal)->getFastMathFlags(); + if (FMF.noNaNs() && !SI.hasNoNaNs()) { + SI.setHasNoNaNs(true); + ChangedFMF = true; + } + if (FMF.noInfs() && !SI.hasNoInfs()) { + SI.setHasNoInfs(true); + ChangedFMF = true; + } + // With nsz, when 'Swap' is false: // fold (X < +/-0.0) ? -X : X or (X <= +/-0.0) ? -X : X to fabs(X) // fold (X > +/-0.0) ? -X : X or (X >= +/-0.0) ? -X : X to -fabs(x) // when 'Swap' is true: // fold (X > +/-0.0) ? X : -X or (X >= +/-0.0) ? X : -X to fabs(X) // fold (X < +/-0.0) ? X : -X or (X <= +/-0.0) ? X : -X to -fabs(X) - if (!match(TrueVal, m_FNeg(m_Specific(X))) || !SI.hasNoSignedZeros()) + // + // Note: We require "nnan" for this fold because fcmp ignores the signbit + // of NAN, but IEEE-754 specifies the signbit of NAN values with + // fneg/fabs operations. + if (!SI.hasNoSignedZeros() || !SI.hasNoNaNs()) return nullptr; if (Swap) @@ -2563,7 +2673,7 @@ static Instruction *foldSelectWithFCmpToFabs(SelectInst &SI, } } - return nullptr; + return ChangedFMF ? &SI : nullptr; } // Match the following IR pattern: @@ -2602,10 +2712,14 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, if (!match(XLowBits, m_And(m_Specific(X), m_APIntAllowUndef(LowBitMaskCst)))) return nullptr; + // Match even if the AND and ADD are swapped. const APInt *BiasCst, *HighBitMaskCst; if (!match(XBiasedHighBits, m_And(m_Add(m_Specific(X), m_APIntAllowUndef(BiasCst)), - m_APIntAllowUndef(HighBitMaskCst)))) + m_APIntAllowUndef(HighBitMaskCst))) && + !match(XBiasedHighBits, + m_Add(m_And(m_Specific(X), m_APIntAllowUndef(HighBitMaskCst)), + m_APIntAllowUndef(BiasCst)))) return nullptr; if (!LowBitMaskCst->isMask()) @@ -2635,200 +2749,392 @@ foldRoundUpIntegerWithPow2Alignment(SelectInst &SI, return R; } -Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { +namespace { +struct DecomposedSelect { + Value *Cond = nullptr; + Value *TrueVal = nullptr; + Value *FalseVal = nullptr; +}; +} // namespace + +/// Look for patterns like +/// %outer.cond = select i1 %inner.cond, i1 %alt.cond, i1 false +/// %inner.sel = select i1 %inner.cond, i8 %inner.sel.t, i8 %inner.sel.f +/// %outer.sel = select i1 %outer.cond, i8 %outer.sel.t, i8 %inner.sel +/// and rewrite it as +/// %inner.sel = select i1 %cond.alternative, i8 %sel.outer.t, i8 %sel.inner.t +/// %sel.outer = select i1 %cond.inner, i8 %inner.sel, i8 %sel.inner.f +static Instruction *foldNestedSelects(SelectInst &OuterSelVal, + InstCombiner::BuilderTy &Builder) { + // We must start with a `select`. + DecomposedSelect OuterSel; + match(&OuterSelVal, + m_Select(m_Value(OuterSel.Cond), m_Value(OuterSel.TrueVal), + m_Value(OuterSel.FalseVal))); + + // Canonicalize inversion of the outermost `select`'s condition. + if (match(OuterSel.Cond, m_Not(m_Value(OuterSel.Cond)))) + std::swap(OuterSel.TrueVal, OuterSel.FalseVal); + + // The condition of the outermost select must be an `and`/`or`. + if (!match(OuterSel.Cond, m_c_LogicalOp(m_Value(), m_Value()))) + return nullptr; + + // Depending on the logical op, inner select might be in different hand. + bool IsAndVariant = match(OuterSel.Cond, m_LogicalAnd()); + Value *InnerSelVal = IsAndVariant ? OuterSel.FalseVal : OuterSel.TrueVal; + + // Profitability check - avoid increasing instruction count. + if (none_of(ArrayRef<Value *>({OuterSelVal.getCondition(), InnerSelVal}), + [](Value *V) { return V->hasOneUse(); })) + return nullptr; + + // The appropriate hand of the outermost `select` must be a select itself. + DecomposedSelect InnerSel; + if (!match(InnerSelVal, + m_Select(m_Value(InnerSel.Cond), m_Value(InnerSel.TrueVal), + m_Value(InnerSel.FalseVal)))) + return nullptr; + + // Canonicalize inversion of the innermost `select`'s condition. + if (match(InnerSel.Cond, m_Not(m_Value(InnerSel.Cond)))) + std::swap(InnerSel.TrueVal, InnerSel.FalseVal); + + Value *AltCond = nullptr; + auto matchOuterCond = [OuterSel, &AltCond](auto m_InnerCond) { + return match(OuterSel.Cond, m_c_LogicalOp(m_InnerCond, m_Value(AltCond))); + }; + + // Finally, match the condition that was driving the outermost `select`, + // it should be a logical operation between the condition that was driving + // the innermost `select` (after accounting for the possible inversions + // of the condition), and some other condition. + if (matchOuterCond(m_Specific(InnerSel.Cond))) { + // Done! + } else if (Value * NotInnerCond; matchOuterCond(m_CombineAnd( + m_Not(m_Specific(InnerSel.Cond)), m_Value(NotInnerCond)))) { + // Done! + std::swap(InnerSel.TrueVal, InnerSel.FalseVal); + InnerSel.Cond = NotInnerCond; + } else // Not the pattern we were looking for. + return nullptr; + + Value *SelInner = Builder.CreateSelect( + AltCond, IsAndVariant ? OuterSel.TrueVal : InnerSel.FalseVal, + IsAndVariant ? InnerSel.TrueVal : OuterSel.FalseVal); + SelInner->takeName(InnerSelVal); + return SelectInst::Create(InnerSel.Cond, + IsAndVariant ? SelInner : InnerSel.TrueVal, + !IsAndVariant ? SelInner : InnerSel.FalseVal); +} + +Instruction *InstCombinerImpl::foldSelectOfBools(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); - if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, - SQ.getWithInstruction(&SI))) - return replaceInstUsesWith(SI, V); - - if (Instruction *I = canonicalizeSelectToShuffle(SI)) - return I; - - if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) - return I; - // Avoid potential infinite loops by checking for non-constant condition. // TODO: Can we assert instead by improving canonicalizeSelectToShuffle()? // Scalar select must have simplified? - if (SelType->isIntOrIntVectorTy(1) && !isa<Constant>(CondVal) && - TrueVal->getType() == CondVal->getType()) { - // Folding select to and/or i1 isn't poison safe in general. impliesPoison - // checks whether folding it does not convert a well-defined value into - // poison. - if (match(TrueVal, m_One())) { - if (impliesPoison(FalseVal, CondVal)) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } + if (!SelType->isIntOrIntVectorTy(1) || isa<Constant>(CondVal) || + TrueVal->getType() != CondVal->getType()) + return nullptr; - if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) - if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) - if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, - /*IsSelectLogical*/ true)) - return replaceInstUsesWith(SI, V); - } - if (match(FalseVal, m_Zero())) { - if (impliesPoison(TrueVal, CondVal)) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } + auto *One = ConstantInt::getTrue(SelType); + auto *Zero = ConstantInt::getFalse(SelType); + Value *A, *B, *C, *D; - if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) - if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) - if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, - /*IsSelectLogical*/ true)) - return replaceInstUsesWith(SI, V); + // Folding select to and/or i1 isn't poison safe in general. impliesPoison + // checks whether folding it does not convert a well-defined value into + // poison. + if (match(TrueVal, m_One())) { + if (impliesPoison(FalseVal, CondVal)) { + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); } - auto *One = ConstantInt::getTrue(SelType); - auto *Zero = ConstantInt::getFalse(SelType); + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(FalseVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); - // We match the "full" 0 or 1 constant here to avoid a potential infinite - // loop with vectors that may have undefined/poison elements. - // select a, false, b -> select !a, b, false - if (match(TrueVal, m_Specific(Zero))) { - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, FalseVal, Zero); + // (A && B) || (C && B) --> (A || C) && B + if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && + match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) && + (CondVal->hasOneUse() || FalseVal->hasOneUse())) { + bool CondLogicAnd = isa<SelectInst>(CondVal); + bool FalseLogicAnd = isa<SelectInst>(FalseVal); + auto AndFactorization = [&](Value *Common, Value *InnerCond, + Value *InnerVal, + bool SelFirst = false) -> Instruction * { + Value *InnerSel = Builder.CreateSelect(InnerCond, One, InnerVal); + if (SelFirst) + std::swap(Common, InnerSel); + if (FalseLogicAnd || (CondLogicAnd && Common == A)) + return SelectInst::Create(Common, InnerSel, Zero); + else + return BinaryOperator::CreateAnd(Common, InnerSel); + }; + + if (A == C) + return AndFactorization(A, B, D); + if (A == D) + return AndFactorization(A, B, C); + if (B == C) + return AndFactorization(B, A, D); + if (B == D) + return AndFactorization(B, A, C, CondLogicAnd && FalseLogicAnd); } - // select a, b, true -> select !a, true, b - if (match(FalseVal, m_Specific(One))) { - Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); - return SelectInst::Create(NotCond, One, TrueVal); + } + + if (match(FalseVal, m_Zero())) { + if (impliesPoison(TrueVal, CondVal)) { + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + + if (auto *LHS = dyn_cast<FCmpInst>(CondVal)) + if (auto *RHS = dyn_cast<FCmpInst>(TrueVal)) + if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, + /*IsSelectLogical*/ true)) + return replaceInstUsesWith(SI, V); + + // (A || B) && (C || B) --> (A && C) || B + if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) && + (CondVal->hasOneUse() || TrueVal->hasOneUse())) { + bool CondLogicOr = isa<SelectInst>(CondVal); + bool TrueLogicOr = isa<SelectInst>(TrueVal); + auto OrFactorization = [&](Value *Common, Value *InnerCond, + Value *InnerVal, + bool SelFirst = false) -> Instruction * { + Value *InnerSel = Builder.CreateSelect(InnerCond, InnerVal, Zero); + if (SelFirst) + std::swap(Common, InnerSel); + if (TrueLogicOr || (CondLogicOr && Common == A)) + return SelectInst::Create(Common, One, InnerSel); + else + return BinaryOperator::CreateOr(Common, InnerSel); + }; + + if (A == C) + return OrFactorization(A, B, D); + if (A == D) + return OrFactorization(A, B, C); + if (B == C) + return OrFactorization(B, A, D); + if (B == D) + return OrFactorization(B, A, C, CondLogicOr && TrueLogicOr); } + } + + // We match the "full" 0 or 1 constant here to avoid a potential infinite + // loop with vectors that may have undefined/poison elements. + // select a, false, b -> select !a, b, false + if (match(TrueVal, m_Specific(Zero))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, FalseVal, Zero); + } + // select a, b, true -> select !a, true, b + if (match(FalseVal, m_Specific(One))) { + Value *NotCond = Builder.CreateNot(CondVal, "not." + CondVal->getName()); + return SelectInst::Create(NotCond, One, TrueVal); + } + + // DeMorgan in select form: !a && !b --> !(a || b) + // select !a, !b, false --> not (select a, true, b) + if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || TrueVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); - // select a, a, b -> select a, true, b - if (CondVal == TrueVal) - return replaceOperand(SI, 1, One); - // select a, b, a -> select a, b, false - if (CondVal == FalseVal) - return replaceOperand(SI, 2, Zero); + // DeMorgan in select form: !a || !b --> !(a && b) + // select !a, true, !b --> not (select a, b, false) + if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && + (CondVal->hasOneUse() || FalseVal->hasOneUse()) && + !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) + return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); - // select a, !a, b -> select !a, b, false - if (match(TrueVal, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(TrueVal, FalseVal, Zero); - // select a, b, !a -> select !a, true, b - if (match(FalseVal, m_Not(m_Specific(CondVal)))) - return SelectInst::Create(FalseVal, One, TrueVal); + // select (select a, true, b), true, b -> select a, true, b + if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && + match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) + return replaceOperand(SI, 0, A); + // select (select a, b, false), b, false -> select a, b, false + if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && + match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) + return replaceOperand(SI, 0, A); - Value *A, *B; + // ~(A & B) & (A | B) --> A ^ B + if (match(&SI, m_c_LogicalAnd(m_Not(m_LogicalAnd(m_Value(A), m_Value(B))), + m_c_LogicalOr(m_Deferred(A), m_Deferred(B))))) + return BinaryOperator::CreateXor(A, B); - // DeMorgan in select form: !a && !b --> !(a || b) - // select !a, !b, false --> not (select a, true, b) - if (match(&SI, m_LogicalAnd(m_Not(m_Value(A)), m_Not(m_Value(B)))) && - (CondVal->hasOneUse() || TrueVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, One, B)); + // select (~a | c), a, b -> and a, (or c, freeze(b)) + if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && + CondVal->hasOneUse()) { + FalseVal = Builder.CreateFreeze(FalseVal); + return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); + } + // select (~c & b), a, b -> and b, (or freeze(a), c) + if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && + CondVal->hasOneUse()) { + TrueVal = Builder.CreateFreeze(TrueVal); + return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + } + + if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { + Use *Y = nullptr; + bool IsAnd = match(FalseVal, m_Zero()) ? true : false; + Value *Op1 = IsAnd ? TrueVal : FalseVal; + if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { + auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); + InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); + replaceUse(*Y, FI); + return replaceInstUsesWith(SI, Op1); + } + + if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) + if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, + /* IsAnd */ IsAnd)) + return I; - // DeMorgan in select form: !a || !b --> !(a && b) - // select !a, true, !b --> not (select a, b, false) - if (match(&SI, m_LogicalOr(m_Not(m_Value(A)), m_Not(m_Value(B)))) && - (CondVal->hasOneUse() || FalseVal->hasOneUse()) && - !match(A, m_ConstantExpr()) && !match(B, m_ConstantExpr())) - return BinaryOperator::CreateNot(Builder.CreateSelect(A, B, Zero)); + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, + /* IsLogical */ true)) + return replaceInstUsesWith(SI, V); + } - // select (select a, true, b), true, b -> select a, true, b - if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(TrueVal, m_One()) && match(FalseVal, m_Specific(B))) + // select (a || b), c, false -> select a, c, false + // select c, (a || b), false -> select c, a, false + // if c implies that b is false. + if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(FalseVal, m_Zero())) { + std::optional<bool> Res = isImpliedCondition(TrueVal, B, DL); + if (Res && *Res == false) return replaceOperand(SI, 0, A); - // select (select a, b, false), b, false -> select a, b, false - if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && - match(TrueVal, m_Specific(B)) && match(FalseVal, m_Zero())) + } + if (match(TrueVal, m_LogicalOr(m_Value(A), m_Value(B))) && + match(FalseVal, m_Zero())) { + std::optional<bool> Res = isImpliedCondition(CondVal, B, DL); + if (Res && *Res == false) + return replaceOperand(SI, 1, A); + } + // select c, true, (a && b) -> select c, true, a + // select (a && b), true, c -> select a, true, c + // if c = false implies that b = true + if (match(TrueVal, m_One()) && + match(FalseVal, m_LogicalAnd(m_Value(A), m_Value(B)))) { + std::optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); + if (Res && *Res == true) + return replaceOperand(SI, 2, A); + } + if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && + match(TrueVal, m_One())) { + std::optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); + if (Res && *Res == true) return replaceOperand(SI, 0, A); + } + if (match(TrueVal, m_One())) { Value *C; - // select (~a | c), a, b -> and a, (or c, freeze(b)) - if (match(CondVal, m_c_Or(m_Not(m_Specific(TrueVal)), m_Value(C))) && - CondVal->hasOneUse()) { - FalseVal = Builder.CreateFreeze(FalseVal); - return BinaryOperator::CreateAnd(TrueVal, Builder.CreateOr(C, FalseVal)); - } - // select (~c & b), a, b -> and b, (or freeze(a), c) - if (match(CondVal, m_c_And(m_Not(m_Value(C)), m_Specific(FalseVal))) && - CondVal->hasOneUse()) { - TrueVal = Builder.CreateFreeze(TrueVal); - return BinaryOperator::CreateAnd(FalseVal, Builder.CreateOr(C, TrueVal)); + + // (C && A) || (!C && B) --> sel C, A, B + // (A && C) || (!C && B) --> sel C, A, B + // (C && A) || (B && !C) --> sel C, A, B + // (A && C) || (B && !C) --> sel C, A, B (may require freeze) + if (match(FalseVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(B))) && + match(CondVal, m_c_LogicalAnd(m_Specific(C), m_Value(A)))) { + auto *SelCond = dyn_cast<SelectInst>(CondVal); + auto *SelFVal = dyn_cast<SelectInst>(FalseVal); + bool MayNeedFreeze = SelCond && SelFVal && + match(SelFVal->getTrueValue(), + m_Not(m_Specific(SelCond->getTrueValue()))); + if (MayNeedFreeze) + C = Builder.CreateFreeze(C); + return SelectInst::Create(C, A, B); } - if (!SelType->isVectorTy()) { - if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, One, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 1, S); - if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, Zero, SQ, - /* AllowRefinement */ true)) - return replaceOperand(SI, 2, S); + // (!C && A) || (C && B) --> sel C, B, A + // (A && !C) || (C && B) --> sel C, B, A + // (!C && A) || (B && C) --> sel C, B, A + // (A && !C) || (B && C) --> sel C, B, A (may require freeze) + if (match(CondVal, m_c_LogicalAnd(m_Not(m_Value(C)), m_Value(A))) && + match(FalseVal, m_c_LogicalAnd(m_Specific(C), m_Value(B)))) { + auto *SelCond = dyn_cast<SelectInst>(CondVal); + auto *SelFVal = dyn_cast<SelectInst>(FalseVal); + bool MayNeedFreeze = SelCond && SelFVal && + match(SelCond->getTrueValue(), + m_Not(m_Specific(SelFVal->getTrueValue()))); + if (MayNeedFreeze) + C = Builder.CreateFreeze(C); + return SelectInst::Create(C, B, A); } + } - if (match(FalseVal, m_Zero()) || match(TrueVal, m_One())) { - Use *Y = nullptr; - bool IsAnd = match(FalseVal, m_Zero()) ? true : false; - Value *Op1 = IsAnd ? TrueVal : FalseVal; - if (isCheckForZeroAndMulWithOverflow(CondVal, Op1, IsAnd, Y)) { - auto *FI = new FreezeInst(*Y, (*Y)->getName() + ".fr"); - InsertNewInstBefore(FI, *cast<Instruction>(Y->getUser())); - replaceUse(*Y, FI); - return replaceInstUsesWith(SI, Op1); - } + return nullptr; +} - if (auto *Op1SI = dyn_cast<SelectInst>(Op1)) - if (auto *I = foldAndOrOfSelectUsingImpliedCond(CondVal, *Op1SI, - /* IsAnd */ IsAnd)) - return I; +Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + Type *SelType = SI.getType(); - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) - if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) - if (auto *V = foldAndOrOfICmps(ICmp0, ICmp1, SI, IsAnd, - /* IsLogical */ true)) - return replaceInstUsesWith(SI, V); - } + if (Value *V = simplifySelectInst(CondVal, TrueVal, FalseVal, + SQ.getWithInstruction(&SI))) + return replaceInstUsesWith(SI, V); - // select (select a, true, b), c, false -> select a, c, false - // select c, (select a, true, b), false -> select c, a, false - // if c implies that b is false. - if (match(CondVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(FalseVal, m_Zero())) { - Optional<bool> Res = isImpliedCondition(TrueVal, B, DL); - if (Res && *Res == false) - return replaceOperand(SI, 0, A); - } - if (match(TrueVal, m_Select(m_Value(A), m_One(), m_Value(B))) && - match(FalseVal, m_Zero())) { - Optional<bool> Res = isImpliedCondition(CondVal, B, DL); - if (Res && *Res == false) - return replaceOperand(SI, 1, A); - } - // select c, true, (select a, b, false) -> select c, true, a - // select (select a, b, false), true, c -> select a, true, c - // if c = false implies that b = true - if (match(TrueVal, m_One()) && - match(FalseVal, m_Select(m_Value(A), m_Value(B), m_Zero()))) { - Optional<bool> Res = isImpliedCondition(CondVal, B, DL, false); - if (Res && *Res == true) - return replaceOperand(SI, 2, A); - } - if (match(CondVal, m_Select(m_Value(A), m_Value(B), m_Zero())) && - match(TrueVal, m_One())) { - Optional<bool> Res = isImpliedCondition(FalseVal, B, DL, false); - if (Res && *Res == true) - return replaceOperand(SI, 0, A); - } + if (Instruction *I = canonicalizeSelectToShuffle(SI)) + return I; - // sel (sel c, a, false), true, (sel !c, b, false) -> sel c, a, b - // sel (sel !c, a, false), true, (sel c, b, false) -> sel c, b, a - Value *C1, *C2; - if (match(CondVal, m_Select(m_Value(C1), m_Value(A), m_Zero())) && - match(TrueVal, m_One()) && - match(FalseVal, m_Select(m_Value(C2), m_Value(B), m_Zero()))) { - if (match(C2, m_Not(m_Specific(C1)))) // first case - return SelectInst::Create(C1, A, B); - else if (match(C1, m_Not(m_Specific(C2)))) // second case - return SelectInst::Create(C2, B, A); - } + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, *this)) + return I; + + // If the type of select is not an integer type or if the condition and + // the selection type are not both scalar nor both vector types, there is no + // point in attempting to match these patterns. + Type *CondType = CondVal->getType(); + if (!isa<Constant>(CondVal) && SelType->isIntOrIntVectorTy() && + CondType->isVectorTy() == SelType->isVectorTy()) { + if (Value *S = simplifyWithOpReplaced(TrueVal, CondVal, + ConstantInt::getTrue(CondType), SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 1, S); + + if (Value *S = simplifyWithOpReplaced(FalseVal, CondVal, + ConstantInt::getFalse(CondType), SQ, + /* AllowRefinement */ true)) + return replaceOperand(SI, 2, S); + + // Handle patterns involving sext/zext + not explicitly, + // as simplifyWithOpReplaced() only looks past one instruction. + Value *NotCond; + + // select a, sext(!a), b -> select !a, b, 0 + // select a, zext(!a), b -> select !a, b, 0 + if (match(TrueVal, m_ZExtOrSExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, FalseVal, + Constant::getNullValue(SelType)); + + // select a, b, zext(!a) -> select !a, 1, b + if (match(FalseVal, m_ZExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, ConstantInt::get(SelType, 1), TrueVal); + + // select a, b, sext(!a) -> select !a, -1, b + if (match(FalseVal, m_SExt(m_CombineAnd(m_Value(NotCond), + m_Not(m_Specific(CondVal)))))) + return SelectInst::Create(NotCond, Constant::getAllOnesValue(SelType), + TrueVal); } + if (Instruction *R = foldSelectOfBools(SI)) + return R; + // Selecting between two integer or vector splat integer constants? // // Note that we don't handle a scalar select of vectors: @@ -2881,8 +3187,23 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); return replaceInstUsesWith(SI, NewSel); } + } + } + + if (isa<FPMathOperator>(SI)) { + // TODO: Try to forward-propagate FMF from select arms to the select. + + // Canonicalize select of FP values where NaN and -0.0 are not valid as + // minnum/maxnum intrinsics. + if (SI.hasNoNaNs() && SI.hasNoSignedZeros()) { + Value *X, *Y; + if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); - // NOTE: if we wanted to, this is where to detect MIN/MAX + if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) + return replaceInstUsesWith( + SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); } } @@ -2997,19 +3318,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } - // Canonicalize select of FP values where NaN and -0.0 are not valid as - // minnum/maxnum intrinsics. - if (isa<FPMathOperator>(SI) && SI.hasNoNaNs() && SI.hasNoSignedZeros()) { - Value *X, *Y; - if (match(&SI, m_OrdFMax(m_Value(X), m_Value(Y)))) - return replaceInstUsesWith( - SI, Builder.CreateBinaryIntrinsic(Intrinsic::maxnum, X, Y, &SI)); - - if (match(&SI, m_OrdFMin(m_Value(X), m_Value(Y)))) - return replaceInstUsesWith( - SI, Builder.CreateBinaryIntrinsic(Intrinsic::minnum, X, Y, &SI)); - } - // See if we can fold the select into a phi node if the condition is a select. if (auto *PN = dyn_cast<PHINode>(SI.getCondition())) // The true/false values have to be live in the PHI predecessor's blocks. @@ -3198,5 +3506,15 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } } + if (Instruction *I = foldNestedSelects(SI, Builder)) + return I; + + // Match logical variants of the pattern, + // and transform them iff that gets rid of inversions. + // (~x) | y --> ~(x & (~y)) + // (~x) & y --> ~(x | (~y)) + if (sinkNotIntoOtherHandOfLogicalOp(SI)) + return &SI; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 13c98b935adf..ec505381cc86 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -346,8 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, Value *X, *Y; auto matchFirstShift = [&](Value *V) { APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && - match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && + return match(V, + m_OneUse(m_BinOp(ShiftOpcode, m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); }; @@ -363,7 +363,7 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) Constant *ShiftSumC = ConstantExpr::getAdd(C0, C1); Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); - Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, C1); return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); } @@ -730,13 +730,34 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return BinaryOperator::Create( I.getOpcode(), Builder.CreateBinOp(I.getOpcode(), C2, C1), X); + bool IsLeftShift = I.getOpcode() == Instruction::Shl; + Type *Ty = I.getType(); + unsigned TypeBits = Ty->getScalarSizeInBits(); + + // (X / +DivC) >> (Width - 1) --> ext (X <= -DivC) + // (X / -DivC) >> (Width - 1) --> ext (X >= +DivC) + const APInt *DivC; + if (!IsLeftShift && match(C1, m_SpecificIntAllowUndef(TypeBits - 1)) && + match(Op0, m_SDiv(m_Value(X), m_APInt(DivC))) && !DivC->isZero() && + !DivC->isMinSignedValue()) { + Constant *NegDivC = ConstantInt::get(Ty, -(*DivC)); + ICmpInst::Predicate Pred = + DivC->isNegative() ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_SLE; + Value *Cmp = Builder.CreateICmp(Pred, X, NegDivC); + auto ExtOpcode = (I.getOpcode() == Instruction::AShr) ? Instruction::SExt + : Instruction::ZExt; + return CastInst::Create(ExtOpcode, Cmp, Ty); + } + const APInt *Op1C; if (!match(C1, m_APInt(Op1C))) return nullptr; + assert(!Op1C->uge(TypeBits) && + "Shift over the type width should have been removed already"); + // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. - bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( @@ -748,14 +769,6 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } - // See if we can simplify any instructions used by the instruction whose sole - // purpose is to compute bits we don't care about. - Type *Ty = I.getType(); - unsigned TypeBits = Ty->getScalarSizeInBits(); - assert(!Op1C->uge(TypeBits) && - "Shift over the type width should have been removed already"); - (void)TypeBits; - if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; @@ -826,6 +839,74 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *C1, return nullptr; } +// Tries to perform +// (lshr (add (zext X), (zext Y)), K) +// -> (icmp ult (add X, Y), X) +// where +// - The add's operands are zexts from a K-bits integer to a bigger type. +// - The add is only used by the shr, or by iK (or narrower) truncates. +// - The lshr type has more than 2 bits (other types are boolean math). +// - K > 1 +// note that +// - The resulting add cannot have nuw/nsw, else on overflow we get a +// poison value and the transform isn't legal anymore. +Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { + assert(I.getOpcode() == Instruction::LShr); + + Value *Add = I.getOperand(0); + Value *ShiftAmt = I.getOperand(1); + Type *Ty = I.getType(); + + if (Ty->getScalarSizeInBits() < 3) + return nullptr; + + const APInt *ShAmtAPInt = nullptr; + Value *X = nullptr, *Y = nullptr; + if (!match(ShiftAmt, m_APInt(ShAmtAPInt)) || + !match(Add, + m_Add(m_OneUse(m_ZExt(m_Value(X))), m_OneUse(m_ZExt(m_Value(Y)))))) + return nullptr; + + const unsigned ShAmt = ShAmtAPInt->getZExtValue(); + if (ShAmt == 1) + return nullptr; + + // X/Y are zexts from `ShAmt`-sized ints. + if (X->getType()->getScalarSizeInBits() != ShAmt || + Y->getType()->getScalarSizeInBits() != ShAmt) + return nullptr; + + // Make sure that `Add` is only used by `I` and `ShAmt`-truncates. + if (!Add->hasOneUse()) { + for (User *U : Add->users()) { + if (U == &I) + continue; + + TruncInst *Trunc = dyn_cast<TruncInst>(U); + if (!Trunc || Trunc->getType()->getScalarSizeInBits() > ShAmt) + return nullptr; + } + } + + // Insert at Add so that the newly created `NarrowAdd` will dominate it's + // users (i.e. `Add`'s users). + Instruction *AddInst = cast<Instruction>(Add); + Builder.SetInsertPoint(AddInst); + + Value *NarrowAdd = Builder.CreateAdd(X, Y, "add.narrowed"); + Value *Overflow = + Builder.CreateICmpULT(NarrowAdd, X, "add.narrowed.overflow"); + + // Replace the uses of the original add with a zext of the + // NarrowAdd's result. Note that all users at this stage are known to + // be ShAmt-sized truncs, or the lshr itself. + if (!Add->hasOneUse()) + replaceInstUsesWith(*AddInst, Builder.CreateZExt(NarrowAdd, Ty)); + + // Replace the LShr with a zext of the overflow check. + return new ZExtInst(Overflow, Ty); +} + Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -1046,11 +1127,21 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 - if (match(Op0, m_One()) && - match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) - return BinaryOperator::CreateLShr( - ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + if (match(Op0, m_One())) { + // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 + if (match(Op1, m_Sub(m_SpecificInt(BitWidth - 1), m_Value(X)))) + return BinaryOperator::CreateLShr( + ConstantInt::get(Ty, APInt::getSignMask(BitWidth)), X); + + // The only way to shift out the 1 is with an over-shift, so that would + // be poison with or without "nuw". Undef is excluded because (undef << X) + // is not undef (it is zero). + Constant *ConstantOne = cast<Constant>(Op0); + if (!I.hasNoUnsignedWrap() && !ConstantOne->containsUndefElement()) { + I.setHasNoUnsignedWrap(); + return &I; + } + } return nullptr; } @@ -1068,10 +1159,17 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); + Value *X; const APInt *C; + unsigned BitWidth = Ty->getScalarSizeInBits(); + + // (iN (~X) u>> (N - 1)) --> zext (X > -1) + if (match(Op0, m_OneUse(m_Not(m_Value(X)))) && + match(Op1, m_SpecificIntAllowUndef(BitWidth - 1))) + return new ZExtInst(Builder.CreateIsNotNeg(X, "isnotneg"), Ty); + if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || @@ -1276,6 +1374,18 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } } + // Reduce add-carry of bools to logic: + // ((zext BoolX) + (zext BoolY)) >> 1 --> zext (BoolX && BoolY) + Value *BoolX, *BoolY; + if (ShAmtC == 1 && match(Op0, m_Add(m_Value(X), m_Value(Y))) && + match(X, m_ZExt(m_Value(BoolX))) && match(Y, m_ZExt(m_Value(BoolY))) && + BoolX->getType()->isIntOrIntVectorTy(1) && + BoolY->getType()->isIntOrIntVectorTy(1) && + (X->hasOneUse() || Y->hasOneUse() || Op0->hasOneUse())) { + Value *And = Builder.CreateAnd(BoolX, BoolY); + return new ZExtInst(And, Ty); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { @@ -1285,13 +1395,15 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } // Transform (x << y) >> y to x & (-1 >> y) - Value *X; if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_Specific(Op1))))) { Constant *AllOnes = ConstantInt::getAllOnesValue(Ty); Value *Mask = Builder.CreateLShr(AllOnes, Op1); return BinaryOperator::CreateAnd(Mask, X); } + if (Instruction *Overflow = foldLShrOverflowBit(I)) + return Overflow; + return nullptr; } @@ -1469,8 +1581,11 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return R; // See if we can turn a signed shr into an unsigned shr. - if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) - return BinaryOperator::CreateLShr(Op0, Op1); + if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) { + Instruction *Lshr = BinaryOperator::CreateLShr(Op0, Op1); + Lshr->setIsExact(I.isExact()); + return Lshr; + } // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index febd0f51d25f..77d675422966 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -130,9 +130,6 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == MaxAnalysisRecursionDepth) return nullptr; - if (isa<ScalableVectorType>(VTy)) - return nullptr; - Instruction *I = dyn_cast<Instruction>(V); if (!I) { computeKnownBits(V, Known, Depth, CxtI); @@ -154,6 +151,20 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (Depth == 0 && !V->hasOneUse()) DemandedMask.setAllBits(); + // Update flags after simplifying an operand based on the fact that some high + // order bits are not demanded. + auto disableWrapFlagsBasedOnUnusedHighBits = [](Instruction *I, + unsigned NLZ) { + if (NLZ > 0) { + // Disable the nsw and nuw flags here: We can no longer guarantee that + // we won't wrap after simplification. Removing the nsw/nuw flags is + // legal here because the top bit is not demanded. + I->setHasNoSignedWrap(false); + I->setHasNoUnsignedWrap(false); + } + return I; + }; + // If the high-bits of an ADD/SUB/MUL are not demanded, then we do not care // about the high bits of the operands. auto simplifyOperandsBasedOnUnusedHighBits = [&](APInt &DemandedFromOps) { @@ -165,13 +176,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || ShrinkDemandedConstant(I, 1, DemandedFromOps) || SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { - if (NLZ > 0) { - // Disable the nsw and nuw flags here: We can no longer guarantee that - // we won't wrap after simplification. Removing the nsw/nuw flags is - // legal here because the top bit is not demanded. - I->setHasNoSignedWrap(false); - I->setHasNoUnsignedWrap(false); - } + disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); return true; } return false; @@ -397,7 +402,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } } } - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); @@ -416,7 +421,9 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (auto *DstVTy = dyn_cast<VectorType>(VTy)) { if (auto *SrcVTy = dyn_cast<VectorType>(I->getOperand(0)->getType())) { - if (cast<FixedVectorType>(DstVTy)->getNumElements() != + if (isa<ScalableVectorType>(DstVTy) || + isa<ScalableVectorType>(SrcVTy) || + cast<FixedVectorType>(DstVTy)->getNumElements() != cast<FixedVectorType>(SrcVTy)->getNumElements()) // Don't touch a bitcast between vectors of different element counts. return nullptr; @@ -461,7 +468,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!Known.hasConflict() && "Bits known to be one AND zero?"); break; } - case Instruction::Add: + case Instruction::Add: { if ((DemandedMask & 1) == 0) { // If we do not need the low bit, try to convert bool math to logic: // add iN (zext i1 X), (sext i1 Y) --> sext (~X & Y) to iN @@ -498,26 +505,68 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return Builder.CreateSExt(Or, VTy); } } - LLVM_FALLTHROUGH; + + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If low order bits are not demanded and known to be zero in one operand, + // then we don't need to demand them from the other operand, since they + // can't cause overflow into any bits that are demanded in the result. + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + APInt DemandedFromLHS = DemandedFromOps; + DemandedFromLHS.clearLowBits(NTZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If we are known to be adding zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Otherwise just compute the known bits of the result. + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(true, NSW, LHSKnown, RHSKnown); + break; + } case Instruction::Sub: { - APInt DemandedFromOps; - if (simplifyOperandsBasedOnUnusedHighBits(DemandedFromOps)) - return I; + // Right fill the mask of bits for the operands to demand the most + // significant bit and all those below it. + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + if (ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); + + // If low order bits are not demanded and are known to be zero in RHS, + // then we don't need to demand them from LHS, since they can't cause a + // borrow from any bits that are demanded in the result. + unsigned NTZ = (~DemandedMask & RHSKnown.Zero).countTrailingOnes(); + APInt DemandedFromLHS = DemandedFromOps; + DemandedFromLHS.clearLowBits(NTZ); + if (ShrinkDemandedConstant(I, 0, DemandedFromLHS) || + SimplifyDemandedBits(I, 0, DemandedFromLHS, LHSKnown, Depth + 1)) + return disableWrapFlagsBasedOnUnusedHighBits(I, NLZ); - // If we are known to be adding/subtracting zeros to every bit below + // If we are known to be subtracting zeros from every bit below // the highest demanded bit, we just return the other side. if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); // We can't do this with the LHS for subtraction, unless we are only // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && - DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + if (DemandedFromOps.isOne() && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); // Otherwise just compute the known bits of the result. bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); - Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add, - NSW, LHSKnown, RHSKnown); + Known = KnownBits::computeForAddSub(false, NSW, LHSKnown, RHSKnown); break; } case Instruction::Mul: { @@ -747,18 +796,18 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // UDiv doesn't demand low bits that are zero in the divisor. const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { - // If the shift is exact, then it does demand the low bits. - if (cast<UDivOperator>(I)->isExact()) - break; - - // FIXME: Take the demanded mask of the result into account. + // TODO: Take the demanded mask of the result into account. unsigned RHSTrailingZeros = SA->countTrailingZeros(); APInt DemandedMaskIn = APInt::getHighBitsSet(BitWidth, BitWidth - RHSTrailingZeros); - if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) + if (SimplifyDemandedBits(I, 0, DemandedMaskIn, LHSKnown, Depth + 1)) { + // We can't guarantee that "exact" is still true after changing the + // the dividend. + I->dropPoisonGeneratingFlags(); return I; + } - // Propagate zero bits from the input. + // Increase high zero bits from the input. Known.Zero.setHighBits(std::min( BitWidth, LHSKnown.Zero.countLeadingOnes() + RHSTrailingZeros)); } else { @@ -922,10 +971,10 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } default: { // Handle target specific intrinsics - Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( + std::optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( *II, DemandedMask, Known, KnownBitsComputed); if (V) - return V.value(); + return *V; break; } } @@ -962,11 +1011,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( // this instruction has a simpler value in that context. switch (I->getOpcode()) { case Instruction::And: { - // If either the LHS or the RHS are Zero, the result is zero. computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown & RHSKnown; // If the client is only demanding bits that we know, return the known @@ -975,8 +1021,7 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( return Constant::getIntegerValue(ITy, Known.One); // If all of the demanded bits are known 1 on one side, return the other. - // These bits cannot contribute to the result of the 'and' in this - // context. + // These bits cannot contribute to the result of the 'and' in this context. if (DemandedMask.isSubsetOf(LHSKnown.Zero | RHSKnown.One)) return I->getOperand(0); if (DemandedMask.isSubsetOf(RHSKnown.Zero | LHSKnown.One)) @@ -985,14 +1030,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Or: { - // We can simplify (X|Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - - // If either the LHS or the RHS are One, the result is One. computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown | RHSKnown; // If the client is only demanding bits that we know, return the known @@ -1000,9 +1039,10 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(ITy, Known.One); - // If all of the demanded bits are known zero on one side, return the - // other. These bits cannot contribute to the result of the 'or' in this - // context. + // We can simplify (X|Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + // If all of the demanded bits are known zero on one side, return the other. + // These bits cannot contribute to the result of the 'or' in this context. if (DemandedMask.isSubsetOf(LHSKnown.One | RHSKnown.Zero)) return I->getOperand(0); if (DemandedMask.isSubsetOf(RHSKnown.One | LHSKnown.Zero)) @@ -1011,13 +1051,8 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } case Instruction::Xor: { - // We can simplify (X^Y) -> X or Y in the user's context if we know that - // only bits from X or Y are demanded. - computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); - computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, - CxtI); - + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); Known = LHSKnown ^ RHSKnown; // If the client is only demanding bits that we know, return the known @@ -1025,8 +1060,9 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( if (DemandedMask.isSubsetOf(Known.Zero | Known.One)) return Constant::getIntegerValue(ITy, Known.One); - // If all of the demanded bits are known zero on one side, return the - // other. + // We can simplify (X^Y) -> X or Y in the user's context if we know that + // only bits from X or Y are demanded. + // If all of the demanded bits are known zero on one side, return the other. if (DemandedMask.isSubsetOf(RHSKnown.Zero)) return I->getOperand(0); if (DemandedMask.isSubsetOf(LHSKnown.Zero)) @@ -1034,6 +1070,34 @@ Value *InstCombinerImpl::SimplifyMultipleUseDemandedBits( break; } + case Instruction::Add: { + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + + // If an operand adds zeros to every bit below the highest demanded bit, + // that operand doesn't change the result. Return the other side. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + + computeKnownBits(I->getOperand(0), LHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + break; + } + case Instruction::Sub: { + unsigned NLZ = DemandedMask.countLeadingZeros(); + APInt DemandedFromOps = APInt::getLowBitsSet(BitWidth, BitWidth - NLZ); + + // If an operand subtracts zeros from every bit below the highest demanded + // bit, that operand doesn't change the result. Return the other side. + computeKnownBits(I->getOperand(1), RHSKnown, Depth + 1, CxtI); + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + + break; + } case Instruction::AShr: { // Compute the Known bits to simplify things downstream. computeKnownBits(I, Known, Depth, CxtI); @@ -1632,11 +1696,11 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, } default: { // Handle target specific intrinsics - Optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( + std::optional<Value *> V = targetSimplifyDemandedVectorEltsIntrinsic( *II, DemandedElts, UndefElts, UndefElts2, UndefElts3, simplifyAndSetOp); if (V) - return V.value(); + return *V; break; } } // switch on IntrinsicID diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index b80c58183dd5..61e62adbe327 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -105,7 +105,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // 2) Possibly more ExtractElements with the same index. // 3) Another operand, which will feed back into the PHI. Instruction *PHIUser = nullptr; - for (auto U : PN->users()) { + for (auto *U : PN->users()) { if (ExtractElementInst *EU = dyn_cast<ExtractElementInst>(U)) { if (EI.getIndexOperand() == EU->getIndexOperand()) Extracts.push_back(EU); @@ -171,7 +171,7 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, } } - for (auto E : Extracts) + for (auto *E : Extracts) replaceInstUsesWith(*E, scalarPHI); return &EI; @@ -187,13 +187,12 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { ElementCount NumElts = cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); Type *DestTy = Ext.getType(); + unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); bool IsBigEndian = DL.isBigEndian(); // If we are casting an integer to vector and extracting a portion, that is // a shift-right and truncate. - // TODO: Allow FP dest type by casting the trunc to FP? - if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && - isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + if (X->getType()->isIntegerTy()) { assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && "Expected fixed vector type for bitcast from scalar integer"); @@ -202,10 +201,18 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 if (IsBigEndian) ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; - unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); - if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { - Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); - return new TruncInst(Lshr, DestTy); + unsigned ShiftAmountC = ExtIndexC * DestWidth; + if (!ShiftAmountC || + (isDesirableIntType(X->getType()->getPrimitiveSizeInBits()) && + Ext.getVectorOperand()->hasOneUse())) { + if (ShiftAmountC) + X = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + if (DestTy->isFloatingPointTy()) { + Type *DstIntTy = IntegerType::getIntNTy(X->getContext(), DestWidth); + Value *Trunc = Builder.CreateTrunc(X, DstIntTy); + return new BitCastInst(Trunc, DestTy); + } + return new TruncInst(X, DestTy); } } @@ -278,7 +285,6 @@ Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { return nullptr; unsigned SrcWidth = SrcTy->getScalarSizeInBits(); - unsigned DestWidth = DestTy->getPrimitiveSizeInBits(); unsigned ShAmt = Chunk * DestWidth; // TODO: This limitation is more strict than necessary. We could sum the @@ -393,6 +399,20 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { SQ.getWithInstruction(&EI))) return replaceInstUsesWith(EI, V); + // extractelt (select %x, %vec1, %vec2), %const -> + // select %x, %vec1[%const], %vec2[%const] + // TODO: Support constant folding of multiple select operands: + // extractelt (select %x, %vec1, %vec2), (select %x, %c1, %c2) + // If the extractelement will for instance try to do out of bounds accesses + // because of the values of %c1 and/or %c2, the sequence could be optimized + // early. This is currently not possible because constant folding will reach + // an unreachable assertion if it doesn't find a constant operand. + if (SelectInst *SI = dyn_cast<SelectInst>(EI.getVectorOperand())) + if (SI->getCondition()->getType()->isIntegerTy() && + isa<Constant>(EI.getIndexOperand())) + if (Instruction *R = FoldOpIntoSelect(EI, SI)) + return R; + // If extracting a specified index from the vector, see if we can recursively // find a previously computed scalar that was inserted into the vector. auto *IndexC = dyn_cast<ConstantInt>(Index); @@ -850,17 +870,16 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( if (NumAggElts > 2) return nullptr; - static constexpr auto NotFound = None; + static constexpr auto NotFound = std::nullopt; static constexpr auto FoundMismatch = nullptr; // Try to find a value of each element of an aggregate. // FIXME: deal with more complex, not one-dimensional, aggregate types - SmallVector<Optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); + SmallVector<std::optional<Instruction *>, 2> AggElts(NumAggElts, NotFound); // Do we know values for each element of the aggregate? auto KnowAllElts = [&AggElts]() { - return all_of(AggElts, - [](Optional<Instruction *> Elt) { return Elt != NotFound; }); + return !llvm::is_contained(AggElts, NotFound); }; int Depth = 0; @@ -889,7 +908,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // Now, we may have already previously recorded the value for this element // of an aggregate. If we did, that means the CurrIVI will later be // overwritten with the already-recorded value. But if not, let's record it! - Optional<Instruction *> &Elt = AggElts[Indices.front()]; + std::optional<Instruction *> &Elt = AggElts[Indices.front()]; Elt = Elt.value_or(InsertedValue); // FIXME: should we handle chain-terminating undef base operand? @@ -919,7 +938,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( /// or different elements had different source aggregates. FoundMismatch }; - auto Describe = [](Optional<Value *> SourceAggregate) { + auto Describe = [](std::optional<Value *> SourceAggregate) { if (SourceAggregate == NotFound) return AggregateDescription::NotFound; if (*SourceAggregate == FoundMismatch) @@ -933,8 +952,8 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // If found, return the source aggregate from which the extraction was. // If \p PredBB is provided, does PHI translation of an \p Elt first. auto FindSourceAggregate = - [&](Instruction *Elt, unsigned EltIdx, Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { + [&](Instruction *Elt, unsigned EltIdx, std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { // For now(?), only deal with, at most, a single level of PHI indirection. if (UseBB && PredBB) Elt = dyn_cast<Instruction>(Elt->DoPHITranslation(*UseBB, *PredBB)); @@ -961,9 +980,9 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // see if we can find appropriate source aggregate for each of the elements, // and see it's the same aggregate for each element. If so, return it. auto FindCommonSourceAggregate = - [&](Optional<BasicBlock *> UseBB, - Optional<BasicBlock *> PredBB) -> Optional<Value *> { - Optional<Value *> SourceAggregate; + [&](std::optional<BasicBlock *> UseBB, + std::optional<BasicBlock *> PredBB) -> std::optional<Value *> { + std::optional<Value *> SourceAggregate; for (auto I : enumerate(AggElts)) { assert(Describe(SourceAggregate) != AggregateDescription::FoundMismatch && @@ -975,7 +994,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // For this element, is there a plausible source aggregate? // FIXME: we could special-case undef element, IFF we know that in the // source aggregate said element isn't poison. - Optional<Value *> SourceAggregateForElement = + std::optional<Value *> SourceAggregateForElement = FindSourceAggregate(*I.value(), I.index(), UseBB, PredBB); // Okay, what have we found? Does that correlate with previous findings? @@ -1009,10 +1028,11 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( return *SourceAggregate; }; - Optional<Value *> SourceAggregate; + std::optional<Value *> SourceAggregate; // Can we find the source aggregate without looking at predecessors? - SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/None, /*PredBB=*/None); + SourceAggregate = FindCommonSourceAggregate(/*UseBB=*/std::nullopt, + /*PredBB=*/std::nullopt); if (Describe(SourceAggregate) != AggregateDescription::NotFound) { if (Describe(SourceAggregate) == AggregateDescription::FoundMismatch) return nullptr; // Conflicting source aggregates! @@ -1029,7 +1049,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( // they all should be defined in the same basic block. BasicBlock *UseBB = nullptr; - for (const Optional<Instruction *> &I : AggElts) { + for (const std::optional<Instruction *> &I : AggElts) { BasicBlock *BB = (*I)->getParent(); // If it's the first instruction we've encountered, record the basic block. if (!UseBB) { @@ -1495,6 +1515,71 @@ static Instruction *narrowInsElt(InsertElementInst &InsElt, return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); } +/// If we are inserting 2 halves of a value into adjacent elements of a vector, +/// try to convert to a single insert with appropriate bitcasts. +static Instruction *foldTruncInsEltPair(InsertElementInst &InsElt, + bool IsBigEndian, + InstCombiner::BuilderTy &Builder) { + Value *VecOp = InsElt.getOperand(0); + Value *ScalarOp = InsElt.getOperand(1); + Value *IndexOp = InsElt.getOperand(2); + + // Pattern depends on endian because we expect lower index is inserted first. + // Big endian: + // inselt (inselt BaseVec, (trunc (lshr X, BW/2), Index0), (trunc X), Index1 + // Little endian: + // inselt (inselt BaseVec, (trunc X), Index0), (trunc (lshr X, BW/2)), Index1 + // Note: It is not safe to do this transform with an arbitrary base vector + // because the bitcast of that vector to fewer/larger elements could + // allow poison to spill into an element that was not poison before. + // TODO: Detect smaller fractions of the scalar. + // TODO: One-use checks are conservative. + auto *VTy = dyn_cast<FixedVectorType>(InsElt.getType()); + Value *Scalar0, *BaseVec; + uint64_t Index0, Index1; + if (!VTy || (VTy->getNumElements() & 1) || + !match(IndexOp, m_ConstantInt(Index1)) || + !match(VecOp, m_InsertElt(m_Value(BaseVec), m_Value(Scalar0), + m_ConstantInt(Index0))) || + !match(BaseVec, m_Undef())) + return nullptr; + + // The first insert must be to the index one less than this one, and + // the first insert must be to an even index. + if (Index0 + 1 != Index1 || Index0 & 1) + return nullptr; + + // For big endian, the high half of the value should be inserted first. + // For little endian, the low half of the value should be inserted first. + Value *X; + uint64_t ShAmt; + if (IsBigEndian) { + if (!match(ScalarOp, m_Trunc(m_Value(X))) || + !match(Scalar0, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } else { + if (!match(Scalar0, m_Trunc(m_Value(X))) || + !match(ScalarOp, m_Trunc(m_LShr(m_Specific(X), m_ConstantInt(ShAmt))))) + return nullptr; + } + + Type *SrcTy = X->getType(); + unsigned ScalarWidth = SrcTy->getScalarSizeInBits(); + unsigned VecEltWidth = VTy->getScalarSizeInBits(); + if (ScalarWidth != VecEltWidth * 2 || ShAmt != VecEltWidth) + return nullptr; + + // Bitcast the base vector to a vector type with the source element type. + Type *CastTy = FixedVectorType::get(SrcTy, VTy->getNumElements() / 2); + Value *CastBaseVec = Builder.CreateBitCast(BaseVec, CastTy); + + // Scale the insert index for a vector with half as many elements. + // bitcast (inselt (bitcast BaseVec), X, NewIndex) + uint64_t NewIndex = IsBigEndian ? Index1 / 2 : Index0 / 2; + Value *NewInsert = Builder.CreateInsertElement(CastBaseVec, X, NewIndex); + return new BitCastInst(NewInsert, VTy); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1505,10 +1590,22 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { return replaceInstUsesWith(IE, V); // Canonicalize type of constant indices to i64 to simplify CSE - if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) + if (auto *IndexC = dyn_cast<ConstantInt>(IdxOp)) { if (auto *NewIdx = getPreferredVectorIndex(IndexC)) return replaceOperand(IE, 2, NewIdx); + Value *BaseVec, *OtherScalar; + uint64_t OtherIndexVal; + if (match(VecOp, m_OneUse(m_InsertElt(m_Value(BaseVec), + m_Value(OtherScalar), + m_ConstantInt(OtherIndexVal)))) && + !isa<Constant>(OtherScalar) && OtherIndexVal > IndexC->getZExtValue()) { + Value *NewIns = Builder.CreateInsertElement(BaseVec, ScalarOp, IdxOp); + return InsertElementInst::Create(NewIns, OtherScalar, + Builder.getInt64(OtherIndexVal)); + } + } + // If the scalar is bitcast and inserted into undef, do the insert in the // source type followed by bitcast. // TODO: Generalize for insert into any constant, not just undef? @@ -1622,6 +1719,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Ext = narrowInsElt(IE, Builder)) return Ext; + if (Instruction *Ext = foldTruncInsEltPair(IE, DL.isBigEndian(), Builder)) + return Ext; + return nullptr; } @@ -1653,7 +1753,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // from an undefined element in an operand. if (llvm::is_contained(Mask, -1)) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -1700,8 +1800,8 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, // Verify that 'CI' does not occur twice in Mask. A single 'insertelement' // can't put an element into multiple indices. bool SeenOnce = false; - for (int i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] == ElementNumber) { + for (int I : Mask) { + if (I == ElementNumber) { if (SeenOnce) return false; SeenOnce = true; @@ -1957,6 +2057,56 @@ static BinopElts getAlternateBinop(BinaryOperator *BO, const DataLayout &DL) { return {}; } +/// A select shuffle of a select shuffle with a shared operand can be reduced +/// to a single select shuffle. This is an obvious improvement in IR, and the +/// backend is expected to lower select shuffles efficiently. +static Instruction *foldSelectShuffleOfSelectShuffle(ShuffleVectorInst &Shuf) { + assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); + + Value *Op0 = Shuf.getOperand(0), *Op1 = Shuf.getOperand(1); + SmallVector<int, 16> Mask; + Shuf.getShuffleMask(Mask); + unsigned NumElts = Mask.size(); + + // Canonicalize a select shuffle with common operand as Op1. + auto *ShufOp = dyn_cast<ShuffleVectorInst>(Op0); + if (ShufOp && ShufOp->isSelect() && + (ShufOp->getOperand(0) == Op1 || ShufOp->getOperand(1) == Op1)) { + std::swap(Op0, Op1); + ShuffleVectorInst::commuteShuffleMask(Mask, NumElts); + } + + ShufOp = dyn_cast<ShuffleVectorInst>(Op1); + if (!ShufOp || !ShufOp->isSelect() || + (ShufOp->getOperand(0) != Op0 && ShufOp->getOperand(1) != Op0)) + return nullptr; + + Value *X = ShufOp->getOperand(0), *Y = ShufOp->getOperand(1); + SmallVector<int, 16> Mask1; + ShufOp->getShuffleMask(Mask1); + assert(Mask1.size() == NumElts && "Vector size changed with select shuffle"); + + // Canonicalize common operand (Op0) as X (first operand of first shuffle). + if (Y == Op0) { + std::swap(X, Y); + ShuffleVectorInst::commuteShuffleMask(Mask1, NumElts); + } + + // If the mask chooses from X (operand 0), it stays the same. + // If the mask chooses from the earlier shuffle, the other mask value is + // transferred to the combined select shuffle: + // shuf X, (shuf X, Y, M1), M --> shuf X, Y, M' + SmallVector<int, 16> NewMask(NumElts); + for (unsigned i = 0; i != NumElts; ++i) + NewMask[i] = Mask[i] < (signed)NumElts ? Mask[i] : Mask1[i]; + + // A select mask with undef elements might look like an identity mask. + assert((ShuffleVectorInst::isSelectMask(NewMask) || + ShuffleVectorInst::isIdentityMask(NewMask)) && + "Unexpected shuffle mask"); + return new ShuffleVectorInst(X, Y, NewMask); +} + static Instruction *foldSelectShuffleWith1Binop(ShuffleVectorInst &Shuf) { assert(Shuf.isSelect() && "Must have select-equivalent shuffle"); @@ -2061,6 +2211,9 @@ Instruction *InstCombinerImpl::foldSelectShuffle(ShuffleVectorInst &Shuf) { return &Shuf; } + if (Instruction *I = foldSelectShuffleOfSelectShuffle(Shuf)) + return I; + if (Instruction *I = foldSelectShuffleWith1Binop(Shuf)) return I; @@ -2541,6 +2694,35 @@ static Instruction *foldIdentityPaddedShuffles(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, NewMask); } +// Splatting the first element of the result of a BinOp, where any of the +// BinOp's operands are the result of a first element splat can be simplified to +// splatting the first element of the result of the BinOp +Instruction *InstCombinerImpl::simplifyBinOpSplats(ShuffleVectorInst &SVI) { + if (!match(SVI.getOperand(1), m_Undef()) || + !match(SVI.getShuffleMask(), m_ZeroMask())) + return nullptr; + + Value *Op0 = SVI.getOperand(0); + Value *X, *Y; + if (!match(Op0, m_BinOp(m_Shuffle(m_Value(X), m_Undef(), m_ZeroMask()), + m_Value(Y))) && + !match(Op0, m_BinOp(m_Value(X), + m_Shuffle(m_Value(Y), m_Undef(), m_ZeroMask())))) + return nullptr; + if (X->getType() != Y->getType()) + return nullptr; + + auto *BinOp = cast<BinaryOperator>(Op0); + if (!isSafeToSpeculativelyExecute(BinOp)) + return nullptr; + + Value *NewBO = Builder.CreateBinOp(BinOp->getOpcode(), X, Y); + if (auto NewBOI = dyn_cast<Instruction>(NewBO)) + NewBOI->copyIRFlags(BinOp); + + return new ShuffleVectorInst(NewBO, SVI.getShuffleMask()); +} + Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *LHS = SVI.getOperand(0); Value *RHS = SVI.getOperand(1); @@ -2549,7 +2731,9 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { SVI.getType(), ShufQuery)) return replaceInstUsesWith(SVI, V); - // Bail out for scalable vectors + if (Instruction *I = simplifyBinOpSplats(SVI)) + return I; + if (isa<ScalableVectorType>(LHS->getType())) return nullptr; @@ -2694,7 +2878,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { Value *V = LHS; unsigned MaskElems = Mask.size(); auto *SrcTy = cast<FixedVectorType>(V->getType()); - unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedSize(); + unsigned VecBitWidth = SrcTy->getPrimitiveSizeInBits().getFixedValue(); unsigned SrcElemBitWidth = DL.getTypeSizeInBits(SrcTy->getElementType()); assert(SrcElemBitWidth && "vector elements must have a bitwidth"); unsigned SrcNumElems = SrcTy->getNumElements(); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 71c763de43b4..fb6f4f96ea48 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -38,7 +38,6 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" @@ -99,16 +98,19 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> #include <cstdint> #include <memory> +#include <optional> #include <string> #include <utility> #define DEBUG_TYPE "instcombine" #include "llvm/Transforms/Utils/InstructionWorklist.h" +#include <optional> using namespace llvm; using namespace llvm::PatternMatch; @@ -167,16 +169,16 @@ MaxArraySize("instcombine-maxarray-size", cl::init(1024), static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", cl::Hidden, cl::init(true)); -Optional<Instruction *> +std::optional<Instruction *> InstCombiner::targetInstCombineIntrinsic(IntrinsicInst &II) { // Handle target specific intrinsics if (II.getCalledFunction()->isTargetIntrinsic()) { return TTI.instCombineIntrinsic(*this, II); } - return None; + return std::nullopt; } -Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( +std::optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( IntrinsicInst &II, APInt DemandedMask, KnownBits &Known, bool &KnownBitsComputed) { // Handle target specific intrinsics @@ -184,10 +186,10 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedUseBitsIntrinsic( return TTI.simplifyDemandedUseBitsIntrinsic(*this, II, DemandedMask, Known, KnownBitsComputed); } - return None; + return std::nullopt; } -Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( +std::optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( IntrinsicInst &II, APInt DemandedElts, APInt &UndefElts, APInt &UndefElts2, APInt &UndefElts3, std::function<void(Instruction *, unsigned, APInt, APInt &)> @@ -198,11 +200,11 @@ Optional<Value *> InstCombiner::targetSimplifyDemandedVectorEltsIntrinsic( *this, II, DemandedElts, UndefElts, UndefElts2, UndefElts3, SimplifyAndSetOp); } - return None; + return std::nullopt; } Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { - return llvm::EmitGEPOffset(&Builder, DL, GEP); + return llvm::emitGEPOffset(&Builder, DL, GEP); } /// Legal integers and common types are considered desirable. This is used to @@ -223,11 +225,12 @@ bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const { /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. -/// We don't want to convert from a legal to an illegal type or from a smaller -/// to a larger illegal type. A width of '1' is always treated as a desirable -/// type because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. Common/desirable widths are equally treated as -/// legal to convert to, in order to open up more combining opportunities. +/// We don't want to convert from a legal or desirable type (like i8) to an +/// illegal type or from a smaller to a larger illegal type. A width of '1' +/// is always treated as a desirable type because i1 is a fundamental type in +/// IR, and there are many specialized optimizations for i1 types. +/// Common/desirable widths are equally treated as legal to convert to, in +/// order to open up more combining opportunities. bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); @@ -238,9 +241,9 @@ bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, if (ToWidth < FromWidth && isDesirableIntType(ToWidth)) return true; - // If this is a legal integer from type, and the result would be an illegal - // type, don't do the transformation. - if (FromLegal && !ToLegal) + // If this is a legal or desiable integer from type, and the result would be + // an illegal type, don't do the transformation. + if ((FromLegal || isDesirableIntType(FromWidth)) && !ToLegal) return false; // Otherwise, if both are illegal, do not increase the size of the result. We @@ -367,14 +370,14 @@ static bool simplifyAssocCastAssoc(BinaryOperator *BinOp1, // inttoptr ( ptrtoint (x) ) --> x Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { auto *IntToPtr = dyn_cast<IntToPtrInst>(Val); - if (IntToPtr && DL.getPointerTypeSizeInBits(IntToPtr->getDestTy()) == + if (IntToPtr && DL.getTypeSizeInBits(IntToPtr->getDestTy()) == DL.getTypeSizeInBits(IntToPtr->getSrcTy())) { auto *PtrToInt = dyn_cast<PtrToIntInst>(IntToPtr->getOperand(0)); Type *CastTy = IntToPtr->getDestTy(); if (PtrToInt && CastTy->getPointerAddressSpace() == PtrToInt->getSrcTy()->getPointerAddressSpace() && - DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == + DL.getTypeSizeInBits(PtrToInt->getSrcTy()) == DL.getTypeSizeInBits(PtrToInt->getDestTy())) { return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, "", PtrToInt); @@ -632,14 +635,14 @@ getBinOpsForFactorization(Instruction::BinaryOps TopOpcode, BinaryOperator *Op, /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). -Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, - Instruction::BinaryOps InnerOpcode, - Value *A, Value *B, Value *C, - Value *D) { +static Value *tryFactorization(BinaryOperator &I, const SimplifyQuery &SQ, + InstCombiner::BuilderTy &Builder, + Instruction::BinaryOps InnerOpcode, Value *A, + Value *B, Value *C, Value *D) { assert(A && B && C && D && "All values must be provided"); Value *V = nullptr; - Value *SimplifiedInst = nullptr; + Value *RetVal = nullptr; Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); @@ -647,7 +650,7 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, bool InnerCommutative = Instruction::isCommutative(InnerOpcode); // Does "X op' (Y op Z)" always equal "(X op' Y) op (X op' Z)"? - if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) + if (leftDistributesOverRight(InnerOpcode, TopLevelOpcode)) { // Does the instruction have the form "(A op' B) op (A op' D)" or, in the // commutative case, "(A op' B) op (C op' A)"? if (A == C || (InnerCommutative && A == D)) { @@ -656,17 +659,18 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, // Consider forming "A op' (B op D)". // If "B op D" simplifies then it can be formed with no cost. V = simplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I)); - // If "B op D" doesn't simplify then only go on if both of the existing + + // If "B op D" doesn't simplify then only go on if one of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. - if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + if (!V && (LHS->hasOneUse() || RHS->hasOneUse())) V = Builder.CreateBinOp(TopLevelOpcode, B, D, RHS->getName()); - if (V) { - SimplifiedInst = Builder.CreateBinOp(InnerOpcode, A, V); - } + if (V) + RetVal = Builder.CreateBinOp(InnerOpcode, A, V); } + } // Does "(X op Y) op' Z" always equal "(X op' Z) op (Y op' Z)"? - if (!SimplifiedInst && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) + if (!RetVal && rightDistributesOverLeft(TopLevelOpcode, InnerOpcode)) { // Does the instruction have the form "(A op' B) op (C op' B)" or, in the // commutative case, "(A op' B) op (B op' D)"? if (B == D || (InnerCommutative && B == C)) { @@ -676,61 +680,94 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, // If "A op C" simplifies then it can be formed with no cost. V = simplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)); - // If "A op C" doesn't simplify then only go on if both of the existing + // If "A op C" doesn't simplify then only go on if one of the existing // operations "A op' B" and "C op' D" will be zapped as no longer used. - if (!V && LHS->hasOneUse() && RHS->hasOneUse()) + if (!V && (LHS->hasOneUse() || RHS->hasOneUse())) V = Builder.CreateBinOp(TopLevelOpcode, A, C, LHS->getName()); - if (V) { - SimplifiedInst = Builder.CreateBinOp(InnerOpcode, V, B); - } + if (V) + RetVal = Builder.CreateBinOp(InnerOpcode, V, B); } + } - if (SimplifiedInst) { - ++NumFactor; - SimplifiedInst->takeName(&I); + if (!RetVal) + return nullptr; - // Check if we can add NSW/NUW flags to SimplifiedInst. If so, set them. - if (BinaryOperator *BO = dyn_cast<BinaryOperator>(SimplifiedInst)) { - if (isa<OverflowingBinaryOperator>(SimplifiedInst)) { - bool HasNSW = false; - bool HasNUW = false; - if (isa<OverflowingBinaryOperator>(&I)) { - HasNSW = I.hasNoSignedWrap(); - HasNUW = I.hasNoUnsignedWrap(); - } + ++NumFactor; + RetVal->takeName(&I); - if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { - HasNSW &= LOBO->hasNoSignedWrap(); - HasNUW &= LOBO->hasNoUnsignedWrap(); - } + // Try to add no-overflow flags to the final value. + if (isa<OverflowingBinaryOperator>(RetVal)) { + bool HasNSW = false; + bool HasNUW = false; + if (isa<OverflowingBinaryOperator>(&I)) { + HasNSW = I.hasNoSignedWrap(); + HasNUW = I.hasNoUnsignedWrap(); + } + if (auto *LOBO = dyn_cast<OverflowingBinaryOperator>(LHS)) { + HasNSW &= LOBO->hasNoSignedWrap(); + HasNUW &= LOBO->hasNoUnsignedWrap(); + } - if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { - HasNSW &= ROBO->hasNoSignedWrap(); - HasNUW &= ROBO->hasNoUnsignedWrap(); - } + if (auto *ROBO = dyn_cast<OverflowingBinaryOperator>(RHS)) { + HasNSW &= ROBO->hasNoSignedWrap(); + HasNUW &= ROBO->hasNoUnsignedWrap(); + } - if (TopLevelOpcode == Instruction::Add && - InnerOpcode == Instruction::Mul) { - // We can propagate 'nsw' if we know that - // %Y = mul nsw i16 %X, C - // %Z = add nsw i16 %Y, %X - // => - // %Z = mul nsw i16 %X, C+1 - // - // iff C+1 isn't INT_MIN - const APInt *CInt; - if (match(V, m_APInt(CInt))) { - if (!CInt->isMinSignedValue()) - BO->setHasNoSignedWrap(HasNSW); - } + if (TopLevelOpcode == Instruction::Add && InnerOpcode == Instruction::Mul) { + // We can propagate 'nsw' if we know that + // %Y = mul nsw i16 %X, C + // %Z = add nsw i16 %Y, %X + // => + // %Z = mul nsw i16 %X, C+1 + // + // iff C+1 isn't INT_MIN + const APInt *CInt; + if (match(V, m_APInt(CInt)) && !CInt->isMinSignedValue()) + cast<Instruction>(RetVal)->setHasNoSignedWrap(HasNSW); - // nuw can be propagated with any constant or nuw value. - BO->setHasNoUnsignedWrap(HasNUW); - } - } + // nuw can be propagated with any constant or nuw value. + cast<Instruction>(RetVal)->setHasNoUnsignedWrap(HasNUW); } } - return SimplifiedInst; + return RetVal; +} + +Value *InstCombinerImpl::tryFactorizationFolds(BinaryOperator &I) { + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); + BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); + BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); + Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); + Value *A, *B, *C, *D; + Instruction::BinaryOps LHSOpcode, RHSOpcode; + + if (Op0) + LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); + if (Op1) + RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); + + // The instruction has the form "(A op' B) op (C op' D)". Try to factorize + // a common term. + if (Op0 && Op1 && LHSOpcode == RHSOpcode) + if (Value *V = tryFactorization(I, SQ, Builder, LHSOpcode, A, B, C, D)) + return V; + + // The instruction has the form "(A op' B) op (C)". Try to factorize common + // term. + if (Op0) + if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, LHSOpcode, A, B, RHS, Ident)) + return V; + + // The instruction has the form "(B) op (C op' D)". Try to factorize common + // term. + if (Op1) + if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) + if (Value *V = + tryFactorization(I, SQ, Builder, RHSOpcode, LHS, Ident, C, D)) + return V; + + return nullptr; } /// This tries to simplify binary operations which some other binary operation @@ -738,41 +775,15 @@ Value *InstCombinerImpl::tryFactorization(BinaryOperator &I, /// (eg "(A*B)+(A*C)" -> "A*(B+C)") or expanding out if this results in /// simplifications (eg: "A & (B | C) -> (A&B) | (A&C)" if this is a win). /// Returns the simplified value, or null if it didn't simplify. -Value *InstCombinerImpl::SimplifyUsingDistributiveLaws(BinaryOperator &I) { +Value *InstCombinerImpl::foldUsingDistributiveLaws(BinaryOperator &I) { Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); BinaryOperator *Op0 = dyn_cast<BinaryOperator>(LHS); BinaryOperator *Op1 = dyn_cast<BinaryOperator>(RHS); Instruction::BinaryOps TopLevelOpcode = I.getOpcode(); - { - // Factorization. - Value *A, *B, *C, *D; - Instruction::BinaryOps LHSOpcode, RHSOpcode; - if (Op0) - LHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op0, A, B); - if (Op1) - RHSOpcode = getBinOpsForFactorization(TopLevelOpcode, Op1, C, D); - - // The instruction has the form "(A op' B) op (C op' D)". Try to factorize - // a common term. - if (Op0 && Op1 && LHSOpcode == RHSOpcode) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, C, D)) - return V; - - // The instruction has the form "(A op' B) op (C)". Try to factorize common - // term. - if (Op0) - if (Value *Ident = getIdentityValue(LHSOpcode, RHS)) - if (Value *V = tryFactorization(I, LHSOpcode, A, B, RHS, Ident)) - return V; - - // The instruction has the form "(B) op (C op' D)". Try to factorize common - // term. - if (Op1) - if (Value *Ident = getIdentityValue(RHSOpcode, LHS)) - if (Value *V = tryFactorization(I, RHSOpcode, LHS, Ident, C, D)) - return V; - } + // Factorization. + if (Value *R = tryFactorizationFolds(I)) + return R; // Expansion. if (Op0 && rightDistributesOverLeft(Op0->getOpcode(), TopLevelOpcode)) { @@ -876,6 +887,28 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Cond, *True = nullptr, *False = nullptr; + + // Special-case for add/negate combination. Replace the zero in the negation + // with the trailing add operand: + // (Cond ? TVal : -N) + Z --> Cond ? True : (Z - N) + // (Cond ? -N : FVal) + Z --> Cond ? (Z - N) : False + auto foldAddNegate = [&](Value *TVal, Value *FVal, Value *Z) -> Value * { + // We need an 'add' and exactly 1 arm of the select to have been simplified. + if (Opcode != Instruction::Add || (!True && !False) || (True && False)) + return nullptr; + + Value *N; + if (True && match(FVal, m_Neg(m_Value(N)))) { + Value *Sub = Builder.CreateSub(Z, N); + return Builder.CreateSelect(Cond, True, Sub, I.getName()); + } + if (False && match(TVal, m_Neg(m_Value(N)))) { + Value *Sub = Builder.CreateSub(Z, N); + return Builder.CreateSelect(Cond, Sub, False, I.getName()); + } + return nullptr; + }; + if (LHSIsSelect && RHSIsSelect && A == D) { // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) Cond = A; @@ -893,11 +926,15 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Cond = A; True = simplifyBinOp(Opcode, B, RHS, FMF, Q); False = simplifyBinOp(Opcode, C, RHS, FMF, Q); + if (Value *NewSel = foldAddNegate(B, C, RHS)) + return NewSel; } else if (RHSIsSelect && RHS->hasOneUse()) { // X op (D ? E : F) -> D ? (X op E) : (X op F) Cond = D; True = simplifyBinOp(Opcode, LHS, E, FMF, Q); False = simplifyBinOp(Opcode, LHS, F, FMF, Q); + if (Value *NewSel = foldAddNegate(E, F, LHS)) + return NewSel; } if (!True || !False) @@ -910,8 +947,10 @@ Value *InstCombinerImpl::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, /// Freely adapt every user of V as-if V was changed to !V. /// WARNING: only if canFreelyInvertAllUsersOf() said this can be done. -void InstCombinerImpl::freelyInvertAllUsersOf(Value *I) { - for (User *U : I->users()) { +void InstCombinerImpl::freelyInvertAllUsersOf(Value *I, Value *IgnoredUser) { + for (User *U : make_early_inc_range(I->users())) { + if (U == IgnoredUser) + continue; // Don't consider this user. switch (cast<Instruction>(U)->getOpcode()) { case Instruction::Select: { auto *SI = cast<SelectInst>(U); @@ -1033,6 +1072,9 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); } + if (auto *EI = dyn_cast<ExtractElementInst>(&I)) + return Builder.CreateExtractElement(SO, EI->getIndexOperand()); + assert(I.isBinaryOp() && "Unexpected opcode for select folding"); // Figure out if the constant is the left or the right argument. @@ -1133,22 +1175,6 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, SelectInst *SI, return SelectInst::Create(SI->getCondition(), NewTV, NewFV, "", nullptr, SI); } -static Value *foldOperationIntoPhiValue(BinaryOperator *I, Value *InV, - InstCombiner::BuilderTy &Builder) { - bool ConstIsRHS = isa<Constant>(I->getOperand(1)); - Constant *C = cast<Constant>(I->getOperand(ConstIsRHS)); - - Value *Op0 = InV, *Op1 = C; - if (!ConstIsRHS) - std::swap(Op0, Op1); - - Value *RI = Builder.CreateBinOp(I->getOpcode(), Op0, Op1, "phi.bo"); - auto *FPInst = dyn_cast<Instruction>(RI); - if (FPInst && isa<FPMathOperator>(FPInst)) - FPInst->copyFastMathFlags(I); - return RI; -} - Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { unsigned NumPHIValues = PN->getNumIncomingValues(); if (NumPHIValues == 0) @@ -1167,48 +1193,69 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // Otherwise, we can replace *all* users with the new PHI we form. } - // Check to see if all of the operands of the PHI are simple constants - // (constantint/constantfp/undef). If there is one non-constant value, - // remember the BB it is in. If there is more than one or if *it* is a PHI, - // bail out. We don't do arbitrary constant expressions here because moving - // their computation can be expensive without a cost model. - BasicBlock *NonConstBB = nullptr; + // Check to see whether the instruction can be folded into each phi operand. + // If there is one operand that does not fold, remember the BB it is in. + // If there is more than one or if *it* is a PHI, bail out. + SmallVector<Value *> NewPhiValues; + BasicBlock *NonSimplifiedBB = nullptr; + Value *NonSimplifiedInVal = nullptr; for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InVal = PN->getIncomingValue(i); - // For non-freeze, require constant operand - // For freeze, require non-undef, non-poison operand - if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant())) - continue; - if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal)) + BasicBlock *InBB = PN->getIncomingBlock(i); + + // NB: It is a precondition of this transform that the operands be + // phi translatable! This is usually trivially satisfied by limiting it + // to constant ops, and for selects we do a more sophisticated check. + SmallVector<Value *> Ops; + for (Value *Op : I.operands()) { + if (Op == PN) + Ops.push_back(InVal); + else + Ops.push_back(Op->DoPHITranslation(PN->getParent(), InBB)); + } + + // Don't consider the simplification successful if we get back a constant + // expression. That's just an instruction in hiding. + // Also reject the case where we simplify back to the phi node. We wouldn't + // be able to remove it in that case. + Value *NewVal = simplifyInstructionWithOperands( + &I, Ops, SQ.getWithInstruction(InBB->getTerminator())); + if (NewVal && NewVal != PN && !match(NewVal, m_ConstantExpr())) { + NewPhiValues.push_back(NewVal); continue; + } if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. - if (NonConstBB) return nullptr; // More than one non-const value. + if (NonSimplifiedBB) return nullptr; // More than one non-simplified value. - NonConstBB = PN->getIncomingBlock(i); + NonSimplifiedBB = InBB; + NonSimplifiedInVal = InVal; + NewPhiValues.push_back(nullptr); // If the InVal is an invoke at the end of the pred block, then we can't // insert a computation after it without breaking the edge. if (isa<InvokeInst>(InVal)) - if (cast<Instruction>(InVal)->getParent() == NonConstBB) + if (cast<Instruction>(InVal)->getParent() == NonSimplifiedBB) return nullptr; // If the incoming non-constant value is reachable from the phis block, // we'll push the operation across a loop backedge. This could result in // an infinite combine loop, and is generally non-profitable (especially // if the operation was originally outside the loop). - if (isPotentiallyReachable(PN->getParent(), NonConstBB, nullptr, &DT, LI)) + if (isPotentiallyReachable(PN->getParent(), NonSimplifiedBB, nullptr, &DT, + LI)) return nullptr; } - // If there is exactly one non-constant value, we can insert a copy of the + // If there is exactly one non-simplified value, we can insert a copy of the // operation in that block. However, if this is a critical edge, we would be // inserting the computation on some other paths (e.g. inside a loop). Only // do this if the pred block is unconditionally branching into the phi block. // Also, make sure that the pred block is not dead code. - if (NonConstBB != nullptr) { - BranchInst *BI = dyn_cast<BranchInst>(NonConstBB->getTerminator()); - if (!BI || !BI->isUnconditional() || !DT.isReachableFromEntry(NonConstBB)) + if (NonSimplifiedBB != nullptr) { + BranchInst *BI = dyn_cast<BranchInst>(NonSimplifiedBB->getTerminator()); + if (!BI || !BI->isUnconditional() || + !DT.isReachableFromEntry(NonSimplifiedBB)) return nullptr; } @@ -1219,83 +1266,23 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { // If we are going to have to insert a new computation, do so right before the // predecessor's terminator. - if (NonConstBB) - Builder.SetInsertPoint(NonConstBB->getTerminator()); - - // Next, add all of the operands to the PHI. - if (SelectInst *SI = dyn_cast<SelectInst>(&I)) { - // We only currently try to fold the condition of a select when it is a phi, - // not the true/false values. - Value *TrueV = SI->getTrueValue(); - Value *FalseV = SI->getFalseValue(); - BasicBlock *PhiTransBB = PN->getParent(); - for (unsigned i = 0; i != NumPHIValues; ++i) { - BasicBlock *ThisBB = PN->getIncomingBlock(i); - Value *TrueVInPred = TrueV->DoPHITranslation(PhiTransBB, ThisBB); - Value *FalseVInPred = FalseV->DoPHITranslation(PhiTransBB, ThisBB); - Value *InV = nullptr; - // Beware of ConstantExpr: it may eventually evaluate to getNullValue, - // even if currently isNullValue gives false. - Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i)); - // For vector constants, we cannot use isNullValue to fold into - // FalseVInPred versus TrueVInPred. When we have individual nonzero - // elements in the vector, we will incorrectly fold InC to - // `TrueVInPred`. - if (InC && isa<ConstantInt>(InC)) - InV = InC->isNullValue() ? FalseVInPred : TrueVInPred; - else { - // Generate the select in the same block as PN's current incoming block. - // Note: ThisBB need not be the NonConstBB because vector constants - // which are constants by definition are handled here. - // FIXME: This can lead to an increase in IR generation because we might - // generate selects for vector constant phi operand, that could not be - // folded to TrueVInPred or FalseVInPred as done for ConstantInt. For - // non-vector phis, this transformation was always profitable because - // the select would be generated exactly once in the NonConstBB. - Builder.SetInsertPoint(ThisBB->getTerminator()); - InV = Builder.CreateSelect(PN->getIncomingValue(i), TrueVInPred, - FalseVInPred, "phi.sel"); - } - NewPN->addIncoming(InV, ThisBB); - } - } else if (CmpInst *CI = dyn_cast<CmpInst>(&I)) { - Constant *C = cast<Constant>(I.getOperand(1)); - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = nullptr; - if (auto *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::getCompare(CI->getPredicate(), InC, C); - else - InV = Builder.CreateCmp(CI->getPredicate(), PN->getIncomingValue(i), - C, "phi.cmp"); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else if (auto *BO = dyn_cast<BinaryOperator>(&I)) { - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV = foldOperationIntoPhiValue(BO, PN->getIncomingValue(i), - Builder); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else if (isa<FreezeInst>(&I)) { - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV; - if (NonConstBB == PN->getIncomingBlock(i)) - InV = Builder.CreateFreeze(PN->getIncomingValue(i), "phi.fr"); - else - InV = PN->getIncomingValue(i); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); - } - } else { - CastInst *CI = cast<CastInst>(&I); - Type *RetTy = CI->getType(); - for (unsigned i = 0; i != NumPHIValues; ++i) { - Value *InV; - if (Constant *InC = dyn_cast<Constant>(PN->getIncomingValue(i))) - InV = ConstantExpr::getCast(CI->getOpcode(), InC, RetTy); + Instruction *Clone = nullptr; + if (NonSimplifiedBB) { + Clone = I.clone(); + for (Use &U : Clone->operands()) { + if (U == PN) + U = NonSimplifiedInVal; else - InV = Builder.CreateCast(CI->getOpcode(), PN->getIncomingValue(i), - I.getType(), "phi.cast"); - NewPN->addIncoming(InV, PN->getIncomingBlock(i)); + U = U->DoPHITranslation(PN->getParent(), NonSimplifiedBB); } + InsertNewInstBefore(Clone, *NonSimplifiedBB->getTerminator()); + } + + for (unsigned i = 0; i != NumPHIValues; ++i) { + if (NewPhiValues[i]) + NewPN->addIncoming(NewPhiValues[i], PN->getIncomingBlock(i)); + else + NewPN->addIncoming(Clone, PN->getIncomingBlock(i)); } for (User *U : make_early_inc_range(PN->users())) { @@ -1696,6 +1683,35 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { return new ShuffleVectorInst(NewBO0, NewBO1, Mask); } + auto createBinOpReverse = [&](Value *X, Value *Y) { + Value *V = Builder.CreateBinOp(Opcode, X, Y, Inst.getName()); + if (auto *BO = dyn_cast<BinaryOperator>(V)) + BO->copyIRFlags(&Inst); + Module *M = Inst.getModule(); + Function *F = Intrinsic::getDeclaration( + M, Intrinsic::experimental_vector_reverse, V->getType()); + return CallInst::Create(F, V); + }; + + // NOTE: Reverse shuffles don't require the speculative execution protection + // below because they don't affect which lanes take part in the computation. + + Value *V1, *V2; + if (match(LHS, m_VecReverse(m_Value(V1)))) { + // Op(rev(V1), rev(V2)) -> rev(Op(V1, V2)) + if (match(RHS, m_VecReverse(m_Value(V2))) && + (LHS->hasOneUse() || RHS->hasOneUse() || + (LHS == RHS && LHS->hasNUses(2)))) + return createBinOpReverse(V1, V2); + + // Op(rev(V1), RHSSplat)) -> rev(Op(V1, RHSSplat)) + if (LHS->hasOneUse() && isSplatValue(RHS)) + return createBinOpReverse(V1, RHS); + } + // Op(LHSSplat, rev(V2)) -> rev(Op(LHSSplat, V2)) + else if (isSplatValue(LHS) && match(RHS, m_OneUse(m_VecReverse(m_Value(V2))))) + return createBinOpReverse(LHS, V2); + // It may not be safe to reorder shuffles and things like div, urem, etc. // because we may trap when executing those ops on unknown vector elements. // See PR20059. @@ -1711,7 +1727,6 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // If both arguments of the binary operation are shuffles that use the same // mask and shuffle within a single vector, move the shuffle after the binop. - Value *V1, *V2; if (match(LHS, m_Shuffle(m_Value(V1), m_Undef(), m_Mask(Mask))) && match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(Mask))) && V1->getType() == V2->getType() && @@ -2228,7 +2243,7 @@ Instruction *InstCombinerImpl::visitGEPOfBitcast(BitCastInst *BCI, if (Instruction *I = visitBitCast(*BCI)) { if (I != BCI) { I->takeName(BCI); - BCI->getParent()->getInstList().insert(BCI->getIterator(), I); + I->insertInto(BCI->getParent(), BCI->getIterator()); replaceInstUsesWith(*BCI, I); } return &GEP; @@ -2434,10 +2449,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { NewGEP->setOperand(DI, NewPN); } - GEP.getParent()->getInstList().insert( - GEP.getParent()->getFirstInsertionPt(), NewGEP); - replaceOperand(GEP, 0, NewGEP); - PtrOp = NewGEP; + NewGEP->insertInto(GEP.getParent(), GEP.getParent()->getFirstInsertionPt()); + return replaceOperand(GEP, 0, NewGEP); } if (auto *Src = dyn_cast<GEPOperator>(PtrOp)) @@ -2450,7 +2463,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { unsigned AS = GEP.getPointerAddressSpace(); if (GEP.getOperand(1)->getType()->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) { - uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t TyAllocSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); bool Matched = false; uint64_t C; @@ -2580,8 +2593,9 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (GEPEltType->isSized() && StrippedPtrEltTy->isSized()) { // Check that changing the type amounts to dividing the index by a scale // factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); - uint64_t SrcSize = DL.getTypeAllocSize(StrippedPtrEltTy).getFixedSize(); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); + uint64_t SrcSize = + DL.getTypeAllocSize(StrippedPtrEltTy).getFixedValue(); if (ResSize && SrcSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2617,10 +2631,10 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { StrippedPtrEltTy->isArrayTy()) { // Check that changing to the array element type amounts to dividing the // index by a scale factor. - uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedSize(); + uint64_t ResSize = DL.getTypeAllocSize(GEPEltType).getFixedValue(); uint64_t ArrayEltSize = DL.getTypeAllocSize(StrippedPtrEltTy->getArrayElementType()) - .getFixedSize(); + .getFixedValue(); if (ResSize && ArrayEltSize % ResSize == 0) { Value *Idx = GEP.getOperand(1); unsigned BitWidth = Idx->getType()->getPrimitiveSizeInBits(); @@ -2681,7 +2695,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { BasePtrOffset.isNonNegative()) { APInt AllocSize( IdxWidth, - DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinSize()); + DL.getTypeAllocSize(AI->getAllocatedType()).getKnownMinValue()); if (BasePtrOffset.ule(AllocSize)) { return GetElementPtrInst::CreateInBounds( GEP.getSourceElementType(), PtrOp, Indices, GEP.getName()); @@ -2724,7 +2738,7 @@ static bool isRemovableWrite(CallBase &CB, Value *UsedV, // If the only possible side effect of the call is writing to the alloca, // and the result isn't used, we can safely remove any reads implied by the // call including those which might read the alloca itself. - Optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI); + std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(&CB, TLI); return Dest && Dest->Ptr == UsedV; } @@ -2732,7 +2746,7 @@ static bool isAllocSiteRemovable(Instruction *AI, SmallVectorImpl<WeakTrackingVH> &Users, const TargetLibraryInfo &TLI) { SmallVector<Instruction*, 4> Worklist; - const Optional<StringRef> Family = getAllocationFamily(AI, &TLI); + const std::optional<StringRef> Family = getAllocationFamily(AI, &TLI); Worklist.push_back(AI); do { @@ -2778,7 +2792,7 @@ static bool isAllocSiteRemovable(Instruction *AI, MemIntrinsic *MI = cast<MemIntrinsic>(II); if (MI->isVolatile() || MI->getRawDest() != PI) return false; - LLVM_FALLTHROUGH; + [[fallthrough]]; } case Intrinsic::assume: case Intrinsic::invariant_start: @@ -2808,7 +2822,7 @@ static bool isAllocSiteRemovable(Instruction *AI, continue; } - if (getReallocatedOperand(cast<CallBase>(I), &TLI) == PI && + if (getReallocatedOperand(cast<CallBase>(I)) == PI && getAllocationFamily(I, &TLI) == Family) { assert(Family); Users.emplace_back(I); @@ -2902,7 +2916,7 @@ Instruction *InstCombinerImpl::visitAllocSite(Instruction &MI) { Module *M = II->getModule(); Function *F = Intrinsic::getDeclaration(M, Intrinsic::donothing); InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), - None, "", II->getParent()); + std::nullopt, "", II->getParent()); } // Remove debug intrinsics which describe the value contained within the @@ -3052,7 +3066,7 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI, Value *Op) { // realloc() entirely. CallInst *CI = dyn_cast<CallInst>(Op); if (CI && CI->hasOneUse()) - if (Value *ReallocatedOp = getReallocatedOperand(CI, &TLI)) + if (Value *ReallocatedOp = getReallocatedOperand(CI)) return eraseInstFromFunction(*replaceInstUsesWith(*CI, ReallocatedOp)); // If we optimize for code size, try to move the call to free before the null @@ -3166,31 +3180,41 @@ Instruction *InstCombinerImpl::visitBranchInst(BranchInst &BI) { return visitUnconditionalBranchInst(BI); // Change br (not X), label True, label False to: br X, label False, True - Value *X = nullptr; - if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && - !isa<Constant>(X)) { + Value *Cond = BI.getCondition(); + Value *X; + if (match(Cond, m_Not(m_Value(X))) && !isa<Constant>(X)) { // Swap Destinations and condition... BI.swapSuccessors(); return replaceOperand(BI, 0, X); } + // Canonicalize logical-and-with-invert as logical-or-with-invert. + // This is done by inverting the condition and swapping successors: + // br (X && !Y), T, F --> br !(X && !Y), F, T --> br (!X || Y), F, T + Value *Y; + if (isa<SelectInst>(Cond) && + match(Cond, + m_OneUse(m_LogicalAnd(m_Value(X), m_OneUse(m_Not(m_Value(Y))))))) { + Value *NotX = Builder.CreateNot(X, "not." + X->getName()); + Value *Or = Builder.CreateLogicalOr(NotX, Y); + BI.swapSuccessors(); + return replaceOperand(BI, 0, Or); + } + // If the condition is irrelevant, remove the use so that other // transforms on the condition become more effective. - if (!isa<ConstantInt>(BI.getCondition()) && - BI.getSuccessor(0) == BI.getSuccessor(1)) - return replaceOperand( - BI, 0, ConstantInt::getFalse(BI.getCondition()->getType())); + if (!isa<ConstantInt>(Cond) && BI.getSuccessor(0) == BI.getSuccessor(1)) + return replaceOperand(BI, 0, ConstantInt::getFalse(Cond->getType())); // Canonicalize, for example, fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_FCmp(Pred, m_Value(), m_Value())), - m_BasicBlock(), m_BasicBlock())) && + if (match(Cond, m_OneUse(m_FCmp(Pred, m_Value(), m_Value()))) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. - CmpInst *Cond = cast<CmpInst>(BI.getCondition()); - Cond->setPredicate(CmpInst::getInversePredicate(Pred)); + auto *Cmp = cast<CmpInst>(Cond); + Cmp->setPredicate(CmpInst::getInversePredicate(Pred)); BI.swapSuccessors(); - Worklist.push(Cond); + Worklist.push(Cmp); return &BI; } @@ -3218,7 +3242,7 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { // Compute the number of leading bits we can ignore. // TODO: A better way to determine this would use ComputeNumSignBits(). - for (auto &C : SI.cases()) { + for (const auto &C : SI.cases()) { LeadingKnownZeros = std::min( LeadingKnownZeros, C.getCaseValue()->getValue().countLeadingZeros()); LeadingKnownOnes = std::min( @@ -3247,6 +3271,81 @@ Instruction *InstCombinerImpl::visitSwitchInst(SwitchInst &SI) { return nullptr; } +Instruction * +InstCombinerImpl::foldExtractOfOverflowIntrinsic(ExtractValueInst &EV) { + auto *WO = dyn_cast<WithOverflowInst>(EV.getAggregateOperand()); + if (!WO) + return nullptr; + + Intrinsic::ID OvID = WO->getIntrinsicID(); + const APInt *C = nullptr; + if (match(WO->getRHS(), m_APIntAllowUndef(C))) { + if (*EV.idx_begin() == 0 && (OvID == Intrinsic::smul_with_overflow || + OvID == Intrinsic::umul_with_overflow)) { + // extractvalue (any_mul_with_overflow X, -1), 0 --> -X + if (C->isAllOnes()) + return BinaryOperator::CreateNeg(WO->getLHS()); + // extractvalue (any_mul_with_overflow X, 2^n), 0 --> X << n + if (C->isPowerOf2()) { + return BinaryOperator::CreateShl( + WO->getLHS(), + ConstantInt::get(WO->getLHS()->getType(), C->logBase2())); + } + } + } + + // We're extracting from an overflow intrinsic. See if we're the only user. + // That allows us to simplify multiple result intrinsics to simpler things + // that just get one value. + if (!WO->hasOneUse()) + return nullptr; + + // Check if we're grabbing only the result of a 'with overflow' intrinsic + // and replace it with a traditional binary instruction. + if (*EV.idx_begin() == 0) { + Instruction::BinaryOps BinOp = WO->getBinaryOp(); + Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); + // Replace the old instruction's uses with poison. + replaceInstUsesWith(*WO, PoisonValue::get(WO->getType())); + eraseInstFromFunction(*WO); + return BinaryOperator::Create(BinOp, LHS, RHS); + } + + assert(*EV.idx_begin() == 1 && "Unexpected extract index for overflow inst"); + + // (usub LHS, RHS) overflows when LHS is unsigned-less-than RHS. + if (OvID == Intrinsic::usub_with_overflow) + return new ICmpInst(ICmpInst::ICMP_ULT, WO->getLHS(), WO->getRHS()); + + // smul with i1 types overflows when both sides are set: -1 * -1 == +1, but + // +1 is not possible because we assume signed values. + if (OvID == Intrinsic::smul_with_overflow && + WO->getLHS()->getType()->isIntOrIntVectorTy(1)) + return BinaryOperator::CreateAnd(WO->getLHS(), WO->getRHS()); + + // If only the overflow result is used, and the right hand side is a + // constant (or constant splat), we can remove the intrinsic by directly + // checking for overflow. + if (C) { + // Compute the no-wrap range for LHS given RHS=C, then construct an + // equivalent icmp, potentially using an offset. + ConstantRange NWR = ConstantRange::makeExactNoWrapRegion( + WO->getBinaryOp(), *C, WO->getNoWrapKind()); + + CmpInst::Predicate Pred; + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + auto *OpTy = WO->getRHS()->getType(); + auto *NewLHS = WO->getLHS(); + if (Offset != 0) + NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); + return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, + ConstantInt::get(OpTy, NewRHSC)); + } + + return nullptr; +} + Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *Agg = EV.getAggregateOperand(); @@ -3294,7 +3393,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Value *NewEV = Builder.CreateExtractValue(IV->getAggregateOperand(), EV.getIndices()); return InsertValueInst::Create(NewEV, IV->getInsertedValueOperand(), - makeArrayRef(insi, inse)); + ArrayRef(insi, inse)); } if (insi == inse) // The insert list is a prefix of the extract list @@ -3306,60 +3405,13 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // with // %E extractvalue { i32 } { i32 42 }, 0 return ExtractValueInst::Create(IV->getInsertedValueOperand(), - makeArrayRef(exti, exte)); + ArrayRef(exti, exte)); } - if (WithOverflowInst *WO = dyn_cast<WithOverflowInst>(Agg)) { - // extractvalue (any_mul_with_overflow X, -1), 0 --> -X - Intrinsic::ID OvID = WO->getIntrinsicID(); - if (*EV.idx_begin() == 0 && - (OvID == Intrinsic::smul_with_overflow || - OvID == Intrinsic::umul_with_overflow) && - match(WO->getArgOperand(1), m_AllOnes())) { - return BinaryOperator::CreateNeg(WO->getArgOperand(0)); - } - - // We're extracting from an overflow intrinsic, see if we're the only user, - // which allows us to simplify multiple result intrinsics to simpler - // things that just get one value. - if (WO->hasOneUse()) { - // Check if we're grabbing only the result of a 'with overflow' intrinsic - // and replace it with a traditional binary instruction. - if (*EV.idx_begin() == 0) { - Instruction::BinaryOps BinOp = WO->getBinaryOp(); - Value *LHS = WO->getLHS(), *RHS = WO->getRHS(); - // Replace the old instruction's uses with poison. - replaceInstUsesWith(*WO, PoisonValue::get(WO->getType())); - eraseInstFromFunction(*WO); - return BinaryOperator::Create(BinOp, LHS, RHS); - } - - assert(*EV.idx_begin() == 1 && - "unexpected extract index for overflow inst"); - // If only the overflow result is used, and the right hand side is a - // constant (or constant splat), we can remove the intrinsic by directly - // checking for overflow. - const APInt *C; - if (match(WO->getRHS(), m_APInt(C))) { - // Compute the no-wrap range for LHS given RHS=C, then construct an - // equivalent icmp, potentially using an offset. - ConstantRange NWR = - ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, - WO->getNoWrapKind()); + if (Instruction *R = foldExtractOfOverflowIntrinsic(EV)) + return R; - CmpInst::Predicate Pred; - APInt NewRHSC, Offset; - NWR.getEquivalentICmp(Pred, NewRHSC, Offset); - auto *OpTy = WO->getRHS()->getType(); - auto *NewLHS = WO->getLHS(); - if (Offset != 0) - NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); - return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, - ConstantInt::get(OpTy, NewRHSC)); - } - } - } - if (LoadInst *L = dyn_cast<LoadInst>(Agg)) + if (LoadInst *L = dyn_cast<LoadInst>(Agg)) { // If the (non-volatile) load only has one use, we can rewrite this to a // load from a GEP. This reduces the size of the load. If a load is used // only by extractvalue instructions then this either must have been @@ -3386,6 +3438,12 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // the wrong spot, so use replaceInstUsesWith(). return replaceInstUsesWith(EV, NL); } + } + + if (auto *PN = dyn_cast<PHINode>(Agg)) + if (Instruction *Res = foldOpIntoPhi(EV, PN)) + return Res; + // We could simplify extracts from other values. Note that nested extracts may // already be simplified implicitly by the above: extract (extract (insert) ) // will be translated into extract ( insert ( extract ) ) first and then just @@ -3771,7 +3829,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // poison. If the only source of new poison is flags, we can simply // strip them (since we know the only use is the freeze and nothing can // benefit from them.) - if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false)) + if (canCreateUndefOrPoison(cast<Operator>(OrigOp), + /*ConsiderFlagsAndMetadata*/ false)) return nullptr; // If operand is guaranteed not to be poison, there is no need to add freeze @@ -3779,7 +3838,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // poison. Use *MaybePoisonOperand = nullptr; for (Use &U : OrigOpInst->operands()) { - if (isGuaranteedNotToBeUndefOrPoison(U.get())) + if (isa<MetadataAsValue>(U.get()) || + isGuaranteedNotToBeUndefOrPoison(U.get())) continue; if (!MaybePoisonOperand) MaybePoisonOperand = &U; @@ -3787,7 +3847,7 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { return nullptr; } - OrigOpInst->dropPoisonGeneratingFlags(); + OrigOpInst->dropPoisonGeneratingFlagsAndMetadata(); // If all operands are guaranteed to be non-poison, we can drop freeze. if (!MaybePoisonOperand) @@ -3850,7 +3910,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, Instruction *I = dyn_cast<Instruction>(V); if (!I || canCreateUndefOrPoison(cast<Operator>(I), - /*ConsiderFlags*/ false)) + /*ConsiderFlagsAndMetadata*/ false)) return nullptr; DropFlags.push_back(I); @@ -3858,7 +3918,7 @@ Instruction *InstCombinerImpl::foldFreezeIntoRecurrence(FreezeInst &FI, } for (Instruction *I : DropFlags) - I->dropPoisonGeneratingFlags(); + I->dropPoisonGeneratingFlagsAndMetadata(); if (StartNeedsFreeze) { Builder.SetInsertPoint(StartBB->getTerminator()); @@ -3880,21 +3940,14 @@ bool InstCombinerImpl::freezeOtherUses(FreezeInst &FI) { // *all* uses if the operand is an invoke/callbr and the use is in a phi on // the normal/default destination. This is why the domination check in the // replacement below is still necessary. - Instruction *MoveBefore = nullptr; + Instruction *MoveBefore; if (isa<Argument>(Op)) { - MoveBefore = &FI.getFunction()->getEntryBlock().front(); - while (isa<AllocaInst>(MoveBefore)) - MoveBefore = MoveBefore->getNextNode(); - } else if (auto *PN = dyn_cast<PHINode>(Op)) { - MoveBefore = PN->getParent()->getFirstNonPHI(); - } else if (auto *II = dyn_cast<InvokeInst>(Op)) { - MoveBefore = II->getNormalDest()->getFirstNonPHI(); - } else if (auto *CB = dyn_cast<CallBrInst>(Op)) { - MoveBefore = CB->getDefaultDest()->getFirstNonPHI(); + MoveBefore = + &*FI.getFunction()->getEntryBlock().getFirstNonPHIOrDbgOrAlloca(); } else { - auto *I = cast<Instruction>(Op); - assert(!I->isTerminator() && "Cannot be a terminator"); - MoveBefore = I->getNextNode(); + MoveBefore = cast<Instruction>(Op)->getInsertionPointAfterDef(); + if (!MoveBefore) + return false; } bool Changed = false; @@ -3987,7 +4040,7 @@ static bool SoleWriteToDeadLocal(Instruction *I, TargetLibraryInfo &TLI) { // to allow reload along used path as described below. Otherwise, this // is simply a store to a dead allocation which will be removed. return false; - Optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI); + std::optional<MemoryLocation> Dest = MemoryLocation::getForDest(CB, TLI); if (!Dest) return false; auto *AI = dyn_cast<AllocaInst>(getUnderlyingObject(Dest->Ptr)); @@ -4103,7 +4156,7 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, SmallVector<DbgVariableIntrinsic *, 2> DIIClones; SmallSet<DebugVariable, 4> SunkVariables; - for (auto User : DbgUsersToSink) { + for (auto *User : DbgUsersToSink) { // A dbg.declare instruction should not be cloned, since there can only be // one per variable fragment. It should be left in the original place // because the sunk instruction is not an alloca (otherwise we could not be @@ -4118,6 +4171,11 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock, if (!SunkVariables.insert(DbgUserVariable).second) continue; + // Leave dbg.assign intrinsics in their original positions and there should + // be no need to insert a clone. + if (isa<DbgAssignIntrinsic>(User)) + continue; + DIIClones.emplace_back(cast<DbgVariableIntrinsic>(User->clone())); if (isa<DbgDeclareInst>(User) && isa<CastInst>(I)) DIIClones.back()->replaceVariableLocationOp(I, I->getOperand(0)); @@ -4190,9 +4248,9 @@ bool InstCombinerImpl::run() { // prove that the successor is not executed more frequently than our block. // Return the UserBlock if successful. auto getOptionalSinkBlockForInst = - [this](Instruction *I) -> Optional<BasicBlock *> { + [this](Instruction *I) -> std::optional<BasicBlock *> { if (!EnableCodeSinking) - return None; + return std::nullopt; BasicBlock *BB = I->getParent(); BasicBlock *UserParent = nullptr; @@ -4202,7 +4260,7 @@ bool InstCombinerImpl::run() { if (U->isDroppable()) continue; if (NumUsers > MaxSinkNumUsers) - return None; + return std::nullopt; Instruction *UserInst = cast<Instruction>(U); // Special handling for Phi nodes - get the block the use occurs in. @@ -4213,14 +4271,14 @@ bool InstCombinerImpl::run() { // sophisticated analysis (i.e finding NearestCommonDominator of // these use blocks). if (UserParent && UserParent != PN->getIncomingBlock(i)) - return None; + return std::nullopt; UserParent = PN->getIncomingBlock(i); } } assert(UserParent && "expected to find user block!"); } else { if (UserParent && UserParent != UserInst->getParent()) - return None; + return std::nullopt; UserParent = UserInst->getParent(); } @@ -4230,7 +4288,7 @@ bool InstCombinerImpl::run() { // Try sinking to another block. If that block is unreachable, then do // not bother. SimplifyCFG should handle it. if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) - return None; + return std::nullopt; auto *Term = UserParent->getTerminator(); // See if the user is one of our successors that has only one @@ -4242,7 +4300,7 @@ bool InstCombinerImpl::run() { // - the User will be executed at most once. // So sinking I down to User is always profitable or neutral. if (UserParent->getUniquePredecessor() != BB && !succ_empty(Term)) - return None; + return std::nullopt; assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); } @@ -4252,7 +4310,7 @@ bool InstCombinerImpl::run() { // No user or only has droppable users. if (!UserParent) - return None; + return std::nullopt; return UserParent; }; @@ -4312,7 +4370,7 @@ bool InstCombinerImpl::run() { InsertPos = InstParent->getFirstNonPHI()->getIterator(); } - InstParent->getInstList().insert(InsertPos, Result); + Result->insertInto(InstParent, InsertPos); // Push the new instruction and any users onto the worklist. Worklist.pushUsersToWorkList(*Result); @@ -4360,7 +4418,7 @@ public: const auto *MDScopeList = dyn_cast_or_null<MDNode>(ScopeList); if (!MDScopeList || !Container.insert(MDScopeList).second) return; - for (auto &MDOperand : MDScopeList->operands()) + for (const auto &MDOperand : MDScopeList->operands()) if (auto *MDScope = dyn_cast<MDNode>(MDOperand)) Container.insert(MDScope); }; @@ -4543,6 +4601,13 @@ static bool combineInstructionsOverFunction( bool MadeIRChange = false; if (ShouldLowerDbgDeclare) MadeIRChange = LowerDbgDeclare(F); + // LowerDbgDeclare calls RemoveRedundantDbgInstrs, but LowerDbgDeclare will + // almost never return true when running an assignment tracking build. Take + // this opportunity to do some clean up for assignment tracking builds too. + if (!MadeIRChange && isAssignmentTrackingEnabled(*F.getParent())) { + for (auto &BB : F) + RemoveRedundantDbgInstrs(&BB); + } // Iterate while there is work to do. unsigned Iteration = 0; |
