diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2020-01-17 20:45:01 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2020-01-17 20:45:01 +0000 |
| commit | 706b4fc47bbc608932d3b491ae19a3b9cde9497b (patch) | |
| tree | 4adf86a776049cbf7f69a1929c4babcbbef925eb /llvm/lib/Transforms/InstCombine | |
| parent | 7cc9cf2bf09f069cb2dd947ead05d0b54301fb71 (diff) | |
Notes
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
14 files changed, 1087 insertions, 308 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 8bc34825f8a7..ec976a971e3c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -890,6 +890,10 @@ Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { if (match(Op0, m_ZExt(m_Value(X))) && X->getType()->getScalarSizeInBits() == 1) return SelectInst::Create(X, AddOne(Op1C), Op1); + // sext(bool) + C -> bool ? C - 1 : C + if (match(Op0, m_SExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, SubOne(Op1C), Op1); // ~X + C --> (C-1) - X if (match(Op0, m_Not(m_Value(X)))) @@ -1288,12 +1292,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { return BinaryOperator::CreateSub(RHS, A); } - // Canonicalize sext to zext for better value tracking potential. - // add A, sext(B) --> sub A, zext(B) - if (match(&I, m_c_Add(m_Value(A), m_OneUse(m_SExt(m_Value(B))))) && - B->getType()->isIntOrIntVectorTy(1)) - return BinaryOperator::CreateSub(A, Builder.CreateZExt(B, Ty)); - // A + -B --> A - B if (match(RHS, m_Neg(m_Value(B)))) return BinaryOperator::CreateSub(LHS, B); @@ -1587,7 +1585,7 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, - Type *Ty) { + Type *Ty, bool IsNUW) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize // this. bool Swapped = false; @@ -1655,6 +1653,15 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, // Emit the offset of the GEP and an intptr_t. Value *Result = EmitGEPOffset(GEP1); + // If this is a single inbounds GEP and the original sub was nuw, + // then the final multiplication is also nuw. We match an extra add zero + // here, because that's what EmitGEPOffset() generates. + Instruction *I; + if (IsNUW && !GEP2 && !Swapped && GEP1->isInBounds() && + match(Result, m_Add(m_Instruction(I), m_Zero())) && + I->getOpcode() == Instruction::Mul) + I->setHasNoUnsignedWrap(); + // If we had a constant expression GEP on the other side offsetting the // pointer, subtract it from the offset we have. if (GEP2) { @@ -1881,6 +1888,74 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Y, Builder.CreateNot(Op1, Op1->getName() + ".not")); } + { + // (sub (and Op1, (neg X)), Op1) --> neg (and Op1, (add X, -1)) + Value *X; + if (match(Op0, m_OneUse(m_c_And(m_Specific(Op1), + m_OneUse(m_Neg(m_Value(X))))))) { + return BinaryOperator::CreateNeg(Builder.CreateAnd( + Op1, Builder.CreateAdd(X, Constant::getAllOnesValue(I.getType())))); + } + } + + { + // (sub (and Op1, C), Op1) --> neg (and Op1, ~C) + Constant *C; + if (match(Op0, m_OneUse(m_And(m_Specific(Op1), m_Constant(C))))) { + return BinaryOperator::CreateNeg( + Builder.CreateAnd(Op1, Builder.CreateNot(C))); + } + } + + { + // If we have a subtraction between some value and a select between + // said value and something else, sink subtraction into select hands, i.e.: + // sub (select %Cond, %TrueVal, %FalseVal), %Op1 + // -> + // select %Cond, (sub %TrueVal, %Op1), (sub %FalseVal, %Op1) + // or + // sub %Op0, (select %Cond, %TrueVal, %FalseVal) + // -> + // select %Cond, (sub %Op0, %TrueVal), (sub %Op0, %FalseVal) + // This will result in select between new subtraction and 0. + auto SinkSubIntoSelect = + [Ty = I.getType()](Value *Select, Value *OtherHandOfSub, + auto SubBuilder) -> Instruction * { + Value *Cond, *TrueVal, *FalseVal; + if (!match(Select, m_OneUse(m_Select(m_Value(Cond), m_Value(TrueVal), + m_Value(FalseVal))))) + return nullptr; + if (OtherHandOfSub != TrueVal && OtherHandOfSub != FalseVal) + return nullptr; + // While it is really tempting to just create two subtractions and let + // InstCombine fold one of those to 0, it isn't possible to do so + // because of worklist visitation order. So ugly it is. + bool OtherHandOfSubIsTrueVal = OtherHandOfSub == TrueVal; + Value *NewSub = SubBuilder(OtherHandOfSubIsTrueVal ? FalseVal : TrueVal); + Constant *Zero = Constant::getNullValue(Ty); + SelectInst *NewSel = + SelectInst::Create(Cond, OtherHandOfSubIsTrueVal ? Zero : NewSub, + OtherHandOfSubIsTrueVal ? NewSub : Zero); + // Preserve prof metadata if any. + NewSel->copyMetadata(cast<Instruction>(*Select)); + return NewSel; + }; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op0, /*OtherHandOfSub=*/Op1, + [Builder = &Builder, Op1](Value *OtherHandOfSelect) { + return Builder->CreateSub(OtherHandOfSelect, + /*OtherHandOfSub=*/Op1); + })) + return NewSel; + if (Instruction *NewSel = SinkSubIntoSelect( + /*Select=*/Op1, /*OtherHandOfSub=*/Op0, + [Builder = &Builder, Op0](Value *OtherHandOfSelect) { + return Builder->CreateSub(/*OtherHandOfSub=*/Op0, + OtherHandOfSelect); + })) + return NewSel; + } + if (Op1->hasOneUse()) { Value *X = nullptr, *Y = nullptr, *Z = nullptr; Constant *C = nullptr; @@ -1896,14 +1971,16 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Builder.CreateNot(Y, Y->getName() + ".not")); // 0 - (X sdiv C) -> (X sdiv -C) provided the negation doesn't overflow. - // TODO: This could be extended to match arbitrary vector constants. - const APInt *DivC; - if (match(Op0, m_Zero()) && match(Op1, m_SDiv(m_Value(X), m_APInt(DivC))) && - !DivC->isMinSignedValue() && *DivC != 1) { - Constant *NegDivC = ConstantInt::get(I.getType(), -(*DivC)); - Instruction *BO = BinaryOperator::CreateSDiv(X, NegDivC); - BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); - return BO; + if (match(Op0, m_Zero())) { + Constant *Op11C; + if (match(Op1, m_SDiv(m_Value(X), m_Constant(Op11C))) && + !Op11C->containsUndefElement() && Op11C->isNotMinSignedValue() && + Op11C->isNotOneValue()) { + Instruction *BO = + BinaryOperator::CreateSDiv(X, ConstantExpr::getNeg(Op11C)); + BO->setIsExact(cast<BinaryOperator>(Op1)->isExact()); + return BO; + } } // 0 - (X << Y) -> (-X << Y) when X is freely negatable. @@ -1921,6 +1998,14 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Add->setHasNoSignedWrap(I.hasNoSignedWrap()); return Add; } + // sub [nsw] X, zext(bool Y) -> add [nsw] X, sext(bool Y) + // 'nuw' is dropped in favor of the canonical form. + if (match(Op1, m_ZExt(m_Value(Y))) && Y->getType()->isIntOrIntVectorTy(1)) { + Value *Sext = Builder.CreateSExt(Y, I.getType()); + BinaryOperator *Add = BinaryOperator::CreateAdd(Op0, Sext); + Add->setHasNoSignedWrap(I.hasNoSignedWrap()); + return Add; + } // X - A*-B -> X + A*B // X - -A*B -> X + A*B @@ -1975,13 +2060,15 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Value *LHSOp, *RHSOp; if (match(Op0, m_PtrToInt(m_Value(LHSOp))) && match(Op1, m_PtrToInt(m_Value(RHSOp)))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + I.hasNoUnsignedWrap())) return replaceInstUsesWith(I, Res); // trunc(p)-trunc(q) -> trunc(p-q) if (match(Op0, m_Trunc(m_PtrToInt(m_Value(LHSOp)))) && match(Op1, m_Trunc(m_PtrToInt(m_Value(RHSOp))))) - if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType())) + if (Value *Res = OptimizePointerDifference(LHSOp, RHSOp, I.getType(), + /* IsNUW */ false)) return replaceInstUsesWith(I, Res); // Canonicalize a shifty way to code absolute value to the common pattern. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 4a30b60ca931..cc0a9127f8b1 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -3279,6 +3279,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { NotLHS, NotRHS); } } + + // Pull 'not' into operands of select if both operands are one-use compares. + // Inverting the predicates eliminates the 'not' operation. + // Example: + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) + // TODO: Canonicalize by hoisting 'not' into an arm of the select if only + // 1 select operand is a cmp? + if (auto *Sel = dyn_cast<SelectInst>(Op0)) { + auto *CmpT = dyn_cast<CmpInst>(Sel->getTrueValue()); + auto *CmpF = dyn_cast<CmpInst>(Sel->getFalseValue()); + if (CmpT && CmpF && CmpT->hasOneUse() && CmpF->hasOneUse()) { + CmpT->setPredicate(CmpT->getInversePredicate()); + CmpF->setPredicate(CmpF->getInversePredicate()); + return replaceInstUsesWith(I, Sel); + } + } } if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index c650d242cd50..f463c5fa1138 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -40,6 +40,12 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsX86.h" +#include "llvm/IR/IntrinsicsARM.h" +#include "llvm/IR/IntrinsicsAArch64.h" +#include "llvm/IR/IntrinsicsNVPTX.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsPowerPC.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" @@ -2279,6 +2285,35 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::copysign: { + if (SignBitMustBeZero(II->getArgOperand(1), &TLI)) { + // If we know that the sign argument is positive, reduce to FABS: + // copysign X, Pos --> fabs X + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, + II->getArgOperand(0), II); + return replaceInstUsesWith(*II, Fabs); + } + // TODO: There should be a ValueTracking sibling like SignBitMustBeOne. + const APFloat *C; + if (match(II->getArgOperand(1), m_APFloat(C)) && C->isNegative()) { + // If we know that the sign argument is negative, reduce to FNABS: + // copysign X, Neg --> fneg (fabs X) + Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, + II->getArgOperand(0), II); + return replaceInstUsesWith(*II, Builder.CreateFNegFMF(Fabs, II)); + } + + // Propagate sign argument through nested calls: + // copysign X, (copysign ?, SignArg) --> copysign X, SignArg + Value *SignArg; + if (match(II->getArgOperand(1), + m_Intrinsic<Intrinsic::copysign>(m_Value(), m_Value(SignArg)))) { + II->setArgOperand(1, SignArg); + return II; + } + + break; + } case Intrinsic::fabs: { Value *Cond; Constant *LHS, *RHS; @@ -2452,6 +2487,64 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // TODO should we convert this to an AND if the RHS is constant? } break; + case Intrinsic::x86_bmi_pext_32: + case Intrinsic::x86_bmi_pext_64: + if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + if (MaskC->isNullValue()) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + if (MaskC->isAllOnesValue()) + return replaceInstUsesWith(CI, II->getArgOperand(0)); + + if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Src = SrcC->getZExtValue(); + uint64_t Mask = MaskC->getZExtValue(); + uint64_t Result = 0; + uint64_t BitToSet = 1; + + while (Mask) { + // Isolate lowest set bit. + uint64_t BitToTest = Mask & -Mask; + if (BitToTest & Src) + Result |= BitToSet; + + BitToSet <<= 1; + // Clear lowest set bit. + Mask &= Mask - 1; + } + + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + } + break; + case Intrinsic::x86_bmi_pdep_32: + case Intrinsic::x86_bmi_pdep_64: + if (auto *MaskC = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + if (MaskC->isNullValue()) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + if (MaskC->isAllOnesValue()) + return replaceInstUsesWith(CI, II->getArgOperand(0)); + + if (auto *SrcC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Src = SrcC->getZExtValue(); + uint64_t Mask = MaskC->getZExtValue(); + uint64_t Result = 0; + uint64_t BitToTest = 1; + + while (Mask) { + // Isolate lowest set bit. + uint64_t BitToSet = Mask & -Mask; + if (BitToTest & Src) + Result |= BitToSet; + + BitToTest <<= 1; + // Clear lowest set bit; + Mask &= Mask - 1; + } + + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + } + break; case Intrinsic::x86_vcvtph2ps_128: case Intrinsic::x86_vcvtph2ps_256: { @@ -3308,6 +3401,60 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::arm_mve_pred_i2v: { + Value *Arg = II->getArgOperand(0); + Value *ArgArg; + if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg))) && + II->getType() == ArgArg->getType()) + return replaceInstUsesWith(*II, ArgArg); + Constant *XorMask; + if (match(Arg, + m_Xor(m_Intrinsic<Intrinsic::arm_mve_pred_v2i>(m_Value(ArgArg)), + m_Constant(XorMask))) && + II->getType() == ArgArg->getType()) { + if (auto *CI = dyn_cast<ConstantInt>(XorMask)) { + if (CI->getValue().trunc(16).isAllOnesValue()) { + auto TrueVector = Builder.CreateVectorSplat( + II->getType()->getVectorNumElements(), Builder.getTrue()); + return BinaryOperator::Create(Instruction::Xor, ArgArg, TrueVector); + } + } + } + KnownBits ScalarKnown(32); + if (SimplifyDemandedBits(II, 0, APInt::getLowBitsSet(32, 16), + ScalarKnown, 0)) + return II; + break; + } + case Intrinsic::arm_mve_pred_v2i: { + Value *Arg = II->getArgOperand(0); + Value *ArgArg; + if (match(Arg, m_Intrinsic<Intrinsic::arm_mve_pred_i2v>(m_Value(ArgArg)))) + return replaceInstUsesWith(*II, ArgArg); + if (!II->getMetadata(LLVMContext::MD_range)) { + Type *IntTy32 = Type::getInt32Ty(II->getContext()); + Metadata *M[] = { + ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0)), + ConstantAsMetadata::get(ConstantInt::get(IntTy32, 0xFFFF)) + }; + II->setMetadata(LLVMContext::MD_range, MDNode::get(II->getContext(), M)); + return II; + } + break; + } + case Intrinsic::arm_mve_vadc: + case Intrinsic::arm_mve_vadc_predicated: { + unsigned CarryOp = + (II->getIntrinsicID() == Intrinsic::arm_mve_vadc_predicated) ? 3 : 2; + assert(II->getArgOperand(CarryOp)->getType()->getScalarSizeInBits() == 32 && + "Bad type for intrinsic!"); + + KnownBits CarryKnown(32); + if (SimplifyDemandedBits(II, CarryOp, APInt::getOneBitSet(32, 29), + CarryKnown)) + return II; + break; + } case Intrinsic::amdgcn_rcp: { Value *Src = II->getArgOperand(0); @@ -3317,7 +3464,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (const ConstantFP *C = dyn_cast<ConstantFP>(Src)) { const APFloat &ArgVal = C->getValueAPF(); - APFloat Val(ArgVal.getSemantics(), 1.0); + APFloat Val(ArgVal.getSemantics(), 1); APFloat::opStatus Status = Val.divide(ArgVal, APFloat::rmNearestTiesToEven); // Only do this if it was exact and therefore not dependent on the @@ -3872,7 +4019,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return eraseInstFromFunction(CI); // Bail if we cross over an intrinsic with side effects, such as - // llvm.stacksave, llvm.read_register, or llvm.setjmp. + // llvm.stacksave, or llvm.read_register. if (II2->mayHaveSideEffects()) { CannotRemove = true; break; @@ -4019,12 +4166,12 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { // Is this guard followed by another guard? We scan forward over a small // fixed window of instructions to handle common cases with conditions // computed between guards. - Instruction *NextInst = II->getNextNode(); + Instruction *NextInst = II->getNextNonDebugInstruction(); for (unsigned i = 0; i < GuardWideningWindow; i++) { // Note: Using context-free form to avoid compile time blow up if (!isSafeToSpeculativelyExecute(NextInst)) break; - NextInst = NextInst->getNextNode(); + NextInst = NextInst->getNextNonDebugInstruction(); } Value *NextCond = nullptr; if (match(NextInst, @@ -4032,18 +4179,18 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { Value *CurrCond = II->getArgOperand(0); // Remove a guard that it is immediately preceded by an identical guard. - if (CurrCond == NextCond) - return eraseInstFromFunction(*NextInst); - // Otherwise canonicalize guard(a); guard(b) -> guard(a & b). - Instruction* MoveI = II->getNextNode(); - while (MoveI != NextInst) { - auto *Temp = MoveI; - MoveI = MoveI->getNextNode(); - Temp->moveBefore(II); + if (CurrCond != NextCond) { + Instruction *MoveI = II->getNextNonDebugInstruction(); + while (MoveI != NextInst) { + auto *Temp = MoveI; + MoveI = MoveI->getNextNonDebugInstruction(); + Temp->moveBefore(II); + } + II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); } - II->setArgOperand(0, Builder.CreateAnd(CurrCond, NextCond)); - return eraseInstFromFunction(*NextInst); + eraseInstFromFunction(*NextInst); + return II; } break; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 65aaef28d87a..71b7f279e5fa 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -18,6 +18,7 @@ #include "llvm/IR/DIBuilder.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" +#include <numeric> using namespace llvm; using namespace PatternMatch; @@ -843,33 +844,33 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { return nullptr; } -Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, +Instruction *InstCombiner::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, bool DoTransform) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. const APInt *Op1CV; - if (match(ICI->getOperand(1), m_APInt(Op1CV))) { + if (match(Cmp->getOperand(1), m_APInt(Op1CV))) { // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { - if (!DoTransform) return ICI; + if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || + (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { + if (!DoTransform) return Cmp; - Value *In = ICI->getOperand(0); + Value *In = Cmp->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), In->getType()->getScalarSizeInBits() - 1); In = Builder.CreateLShr(In, Sh, In->getName() + ".lobit"); - if (In->getType() != CI.getType()) - In = Builder.CreateIntCast(In, CI.getType(), false /*ZExt*/); + if (In->getType() != Zext.getType()) + In = Builder.CreateIntCast(In, Zext.getType(), false /*ZExt*/); - if (ICI->getPredicate() == ICmpInst::ICMP_SGT) { + if (Cmp->getPredicate() == ICmpInst::ICMP_SGT) { Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One, In->getName() + ".not"); } - return replaceInstUsesWith(CI, In); + return replaceInstUsesWith(Zext, In); } // zext (X == 0) to i32 --> X^1 iff X has only the low bit set. @@ -882,24 +883,24 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && // This only works for EQ and NE - ICI->isEquality()) { + Cmp->isEquality()) { // If Op1C some other power of two, convert: - KnownBits Known = computeKnownBits(ICI->getOperand(0), 0, &CI); + KnownBits Known = computeKnownBits(Cmp->getOperand(0), 0, &Zext); APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoTransform) return ICI; + if (!DoTransform) return Cmp; - bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; + bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(CI.getType(), isNE); - return replaceInstUsesWith(CI, Res); + Constant *Res = ConstantInt::get(Zext.getType(), isNE); + return replaceInstUsesWith(Zext, Res); } uint32_t ShAmt = KnownZeroMask.logBase2(); - Value *In = ICI->getOperand(0); + Value *In = Cmp->getOperand(0); if (ShAmt) { // Perform a logical shr by shiftamt. // Insert the shift to put the result in the low bit. @@ -912,11 +913,11 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, In = Builder.CreateXor(In, One); } - if (CI.getType() == In->getType()) - return replaceInstUsesWith(CI, In); + if (Zext.getType() == In->getType()) + return replaceInstUsesWith(Zext, In); - Value *IntCast = Builder.CreateIntCast(In, CI.getType(), false); - return replaceInstUsesWith(CI, IntCast); + Value *IntCast = Builder.CreateIntCast(In, Zext.getType(), false); + return replaceInstUsesWith(Zext, IntCast); } } } @@ -924,19 +925,19 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // icmp ne A, B is equal to xor A, B when A and B only really have one bit. // It is also profitable to transform icmp eq into not(xor(A, B)) because that // may lead to additional simplifications. - if (ICI->isEquality() && CI.getType() == ICI->getOperand(0)->getType()) { - if (IntegerType *ITy = dyn_cast<IntegerType>(CI.getType())) { - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); + if (Cmp->isEquality() && Zext.getType() == Cmp->getOperand(0)->getType()) { + if (IntegerType *ITy = dyn_cast<IntegerType>(Zext.getType())) { + Value *LHS = Cmp->getOperand(0); + Value *RHS = Cmp->getOperand(1); - KnownBits KnownLHS = computeKnownBits(LHS, 0, &CI); - KnownBits KnownRHS = computeKnownBits(RHS, 0, &CI); + KnownBits KnownLHS = computeKnownBits(LHS, 0, &Zext); + KnownBits KnownRHS = computeKnownBits(RHS, 0, &Zext); if (KnownLHS.Zero == KnownRHS.Zero && KnownLHS.One == KnownRHS.One) { APInt KnownBits = KnownLHS.Zero | KnownLHS.One; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoTransform) return ICI; + if (!DoTransform) return Cmp; Value *Result = Builder.CreateXor(LHS, RHS); @@ -949,10 +950,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, Result = Builder.CreateLShr( Result, ConstantInt::get(ITy, UnknownBit.countTrailingZeros())); - if (ICI->getPredicate() == ICmpInst::ICMP_EQ) + if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) Result = Builder.CreateXor(Result, ConstantInt::get(ITy, 1)); - Result->takeName(ICI); - return replaceInstUsesWith(CI, Result); + Result->takeName(Cmp); + return replaceInstUsesWith(Zext, Result); } } } @@ -1172,8 +1173,8 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { } } - if (ICmpInst *ICI = dyn_cast<ICmpInst>(Src)) - return transformZExtICmp(ICI, CI); + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) + return transformZExtICmp(Cmp, CI); BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); if (SrcI && SrcI->getOpcode() == Instruction::Or) { @@ -1188,7 +1189,9 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName()); Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName()); - BinaryOperator *Or = BinaryOperator::Create(Instruction::Or, LCast, RCast); + Value *Or = Builder.CreateOr(LCast, RCast, CI.getName()); + if (auto *OrInst = dyn_cast<Instruction>(Or)) + Builder.SetInsertPoint(OrInst); // Perform the elimination. if (auto *LZExt = dyn_cast<ZExtInst>(LCast)) @@ -1196,7 +1199,7 @@ Instruction *InstCombiner::visitZExt(ZExtInst &CI) { if (auto *RZExt = dyn_cast<ZExtInst>(RCast)) transformZExtICmp(RHS, *RZExt); - return Or; + return replaceInstUsesWith(CI, Or); } } @@ -1621,6 +1624,11 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { Value *X; Instruction *Op = dyn_cast<Instruction>(FPT.getOperand(0)); if (Op && Op->hasOneUse()) { + // FIXME: The FMF should propagate from the fptrunc, not the source op. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + if (isa<FPMathOperator>(Op)) + Builder.setFastMathFlags(Op->getFastMathFlags()); + if (match(Op, m_FNeg(m_Value(X)))) { Value *InnerTrunc = Builder.CreateFPTrunc(X, Ty); @@ -1630,6 +1638,24 @@ Instruction *InstCombiner::visitFPTrunc(FPTruncInst &FPT) { return BinaryOperator::CreateFNegFMF(InnerTrunc, Op); return UnaryOperator::CreateFNegFMF(InnerTrunc, Op); } + + // If we are truncating a select that has an extended operand, we can + // narrow the other operand and do the select as a narrow op. + Value *Cond, *X, *Y; + if (match(Op, m_Select(m_Value(Cond), m_FPExt(m_Value(X)), m_Value(Y))) && + X->getType() == Ty) { + // fptrunc (select Cond, (fpext X), Y --> select Cond, X, (fptrunc Y) + Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); + Value *Sel = Builder.CreateSelect(Cond, X, NarrowY, "narrow.sel", Op); + return replaceInstUsesWith(FPT, Sel); + } + if (match(Op, m_Select(m_Value(Cond), m_Value(Y), m_FPExt(m_Value(X)))) && + X->getType() == Ty) { + // fptrunc (select Cond, Y, (fpext X) --> select Cond, (fptrunc Y), X + Value *NarrowY = Builder.CreateFPTrunc(Y, Ty); + Value *Sel = Builder.CreateSelect(Cond, NarrowY, X, "narrow.sel", Op); + return replaceInstUsesWith(FPT, Sel); + } } if (auto *II = dyn_cast<IntrinsicInst>(FPT.getOperand(0))) { @@ -1808,7 +1834,7 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { Type *Ty = CI.getType(); unsigned AS = CI.getPointerAddressSpace(); - if (Ty->getScalarSizeInBits() == DL.getIndexSizeInBits(AS)) + if (Ty->getScalarSizeInBits() == DL.getPointerSizeInBits(AS)) return commonPointerCastTransforms(CI); Type *PtrTy = DL.getIntPtrType(CI.getContext(), AS); @@ -1820,12 +1846,24 @@ Instruction *InstCombiner::visitPtrToInt(PtrToIntInst &CI) { } /// This input value (which is known to have vector type) is being zero extended -/// or truncated to the specified vector type. +/// or truncated to the specified vector type. Since the zext/trunc is done +/// using an integer type, we have a (bitcast(cast(bitcast))) pattern, +/// endianness will impact which end of the vector that is extended or +/// truncated. +/// +/// A vector is always stored with index 0 at the lowest address, which +/// corresponds to the most significant bits for a big endian stored integer and +/// the least significant bits for little endian. A trunc/zext of an integer +/// impacts the big end of the integer. Thus, we need to add/remove elements at +/// the front of the vector for big endian targets, and the back of the vector +/// for little endian targets. +/// /// Try to replace it with a shuffle (and vector/vector bitcast) if possible. /// /// The source and destination vector types may have different element types. -static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy, - InstCombiner &IC) { +static Instruction *optimizeVectorResizeWithIntegerBitCasts(Value *InVal, + VectorType *DestTy, + InstCombiner &IC) { // We can only do this optimization if the output is a multiple of the input // element size, or the input is a multiple of the output element size. // Convert the input type to have the same element type as the output. @@ -1844,31 +1882,53 @@ static Instruction *optimizeVectorResize(Value *InVal, VectorType *DestTy, InVal = IC.Builder.CreateBitCast(InVal, SrcTy); } + bool IsBigEndian = IC.getDataLayout().isBigEndian(); + unsigned SrcElts = SrcTy->getNumElements(); + unsigned DestElts = DestTy->getNumElements(); + + assert(SrcElts != DestElts && "Element counts should be different."); + // Now that the element types match, get the shuffle mask and RHS of the // shuffle to use, which depends on whether we're increasing or decreasing the // size of the input. - SmallVector<uint32_t, 16> ShuffleMask; + SmallVector<uint32_t, 16> ShuffleMaskStorage; + ArrayRef<uint32_t> ShuffleMask; Value *V2; - if (SrcTy->getNumElements() > DestTy->getNumElements()) { - // If we're shrinking the number of elements, just shuffle in the low - // elements from the input and use undef as the second shuffle input. - V2 = UndefValue::get(SrcTy); - for (unsigned i = 0, e = DestTy->getNumElements(); i != e; ++i) - ShuffleMask.push_back(i); + // Produce an identify shuffle mask for the src vector. + ShuffleMaskStorage.resize(SrcElts); + std::iota(ShuffleMaskStorage.begin(), ShuffleMaskStorage.end(), 0); + if (SrcElts > DestElts) { + // If we're shrinking the number of elements (rewriting an integer + // truncate), just shuffle in the elements corresponding to the least + // significant bits from the input and use undef as the second shuffle + // input. + V2 = UndefValue::get(SrcTy); + // Make sure the shuffle mask selects the "least significant bits" by + // keeping elements from back of the src vector for big endian, and from the + // front for little endian. + ShuffleMask = ShuffleMaskStorage; + if (IsBigEndian) + ShuffleMask = ShuffleMask.take_back(DestElts); + else + ShuffleMask = ShuffleMask.take_front(DestElts); } else { - // If we're increasing the number of elements, shuffle in all of the - // elements from InVal and fill the rest of the result elements with zeros - // from a constant zero. + // If we're increasing the number of elements (rewriting an integer zext), + // shuffle in all of the elements from InVal. Fill the rest of the result + // elements with zeros from a constant zero. V2 = Constant::getNullValue(SrcTy); - unsigned SrcElts = SrcTy->getNumElements(); - for (unsigned i = 0, e = SrcElts; i != e; ++i) - ShuffleMask.push_back(i); - - // The excess elements reference the first element of the zero input. - for (unsigned i = 0, e = DestTy->getNumElements()-SrcElts; i != e; ++i) - ShuffleMask.push_back(SrcElts); + // Use first elt from V2 when indicating zero in the shuffle mask. + uint32_t NullElt = SrcElts; + // Extend with null values in the "most significant bits" by adding elements + // in front of the src vector for big endian, and at the back for little + // endian. + unsigned DeltaElts = DestElts - SrcElts; + if (IsBigEndian) + ShuffleMaskStorage.insert(ShuffleMaskStorage.begin(), DeltaElts, NullElt); + else + ShuffleMaskStorage.append(DeltaElts, NullElt); + ShuffleMask = ShuffleMaskStorage; } return new ShuffleVectorInst(InVal, V2, @@ -2217,6 +2277,31 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { } } + // Check that each user of each old PHI node is something that we can + // rewrite, so that all of the old PHI nodes can be cleaned up afterwards. + for (auto *OldPN : OldPhiNodes) { + for (User *V : OldPN->users()) { + if (auto *SI = dyn_cast<StoreInst>(V)) { + if (!SI->isSimple() || SI->getOperand(0) != OldPN) + return nullptr; + } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { + // Verify it's a B->A cast. + Type *TyB = BCI->getOperand(0)->getType(); + Type *TyA = BCI->getType(); + if (TyA != DestTy || TyB != SrcTy) + return nullptr; + } else if (auto *PHI = dyn_cast<PHINode>(V)) { + // As long as the user is another old PHI node, then even if we don't + // rewrite it, the PHI web we're considering won't have any users + // outside itself, so it'll be dead. + if (OldPhiNodes.count(PHI) == 0) + return nullptr; + } else { + return nullptr; + } + } + } + // For each old PHI node, create a corresponding new PHI node with a type A. SmallDenseMap<PHINode *, PHINode *> NewPNodes; for (auto *OldPN : OldPhiNodes) { @@ -2234,9 +2319,14 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { if (auto *C = dyn_cast<Constant>(V)) { NewV = ConstantExpr::getBitCast(C, DestTy); } else if (auto *LI = dyn_cast<LoadInst>(V)) { - Builder.SetInsertPoint(LI->getNextNode()); - NewV = Builder.CreateBitCast(LI, DestTy); - Worklist.Add(LI); + // Explicitly perform load combine to make sure no opposing transform + // can remove the bitcast in the meantime and trigger an infinite loop. + Builder.SetInsertPoint(LI); + NewV = combineLoadToNewType(*LI, DestTy); + // Remove the old load and its use in the old phi, which itself becomes + // dead once the whole transform finishes. + replaceInstUsesWith(*LI, UndefValue::get(LI->getType())); + eraseInstFromFunction(*LI); } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { NewV = BCI->getOperand(0); } else if (auto *PrevPN = dyn_cast<PHINode>(V)) { @@ -2259,26 +2349,33 @@ Instruction *InstCombiner::optimizeBitCastFromPhi(CastInst &CI, PHINode *PN) { Instruction *RetVal = nullptr; for (auto *OldPN : OldPhiNodes) { PHINode *NewPN = NewPNodes[OldPN]; - for (User *V : OldPN->users()) { + for (auto It = OldPN->user_begin(), End = OldPN->user_end(); It != End; ) { + User *V = *It; + // We may remove this user, advance to avoid iterator invalidation. + ++It; if (auto *SI = dyn_cast<StoreInst>(V)) { - if (SI->isSimple() && SI->getOperand(0) == OldPN) { - Builder.SetInsertPoint(SI); - auto *NewBC = - cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); - SI->setOperand(0, NewBC); - Worklist.Add(SI); - assert(hasStoreUsersOnly(*NewBC)); - } + assert(SI->isSimple() && SI->getOperand(0) == OldPN); + Builder.SetInsertPoint(SI); + auto *NewBC = + cast<BitCastInst>(Builder.CreateBitCast(NewPN, SrcTy)); + SI->setOperand(0, NewBC); + Worklist.Add(SI); + assert(hasStoreUsersOnly(*NewBC)); } else if (auto *BCI = dyn_cast<BitCastInst>(V)) { - // Verify it's a B->A cast. Type *TyB = BCI->getOperand(0)->getType(); Type *TyA = BCI->getType(); - if (TyA == DestTy && TyB == SrcTy) { - Instruction *I = replaceInstUsesWith(*BCI, NewPN); - if (BCI == &CI) - RetVal = I; - } + assert(TyA == DestTy && TyB == SrcTy); + (void) TyA; + (void) TyB; + Instruction *I = replaceInstUsesWith(*BCI, NewPN); + if (BCI == &CI) + RetVal = I; + } else if (auto *PHI = dyn_cast<PHINode>(V)) { + assert(OldPhiNodes.count(PHI) > 0); + (void) PHI; + } else { + llvm_unreachable("all uses should be handled"); } } } @@ -2374,8 +2471,8 @@ Instruction *InstCombiner::visitBitCast(BitCastInst &CI) { CastInst *SrcCast = cast<CastInst>(Src); if (BitCastInst *BCIn = dyn_cast<BitCastInst>(SrcCast->getOperand(0))) if (isa<VectorType>(BCIn->getOperand(0)->getType())) - if (Instruction *I = optimizeVectorResize(BCIn->getOperand(0), - cast<VectorType>(DestTy), *this)) + if (Instruction *I = optimizeVectorResizeWithIntegerBitCasts( + BCIn->getOperand(0), cast<VectorType>(DestTy), *this)) return I; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index a9f64feb600c..f38dc436722d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -2566,9 +2566,6 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, Type *Ty = Add->getType(); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (!Add->hasOneUse()) - return nullptr; - // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE // are canonicalized to SGT/SLT/UGT/ULT. @@ -2602,6 +2599,9 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); } + if (!Add->hasOneUse()) + return nullptr; + // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 @@ -3364,6 +3364,23 @@ static Value *foldICmpWithLowBitMaskedVal(ICmpInst &I, llvm_unreachable("All possible folds are handled."); } + // The mask value may be a vector constant that has undefined elements. But it + // may not be safe to propagate those undefs into the new compare, so replace + // those elements by copying an existing, defined, and safe scalar constant. + Type *OpTy = M->getType(); + auto *VecC = dyn_cast<Constant>(M); + if (OpTy->isVectorTy() && VecC && VecC->containsUndefElement()) { + Constant *SafeReplacementConstant = nullptr; + for (unsigned i = 0, e = OpTy->getVectorNumElements(); i != e; ++i) { + if (!isa<UndefValue>(VecC->getAggregateElement(i))) { + SafeReplacementConstant = VecC->getAggregateElement(i); + break; + } + } + assert(SafeReplacementConstant && "Failed to find undef replacement"); + M = Constant::replaceUndefsWith(VecC, SafeReplacementConstant); + } + return Builder.CreateICmp(DstPred, X, M); } @@ -4930,7 +4947,7 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { // Get scalar or pointer size. unsigned BitWidth = Ty->isIntOrIntVectorTy() ? Ty->getScalarSizeInBits() - : DL.getIndexTypeSizeInBits(Ty->getScalarType()); + : DL.getPointerTypeSizeInBits(Ty->getScalarType()); if (!BitWidth) return nullptr; @@ -5167,6 +5184,7 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, return WillIncrement ? !C->isMaxValue(IsSigned) : !C->isMinValue(IsSigned); }; + Constant *SafeReplacementConstant = nullptr; if (auto *CI = dyn_cast<ConstantInt>(C)) { // Bail out if the constant can't be safely incremented/decremented. if (!ConstantIsOk(CI)) @@ -5186,12 +5204,23 @@ llvm::getFlippedStrictnessPredicateAndConstant(CmpInst::Predicate Pred, auto *CI = dyn_cast<ConstantInt>(Elt); if (!CI || !ConstantIsOk(CI)) return llvm::None; + + if (!SafeReplacementConstant) + SafeReplacementConstant = CI; } } else { // ConstantExpr? return llvm::None; } + // It may not be safe to change a compare predicate in the presence of + // undefined elements, so replace those elements with the first safe constant + // that we found. + if (C->containsUndefElement()) { + assert(SafeReplacementConstant && "Replacement constant not set"); + C = Constant::replaceUndefsWith(C, SafeReplacementConstant); + } + CmpInst::Predicate NewPred = CmpInst::getFlippedStrictnessPredicate(Pred); // Increment or decrement the constant. @@ -5374,6 +5403,36 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, return nullptr; } +// extract(uadd.with.overflow(A, B), 0) ult A +// -> extract(uadd.with.overflow(A, B), 1) +static Instruction *foldICmpOfUAddOv(ICmpInst &I) { + CmpInst::Predicate Pred = I.getPredicate(); + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + + Value *UAddOv; + Value *A, *B; + auto UAddOvResultPat = m_ExtractValue<0>( + m_Intrinsic<Intrinsic::uadd_with_overflow>(m_Value(A), m_Value(B))); + if (match(Op0, UAddOvResultPat) && + ((Pred == ICmpInst::ICMP_ULT && (Op1 == A || Op1 == B)) || + (Pred == ICmpInst::ICMP_EQ && match(Op1, m_ZeroInt()) && + (match(A, m_One()) || match(B, m_One()))) || + (Pred == ICmpInst::ICMP_NE && match(Op1, m_AllOnes()) && + (match(A, m_AllOnes()) || match(B, m_AllOnes()))))) + // extract(uadd.with.overflow(A, B), 0) < A + // extract(uadd.with.overflow(A, 1), 0) == 0 + // extract(uadd.with.overflow(A, -1), 0) != -1 + UAddOv = cast<ExtractValueInst>(Op0)->getAggregateOperand(); + else if (match(Op1, UAddOvResultPat) && + Pred == ICmpInst::ICMP_UGT && (Op0 == A || Op0 == B)) + // A > extract(uadd.with.overflow(A, B), 0) + UAddOv = cast<ExtractValueInst>(Op1)->getAggregateOperand(); + else + return nullptr; + + return ExtractValueInst::Create(UAddOv, 1); +} + Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -5562,6 +5621,9 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpEquality(I)) return Res; + if (Instruction *Res = foldICmpOfUAddOv(I)) + return Res; + // The 'cmpxchg' instruction returns an aggregate containing the old value and // an i1 which indicates whether or not we successfully did the swap. // diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index 1dbc06d92e7a..1a746cb87abb 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -369,7 +369,8 @@ public: Instruction *visitFNeg(UnaryOperator &I); Instruction *visitAdd(BinaryOperator &I); Instruction *visitFAdd(BinaryOperator &I); - Value *OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty); + Value *OptimizePointerDifference( + Value *LHS, Value *RHS, Type *Ty, bool isNUW); Instruction *visitSub(BinaryOperator &I); Instruction *visitFSub(BinaryOperator &I); Instruction *visitMul(BinaryOperator &I); @@ -446,6 +447,7 @@ public: Instruction *visitLandingPadInst(LandingPadInst &LI); Instruction *visitVAStartInst(VAStartInst &I); Instruction *visitVACopyInst(VACopyInst &I); + Instruction *visitFreeze(FreezeInst &I); /// Specify what to return for unhandled instructions. Instruction *visitInstruction(Instruction &I) { return nullptr; } @@ -465,6 +467,9 @@ public: /// \return true if successful. bool replacePointer(Instruction &I, Value *V); + LoadInst *combineLoadToNewType(LoadInst &LI, Type *NewTy, + const Twine &Suffix = ""); + private: bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; @@ -705,7 +710,7 @@ public: Instruction *eraseInstFromFunction(Instruction &I) { LLVM_DEBUG(dbgs() << "IC: ERASE " << I << '\n'); assert(I.use_empty() && "Cannot erase instruction that is used!"); - salvageDebugInfo(I); + salvageDebugInfoOrMarkUndef(I); // Make sure that we reprocess all operands now that we reduced their // use counts. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 3a0e05832fcb..ebf9d24eecc4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -449,8 +449,8 @@ static bool isSupportedAtomicType(Type *Ty) { /// /// Note that this will create all of the instructions with whatever insert /// point the \c InstCombiner currently is using. -static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewTy, - const Twine &Suffix = "") { +LoadInst *InstCombiner::combineLoadToNewType(LoadInst &LI, Type *NewTy, + const Twine &Suffix) { assert((!LI.isAtomic() || isSupportedAtomicType(NewTy)) && "can't fold an atomic load to requested type"); @@ -460,10 +460,17 @@ static LoadInst *combineLoadToNewType(InstCombiner &IC, LoadInst &LI, Type *NewT if (!(match(Ptr, m_BitCast(m_Value(NewPtr))) && NewPtr->getType()->getPointerElementType() == NewTy && NewPtr->getType()->getPointerAddressSpace() == AS)) - NewPtr = IC.Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); + NewPtr = Builder.CreateBitCast(Ptr, NewTy->getPointerTo(AS)); - LoadInst *NewLoad = IC.Builder.CreateAlignedLoad( - NewTy, NewPtr, LI.getAlignment(), LI.isVolatile(), LI.getName() + Suffix); + unsigned Align = LI.getAlignment(); + if (!Align) + // If old load did not have an explicit alignment specified, + // manually preserve the implied (ABI) alignment of the load. + // Else we may inadvertently incorrectly over-promise alignment. + Align = getDataLayout().getABITypeAlignment(LI.getType()); + + LoadInst *NewLoad = Builder.CreateAlignedLoad( + NewTy, NewPtr, Align, LI.isVolatile(), LI.getName() + Suffix); NewLoad->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); copyMetadataForLoad(*NewLoad, LI); return NewLoad; @@ -526,7 +533,7 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value /// Returns true if instruction represent minmax pattern like: /// select ((cmp load V1, load V2), V1, V2). -static bool isMinMaxWithLoads(Value *V) { +static bool isMinMaxWithLoads(Value *V, Type *&LoadTy) { assert(V->getType()->isPointerTy() && "Expected pointer type."); // Ignore possible ty* to ixx* bitcast. V = peekThroughBitcast(V); @@ -540,6 +547,7 @@ static bool isMinMaxWithLoads(Value *V) { if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), m_Value(LHS), m_Value(RHS)))) return false; + LoadTy = L1->getType(); return (match(L1, m_Load(m_Specific(LHS))) && match(L2, m_Load(m_Specific(RHS)))) || (match(L1, m_Load(m_Specific(RHS))) && @@ -585,20 +593,22 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // size is a legal integer type. // Do not perform canonicalization if minmax pattern is found (to avoid // infinite loop). + Type *Dummy; if (!Ty->isIntegerTy() && Ty->isSized() && + !(Ty->isVectorTy() && Ty->getVectorIsScalable()) && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.typeSizeEqualsStoreSize(Ty) && !DL.isNonIntegralPointerType(Ty) && !isMinMaxWithLoads( - peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { + peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true), + Dummy)) { if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast<StoreInst>(U); return SI && SI->getPointerOperand() != &LI && !SI->getPointerOperand()->isSwiftError(); })) { - LoadInst *NewLoad = combineLoadToNewType( - IC, LI, - Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty))); + LoadInst *NewLoad = IC.combineLoadToNewType( + LI, Type::getIntNTy(LI.getContext(), DL.getTypeStoreSizeInBits(Ty))); // Replace all the stores with stores of the newly loaded value. for (auto UI = LI.user_begin(), UE = LI.user_end(); UI != UE;) { auto *SI = cast<StoreInst>(*UI++); @@ -620,7 +630,7 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { if (auto* CI = dyn_cast<CastInst>(LI.user_back())) if (CI->isNoopCast(DL)) if (!LI.isAtomic() || isSupportedAtomicType(CI->getDestTy())) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, CI->getDestTy()); + LoadInst *NewLoad = IC.combineLoadToNewType(LI, CI->getDestTy()); CI->replaceAllUsesWith(NewLoad); IC.eraseInstFromFunction(*CI); return &LI; @@ -648,8 +658,8 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { // If the struct only have one element, we unpack. auto NumElements = ST->getNumElements(); if (NumElements == 1) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, ST->getTypeAtIndex(0U), - ".unpack"); + LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U), + ".unpack"); AAMDNodes AAMD; LI.getAAMetadata(AAMD); NewLoad->setAAMetadata(AAMD); @@ -698,7 +708,7 @@ static Instruction *unpackLoadToAggregate(InstCombiner &IC, LoadInst &LI) { auto *ET = AT->getElementType(); auto NumElements = AT->getNumElements(); if (NumElements == 1) { - LoadInst *NewLoad = combineLoadToNewType(IC, LI, ET, ".unpack"); + LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); AAMDNodes AAMD; LI.getAAMetadata(AAMD); NewLoad->setAAMetadata(AAMD); @@ -1322,7 +1332,14 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, auto *LI = cast<LoadInst>(SI.getValueOperand()); if (!LI->getType()->isIntegerTy()) return false; - if (!isMinMaxWithLoads(LoadAddr)) + Type *CmpLoadTy; + if (!isMinMaxWithLoads(LoadAddr, CmpLoadTy)) + return false; + + // Make sure we're not changing the size of the load/store. + const auto &DL = IC.getDataLayout(); + if (DL.getTypeStoreSizeInBits(LI->getType()) != + DL.getTypeStoreSizeInBits(CmpLoadTy)) return false; if (!all_of(LI->users(), [LI, LoadAddr](User *U) { @@ -1334,8 +1351,7 @@ static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, return false; IC.Builder.SetInsertPoint(LI); - LoadInst *NewLI = combineLoadToNewType( - IC, *LI, LoadAddr->getType()->getPointerElementType()); + LoadInst *NewLI = IC.combineLoadToNewType(*LI, CmpLoadTy); // Replace all the stores with stores of the newly loaded value. for (auto *UI : LI->users()) { auto *USI = cast<StoreInst>(UI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 0b9128a9f5a1..2774e46151fa 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -1239,6 +1239,14 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Value *YZ = Builder.CreateFMulFMF(Y, Op0, &I); return BinaryOperator::CreateFDivFMF(YZ, X, &I); } + // Z / (1.0 / Y) => (Y * Z) + // + // This is a special case of Z / (X / Y) => (Y * Z) / X, with X = 1.0. The + // m_OneUse check is avoided because even in the case of the multiple uses + // for 1.0/Y, the number of instructions remain the same and a division is + // replaced by a multiplication. + if (match(Op1, m_FDiv(m_SpecificFP(1.0), m_Value(Y)))) + return BinaryOperator::CreateFMulFMF(Y, Op0, &I); } if (I.hasAllowReassoc() && Op0->hasOneUse() && Op1->hasOneUse()) { @@ -1368,8 +1376,10 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { } // 1 urem X -> zext(X != 1) - if (match(Op0, m_One())) - return CastInst::CreateZExtOrBitCast(Builder.CreateICmpNE(Op1, Op0), Ty); + if (match(Op0, m_One())) { + Value *Cmp = Builder.CreateICmpNE(Op1, ConstantInt::get(Ty, 1)); + return CastInst::CreateZExtOrBitCast(Cmp, Ty); + } // X urem C -> X < C ? X : X - C, where C >= signbit. if (match(Op1, m_Negative())) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index e0376b7582f3..74e015a4f1d4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -14,9 +14,10 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace llvm::PatternMatch; @@ -180,13 +181,14 @@ Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { "Not enough available ptr typed incoming values"); PHINode *MatchingPtrPHI = nullptr; unsigned NumPhis = 0; - for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI()); - II != EI; II++, NumPhis++) { + for (auto II = BB->begin(); II != BB->end(); II++, NumPhis++) { // FIXME: consider handling this in AggressiveInstCombine + PHINode *PtrPHI = dyn_cast<PHINode>(II); + if (!PtrPHI) + break; if (NumPhis > MaxNumPhis) return nullptr; - PHINode *PtrPHI = dyn_cast<PHINode>(II); - if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) + if (PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) continue; MatchingPtrPHI = PtrPHI; for (unsigned i = 0; i != PtrPHI->getNumIncomingValues(); ++i) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index 9fc871e49b30..05a624fde86b 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -704,16 +704,24 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI, assert((Pred == ICmpInst::ICMP_UGE || Pred == ICmpInst::ICMP_UGT) && "Unexpected isUnsigned predicate!"); - // Account for swapped form of subtraction: ((a > b) ? b - a : 0). + // Ensure the sub is of the form: + // (a > b) ? a - b : 0 -> usub.sat(a, b) + // (a > b) ? b - a : 0 -> -usub.sat(a, b) + // Checking for both a-b and a+(-b) as a constant. bool IsNegative = false; - if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A)))) + const APInt *C; + if (match(TrueVal, m_Sub(m_Specific(B), m_Specific(A))) || + (match(A, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(B), m_SpecificInt(-*C))))) IsNegative = true; - else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B)))) + else if (!match(TrueVal, m_Sub(m_Specific(A), m_Specific(B))) && + !(match(B, m_APInt(C)) && + match(TrueVal, m_Add(m_Specific(A), m_SpecificInt(-*C))))) return nullptr; - // If sub is used anywhere else, we wouldn't be able to eliminate it - // afterwards. - if (!TrueVal->hasOneUse()) + // If we are adding a negate and the sub and icmp are used anywhere else, we + // would end up with more instructions. + if (IsNegative && !TrueVal->hasOneUse() && !ICI->hasOneUse()) return nullptr; // (a > b) ? a - b : 0 -> usub.sat(a, b) @@ -781,6 +789,13 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal, return Builder.CreateBinaryIntrinsic( Intrinsic::uadd_sat, BO->getOperand(0), BO->getOperand(1)); } + // The overflow may be detected via the add wrapping round. + if (match(Cmp0, m_c_Add(m_Specific(Cmp1), m_Value(Y))) && + match(FVal, m_c_Add(m_Specific(Cmp1), m_Specific(Y)))) { + // ((X + Y) u< X) ? -1 : (X + Y) --> uadd.sat(X, Y) + // ((X + Y) u< Y) ? -1 : (X + Y) --> uadd.sat(X, Y) + return Builder.CreateBinaryIntrinsic(Intrinsic::uadd_sat, Cmp1, Y); + } return nullptr; } @@ -1725,6 +1740,128 @@ static Instruction *foldAddSubSelect(SelectInst &SI, return nullptr; } +/// Turn X + Y overflows ? -1 : X + Y -> uadd_sat X, Y +/// And X - Y overflows ? 0 : X - Y -> usub_sat X, Y +/// Along with a number of patterns similar to: +/// X + Y overflows ? (X < 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y +/// X - Y overflows ? (X > 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y +static Instruction * +foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { + Value *CondVal = SI.getCondition(); + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + + WithOverflowInst *II; + if (!match(CondVal, m_ExtractValue<1>(m_WithOverflowInst(II))) || + !match(FalseVal, m_ExtractValue<0>(m_Specific(II)))) + return nullptr; + + Value *X = II->getLHS(); + Value *Y = II->getRHS(); + + auto IsSignedSaturateLimit = [&](Value *Limit, bool IsAdd) { + Type *Ty = Limit->getType(); + + ICmpInst::Predicate Pred; + Value *TrueVal, *FalseVal, *Op; + const APInt *C; + if (!match(Limit, m_Select(m_ICmp(Pred, m_Value(Op), m_APInt(C)), + m_Value(TrueVal), m_Value(FalseVal)))) + return false; + + auto IsZeroOrOne = [](const APInt &C) { + return C.isNullValue() || C.isOneValue(); + }; + auto IsMinMax = [&](Value *Min, Value *Max) { + APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); + return match(Min, m_SpecificInt(MinVal)) && + match(Max, m_SpecificInt(MaxVal)); + }; + + if (Op != X && Op != Y) + return false; + + if (IsAdd) { + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + if (Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(FalseVal, TrueVal)) + return true; + } else { + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == X && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 2) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SLT && IsZeroOrOne(*C) && + IsMinMax(FalseVal, TrueVal)) + return true; + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + if (Op == Y && Pred == ICmpInst::ICMP_SGT && IsZeroOrOne(*C + 1) && + IsMinMax(TrueVal, FalseVal)) + return true; + } + + return false; + }; + + Intrinsic::ID NewIntrinsicID; + if (II->getIntrinsicID() == Intrinsic::uadd_with_overflow && + match(TrueVal, m_AllOnes())) + // X + Y overflows ? -1 : X + Y -> uadd_sat X, Y + NewIntrinsicID = Intrinsic::uadd_sat; + else if (II->getIntrinsicID() == Intrinsic::usub_with_overflow && + match(TrueVal, m_Zero())) + // X - Y overflows ? 0 : X - Y -> usub_sat X, Y + NewIntrinsicID = Intrinsic::usub_sat; + else if (II->getIntrinsicID() == Intrinsic::sadd_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/true)) + // X + Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 0 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y <s 1 ? INTMIN : INTMAX) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s 0 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + // X + Y overflows ? (Y >s -1 ? INTMAX : INTMIN) : X + Y --> sadd_sat X, Y + NewIntrinsicID = Intrinsic::sadd_sat; + else if (II->getIntrinsicID() == Intrinsic::ssub_with_overflow && + IsSignedSaturateLimit(TrueVal, /*IsAdd=*/false)) + // X - Y overflows ? (X <s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X <s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (X >s -2 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 0 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y <s 1 ? INTMAX : INTMIN) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s 0 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + // X - Y overflows ? (Y >s -1 ? INTMIN : INTMAX) : X - Y --> ssub_sat X, Y + NewIntrinsicID = Intrinsic::ssub_sat; + else + return nullptr; + + Function *F = + Intrinsic::getDeclaration(SI.getModule(), NewIntrinsicID, SI.getType()); + return CallInst::Create(F, {X, Y}); +} + Instruction *InstCombiner::foldSelectExtConst(SelectInst &Sel) { Constant *C; if (!match(Sel.getTrueValue(), m_Constant(C)) && @@ -2296,7 +2433,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { - if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { + Value *Cmp0 = FCI->getOperand(0), *Cmp1 = FCI->getOperand(1); + if ((Cmp0 == TrueVal && Cmp1 == FalseVal) || + (Cmp0 == FalseVal && Cmp1 == TrueVal)) { // Canonicalize to use ordered comparisons by swapping the select // operands. // @@ -2305,30 +2444,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { FCmpInst::Predicate InvPred = FCI->getInversePredicate(); IRBuilder<>::FastMathFlagGuard FMFG(Builder); + // FIXME: The FMF should propagate from the select, not the fcmp. Builder.setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, TrueVal, FalseVal, - FCI->getName() + ".inv"); - - return SelectInst::Create(NewCond, FalseVal, TrueVal, - SI.getName() + ".p"); - } - - // NOTE: if we wanted to, this is where to detect MIN/MAX - } else if (FCI->getOperand(0) == FalseVal && FCI->getOperand(1) == TrueVal){ - // Canonicalize to use ordered comparisons by swapping the select - // operands. - // - // e.g. - // (X ugt Y) ? X : Y -> (X ole Y) ? X : Y - if (FCI->hasOneUse() && FCmpInst::isUnordered(FCI->getPredicate())) { - FCmpInst::Predicate InvPred = FCI->getInversePredicate(); - IRBuilder<>::FastMathFlagGuard FMFG(Builder); - Builder.setFastMathFlags(FCI->getFastMathFlags()); - Value *NewCond = Builder.CreateFCmp(InvPred, FalseVal, TrueVal, + Value *NewCond = Builder.CreateFCmp(InvPred, Cmp0, Cmp1, FCI->getName() + ".inv"); - - return SelectInst::Create(NewCond, FalseVal, TrueVal, - SI.getName() + ".p"); + Value *NewSel = Builder.CreateSelect(NewCond, FalseVal, TrueVal); + return replaceInstUsesWith(SI, NewSel); } // NOTE: if we wanted to, this is where to detect MIN/MAX @@ -2391,6 +2512,8 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *Add = foldAddSubSelect(SI, Builder)) return Add; + if (Instruction *Add = foldOverflowingAddSubSelect(SI, Builder)) + return Add; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast<Instruction>(TrueVal); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 64294838644f..fbff5dd4a8cd 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -138,24 +138,6 @@ Value *InstCombiner::reassociateShiftAmtsOfTwoSameDirectionShifts( return Ret; } -// Try to replace `undef` constants in C with Replacement. -static Constant *replaceUndefsWith(Constant *C, Constant *Replacement) { - if (C && match(C, m_Undef())) - return Replacement; - - if (auto *CV = dyn_cast<ConstantVector>(C)) { - llvm::SmallVector<Constant *, 32> NewOps(CV->getNumOperands()); - for (unsigned i = 0, NumElts = NewOps.size(); i != NumElts; ++i) { - Constant *EltC = CV->getOperand(i); - NewOps[i] = EltC && match(EltC, m_Undef()) ? Replacement : EltC; - } - return ConstantVector::get(NewOps); - } - - // Don't know how to deal with this constant. - return C; -} - // If we have some pattern that leaves only some low bits set, and then performs // left-shift of those bits, if none of the bits that are left after the final // shift are modified by the mask, we can omit the mask. @@ -180,10 +162,20 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, "The input must be 'shl'!"); Value *Masked, *ShiftShAmt; - match(OuterShift, m_Shift(m_Value(Masked), m_Value(ShiftShAmt))); + match(OuterShift, + m_Shift(m_Value(Masked), m_ZExtOrSelf(m_Value(ShiftShAmt)))); + + // *If* there is a truncation between an outer shift and a possibly-mask, + // then said truncation *must* be one-use, else we can't perform the fold. + Value *Trunc; + if (match(Masked, m_CombineAnd(m_Trunc(m_Value(Masked)), m_Value(Trunc))) && + !Trunc->hasOneUse()) + return nullptr; Type *NarrowestTy = OuterShift->getType(); Type *WidestTy = Masked->getType(); + bool HadTrunc = WidestTy != NarrowestTy; + // The mask must be computed in a type twice as wide to ensure // that no bits are lost if the sum-of-shifts is wider than the base type. Type *ExtendedTy = WidestTy->getExtendedType(); @@ -204,6 +196,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, Constant *NewMask; if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (MaskShAmt->getType() != ShiftShAmt->getType()) + return nullptr; + // Can we simplify (MaskShAmt+ShiftShAmt) ? auto *SumOfShAmts = dyn_cast_or_null<Constant>(SimplifyAddInst( MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); @@ -216,7 +216,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // completely unknown. Replace the the `undef` shift amounts with final // shift bitwidth to ensure that the value remains undef when creating the // subsequent shift op. - SumOfShAmts = replaceUndefsWith( + SumOfShAmts = Constant::replaceUndefsWith( SumOfShAmts, ConstantInt::get(SumOfShAmts->getType()->getScalarType(), ExtendedTy->getScalarSizeInBits())); auto *ExtendedSumOfShAmts = ConstantExpr::getZExt(SumOfShAmts, ExtendedTy); @@ -228,6 +228,14 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, } else if (match(Masked, m_c_And(m_CombineOr(MaskC, MaskD), m_Value(X))) || match(Masked, m_Shr(m_Shl(m_Value(X), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)))) { + // Peek through an optional zext of the shift amount. + match(MaskShAmt, m_ZExtOrSelf(m_Value(MaskShAmt))); + + // We have two shift amounts from two different shifts. The types of those + // shift amounts may not match. If that's the case let's bailout now. + if (MaskShAmt->getType() != ShiftShAmt->getType()) + return nullptr; + // Can we simplify (ShiftShAmt-MaskShAmt) ? auto *ShAmtsDiff = dyn_cast_or_null<Constant>(SimplifySubInst( ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false, Q)); @@ -241,7 +249,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // bitwidth of innermost shift to ensure that the value remains undef when // creating the subsequent shift op. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); - ShAmtsDiff = replaceUndefsWith( + ShAmtsDiff = Constant::replaceUndefsWith( ShAmtsDiff, ConstantInt::get(ShAmtsDiff->getType()->getScalarType(), -WidestTyBitWidth)); auto *ExtendedNumHighBitsToClear = ConstantExpr::getZExt( @@ -272,10 +280,15 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return nullptr; } + // If we need to apply truncation, let's do it first, since we can. + // We have already ensured that the old truncation will go away. + if (HadTrunc) + X = Builder.CreateTrunc(X, NarrowestTy); + // No 'NUW'/'NSW'! We no longer know that we won't shift-out non-0 bits. + // We didn't change the Type of this outermost shift, so we can just do it. auto *NewShift = BinaryOperator::Create(OuterShift->getOpcode(), X, OuterShift->getOperand(1)); - if (!NeedMask) return NewShift; @@ -283,6 +296,50 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, return BinaryOperator::Create(Instruction::And, NewShift, NewMask); } +/// If we have a shift-by-constant of a bitwise logic op that itself has a +/// shift-by-constant operand with identical opcode, we may be able to convert +/// that into 2 independent shifts followed by the logic op. This eliminates a +/// a use of an intermediate value (reduces dependency chain). +static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + assert(I.isShift() && "Expected a shift as input"); + auto *LogicInst = dyn_cast<BinaryOperator>(I.getOperand(0)); + if (!LogicInst || !LogicInst->isBitwiseLogicOp() || !LogicInst->hasOneUse()) + return nullptr; + + const APInt *C0, *C1; + if (!match(I.getOperand(1), m_APInt(C1))) + return nullptr; + + Instruction::BinaryOps ShiftOpcode = I.getOpcode(); + Type *Ty = I.getType(); + + // Find a matching one-use shift by constant. The fold is not valid if the sum + // of the shift values equals or exceeds bitwidth. + // TODO: Remove the one-use check if the other logic operand (Y) is constant. + Value *X, *Y; + auto matchFirstShift = [&](Value *V) { + return !isa<ConstantExpr>(V) && + match(V, m_OneUse(m_Shift(m_Value(X), m_APInt(C0)))) && + cast<BinaryOperator>(V)->getOpcode() == ShiftOpcode && + (*C0 + *C1).ult(Ty->getScalarSizeInBits()); + }; + + // Logic ops are commutative, so check each operand for a match. + if (matchFirstShift(LogicInst->getOperand(0))) + Y = LogicInst->getOperand(1); + else if (matchFirstShift(LogicInst->getOperand(1))) + Y = LogicInst->getOperand(0); + else + return nullptr; + + // shift (logic (shift X, C0), Y), C1 -> logic (shift X, C0+C1), (shift Y, C1) + Constant *ShiftSumC = ConstantInt::get(Ty, *C0 + *C1); + Value *NewShift1 = Builder.CreateBinOp(ShiftOpcode, X, ShiftSumC); + Value *NewShift2 = Builder.CreateBinOp(ShiftOpcode, Y, I.getOperand(1)); + return BinaryOperator::Create(LogicInst->getOpcode(), NewShift1, NewShift2); +} + Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); assert(Op0->getType() == Op1->getType()); @@ -335,6 +392,9 @@ Instruction *InstCombiner::commonShiftTransforms(BinaryOperator &I) { return &I; } + if (Instruction *Logic = foldShiftOfShiftedLogic(I, Builder)) + return Logic; + return nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index d30ab8001897..47ce83974c8d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -14,6 +14,8 @@ #include "InstCombineInternal.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IntrinsicsAMDGPU.h" +#include "llvm/IR/IntrinsicsX86.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/KnownBits.h" @@ -348,8 +350,36 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, assert(!LHSKnown.hasConflict() && "Bits known to be one AND zero?"); // If the operands are constants, see if we can simplify them. - if (ShrinkDemandedConstant(I, 1, DemandedMask) || - ShrinkDemandedConstant(I, 2, DemandedMask)) + // This is similar to ShrinkDemandedConstant, but for a select we want to + // try to keep the selected constants the same as icmp value constants, if + // we can. This helps not break apart (or helps put back together) + // canonical patterns like min and max. + auto CanonicalizeSelectConstant = [](Instruction *I, unsigned OpNo, + APInt DemandedMask) { + const APInt *SelC; + if (!match(I->getOperand(OpNo), m_APInt(SelC))) + return false; + + // Get the constant out of the ICmp, if there is one. + const APInt *CmpC; + ICmpInst::Predicate Pred; + if (!match(I->getOperand(0), m_c_ICmp(Pred, m_APInt(CmpC), m_Value())) || + CmpC->getBitWidth() != SelC->getBitWidth()) + return ShrinkDemandedConstant(I, OpNo, DemandedMask); + + // If the constant is already the same as the ICmp, leave it as-is. + if (*CmpC == *SelC) + return false; + // If the constants are not already the same, but can be with the demand + // mask, use the constant value from the ICmp. + if ((*CmpC & DemandedMask) == (*SelC & DemandedMask)) { + I->setOperand(OpNo, ConstantInt::get(I->getType(), *CmpC)); + return true; + } + return ShrinkDemandedConstant(I, OpNo, DemandedMask); + }; + if (CanonicalizeSelectConstant(I, 1, DemandedMask) || + CanonicalizeSelectConstant(I, 2, DemandedMask)) return I; // Only known if known in both the LHS and RHS. @@ -1247,30 +1277,57 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } case Instruction::ShuffleVector: { - ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); - unsigned LHSVWidth = - Shuffle->getOperand(0)->getType()->getVectorNumElements(); - APInt LeftDemanded(LHSVWidth, 0), RightDemanded(LHSVWidth, 0); + auto *Shuffle = cast<ShuffleVectorInst>(I); + assert(Shuffle->getOperand(0)->getType() == + Shuffle->getOperand(1)->getType() && + "Expected shuffle operands to have same type"); + unsigned OpWidth = + Shuffle->getOperand(0)->getType()->getVectorNumElements(); + APInt LeftDemanded(OpWidth, 0), RightDemanded(OpWidth, 0); for (unsigned i = 0; i < VWidth; i++) { if (DemandedElts[i]) { unsigned MaskVal = Shuffle->getMaskValue(i); if (MaskVal != -1u) { - assert(MaskVal < LHSVWidth * 2 && + assert(MaskVal < OpWidth * 2 && "shufflevector mask index out of range!"); - if (MaskVal < LHSVWidth) + if (MaskVal < OpWidth) LeftDemanded.setBit(MaskVal); else - RightDemanded.setBit(MaskVal - LHSVWidth); + RightDemanded.setBit(MaskVal - OpWidth); } } } - APInt LHSUndefElts(LHSVWidth, 0); + APInt LHSUndefElts(OpWidth, 0); simplifyAndSetOp(I, 0, LeftDemanded, LHSUndefElts); - APInt RHSUndefElts(LHSVWidth, 0); + APInt RHSUndefElts(OpWidth, 0); simplifyAndSetOp(I, 1, RightDemanded, RHSUndefElts); + // If this shuffle does not change the vector length and the elements + // demanded by this shuffle are an identity mask, then this shuffle is + // unnecessary. + // + // We are assuming canonical form for the mask, so the source vector is + // operand 0 and operand 1 is not used. + // + // Note that if an element is demanded and this shuffle mask is undefined + // for that element, then the shuffle is not considered an identity + // operation. The shuffle prevents poison from the operand vector from + // leaking to the result by replacing poison with an undefined value. + if (VWidth == OpWidth) { + bool IsIdentityShuffle = true; + for (unsigned i = 0; i < VWidth; i++) { + unsigned MaskVal = Shuffle->getMaskValue(i); + if (DemandedElts[i] && i != MaskVal) { + IsIdentityShuffle = false; + break; + } + } + if (IsIdentityShuffle) + return Shuffle->getOperand(0); + } + bool NewUndefElts = false; unsigned LHSIdx = -1u, LHSValIdx = -1u; unsigned RHSIdx = -1u, RHSValIdx = -1u; @@ -1283,23 +1340,23 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, } else if (!DemandedElts[i]) { NewUndefElts = true; UndefElts.setBit(i); - } else if (MaskVal < LHSVWidth) { + } else if (MaskVal < OpWidth) { if (LHSUndefElts[MaskVal]) { NewUndefElts = true; UndefElts.setBit(i); } else { - LHSIdx = LHSIdx == -1u ? i : LHSVWidth; - LHSValIdx = LHSValIdx == -1u ? MaskVal : LHSVWidth; + LHSIdx = LHSIdx == -1u ? i : OpWidth; + LHSValIdx = LHSValIdx == -1u ? MaskVal : OpWidth; LHSUniform = LHSUniform && (MaskVal == i); } } else { - if (RHSUndefElts[MaskVal - LHSVWidth]) { + if (RHSUndefElts[MaskVal - OpWidth]) { NewUndefElts = true; UndefElts.setBit(i); } else { - RHSIdx = RHSIdx == -1u ? i : LHSVWidth; - RHSValIdx = RHSValIdx == -1u ? MaskVal - LHSVWidth : LHSVWidth; - RHSUniform = RHSUniform && (MaskVal - LHSVWidth == i); + RHSIdx = RHSIdx == -1u ? i : OpWidth; + RHSValIdx = RHSValIdx == -1u ? MaskVal - OpWidth : OpWidth; + RHSUniform = RHSUniform && (MaskVal - OpWidth == i); } } } @@ -1308,20 +1365,20 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, // this constant vector to single insertelement instruction. // shufflevector V, C, <v1, v2, .., ci, .., vm> -> // insertelement V, C[ci], ci-n - if (LHSVWidth == Shuffle->getType()->getNumElements()) { + if (OpWidth == Shuffle->getType()->getNumElements()) { Value *Op = nullptr; Constant *Value = nullptr; unsigned Idx = -1u; // Find constant vector with the single element in shuffle (LHS or RHS). - if (LHSIdx < LHSVWidth && RHSUniform) { + if (LHSIdx < OpWidth && RHSUniform) { if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(0))) { Op = Shuffle->getOperand(1); Value = CV->getOperand(LHSValIdx); Idx = LHSIdx; } } - if (RHSIdx < LHSVWidth && LHSUniform) { + if (RHSIdx < OpWidth && LHSUniform) { if (auto *CV = dyn_cast<ConstantVector>(Shuffle->getOperand(1))) { Op = Shuffle->getOperand(0); Value = CV->getOperand(RHSValIdx); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 9c890748e5ab..f604c9dc32ca 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -1390,20 +1390,6 @@ static Value *evaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { llvm_unreachable("failed to reorder elements of vector instruction!"); } -static void recognizeIdentityMask(const SmallVectorImpl<int> &Mask, - bool &isLHSID, bool &isRHSID) { - isLHSID = isRHSID = true; - - for (unsigned i = 0, e = Mask.size(); i != e; ++i) { - if (Mask[i] < 0) continue; // Ignore undef values. - // Is this an identity shuffle of the LHS value? - isLHSID &= (Mask[i] == (int)i); - - // Is this an identity shuffle of the RHS value? - isRHSID &= (Mask[i]-e == i); - } -} - // Returns true if the shuffle is extracting a contiguous range of values from // LHS, for example: // +--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+--+ @@ -1560,9 +1546,11 @@ static Instruction *foldSelectShuffle(ShuffleVectorInst &Shuf, if (!Shuf.isSelect()) return nullptr; - // Canonicalize to choose from operand 0 first. + // Canonicalize to choose from operand 0 first unless operand 1 is undefined. + // Commuting undef to operand 0 conflicts with another canonicalization. unsigned NumElts = Shuf.getType()->getVectorNumElements(); - if (Shuf.getMaskValue(0) >= (int)NumElts) { + if (!isa<UndefValue>(Shuf.getOperand(1)) && + Shuf.getMaskValue(0) >= (int)NumElts) { // TODO: Can we assert that both operands of a shuffle-select are not undef // (otherwise, it would have been folded by instsimplify? Shuf.commute(); @@ -1753,7 +1741,8 @@ static Instruction *foldIdentityExtractShuffle(ShuffleVectorInst &Shuf) { return new ShuffleVectorInst(X, Y, ConstantVector::get(NewMask)); } -/// Try to replace a shuffle with an insertelement. +/// Try to replace a shuffle with an insertelement or try to replace a shuffle +/// operand with the operand of an insertelement. static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { Value *V0 = Shuf.getOperand(0), *V1 = Shuf.getOperand(1); SmallVector<int, 16> Mask = Shuf.getShuffleMask(); @@ -1765,6 +1754,31 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf) { if (NumElts != (int)(V0->getType()->getVectorNumElements())) return nullptr; + // This is a specialization of a fold in SimplifyDemandedVectorElts. We may + // not be able to handle it there if the insertelement has >1 use. + // If the shuffle has an insertelement operand but does not choose the + // inserted scalar element from that value, then we can replace that shuffle + // operand with the source vector of the insertelement. + Value *X; + uint64_t IdxC; + if (match(V0, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // shuf (inselt X, ?, IdxC), ?, Mask --> shuf X, ?, Mask + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { + Shuf.setOperand(0, X); + return &Shuf; + } + } + if (match(V1, m_InsertElement(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { + // Offset the index constant by the vector width because we are checking for + // accesses to the 2nd vector input of the shuffle. + IdxC += NumElts; + // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask + if (none_of(Mask, [IdxC](int MaskElt) { return MaskElt == (int)IdxC; })) { + Shuf.setOperand(1, X); + return &Shuf; + } + } + // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { // We need an insertelement with a constant index. @@ -1891,29 +1905,21 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI))) return replaceInstUsesWith(SVI, V); - // Canonicalize shuffle(x ,x,mask) -> shuffle(x, undef,mask') - // Canonicalize shuffle(undef,x,mask) -> shuffle(x, undef,mask'). + // shuffle x, x, mask --> shuffle x, undef, mask' unsigned VWidth = SVI.getType()->getVectorNumElements(); unsigned LHSWidth = LHS->getType()->getVectorNumElements(); SmallVector<int, 16> Mask = SVI.getShuffleMask(); Type *Int32Ty = Type::getInt32Ty(SVI.getContext()); - if (LHS == RHS || isa<UndefValue>(LHS)) { + if (LHS == RHS) { + assert(!isa<UndefValue>(RHS) && "Shuffle with 2 undef ops not simplified?"); // Remap any references to RHS to use LHS. SmallVector<Constant*, 16> Elts; - for (unsigned i = 0, e = LHSWidth; i != VWidth; ++i) { - if (Mask[i] < 0) { - Elts.push_back(UndefValue::get(Int32Ty)); - continue; - } - - if ((Mask[i] >= (int)e && isa<UndefValue>(RHS)) || - (Mask[i] < (int)e && isa<UndefValue>(LHS))) { - Mask[i] = -1; // Turn into undef. + for (unsigned i = 0; i != VWidth; ++i) { + // Propagate undef elements or force mask to LHS. + if (Mask[i] < 0) Elts.push_back(UndefValue::get(Int32Ty)); - } else { - Mask[i] = Mask[i] % e; // Force to LHS. - Elts.push_back(ConstantInt::get(Int32Ty, Mask[i])); - } + else + Elts.push_back(ConstantInt::get(Int32Ty, Mask[i] % LHSWidth)); } SVI.setOperand(0, SVI.getOperand(1)); SVI.setOperand(1, UndefValue::get(RHS->getType())); @@ -1921,6 +1927,12 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return &SVI; } + // shuffle undef, x, mask --> shuffle x, undef, mask' + if (isa<UndefValue>(LHS)) { + SVI.commute(); + return &SVI; + } + if (Instruction *I = canonicalizeInsertSplat(SVI, Builder)) return I; @@ -1948,16 +1960,6 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Instruction *I = foldIdentityPaddedShuffles(SVI)) return I; - if (VWidth == LHSWidth) { - // Analyze the shuffle, are the LHS or RHS and identity shuffles? - bool isLHSID, isRHSID; - recognizeIdentityMask(Mask, isLHSID, isRHSID); - - // Eliminate identity shuffles. - if (isLHSID) return replaceInstUsesWith(SVI, LHS); - if (isRHSID) return replaceInstUsesWith(SVI, RHS); - } - if (isa<UndefValue>(RHS) && canEvaluateShuffled(LHS, Mask)) { Value *V = evaluateInDifferentElementOrder(LHS, Mask); return replaceInstUsesWith(SVI, V); @@ -2235,12 +2237,5 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return new ShuffleVectorInst(newLHS, newRHS, ConstantVector::get(Elts)); } - // If the result mask is an identity, replace uses of this instruction with - // corresponding argument. - bool isLHSID, isRHSID; - recognizeIdentityMask(newMask, isLHSID, isRHSID); - if (isLHSID && VWidth == LHSOp0Width) return replaceInstUsesWith(SVI, newLHS); - if (isRHSID && VWidth == RHSOp0Width) return replaceInstUsesWith(SVI, newRHS); - return MadeChange ? &SVI : nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index ecb486c544e0..801c09a317a7 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -86,6 +86,7 @@ #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/InitializePasses.h" #include "llvm/Pass.h" #include "llvm/Support/CBindingWrapping.h" #include "llvm/Support/Casting.h" @@ -121,6 +122,9 @@ STATISTIC(NumReassoc , "Number of reassociations"); DEBUG_COUNTER(VisitCounter, "instcombine-visit", "Controls which instructions are visited"); +static constexpr unsigned InstCombineDefaultMaxIterations = 1000; +static constexpr unsigned InstCombineDefaultInfiniteLoopThreshold = 1000; + static cl::opt<bool> EnableCodeSinking("instcombine-code-sinking", cl::desc("Enable code sinking"), cl::init(true)); @@ -129,6 +133,17 @@ static cl::opt<bool> EnableExpensiveCombines("expensive-combines", cl::desc("Enable expensive instruction combines")); +static cl::opt<unsigned> LimitMaxIterations( + "instcombine-max-iterations", + cl::desc("Limit the maximum number of instruction combining iterations"), + cl::init(InstCombineDefaultMaxIterations)); + +static cl::opt<unsigned> InfiniteLoopDetectionThreshold( + "instcombine-infinite-loop-threshold", + cl::desc("Number of instruction combining iterations considered an " + "infinite loop"), + cl::init(InstCombineDefaultInfiniteLoopThreshold), cl::Hidden); + static cl::opt<unsigned> MaxArraySize("instcombine-maxarray-size", cl::init(1024), cl::desc("Maximum array size considered when doing a combine")); @@ -759,35 +774,52 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, Value *RHS) { - Instruction::BinaryOps Opcode = I.getOpcode(); - // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op - // c, e))) - Value *A, *B, *C, *D, *E; - Value *SI = nullptr; - if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && - match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { - bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + Value *A, *B, *C, *D, *E, *F; + bool LHSIsSelect = match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))); + bool RHSIsSelect = match(RHS, m_Select(m_Value(D), m_Value(E), m_Value(F))); + if (!LHSIsSelect && !RHSIsSelect) + return nullptr; - FastMathFlags FMF; - BuilderTy::FastMathFlagGuard Guard(Builder); - if (isa<FPMathOperator>(&I)) { - FMF = I.getFastMathFlags(); - Builder.setFastMathFlags(FMF); - } + FastMathFlags FMF; + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa<FPMathOperator>(&I)) { + FMF = I.getFastMathFlags(); + Builder.setFastMathFlags(FMF); + } - Value *V1 = SimplifyBinOp(Opcode, C, E, FMF, SQ.getWithInstruction(&I)); - Value *V2 = SimplifyBinOp(Opcode, B, D, FMF, SQ.getWithInstruction(&I)); - if (V1 && V2) - SI = Builder.CreateSelect(A, V2, V1); - else if (V2 && SelectsHaveOneUse) - SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E)); - else if (V1 && SelectsHaveOneUse) - SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1); + Instruction::BinaryOps Opcode = I.getOpcode(); + SimplifyQuery Q = SQ.getWithInstruction(&I); + + Value *Cond, *True = nullptr, *False = nullptr; + if (LHSIsSelect && RHSIsSelect && A == D) { + // (A ? B : C) op (A ? E : F) -> A ? (B op E) : (C op F) + Cond = A; + True = SimplifyBinOp(Opcode, B, E, FMF, Q); + False = SimplifyBinOp(Opcode, C, F, FMF, Q); - if (SI) - SI->takeName(&I); + if (LHS->hasOneUse() && RHS->hasOneUse()) { + if (False && !True) + True = Builder.CreateBinOp(Opcode, B, E); + else if (True && !False) + False = Builder.CreateBinOp(Opcode, C, F); + } + } else if (LHSIsSelect && LHS->hasOneUse()) { + // (A ? B : C) op Y -> A ? (B op Y) : (C op Y) + Cond = A; + True = SimplifyBinOp(Opcode, B, RHS, FMF, Q); + False = SimplifyBinOp(Opcode, C, RHS, FMF, Q); + } else if (RHSIsSelect && RHS->hasOneUse()) { + // X op (D ? E : F) -> D ? (X op E) : (X op F) + Cond = D; + True = SimplifyBinOp(Opcode, LHS, E, FMF, Q); + False = SimplifyBinOp(Opcode, LHS, F, FMF, Q); } + if (!True || !False) + return nullptr; + + Value *SI = Builder.CreateSelect(Cond, True, False); + SI->takeName(&I); return SI; } @@ -1526,11 +1558,13 @@ Instruction *InstCombiner::foldVectorBinop(BinaryOperator &Inst) { // If this is a widening shuffle, we must be able to extend with undef // elements. If the original binop does not produce an undef in the high // lanes, then this transform is not safe. + // Similarly for undef lanes due to the shuffle mask, we can only + // transform binops that preserve undef. // TODO: We could shuffle those non-undef constant values into the // result by using a constant vector (rather than an undef vector) // as operand 1 of the new binop, but that might be too aggressive // for target-independent shuffle creation. - if (I >= SrcVecNumElts) { + if (I >= SrcVecNumElts || ShMask[I] < 0) { Constant *MaybeUndef = ConstOp1 ? ConstantExpr::get(Opcode, UndefScalar, CElt) : ConstantExpr::get(Opcode, CElt, UndefScalar); @@ -1615,6 +1649,15 @@ Instruction *InstCombiner::narrowMathIfNoOverflow(BinaryOperator &BO) { return CastInst::Create(CastOpc, NarrowBO, BO.getType()); } +static bool isMergedGEPInBounds(GEPOperator &GEP1, GEPOperator &GEP2) { + // At least one GEP must be inbounds. + if (!GEP1.isInBounds() && !GEP2.isInBounds()) + return false; + + return (GEP1.isInBounds() || GEP1.hasAllZeroIndices()) && + (GEP2.isInBounds() || GEP2.hasAllZeroIndices()); +} + Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end()); Type *GEPType = GEP.getType(); @@ -1724,8 +1767,11 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // The first two arguments can vary for any GEP, the rest have to be // static for struct slots - if (J > 1 && CurTy->isStructTy()) - return nullptr; + if (J > 1) { + assert(CurTy && "No current type?"); + if (CurTy->isStructTy()) + return nullptr; + } DI = J; } else { @@ -1885,6 +1931,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Update the GEP in place if possible. if (Src->getNumOperands() == 2) { + GEP.setIsInBounds(isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP))); GEP.setOperand(0, Src->getOperand(0)); GEP.setOperand(1, Sum); return &GEP; @@ -1901,7 +1948,7 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { } if (!Indices.empty()) - return GEP.isInBounds() && Src->isInBounds() + return isMergedGEPInBounds(*Src, *cast<GEPOperator>(&GEP)) ? GetElementPtrInst::CreateInBounds( Src->getSourceElementType(), Src->getOperand(0), Indices, GEP.getName()) @@ -2154,15 +2201,17 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // of a bitcasted pointer to vector or array of the same dimensions: // gep (bitcast <c x ty>* X to [c x ty]*), Y, Z --> gep X, Y, Z // gep (bitcast [c x ty]* X to <c x ty>*), Y, Z --> gep X, Y, Z - auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy) { + auto areMatchingArrayAndVecTypes = [](Type *ArrTy, Type *VecTy, + const DataLayout &DL) { return ArrTy->getArrayElementType() == VecTy->getVectorElementType() && - ArrTy->getArrayNumElements() == VecTy->getVectorNumElements(); + ArrTy->getArrayNumElements() == VecTy->getVectorNumElements() && + DL.getTypeAllocSize(ArrTy) == DL.getTypeAllocSize(VecTy); }; if (GEP.getNumOperands() == 3 && ((GEPEltType->isArrayTy() && SrcEltType->isVectorTy() && - areMatchingArrayAndVecTypes(GEPEltType, SrcEltType)) || + areMatchingArrayAndVecTypes(GEPEltType, SrcEltType, DL)) || (GEPEltType->isVectorTy() && SrcEltType->isArrayTy() && - areMatchingArrayAndVecTypes(SrcEltType, GEPEltType)))) { + areMatchingArrayAndVecTypes(SrcEltType, GEPEltType, DL)))) { // Create a new GEP here, as using `setOperand()` followed by // `setSourceElementType()` won't actually update the type of the @@ -2401,12 +2450,13 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { replaceInstUsesWith(*C, ConstantInt::get(Type::getInt1Ty(C->getContext()), C->isFalseWhenEqual())); - } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I) || - isa<AddrSpaceCastInst>(I)) { - replaceInstUsesWith(*I, UndefValue::get(I->getType())); } else if (auto *SI = dyn_cast<StoreInst>(I)) { for (auto *DII : DIIs) ConvertDebugDeclareToDebugValue(DII, SI, *DIB); + } else { + // Casts, GEP, or anything else: we're about to delete this instruction, + // so it can not have any valid uses. + replaceInstUsesWith(*I, UndefValue::get(I->getType())); } eraseInstFromFunction(*I); } @@ -3111,6 +3161,15 @@ Instruction *InstCombiner::visitLandingPadInst(LandingPadInst &LI) { return nullptr; } +Instruction *InstCombiner::visitFreeze(FreezeInst &I) { + Value *Op0 = I.getOperand(0); + + if (Value *V = SimplifyFreezeInst(Op0, SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + return nullptr; +} + /// Try to move the specified instruction from its current block into the /// beginning of DestBlock, which can only happen if it's safe to move the /// instruction past all of the instructions between it and the end of its @@ -3322,10 +3381,6 @@ bool InstCombiner::run() { // Move the name to the new instruction first. Result->takeName(I); - // Push the new instruction and any users onto the worklist. - Worklist.AddUsersToWorkList(*Result); - Worklist.Add(Result); - // Insert the new instruction into the basic block... BasicBlock *InstParent = I->getParent(); BasicBlock::iterator InsertPos = I->getIterator(); @@ -3337,6 +3392,10 @@ bool InstCombiner::run() { InstParent->getInstList().insert(InsertPos, Result); + // Push the new instruction and any users onto the worklist. + Worklist.AddUsersToWorkList(*Result); + Worklist.Add(Result); + eraseInstFromFunction(*I); } else { LLVM_DEBUG(dbgs() << "IC: Mod = " << OrigI << '\n' @@ -3392,8 +3451,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; LLVM_DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); - if (!salvageDebugInfo(*Inst)) - replaceDbgUsesWithUndef(Inst); + salvageDebugInfoOrMarkUndef(*Inst); Inst->eraseFromParent(); MadeIRChange = true; continue; @@ -3507,10 +3565,11 @@ static bool combineInstructionsOverFunction( Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, - ProfileSummaryInfo *PSI, bool ExpensiveCombines = true, - LoopInfo *LI = nullptr) { + ProfileSummaryInfo *PSI, bool ExpensiveCombines, unsigned MaxIterations, + LoopInfo *LI) { auto &DL = F.getParent()->getDataLayout(); ExpensiveCombines |= EnableExpensiveCombines; + MaxIterations = std::min(MaxIterations, LimitMaxIterations.getValue()); /// Builder - This is an IRBuilder that automatically inserts new /// instructions into the worklist when they are created. @@ -3529,9 +3588,23 @@ static bool combineInstructionsOverFunction( MadeIRChange = LowerDbgDeclare(F); // Iterate while there is work to do. - int Iteration = 0; + unsigned Iteration = 0; while (true) { ++Iteration; + + if (Iteration > InfiniteLoopDetectionThreshold) { + report_fatal_error( + "Instruction Combining seems stuck in an infinite loop after " + + Twine(InfiniteLoopDetectionThreshold) + " iterations."); + } + + if (Iteration > MaxIterations) { + LLVM_DEBUG(dbgs() << "\n\n[IC] Iteration limit #" << MaxIterations + << " on " << F.getName() + << " reached; stopping before reaching a fixpoint\n"); + break; + } + LLVM_DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); @@ -3543,11 +3616,19 @@ static bool combineInstructionsOverFunction( if (!IC.run()) break; + + MadeIRChange = true; } - return MadeIRChange || Iteration > 1; + return MadeIRChange; } +InstCombinePass::InstCombinePass(bool ExpensiveCombines) + : ExpensiveCombines(ExpensiveCombines), MaxIterations(LimitMaxIterations) {} + +InstCombinePass::InstCombinePass(bool ExpensiveCombines, unsigned MaxIterations) + : ExpensiveCombines(ExpensiveCombines), MaxIterations(MaxIterations) {} + PreservedAnalyses InstCombinePass::run(Function &F, FunctionAnalysisManager &AM) { auto &AC = AM.getResult<AssumptionAnalysis>(F); @@ -3565,8 +3646,9 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto *BFI = (PSI && PSI->hasProfileSummary()) ? &AM.getResult<BlockFrequencyAnalysis>(F) : nullptr; - if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, - BFI, PSI, ExpensiveCombines, LI)) + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, + PSI, ExpensiveCombines, MaxIterations, + LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3615,12 +3697,26 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, - BFI, PSI, ExpensiveCombines, LI); + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, BFI, + PSI, ExpensiveCombines, MaxIterations, + LI); } char InstructionCombiningPass::ID = 0; +InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines) + : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines), + MaxIterations(InstCombineDefaultMaxIterations) { + initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); +} + +InstructionCombiningPass::InstructionCombiningPass(bool ExpensiveCombines, + unsigned MaxIterations) + : FunctionPass(ID), ExpensiveCombines(ExpensiveCombines), + MaxIterations(MaxIterations) { + initializeInstructionCombiningPassPass(*PassRegistry::getPassRegistry()); +} + INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -3647,6 +3743,11 @@ FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines) { return new InstructionCombiningPass(ExpensiveCombines); } +FunctionPass *llvm::createInstructionCombiningPass(bool ExpensiveCombines, + unsigned MaxIterations) { + return new InstructionCombiningPass(ExpensiveCombines, MaxIterations); +} + void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createInstructionCombiningPass()); } |
