diff options
Diffstat (limited to 'lib/Transforms/InstCombine')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAddSub.cpp | 268 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 278 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp | 4 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCalls.cpp | 121 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCasts.cpp | 102 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 870 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineInternal.h | 116 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp | 93 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineMulDivRem.cpp | 77 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombinePHI.cpp | 6 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSelect.cpp | 455 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineShifts.cpp | 370 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp | 48 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineVectorOps.cpp | 171 | ||||
-rw-r--r-- | lib/Transforms/InstCombine/InstructionCombining.cpp | 67 |
15 files changed, 2466 insertions, 580 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index ba15b023f2a3..8bc34825f8a7 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -1097,6 +1097,107 @@ static Instruction *foldToUnsignedSaturatedAdd(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I) { + assert((I.getOpcode() == Instruction::Add || + I.getOpcode() == Instruction::Or || + I.getOpcode() == Instruction::Sub) && + "Expecting add/or/sub instruction"); + + // We have a subtraction/addition between a (potentially truncated) *logical* + // right-shift of X and a "select". + Value *X, *Select; + Instruction *LowBitsToSkip, *Extract; + if (!match(&I, m_c_BinOp(m_TruncOrSelf(m_CombineAnd( + m_LShr(m_Value(X), m_Instruction(LowBitsToSkip)), + m_Instruction(Extract))), + m_Value(Select)))) + return nullptr; + + // `add`/`or` is commutative; but for `sub`, "select" *must* be on RHS. + if (I.getOpcode() == Instruction::Sub && I.getOperand(1) != Select) + return nullptr; + + Type *XTy = X->getType(); + bool HadTrunc = I.getType() != XTy; + + // If there was a truncation of extracted value, then we'll need to produce + // one extra instruction, so we need to ensure one instruction will go away. + if (HadTrunc && !match(&I, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Extraction should extract high NBits bits, with shift amount calculated as: + // low bits to skip = shift bitwidth - high bits to extract + // The shift amount itself may be extended, and we need to look past zero-ext + // when matching NBits, that will matter for matching later. + Constant *C; + Value *NBits; + if (!match( + LowBitsToSkip, + m_ZExtOrSelf(m_Sub(m_Constant(C), m_ZExtOrSelf(m_Value(NBits))))) || + !match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + X->getType()->getScalarSizeInBits())))) + return nullptr; + + // Sign-extending value can be zero-extended if we `sub`tract it, + // or sign-extended otherwise. + auto SkipExtInMagic = [&I](Value *&V) { + if (I.getOpcode() == Instruction::Sub) + match(V, m_ZExtOrSelf(m_Value(V))); + else + match(V, m_SExtOrSelf(m_Value(V))); + }; + + // Now, finally validate the sign-extending magic. + // `select` itself may be appropriately extended, look past that. + SkipExtInMagic(Select); + + ICmpInst::Predicate Pred; + const APInt *Thr; + Value *SignExtendingValue, *Zero; + bool ShouldSignext; + // It must be a select between two values we will later establish to be a + // sign-extending value and a zero constant. The condition guarding the + // sign-extension must be based on a sign bit of the same X we had in `lshr`. + if (!match(Select, m_Select(m_ICmp(Pred, m_Specific(X), m_APInt(Thr)), + m_Value(SignExtendingValue), m_Value(Zero))) || + !isSignBitCheck(Pred, *Thr, ShouldSignext)) + return nullptr; + + // icmp-select pair is commutative. + if (!ShouldSignext) + std::swap(SignExtendingValue, Zero); + + // If we should not perform sign-extension then we must add/or/subtract zero. + if (!match(Zero, m_Zero())) + return nullptr; + // Otherwise, it should be some constant, left-shifted by the same NBits we + // had in `lshr`. Said left-shift can also be appropriately extended. + // Again, we must look past zero-ext when looking for NBits. + SkipExtInMagic(SignExtendingValue); + Constant *SignExtendingValueBaseConstant; + if (!match(SignExtendingValue, + m_Shl(m_Constant(SignExtendingValueBaseConstant), + m_ZExtOrSelf(m_Specific(NBits))))) + return nullptr; + // If we `sub`, then the constant should be one, else it should be all-ones. + if (I.getOpcode() == Instruction::Sub + ? !match(SignExtendingValueBaseConstant, m_One()) + : !match(SignExtendingValueBaseConstant, m_AllOnes())) + return nullptr; + + auto *NewAShr = BinaryOperator::CreateAShr(X, LowBitsToSkip, + Extract->getName() + ".sext"); + NewAShr->copyIRFlags(Extract); // Preserve `exact`-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, I.getType()); +} + Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Value *V = SimplifyAddInst(I.getOperand(0), I.getOperand(1), I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), @@ -1302,12 +1403,32 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (Instruction *V = canonicalizeLowbitMask(I, Builder)) return V; + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *SatAdd = foldToUnsignedSaturatedAdd(I)) return SatAdd; return Changed ? &I : nullptr; } +/// Eliminate an op from a linear interpolation (lerp) pattern. +static Instruction *factorizeLerp(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *X, *Y, *Z; + if (!match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_Value(Y), + m_OneUse(m_FSub(m_FPOne(), + m_Value(Z))))), + m_OneUse(m_c_FMul(m_Value(X), m_Deferred(Z)))))) + return nullptr; + + // (Y * (1.0 - Z)) + (X * Z) --> Y + Z * (X - Y) [8 commuted variants] + Value *XY = Builder.CreateFSubFMF(X, Y, &I); + Value *MulZ = Builder.CreateFMulFMF(Z, XY, &I); + return BinaryOperator::CreateFAddFMF(Y, MulZ, &I); +} + /// Factor a common operand out of fadd/fsub of fmul/fdiv. static Instruction *factorizeFAddFSub(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { @@ -1315,6 +1436,10 @@ static Instruction *factorizeFAddFSub(BinaryOperator &I, I.getOpcode() == Instruction::FSub) && "Expecting fadd/fsub"); assert(I.hasAllowReassoc() && I.hasNoSignedZeros() && "FP factorization requires FMF"); + + if (Instruction *Lerp = factorizeLerp(I, Builder)) + return Lerp; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *X, *Y, *Z; bool IsFMul; @@ -1362,17 +1487,32 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { if (Instruction *FoldedFAdd = foldBinOpIntoSelectOrPhi(I)) return FoldedFAdd; - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - Value *X; // (-X) + Y --> Y - X - if (match(LHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(RHS, X, &I); - // Y + (-X) --> Y - X - if (match(RHS, m_FNeg(m_Value(X)))) - return BinaryOperator::CreateFSubFMF(LHS, X, &I); + Value *X, *Y; + if (match(&I, m_c_FAdd(m_FNeg(m_Value(X)), m_Value(Y)))) + return BinaryOperator::CreateFSubFMF(Y, X, &I); + + // Similar to above, but look through fmul/fdiv for the negated term. + // (-X * Y) + Z --> Z - (X * Y) [4 commuted variants] + Value *Z; + if (match(&I, m_c_FAdd(m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z)))) { + Value *XY = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } + // (-X / Y) + Z --> Z - (X / Y) [2 commuted variants] + // (X / -Y) + Z --> Z - (X / Y) [2 commuted variants] + if (match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y))), + m_Value(Z))) || + match(&I, m_c_FAdd(m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))), + m_Value(Z)))) { + Value *XY = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFSubFMF(Z, XY, &I); + } // Check for (fadd double (sitofp x), y), see if we can merge this into an // integer add followed by a promotion. + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (SIToFPInst *LHSConv = dyn_cast<SIToFPInst>(LHS)) { Value *LHSIntVal = LHSConv->getOperand(0); Type *FPType = LHSConv->getType(); @@ -1631,37 +1771,50 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { const APInt *Op0C; if (match(Op0, m_APInt(Op0C))) { - unsigned BitWidth = I.getType()->getScalarSizeInBits(); - // -(X >>u 31) -> (X >>s 31) - // -(X >>s 31) -> (X >>u 31) if (Op0C->isNullValue()) { + Value *Op1Wide; + match(Op1, m_TruncOrSelf(m_Value(Op1Wide))); + bool HadTrunc = Op1Wide != Op1; + bool NoTruncOrTruncIsOneUse = !HadTrunc || Op1->hasOneUse(); + unsigned BitWidth = Op1Wide->getType()->getScalarSizeInBits(); + Value *X; const APInt *ShAmt; - if (match(Op1, m_LShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>u 31) -> (X >>s 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_LShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateAShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateAShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (match(Op1, m_AShr(m_Value(X), m_APInt(ShAmt))) && + // -(X >>s 31) -> (X >>u 31) + if (NoTruncOrTruncIsOneUse && + match(Op1Wide, m_AShr(m_Value(X), m_APInt(ShAmt))) && *ShAmt == BitWidth - 1) { - Value *ShAmtOp = cast<Instruction>(Op1)->getOperand(1); - return BinaryOperator::CreateLShr(X, ShAmtOp); + Value *ShAmtOp = cast<Instruction>(Op1Wide)->getOperand(1); + Instruction *NewShift = BinaryOperator::CreateLShr(X, ShAmtOp); + NewShift->copyIRFlags(Op1Wide); + if (!HadTrunc) + return NewShift; + Builder.Insert(NewShift); + return TruncInst::CreateTruncOrBitCast(NewShift, Op1->getType()); } - if (Op1->hasOneUse()) { + if (!HadTrunc && Op1->hasOneUse()) { Value *LHS, *RHS; SelectPatternFlavor SPF = matchSelectPattern(Op1, LHS, RHS).Flavor; if (SPF == SPF_ABS || SPF == SPF_NABS) { // This is a negate of an ABS/NABS pattern. Just swap the operands // of the select. - SelectInst *SI = cast<SelectInst>(Op1); - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - SI->setTrueValue(FalseVal); - SI->setFalseValue(TrueVal); + cast<SelectInst>(Op1)->swapValues(); // Don't swap prof metadata, we didn't change the branch behavior. - return replaceInstUsesWith(I, SI); + return replaceInstUsesWith(I, Op1); } } } @@ -1686,6 +1839,23 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } + // (sub (or A, B) (and A, B)) --> (xor A, B) + { + Value *A, *B; + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + match(Op0, m_c_Or(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateXor(A, B); + } + + // (sub (and A, B) (or A, B)) --> neg (xor A, B) + { + Value *A, *B; + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateXor(A, B)); + } + // (sub (or A, B), (xor A, B)) --> (and A, B) { Value *A, *B; @@ -1694,6 +1864,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd(A, B); } + // (sub (xor A, B) (or A, B)) --> neg (and A, B) + { + Value *A, *B; + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_Or(m_Specific(A), m_Specific(B))) && + (Op0->hasOneUse() || Op1->hasOneUse())) + return BinaryOperator::CreateNeg(Builder.CreateAnd(A, B)); + } + { Value *Y; // ((X | Y) - X) --> (~X & Y) @@ -1778,7 +1957,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { std::swap(LHS, RHS); // LHS is now O above and expected to have at least 2 uses (the min/max) // NotA is epected to have 2 uses from the min/max and 1 from the sub. - if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && !NotA->hasNUsesOrMore(4)) { // Note: We don't generate the inverse max/min, just create the not of // it and let other folds do the rest. @@ -1826,6 +2005,10 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return SelectInst::Create(Cmp, Neg, A); } + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + if (Instruction *Ext = narrowMathIfNoOverflow(I)) return Ext; @@ -1865,6 +2048,22 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { return nullptr; } +static Instruction *hoistFNegAboveFMulFDiv(Instruction &I, + InstCombiner::BuilderTy &Builder) { + Value *FNeg; + if (!match(&I, m_FNeg(m_Value(FNeg)))) + return nullptr; + + Value *X, *Y; + if (match(FNeg, m_OneUse(m_FMul(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFMulFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + if (match(FNeg, m_OneUse(m_FDiv(m_Value(X), m_Value(Y))))) + return BinaryOperator::CreateFDivFMF(Builder.CreateFNegFMF(X, &I), Y, &I); + + return nullptr; +} + Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { Value *Op = I.getOperand(0); @@ -1882,6 +2081,9 @@ Instruction *InstCombiner::visitFNeg(UnaryOperator &I) { match(Op, m_OneUse(m_FSub(m_Value(X), m_Value(Y))))) return BinaryOperator::CreateFSubFMF(Y, X, &I); + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + return nullptr; } @@ -1903,6 +2105,9 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (Instruction *X = foldFNegIntoConstant(I)) return X; + if (Instruction *R = hoistFNegAboveFMulFDiv(I, Builder)) + return R; + Value *X, *Y; Constant *C; @@ -1944,6 +2149,21 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { if (match(Op1, m_OneUse(m_FPExt(m_FNeg(m_Value(Y)))))) return BinaryOperator::CreateFAddFMF(Op0, Builder.CreateFPExt(Y, Ty), &I); + // Similar to above, but look through fmul/fdiv of the negated value: + // Op0 - (-X * Y) --> Op0 + (X * Y) + // Op0 - (Y * -X) --> Op0 + (X * Y) + if (match(Op1, m_OneUse(m_c_FMul(m_FNeg(m_Value(X)), m_Value(Y))))) { + Value *FMul = Builder.CreateFMulFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FMul, &I); + } + // Op0 - (-X / Y) --> Op0 + (X / Y) + // Op0 - (X / -Y) --> Op0 + (X / Y) + if (match(Op1, m_OneUse(m_FDiv(m_FNeg(m_Value(X)), m_Value(Y)))) || + match(Op1, m_OneUse(m_FDiv(m_Value(X), m_FNeg(m_Value(Y)))))) { + Value *FDiv = Builder.CreateFDivFMF(X, Y, &I); + return BinaryOperator::CreateFAddFMF(Op0, FDiv, &I); + } + // Handle special cases for FSub with selects feeding the operation if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) return replaceInstUsesWith(I, V); diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 2b9859b602f4..4a30b60ca931 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -160,16 +160,14 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, } /// Emit a computation of: (V >= Lo && V < Hi) if Inside is true, otherwise -/// (V < Lo || V >= Hi). This method expects that Lo <= Hi. IsSigned indicates +/// (V < Lo || V >= Hi). This method expects that Lo < Hi. IsSigned indicates /// whether to treat V, Lo, and Hi as signed or not. Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, bool isSigned, bool Inside) { - assert((isSigned ? Lo.sle(Hi) : Lo.ule(Hi)) && - "Lo is not <= Hi in range emission code!"); + assert((isSigned ? Lo.slt(Hi) : Lo.ult(Hi)) && + "Lo is not < Hi in range emission code!"); Type *Ty = V->getType(); - if (Lo == Hi) - return Inside ? ConstantInt::getFalse(Ty) : ConstantInt::getTrue(Ty); // V >= Min && V < Hi --> V < Hi // V < Min || V >= Hi --> V >= Hi @@ -1051,9 +1049,103 @@ static Value *foldIsPowerOf2(ICmpInst *Cmp0, ICmpInst *Cmp1, bool JoinedByAnd, return nullptr; } +/// Commuted variants are assumed to be handled by calling this function again +/// with the parameters swapped. +static Value *foldUnsignedUnderflowCheck(ICmpInst *ZeroICmp, + ICmpInst *UnsignedICmp, bool IsAnd, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + Value *ZeroCmpOp; + ICmpInst::Predicate EqPred; + if (!match(ZeroICmp, m_ICmp(EqPred, m_Value(ZeroCmpOp), m_Zero())) || + !ICmpInst::isEquality(EqPred)) + return nullptr; + + auto IsKnownNonZero = [&](Value *V) { + return isKnownNonZero(V, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT); + }; + + ICmpInst::Predicate UnsignedPred; + + Value *A, *B; + if (match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(ZeroCmpOp), m_Value(A))) && + match(ZeroCmpOp, m_c_Add(m_Specific(A), m_Value(B))) && + (ZeroICmp->hasOneUse() || UnsignedICmp->hasOneUse())) { + if (UnsignedICmp->getOperand(0) != ZeroCmpOp) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + auto GetKnownNonZeroAndOther = [&](Value *&NonZero, Value *&Other) { + if (!IsKnownNonZero(NonZero)) + std::swap(NonZero, Other); + return IsKnownNonZero(NonZero); + }; + + // Given ZeroCmpOp = (A + B) + // ZeroCmpOp <= A && ZeroCmpOp != 0 --> (0-B) < A + // ZeroCmpOp > A || ZeroCmpOp == 0 --> (0-B) >= A + // + // ZeroCmpOp < A && ZeroCmpOp != 0 --> (0-X) < Y iff + // ZeroCmpOp >= A || ZeroCmpOp == 0 --> (0-X) >= Y iff + // with X being the value (A/B) that is known to be non-zero, + // and Y being remaining value. + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + IsAnd) + return Builder.CreateICmpULT(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_ULT && EqPred == ICmpInst::ICMP_NE && + IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpULT(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + !IsAnd) + return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); + if (UnsignedPred == ICmpInst::ICMP_UGE && EqPred == ICmpInst::ICMP_EQ && + !IsAnd && GetKnownNonZeroAndOther(B, A)) + return Builder.CreateICmpUGE(Builder.CreateNeg(B), A); + } + + Value *Base, *Offset; + if (!match(ZeroCmpOp, m_Sub(m_Value(Base), m_Value(Offset)))) + return nullptr; + + if (!match(UnsignedICmp, + m_c_ICmp(UnsignedPred, m_Specific(Base), m_Specific(Offset))) || + !ICmpInst::isUnsigned(UnsignedPred)) + return nullptr; + if (UnsignedICmp->getOperand(0) != Base) + UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred); + + // Base >=/> Offset && (Base - Offset) != 0 <--> Base > Offset + // (no overflow and not null) + if ((UnsignedPred == ICmpInst::ICMP_UGE || + UnsignedPred == ICmpInst::ICMP_UGT) && + EqPred == ICmpInst::ICMP_NE && IsAnd) + return Builder.CreateICmpUGT(Base, Offset); + + // Base <=/< Offset || (Base - Offset) == 0 <--> Base <= Offset + // (overflow or null) + if ((UnsignedPred == ICmpInst::ICMP_ULE || + UnsignedPred == ICmpInst::ICMP_ULT) && + EqPred == ICmpInst::ICMP_EQ && !IsAnd) + return Builder.CreateICmpULE(Base, Offset); + + // Base <= Offset && (Base - Offset) != 0 --> Base < Offset + if (UnsignedPred == ICmpInst::ICMP_ULE && EqPred == ICmpInst::ICMP_NE && + IsAnd) + return Builder.CreateICmpULT(Base, Offset); + + // Base > Offset || (Base - Offset) == 0 --> Base >= Offset + if (UnsignedPred == ICmpInst::ICMP_UGT && EqPred == ICmpInst::ICMP_EQ && + !IsAnd) + return Builder.CreateICmpUGE(Base, Offset); + + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { + const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + // Fold (!iszero(A & K1) & !iszero(A & K2)) -> (A & (K1 | K2)) == (K1 | K2) // if K1 and K2 are a one-bit mask. if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, true, CxtI)) @@ -1096,6 +1188,13 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, true /* JoinedByAnd */, Builder)) return V; + if (Value *X = + foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/true, Q, Builder)) + return X; + if (Value *X = + foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder)) + return X; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); @@ -1196,16 +1295,22 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_ULT: - if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 + // (X != 13 & X u< 14) -> X < 13 + if (LHSC->getValue() == (RHSC->getValue() - 1)) return Builder.CreateICmpULT(LHS0, LHSC); - if (LHSC->isZero()) // (X != 0 & X u< 14) -> X-1 u< 13 + if (LHSC->isZero()) // (X != 0 & X u< C) -> X-1 u< C-1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: - if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 + // (X != 13 & X s< 14) -> X < 13 + if (LHSC->getValue() == (RHSC->getValue() - 1)) return Builder.CreateICmpSLT(LHS0, LHSC); - break; // (X != 13 & X s< 15) -> no change + // (X != INT_MIN & X s< C) -> X-(INT_MIN+1) u< (C-(INT_MIN+1)) + if (LHSC->isMinValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + true, true); + break; // (X != 13 & X s< 15) -> no change case ICmpInst::ICMP_NE: // Potential folds for this case should already be handled. break; @@ -1216,10 +1321,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 + // (X u> 13 & X != 14) -> X u> 14 + if (RHSC->getValue() == (LHSC->getValue() + 1)) return Builder.CreateICmp(PredL, LHS0, RHSC); + // X u> C & X != UINT_MAX -> (X-(C+1)) u< UINT_MAX-(C+1) + if (RHSC->isMaxValue(false)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + false, true); break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) u< 1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); } @@ -1229,10 +1339,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_NE: - if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 + // (X s> 13 & X != 14) -> X s> 14 + if (RHSC->getValue() == (LHSC->getValue() + 1)) return Builder.CreateICmp(PredL, LHS0, RHSC); + // X s> C & X != INT_MAX -> (X-(C+1)) u< INT_MAX-(C+1) + if (RHSC->isMaxValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), + true, true); break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) u< 1 return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, true); } @@ -1352,8 +1467,8 @@ static Instruction *matchDeMorgansLaws(BinaryOperator &I, Value *A, *B; if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && - !IsFreeToInvert(A, A->hasOneUse()) && - !IsFreeToInvert(B, B->hasOneUse())) { + !isFreeToInvert(A, A->hasOneUse()) && + !isFreeToInvert(B, B->hasOneUse())) { Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } @@ -1770,13 +1885,13 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // (A ^ B) & ((B ^ C) ^ A) -> (A ^ B) & ~C if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) if (match(Op1, m_Xor(m_Xor(m_Specific(B), m_Value(C)), m_Specific(A)))) - if (Op1->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + if (Op1->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(C)); // ((A ^ C) ^ B) & (B ^ A) -> (B ^ A) & ~C if (match(Op0, m_Xor(m_Xor(m_Value(A), m_Value(C)), m_Value(B)))) if (match(Op1, m_Xor(m_Specific(B), m_Specific(A)))) - if (Op0->hasOneUse() || IsFreeToInvert(C, C->hasOneUse())) + if (Op0->hasOneUse() || isFreeToInvert(C, C->hasOneUse())) return BinaryOperator::CreateAnd(Op1, Builder.CreateNot(C)); // (A | B) & ((~A) ^ B) -> (A & B) @@ -1844,6 +1959,20 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); + // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0. + { + Value *X, *Y; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(&I, m_c_And(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), + m_APInt(ShAmt))), + m_Deferred(X))) && + *ShAmt == Ty->getScalarSizeInBits() - 1) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); + } + } + return nullptr; } @@ -2057,6 +2186,8 @@ Value *InstCombiner::matchSelectFromAndOr(Value *A, Value *C, Value *B, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI) { + const SimplifyQuery Q = SQ.getWithInstruction(&CxtI); + // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. if (Value *V = foldAndOrOfICmpsOfAndWithPow2(LHS, RHS, false, CxtI)) @@ -2182,6 +2313,13 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldIsPowerOf2(LHS, RHS, false /* JoinedByAnd */, Builder)) return V; + if (Value *X = + foldUnsignedUnderflowCheck(LHS, RHS, /*IsAnd=*/false, Q, Builder)) + return X; + if (Value *X = + foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) + return X; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). if (!LHSC || !RHSC) return nullptr; @@ -2251,8 +2389,19 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, case ICmpInst::ICMP_EQ: // Potential folds for this case should already be handled. break; - case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change - case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + case ICmpInst::ICMP_UGT: + // (X == 0 || X u> C) -> (X-1) u>= C + if (LHSC->isMinValue(false)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, + false, false); + // (X == 13 | X u> 14) -> no change + break; + case ICmpInst::ICMP_SGT: + // (X == INT_MIN || X s> C) -> (X-(INT_MIN+1)) u>= C-INT_MIN + if (LHSC->isMinValue(true)) + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, + true, false); + // (X == 13 | X s> 14) -> no change break; } break; @@ -2261,6 +2410,10 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, default: llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + // (X u< C || X == UINT_MAX) => (X-C) u>= UINT_MAX-C + if (RHSC->isMaxValue(false)) + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), + false, false); break; case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 assert(!RHSC->isMaxValue(false) && "Missed icmp simplification"); @@ -2272,9 +2425,14 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, switch (PredR) { default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + case ICmpInst::ICMP_EQ: + // (X s< C || X == INT_MAX) => (X-C) u>= INT_MAX-C + if (RHSC->isMaxValue(true)) + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), + true, false); + // (X s< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) u> 2 assert(!RHSC->isMaxValue(true) && "Missed icmp simplification"); return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, false); @@ -2552,6 +2710,25 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } + // or(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? -1 : X. + { + Value *X, *Y; + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(&I, m_c_Or(m_OneUse(m_AShr(m_NSWSub(m_Value(Y), m_Value(X)), + m_APInt(ShAmt))), + m_Deferred(X))) && + *ShAmt == Ty->getScalarSizeInBits() - 1) { + Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); + return SelectInst::Create(NewICmpInst, ConstantInt::getAllOnesValue(Ty), + X); + } + } + + if (Instruction *V = + canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract(I)) + return V; + return nullptr; } @@ -2617,7 +2794,11 @@ static Instruction *foldXorToXor(BinaryOperator &I, return nullptr; } -Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { +Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, + BinaryOperator &I) { + assert(I.getOpcode() == Instruction::Xor && I.getOperand(0) == LHS && + I.getOperand(1) == RHS && "Should be 'xor' with these operands"); + if (predicatesFoldable(LHS->getPredicate(), RHS->getPredicate())) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) @@ -2672,14 +2853,35 @@ Value *InstCombiner::foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // TODO: If OrICmp is false, the whole thing is false (InstSimplify?). if (Value *AndICmp = SimplifyBinOp(Instruction::And, LHS, RHS, SQ)) { // TODO: Independently handle cases where the 'and' side is a constant. - if (OrICmp == LHS && AndICmp == RHS && RHS->hasOneUse()) { - // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS - RHS->setPredicate(RHS->getInversePredicate()); - return Builder.CreateAnd(LHS, RHS); + ICmpInst *X = nullptr, *Y = nullptr; + if (OrICmp == LHS && AndICmp == RHS) { + // (LHS | RHS) & !(LHS & RHS) --> LHS & !RHS --> X & !Y + X = LHS; + Y = RHS; } - if (OrICmp == RHS && AndICmp == LHS && LHS->hasOneUse()) { - // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS - LHS->setPredicate(LHS->getInversePredicate()); + if (OrICmp == RHS && AndICmp == LHS) { + // !(LHS & RHS) & (LHS | RHS) --> !LHS & RHS --> !Y & X + X = RHS; + Y = LHS; + } + if (X && Y && (Y->hasOneUse() || canFreelyInvertAllUsersOf(Y, &I))) { + // Invert the predicate of 'Y', thus inverting its output. + Y->setPredicate(Y->getInversePredicate()); + // So, are there other uses of Y? + if (!Y->hasOneUse()) { + // We need to adapt other uses of Y though. Get a value that matches + // the original value of Y before inversion. While this increases + // immediate instruction count, we have just ensured that all the + // users are freely-invertible, so that 'not' *will* get folded away. + BuilderTy::InsertPointGuard Guard(Builder); + // Set insertion point to right after the Y. + Builder.SetInsertPoint(Y->getParent(), ++(Y->getIterator())); + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + // Replace all uses of Y (excluding the one in NotY!) with NotY. + Y->replaceUsesWithIf(NotY, + [NotY](Use &U) { return U.getUser() != NotY; }); + } + // All done. return Builder.CreateAnd(LHS, RHS); } } @@ -2747,9 +2949,9 @@ static Instruction *sinkNotIntoXor(BinaryOperator &I, return nullptr; // We only want to do the transform if it is free to do. - if (IsFreeToInvert(X, X->hasOneUse())) { + if (isFreeToInvert(X, X->hasOneUse())) { // Ok, good. - } else if (IsFreeToInvert(Y, Y->hasOneUse())) { + } else if (isFreeToInvert(Y, Y->hasOneUse())) { std::swap(X, Y); } else return nullptr; @@ -2827,9 +3029,9 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // Apply DeMorgan's Law when inverts are free: // ~(X & Y) --> (~X | ~Y) // ~(X | Y) --> (~X & ~Y) - if (IsFreeToInvert(NotVal->getOperand(0), + if (isFreeToInvert(NotVal->getOperand(0), NotVal->getOperand(0)->hasOneUse()) && - IsFreeToInvert(NotVal->getOperand(1), + isFreeToInvert(NotVal->getOperand(1), NotVal->getOperand(1)->hasOneUse())) { Value *NotX = Builder.CreateNot(NotVal->getOperand(0), "notlhs"); Value *NotY = Builder.CreateNot(NotVal->getOperand(1), "notrhs"); @@ -3004,7 +3206,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (auto *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (auto *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) - if (Value *V = foldXorOfICmps(LHS, RHS)) + if (Value *V = foldXorOfICmps(LHS, RHS, I)) return replaceInstUsesWith(I, V); if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) @@ -3052,7 +3254,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (SelectPatternResult::isMinOrMax(SPF)) { // It's possible we get here before the not has been simplified, so make // sure the input to the not isn't freely invertible. - if (match(LHS, m_Not(m_Value(X))) && !IsFreeToInvert(X, X->hasOneUse())) { + if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { Value *NotY = Builder.CreateNot(RHS); return SelectInst::Create( Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); @@ -3060,7 +3262,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // It's possible we get here before the not has been simplified, so make // sure the input to the not isn't freely invertible. - if (match(RHS, m_Not(m_Value(Y))) && !IsFreeToInvert(Y, Y->hasOneUse())) { + if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { Value *NotX = Builder.CreateNot(LHS); return SelectInst::Create( Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); @@ -3068,8 +3270,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // If both sides are freely invertible, then we can get rid of the xor // completely. - if (IsFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - IsFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { Value *NotLHS = Builder.CreateNot(LHS); Value *NotRHS = Builder.CreateNot(RHS); return SelectInst::Create( diff --git a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp index 5f37a00f56cf..825f4b468b0a 100644 --- a/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp +++ b/lib/Transforms/InstCombine/InstCombineAtomicRMW.cpp @@ -124,7 +124,7 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { auto *SI = new StoreInst(RMWI.getValOperand(), RMWI.getPointerOperand(), &RMWI); SI->setAtomic(Ordering, RMWI.getSyncScopeID()); - SI->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + SI->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); return eraseInstFromFunction(RMWI); } @@ -154,6 +154,6 @@ Instruction *InstCombiner::visitAtomicRMWInst(AtomicRMWInst &RMWI) { LoadInst *Load = new LoadInst(RMWI.getType(), RMWI.getPointerOperand()); Load->setAtomic(Ordering, RMWI.getSyncScopeID()); - Load->setAlignment(DL.getABITypeAlignment(RMWI.getType())); + Load->setAlignment(MaybeAlign(DL.getABITypeAlignment(RMWI.getType()))); return Load; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 4b3333affa72..c650d242cd50 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -185,7 +185,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { Value *Dest = Builder.CreateBitCast(MI->getArgOperand(0), NewDstPtrTy); LoadInst *L = Builder.CreateLoad(IntType, Src); // Alignment from the mem intrinsic will be better, so use it. - L->setAlignment(CopySrcAlign); + L->setAlignment( + MaybeAlign(CopySrcAlign)); // FIXME: Check if we can use Align instead. if (CopyMD) L->setMetadata(LLVMContext::MD_tbaa, CopyMD); MDNode *LoopMemParallelMD = @@ -198,7 +199,8 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { StoreInst *S = Builder.CreateStore(L, Dest); // Alignment from the mem intrinsic will be better, so use it. - S->setAlignment(CopyDstAlign); + S->setAlignment( + MaybeAlign(CopyDstAlign)); // FIXME: Check if we can use Align instead. if (CopyMD) S->setMetadata(LLVMContext::MD_tbaa, CopyMD); if (LoopMemParallelMD) @@ -223,9 +225,10 @@ Instruction *InstCombiner::SimplifyAnyMemTransfer(AnyMemTransferInst *MI) { } Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { - unsigned Alignment = getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); - if (MI->getDestAlignment() < Alignment) { - MI->setDestAlignment(Alignment); + const unsigned KnownAlignment = + getKnownAlignment(MI->getDest(), DL, MI, &AC, &DT); + if (MI->getDestAlignment() < KnownAlignment) { + MI->setDestAlignment(KnownAlignment); return MI; } @@ -243,13 +246,9 @@ Instruction *InstCombiner::SimplifyAnyMemSet(AnyMemSetInst *MI) { ConstantInt *FillC = dyn_cast<ConstantInt>(MI->getValue()); if (!LenC || !FillC || !FillC->getType()->isIntegerTy(8)) return nullptr; - uint64_t Len = LenC->getLimitedValue(); - Alignment = MI->getDestAlignment(); + const uint64_t Len = LenC->getLimitedValue(); assert(Len && "0-sized memory setting should be removed already."); - - // Alignment 0 is identity for alignment 1 for memset, but not store. - if (Alignment == 0) - Alignment = 1; + const Align Alignment = assumeAligned(MI->getDestAlignment()); // If it is an atomic and alignment is less than the size then we will // introduce the unaligned memory access which will be later transformed @@ -1060,9 +1059,9 @@ Value *InstCombiner::simplifyMaskedLoad(IntrinsicInst &II) { // If we can unconditionally load from this address, replace with a // load/select idiom. TODO: use DT for context sensitive query - if (isDereferenceableAndAlignedPointer(LoadPtr, II.getType(), Alignment, - II.getModule()->getDataLayout(), - &II, nullptr)) { + if (isDereferenceableAndAlignedPointer( + LoadPtr, II.getType(), MaybeAlign(Alignment), + II.getModule()->getDataLayout(), &II, nullptr)) { Value *LI = Builder.CreateAlignedLoad(II.getType(), LoadPtr, Alignment, "unmaskedload"); return Builder.CreateSelect(II.getArgOperand(2), LI, II.getArgOperand(3)); @@ -1086,7 +1085,8 @@ Instruction *InstCombiner::simplifyMaskedStore(IntrinsicInst &II) { // If the mask is all ones, this is a plain vector store of the 1st argument. if (ConstMask->isAllOnesValue()) { Value *StorePtr = II.getArgOperand(1); - unsigned Alignment = cast<ConstantInt>(II.getArgOperand(2))->getZExtValue(); + MaybeAlign Alignment( + cast<ConstantInt>(II.getArgOperand(2))->getZExtValue()); return new StoreInst(II.getArgOperand(0), StorePtr, false, Alignment); } @@ -2234,6 +2234,15 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, Add); } + // Try to simplify the underlying FMul. + if (Value *V = SimplifyFMulInst(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); + FAdd->copyFastMathFlags(II); + return FAdd; + } + LLVM_FALLTHROUGH; } case Intrinsic::fma: { @@ -2258,9 +2267,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; } - // fma x, 1, z -> fadd x, z - if (match(Src1, m_FPOne())) { - auto *FAdd = BinaryOperator::CreateFAdd(Src0, II->getArgOperand(2)); + // Try to simplify the underlying FMul. We can only apply simplifications + // that do not require rounding. + if (Value *V = SimplifyFMAFMul(II->getArgOperand(0), II->getArgOperand(1), + II->getFastMathFlags(), + SQ.getWithInstruction(II))) { + auto *FAdd = BinaryOperator::CreateFAdd(V, II->getArgOperand(2)); FAdd->copyFastMathFlags(II); return FAdd; } @@ -2331,7 +2343,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX loads into normal loads. Value *Ptr = Builder.CreateBitCast(II->getArgOperand(0), PointerType::getUnqual(II->getType())); - return new LoadInst(II->getType(), Ptr, Twine(""), false, 1); + return new LoadInst(II->getType(), Ptr, Twine(""), false, Align::None()); } case Intrinsic::ppc_altivec_stvx: case Intrinsic::ppc_altivec_stvxl: @@ -2349,7 +2361,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Turn PPC VSX stores into normal stores. Type *OpPtrTy = PointerType::getUnqual(II->getArgOperand(0)->getType()); Value *Ptr = Builder.CreateBitCast(II->getArgOperand(1), OpPtrTy); - return new StoreInst(II->getArgOperand(0), Ptr, false, 1); + return new StoreInst(II->getArgOperand(0), Ptr, false, Align::None()); } case Intrinsic::ppc_qpx_qvlfs: // Turn PPC QPX qvlfs -> load if the pointer is known aligned. @@ -3885,6 +3897,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Asan needs to poison memory to detect invalid access which is possible // even for empty lifetime range. if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeMemory) || II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) break; @@ -3950,10 +3963,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } case Intrinsic::experimental_gc_relocate: { + auto &GCR = *cast<GCRelocateInst>(II); + + // If we have two copies of the same pointer in the statepoint argument + // list, canonicalize to one. This may let us common gc.relocates. + if (GCR.getBasePtr() == GCR.getDerivedPtr() && + GCR.getBasePtrIndex() != GCR.getDerivedPtrIndex()) { + auto *OpIntTy = GCR.getOperand(2)->getType(); + II->setOperand(2, ConstantInt::get(OpIntTy, GCR.getBasePtrIndex())); + return II; + } + // Translate facts known about a pointer before relocating into // facts about the relocate value, while being careful to // preserve relocation semantics. - Value *DerivedPtr = cast<GCRelocateInst>(II)->getDerivedPtr(); + Value *DerivedPtr = GCR.getDerivedPtr(); // Remove the relocation if unused, note that this check is required // to prevent the cases below from looping forever. @@ -4177,10 +4201,58 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { return nullptr; } +static void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { + unsigned NumArgs = Call.getNumArgOperands(); + ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0)); + ConstantInt *Op1C = + (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1)); + // Bail out if the allocation size is zero. + if ((Op0C && Op0C->isNullValue()) || (Op1C && Op1C->isNullValue())) + return; + + if (isMallocLikeFn(&Call, TLI) && Op0C) { + if (isOpNewLikeFn(&Call, TLI)) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableBytes( + Call.getContext(), Op0C->getZExtValue())); + else + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op0C->getZExtValue())); + } else if (isReallocLikeFn(&Call, TLI) && Op1C) { + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); + } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) { + bool Overflow; + const APInt &N = Op0C->getValue(); + APInt Size = N.umul_ov(Op1C->getValue(), Overflow); + if (!Overflow) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size.getZExtValue())); + } else if (isStrdupLikeFn(&Call, TLI)) { + uint64_t Len = GetStringLength(Call.getOperand(0)); + if (Len) { + // strdup + if (NumArgs == 1) + Call.addAttribute(AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Len)); + // strndup + else if (NumArgs == 2 && Op1C) + Call.addAttribute( + AttributeList::ReturnIndex, + Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); + } + } +} + /// Improvements for call, callbr and invoke instructions. Instruction *InstCombiner::visitCallBase(CallBase &Call) { - if (isAllocLikeFn(&Call, &TLI)) - return visitAllocSite(Call); + if (isAllocationFn(&Call, &TLI)) + annotateAnyAllocSite(Call, &TLI); bool Changed = false; @@ -4312,6 +4384,9 @@ Instruction *InstCombiner::visitCallBase(CallBase &Call) { if (I) return eraseInstFromFunction(*I); } + if (isAllocLikeFn(&Call, &TLI)) + return visitAllocSite(Call); + return Changed ? &Call : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index 2c9ba203fbf3..65aaef28d87a 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -140,7 +140,7 @@ Instruction *InstCombiner::PromoteCastOfAllocation(BitCastInst &CI, } AllocaInst *New = AllocaBuilder.CreateAlloca(CastElTy, Amt); - New->setAlignment(AI.getAlignment()); + New->setAlignment(MaybeAlign(AI.getAlignment())); New->takeName(&AI); New->setUsedWithInAlloca(AI.isUsedWithInAlloca()); @@ -1531,16 +1531,16 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // what we can and cannot do safely varies from operation to operation, and // is explained below in the various case statements. Type *Ty = FPT.getType(); - BinaryOperator *OpI = dyn_cast<BinaryOperator>(FPT.getOperand(0)); - if (OpI && OpI->hasOneUse()) { - Type *LHSMinType = getMinimumFPType(OpI->getOperand(0)); - Type *RHSMinType = getMinimumFPType(OpI->getOperand(1)); - unsigned OpWidth = OpI->getType()->getFPMantissaWidth(); + auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0)); + if (BO && BO->hasOneUse()) { + Type *LHSMinType = getMinimumFPType(BO->getOperand(0)); + Type *RHSMinType = getMinimumFPType(BO->getOperand(1)); + unsigned OpWidth = BO->getType()->getFPMantissaWidth(); unsigned LHSWidth = LHSMinType->getFPMantissaWidth(); unsigned RHSWidth = RHSMinType->getFPMantissaWidth(); unsigned SrcWidth = std::max(LHSWidth, RHSWidth); unsigned DstWidth = Ty->getFPMantissaWidth(); - switch (OpI->getOpcode()) { + switch (BO->getOpcode()) { default: break; case Instruction::FAdd: case Instruction::FSub: @@ -1563,10 +1563,10 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // could be tightened for those cases, but they are rare (the main // case of interest here is (float)((double)float + float)). if (OpWidth >= 2*DstWidth+1 && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - Instruction *RI = BinaryOperator::Create(OpI->getOpcode(), LHS, RHS); - RI->copyFastMathFlags(OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + Instruction *RI = BinaryOperator::Create(BO->getOpcode(), LHS, RHS); + RI->copyFastMathFlags(BO); return RI; } break; @@ -1577,9 +1577,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // rounding can possibly occur; we can safely perform the operation // in the destination format if it can represent both sources. if (OpWidth >= LHSWidth + RHSWidth && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - return BinaryOperator::CreateFMulFMF(LHS, RHS, OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFMulFMF(LHS, RHS, BO); } break; case Instruction::FDiv: @@ -1590,9 +1590,9 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { // condition used here is a good conservative first pass. // TODO: Tighten bound via rigorous analysis of the unbalanced case. if (OpWidth >= 2*DstWidth && DstWidth >= SrcWidth) { - Value *LHS = Builder.CreateFPTrunc(OpI->getOperand(0), Ty); - Value *RHS = Builder.CreateFPTrunc(OpI->getOperand(1), Ty); - return BinaryOperator::CreateFDivFMF(LHS, RHS, OpI); + Value *LHS = Builder.CreateFPTrunc(BO->getOperand(0), Ty); + Value *RHS = Builder.CreateFPTrunc(BO->getOperand(1), Ty); + return BinaryOperator::CreateFDivFMF(LHS, RHS, BO); } break; case Instruction::FRem: { @@ -1604,14 +1604,14 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { break; Value *LHS, *RHS; if (LHSWidth == SrcWidth) { - LHS = Builder.CreateFPTrunc(OpI->getOperand(0), LHSMinType); - RHS = Builder.CreateFPTrunc(OpI->getOperand(1), LHSMinType); + LHS = Builder.CreateFPTrunc(BO->getOperand(0), LHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), LHSMinType); } else { - LHS = Builder.CreateFPTrunc(OpI->getOperand(0), RHSMinType); - RHS = Builder.CreateFPTrunc(OpI->getOperand(1), RHSMinType); + LHS = Builder.CreateFPTrunc(BO->getOperand(0), RHSMinType); + RHS = Builder.CreateFPTrunc(BO->getOperand(1), RHSMinType); } - Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, OpI); + Value *ExactResult = Builder.CreateFRemFMF(LHS, RHS, BO); return CastInst::CreateFPCast(ExactResult, Ty); } } @@ -2338,8 +2338,23 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { // If we found a path from the src to dest, create the getelementptr now. if (SrcElTy == DstElTy) { SmallVector<Value *, 8> Idxs(NumZeros + 1, Builder.getInt32(0)); - return GetElementPtrInst::CreateInBounds(SrcPTy->getElementType(), Src, - Idxs); + GetElementPtrInst *GEP = + GetElementPtrInst::Create(SrcPTy->getElementType(), Src, Idxs); + + // If the source pointer is dereferenceable, then assume it points to an + // allocated object and apply "inbounds" to the GEP. + bool CanBeNull; + if (Src->getPointerDereferenceableBytes(DL, CanBeNull)) { + // In a non-default address space (not 0), a null pointer can not be + // assumed inbounds, so ignore that case (dereferenceable_or_null). + // The reason is that 'null' is not treated differently in these address + // spaces, and we consequently ignore the 'gep inbounds' special case + // for 'null' which allows 'inbounds' on 'null' if the indices are + // zeros. + if (SrcPTy->getAddressSpace() == 0 || !CanBeNull) + GEP->setIsInBounds(); + } + return GEP; } } @@ -2391,28 +2406,47 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { } } - if (ShuffleVectorInst *SVI = dyn_cast<ShuffleVectorInst>(Src)) { + if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) { // Okay, we have (bitcast (shuffle ..)). Check to see if this is // a bitcast to a vector with the same # elts. - if (SVI->hasOneUse() && DestTy->isVectorTy() && - DestTy->getVectorNumElements() == SVI->getType()->getNumElements() && - SVI->getType()->getNumElements() == - SVI->getOperand(0)->getType()->getVectorNumElements()) { + Value *ShufOp0 = Shuf->getOperand(0); + Value *ShufOp1 = Shuf->getOperand(1); + unsigned NumShufElts = Shuf->getType()->getVectorNumElements(); + unsigned NumSrcVecElts = ShufOp0->getType()->getVectorNumElements(); + if (Shuf->hasOneUse() && DestTy->isVectorTy() && + DestTy->getVectorNumElements() == NumShufElts && + NumShufElts == NumSrcVecElts) { BitCastInst *Tmp; // If either of the operands is a cast from CI.getType(), then // evaluating the shuffle in the casted destination's type will allow // us to eliminate at least one cast. - if (((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(0))) && + if (((Tmp = dyn_cast<BitCastInst>(ShufOp0)) && Tmp->getOperand(0)->getType() == DestTy) || - ((Tmp = dyn_cast<BitCastInst>(SVI->getOperand(1))) && + ((Tmp = dyn_cast<BitCastInst>(ShufOp1)) && Tmp->getOperand(0)->getType() == DestTy)) { - Value *LHS = Builder.CreateBitCast(SVI->getOperand(0), DestTy); - Value *RHS = Builder.CreateBitCast(SVI->getOperand(1), DestTy); + Value *LHS = Builder.CreateBitCast(ShufOp0, DestTy); + Value *RHS = Builder.CreateBitCast(ShufOp1, DestTy); // Return a new shuffle vector. Use the same element ID's, as we // know the vector types match #elts. - return new ShuffleVectorInst(LHS, RHS, SVI->getOperand(2)); + return new ShuffleVectorInst(LHS, RHS, Shuf->getOperand(2)); } } + + // A bitcasted-to-scalar and byte-reversing shuffle is better recognized as + // a byte-swap: + // bitcast <N x i8> (shuf X, undef, <N, N-1,...0>) --> bswap (bitcast X) + // TODO: We should match the related pattern for bitreverse. + if (DestTy->isIntegerTy() && + DL.isLegalInteger(DestTy->getScalarSizeInBits()) && + SrcTy->getScalarSizeInBits() == 8 && NumShufElts % 2 == 0 && + Shuf->hasOneUse() && Shuf->isReverse()) { + assert(ShufOp0->getType() == SrcTy && "Unexpected shuffle mask"); + assert(isa<UndefValue>(ShufOp1) && "Unexpected shuffle op"); + Function *Bswap = + Intrinsic::getDeclaration(CI.getModule(), Intrinsic::bswap, DestTy); + Value *ScalarX = Builder.CreateBitCast(ShufOp0, DestTy); + return IntrinsicInst::Create(Bswap, { ScalarX }); + } } // Handle the A->B->A cast, and there is an intervening PHI node. diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 3a4283ae5406..a9f64feb600c 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -69,34 +69,6 @@ static bool hasBranchUse(ICmpInst &I) { return false; } -/// Given an exploded icmp instruction, return true if the comparison only -/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the -/// result of the comparison is true when the input value is signed. -static bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, - bool &TrueIfSigned) { - switch (Pred) { - case ICmpInst::ICMP_SLT: // True if LHS s< 0 - TrueIfSigned = true; - return RHS.isNullValue(); - case ICmpInst::ICMP_SLE: // True if LHS s<= RHS and RHS == -1 - TrueIfSigned = true; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_SGT: // True if LHS s> -1 - TrueIfSigned = false; - return RHS.isAllOnesValue(); - case ICmpInst::ICMP_UGT: - // True if LHS u> RHS and RHS == high-bit-mask - 1 - TrueIfSigned = true; - return RHS.isMaxSignedValue(); - case ICmpInst::ICMP_UGE: - // True if LHS u>= RHS and RHS == high-bit-mask (2^7, 2^15, 2^31, etc) - TrueIfSigned = true; - return RHS.isSignMask(); - default: - return false; - } -} - /// Returns true if the exploded icmp can be expressed as a signed comparison /// to zero and updates the predicate accordingly. /// The signedness of the comparison is preserved. @@ -832,6 +804,10 @@ getAsConstantIndexedAddress(Value *V, const DataLayout &DL) { static Instruction *transformToIndexedCompare(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, const DataLayout &DL) { + // FIXME: Support vector of pointers. + if (GEPLHS->getType()->isVectorTy()) + return nullptr; + if (!GEPLHS->hasAllConstantIndices()) return nullptr; @@ -882,7 +858,9 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, RHS = RHS->stripPointerCasts(); Value *PtrBase = GEPLHS->getOperand(0); - if (PtrBase == RHS && GEPLHS->isInBounds()) { + // FIXME: Support vector pointer GEPs. + if (PtrBase == RHS && GEPLHS->isInBounds() && + !GEPLHS->getType()->isVectorTy()) { // ((gep Ptr, OFFSET) cmp Ptr) ---> (OFFSET cmp 0). // This transformation (ignoring the base and scales) is valid because we // know pointers can't overflow since the gep is inbounds. See if we can @@ -894,6 +872,37 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, Offset = EmitGEPOffset(GEPLHS); return new ICmpInst(ICmpInst::getSignedPredicate(Cond), Offset, Constant::getNullValue(Offset->getType())); + } + + if (GEPLHS->isInBounds() && ICmpInst::isEquality(Cond) && + isa<Constant>(RHS) && cast<Constant>(RHS)->isNullValue() && + !NullPointerIsDefined(I.getFunction(), + RHS->getType()->getPointerAddressSpace())) { + // For most address spaces, an allocation can't be placed at null, but null + // itself is treated as a 0 size allocation in the in bounds rules. Thus, + // the only valid inbounds address derived from null, is null itself. + // Thus, we have four cases to consider: + // 1) Base == nullptr, Offset == 0 -> inbounds, null + // 2) Base == nullptr, Offset != 0 -> poison as the result is out of bounds + // 3) Base != nullptr, Offset == (-base) -> poison (crossing allocations) + // 4) Base != nullptr, Offset != (-base) -> nonnull (and possibly poison) + // + // (Note if we're indexing a type of size 0, that simply collapses into one + // of the buckets above.) + // + // In general, we're allowed to make values less poison (i.e. remove + // sources of full UB), so in this case, we just select between the two + // non-poison cases (1 and 4 above). + // + // For vectors, we apply the same reasoning on a per-lane basis. + auto *Base = GEPLHS->getPointerOperand(); + if (GEPLHS->getType()->isVectorTy() && Base->getType()->isPointerTy()) { + int NumElts = GEPLHS->getType()->getVectorNumElements(); + Base = Builder.CreateVectorSplat(NumElts, Base); + } + return new ICmpInst(Cond, Base, + ConstantExpr::getPointerBitCastOrAddrSpaceCast( + cast<Constant>(RHS), Base->getType())); } else if (GEPOperator *GEPRHS = dyn_cast<GEPOperator>(RHS)) { // If the base pointers are different, but the indices are the same, just // compare the base pointer. @@ -916,11 +925,13 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, // If we're comparing GEPs with two base pointers that only differ in type // and both GEPs have only constant indices or just one use, then fold // the compare with the adjusted indices. + // FIXME: Support vector of pointers. if (GEPLHS->isInBounds() && GEPRHS->isInBounds() && (GEPLHS->hasAllConstantIndices() || GEPLHS->hasOneUse()) && (GEPRHS->hasAllConstantIndices() || GEPRHS->hasOneUse()) && PtrBase->stripPointerCasts() == - GEPRHS->getOperand(0)->stripPointerCasts()) { + GEPRHS->getOperand(0)->stripPointerCasts() && + !GEPLHS->getType()->isVectorTy()) { Value *LOffset = EmitGEPOffset(GEPLHS); Value *ROffset = EmitGEPOffset(GEPRHS); @@ -949,12 +960,14 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, } // If one of the GEPs has all zero indices, recurse. - if (GEPLHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPLHS->getType()->isVectorTy() && GEPLHS->hasAllZeroIndices()) return foldGEPICmp(GEPRHS, GEPLHS->getOperand(0), ICmpInst::getSwappedPredicate(Cond), I); // If the other GEP has all zero indices, recurse. - if (GEPRHS->hasAllZeroIndices()) + // FIXME: Handle vector of pointers. + if (!GEPRHS->getType()->isVectorTy() && GEPRHS->hasAllZeroIndices()) return foldGEPICmp(GEPLHS, GEPRHS->getOperand(0), Cond, I); bool GEPsInBounds = GEPLHS->isInBounds() && GEPRHS->isInBounds(); @@ -964,15 +977,20 @@ Instruction *InstCombiner::foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, unsigned DiffOperand = 0; // The operand that differs. for (unsigned i = 1, e = GEPRHS->getNumOperands(); i != e; ++i) if (GEPLHS->getOperand(i) != GEPRHS->getOperand(i)) { - if (GEPLHS->getOperand(i)->getType()->getPrimitiveSizeInBits() != - GEPRHS->getOperand(i)->getType()->getPrimitiveSizeInBits()) { + Type *LHSType = GEPLHS->getOperand(i)->getType(); + Type *RHSType = GEPRHS->getOperand(i)->getType(); + // FIXME: Better support for vector of pointers. + if (LHSType->getPrimitiveSizeInBits() != + RHSType->getPrimitiveSizeInBits() || + (GEPLHS->getType()->isVectorTy() && + (!LHSType->isVectorTy() || !RHSType->isVectorTy()))) { // Irreconcilable differences. NumDifferences = 2; break; - } else { - if (NumDifferences++) break; - DiffOperand = i; } + + if (NumDifferences++) break; + DiffOperand = i; } if (NumDifferences == 0) // SAME GEP? @@ -1317,6 +1335,59 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +/// If we have: +/// icmp eq/ne (urem/srem %x, %y), 0 +/// iff %y is a power-of-two, we can replace this with a bit test: +/// icmp eq/ne (and %x, (add %y, -1)), 0 +Instruction *InstCombiner::foldIRemByPowerOfTwoToBitTest(ICmpInst &I) { + // This fold is only valid for equality predicates. + if (!I.isEquality()) + return nullptr; + ICmpInst::Predicate Pred; + Value *X, *Y, *Zero; + if (!match(&I, m_ICmp(Pred, m_OneUse(m_IRem(m_Value(X), m_Value(Y))), + m_CombineAnd(m_Zero(), m_Value(Zero))))) + return nullptr; + if (!isKnownToBeAPowerOfTwo(Y, /*OrZero*/ true, 0, &I)) + return nullptr; + // This may increase instruction count, we don't enforce that Y is a constant. + Value *Mask = Builder.CreateAdd(Y, Constant::getAllOnesValue(Y->getType())); + Value *Masked = Builder.CreateAnd(X, Mask); + return ICmpInst::Create(Instruction::ICmp, Pred, Masked, Zero); +} + +/// Fold equality-comparison between zero and any (maybe truncated) right-shift +/// by one-less-than-bitwidth into a sign test on the original value. +Instruction *InstCombiner::foldSignBitTest(ICmpInst &I) { + Instruction *Val; + ICmpInst::Predicate Pred; + if (!I.isEquality() || !match(&I, m_ICmp(Pred, m_Instruction(Val), m_Zero()))) + return nullptr; + + Value *X; + Type *XTy; + + Constant *C; + if (match(Val, m_TruncOrSelf(m_Shr(m_Value(X), m_Constant(C))))) { + XTy = X->getType(); + unsigned XBitWidth = XTy->getScalarSizeInBits(); + if (!match(C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(XBitWidth, XBitWidth - 1)))) + return nullptr; + } else if (isa<BinaryOperator>(Val) && + (X = reassociateShiftAmtsOfTwoSameDirectionShifts( + cast<BinaryOperator>(Val), SQ.getWithInstruction(Val), + /*AnalyzeForSignBitExtraction=*/true))) { + XTy = X->getType(); + } else + return nullptr; + + return ICmpInst::Create(Instruction::ICmp, + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_SLT, + X, ConstantInt::getNullValue(XTy)); +} + // Handle icmp pred X, 0 Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1335,6 +1406,9 @@ Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { } } + if (Instruction *New = foldIRemByPowerOfTwoToBitTest(Cmp)) + return New; + // Given: // icmp eq/ne (urem %x, %y), 0 // Iff %x has 0 or 1 bits set, and %y has at least 2 bits set, omit 'urem': @@ -2179,6 +2253,44 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, return nullptr; } +Instruction *InstCombiner::foldICmpSRemConstant(ICmpInst &Cmp, + BinaryOperator *SRem, + const APInt &C) { + // Match an 'is positive' or 'is negative' comparison of remainder by a + // constant power-of-2 value: + // (X % pow2C) sgt/slt 0 + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + if (Pred != ICmpInst::ICMP_SGT && Pred != ICmpInst::ICMP_SLT) + return nullptr; + + // TODO: The one-use check is standard because we do not typically want to + // create longer instruction sequences, but this might be a special-case + // because srem is not good for analysis or codegen. + if (!SRem->hasOneUse()) + return nullptr; + + const APInt *DivisorC; + if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + return nullptr; + + // Mask off the sign bit and the modulo bits (low-bits). + Type *Ty = SRem->getType(); + APInt SignMask = APInt::getSignMask(Ty->getScalarSizeInBits()); + Constant *MaskC = ConstantInt::get(Ty, SignMask | (*DivisorC - 1)); + Value *And = Builder.CreateAnd(SRem->getOperand(0), MaskC); + + // For 'is positive?' check that the sign-bit is clear and at least 1 masked + // bit is set. Example: + // (i8 X % 32) s> 0 --> (X & 159) s> 0 + if (Pred == ICmpInst::ICMP_SGT) + return new ICmpInst(ICmpInst::ICMP_SGT, And, ConstantInt::getNullValue(Ty)); + + // For 'is negative?' check that the sign-bit is set and at least 1 masked + // bit is set. Example: + // (i16 X % 4) s< 0 --> (X & 32771) u> 32768 + return new ICmpInst(ICmpInst::ICMP_UGT, And, ConstantInt::get(Ty, SignMask)); +} + /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, @@ -2387,6 +2499,11 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, const APInt *C2; APInt SubResult; + // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 + if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) + return new ICmpInst(Cmp.getPredicate(), Y, + ConstantInt::get(Y->getType(), 0)); + // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) if (match(X, m_APInt(C2)) && ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || @@ -2509,20 +2626,49 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, // TODO: Generalize this to work with other comparison idioms or ensure // they get canonicalized into this form. - // select i1 (a == b), i32 Equal, i32 (select i1 (a < b), i32 Less, i32 - // Greater), where Equal, Less and Greater are placeholders for any three - // constants. - ICmpInst::Predicate PredA, PredB; - if (match(SI->getTrueValue(), m_ConstantInt(Equal)) && - match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) && - PredA == ICmpInst::ICMP_EQ && - match(SI->getFalseValue(), - m_Select(m_ICmp(PredB, m_Specific(LHS), m_Specific(RHS)), - m_ConstantInt(Less), m_ConstantInt(Greater))) && - PredB == ICmpInst::ICMP_SLT) { - return true; + // select i1 (a == b), + // i32 Equal, + // i32 (select i1 (a < b), i32 Less, i32 Greater) + // where Equal, Less and Greater are placeholders for any three constants. + ICmpInst::Predicate PredA; + if (!match(SI->getCondition(), m_ICmp(PredA, m_Value(LHS), m_Value(RHS))) || + !ICmpInst::isEquality(PredA)) + return false; + Value *EqualVal = SI->getTrueValue(); + Value *UnequalVal = SI->getFalseValue(); + // We still can get non-canonical predicate here, so canonicalize. + if (PredA == ICmpInst::ICMP_NE) + std::swap(EqualVal, UnequalVal); + if (!match(EqualVal, m_ConstantInt(Equal))) + return false; + ICmpInst::Predicate PredB; + Value *LHS2, *RHS2; + if (!match(UnequalVal, m_Select(m_ICmp(PredB, m_Value(LHS2), m_Value(RHS2)), + m_ConstantInt(Less), m_ConstantInt(Greater)))) + return false; + // We can get predicate mismatch here, so canonicalize if possible: + // First, ensure that 'LHS' match. + if (LHS2 != LHS) { + // x sgt y <--> y slt x + std::swap(LHS2, RHS2); + PredB = ICmpInst::getSwappedPredicate(PredB); + } + if (LHS2 != LHS) + return false; + // We also need to canonicalize 'RHS'. + if (PredB == ICmpInst::ICMP_SGT && isa<Constant>(RHS2)) { + // x sgt C-1 <--> x sge C <--> not(x slt C) + auto FlippedStrictness = + getFlippedStrictnessPredicateAndConstant(PredB, cast<Constant>(RHS2)); + if (!FlippedStrictness) + return false; + assert(FlippedStrictness->first == ICmpInst::ICMP_SGE && "Sanity check"); + RHS2 = FlippedStrictness->second; + // And kind-of perform the result swap. + std::swap(Less, Greater); + PredB = ICmpInst::ICMP_SLT; } - return false; + return PredB == ICmpInst::ICMP_SLT && RHS == RHS2; } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, @@ -2702,6 +2848,10 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; + case Instruction::SRem: + if (Instruction *I = foldICmpSRemConstant(Cmp, BO, *C)) + return I; + break; case Instruction::UDiv: if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; @@ -2926,6 +3076,28 @@ Instruction *InstCombiner::foldICmpEqIntrinsicWithConstant(ICmpInst &Cmp, } break; } + + case Intrinsic::uadd_sat: { + // uadd.sat(a, b) == 0 -> (a | b) == 0 + if (C.isNullValue()) { + Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); + return replaceInstUsesWith(Cmp, Builder.CreateICmp( + Cmp.getPredicate(), Or, Constant::getNullValue(Ty))); + + } + break; + } + + case Intrinsic::usub_sat: { + // usub.sat(a, b) == 0 -> a <= b + if (C.isNullValue()) { + ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ + ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + return ICmpInst::Create(Instruction::ICmp, NewPred, + II->getArgOperand(0), II->getArgOperand(1)); + } + break; + } default: break; } @@ -3275,6 +3447,7 @@ foldICmpWithTruncSignExtendedVal(ICmpInst &I, // we should move shifts to the same hand of 'and', i.e. rewrite as // icmp eq/ne (and (x shift (Q+K)), y), 0 iff (Q+K) u< bitwidth(x) // We are only interested in opposite logical shifts here. +// One of the shifts can be truncated. // If we can, we want to end up creating 'lshr' shift. static Value * foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, @@ -3284,55 +3457,215 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, return nullptr; auto m_AnyLogicalShift = m_LogicalShift(m_Value(), m_Value()); - auto m_AnyLShr = m_LShr(m_Value(), m_Value()); - - // Look for an 'and' of two (opposite) logical shifts. - // Pick the single-use shift as XShift. - Value *XShift, *YShift; - if (!match(I.getOperand(0), - m_c_And(m_OneUse(m_CombineAnd(m_AnyLogicalShift, m_Value(XShift))), - m_CombineAnd(m_AnyLogicalShift, m_Value(YShift))))) + + // Look for an 'and' of two logical shifts, one of which may be truncated. + // We use m_TruncOrSelf() on the RHS to correctly handle commutative case. + Instruction *XShift, *MaybeTruncation, *YShift; + if (!match( + I.getOperand(0), + m_c_And(m_CombineAnd(m_AnyLogicalShift, m_Instruction(XShift)), + m_CombineAnd(m_TruncOrSelf(m_CombineAnd( + m_AnyLogicalShift, m_Instruction(YShift))), + m_Instruction(MaybeTruncation))))) return nullptr; - // If YShift is a single-use 'lshr', swap the shifts around. - if (match(YShift, m_OneUse(m_AnyLShr))) + // We potentially looked past 'trunc', but only when matching YShift, + // therefore YShift must have the widest type. + Instruction *WidestShift = YShift; + // Therefore XShift must have the shallowest type. + // Or they both have identical types if there was no truncation. + Instruction *NarrowestShift = XShift; + + Type *WidestTy = WidestShift->getType(); + assert(NarrowestShift->getType() == I.getOperand(0)->getType() && + "We did not look past any shifts while matching XShift though."); + bool HadTrunc = WidestTy != I.getOperand(0)->getType(); + + // If YShift is a 'lshr', swap the shifts around. + if (match(YShift, m_LShr(m_Value(), m_Value()))) std::swap(XShift, YShift); // The shifts must be in opposite directions. - Instruction::BinaryOps XShiftOpcode = - cast<BinaryOperator>(XShift)->getOpcode(); - if (XShiftOpcode == cast<BinaryOperator>(YShift)->getOpcode()) + auto XShiftOpcode = XShift->getOpcode(); + if (XShiftOpcode == YShift->getOpcode()) return nullptr; // Do not care about same-direction shifts here. Value *X, *XShAmt, *Y, *YShAmt; - match(XShift, m_BinOp(m_Value(X), m_Value(XShAmt))); - match(YShift, m_BinOp(m_Value(Y), m_Value(YShAmt))); + match(XShift, m_BinOp(m_Value(X), m_ZExtOrSelf(m_Value(XShAmt)))); + match(YShift, m_BinOp(m_Value(Y), m_ZExtOrSelf(m_Value(YShAmt)))); + + // If one of the values being shifted is a constant, then we will end with + // and+icmp, and [zext+]shift instrs will be constant-folded. If they are not, + // however, we will need to ensure that we won't increase instruction count. + if (!isa<Constant>(X) && !isa<Constant>(Y)) { + // At least one of the hands of the 'and' should be one-use shift. + if (!match(I.getOperand(0), + m_c_And(m_OneUse(m_AnyLogicalShift), m_Value()))) + return nullptr; + if (HadTrunc) { + // Due to the 'trunc', we will need to widen X. For that either the old + // 'trunc' or the shift amt in the non-truncated shift should be one-use. + if (!MaybeTruncation->hasOneUse() && + !NarrowestShift->getOperand(1)->hasOneUse()) + return nullptr; + } + } + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (XShAmt->getType() != YShAmt->getType()) + return nullptr; // Can we fold (XShAmt+YShAmt) ? - Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, XShAmt, YShAmt, - SQ.getWithInstruction(&I)); + auto *NewShAmt = dyn_cast_or_null<Constant>( + SimplifyAddInst(XShAmt, YShAmt, /*isNSW=*/false, + /*isNUW=*/false, SQ.getWithInstruction(&I))); if (!NewShAmt) return nullptr; + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, WidestTy); + unsigned WidestBitWidth = WidestTy->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width? // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); - if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, + APInt(WidestBitWidth, WidestBitWidth)))) return nullptr; - // All good, we can do this fold. The shift is the same that was for X. + + // An extra legality check is needed if we had trunc-of-lshr. + if (HadTrunc && match(WidestShift, m_LShr(m_Value(), m_Value()))) { + auto CanFold = [NewShAmt, WidestBitWidth, NarrowestShift, SQ, + WidestShift]() { + // It isn't obvious whether it's worth it to analyze non-constants here. + // Also, let's basically give up on non-splat cases, pessimizing vectors. + // If *any* of these preconditions matches we can perform the fold. + Constant *NewShAmtSplat = NewShAmt->getType()->isVectorTy() + ? NewShAmt->getSplatValue() + : NewShAmt; + // If it's edge-case shift (by 0 or by WidestBitWidth-1) we can fold. + if (NewShAmtSplat && + (NewShAmtSplat->isNullValue() || + NewShAmtSplat->getUniqueInteger() == WidestBitWidth - 1)) + return true; + // We consider *min* leading zeros so a single outlier + // blocks the transform as opposed to allowing it. + if (auto *C = dyn_cast<Constant>(NarrowestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: NewShAmt u<= countLeadingZeros(C) + if (NewShAmtSplat && NewShAmtSplat->getUniqueInteger().ule(MinLeadZero)) + return true; + } + if (auto *C = dyn_cast<Constant>(WidestShift->getOperand(0))) { + KnownBits Known = computeKnownBits(C, SQ.DL); + unsigned MinLeadZero = Known.countMinLeadingZeros(); + // If the value being shifted has at most lowest bit set we can fold. + unsigned MaxActiveBits = Known.getBitWidth() - MinLeadZero; + if (MaxActiveBits <= 1) + return true; + // Precondition: ((WidestBitWidth-1)-NewShAmt) u<= countLeadingZeros(C) + if (NewShAmtSplat) { + APInt AdjNewShAmt = + (WidestBitWidth - 1) - NewShAmtSplat->getUniqueInteger(); + if (AdjNewShAmt.ule(MinLeadZero)) + return true; + } + } + return false; // Can't tell if it's ok. + }; + if (!CanFold()) + return nullptr; + } + + // All good, we can do this fold. + X = Builder.CreateZExt(X, WidestTy); + Y = Builder.CreateZExt(Y, WidestTy); + // The shift is the same that was for X. Value *T0 = XShiftOpcode == Instruction::BinaryOps::LShr ? Builder.CreateLShr(X, NewShAmt) : Builder.CreateShl(X, NewShAmt); Value *T1 = Builder.CreateAnd(T0, Y); return Builder.CreateICmp(I.getPredicate(), T1, - Constant::getNullValue(X->getType())); + Constant::getNullValue(WidestTy)); +} + +/// Fold +/// (-1 u/ x) u< y +/// ((x * y) u/ x) != y +/// to +/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// Note that the comparison is commutative, while inverted (u>=, ==) predicate +/// will mean that we are looking for the opposite answer. +Value *InstCombiner::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { + ICmpInst::Predicate Pred; + Value *X, *Y; + Instruction *Mul; + bool NeedNegation; + // Look for: (-1 u/ x) u</u>= y + if (!I.isEquality() && + match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Value(Y)))) { + Mul = nullptr; + // Canonicalize as-if y was on RHS. + if (I.getOperand(1) != Y) + Pred = I.getSwappedPredicate(); + + // Are we checking that overflow does not happen, or does happen? + switch (Pred) { + case ICmpInst::Predicate::ICMP_ULT: + NeedNegation = false; + break; // OK + case ICmpInst::Predicate::ICMP_UGE: + NeedNegation = true; + break; // OK + default: + return nullptr; // Wrong predicate. + } + } else // Look for: ((x * y) u/ x) !=/== y + if (I.isEquality() && + match(&I, m_c_ICmp(Pred, m_Value(Y), + m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + m_Value(X)), + m_Instruction(Mul)), + m_Deferred(X)))))) { + NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; + } else + return nullptr; + + BuilderTy::InsertPointGuard Guard(Builder); + // If the pattern included (x * y), we'll want to insert new instructions + // right before that original multiplication so that we can replace it. + bool MulHadOtherUses = Mul && !Mul->hasOneUse(); + if (MulHadOtherUses) + Builder.SetInsertPoint(Mul); + + Function *F = Intrinsic::getDeclaration( + I.getModule(), Intrinsic::umul_with_overflow, X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + + // If the multiplication was used elsewhere, to ensure that we don't leave + // "duplicate" instructions, replace uses of that original multiplication + // with the multiplication result from the with.overflow intrinsic. + if (MulHadOtherUses) + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + + Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + if (NeedNegation) // This technically increases instruction count. + Res = Builder.CreateNot(Res, "umul.not.ov"); + + return Res; } /// Try to fold icmp (binop), X or icmp X, (binop). /// TODO: A large part of this logic is duplicated in InstSimplify's /// simplifyICmpWithBinOp(). We should be able to share that and avoid the code /// duplication. -Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { +Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I, const SimplifyQuery &SQ) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); // Special logic for binary operators. @@ -3345,13 +3678,13 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { Value *X; // Convert add-with-unsigned-overflow comparisons into a 'not' with compare. - // (Op1 + X) <u Op1 --> ~Op1 <u X - // Op0 >u (Op0 + X) --> X >u ~Op0 + // (Op1 + X) u</u>= Op1 --> ~Op1 u</u>= X if (match(Op0, m_OneUse(m_c_Add(m_Specific(Op1), m_Value(X)))) && - Pred == ICmpInst::ICMP_ULT) + (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) return new ICmpInst(Pred, Builder.CreateNot(Op1), X); + // Op0 u>/u<= (Op0 + X) --> X u>/u<= ~Op0 if (match(Op1, m_OneUse(m_c_Add(m_Specific(Op0), m_Value(X)))) && - Pred == ICmpInst::ICMP_UGT) + (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) return new ICmpInst(Pred, X, Builder.CreateNot(Op0)); bool NoOp0WrapProblem = false, NoOp1WrapProblem = false; @@ -3378,21 +3711,21 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. + // icmp (A+B), A -> icmp B, 0 for equalities or if there is no overflow. + // icmp (A+B), B -> icmp A, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, Constant::getNullValue(Op1->getType())); - // icmp X, (X+Y) -> icmp 0, Y for equalities or if there is no overflow. + // icmp C, (C+D) -> icmp 0, D for equalities or if there is no overflow. + // icmp D, (C+D) -> icmp 0, C for equalities or if there is no overflow. if ((C == Op0 || D == Op0) && NoOp1WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op0->getType()), C == Op0 ? D : C); - // icmp (X+Y), (X+Z) -> icmp Y, Z for equalities or if there is no overflow. + // icmp (A+B), (A+D) -> icmp B, D for equalities or if there is no overflow. if (A && C && (A == C || A == D || B == C || B == D) && NoOp0WrapProblem && - NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) { + NoOp1WrapProblem) { // Determine Y and Z in the form icmp (X+Y), (X+Z). Value *Y, *Z; if (A == C) { @@ -3416,39 +3749,39 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { return new ICmpInst(Pred, Y, Z); } - // icmp slt (X + -1), Y -> icmp sle X, Y + // icmp slt (A + -1), Op1 -> icmp sle A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLT && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLE, A, Op1); - // icmp sge (X + -1), Y -> icmp sgt X, Y + // icmp sge (A + -1), Op1 -> icmp sgt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGE && match(B, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGT, A, Op1); - // icmp sle (X + 1), Y -> icmp slt X, Y + // icmp sle (A + 1), Op1 -> icmp slt A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SLE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SLT, A, Op1); - // icmp sgt (X + 1), Y -> icmp sge X, Y + // icmp sgt (A + 1), Op1 -> icmp sge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_SGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_SGE, A, Op1); - // icmp sgt X, (Y + -1) -> icmp sge X, Y + // icmp sgt Op0, (C + -1) -> icmp sge Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGT && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SGE, Op0, C); - // icmp sle X, (Y + -1) -> icmp slt X, Y + // icmp sle Op0, (C + -1) -> icmp slt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLE && match(D, m_AllOnes())) return new ICmpInst(CmpInst::ICMP_SLT, Op0, C); - // icmp sge X, (Y + 1) -> icmp sgt X, Y + // icmp sge Op0, (C + 1) -> icmp sgt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SGT, Op0, C); - // icmp slt X, (Y + 1) -> icmp sle X, Y + // icmp slt Op0, (C + 1) -> icmp sle Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); @@ -3456,33 +3789,33 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { // canonicalization from (X -nuw 1) to (X + -1) means that the combinations // wouldn't happen even if they were implemented. // - // icmp ult (X - 1), Y -> icmp ule X, Y - // icmp uge (X - 1), Y -> icmp ugt X, Y - // icmp ugt X, (Y - 1) -> icmp uge X, Y - // icmp ule X, (Y - 1) -> icmp ult X, Y + // icmp ult (A - 1), Op1 -> icmp ule A, Op1 + // icmp uge (A - 1), Op1 -> icmp ugt A, Op1 + // icmp ugt Op0, (C - 1) -> icmp uge Op0, C + // icmp ule Op0, (C - 1) -> icmp ult Op0, C - // icmp ule (X + 1), Y -> icmp ult X, Y + // icmp ule (A + 1), Op0 -> icmp ult A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); - // icmp ugt (X + 1), Y -> icmp uge X, Y + // icmp ugt (A + 1), Op0 -> icmp uge A, Op1 if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); - // icmp uge X, (Y + 1) -> icmp ugt X, Y + // icmp uge Op0, (C + 1) -> icmp ugt Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); - // icmp ult X, (Y + 1) -> icmp ule X, Y + // icmp ult Op0, (C + 1) -> icmp ule Op0, C if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); // if C1 has greater magnitude than C2: - // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y + // icmp (A + C1), (C + C2) -> icmp (A + C3), C // s.t. C3 = C1 - C2 // // if C2 has greater magnitude than C1: - // icmp (X + C1), (Y + C2) -> icmp X, (Y + C3) + // icmp (A + C1), (C + C2) -> icmp A, (C + C3) // s.t. C3 = C2 - C1 if (A && C && NoOp0WrapProblem && NoOp1WrapProblem && (BO0->hasOneUse() || BO1->hasOneUse()) && !I.isUnsigned()) @@ -3520,29 +3853,35 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X-Y), X -> icmp 0, Y for equalities or if there is no overflow. + // icmp (A-B), A -> icmp 0, B for equalities or if there is no overflow. if (A == Op1 && NoOp0WrapProblem) return new ICmpInst(Pred, Constant::getNullValue(Op1->getType()), B); - // icmp X, (X-Y) -> icmp Y, 0 for equalities or if there is no overflow. + // icmp C, (C-D) -> icmp D, 0 for equalities or if there is no overflow. if (C == Op0 && NoOp1WrapProblem) return new ICmpInst(Pred, D, Constant::getNullValue(Op0->getType())); - // (A - B) >u A --> A <u B - if (A == Op1 && Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_ULT, A, B); - // C <u (C - D) --> C <u D - if (C == Op0 && Pred == ICmpInst::ICMP_ULT) - return new ICmpInst(ICmpInst::ICMP_ULT, C, D); - - // icmp (Y-X), (Z-X) -> icmp Y, Z for equalities or if there is no overflow. - if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + // Convert sub-with-unsigned-overflow comparisons into a comparison of args. + // (A - B) u>/u<= A --> B u>/u<= A + if (A == Op1 && (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULE)) + return new ICmpInst(Pred, B, A); + // C u</u>= (C - D) --> C u</u>= D + if (C == Op0 && (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_UGE)) + return new ICmpInst(Pred, C, D); + // (A - B) u>=/u< A --> B u>/u<= A iff B != 0 + if (A == Op1 && (Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_ULT) && + isKnownNonZero(B, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), B, A); + // C u<=/u> (C - D) --> C u</u>= D iff B != 0 + if (C == Op0 && (Pred == ICmpInst::ICMP_ULE || Pred == ICmpInst::ICMP_UGT) && + isKnownNonZero(D, Q.DL, /*Depth=*/0, Q.AC, Q.CxtI, Q.DT)) + return new ICmpInst(CmpInst::getFlippedStrictnessPredicate(Pred), C, D); + + // icmp (A-B), (C-B) -> icmp A, C for equalities or if there is no overflow. + if (B && D && B == D && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, A, C); - // icmp (X-Y), (X-Z) -> icmp Z, Y for equalities or if there is no overflow. - if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem && - // Try not to increase register pressure. - BO0->hasOneUse() && BO1->hasOneUse()) + + // icmp (A-B), (A-D) -> icmp D, B for equalities or if there is no overflow. + if (A && C && A == C && NoOp0WrapProblem && NoOp1WrapProblem) return new ICmpInst(Pred, D, B); // icmp (0-X) < cst --> x > -cst @@ -3677,6 +4016,9 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { } } + if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + return replaceInstUsesWith(I, V); + if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) return replaceInstUsesWith(I, V); @@ -3953,125 +4295,140 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; } -/// Handle icmp (cast x to y), (cast/cst). We only handle extending casts so -/// far. -Instruction *InstCombiner::foldICmpWithCastAndCast(ICmpInst &ICmp) { - const CastInst *LHSCI = cast<CastInst>(ICmp.getOperand(0)); - Value *LHSCIOp = LHSCI->getOperand(0); - Type *SrcTy = LHSCIOp->getType(); - Type *DestTy = LHSCI->getType(); - - // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the - // integer type is the same size as the pointer type. - const auto& CompatibleSizes = [&](Type* SrcTy, Type* DestTy) -> bool { - if (isa<VectorType>(SrcTy)) { - SrcTy = cast<VectorType>(SrcTy)->getElementType(); - DestTy = cast<VectorType>(DestTy)->getElementType(); - } - return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); - }; - if (LHSCI->getOpcode() == Instruction::PtrToInt && - CompatibleSizes(SrcTy, DestTy)) { - Value *RHSOp = nullptr; - if (auto *RHSC = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { - Value *RHSCIOp = RHSC->getOperand(0); - if (RHSCIOp->getType()->getPointerAddressSpace() == - LHSCIOp->getType()->getPointerAddressSpace()) { - RHSOp = RHSC->getOperand(0); - // If the pointer types don't match, insert a bitcast. - if (LHSCIOp->getType() != RHSOp->getType()) - RHSOp = Builder.CreateBitCast(RHSOp, LHSCIOp->getType()); - } - } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { - RHSOp = ConstantExpr::getIntToPtr(RHSC, SrcTy); - } - - if (RHSOp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSOp); - } - - // The code below only handles extension cast instructions, so far. - // Enforce this. - if (LHSCI->getOpcode() != Instruction::ZExt && - LHSCI->getOpcode() != Instruction::SExt) +static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0"); + auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0)); + Value *X; + if (!match(CastOp0, m_ZExtOrSExt(m_Value(X)))) return nullptr; - bool isSignedExt = LHSCI->getOpcode() == Instruction::SExt; - bool isSignedCmp = ICmp.isSigned(); - - if (auto *CI = dyn_cast<CastInst>(ICmp.getOperand(1))) { - // Not an extension from the same type? - Value *RHSCIOp = CI->getOperand(0); - if (RHSCIOp->getType() != LHSCIOp->getType()) - return nullptr; - + bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt; + bool IsSignedCmp = ICmp.isSigned(); + if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) { // If the signedness of the two casts doesn't agree (i.e. one is a sext // and the other is a zext), then we can't handle this. - if (CI->getOpcode() != LHSCI->getOpcode()) + // TODO: This is too strict. We can handle some predicates (equality?). + if (CastOp0->getOpcode() != CastOp1->getOpcode()) return nullptr; - // Deal with equality cases early. + // Not an extension from the same type? + Value *Y = CastOp1->getOperand(0); + Type *XTy = X->getType(), *YTy = Y->getType(); + if (XTy != YTy) { + // One of the casts must have one use because we are creating a new cast. + if (!CastOp0->hasOneUse() && !CastOp1->hasOneUse()) + return nullptr; + // Extend the narrower operand to the type of the wider operand. + if (XTy->getScalarSizeInBits() < YTy->getScalarSizeInBits()) + X = Builder.CreateCast(CastOp0->getOpcode(), X, YTy); + else if (YTy->getScalarSizeInBits() < XTy->getScalarSizeInBits()) + Y = Builder.CreateCast(CastOp0->getOpcode(), Y, XTy); + else + return nullptr; + } + + // (zext X) == (zext Y) --> X == Y + // (sext X) == (sext Y) --> X == Y if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getPredicate(), X, Y); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedCmp && isSignedExt) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, RHSCIOp); + if (IsSignedCmp && IsSignedExt) + return new ICmpInst(ICmp.getPredicate(), X, Y); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, RHSCIOp); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y); } - // If we aren't dealing with a constant on the RHS, exit early. + // Below here, we are only folding a compare with constant. auto *C = dyn_cast<Constant>(ICmp.getOperand(1)); if (!C) return nullptr; // Compute the constant that would happen if we truncated to SrcTy then // re-extended to DestTy. + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy); - Constant *Res2 = ConstantExpr::getCast(LHSCI->getOpcode(), Res1, DestTy); + Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy); // If the re-extended constant didn't change... if (Res2 == C) { - // Deal with equality cases early. if (ICmp.isEquality()) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getPredicate(), X, Res1); // A signed comparison of sign extended values simplifies into a // signed comparison. - if (isSignedExt && isSignedCmp) - return new ICmpInst(ICmp.getPredicate(), LHSCIOp, Res1); + if (IsSignedExt && IsSignedCmp) + return new ICmpInst(ICmp.getPredicate(), X, Res1); // The other three cases all fold into an unsigned comparison. - return new ICmpInst(ICmp.getUnsignedPredicate(), LHSCIOp, Res1); + return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1); } // The re-extended constant changed, partly changed (in the case of a vector), // or could not be determined to be equal (in the case of a constant // expression), so the constant cannot be represented in the shorter type. - // Consequently, we cannot emit a simple comparison. // All the cases that fold to true or false will have already been handled // by SimplifyICmpInst, so only deal with the tricky case. + if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C)) + return nullptr; + + // Is source op positive? + // icmp ult (sext X), C --> icmp sgt X, -1 + if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) + return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy)); + + // Is source op negative? + // icmp ugt (sext X), C --> icmp slt X, 0 + assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); + return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy)); +} - if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C)) +/// Handle icmp (cast x), (cast or constant). +Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) { + auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); + if (!CastOp0) + return nullptr; + if (!isa<Constant>(ICmp.getOperand(1)) && !isa<CastInst>(ICmp.getOperand(1))) return nullptr; - // Evaluate the comparison for LT (we invert for GT below). LE and GE cases - // should have been folded away previously and not enter in here. + Value *Op0Src = CastOp0->getOperand(0); + Type *SrcTy = CastOp0->getSrcTy(); + Type *DestTy = CastOp0->getDestTy(); - // We're performing an unsigned comp with a sign extended value. - // This is true if the input is >= 0. [aka >s -1] - Constant *NegOne = Constant::getAllOnesValue(SrcTy); - Value *Result = Builder.CreateICmpSGT(LHSCIOp, NegOne, ICmp.getName()); + // Turn icmp (ptrtoint x), (ptrtoint/c) into a compare of the input if the + // integer type is the same size as the pointer type. + auto CompatibleSizes = [&](Type *SrcTy, Type *DestTy) { + if (isa<VectorType>(SrcTy)) { + SrcTy = cast<VectorType>(SrcTy)->getElementType(); + DestTy = cast<VectorType>(DestTy)->getElementType(); + } + return DL.getPointerTypeSizeInBits(SrcTy) == DestTy->getIntegerBitWidth(); + }; + if (CastOp0->getOpcode() == Instruction::PtrToInt && + CompatibleSizes(SrcTy, DestTy)) { + Value *NewOp1 = nullptr; + if (auto *PtrToIntOp1 = dyn_cast<PtrToIntOperator>(ICmp.getOperand(1))) { + Value *PtrSrc = PtrToIntOp1->getOperand(0); + if (PtrSrc->getType()->getPointerAddressSpace() == + Op0Src->getType()->getPointerAddressSpace()) { + NewOp1 = PtrToIntOp1->getOperand(0); + // If the pointer types don't match, insert a bitcast. + if (Op0Src->getType() != NewOp1->getType()) + NewOp1 = Builder.CreateBitCast(NewOp1, Op0Src->getType()); + } + } else if (auto *RHSC = dyn_cast<Constant>(ICmp.getOperand(1))) { + NewOp1 = ConstantExpr::getIntToPtr(RHSC, SrcTy); + } - // Finally, return the value computed. - if (ICmp.getPredicate() == ICmpInst::ICMP_ULT) - return replaceInstUsesWith(ICmp, Result); + if (NewOp1) + return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); + } - assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!"); - return BinaryOperator::CreateNot(Result); + return foldICmpWithZextOrSext(ICmp, Builder); } static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) { @@ -4791,41 +5148,35 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return nullptr; } -/// If we have an icmp le or icmp ge instruction with a constant operand, turn -/// it into the appropriate icmp lt or icmp gt instruction. This transform -/// allows them to be folded in visitICmpInst. -static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { - ICmpInst::Predicate Pred = I.getPredicate(); - if (Pred != ICmpInst::ICMP_SLE && Pred != ICmpInst::ICMP_SGE && - Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_UGE) - return nullptr; +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, + Constant *C) { + assert(ICmpInst::isRelational(Pred) && ICmpInst::isIntPredicate(Pred) && + "Only for relational integer predicates."); - Value *Op0 = I.getOperand(0); - Value *Op1 = I.getOperand(1); - auto *Op1C = dyn_cast<Constant>(Op1); - if (!Op1C) - return nullptr; + Type *Type = C->getType(); + bool IsSigned = ICmpInst::isSigned(Pred); + + CmpInst::Predicate UnsignedPred = ICmpInst::getUnsignedPredicate(Pred); + bool WillIncrement = + UnsignedPred == ICmpInst::ICMP_ULE || UnsignedPred == ICmpInst::ICMP_UGT; - // Check if the constant operand can be safely incremented/decremented without - // overflowing/underflowing. For scalars, SimplifyICmpInst has already handled - // the edge cases for us, so we just assert on them. For vectors, we must - // handle the edge cases. - Type *Op1Type = Op1->getType(); - bool IsSigned = I.isSigned(); - bool IsLE = (Pred == ICmpInst::ICMP_SLE || Pred == ICmpInst::ICMP_ULE); - auto *CI = dyn_cast<ConstantInt>(Op1C); - if (CI) { - // A <= MAX -> TRUE ; A >= MIN -> TRUE - assert(IsLE ? !CI->isMaxValue(IsSigned) : !CI->isMinValue(IsSigned)); - } else if (Op1Type->isVectorTy()) { - // TODO? If the edge cases for vectors were guaranteed to be handled as they - // are for scalar, we could remove the min/max checks. However, to do that, - // we would have to use insertelement/shufflevector to replace edge values. - unsigned NumElts = Op1Type->getVectorNumElements(); + // Check if the constant operand can be safely incremented/decremented + // without overflowing/underflowing. + auto ConstantIsOk = [WillIncrement, IsSigned](ConstantInt *C) { + return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); + }; + + if (auto *CI = dyn_cast<ConstantInt>(C)) { + // Bail out if the constant can't be safely incremented/decremented. + if (!ConstantIsOk(CI)) + return llvm::None; + } else if (Type->isVectorTy()) { + unsigned NumElts = Type->getVectorNumElements(); for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = Op1C->getAggregateElement(i); + Constant *Elt = C->getAggregateElement(i); if (!Elt) - return nullptr; + return llvm::None; if (isa<UndefValue>(Elt)) continue; @@ -4833,20 +5184,43 @@ static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { // Bail out if we can't determine if this constant is min/max or if we // know that this constant is min/max. auto *CI = dyn_cast<ConstantInt>(Elt); - if (!CI || (IsLE ? CI->isMaxValue(IsSigned) : CI->isMinValue(IsSigned))) - return nullptr; + if (!CI || !ConstantIsOk(CI)) + return llvm::None; } } else { // ConstantExpr? - return nullptr; + return llvm::None; } - // Increment or decrement the constant and set the new comparison predicate: - // ULE -> ULT ; UGE -> UGT ; SLE -> SLT ; SGE -> SGT - Constant *OneOrNegOne = ConstantInt::get(Op1Type, IsLE ? 1 : -1, true); - CmpInst::Predicate NewPred = IsLE ? ICmpInst::ICMP_ULT: ICmpInst::ICMP_UGT; - NewPred = IsSigned ? ICmpInst::getSignedPredicate(NewPred) : NewPred; - return new ICmpInst(NewPred, Op0, ConstantExpr::getAdd(Op1C, OneOrNegOne)); + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); + + // Increment or decrement the constant. + Constant *OneOrNegOne = ConstantInt::get(Type, WillIncrement ? 1 : -1, true); + Constant *NewC = ConstantExpr::getAdd(C, OneOrNegOne); + + return std::make_pair(NewPred, NewC); +} + +/// If we have an icmp le or icmp ge instruction with a constant operand, turn +/// it into the appropriate icmp lt or icmp gt instruction. This transform +/// allows them to be folded in visitICmpInst. +static ICmpInst *canonicalizeCmpWithConstant(ICmpInst &I) { + ICmpInst::Predicate Pred = I.getPredicate(); + if (ICmpInst::isEquality(Pred) || !ICmpInst::isIntPredicate(Pred) || + isCanonicalPredicate(Pred)) + return nullptr; + + Value *Op0 = I.getOperand(0); + Value *Op1 = I.getOperand(1); + auto *Op1C = dyn_cast<Constant>(Op1); + if (!Op1C) + return nullptr; + + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, Op1C); + if (!FlippedStrictness) + return nullptr; + + return new ICmpInst(FlippedStrictness->first, Op0, FlippedStrictness->second); } /// Integer compare with boolean values can always be turned into bitwise ops. @@ -5002,6 +5376,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; + const SimplifyQuery Q = SQ.getWithInstruction(&I); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); unsigned Op0Cplxity = getComplexity(Op0); unsigned Op1Cplxity = getComplexity(Op1); @@ -5016,8 +5391,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Changed = true; } - if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyICmpInst(I.getPredicate(), Op0, Op1, Q)) return replaceInstUsesWith(I, V); // Comparing -val or val with non-zero is the same as just comparing val @@ -5050,6 +5424,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5098,6 +5475,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; + // Try to match comparison as a sign bit test. Intentionally do this after + // foldICmpInstWithConstant() to potentially let other folds to happen first. + if (Instruction *New = foldSignBitTest(I)) + return New; + if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; @@ -5124,20 +5506,8 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpBitCast(I, Builder)) return Res; - if (isa<CastInst>(Op0)) { - // Handle the special case of: icmp (cast bool to X), <cst> - // This comes up when you have code like - // int X = A < B; - // if (X) ... - // For generality, we handle any zero-extension of any operand comparison - // with a constant or another cast from the same type. - if (isa<Constant>(Op1) || isa<CastInst>(Op1)) - if (Instruction *R = foldICmpWithCastAndCast(I)) - return R; - } - - if (Instruction *Res = foldICmpBinOp(I)) - return Res; + if (Instruction *R = foldICmpWithCastOp(I)) + return R; if (Instruction *Res = foldICmpWithMinMax(I)) return Res; diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index 434b0d591215..1dbc06d92e7a 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -113,6 +113,48 @@ static inline bool isCanonicalPredicate(CmpInst::Predicate Pred) { } } +/// Given an exploded icmp instruction, return true if the comparison only +/// checks the sign bit. If it only checks the sign bit, set TrueIfSigned if the +/// result of the comparison is true when the input value is signed. +inline bool isSignBitCheck(ICmpInst::Predicate Pred, const APInt &RHS, + bool &TrueIfSigned) { + switch (Pred) { + case ICmpInst::ICMP_SLT: // True if LHS s< 0 + TrueIfSigned = true; + return RHS.isNullValue(); + case ICmpInst::ICMP_SLE: // True if LHS s<= -1 + TrueIfSigned = true; + return RHS.isAllOnesValue(); + case ICmpInst::ICMP_SGT: // True if LHS s> -1 + TrueIfSigned = false; + return RHS.isAllOnesValue(); + case ICmpInst::ICMP_SGE: // True if LHS s>= 0 + TrueIfSigned = false; + return RHS.isNullValue(); + case ICmpInst::ICMP_UGT: + // True if LHS u> RHS and RHS == sign-bit-mask - 1 + TrueIfSigned = true; + return RHS.isMaxSignedValue(); + case ICmpInst::ICMP_UGE: + // True if LHS u>= RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = true; + return RHS.isMinSignedValue(); + case ICmpInst::ICMP_ULT: + // True if LHS u< RHS and RHS == sign-bit-mask (2^7, 2^15, 2^31, etc) + TrueIfSigned = false; + return RHS.isMinSignedValue(); + case ICmpInst::ICMP_ULE: + // True if LHS u<= RHS and RHS == sign-bit-mask - 1 + TrueIfSigned = false; + return RHS.isMaxSignedValue(); + default: + return false; + } +} + +llvm::Optional<std::pair<CmpInst::Predicate, Constant *>> +getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, Constant *C); + /// Return the source operand of a potentially bitcasted value while optionally /// checking if it has one use. If there is no bitcast or the one use check is /// not met, return the input value itself. @@ -139,32 +181,17 @@ static inline Constant *SubOne(Constant *C) { /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. -static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { +/// +/// See also: canFreelyInvertAllUsersOf() +static inline bool isFreeToInvert(Value *V, bool WillInvertAllUses) { // ~(~(X)) -> X. if (match(V, m_Not(m_Value()))) return true; // Constants can be considered to be not'ed values. - if (isa<ConstantInt>(V)) + if (match(V, m_AnyIntegralConstant())) return true; - // A vector of constant integers can be inverted easily. - if (V->getType()->isVectorTy() && isa<Constant>(V)) { - unsigned NumElts = V->getType()->getVectorNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - - if (isa<UndefValue>(Elt)) - continue; - - if (!isa<ConstantInt>(Elt)) - return false; - } - return true; - } - // Compares can be inverted if all of their uses are being modified to use the // ~V. if (isa<CmpInst>(V)) @@ -185,6 +212,32 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } +/// Given i1 V, can every user of V be freely adapted if V is changed to !V ? +/// +/// See also: isFreeToInvert() +static inline bool canFreelyInvertAllUsersOf(Value *V, Value *IgnoredUser) { + // Look at every user of V. + for (User *U : V->users()) { + if (U == IgnoredUser) + continue; // Don't consider this user. + + auto *I = cast<Instruction>(U); + switch (I->getOpcode()) { + case Instruction::Select: + case Instruction::Br: + break; // Free to invert by swapping true/false values/destinations. + case Instruction::Xor: // Can invert 'xor' if it's a 'not', by ignoring it. + if (!match(I, m_Not(m_Value()))) + return false; // Not a 'not'. + break; + default: + return false; // Don't know, likely not freely invertible. + } + // So far all users were free to invert... + } + return true; // Can freely invert all users! +} + /// Some binary operators require special handling to avoid poison and undefined /// behavior. If a constant vector has undef elements, replace those undefs with /// identity constants if possible because those are always safe to execute. @@ -337,6 +390,13 @@ public: Instruction *visitOr(BinaryOperator &I); Instruction *visitXor(BinaryOperator &I); Instruction *visitShl(BinaryOperator &I); + Value *reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction = false); + Instruction *canonicalizeCondSignextOfHighBitExtractToSignextHighBitExtract( + BinaryOperator &I); + Instruction *foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr); Instruction *visitAShr(BinaryOperator &I); Instruction *visitLShr(BinaryOperator &I); Instruction *commonShiftTransforms(BinaryOperator &I); @@ -541,6 +601,7 @@ private: Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); + Instruction *matchSAddSubSat(SelectInst &MinMax1); /// Determine if a pair of casts can be replaced by a single cast. /// @@ -557,7 +618,7 @@ private: Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &I); /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should @@ -725,7 +786,7 @@ public: Value *LHS, Value *RHS, Instruction *CxtI) const; /// Maximum size of array considered when transforming. - uint64_t MaxArraySizeForCombine; + uint64_t MaxArraySizeForCombine = 0; private: /// Performs a few simplifications for operators which are associative @@ -798,7 +859,8 @@ private: int DmaskIdx = -1); Value *SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, - APInt &UndefElts, unsigned Depth = 0); + APInt &UndefElts, unsigned Depth = 0, + bool AllowMultipleUsers = false); /// Canonicalize the position of binops relative to shufflevector. Instruction *foldVectorBinop(BinaryOperator &Inst); @@ -847,17 +909,21 @@ private: Constant *RHSC); Instruction *foldICmpAddOpConst(Value *X, const APInt &C, ICmpInst::Predicate Pred); - Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); + Instruction *foldICmpWithCastOp(ICmpInst &ICI); Instruction *foldICmpUsingKnownBits(ICmpInst &Cmp); Instruction *foldICmpWithDominatingICmp(ICmpInst &Cmp); Instruction *foldICmpWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstant(ICmpInst &Cmp); Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); - Instruction *foldICmpBinOp(ICmpInst &Cmp); + Instruction *foldICmpBinOp(ICmpInst &Cmp, const SimplifyQuery &SQ); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldIRemByPowerOfTwoToBitTest(ICmpInst &I); + Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); + Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, @@ -874,6 +940,8 @@ private: const APInt &C); Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, const APInt &C); + Instruction *foldICmpSRemConstant(ICmpInst &Cmp, BinaryOperator *UDiv, + const APInt &C); Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, const APInt &C); Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 054fb7da09a2..3a0e05832fcb 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -175,7 +175,7 @@ static bool isDereferenceableForAllocaSize(const Value *V, const AllocaInst *AI, uint64_t AllocaSize = DL.getTypeStoreSize(AI->getAllocatedType()); if (!AllocaSize) return false; - return isDereferenceableAndAlignedPointer(V, AI->getAlignment(), + return isDereferenceableAndAlignedPointer(V, Align(AI->getAlignment()), APInt(64, AllocaSize), DL); } @@ -197,7 +197,7 @@ static Instruction *simplifyAllocaArraySize(InstCombiner &IC, AllocaInst &AI) { if (C->getValue().getActiveBits() <= 64) { Type *NewTy = ArrayType::get(AI.getAllocatedType(), C->getZExtValue()); AllocaInst *New = IC.Builder.CreateAlloca(NewTy, nullptr, AI.getName()); - New->setAlignment(AI.getAlignment()); + New->setAlignment(MaybeAlign(AI.getAlignment())); // Scan to the end of the allocation instructions, to skip over a block of // allocas if possible...also skip interleaved debug info @@ -345,7 +345,8 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { if (AI.getAllocatedType()->isSized()) { // If the alignment is 0 (unspecified), assign it the preferred alignment. if (AI.getAlignment() == 0) - AI.setAlignment(DL.getPrefTypeAlignment(AI.getAllocatedType())); + AI.setAlignment( + MaybeAlign(DL.getPrefTypeAlignment(AI.getAllocatedType()))); // Move all alloca's of zero byte objects to the entry block and merge them // together. Note that we only do this for alloca's, because malloc should @@ -377,12 +378,12 @@ Instruction *InstCombiner::visitAllocaInst(AllocaInst &AI) { // assign it the preferred alignment. if (EntryAI->getAlignment() == 0) EntryAI->setAlignment( - DL.getPrefTypeAlignment(EntryAI->getAllocatedType())); + MaybeAlign(DL.getPrefTypeAlignment(EntryAI->getAllocatedType()))); // Replace this zero-sized alloca with the one at the start of the entry // block after ensuring that the address will be aligned enough for both // types. - unsigned MaxAlign = std::max(EntryAI->getAlignment(), - AI.getAlignment()); + const MaybeAlign MaxAlign( + std::max(EntryAI->getAlignment(), AI.getAlignment())); EntryAI->setAlignment(MaxAlign); if (AI.getType() != EntryAI->getType()) return new BitCastInst(EntryAI, AI.getType()); @@ -455,9 +456,6 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT Value *Ptr = LI.getPointerOperand(); unsigned AS = LI.getPointerAddressSpace(); - SmallVector<std::pair<unsigned, MDNode *>, 8> MD; - LI.getAllMetadata(MD); - Value *NewPtr = nullptr; if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && NewPtr->getType()->getPointerElementType() == NewTy && @@ -467,48 +465,7 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); - MDBuilder MDB(NewLoad->getContext()); - for (const auto &MDPair : MD) { - unsigned ID = MDPair.first; - MDNode *N = MDPair.second; - // Note, essentially every kind of metadata should be preserved here! This - // routine is supposed to clone a load instruction changing *only its type*. - // The only metadata it makes sense to drop is metadata which is invalidated - // when the pointer type changes. This should essentially never be the case - // in LLVM, but we explicitly switch over only known metadata to be - // conservatively correct. If you are adding metadata to LLVM which pertains - // to loads, you almost certainly want to add it here. - switch (ID) { - case LLVMContext::MD_dbg: - case LLVMContext::MD_tbaa: - case LLVMContext::MD_prof: - case LLVMContext::MD_fpmath: - case LLVMContext::MD_tbaa_struct: - case LLVMContext::MD_invariant_load: - case LLVMContext::MD_alias_scope: - case LLVMContext::MD_noalias: - case LLVMContext::MD_nontemporal: - case LLVMContext::MD_mem_parallel_loop_access: - case LLVMContext::MD_access_group: - // All of these directly apply. - NewLoad->setMetadata(ID, N); - break; - - case LLVMContext::MD_nonnull: - copyNonnullMetadata(LI, N, *NewLoad); - break; - case LLVMContext::MD_align: - case LLVMContext::MD_dereferenceable: - case LLVMContext::MD_dereferenceable_or_null: - // These only directly apply if the new type is also a pointer. - if (NewTy->isPointerTy()) - NewLoad->setMetadata(ID, N); - break; - case LLVMContext::MD_range: - copyRangeMetadata(IC.getDataLayout(), LI, N, *NewLoad); - break; - } - } + copyMetadataForLoad(*NewLoad, LI); return NewLoad; } @@ -1004,9 +961,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { LoadAlign != 0 ? LoadAlign : DL.getABITypeAlignment(LI.getType()); if (KnownAlign > EffectiveLoadAlign) - LI.setAlignment(KnownAlign); + LI.setAlignment(MaybeAlign(KnownAlign)); else if (LoadAlign == 0) - LI.setAlignment(EffectiveLoadAlign); + LI.setAlignment(MaybeAlign(EffectiveLoadAlign)); // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Op, LI)) { @@ -1063,11 +1020,11 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { // if (SelectInst *SI = dyn_cast<SelectInst>(Op)) { // load (select (Cond, &V1, &V2)) --> select(Cond, load &V1, load &V2). - unsigned Align = LI.getAlignment(); - if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), Align, - DL, SI) && - isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), Align, - DL, SI)) { + const MaybeAlign Alignment(LI.getAlignment()); + if (isSafeToLoadUnconditionally(SI->getOperand(1), LI.getType(), + Alignment, DL, SI) && + isSafeToLoadUnconditionally(SI->getOperand(2), LI.getType(), + Alignment, DL, SI)) { LoadInst *V1 = Builder.CreateLoad(LI.getType(), SI->getOperand(1), SI->getOperand(1)->getName() + ".val"); @@ -1075,9 +1032,9 @@ Instruction *InstCombiner::visitLoadInst(LoadInst &LI) { Builder.CreateLoad(LI.getType(), SI->getOperand(2), SI->getOperand(2)->getName() + ".val"); assert(LI.isUnordered() && "implied by above"); - V1->setAlignment(Align); + V1->setAlignment(Alignment); V1->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); - V2->setAlignment(Align); + V2->setAlignment(Alignment); V2->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); return SelectInst::Create(SI->getCondition(), V1, V2); } @@ -1399,15 +1356,15 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { return eraseInstFromFunction(SI); // Attempt to improve the alignment. - unsigned KnownAlign = getOrEnforceKnownAlignment( - Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT); - unsigned StoreAlign = SI.getAlignment(); - unsigned EffectiveStoreAlign = - StoreAlign != 0 ? StoreAlign : DL.getABITypeAlignment(Val->getType()); + const Align KnownAlign = Align(getOrEnforceKnownAlignment( + Ptr, DL.getPrefTypeAlignment(Val->getType()), DL, &SI, &AC, &DT)); + const MaybeAlign StoreAlign = MaybeAlign(SI.getAlignment()); + const Align EffectiveStoreAlign = + StoreAlign ? *StoreAlign : Align(DL.getABITypeAlignment(Val->getType())); if (KnownAlign > EffectiveStoreAlign) SI.setAlignment(KnownAlign); - else if (StoreAlign == 0) + else if (!StoreAlign) SI.setAlignment(EffectiveStoreAlign); // Try to canonicalize the stored type. @@ -1622,8 +1579,8 @@ bool InstCombiner::mergeStoreIntoSuccessor(StoreInst &SI) { // Advance to a place where it is safe to insert the new store and insert it. BBI = DestBB->getFirstInsertionPt(); - StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), - SI.isVolatile(), SI.getAlignment(), + StoreInst *NewSI = new StoreInst(MergedVal, SI.getOperand(1), SI.isVolatile(), + MaybeAlign(SI.getAlignment()), SI.getOrdering(), SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); NewSI->setDebugLoc(MergedLoc); diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index cc753ce05313..0b9128a9f5a1 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -124,6 +124,50 @@ static Constant *getLogBase2(Type *Ty, Constant *C) { return ConstantVector::get(Elts); } +// TODO: This is a specific form of a much more general pattern. +// We could detect a select with any binop identity constant, or we +// could use SimplifyBinOp to see if either arm of the select reduces. +// But that needs to be done carefully and/or while removing potential +// reverse canonicalizations as in InstCombiner::foldSelectIntoOp(). +static Value *foldMulSelectToNegate(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Value *Cond, *OtherOp; + + // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp + // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp)); + + // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp + // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), + m_Value(OtherOp)))) + return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp); + + // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp + // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(1.0), + m_SpecificFP(-1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, OtherOp, Builder.CreateFNeg(OtherOp)); + } + + // fmul (select Cond, -1.0, 1.0), OtherOp --> select Cond, -OtherOp, OtherOp + // fmul OtherOp, (select Cond, -1.0, 1.0) --> select Cond, -OtherOp, OtherOp + if (match(&I, m_c_FMul(m_OneUse(m_Select(m_Value(Cond), m_SpecificFP(-1.0), + m_SpecificFP(1.0))), + m_Value(OtherOp)))) { + IRBuilder<>::FastMathFlagGuard FMFGuard(Builder); + Builder.setFastMathFlags(I.getFastMathFlags()); + return Builder.CreateSelect(Cond, Builder.CreateFNeg(OtherOp), OtherOp); + } + + return nullptr; +} + Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Value *V = SimplifyMulInst(I.getOperand(0), I.getOperand(1), SQ.getWithInstruction(&I))) @@ -213,6 +257,9 @@ Instruction *InstCombiner::visitMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { // Canonicalize (X+C1)*CI -> X*CI+C1*CI. @@ -358,6 +405,9 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (Instruction *FoldedMul = foldBinOpIntoSelectOrPhi(I)) return FoldedMul; + if (Value *FoldedMul = foldMulSelectToNegate(I, Builder)) + return replaceInstUsesWith(I, FoldedMul); + // X * -1.0 --> -X Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); if (match(Op1, m_SpecificFP(-1.0))) @@ -373,16 +423,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { if (match(Op0, m_FNeg(m_Value(X))) && match(Op1, m_Constant(C))) return BinaryOperator::CreateFMulFMF(X, ConstantExpr::getFNeg(C), &I); - // Sink negation: -X * Y --> -(X * Y) - // But don't transform constant expressions because there's an inverse fold. - if (match(Op0, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op0)) - return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op1, &I), &I); - - // Sink negation: Y * -X --> -(X * Y) - // But don't transform constant expressions because there's an inverse fold. - if (match(Op1, m_OneUse(m_FNeg(m_Value(X)))) && !isa<ConstantExpr>(Op1)) - return BinaryOperator::CreateFNegFMF(Builder.CreateFMulFMF(X, Op0, &I), &I); - // fabs(X) * fabs(X) -> X * X if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::fabs>(m_Value(X)))) return BinaryOperator::CreateFMulFMF(X, X, &I); @@ -1211,8 +1251,8 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { !IsTan && match(Op0, m_Intrinsic<Intrinsic::cos>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::sin>(m_Specific(X))); - if ((IsTan || IsCot) && hasUnaryFloatFn(&TLI, I.getType(), LibFunc_tan, - LibFunc_tanf, LibFunc_tanl)) { + if ((IsTan || IsCot) && + hasFloatFn(&TLI, I.getType(), LibFunc_tan, LibFunc_tanf, LibFunc_tanl)) { IRBuilder<> B(&I); IRBuilder<>::FastMathFlagGuard FMFGuard(B); B.setFastMathFlags(I.getFastMathFlags()); @@ -1244,6 +1284,17 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { return &I; } + // X / fabs(X) -> copysign(1.0, X) + // fabs(X) / X -> copysign(1.0, X) + if (I.hasNoNaNs() && I.hasNoInfs() && + (match(&I, + m_FDiv(m_Value(X), m_Intrinsic<Intrinsic::fabs>(m_Deferred(X)))) || + match(&I, m_FDiv(m_Intrinsic<Intrinsic::fabs>(m_Value(X)), + m_Deferred(X))))) { + Value *V = Builder.CreateBinaryIntrinsic( + Intrinsic::copysign, ConstantFP::get(I.getType(), 1.0), X, &I); + return replaceInstUsesWith(I, V); + } return nullptr; } @@ -1309,6 +1360,8 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { + // This may increase instruction count, we don't enforce that Y is a + // constant. Constant *N1 = Constant::getAllOnesValue(Ty); Value *Add = Builder.CreateAdd(Op1, N1); return BinaryOperator::CreateAnd(Op0, Add); diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 5820ab726637..e0376b7582f3 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -542,7 +542,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // visitLoadInst will propagate an alignment onto the load when TD is around, // and if TD isn't around, we can't handle the mixed case. bool isVolatile = FirstLI->isVolatile(); - unsigned LoadAlignment = FirstLI->getAlignment(); + MaybeAlign LoadAlignment(FirstLI->getAlignment()); unsigned LoadAddrSpace = FirstLI->getPointerAddressSpace(); // We can't sink the load if the loaded value could be modified between the @@ -574,10 +574,10 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { // If some of the loads have an alignment specified but not all of them, // we can't do the transformation. - if ((LoadAlignment != 0) != (LI->getAlignment() != 0)) + if ((LoadAlignment.hasValue()) != (LI->getAlignment() != 0)) return nullptr; - LoadAlignment = std::min(LoadAlignment, LI->getAlignment()); + LoadAlignment = std::min(LoadAlignment, MaybeAlign(LI->getAlignment())); // If the PHI is of volatile loads and the load block has multiple // successors, sinking it would remove a load of the volatile value from diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index aefaf5af1750..9fc871e49b30 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -785,6 +785,41 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return nullptr; } +/// Fold the following code sequence: +/// \code +/// int a = ctlz(x & -x); +// x ? 31 - a : a; +/// \code +/// +/// into: +/// cttz(x) +static Instruction *foldSelectCtlzToCttz(ICmpInst *ICI, Value *TrueVal, + Value *FalseVal, + InstCombiner::BuilderTy &Builder) { + unsigned BitWidth = TrueVal->getType()->getScalarSizeInBits(); + if (!ICI->isEquality() || !match(ICI->getOperand(1), m_Zero())) + return nullptr; + + if (ICI->getPredicate() == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + if (!match(FalseVal, + m_Xor(m_Deferred(TrueVal), m_SpecificInt(BitWidth - 1)))) + return nullptr; + + if (!match(TrueVal, m_Intrinsic<Intrinsic::ctlz>())) + return nullptr; + + Value *X = ICI->getOperand(0); + auto *II = cast<IntrinsicInst>(TrueVal); + if (!match(II->getOperand(0), m_c_And(m_Specific(X), m_Neg(m_Specific(X))))) + return nullptr; + + Function *F = Intrinsic::getDeclaration(II->getModule(), Intrinsic::cttz, + II->getType()); + return CallInst::Create(F, {X, II->getArgOperand(1)}); +} + /// Attempt to fold a cttz/ctlz followed by a icmp plus select into a single /// call to cttz/ctlz with flag 'is_zero_undef' cleared. /// @@ -973,8 +1008,7 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, // If we are swapping the select operands, swap the metadata too. assert(Sel.getTrueValue() == RHS && Sel.getFalseValue() == LHS && "Unexpected results from matchSelectPattern"); - Sel.setTrueValue(LHS); - Sel.setFalseValue(RHS); + Sel.swapValues(); Sel.swapProfMetadata(); return &Sel; } @@ -1056,17 +1090,293 @@ static Instruction *canonicalizeAbsNabs(SelectInst &Sel, ICmpInst &Cmp, } // We are swapping the select operands, so swap the metadata too. - Sel.setTrueValue(FVal); - Sel.setFalseValue(TVal); + Sel.swapValues(); Sel.swapProfMetadata(); return &Sel; } +static Value *simplifyWithOpReplaced(Value *V, Value *Op, Value *ReplaceOp, + const SimplifyQuery &Q) { + // If this is a binary operator, try to simplify it with the replaced op + // because we know Op and ReplaceOp are equivalant. + // For example: V = X + 1, Op = X, ReplaceOp = 42 + // Simplifies as: add(42, 1) --> 43 + if (auto *BO = dyn_cast<BinaryOperator>(V)) { + if (BO->getOperand(0) == Op) + return SimplifyBinOp(BO->getOpcode(), ReplaceOp, BO->getOperand(1), Q); + if (BO->getOperand(1) == Op) + return SimplifyBinOp(BO->getOpcode(), BO->getOperand(0), ReplaceOp, Q); + } + + return nullptr; +} + +/// If we have a select with an equality comparison, then we know the value in +/// one of the arms of the select. See if substituting this value into an arm +/// and simplifying the result yields the same value as the other arm. +/// +/// To make this transform safe, we must drop poison-generating flags +/// (nsw, etc) if we simplified to a binop because the select may be guarding +/// that poison from propagating. If the existing binop already had no +/// poison-generating flags, then this transform can be done by instsimplify. +/// +/// Consider: +/// %cmp = icmp eq i32 %x, 2147483647 +/// %add = add nsw i32 %x, 1 +/// %sel = select i1 %cmp, i32 -2147483648, i32 %add +/// +/// We can't replace %sel with %add unless we strip away the flags. +/// TODO: Wrapping flags could be preserved in some cases with better analysis. +static Value *foldSelectValueEquivalence(SelectInst &Sel, ICmpInst &Cmp, + const SimplifyQuery &Q) { + if (!Cmp.isEquality()) + return nullptr; + + // Canonicalize the pattern to ICMP_EQ by swapping the select operands. + Value *TrueVal = Sel.getTrueValue(), *FalseVal = Sel.getFalseValue(); + if (Cmp.getPredicate() == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Try each equivalence substitution possibility. + // We have an 'EQ' comparison, so the select's false value will propagate. + // Example: + // (X == 42) ? 43 : (X + 1) --> (X == 42) ? (X + 1) : (X + 1) --> X + 1 + // (X == 42) ? (X + 1) : 43 --> (X == 42) ? (42 + 1) : 43 --> 43 + Value *CmpLHS = Cmp.getOperand(0), *CmpRHS = Cmp.getOperand(1); + if (simplifyWithOpReplaced(FalseVal, CmpLHS, CmpRHS, Q) == TrueVal || + simplifyWithOpReplaced(FalseVal, CmpRHS, CmpLHS, Q) == TrueVal || + simplifyWithOpReplaced(TrueVal, CmpLHS, CmpRHS, Q) == FalseVal || + simplifyWithOpReplaced(TrueVal, CmpRHS, CmpLHS, Q) == FalseVal) { + if (auto *FalseInst = dyn_cast<Instruction>(FalseVal)) + FalseInst->dropPoisonGeneratingFlags(); + return FalseVal; + } + return nullptr; +} + +// See if this is a pattern like: +// %old_cmp1 = icmp slt i32 %x, C2 +// %old_replacement = select i1 %old_cmp1, i32 %target_low, i32 %target_high +// %old_x_offseted = add i32 %x, C1 +// %old_cmp0 = icmp ult i32 %old_x_offseted, C0 +// %r = select i1 %old_cmp0, i32 %x, i32 %old_replacement +// This can be rewritten as more canonical pattern: +// %new_cmp1 = icmp slt i32 %x, -C1 +// %new_cmp2 = icmp sge i32 %x, C0-C1 +// %new_clamped_low = select i1 %new_cmp1, i32 %target_low, i32 %x +// %r = select i1 %new_cmp2, i32 %target_high, i32 %new_clamped_low +// Iff -C1 s<= C2 s<= C0-C1 +// Also ULT predicate can also be UGT iff C0 != -1 (+invert result) +// SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) +static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, + InstCombiner::BuilderTy &Builder) { + Value *X = Sel0.getTrueValue(); + Value *Sel1 = Sel0.getFalseValue(); + + // First match the condition of the outermost select. + // Said condition must be one-use. + if (!Cmp0.hasOneUse()) + return nullptr; + Value *Cmp00 = Cmp0.getOperand(0); + Constant *C0; + if (!match(Cmp0.getOperand(1), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) + return nullptr; + // Canonicalize Cmp0 into the form we expect. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Cmp0.getPredicate()) { + case ICmpInst::Predicate::ICMP_ULT: + break; // Great! + case ICmpInst::Predicate::ICMP_ULE: + // We'd have to increment C0 by one, and for that it must not have all-ones + // element, but then it would have been canonicalized to 'ult' before + // we get here. So we can't do anything useful with 'ule'. + return nullptr; + case ICmpInst::Predicate::ICMP_UGT: + // We want to canonicalize it to 'ult', so we'll need to increment C0, + // which again means it must not have any all-ones elements. + if (!match(C0, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, + APInt::getAllOnesValue( + C0->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have all-ones element[s]. + C0 = AddOne(C0); + std::swap(X, Sel1); + break; + case ICmpInst::Predicate::ICMP_UGE: + // The only way we'd get this predicate if this `icmp` has extra uses, + // but then we won't be able to do this fold. + return nullptr; + default: + return nullptr; // Unknown predicate. + } + + // Now that we've canonicalized the ICmp, we know the X we expect; + // the select in other hand should be one-use. + if (!Sel1->hasOneUse()) + return nullptr; + + // We now can finish matching the condition of the outermost select: + // it should either be the X itself, or an addition of some constant to X. + Constant *C1; + if (Cmp00 == X) + C1 = ConstantInt::getNullValue(Sel0.getType()); + else if (!match(Cmp00, + m_Add(m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) + return nullptr; + + Value *Cmp1; + ICmpInst::Predicate Pred1; + Constant *C2; + Value *ReplacementLow, *ReplacementHigh; + if (!match(Sel1, m_Select(m_Value(Cmp1), m_Value(ReplacementLow), + m_Value(ReplacementHigh))) || + !match(Cmp1, + m_ICmp(Pred1, m_Specific(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C2))))) + return nullptr; + + if (!Cmp1->hasOneUse() && (Cmp00 == X || !Cmp00->hasOneUse())) + return nullptr; // Not enough one-use instructions for the fold. + // FIXME: this restriction could be relaxed if Cmp1 can be reused as one of + // two comparisons we'll need to build. + + // Canonicalize Cmp1 into the form we expect. + // FIXME: we shouldn't care about lanes that are 'undef' in the end? + switch (Pred1) { + case ICmpInst::Predicate::ICMP_SLT: + break; + case ICmpInst::Predicate::ICMP_SLE: + // We'd have to increment C2 by one, and for that it must not have signed + // max element, but then it would have been canonicalized to 'slt' before + // we get here. So we can't do anything useful with 'sle'. + return nullptr; + case ICmpInst::Predicate::ICMP_SGT: + // We want to canonicalize it to 'slt', so we'll need to increment C2, + // which again means it must not have any signed max elements. + if (!match(C2, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, + APInt::getSignedMaxValue( + C2->getType()->getScalarSizeInBits())))) + return nullptr; // Can't do, have signed max element[s]. + C2 = AddOne(C2); + LLVM_FALLTHROUGH; + case ICmpInst::Predicate::ICMP_SGE: + // Also non-canonical, but here we don't need to change C2, + // so we don't have any restrictions on C2, so we can just handle it. + std::swap(ReplacementLow, ReplacementHigh); + break; + default: + return nullptr; // Unknown predicate. + } + + // The thresholds of this clamp-like pattern. + auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); + auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + + // The fold has a precondition 1: C2 s>= ThresholdLow + auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, + ThresholdLowIncl); + if (!match(Precond1, m_One())) + return nullptr; + // The fold has a precondition 2: C2 s<= ThresholdHigh + auto *Precond2 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SLE, C2, + ThresholdHighExcl); + if (!match(Precond2, m_One())) + return nullptr; + + // All good, finally emit the new pattern. + Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); + Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); + Value *MaybeReplacedLow = + Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); + Instruction *MaybeReplacedHigh = + SelectInst::Create(ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); + + return MaybeReplacedHigh; +} + +// If we have +// %cmp = icmp [canonical predicate] i32 %x, C0 +// %r = select i1 %cmp, i32 %y, i32 C1 +// Where C0 != C1 and %x may be different from %y, see if the constant that we +// will have if we flip the strictness of the predicate (i.e. without changing +// the result) is identical to the C1 in select. If it matches we can change +// original comparison to one with swapped predicate, reuse the constant, +// and swap the hands of select. +static Instruction * +tryToReuseConstantFromSelectInComparison(SelectInst &Sel, ICmpInst &Cmp, + InstCombiner::BuilderTy &Builder) { + ICmpInst::Predicate Pred; + Value *X; + Constant *C0; + if (!match(&Cmp, m_OneUse(m_ICmp( + Pred, m_Value(X), + m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))))) + return nullptr; + + // If comparison predicate is non-relational, we won't be able to do anything. + if (ICmpInst::isEquality(Pred)) + return nullptr; + + // If comparison predicate is non-canonical, then we certainly won't be able + // to make it canonical; canonicalizeCmpWithConstant() already tried. + if (!isCanonicalPredicate(Pred)) + return nullptr; + + // If the [input] type of comparison and select type are different, lets abort + // for now. We could try to compare constants with trunc/[zs]ext though. + if (C0->getType() != Sel.getType()) + return nullptr; + + // FIXME: are there any magic icmp predicate+constant pairs we must not touch? + + Value *SelVal0, *SelVal1; // We do not care which one is from where. + match(&Sel, m_Select(m_Value(), m_Value(SelVal0), m_Value(SelVal1))); + // At least one of these values we are selecting between must be a constant + // else we'll never succeed. + if (!match(SelVal0, m_AnyIntegralConstant()) && + !match(SelVal1, m_AnyIntegralConstant())) + return nullptr; + + // Does this constant C match any of the `select` values? + auto MatchesSelectValue = [SelVal0, SelVal1](Constant *C) { + return C->isElementWiseEqual(SelVal0) || C->isElementWiseEqual(SelVal1); + }; + + // If C0 *already* matches true/false value of select, we are done. + if (MatchesSelectValue(C0)) + return nullptr; + + // Check the constant we'd have with flipped-strictness predicate. + auto FlippedStrictness = getFlippedStrictnessPredicateAndConstant(Pred, C0); + if (!FlippedStrictness) + return nullptr; + + // If said constant doesn't match either, then there is no hope, + if (!MatchesSelectValue(FlippedStrictness->second)) + return nullptr; + + // It matched! Lets insert the new comparison just before select. + InstCombiner::BuilderTy::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(&Sel); + + Pred = ICmpInst::getSwappedPredicate(Pred); // Yes, swapped. + Value *NewCmp = Builder.CreateICmp(Pred, X, FlippedStrictness->second, + Cmp.getName() + ".inv"); + Sel.setCondition(NewCmp); + Sel.swapValues(); + Sel.swapProfMetadata(); + + return &Sel; +} + /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); + if (Value *V = foldSelectValueEquivalence(SI, *ICI, SQ)) + return replaceInstUsesWith(SI, V); if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -1074,12 +1384,21 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, Builder)) return NewAbs; + if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder)) + return NewAbs; + + if (Instruction *NewSel = + tryToReuseConstantFromSelectInComparison(SI, *ICI, Builder)) + return NewSel; + bool Changed = adjustMinMax(SI, *ICI); if (Value *V = foldSelectICmpAnd(SI, ICI, Builder)) return replaceInstUsesWith(SI, V); // NOTE: if we wanted to, this is where to detect integer MIN/MAX + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); @@ -1149,6 +1468,9 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, foldSelectICmpAndAnd(SI.getType(), ICI, TrueVal, FalseVal, Builder)) return V; + if (Instruction *V = foldSelectCtlzToCttz(ICI, TrueVal, FalseVal, Builder)) + return V; + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); @@ -1253,6 +1575,16 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, } } + // max(max(A, B), min(A, B)) --> max(A, B) + // min(min(A, B), max(A, B)) --> min(A, B) + // TODO: This could be done in instsimplify. + if (SPF1 == SPF2 && + ((SPF1 == SPF_UMIN && match(C, m_c_UMax(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_SMIN && match(C, m_c_SMax(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_UMAX && match(C, m_c_UMin(m_Specific(A), m_Specific(B)))) || + (SPF1 == SPF_SMAX && match(C, m_c_SMin(m_Specific(A), m_Specific(B)))))) + return replaceInstUsesWith(Outer, Inner); + // ABS(ABS(X)) -> ABS(X) // NABS(NABS(X)) -> NABS(X) // TODO: This could be done in instsimplify. @@ -1280,7 +1612,7 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, return true; } - if (IsFreeToInvert(V, !V->hasNUsesOrMore(3))) { + if (isFreeToInvert(V, !V->hasNUsesOrMore(3))) { NotV = nullptr; return true; } @@ -1492,6 +1824,30 @@ static Instruction *canonicalizeSelectToShuffle(SelectInst &SI) { ConstantVector::get(Mask)); } +/// If we have a select of vectors with a scalar condition, try to convert that +/// to a vector select by splatting the condition. A splat may get folded with +/// other operations in IR and having all operands of a select be vector types +/// is likely better for vector codegen. +static Instruction *canonicalizeScalarSelectOfVecs( + SelectInst &Sel, InstCombiner::BuilderTy &Builder) { + Type *Ty = Sel.getType(); + if (!Ty->isVectorTy()) + return nullptr; + + // We can replace a single-use extract with constant index. + Value *Cond = Sel.getCondition(); + if (!match(Cond, m_OneUse(m_ExtractElement(m_Value(), m_ConstantInt())))) + return nullptr; + + // select (extelt V, Index), T, F --> select (splat V, Index), T, F + // Splatting the extracted condition reduces code (we could directly create a + // splat shuffle of the source vector to eliminate the intermediate step). + unsigned NumElts = Ty->getVectorNumElements(); + Value *SplatCond = Builder.CreateVectorSplat(NumElts, Cond); + Sel.setCondition(SplatCond); + return &Sel; +} + /// Reuse bitcasted operands between a compare and select: /// select (cmp (bitcast C), (bitcast D)), (bitcast' C), (bitcast' D) --> /// bitcast (select (cmp (bitcast C), (bitcast D)), (bitcast C), (bitcast D)) @@ -1648,6 +2004,71 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, return nullptr; } +/// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. +Instruction *InstCombiner::matchSAddSubSat(SelectInst &MinMax1) { + Type *Ty = MinMax1.getType(); + + // We are looking for a tree of: + // max(INT_MIN, min(INT_MAX, add(sext(A), sext(B)))) + // Where the min and max could be reversed + Instruction *MinMax2; + BinaryOperator *AddSub; + const APInt *MinValue, *MaxValue; + if (match(&MinMax1, m_SMin(m_Instruction(MinMax2), m_APInt(MaxValue)))) { + if (!match(MinMax2, m_SMax(m_BinOp(AddSub), m_APInt(MinValue)))) + return nullptr; + } else if (match(&MinMax1, + m_SMax(m_Instruction(MinMax2), m_APInt(MinValue)))) { + if (!match(MinMax2, m_SMin(m_BinOp(AddSub), m_APInt(MaxValue)))) + return nullptr; + } else + return nullptr; + + // Check that the constants clamp a saturate, and that the new type would be + // sensible to convert to. + if (!(*MaxValue + 1).isPowerOf2() || -*MinValue != *MaxValue + 1) + return nullptr; + // In what bitwidth can this be treated as saturating arithmetics? + unsigned NewBitWidth = (*MaxValue + 1).logBase2() + 1; + // FIXME: This isn't quite right for vectors, but using the scalar type is a + // good first approximation for what should be done there. + if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) + return nullptr; + + // Also make sure that the number of uses is as expected. The "3"s are for the + // the two items of min/max (the compare and the select). + if (MinMax2->hasNUsesOrMore(3) || AddSub->hasNUsesOrMore(3)) + return nullptr; + + // Create the new type (which can be a vector type) + Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); + // Match the two extends from the add/sub + Value *A, *B; + if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B))))) + return nullptr; + // And check the incoming values are of a type smaller than or equal to the + // size of the saturation. Otherwise the higher bits can cause different + // results. + if (A->getType()->getScalarSizeInBits() > NewBitWidth || + B->getType()->getScalarSizeInBits() > NewBitWidth) + return nullptr; + + Intrinsic::ID IntrinsicID; + if (AddSub->getOpcode() == Instruction::Add) + IntrinsicID = Intrinsic::sadd_sat; + else if (AddSub->getOpcode() == Instruction::Sub) + IntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + // Finally create and return the sat intrinsic, truncated to the new type + Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); + Value *AT = Builder.CreateSExt(A, NewTy); + Value *BT = Builder.CreateSExt(B, NewTy); + Value *Sat = Builder.CreateCall(F, {AT, BT}); + return CastInst::Create(Instruction::SExt, Sat, Ty); +} + /// Reduce a sequence of min/max with a common operand. static Instruction *factorizeMinMaxTree(SelectPatternFlavor SPF, Value *LHS, Value *RHS, @@ -1788,6 +2209,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = canonicalizeSelectToShuffle(SI)) return I; + if (Instruction *I = canonicalizeScalarSelectOfVecs(SI, Builder)) + return I; + // Canonicalize a one-use integer compare with a non-canonical predicate by // inverting the predicate and swapping the select operands. This matches a // compare canonicalization for conditional branches. @@ -2013,16 +2437,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { (LHS->getType()->isFPOrFPVectorTy() && ((CmpLHS != LHS && CmpLHS != RHS) || (CmpRHS != LHS && CmpRHS != RHS)))) { - CmpInst::Predicate Pred = getMinMaxPred(SPF, SPR.Ordered); + CmpInst::Predicate MinMaxPred = getMinMaxPred(SPF, SPR.Ordered); Value *Cmp; - if (CmpInst::isIntPredicate(Pred)) { - Cmp = Builder.CreateICmp(Pred, LHS, RHS); + if (CmpInst::isIntPredicate(MinMaxPred)) { + Cmp = Builder.CreateICmp(MinMaxPred, LHS, RHS); } else { IRBuilder<>::FastMathFlagGuard FMFG(Builder); - auto FMF = cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); + auto FMF = + cast<FPMathOperator>(SI.getCondition())->getFastMathFlags(); Builder.setFastMathFlags(FMF); - Cmp = Builder.CreateFCmp(Pred, LHS, RHS); + Cmp = Builder.CreateFCmp(MinMaxPred, LHS, RHS); } Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); @@ -2040,9 +2465,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { Value *A; if (match(X, m_Not(m_Value(A))) && !X->hasNUsesOrMore(3) && - !IsFreeToInvert(A, A->hasOneUse()) && + !isFreeToInvert(A, A->hasOneUse()) && // Passing false to only consider m_Not and constants. - IsFreeToInvert(Y, false)) { + isFreeToInvert(Y, false)) { Value *B = Builder.CreateNot(Y); Value *NewMinMax = createMinMax(Builder, getInverseMinMaxFlavor(SPF), A, B); @@ -2070,6 +2495,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *I = factorizeMinMaxTree(SPF, LHS, RHS, Builder)) return I; + if (Instruction *I = matchSAddSubSat(SI)) + return I; } } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index c821292400cd..64294838644f 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -25,50 +25,275 @@ using namespace PatternMatch; // we should rewrite it as // x shiftopcode (Q+K) iff (Q+K) u< bitwidth(x) // This is valid for any shift, but they must be identical. -static Instruction * -reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0, - const SimplifyQuery &SQ) { - // Look for: (x shiftopcode ShAmt0) shiftopcode ShAmt1 - Value *X, *ShAmt1, *ShAmt0; +// +// AnalyzeForSignBitExtraction indicates that we will only analyze whether this +// pattern has any 2 right-shifts that sum to 1 less than original bit width. +Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( + BinaryOperator *Sh0, const SimplifyQuery &SQ, + bool AnalyzeForSignBitExtraction) { + // Look for a shift of some instruction, ignore zext of shift amount if any. + Instruction *Sh0Op0; + Value *ShAmt0; + if (!match(Sh0, + m_Shift(m_Instruction(Sh0Op0), m_ZExtOrSelf(m_Value(ShAmt0))))) + return nullptr; + + // If there is a truncation between the two shifts, we must make note of it + // and look through it. The truncation imposes additional constraints on the + // transform. Instruction *Sh1; - if (!match(Sh0, m_Shift(m_CombineAnd(m_Shift(m_Value(X), m_Value(ShAmt1)), - m_Instruction(Sh1)), - m_Value(ShAmt0)))) + Value *Trunc = nullptr; + match(Sh0Op0, + m_CombineOr(m_CombineAnd(m_Trunc(m_Instruction(Sh1)), m_Value(Trunc)), + m_Instruction(Sh1))); + + // Inner shift: (x shiftopcode ShAmt1) + // Like with other shift, ignore zext of shift amount if any. + Value *X, *ShAmt1; + if (!match(Sh1, m_Shift(m_Value(X), m_ZExtOrSelf(m_Value(ShAmt1))))) + return nullptr; + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now.. + if (ShAmt0->getType() != ShAmt1->getType()) + return nullptr; + + // We are only looking for signbit extraction if we have two right shifts. + bool HadTwoRightShifts = match(Sh0, m_Shr(m_Value(), m_Value())) && + match(Sh1, m_Shr(m_Value(), m_Value())); + // ... and if it's not two right-shifts, we know the answer already. + if (AnalyzeForSignBitExtraction && !HadTwoRightShifts) return nullptr; - // The shift opcodes must be identical. + // The shift opcodes must be identical, unless we are just checking whether + // this pattern can be interpreted as a sign-bit-extraction. Instruction::BinaryOps ShiftOpcode = Sh0->getOpcode(); - if (ShiftOpcode != Sh1->getOpcode()) + bool IdenticalShOpcodes = Sh0->getOpcode() == Sh1->getOpcode(); + if (!IdenticalShOpcodes && !AnalyzeForSignBitExtraction) return nullptr; + + // If we saw truncation, we'll need to produce extra instruction, + // and for that one of the operands of the shift must be one-use, + // unless of course we don't actually plan to produce any instructions here. + if (Trunc && !AnalyzeForSignBitExtraction && + !match(Sh0, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + // Can we fold (ShAmt0+ShAmt1) ? - Value *NewShAmt = SimplifyBinOp(Instruction::BinaryOps::Add, ShAmt0, ShAmt1, - SQ.getWithInstruction(Sh0)); + auto *NewShAmt = dyn_cast_or_null<Constant>( + SimplifyAddInst(ShAmt0, ShAmt1, /*isNSW=*/false, /*isNUW=*/false, + SQ.getWithInstruction(Sh0))); if (!NewShAmt) return nullptr; // Did not simplify. - // Is the new shift amount smaller than the bit width? - // FIXME: could also rely on ConstantRange. - unsigned BitWidth = X->getType()->getScalarSizeInBits(); + unsigned NewShAmtBitWidth = NewShAmt->getType()->getScalarSizeInBits(); + unsigned XBitWidth = X->getType()->getScalarSizeInBits(); + // Is the new shift amount smaller than the bit width of inner/new shift? if (!match(NewShAmt, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_ULT, - APInt(BitWidth, BitWidth)))) - return nullptr; + APInt(NewShAmtBitWidth, XBitWidth)))) + return nullptr; // FIXME: could perform constant-folding. + + // If there was a truncation, and we have a right-shift, we can only fold if + // we are left with the original sign bit. Likewise, if we were just checking + // that this is a sighbit extraction, this is the place to check it. + // FIXME: zero shift amount is also legal here, but we can't *easily* check + // more than one predicate so it's not really worth it. + if (HadTwoRightShifts && (Trunc || AnalyzeForSignBitExtraction)) { + // If it's not a sign bit extraction, then we're done. + if (!match(NewShAmt, + m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(NewShAmtBitWidth, XBitWidth - 1)))) + return nullptr; + // If it is, and that was the question, return the base value. + if (AnalyzeForSignBitExtraction) + return X; + } + + assert(IdenticalShOpcodes && "Should not get here with different shifts."); + // All good, we can do this fold. + NewShAmt = ConstantExpr::getZExtOrBitCast(NewShAmt, X->getType()); + BinaryOperator *NewShift = BinaryOperator::Create(ShiftOpcode, X, NewShAmt); - // If both of the original shifts had the same flag set, preserve the flag. - if (ShiftOpcode == Instruction::BinaryOps::Shl) { - NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && - Sh1->hasNoUnsignedWrap()); - NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && - Sh1->hasNoSignedWrap()); - } else { - NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + + // The flags can only be propagated if there wasn't a trunc. + if (!Trunc) { + // If the pattern did not involve trunc, and both of the original shifts + // had the same flag set, preserve the flag. + if (ShiftOpcode == Instruction::BinaryOps::Shl) { + NewShift->setHasNoUnsignedWrap(Sh0->hasNoUnsignedWrap() && + Sh1->hasNoUnsignedWrap()); + NewShift->setHasNoSignedWrap(Sh0->hasNoSignedWrap() && + Sh1->hasNoSignedWrap()); + } else { + NewShift->setIsExact(Sh0->isExact() && Sh1->isExact()); + } + } + + Instruction *Ret = NewShift; + if (Trunc) { + Builder.Insert(NewShift); + Ret = CastInst::Create(Instruction::Trunc, NewShift, Sh0->getType()); + } + + return Ret; +} + +// Try to replace `undef` constants in C with Replacement. +static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) { + if (C && match(C, m_Undef())) + return Replacement; + + if (auto *CV = dyn_cast<ConstantVector>(C)) { + llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands()); + for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) { + Constant *EltC = CV->getOperand(i); + NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC; + } + return ConstantVector::get(NewOps); + } + + // Don't know how to deal with this constant. + return C; +} + +// If we have some pattern that leaves only some low bits set, and then performs +// left-shift of those bits, if none of the bits that are left after the final +// shift are modified by the mask, we can omit the mask. +// +// There are many variants to this pattern: +// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt +// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt +// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt +// e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt +// f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt +// All these patterns can be simplified to just: +// x << ShiftShAmt +// iff: +// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x) +// c,d,e,f) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt) +static Instruction * +dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, + const SimplifyQuery &Q, + InstCombiner::BuilderTy &Builder) { + assert(OuterShift->getOpcode() == Instruction::BinaryOps::Shl && + "The input must be 'shl'!"); + + Value *Masked, *ShiftShAmt; + match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt))); + + Type *NarrowestTy = OuterShift->getType(); + Type *WidestTy = Masked->getType(); + // The mask must be computed in a type twice as wide to ensure + // that no bits are lost if the sum-of-shifts is wider than the base type. + Type *ExtendedTy = WidestTy->getExtendedType(); + + Value *MaskShAmt; + + // ((1 << MaskShAmt) - 1) + auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); + // (~(-1 << maskNbits)) + auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); + // (-1 >> MaskShAmt) + auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) >> MaskShAmt) + auto MaskD = + m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + + Value *X; + Constant *NewMask; + + if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Can we simplify (MaskShAmt+ShiftShAmt) ? + auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( + MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!SumOfShAmts) + return nullptr; // Did not simplify. + // In this pattern SumOfShAmts correlates with the number of low bits + // that shall remain in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with final + // shift bitwidth to ensure that the value remains undef when creating the + // subsequent shift op. + SumOfShAmts = replaceUndefsWith( + SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), + ExtendedTy->getScalarSizeInBits())); + auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); + // And compute the mask as usual: ~(-1 << (SumOfShAmts)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + auto *ExtendedInvertedMask = + ConstantExpr::getShl(ExtendedAllOnes, ExtendedSumOfShAmts); + NewMask = ConstantExpr::getNot(ExtendedInvertedMask); + } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || + match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), + m_Deferred(MaskShAmt)))) { + // Can we simplify (ShiftShAmt-MaskShAmt) ? + auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( + ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); + if (!ShAmtsDiff) + return nullptr; // Did not simplify. + // In this pattern ShAmtsDiff correlates with the number of high bits that + // shall be unset in the root value (OuterShift). + + // An extend of an undef value becomes zero because the high bits are never + // completely unknown. Replace the the `undef` shift amounts with negated + // bitwidth of innermost shift to ensure that the value remains undef when + // creating the subsequent shift op. + unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); + ShAmtsDiff = replaceUndefsWith( + ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), + -WidestTyBitWidth)); + auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( + ConstantExpr::getSub(ConstantInt::get(ShAmtsDiff->getType(), + WidestTyBitWidth, + /*isSigned=*/false), + ShAmtsDiff), + ExtendedTy); + // And compute the mask as usual: (-1 l>> (NumHighBitsToClear)) + auto *ExtendedAllOnes = ConstantExpr::getAllOnesValue(ExtendedTy); + NewMask = + ConstantExpr::getLShr(ExtendedAllOnes, ExtendedNumHighBitsToClear); + } else + return nullptr; // Don't know anything about this pattern. + + NewMask = ConstantExpr::getTrunc(NewMask, NarrowestTy); + + // Does this mask has any unset bits? If not then we can just not apply it. + bool NeedMask = !match(NewMask, m_AllOnes()); + + // If we need to apply a mask, there are several more restrictions we have. + if (NeedMask) { + // The old masking instruction must go away. + if (!Masked->hasOneUse()) + return nullptr; + // The original "masking" instruction must not have been`ashr`. + if (match(Masked, m_AShr(m_Value(), m_Value()))) + return nullptr; } - return NewShift; + + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. + auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, + OuterShift->getOperand(1)); + + if (!NeedMask) + return NewShift; + + Builder.Insert(NewShift); + return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); + // If the shift amount is a one-use `sext`, we can demote it to `zext`. + Value *Y; + if (match(Op1, m_OneUse(m_SExt(m_Value(Y))))) { + Value *NewExt = Builder.CreateZExt(Y, I.getType(), Op1->getName()); + return BinaryOperator::Create(I.getOpcode(), Op0, NewExt); + } + // See if we can fold away this shift. if (SimplifyDemandedInstructionBits(I)) return &I; @@ -83,8 +308,8 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { if (Instruction *Res = FoldShiftByConstant(Op0, CUI, I)) return Res; - if (Instruction *NewShift = - reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ)) + if (auto *NewShift = cast_or_null<Instruction>( + reassociateShiftAmtsOfTwoSameDirectionShifts(&I, SQ))) return NewShift; // (C1 shift (A add C2)) -> (C1 shift C2) shift A) @@ -618,9 +843,10 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, } Instruction *InstCombiner::visitShl(BinaryOperator &I) { + const SimplifyQuery Q = SQ.getWithInstruction(&I); + if (Value *V = SimplifyShlInst(I.getOperand(0), I.getOperand(1), - I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), - SQ.getWithInstruction(&I))) + I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), Q)) return replaceInstUsesWith(I, V); if (Instruction *X = foldVectorBinop(I)) @@ -629,6 +855,9 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { if (Instruction *V = commonShiftTransforms(I)) return V; + if (Instruction *V = dropRedundantMaskingOfLeftShiftInput(&I, Q, Builder)) + return V; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); @@ -636,12 +865,11 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { const APInt *ShAmtAPInt; if (match(Op1, m_APInt(ShAmtAPInt))) { unsigned ShAmt = ShAmtAPInt->getZExtValue(); - unsigned BitWidth = Ty->getScalarSizeInBits(); // shl (zext X), ShAmt --> zext (shl X, ShAmt) // This is only valid if X would have zeros shifted out. Value *X; - if (match(Op0, m_ZExt(m_Value(X)))) { + if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); if (ShAmt < SrcWidth && MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) @@ -719,6 +947,12 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { // (X * C2) << C1 --> X * (C2 << C1) if (match(Op0, m_Mul(m_Value(X), m_Constant(C2)))) return BinaryOperator::CreateMul(X, ConstantExpr::getShl(C2, C1)); + + // shl (zext i1 X), C1 --> select (X, 1 << C1, 0) + if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { + auto *NewC = ConstantExpr::getShl(ConstantInt::get(Ty, 1), C1); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } } // (1 << (C - x)) -> ((1 << C) >> x) if C is bitwidth - 1 @@ -859,6 +1093,75 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return nullptr; } +Instruction * +InstCombiner::foldVariableSignZeroExtensionOfVariableHighBitExtract( + BinaryOperator &OldAShr) { + assert(OldAShr.getOpcode() == Instruction::AShr && + "Must be called with arithmetic right-shift instruction only."); + + // Check that constant C is a splat of the element-wise bitwidth of V. + auto BitWidthSplat = [](Constant *C, Value *V) { + return match( + C, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_EQ, + APInt(C->getType()->getScalarSizeInBits(), + V->getType()->getScalarSizeInBits()))); + }; + + // It should look like variable-length sign-extension on the outside: + // (Val << (bitwidth(Val)-Nbits)) a>> (bitwidth(Val)-Nbits) + Value *NBits; + Instruction *MaybeTrunc; + Constant *C1, *C2; + if (!match(&OldAShr, + m_AShr(m_Shl(m_Instruction(MaybeTrunc), + m_ZExtOrSelf(m_Sub(m_Constant(C1), + m_ZExtOrSelf(m_Value(NBits))))), + m_ZExtOrSelf(m_Sub(m_Constant(C2), + m_ZExtOrSelf(m_Deferred(NBits)))))) || + !BitWidthSplat(C1, &OldAShr) || !BitWidthSplat(C2, &OldAShr)) + return nullptr; + + // There may or may not be a truncation after outer two shifts. + Instruction *HighBitExtract; + match(MaybeTrunc, m_TruncOrSelf(m_Instruction(HighBitExtract))); + bool HadTrunc = MaybeTrunc != HighBitExtract; + + // And finally, the innermost part of the pattern must be a right-shift. + Value *X, *NumLowBitsToSkip; + if (!match(HighBitExtract, m_Shr(m_Value(X), m_Value(NumLowBitsToSkip)))) + return nullptr; + + // Said right-shift must extract high NBits bits - C0 must be it's bitwidth. + Constant *C0; + if (!match(NumLowBitsToSkip, + m_ZExtOrSelf( + m_Sub(m_Constant(C0), m_ZExtOrSelf(m_Specific(NBits))))) || + !BitWidthSplat(C0, HighBitExtract)) + return nullptr; + + // Since the NBits is identical for all shifts, if the outermost and + // innermost shifts are identical, then outermost shifts are redundant. + // If we had truncation, do keep it though. + if (HighBitExtract->getOpcode() == OldAShr.getOpcode()) + return replaceInstUsesWith(OldAShr, MaybeTrunc); + + // Else, if there was a truncation, then we need to ensure that one + // instruction will go away. + if (HadTrunc && !match(&OldAShr, m_c_BinOp(m_OneUse(m_Value()), m_Value()))) + return nullptr; + + // Finally, bypass two innermost shifts, and perform the outermost shift on + // the operands of the innermost shift. + Instruction *NewAShr = + BinaryOperator::Create(OldAShr.getOpcode(), X, NumLowBitsToSkip); + NewAShr->copyIRFlags(HighBitExtract); // We can preserve 'exact'-ness. + if (!HadTrunc) + return NewAShr; + + Builder.Insert(NewAShr); + return TruncInst::CreateTruncOrBitCast(NewAShr, OldAShr.getType()); +} + Instruction *InstCombiner::visitAShr(BinaryOperator &I) { if (Value *V = SimplifyAShrInst(I.getOperand(0), I.getOperand(1), I.isExact(), SQ.getWithInstruction(&I))) @@ -933,6 +1236,9 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { } } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) + return R; + // See if we can turn a signed shr into an unsigned shr. if (MaskedValueIsZero(Op0, APInt::getSignMask(BitWidth), 0, &I)) return BinaryOperator::CreateLShr(Op0, Op1); diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index e0d85c4b49ae..d30ab8001897 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -971,6 +971,13 @@ InstCombiner::simplifyShrShlDemandedBits(Instruction *Shr, const APInt &ShrOp1, Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, APInt DemandedElts, int DMaskIdx) { + + // FIXME: Allow v3i16/v3f16 in buffer intrinsics when the types are fully supported. + if (DMaskIdx < 0 && + II->getType()->getScalarSizeInBits() != 32 && + DemandedElts.getActiveBits() == 3) + return nullptr; + unsigned VWidth = II->getType()->getVectorNumElements(); if (VWidth == 1) return nullptr; @@ -1067,16 +1074,22 @@ Value *InstCombiner::simplifyAMDGCNMemoryIntrinsicDemanded(IntrinsicInst *II, } /// The specified value produces a vector with any number of elements. +/// This method analyzes which elements of the operand are undef and returns +/// that information in UndefElts. +/// /// DemandedElts contains the set of elements that are actually used by the -/// caller. This method analyzes which elements of the operand are undef and -/// returns that information in UndefElts. +/// caller, and by default (AllowMultipleUsers equals false) the value is +/// simplified only if it has a single caller. If AllowMultipleUsers is set +/// to true, DemandedElts refers to the union of sets of elements that are +/// used by all callers. /// /// If the information about demanded elements can be used to simplify the /// operation, the operation is simplified, then the resultant value is /// returned. This returns null if no change was made. Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, APInt &UndefElts, - unsigned Depth) { + unsigned Depth, + bool AllowMultipleUsers) { unsigned VWidth = V->getType()->getVectorNumElements(); APInt EltMask(APInt::getAllOnesValue(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); @@ -1130,19 +1143,21 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, if (Depth == 10) return nullptr; - // If multiple users are using the root value, proceed with - // simplification conservatively assuming that all elements - // are needed. - if (!V->hasOneUse()) { - // Quit if we find multiple users of a non-root value though. - // They'll be handled when it's their turn to be visited by - // the main instcombine process. - if (Depth != 0) - // TODO: Just compute the UndefElts information recursively. - return nullptr; + if (!AllowMultipleUsers) { + // If multiple users are using the root value, proceed with + // simplification conservatively assuming that all elements + // are needed. + if (!V->hasOneUse()) { + // Quit if we find multiple users of a non-root value though. + // They'll be handled when it's their turn to be visited by + // the main instcombine process. + if (Depth != 0) + // TODO: Just compute the UndefElts information recursively. + return nullptr; - // Conservatively assume that all elements are needed. - DemandedElts = EltMask; + // Conservatively assume that all elements are needed. + DemandedElts = EltMask; + } } Instruction *I = dyn_cast<Instruction>(V); @@ -1674,8 +1689,11 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, case Intrinsic::amdgcn_buffer_load_format: case Intrinsic::amdgcn_raw_buffer_load: case Intrinsic::amdgcn_raw_buffer_load_format: + case Intrinsic::amdgcn_raw_tbuffer_load: case Intrinsic::amdgcn_struct_buffer_load: case Intrinsic::amdgcn_struct_buffer_load_format: + case Intrinsic::amdgcn_struct_tbuffer_load: + case Intrinsic::amdgcn_tbuffer_load: return simplifyAMDGCNMemoryIntrinsicDemanded(II, DemandedElts); default: { if (getAMDGPUImageDMaskIntrinsic(II->getIntrinsicID())) diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index dc9abdd7f47a..9c890748e5ab 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -253,6 +253,69 @@ static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, return nullptr; } +/// Find elements of V demanded by UserInstr. +static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + // Conservatively assume that all elements are needed. + APInt UsedElts(APInt::getAllOnesValue(VWidth)); + + switch (UserInstr->getOpcode()) { + case Instruction::ExtractElement: { + ExtractElementInst *EEI = cast<ExtractElementInst>(UserInstr); + assert(EEI->getVectorOperand() == V); + ConstantInt *EEIIndexC = dyn_cast<ConstantInt>(EEI->getIndexOperand()); + if (EEIIndexC && EEIIndexC->getValue().ult(VWidth)) { + UsedElts = APInt::getOneBitSet(VWidth, EEIIndexC->getZExtValue()); + } + break; + } + case Instruction::ShuffleVector: { + ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(UserInstr); + unsigned MaskNumElts = UserInstr->getType()->getVectorNumElements(); + + UsedElts = APInt(VWidth, 0); + for (unsigned i = 0; i < MaskNumElts; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (MaskVal == -1u || MaskVal >= 2 * VWidth) + continue; + if (Shuffle->getOperand(0) == V && (MaskVal < VWidth)) + UsedElts.setBit(MaskVal); + if (Shuffle->getOperand(1) == V && + ((MaskVal >= VWidth) && (MaskVal < 2 * VWidth))) + UsedElts.setBit(MaskVal - VWidth); + } + break; + } + default: + break; + } + return UsedElts; +} + +/// Find union of elements of V demanded by all its users. +/// If it is known by querying findDemandedEltsBySingleUser that +/// no user demands an element of V, then the corresponding bit +/// remains unset in the returned value. +static APInt findDemandedEltsByAllUsers(Value *V) { + unsigned VWidth = V->getType()->getVectorNumElements(); + + APInt UnionUsedElts(VWidth, 0); + for (const Use &U : V->uses()) { + if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { + UnionUsedElts |= findDemandedEltsBySingleUser(V, I); + } else { + UnionUsedElts = APInt::getAllOnesValue(VWidth); + break; + } + + if (UnionUsedElts.isAllOnesValue()) + break; + } + + return UnionUsedElts; +} + Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Value *SrcVec = EI.getVectorOperand(); Value *Index = EI.getIndexOperand(); @@ -271,19 +334,35 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { return nullptr; // This instruction only demands the single element from the input vector. - // If the input vector has a single use, simplify it based on this use - // property. - if (SrcVec->hasOneUse() && NumElts != 1) { - APInt UndefElts(NumElts, 0); - APInt DemandedElts(NumElts, 0); - DemandedElts.setBit(IndexC->getZExtValue()); - if (Value *V = SimplifyDemandedVectorElts(SrcVec, DemandedElts, - UndefElts)) { - EI.setOperand(0, V); - return &EI; + if (NumElts != 1) { + // If the input vector has a single use, simplify it based on this use + // property. + if (SrcVec->hasOneUse()) { + APInt UndefElts(NumElts, 0); + APInt DemandedElts(NumElts, 0); + DemandedElts.setBit(IndexC->getZExtValue()); + if (Value *V = + SimplifyDemandedVectorElts(SrcVec, DemandedElts, UndefElts)) { + EI.setOperand(0, V); + return &EI; + } + } else { + // If the input vector has multiple uses, simplify it based on a union + // of all elements used. + APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); + if (!DemandedElts.isAllOnesValue()) { + APInt UndefElts(NumElts, 0); + if (Value *V = SimplifyDemandedVectorElts( + SrcVec, DemandedElts, UndefElts, 0 /* Depth */, + true /* AllowMultipleUsers */)) { + if (V != SrcVec) { + SrcVec->replaceAllUsesWith(V); + return &EI; + } + } + } } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) return I; @@ -766,6 +845,55 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); } +/// Try to fold an extract+insert element into an existing identity shuffle by +/// changing the shuffle's mask to include the index of this insert element. +static Instruction *foldInsEltIntoIdentityShuffle(InsertElementInst &InsElt) { + // Check if the vector operand of this insert is an identity shuffle. + auto *Shuf = dyn_cast<ShuffleVectorInst>(InsElt.getOperand(0)); + if (!Shuf || !isa<UndefValue>(Shuf->getOperand(1)) || + !(Shuf->isIdentityWithExtract() || Shuf->isIdentityWithPadding())) + return nullptr; + + // Check for a constant insertion index. + uint64_t IdxC; + if (!match(InsElt.getOperand(2), m_ConstantInt(IdxC))) + return nullptr; + + // Check if this insert's scalar op is extracted from the identity shuffle's + // input vector. + Value *Scalar = InsElt.getOperand(1); + Value *X = Shuf->getOperand(0); + if (!match(Scalar, m_ExtractElement(m_Specific(X), m_SpecificInt(IdxC)))) + return nullptr; + + // Replace the shuffle mask element at the index of this extract+insert with + // that same index value. + // For example: + // inselt (shuf X, IdMask), (extelt X, IdxC), IdxC --> shuf X, IdMask' + unsigned NumMaskElts = Shuf->getType()->getVectorNumElements(); + SmallVector<Constant *, 16> NewMaskVec(NumMaskElts); + Type *I32Ty = IntegerType::getInt32Ty(Shuf->getContext()); + Constant *NewMaskEltC = ConstantInt::get(I32Ty, IdxC); + Constant *OldMask = Shuf->getMask(); + for (unsigned i = 0; i != NumMaskElts; ++i) { + if (i != IdxC) { + // All mask elements besides the inserted element remain the same. + NewMaskVec[i] = OldMask->getAggregateElement(i); + } else if (OldMask->getAggregateElement(i) == NewMaskEltC) { + // If the mask element was already set, there's nothing to do + // (demanded elements analysis may unset it later). + return nullptr; + } else { + assert(isa<UndefValue>(OldMask->getAggregateElement(i)) && + "Unexpected shuffle mask element for identity shuffle"); + NewMaskVec[i] = NewMaskEltC; + } + } + + Constant *NewMask = ConstantVector::get(NewMaskVec); + return new ShuffleVectorInst(X, Shuf->getOperand(1), NewMask); +} + /// If we have an insertelement instruction feeding into another insertelement /// and the 2nd is inserting a constant into the vector, canonicalize that /// constant insertion before the insertion of a variable: @@ -987,6 +1115,9 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *Splat = foldInsEltIntoSplat(IE)) return Splat; + if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) + return IdentityShuf; + return nullptr; } @@ -1009,17 +1140,23 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, if (Depth == 0) return false; switch (I->getOpcode()) { + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::URem: + case Instruction::SRem: + // Propagating an undefined shuffle mask element to integer div/rem is not + // allowed because those opcodes can create immediate undefined behavior + // from an undefined element in an operand. + if (llvm::any_of(Mask, [](int M){ return M == -1; })) + return false; + LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: case Instruction::FSub: case Instruction::Mul: case Instruction::FMul: - case Instruction::UDiv: - case Instruction::SDiv: case Instruction::FDiv: - case Instruction::URem: - case Instruction::SRem: case Instruction::FRem: case Instruction::Shl: case Instruction::LShr: @@ -1040,9 +1177,7 @@ static bool canEvaluateShuffled(Value *V, ArrayRef<int> Mask, case Instruction::FPExt: case Instruction::GetElementPtr: { // Bail out if we would create longer vector ops. We could allow creating - // longer vector ops, but that may result in more expensive codegen. We - // would also need to limit the transform to avoid undefined behavior for - // integer div/rem. + // longer vector ops, but that may result in more expensive codegen. Type *ITy = I->getType(); if (ITy->isVectorTy() && Mask.size() > ITy->getVectorNumElements()) return false; diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index 385f4926b845..ecb486c544e0 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -200,8 +200,8 @@ bool InstCombiner::shouldChangeType(Type *From, Type *To) const { // where both B and C should be ConstantInts, results in a constant that does // not overflow. This function only handles the Add and Sub opcodes. For // all other opcodes, the function conservatively returns false. -static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { - OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); +static bool maintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); if (!OBO || !OBO->hasNoSignedWrap()) return false; @@ -224,10 +224,15 @@ static bool MaintainNoSignedWrap(BinaryOperator &I, Value *B, Value *C) { } static bool hasNoUnsignedWrap(BinaryOperator &I) { - OverflowingBinaryOperator *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); return OBO && OBO->hasNoUnsignedWrap(); } +static bool hasNoSignedWrap(BinaryOperator &I) { + auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I); + return OBO && OBO->hasNoSignedWrap(); +} + /// Conservatively clears subclassOptionalData after a reassociation or /// commutation. We preserve fast-math flags when applicable as they can be /// preserved. @@ -332,22 +337,21 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // It simplifies to V. Form "A op V". I.setOperand(0, A); I.setOperand(1, V); - // Conservatively clear the optional flags, since they may not be - // preserved by the reassociation. bool IsNUW = hasNoUnsignedWrap(I) && hasNoUnsignedWrap(*Op0); - bool IsNSW = MaintainNoSignedWrap(I, B, C); + bool IsNSW = maintainNoSignedWrap(I, B, C) && hasNoSignedWrap(*Op0); + // Conservatively clear all optional flags since they may not be + // preserved by the reassociation. Reset nsw/nuw based on the above + // analysis. ClearSubclassDataAfterReassociation(I); + // Note: this is only valid because SimplifyBinOp doesn't look at + // the operands to Op0. if (IsNUW) I.setHasNoUnsignedWrap(true); - if (IsNSW && - (!Op0 || (isa<BinaryOperator>(Op0) && Op0->hasNoSignedWrap()))) { - // Note: this is only valid because SimplifyBinOp doesn't look at - // the operands to Op0. + if (IsNSW) I.setHasNoSignedWrap(true); - } Changed = true; ++NumReassoc; @@ -610,7 +614,6 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, HasNUW &= ROBO->hasNoUnsignedWrap(); } - const APInt *CInt; if (TopLevelOpcode == Instruction::Add && InnerOpcode == Instruction::Mul) { // We can propagate 'nsw' if we know that @@ -620,6 +623,7 @@ Value *InstCombiner::tryFactorization(BinaryOperator &I, // %Z = mul nsw i16 %X, C+1 // // iff C+1 isn't INT_MIN + const APInt *CInt; if (match(V, m_APInt(CInt))) { if (!CInt->isMinSignedValue()) BO->setHasNoSignedWrap(HasNSW); @@ -763,12 +767,16 @@ Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + + FastMathFlags FMF; BuilderTy::FastMathFlagGuard Guard(Builder); - if (isa<FPMathOperator>(&I)) - Builder.setFastMathFlags(I.getFastMathFlags()); + if (isa<FPMathOperator>(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } - Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); - Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + Value *V1 = SimplifyBinOp(Opcode, C, E, FMF, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyBinOp(Opcode, B, D, FMF, SQ.getWithInstruction(&I)); if (V1 && V2) SI = Builder.CreateSelect(A, V2, V1); else if (V2 && SelectsHaveOneUse) @@ -1659,7 +1667,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // to an index of zero, so replace it with zero if it is not zero already. Type *EltTy = GTI.getIndexedType(); if (EltTy->isSized() && DL.getTypeAllocSize(EltTy) == 0) - if (!isa<Constant>(*I) || !cast<Constant>(*I)->isNullValue()) { + if (!isa<Constant>(*I) || !match(I->get(), m_Zero())) { *I = Constant::getNullValue(NewIndexType); MadeChange = true; } @@ -2549,9 +2557,7 @@ Instruction *InstCombiner::visitReturnInst(ReturnInst &RI) { Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Change br (not X), label True, label False to: br X, label False, True Value *X = nullptr; - BasicBlock *TrueDest; - BasicBlock *FalseDest; - if (match(&BI, m_Br(m_Not(m_Value(X)), TrueDest, FalseDest)) && + if (match(&BI, m_Br(m_Not(m_Value(X)), m_BasicBlock(), m_BasicBlock())) && !isa<Constant>(X)) { // Swap Destinations and condition... BI.setCondition(X); @@ -2569,8 +2575,8 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // Canonicalize, for example, icmp_ne -> icmp_eq or fcmp_one -> fcmp_oeq. CmpInst::Predicate Pred; - if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), TrueDest, - FalseDest)) && + if (match(&BI, m_Br(m_OneUse(m_Cmp(Pred, m_Value(), m_Value())), + m_BasicBlock(), m_BasicBlock())) && !isCanonicalPredicate(Pred)) { // Swap destinations and condition. CmpInst *Cond = cast<CmpInst>(BI.getCondition()); @@ -3156,6 +3162,21 @@ static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { findDbgUsers(DbgUsers, I); for (auto *DII : reverse(DbgUsers)) { if (DII->getParent() == SrcBlock) { + if (isa<DbgDeclareInst>(DII)) { + // A dbg.declare instruction should not be cloned, since there can only be + // one per variable fragment. It should be left in the original place since + // sunk instruction is not an alloca(otherwise we could not be here). + // But we need to update arguments of dbg.declare instruction, so that it + // would not point into sunk instruction. + if (!isa<CastInst>(I)) + continue; // dbg.declare points at something it shouldn't + + DII->setOperand( + 0, MetadataAsValue::get(I->getContext(), + ValueAsMetadata::get(I->getOperand(0)))); + continue; + } + // dbg.value is in the same basic block as the sunk inst, see if we can // salvage it. Clone a new copy of the instruction: on success we need // both salvaged and unsalvaged copies. @@ -3580,7 +3601,7 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { // Required analyses. auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); |