diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-11-19 20:06:13 +0000 |
| commit | c0981da47d5696fe36474fcf86b4ce03ae3ff818 (patch) | |
| tree | f42add1021b9f2ac6a69ac7cf6c4499962739a45 /llvm/lib/Transforms/InstCombine | |
| parent | 344a3780b2e33f6ca763666c380202b18aab72a3 (diff) | |
Diffstat (limited to 'llvm/lib/Transforms/InstCombine')
15 files changed, 2468 insertions, 1600 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp index d01a021bf3f4..eb1b8a29cfc5 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -939,7 +939,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { // add (xor X, LowMaskC), C --> sub (LowMaskC + C), X if (C2->isMask()) { KnownBits LHSKnown = computeKnownBits(X, 0, &Add); - if ((*C2 | LHSKnown.Zero).isAllOnesValue()) + if ((*C2 | LHSKnown.Zero).isAllOnes()) return BinaryOperator::CreateSub(ConstantInt::get(Ty, *C2 + *C), X); } @@ -963,7 +963,7 @@ Instruction *InstCombinerImpl::foldAddWithConstant(BinaryOperator &Add) { } } - if (C->isOneValue() && Op0->hasOneUse()) { + if (C->isOne() && Op0->hasOneUse()) { // add (sext i1 X), 1 --> zext (not X) // TODO: The smallest IR representation is (select X, 0, 1), and that would // not require the one-use check. But we need to remove a transform in @@ -1355,6 +1355,17 @@ Instruction *InstCombinerImpl::visitAdd(BinaryOperator &I) { if (match(RHS, m_OneUse(m_c_Add(m_Value(A), m_Specific(LHS))))) return BinaryOperator::CreateAdd(A, Builder.CreateShl(LHS, 1, "reass.add")); + { + // (A + C1) + (C2 - B) --> (A - B) + (C1 + C2) + Constant *C1, *C2; + if (match(&I, m_c_Add(m_Add(m_Value(A), m_ImmConstant(C1)), + m_Sub(m_ImmConstant(C2), m_Value(B)))) && + (LHS->hasOneUse() || RHS->hasOneUse())) { + Value *Sub = Builder.CreateSub(A, B); + return BinaryOperator::CreateAdd(Sub, ConstantExpr::getAdd(C1, C2)); + } + } + // X % C0 + (( X / C0 ) % C1) * C0 => X % (C0 * C1) if (Value *V = SimplifyAddWithRemainder(I)) return replaceInstUsesWith(I, V); @@ -1817,12 +1828,8 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { if (match(Op0, m_AllOnes())) return BinaryOperator::CreateNot(Op1); - // (~X) - (~Y) --> Y - X - Value *X, *Y; - if (match(Op0, m_Not(m_Value(X))) && match(Op1, m_Not(m_Value(Y)))) - return BinaryOperator::CreateSub(Y, X); - // (X + -1) - Y --> ~Y + X + Value *X, *Y; if (match(Op0, m_OneUse(m_Add(m_Value(X), m_AllOnes())))) return BinaryOperator::CreateAdd(Builder.CreateNot(Op1), X); @@ -1843,6 +1850,17 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateSub(X, Add); } + // (~X) - (~Y) --> Y - X + // This is placed after the other reassociations and explicitly excludes a + // sub-of-sub pattern to avoid infinite looping. + if (isFreeToInvert(Op0, Op0->hasOneUse()) && + isFreeToInvert(Op1, Op1->hasOneUse()) && + !match(Op0, m_Sub(m_ImmConstant(), m_Value()))) { + Value *NotOp0 = Builder.CreateNot(Op0); + Value *NotOp1 = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(NotOp1, NotOp0); + } + auto m_AddRdx = [](Value *&Vec) { return m_OneUse(m_Intrinsic<Intrinsic::vector_reduce_add>(m_Value(Vec))); }; @@ -1892,7 +1910,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { // Turn this into a xor if LHS is 2^n-1 and the remaining bits are known // zero. KnownBits RHSKnown = computeKnownBits(Op1, 0, &I); - if ((*Op0C | RHSKnown.Zero).isAllOnesValue()) + if ((*Op0C | RHSKnown.Zero).isAllOnes()) return BinaryOperator::CreateXor(Op1, Op0); } @@ -2039,12 +2057,31 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return BinaryOperator::CreateAnd( Op0, Builder.CreateNot(Y, Y->getName() + ".not")); + // ~X - Min/Max(~X, Y) -> ~Min/Max(X, ~Y) - X + // ~X - Min/Max(Y, ~X) -> ~Min/Max(X, ~Y) - X + // Min/Max(~X, Y) - ~X -> X - ~Min/Max(X, ~Y) + // Min/Max(Y, ~X) - ~X -> X - ~Min/Max(X, ~Y) + // As long as Y is freely invertible, this will be neutral or a win. + // Note: We don't generate the inverse max/min, just create the 'not' of + // it and let other folds do the rest. + if (match(Op0, m_Not(m_Value(X))) && + match(Op1, m_c_MaxOrMin(m_Specific(Op0), m_Value(Y))) && + !Op0->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op1); + return BinaryOperator::CreateSub(Not, X); + } + if (match(Op1, m_Not(m_Value(X))) && + match(Op0, m_c_MaxOrMin(m_Specific(Op1), m_Value(Y))) && + !Op1->hasNUsesOrMore(3) && isFreeToInvert(Y, Y->hasOneUse())) { + Value *Not = Builder.CreateNot(Op0); + return BinaryOperator::CreateSub(X, Not); + } + + // TODO: This is the same logic as above but handles the cmp-select idioms + // for min/max, so the use checks are increased to account for the + // extra instructions. If we canonicalize to intrinsics, this block + // can likely be removed. { - // ~A - Min/Max(~A, O) -> Max/Min(A, ~O) - A - // ~A - Min/Max(O, ~A) -> Max/Min(A, ~O) - A - // Min/Max(~A, O) - ~A -> A - Max/Min(A, ~O) - // Min/Max(O, ~A) - ~A -> A - Max/Min(A, ~O) - // So long as O here is freely invertible, this will be neutral or a win. Value *LHS, *RHS, *A; Value *NotA = Op0, *MinMax = Op1; SelectPatternFlavor SPF = matchSelectPattern(MinMax, LHS, RHS).Flavor; @@ -2057,12 +2094,10 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { match(NotA, m_Not(m_Value(A))) && (NotA == LHS || NotA == RHS)) { if (NotA == LHS) std::swap(LHS, RHS); - // LHS is now O above and expected to have at least 2 uses (the min/max) - // NotA is epected to have 2 uses from the min/max and 1 from the sub. + // LHS is now Y above and expected to have at least 2 uses (the min/max) + // NotA is expected to have 2 uses from the min/max and 1 from the sub. if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && !NotA->hasNUsesOrMore(4)) { - // Note: We don't generate the inverse max/min, just create the not of - // it and let other folds do the rest. Value *Not = Builder.CreateNot(MinMax); if (NotA == Op0) return BinaryOperator::CreateSub(Not, A); @@ -2119,7 +2154,7 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { unsigned BitWidth = Ty->getScalarSizeInBits(); unsigned Cttz = AddC->countTrailingZeros(); APInt HighMask(APInt::getHighBitsSet(BitWidth, BitWidth - Cttz)); - if ((HighMask & *AndC).isNullValue()) + if ((HighMask & *AndC).isZero()) return BinaryOperator::CreateAnd(Op0, ConstantInt::get(Ty, ~(*AndC))); } @@ -2133,6 +2168,19 @@ Instruction *InstCombinerImpl::visitSub(BinaryOperator &I) { return replaceInstUsesWith( I, Builder.CreateIntrinsic(Intrinsic::umin, {I.getType()}, {Op0, Y})); + // umax(X, Op1) - Op1 --> usub.sat(X, Op1) + // TODO: The one-use restriction is not strictly necessary, but it may + // require improving other pattern matching and/or codegen. + if (match(Op0, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op1))))) + return replaceInstUsesWith( + I, Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op1})); + + // Op0 - umax(X, Op0) --> 0 - usub.sat(X, Op0) + if (match(Op1, m_OneUse(m_c_UMax(m_Value(X), m_Specific(Op0))))) { + Value *USub = Builder.CreateIntrinsic(Intrinsic::usub_sat, {Ty}, {X, Op0}); + return BinaryOperator::CreateNeg(USub); + } + // C - ctpop(X) => ctpop(~X) if C is bitwidth if (match(Op0, m_SpecificInt(Ty->getScalarSizeInBits())) && match(Op1, m_OneUse(m_Intrinsic<Intrinsic::ctpop>(m_Value(X))))) @@ -2173,8 +2221,8 @@ static Instruction *foldFNegIntoConstant(Instruction &I) { // TODO: We could propagate nsz/ninf from fdiv alone? FastMathFlags FMF = I.getFastMathFlags(); FastMathFlags OpFMF = FNegOp->getFastMathFlags(); - FDiv->setHasNoSignedZeros(FMF.noSignedZeros() & OpFMF.noSignedZeros()); - FDiv->setHasNoInfs(FMF.noInfs() & OpFMF.noInfs()); + FDiv->setHasNoSignedZeros(FMF.noSignedZeros() && OpFMF.noSignedZeros()); + FDiv->setHasNoInfs(FMF.noInfs() && OpFMF.noInfs()); return FDiv; } // With NSZ [ counter-example with -0.0: -(-0.0 + 0.0) != 0.0 + -0.0 ]: diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 120852c44474..06c9bf650f37 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -185,14 +185,15 @@ enum MaskedICmpType { /// satisfies. static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, ICmpInst::Predicate Pred) { - ConstantInt *ACst = dyn_cast<ConstantInt>(A); - ConstantInt *BCst = dyn_cast<ConstantInt>(B); - ConstantInt *CCst = dyn_cast<ConstantInt>(C); + const APInt *ConstA = nullptr, *ConstB = nullptr, *ConstC = nullptr; + match(A, m_APInt(ConstA)); + match(B, m_APInt(ConstB)); + match(C, m_APInt(ConstC)); bool IsEq = (Pred == ICmpInst::ICMP_EQ); - bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); - bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); + bool IsAPow2 = ConstA && ConstA->isPowerOf2(); + bool IsBPow2 = ConstB && ConstB->isPowerOf2(); unsigned MaskVal = 0; - if (CCst && CCst->isZero()) { + if (ConstC && ConstC->isZero()) { // if C is zero, then both A and B qualify as mask MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); @@ -211,7 +212,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, if (IsAPow2) MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) : (Mask_AllZeros | AMask_Mixed)); - } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { + } else if (ConstA && ConstC && ConstC->isSubsetOf(*ConstA)) { MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); } @@ -221,7 +222,7 @@ static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, if (IsBPow2) MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) : (Mask_AllZeros | BMask_Mixed)); - } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { + } else if (ConstB && ConstC && ConstC->isSubsetOf(*ConstB)) { MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); } @@ -269,9 +270,9 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { - // vectors are not (yet?) supported. Don't support pointers either. - if (!LHS->getOperand(0)->getType()->isIntegerTy() || - !RHS->getOperand(0)->getType()->isIntegerTy()) + // Don't allow pointers. Splat vectors are fine. + if (!LHS->getOperand(0)->getType()->isIntOrIntVectorTy() || + !RHS->getOperand(0)->getType()->isIntOrIntVectorTy()) return None; // Here comes the tricky part: @@ -367,9 +368,9 @@ getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, } else { return None; } + + assert(Ok && "Failed to find AND on the right side of the RHS icmp."); } - if (!Ok) - return None; if (L11 == A) { B = L12; @@ -619,8 +620,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // Remaining cases assume at least that B and D are constant, and depend on // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. - ConstantInt *BCst, *DCst; - if (!match(B, m_ConstantInt(BCst)) || !match(D, m_ConstantInt(DCst))) + const APInt *ConstB, *ConstD; + if (!match(B, m_APInt(ConstB)) || !match(D, m_APInt(ConstD))) return nullptr; if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { @@ -629,11 +630,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) // Only valid if one of the masks is a superset of the other (check "B&D" is // the same as either B or D). - APInt NewMask = BCst->getValue() & DCst->getValue(); - - if (NewMask == BCst->getValue()) + APInt NewMask = *ConstB & *ConstD; + if (NewMask == *ConstB) return LHS; - else if (NewMask == DCst->getValue()) + else if (NewMask == *ConstD) return RHS; } @@ -642,11 +642,10 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is // the same as either B or D). - APInt NewMask = BCst->getValue() | DCst->getValue(); - - if (NewMask == BCst->getValue()) + APInt NewMask = *ConstB | *ConstD; + if (NewMask == *ConstB) return LHS; - else if (NewMask == DCst->getValue()) + else if (NewMask == *ConstD) return RHS; } @@ -661,23 +660,21 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // We can't simply use C and E because we might actually handle // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. - ConstantInt *CCst, *ECst; - if (!match(C, m_ConstantInt(CCst)) || !match(E, m_ConstantInt(ECst))) + const APInt *OldConstC, *OldConstE; + if (!match(C, m_APInt(OldConstC)) || !match(E, m_APInt(OldConstE))) return nullptr; - if (PredL != NewCC) - CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (PredR != NewCC) - ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + + const APInt ConstC = PredL != NewCC ? *ConstB ^ *OldConstC : *OldConstC; + const APInt ConstE = PredR != NewCC ? *ConstD ^ *OldConstE : *OldConstE; // If there is a conflict, we should actually return a false for the // whole construct. - if (((BCst->getValue() & DCst->getValue()) & - (CCst->getValue() ^ ECst->getValue())).getBoolValue()) + if (((*ConstB & *ConstD) & (ConstC ^ ConstE)).getBoolValue()) return ConstantInt::get(LHS->getType(), !IsAnd); Value *NewOr1 = Builder.CreateOr(B, D); - Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); Value *NewAnd = Builder.CreateAnd(A, NewOr1); + Constant *NewOr2 = ConstantInt::get(A->getType(), ConstC | ConstE); return Builder.CreateICmp(NewCC, NewAnd, NewOr2); } @@ -777,20 +774,6 @@ foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, return Builder.CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); } - // Special case: get the ordering right when the values wrap around zero. - // Ie, we assumed the constants were unsigned when swapping earlier. - if (C1->isNullValue() && C2->isAllOnesValue()) - std::swap(C1, C2); - - if (*C1 == *C2 - 1) { - // (X == 13 || X == 14) --> X - 13 <=u 1 - // (X != 13 && X != 14) --> X - 13 >u 1 - // An 'add' is the canonical IR form, so favor that over a 'sub'. - Value *Add = Builder.CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); - auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; - return Builder.CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); - } - return nullptr; } @@ -923,7 +906,7 @@ static Value *foldSignedTruncationCheck(ICmpInst *ICmp0, ICmpInst *ICmp1, if (!tryToDecompose(OtherICmp, X0, UnsetBitsMask)) return nullptr; - assert(!UnsetBitsMask.isNullValue() && "empty mask makes no sense."); + assert(!UnsetBitsMask.isZero() && "empty mask makes no sense."); // Are they working on the same value? Value *X; @@ -1113,8 +1096,8 @@ static Value *extractIntPart(const IntPart &P, IRBuilderBase &Builder) { /// (icmp eq X0, Y0) & (icmp eq X1, Y1) -> icmp eq X01, Y01 /// (icmp ne X0, Y0) | (icmp ne X1, Y1) -> icmp ne X01, Y01 /// where X0, X1 and Y0, Y1 are adjacent parts extracted from an integer. -static Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd, - InstCombiner::BuilderTy &Builder) { +Value *InstCombinerImpl::foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, + bool IsAnd) { if (!Cmp0->hasOneUse() || !Cmp1->hasOneUse()) return nullptr; @@ -1202,6 +1185,51 @@ static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); } +/// Fold (icmp Pred1 V1, C1) & (icmp Pred2 V2, C2) +/// or (icmp Pred1 V1, C1) | (icmp Pred2 V2, C2) +/// into a single comparison using range-based reasoning. +static Value *foldAndOrOfICmpsUsingRanges( + ICmpInst::Predicate Pred1, Value *V1, const APInt &C1, + ICmpInst::Predicate Pred2, Value *V2, const APInt &C2, + IRBuilderBase &Builder, bool IsAnd) { + // Look through add of a constant offset on V1, V2, or both operands. This + // allows us to interpret the V + C' < C'' range idiom into a proper range. + const APInt *Offset1 = nullptr, *Offset2 = nullptr; + if (V1 != V2) { + Value *X; + if (match(V1, m_Add(m_Value(X), m_APInt(Offset1)))) + V1 = X; + if (match(V2, m_Add(m_Value(X), m_APInt(Offset2)))) + V2 = X; + } + + if (V1 != V2) + return nullptr; + + ConstantRange CR1 = ConstantRange::makeExactICmpRegion(Pred1, C1); + if (Offset1) + CR1 = CR1.subtract(*Offset1); + + ConstantRange CR2 = ConstantRange::makeExactICmpRegion(Pred2, C2); + if (Offset2) + CR2 = CR2.subtract(*Offset2); + + Optional<ConstantRange> CR = + IsAnd ? CR1.exactIntersectWith(CR2) : CR1.exactUnionWith(CR2); + if (!CR) + return nullptr; + + CmpInst::Predicate NewPred; + APInt NewC, Offset; + CR->getEquivalentICmp(NewPred, NewC, Offset); + + Type *Ty = V1->getType(); + Value *NewV = V1; + if (Offset != 0) + NewV = Builder.CreateAdd(NewV, ConstantInt::get(Ty, Offset)); + return Builder.CreateICmp(NewPred, NewV, ConstantInt::get(Ty, NewC)); +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &And) { @@ -1262,170 +1290,64 @@ Value *InstCombinerImpl::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/true, Q, Builder)) return X; - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true, Builder)) + if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/true)) return X; // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); - ConstantInt *LHSC, *RHSC; - if (!match(LHS->getOperand(1), m_ConstantInt(LHSC)) || - !match(RHS->getOperand(1), m_ConstantInt(RHSC))) - return nullptr; - - if (LHSC == RHSC && PredL == PredR) { - // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) - // where C is a power of 2 or - // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || - (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { - Value *NewOr = Builder.CreateOr(LHS0, RHS0); - return Builder.CreateICmp(PredL, NewOr, LHSC); - } + // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) + // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + if (PredL == ICmpInst::ICMP_EQ && match(LHS->getOperand(1), m_ZeroInt()) && + PredR == ICmpInst::ICMP_EQ && match(RHS->getOperand(1), m_ZeroInt()) && + LHS0->getType() == RHS0->getType()) { + Value *NewOr = Builder.CreateOr(LHS0, RHS0); + return Builder.CreateICmp(PredL, NewOr, + Constant::getNullValue(NewOr->getType())); } + const APInt *LHSC, *RHSC; + if (!match(LHS->getOperand(1), m_APInt(LHSC)) || + !match(RHS->getOperand(1), m_APInt(RHSC))) + return nullptr; + // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && RHS->hasOneUse()) { Value *V; - ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr; + const APInt *AndC, *SmallC = nullptr, *BigC = nullptr; // (trunc x) == C1 & (and x, CA) == C2 // (and x, CA) == C2 & (trunc x) == C1 if (match(RHS0, m_Trunc(m_Value(V))) && - match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + match(LHS0, m_And(m_Specific(V), m_APInt(AndC)))) { SmallC = RHSC; BigC = LHSC; } else if (match(LHS0, m_Trunc(m_Value(V))) && - match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + match(RHS0, m_And(m_Specific(V), m_APInt(AndC)))) { SmallC = LHSC; BigC = RHSC; } if (SmallC && BigC) { - unsigned BigBitSize = BigC->getType()->getBitWidth(); - unsigned SmallBitSize = SmallC->getType()->getBitWidth(); + unsigned BigBitSize = BigC->getBitWidth(); + unsigned SmallBitSize = SmallC->getBitWidth(); // Check that the low bits are zero. APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & AndC->getValue()).isNullValue() && - (Low & BigC->getValue()).isNullValue()) { - Value *NewAnd = Builder.CreateAnd(V, Low | AndC->getValue()); - APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); - Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); + if ((Low & *AndC).isZero() && (Low & *BigC).isZero()) { + Value *NewAnd = Builder.CreateAnd(V, Low | *AndC); + APInt N = SmallC->zext(BigBitSize) | *BigC; + Value *NewVal = ConstantInt::get(NewAnd->getType(), N); return Builder.CreateICmp(PredL, NewAnd, NewVal); } } } - // From here on, we only handle: - // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. - if (LHS0 != RHS0) - return nullptr; - - // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. - if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || - PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || - PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || - PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) - return nullptr; - - // We can't fold (ugt x, C) & (sgt x, C2). - if (!predicatesFoldable(PredL, PredR)) - return nullptr; - - // Ensure that the larger constant is on the RHS. - bool ShouldSwap; - if (CmpInst::isSigned(PredL) || - (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) - ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); - else - ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); - - if (ShouldSwap) { - std::swap(LHS, RHS); - std::swap(LHSC, RHSC); - std::swap(PredL, PredR); - } - - // At this point, we know we have two icmp instructions - // comparing a value against two constants and and'ing the result - // together. Because of the above check, we know that we only have - // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know - // (from the icmp folding check above), that the two constants - // are not equal and that the larger constant is on the RHS - assert(LHSC != RHSC && "Compares not folded above?"); - - switch (PredL) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_ULT: - // (X != 13 & X u< 14) -> X < 13 - if (LHSC->getValue() == (RHSC->getValue() - 1)) - return Builder.CreateICmpULT(LHS0, LHSC); - if (LHSC->isZero()) // (X != 0 & X u< C) -> X-1 u< C-1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - break; // (X != 13 & X u< 15) -> no change - case ICmpInst::ICMP_SLT: - // (X != 13 & X s< 14) -> X < 13 - if (LHSC->getValue() == (RHSC->getValue() - 1)) - return Builder.CreateICmpSLT(LHS0, LHSC); - // (X != INT_MIN & X s< C) -> X-(INT_MIN+1) u< (C-(INT_MIN+1)) - if (LHSC->isMinValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - true, true); - break; // (X != 13 & X s< 15) -> no change - case ICmpInst::ICMP_NE: - // Potential folds for this case should already be handled. - break; - } - break; - case ICmpInst::ICMP_UGT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - // (X u> 13 & X != 14) -> X u> 14 - if (RHSC->getValue() == (LHSC->getValue() + 1)) - return Builder.CreateICmp(PredL, LHS0, RHSC); - // X u> C & X != UINT_MAX -> (X-(C+1)) u< UINT_MAX-(C+1) - if (RHSC->isMaxValue(false)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) u< 1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - false, true); - } - break; - case ICmpInst::ICMP_SGT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: - // (X s> 13 & X != 14) -> X s> 14 - if (RHSC->getValue() == (LHSC->getValue() + 1)) - return Builder.CreateICmp(PredL, LHS0, RHSC); - // X s> C & X != INT_MAX -> (X-(C+1)) u< INT_MAX-(C+1) - if (RHSC->isMaxValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), - true, true); - break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) u< 1 - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, - true); - } - break; - } - - return nullptr; + return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, + Builder, /* IsAnd */ true); } Value *InstCombinerImpl::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, @@ -1496,15 +1418,15 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, std::swap(Op0, Op1); // Match inner binop and the predicate for combining 2 NAN checks into 1. - BinaryOperator *BO1; + Value *BO10, *BO11; FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred || - !match(Op1, m_BinOp(BO1)) || BO1->getOpcode() != Opcode) + !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11)))) return nullptr; // The inner logic op must have a matching fcmp operand. - Value *BO10 = BO1->getOperand(0), *BO11 = BO1->getOperand(1), *Y; + Value *Y; if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || Pred != NanPred || X->getType() != Y->getType()) std::swap(BO10, BO11); @@ -1524,27 +1446,42 @@ static Instruction *reassociateFCmps(BinaryOperator &BO, return BinaryOperator::Create(Opcode, NewFCmp, BO11); } -/// Match De Morgan's Laws: +/// Match variations of De Morgan's Laws: /// (~A & ~B) == (~(A | B)) /// (~A | ~B) == (~(A & B)) static Instruction *matchDeMorgansLaws(BinaryOperator &I, InstCombiner::BuilderTy &Builder) { - auto Opcode = I.getOpcode(); + const Instruction::BinaryOps Opcode = I.getOpcode(); assert((Opcode == Instruction::And || Opcode == Instruction::Or) && "Trying to match De Morgan's Laws with something other than and/or"); // Flip the logic operation. - Opcode = (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Value *A, *B; - if (match(I.getOperand(0), m_OneUse(m_Not(m_Value(A)))) && - match(I.getOperand(1), m_OneUse(m_Not(m_Value(B)))) && + if (match(Op0, m_OneUse(m_Not(m_Value(A)))) && + match(Op1, m_OneUse(m_Not(m_Value(B)))) && !InstCombiner::isFreeToInvert(A, A->hasOneUse()) && !InstCombiner::isFreeToInvert(B, B->hasOneUse())) { - Value *AndOr = Builder.CreateBinOp(Opcode, A, B, I.getName() + ".demorgan"); + Value *AndOr = + Builder.CreateBinOp(FlippedOpcode, A, B, I.getName() + ".demorgan"); return BinaryOperator::CreateNot(AndOr); } + // The 'not' ops may require reassociation. + // (A & ~B) & ~C --> A & ~(B | C) + // (~B & A) & ~C --> A & ~(B | C) + // (A | ~B) | ~C --> A | ~(B & C) + // (~B | A) | ~C --> A | ~(B & C) + Value *C; + if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && + match(Op1, m_Not(m_Value(C)))) { + Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); + } + return nullptr; } @@ -1778,6 +1715,72 @@ Instruction *InstCombinerImpl::narrowMaskedBinOp(BinaryOperator &And) { return new ZExtInst(Builder.CreateAnd(NewBO, X), Ty); } +/// Try folding relatively complex patterns for both And and Or operations +/// with all And and Or swapped. +static Instruction *foldComplexAndOrPatterns(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + const Instruction::BinaryOps Opcode = I.getOpcode(); + assert(Opcode == Instruction::And || Opcode == Instruction::Or); + + // Flip the logic operation. + const Instruction::BinaryOps FlippedOpcode = + (Opcode == Instruction::And) ? Instruction::Or : Instruction::And; + + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *A, *B, *C; + + // (~(A | B) & C) | ... --> ... + // (~(A & B) | C) & ... --> ... + // TODO: One use checks are conservative. We just need to check that a total + // number of multiple used values does not exceed reduction + // in operations. + if (match(Op0, m_c_BinOp(FlippedOpcode, + m_Not(m_BinOp(Opcode, m_Value(A), m_Value(B))), + m_Value(C)))) { + // (~(A | B) & C) | (~(A | C) & B) --> (B ^ C) & ~A + // (~(A & B) | C) & (~(A & C) | B) --> ~((B ^ C) & A) + if (match(Op1, + m_OneUse(m_c_BinOp(FlippedOpcode, + m_OneUse(m_Not(m_c_BinOp(Opcode, m_Specific(A), + m_Specific(C)))), + m_Specific(B))))) { + Value *Xor = Builder.CreateXor(B, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(A)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, A)); + } + + // (~(A | B) & C) | (~(B | C) & A) --> (A ^ C) & ~B + // (~(A & B) | C) & (~(B & C) | A) --> ~((A ^ C) & B) + if (match(Op1, + m_OneUse(m_c_BinOp(FlippedOpcode, + m_OneUse(m_Not(m_c_BinOp(Opcode, m_Specific(B), + m_Specific(C)))), + m_Specific(A))))) { + Value *Xor = Builder.CreateXor(A, C); + return (Opcode == Instruction::Or) + ? BinaryOperator::CreateAnd(Xor, Builder.CreateNot(B)) + : BinaryOperator::CreateNot(Builder.CreateAnd(Xor, B)); + } + + // (~(A | B) & C) | ~(A | C) --> ~((B & C) | A) + // (~(A & B) | C) & ~(A & C) --> ~((B | C) & A) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(A), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, B, C), A)); + + // (~(A | B) & C) | ~(B | C) --> ~((A & C) | B) + // (~(A & B) | C) & ~(B & C) --> ~((A | C) & B) + if (match(Op1, m_OneUse(m_Not(m_OneUse( + m_c_BinOp(Opcode, m_Specific(B), m_Specific(C))))))) + return BinaryOperator::CreateNot(Builder.CreateBinOp( + Opcode, Builder.CreateBinOp(FlippedOpcode, A, C), B)); + } + + return nullptr; +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -1803,6 +1806,9 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { if (Instruction *Xor = foldAndToXor(I, Builder)) return Xor; + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + // (A|B)&(A|C) -> A|(B&C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -1883,7 +1889,7 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { // (X + AddC) & LowMaskC --> X & LowMaskC unsigned Ctlz = C->countLeadingZeros(); APInt LowMask(APInt::getLowBitsSet(Width, Width - Ctlz)); - if ((*AddC & LowMask).isNullValue()) + if ((*AddC & LowMask).isZero()) return BinaryOperator::CreateAnd(X, Op1); // If we are masking the result of the add down to exactly one bit and @@ -1896,44 +1902,37 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateXor(NewAnd, Op1); } } - } - ConstantInt *AndRHS; - if (match(Op1, m_ConstantInt(AndRHS))) { - const APInt &AndRHSMask = AndRHS->getValue(); - - // Optimize a variety of ((val OP C1) & C2) combinations... - if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth - // of X and OP behaves well when given trunc(C1) and X. - // TODO: Do this for vectors by using m_APInt instead of m_ConstantInt. - switch (Op0I->getOpcode()) { - default: - break; + // ((C1 OP zext(X)) & C2) -> zext((C1 OP X) & C2) if C2 fits in the + // bitwidth of X and OP behaves well when given trunc(C1) and X. + auto isSuitableBinOpcode = [](BinaryOperator *B) { + switch (B->getOpcode()) { case Instruction::Xor: case Instruction::Or: case Instruction::Mul: case Instruction::Add: case Instruction::Sub: - Value *X; - ConstantInt *C1; - // TODO: The one use restrictions could be relaxed a little if the AND - // is going to be removed. - if (match(Op0I, m_OneUse(m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), - m_ConstantInt(C1))))) { - if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { - auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); - Value *BinOp; - Value *Op0LHS = Op0I->getOperand(0); - if (isa<ZExtInst>(Op0LHS)) - BinOp = Builder.CreateBinOp(Op0I->getOpcode(), X, TruncC1); - else - BinOp = Builder.CreateBinOp(Op0I->getOpcode(), TruncC1, X); - auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); - auto *And = Builder.CreateAnd(BinOp, TruncC2); - return new ZExtInst(And, Ty); - } - } + return true; + default: + return false; + } + }; + BinaryOperator *BO; + if (match(Op0, m_OneUse(m_BinOp(BO))) && isSuitableBinOpcode(BO)) { + Value *X; + const APInt *C1; + // TODO: The one-use restrictions could be relaxed a little if the AND + // is going to be removed. + if (match(BO, m_c_BinOp(m_OneUse(m_ZExt(m_Value(X))), m_APInt(C1))) && + C->isIntN(X->getType()->getScalarSizeInBits())) { + unsigned XWidth = X->getType()->getScalarSizeInBits(); + Constant *TruncC1 = ConstantInt::get(X->getType(), C1->trunc(XWidth)); + Value *BinOp = isa<ZExtInst>(BO->getOperand(0)) + ? Builder.CreateBinOp(BO->getOpcode(), X, TruncC1) + : Builder.CreateBinOp(BO->getOpcode(), TruncC1, X); + Constant *TruncC = ConstantInt::get(X->getType(), C->trunc(XWidth)); + Value *And = Builder.CreateAnd(BinOp, TruncC); + return new ZExtInst(And, Ty); } } } @@ -2071,13 +2070,13 @@ Instruction *InstCombinerImpl::visitAnd(BinaryOperator &I) { A->getType()->isIntOrIntVectorTy(1)) return SelectInst::Create(A, Op0, Constant::getNullValue(Ty)); - // and(ashr(subNSW(Y, X), ScalarSizeInBits(Y)-1), X) --> X s> Y ? X : 0. - if (match(&I, m_c_And(m_OneUse(m_AShr( - m_NSWSub(m_Value(Y), m_Value(X)), - m_SpecificInt(Ty->getScalarSizeInBits() - 1))), - m_Deferred(X)))) { - Value *NewICmpInst = Builder.CreateICmpSGT(X, Y); - return SelectInst::Create(NewICmpInst, X, ConstantInt::getNullValue(Ty)); + // (iN X s>> (N-1)) & Y --> (X s< 0) ? Y : 0 + unsigned FullShift = Ty->getScalarSizeInBits() - 1; + if (match(&I, m_c_And(m_OneUse(m_AShr(m_Value(X), m_SpecificInt(FullShift))), + m_Value(Y)))) { + Constant *Zero = ConstantInt::getNullValue(Ty); + Value *Cmp = Builder.CreateICmpSLT(X, Zero, "isneg"); + return SelectInst::Create(Cmp, Y, Zero); } // (~x) & y --> ~(x | (~y)) iff that gets rid of inversions @@ -2284,28 +2283,38 @@ static bool areInverseVectorBitmasks(Constant *C1, Constant *C2) { /// vector composed of all-zeros or all-ones values and is the bitwise 'not' of /// B, it can be used as the condition operand of a select instruction. Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) { - // Step 1: We may have peeked through bitcasts in the caller. + // We may have peeked through bitcasts in the caller. // Exit immediately if we don't have (vector) integer types. Type *Ty = A->getType(); if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy()) return nullptr; - // Step 2: We need 0 or all-1's bitmasks. - if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits()) - return nullptr; - - // Step 3: If B is the 'not' value of A, we have our answer. - if (match(A, m_Not(m_Specific(B)))) { + // If A is the 'not' operand of B and has enough signbits, we have our answer. + if (match(B, m_Not(m_Specific(A)))) { // If these are scalars or vectors of i1, A can be used directly. if (Ty->isIntOrIntVectorTy(1)) return A; - return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty)); + + // If we look through a vector bitcast, the caller will bitcast the operands + // to match the condition's number of bits (N x i1). + // To make this poison-safe, disallow bitcast from wide element to narrow + // element. That could allow poison in lanes where it was not present in the + // original code. + A = peekThroughBitcast(A); + if (A->getType()->isIntOrIntVectorTy()) { + unsigned NumSignBits = ComputeNumSignBits(A); + if (NumSignBits == A->getType()->getScalarSizeInBits() && + NumSignBits <= Ty->getScalarSizeInBits()) + return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType())); + } + return nullptr; } // If both operands are constants, see if the constants are inverse bitmasks. Constant *AConst, *BConst; if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst))) - if (AConst == ConstantExpr::getNot(BConst)) + if (AConst == ConstantExpr::getNot(BConst) && + ComputeNumSignBits(A) == Ty->getScalarSizeInBits()) return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty)); // Look for more complex patterns. The 'not' op may be hidden behind various @@ -2349,10 +2358,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B, B = peekThroughBitcast(B, true); if (Value *Cond = getSelectCondition(A, B)) { // ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D)) + // If this is a vector, we may need to cast to match the condition's length. // The bitcasts will either all exist or all not exist. The builder will // not create unnecessary casts if the types already match. - Value *BitcastC = Builder.CreateBitCast(C, A->getType()); - Value *BitcastD = Builder.CreateBitCast(D, A->getType()); + Type *SelTy = A->getType(); + if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) { + unsigned Elts = VecTy->getElementCount().getKnownMinValue(); + Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts); + SelTy = VectorType::get(EltTy, VecTy->getElementCount()); + } + Value *BitcastC = Builder.CreateBitCast(C, SelTy); + Value *BitcastD = Builder.CreateBitCast(D, SelTy); Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD); return Builder.CreateBitCast(Select, OrigType); } @@ -2374,8 +2390,9 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); Value *LHS1 = LHS->getOperand(1), *RHS1 = RHS->getOperand(1); - auto *LHSC = dyn_cast<ConstantInt>(LHS1); - auto *RHSC = dyn_cast<ConstantInt>(RHS1); + const APInt *LHSC = nullptr, *RHSC = nullptr; + match(LHS1, m_APInt(LHSC)); + match(RHS1, m_APInt(RHSC)); // Fold (icmp ult/ule (A + C1), C3) | (icmp ult/ule (A + C2), C3) // --> (icmp ult/ule ((A & ~(C1 ^ C2)) + max(C1, C2)), C3) @@ -2389,40 +2406,41 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // This implies all values in the two ranges differ by exactly one bit. if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && - LHSC->getType() == RHSC->getType() && - LHSC->getValue() == (RHSC->getValue())) { + LHSC->getBitWidth() == RHSC->getBitWidth() && *LHSC == *RHSC) { Value *AddOpnd; - ConstantInt *LAddC, *RAddC; - if (match(LHS0, m_Add(m_Value(AddOpnd), m_ConstantInt(LAddC))) && - match(RHS0, m_Add(m_Specific(AddOpnd), m_ConstantInt(RAddC))) && - LAddC->getValue().ugt(LHSC->getValue()) && - RAddC->getValue().ugt(LHSC->getValue())) { + const APInt *LAddC, *RAddC; + if (match(LHS0, m_Add(m_Value(AddOpnd), m_APInt(LAddC))) && + match(RHS0, m_Add(m_Specific(AddOpnd), m_APInt(RAddC))) && + LAddC->ugt(*LHSC) && RAddC->ugt(*LHSC)) { - APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + APInt DiffC = *LAddC ^ *RAddC; if (DiffC.isPowerOf2()) { - ConstantInt *MaxAddC = nullptr; - if (LAddC->getValue().ult(RAddC->getValue())) + const APInt *MaxAddC = nullptr; + if (LAddC->ult(*RAddC)) MaxAddC = RAddC; else MaxAddC = LAddC; - APInt RRangeLow = -RAddC->getValue(); - APInt RRangeHigh = RRangeLow + LHSC->getValue(); - APInt LRangeLow = -LAddC->getValue(); - APInt LRangeHigh = LRangeLow + LHSC->getValue(); + APInt RRangeLow = -*RAddC; + APInt RRangeHigh = RRangeLow + *LHSC; + APInt LRangeLow = -*LAddC; + APInt LRangeHigh = LRangeLow + *LHSC; APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSC->getValue())) { - Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); + RangeDiff.ugt(*LHSC)) { + Type *Ty = AddOpnd->getType(); + Value *MaskC = ConstantInt::get(Ty, ~DiffC); Value *NewAnd = Builder.CreateAnd(AddOpnd, MaskC); - Value *NewAdd = Builder.CreateAdd(NewAnd, MaxAddC); - return Builder.CreateICmp(LHS->getPredicate(), NewAdd, LHSC); + Value *NewAdd = Builder.CreateAdd(NewAnd, + ConstantInt::get(Ty, *MaxAddC)); + return Builder.CreateICmp(LHS->getPredicate(), NewAdd, + ConstantInt::get(Ty, *LHSC)); } } } @@ -2496,14 +2514,13 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, foldUnsignedUnderflowCheck(RHS, LHS, /*IsAnd=*/false, Q, Builder)) return X; - if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false, Builder)) + if (Value *X = foldEqOfParts(LHS, RHS, /*IsAnd=*/false)) return X; // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - // TODO: Remove this when foldLogOpOfMaskedICmps can handle vectors. - if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_Zero()) && - PredR == ICmpInst::ICMP_NE && match(RHS1, m_Zero()) && - LHS0->getType()->isIntOrIntVectorTy() && + // TODO: Remove this when foldLogOpOfMaskedICmps can handle undefs. + if (PredL == ICmpInst::ICMP_NE && match(LHS1, m_ZeroInt()) && + PredR == ICmpInst::ICMP_NE && match(RHS1, m_ZeroInt()) && LHS0->getType() == RHS0->getType()) { Value *NewOr = Builder.CreateOr(LHS0, RHS0); return Builder.CreateICmp(PredL, NewOr, @@ -2514,114 +2531,8 @@ Value *InstCombinerImpl::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (!LHSC || !RHSC) return nullptr; - // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) - // iff C2 + CA == C1. - if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { - ConstantInt *AddC; - if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) - if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) - return Builder.CreateICmpULE(LHS0, LHSC); - } - - // From here on, we only handle: - // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. - if (LHS0 != RHS0) - return nullptr; - - // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. - if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || - PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || - PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || - PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) - return nullptr; - - // We can't fold (ugt x, C) | (sgt x, C2). - if (!predicatesFoldable(PredL, PredR)) - return nullptr; - - // Ensure that the larger constant is on the RHS. - bool ShouldSwap; - if (CmpInst::isSigned(PredL) || - (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) - ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); - else - ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); - - if (ShouldSwap) { - std::swap(LHS, RHS); - std::swap(LHSC, RHSC); - std::swap(PredL, PredR); - } - - // At this point, we know we have two icmp instructions - // comparing a value against two constants and or'ing the result - // together. Because of the above check, we know that we only have - // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the - // icmp folding check above), that the two constants are not - // equal. - assert(LHSC != RHSC && "Compares not folded above?"); - - switch (PredL) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - // Potential folds for this case should already be handled. - break; - case ICmpInst::ICMP_UGT: - // (X == 0 || X u> C) -> (X-1) u>= C - if (LHSC->isMinValue(false)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, - false, false); - // (X == 13 | X u> 14) -> no change - break; - case ICmpInst::ICMP_SGT: - // (X == INT_MIN || X s> C) -> (X-(INT_MIN+1)) u>= C-INT_MIN - if (LHSC->isMinValue(true)) - return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue() + 1, - true, false); - // (X == 13 | X s> 14) -> no change - break; - } - break; - case ICmpInst::ICMP_ULT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change - // (X u< C || X == UINT_MAX) => (X-C) u>= UINT_MAX-C - if (RHSC->isMaxValue(false)) - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), - false, false); - break; - case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 - assert(!RHSC->isMaxValue(false) && "Missed icmp simplification"); - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, - false, false); - } - break; - case ICmpInst::ICMP_SLT: - switch (PredR) { - default: - llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: - // (X s< C || X == INT_MAX) => (X-C) u>= INT_MAX-C - if (RHSC->isMaxValue(true)) - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue(), - true, false); - // (X s< 13 | X == 14) -> no change - break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) u> 2 - assert(!RHSC->isMaxValue(true) && "Missed icmp simplification"); - return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, - false); - } - break; - } - return nullptr; + return foldAndOrOfICmpsUsingRanges(PredL, LHS0, *LHSC, PredR, RHS0, *RHSC, + Builder, /* IsAnd */ false); } // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches @@ -2647,6 +2558,9 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { if (Instruction *Xor = foldOrToXor(I, Builder)) return Xor; + if (Instruction *X = foldComplexAndOrPatterns(I, Builder)) + return X; + // (A&B)|(A&C) -> A&(B|C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); @@ -2684,69 +2598,63 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { Value *X, *Y; const APInt *CV; if (match(&I, m_c_Or(m_OneUse(m_Xor(m_Value(X), m_APInt(CV))), m_Value(Y))) && - !CV->isAllOnesValue() && MaskedValueIsZero(Y, *CV, 0, &I)) { + !CV->isAllOnes() && MaskedValueIsZero(Y, *CV, 0, &I)) { // (X ^ C) | Y -> (X | Y) ^ C iff Y & C == 0 // The check for a 'not' op is for efficiency (if Y is known zero --> ~X). Value *Or = Builder.CreateOr(X, Y); return BinaryOperator::CreateXor(Or, ConstantInt::get(I.getType(), *CV)); } - // (A & C)|(B & D) + // (A & C) | (B & D) Value *A, *B, *C, *D; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - // (A & C1)|(B & C2) - ConstantInt *C1, *C2; - if (match(C, m_ConstantInt(C1)) && match(D, m_ConstantInt(C2))) { - Value *V1 = nullptr, *V2 = nullptr; - if ((C1->getValue() & C2->getValue()).isNullValue()) { - // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) - // iff (C1&C2) == 0 and (N&~C1) == 0 - if (match(A, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == B && - MaskedValueIsZero(V2, ~C1->getValue(), 0, &I)) || // (V|N) - (V2 == B && - MaskedValueIsZero(V1, ~C1->getValue(), 0, &I)))) // (N|V) - return BinaryOperator::CreateAnd(A, - Builder.getInt(C1->getValue()|C2->getValue())); - // Or commutes, try both ways. - if (match(B, m_Or(m_Value(V1), m_Value(V2))) && - ((V1 == A && - MaskedValueIsZero(V2, ~C2->getValue(), 0, &I)) || // (V|N) - (V2 == A && - MaskedValueIsZero(V1, ~C2->getValue(), 0, &I)))) // (N|V) - return BinaryOperator::CreateAnd(B, - Builder.getInt(C1->getValue()|C2->getValue())); - - // ((V|C3)&C1) | ((V|C4)&C2) --> (V|C3|C4)&(C1|C2) - // iff (C1&C2) == 0 and (C3&~C1) == 0 and (C4&~C2) == 0. - ConstantInt *C3 = nullptr, *C4 = nullptr; - if (match(A, m_Or(m_Value(V1), m_ConstantInt(C3))) && - (C3->getValue() & ~C1->getValue()).isNullValue() && - match(B, m_Or(m_Specific(V1), m_ConstantInt(C4))) && - (C4->getValue() & ~C2->getValue()).isNullValue()) { - V2 = Builder.CreateOr(V1, ConstantExpr::getOr(C3, C4), "bitfield"); - return BinaryOperator::CreateAnd(V2, - Builder.getInt(C1->getValue()|C2->getValue())); - } - } - - if (C1->getValue() == ~C2->getValue()) { - Value *X; - // ((X|B)&C1)|(B&C2) -> (X&C1) | B iff C1 == ~C2 + // (A & C0) | (B & C1) + const APInt *C0, *C1; + if (match(C, m_APInt(C0)) && match(D, m_APInt(C1))) { + Value *X; + if (*C0 == ~*C1) { + // ((X | B) & MaskC) | (B & ~MaskC) -> (X & MaskC) | B if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, C1), B); - // (A&C2)|((X|A)&C1) -> (X&C2) | A iff C1 == ~C2 + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X | A) & ~MaskC) -> (X & ~MaskC) | A if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateOr(Builder.CreateAnd(X, C2), A); + return BinaryOperator::CreateOr(Builder.CreateAnd(X, *C1), A); - // ((X^B)&C1)|(B&C2) -> (X&C1) ^ B iff C1 == ~C2 + // ((X ^ B) & MaskC) | (B & ~MaskC) -> (X & MaskC) ^ B if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, C1), B); - // (A&C2)|((X^A)&C1) -> (X&C2) ^ A iff C1 == ~C2 + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C0), B); + // (A & MaskC) | ((X ^ A) & ~MaskC) -> (X & ~MaskC) ^ A if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) - return BinaryOperator::CreateXor(Builder.CreateAnd(X, C2), A); + return BinaryOperator::CreateXor(Builder.CreateAnd(X, *C1), A); + } + + if ((*C0 & *C1).isZero()) { + // ((X | B) & C0) | (B & C1) --> (X | B) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C0) == 0 + if (match(A, m_c_Or(m_Value(X), m_Specific(B))) && + MaskedValueIsZero(X, ~*C0, 0, &I)) { + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(A, C01); + } + // (A & C0) | ((X | A) & C1) --> (X | A) & (C0 | C1) + // iff (C0 & C1) == 0 and (X & ~C1) == 0 + if (match(B, m_c_Or(m_Value(X), m_Specific(A))) && + MaskedValueIsZero(X, ~*C1, 0, &I)) { + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(B, C01); + } + // ((X | C2) & C0) | ((X | C3) & C1) --> (X | C2 | C3) & (C0 | C1) + // iff (C0 & C1) == 0 and (C2 & ~C0) == 0 and (C3 & ~C1) == 0. + const APInt *C2, *C3; + if (match(A, m_Or(m_Value(X), m_APInt(C2))) && + match(B, m_Or(m_Specific(X), m_APInt(C3))) && + (*C2 & ~*C0).isZero() && (*C3 & ~*C1).isZero()) { + Value *Or = Builder.CreateOr(X, *C2 | *C3, "bitfield"); + Constant *C01 = ConstantInt::get(I.getType(), *C0 | *C1); + return BinaryOperator::CreateAnd(Or, C01); + } } } @@ -2801,6 +2709,8 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { // A | ( A ^ B) -> A | B // A | (~A ^ B) -> A | ~B // (A & B) | (A ^ B) + // ~A | (A ^ B) -> ~(A & B) + // The swap above should always make Op0 the 'not' for the last case. if (match(Op1, m_Xor(m_Value(A), m_Value(B)))) { if (Op0 == A || Op0 == B) return BinaryOperator::CreateOr(A, B); @@ -2809,6 +2719,10 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) { match(Op0, m_And(m_Specific(B), m_Specific(A)))) return BinaryOperator::CreateOr(A, B); + if ((Op0->hasOneUse() || Op1->hasOneUse()) && + (match(Op0, m_Not(m_Specific(A))) || match(Op0, m_Not(m_Specific(B))))) + return BinaryOperator::CreateNot(Builder.CreateAnd(A, B)); + if (Op1->hasOneUse() && match(A, m_Not(m_Specific(Op0)))) { Value *Not = Builder.CreateNot(B, B->getName() + ".not"); return BinaryOperator::CreateOr(Not, Op0); @@ -3275,71 +3189,45 @@ bool InstCombinerImpl::sinkNotIntoOtherHandOfAndOrOr(BinaryOperator &I) { return true; } -// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches -// here. We should standardize that construct where it is needed or choose some -// other way to ensure that commutated variants of patterns are not missed. -Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { - if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), - SQ.getWithInstruction(&I))) - return replaceInstUsesWith(I, V); - - if (SimplifyAssociativeOrCommutative(I)) - return &I; - - if (Instruction *X = foldVectorBinop(I)) - return X; - - if (Instruction *NewXor = foldXorToXor(I, Builder)) - return NewXor; - - // (A&B)^(A&C) -> A&(B^C) etc - if (Value *V = SimplifyUsingDistributiveLaws(I)) - return replaceInstUsesWith(I, V); - - // See if we can simplify any instructions used by the instruction whose sole - // purpose is to compute bits we don't care about. - if (SimplifyDemandedInstructionBits(I)) - return &I; - - if (Value *V = SimplifyBSwap(I, Builder)) - return replaceInstUsesWith(I, V); - - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - Type *Ty = I.getType(); - - // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) - // This it a special case in haveNoCommonBitsSet, but the computeKnownBits - // calls in there are unnecessary as SimplifyDemandedInstructionBits should - // have already taken care of those cases. - Value *M; - if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), - m_c_And(m_Deferred(M), m_Value())))) - return BinaryOperator::CreateOr(Op0, Op1); +Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { + Value *NotOp; + if (!match(&I, m_Not(m_Value(NotOp)))) + return nullptr; // Apply DeMorgan's Law for 'nand' / 'nor' logic with an inverted operand. - Value *X, *Y; - // We must eliminate the and/or (one-use) for these transforms to not increase // the instruction count. + // // ~(~X & Y) --> (X | ~Y) // ~(Y & ~X) --> (X | ~Y) - if (match(&I, m_Not(m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y)))))) { + // + // Note: The logical matches do not check for the commuted patterns because + // those are handled via SimplifySelectsFeedingBinaryOp(). + Type *Ty = I.getType(); + Value *X, *Y; + if (match(NotOp, m_OneUse(m_c_And(m_Not(m_Value(X)), m_Value(Y))))) { Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); return BinaryOperator::CreateOr(X, NotY); } + if (match(NotOp, m_OneUse(m_LogicalAnd(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, ConstantInt::getTrue(Ty), NotY); + } + // ~(~X | Y) --> (X & ~Y) // ~(Y | ~X) --> (X & ~Y) - if (match(&I, m_Not(m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y)))))) { + if (match(NotOp, m_OneUse(m_c_Or(m_Not(m_Value(X)), m_Value(Y))))) { Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); return BinaryOperator::CreateAnd(X, NotY); } - - if (Instruction *Xor = visitMaskedMerge(I, Builder)) - return Xor; + if (match(NotOp, m_OneUse(m_LogicalOr(m_Not(m_Value(X)), m_Value(Y))))) { + Value *NotY = Builder.CreateNot(Y, Y->getName() + ".not"); + return SelectInst::Create(X, NotY, ConstantInt::getFalse(Ty)); + } // Is this a 'not' (~) fed by a binary operator? BinaryOperator *NotVal; - if (match(&I, m_Not(m_BinOp(NotVal)))) { + if (match(NotOp, m_BinOp(NotVal))) { if (NotVal->getOpcode() == Instruction::And || NotVal->getOpcode() == Instruction::Or) { // Apply DeMorgan's Law when inverts are free: @@ -3411,9 +3299,164 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { NotVal); } - // Use DeMorgan and reassociation to eliminate a 'not' op. + // not (cmp A, B) = !cmp A, B + CmpInst::Predicate Pred; + if (match(NotOp, m_OneUse(m_Cmp(Pred, m_Value(), m_Value())))) { + cast<CmpInst>(NotOp)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, NotOp); + } + + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // ~min(~X, ~Y) --> max(X, Y) + // ~max(~X, Y) --> min(X, ~Y) + auto *II = dyn_cast<IntrinsicInst>(NotOp); + if (II && II->hasOneUse()) { + if (match(NotOp, m_MaxOrMin(m_Value(X), m_Value(Y))) && + isFreeToInvert(X, X->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotX = Builder.CreateNot(X); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, NotX, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + if (match(NotOp, m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y)))) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); + Value *NotY = Builder.CreateNot(Y); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); + return replaceInstUsesWith(I, InvMaxMin); + } + } + + // TODO: Remove folds if we canonicalize to intrinsics (see above). + // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: + // + // %notx = xor i32 %x, -1 + // %cmp1 = icmp sgt i32 %notx, %y + // %smax = select i1 %cmp1, i32 %notx, i32 %y + // %res = xor i32 %smax, -1 + // => + // %noty = xor i32 %y, -1 + // %cmp2 = icmp slt %x, %noty + // %res = select i1 %cmp2, i32 %x, i32 %noty + // + // Same is applicable for smin/umax/umin. + if (NotOp->hasOneUse()) { + Value *LHS, *RHS; + SelectPatternFlavor SPF = matchSelectPattern(NotOp, LHS, RHS).Flavor; + if (SelectPatternResult::isMinOrMax(SPF)) { + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { + Value *NotY = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); + } + + // It's possible we get here before the not has been simplified, so make + // sure the input to the not isn't freely invertible. + if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { + Value *NotX = Builder.CreateNot(LHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); + } + + // If both sides are freely invertible, then we can get rid of the xor + // completely. + if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && + isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { + Value *NotLHS = Builder.CreateNot(LHS); + Value *NotRHS = Builder.CreateNot(RHS); + return SelectInst::Create( + Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), + NotLHS, NotRHS); + } + } + + // Pull 'not' into operands of select if both operands are one-use compares + // or one is one-use compare and the other one is a constant. + // Inverting the predicates eliminates the 'not' operation. + // Example: + // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> + // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) + // not (select ?, (cmp TPred, ?, ?), true --> + // select ?, (cmp InvTPred, ?, ?), false + if (auto *Sel = dyn_cast<SelectInst>(NotOp)) { + Value *TV = Sel->getTrueValue(); + Value *FV = Sel->getFalseValue(); + auto *CmpT = dyn_cast<CmpInst>(TV); + auto *CmpF = dyn_cast<CmpInst>(FV); + bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); + bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); + if (InvertibleT && InvertibleF) { + if (CmpT) + CmpT->setPredicate(CmpT->getInversePredicate()); + else + Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); + if (CmpF) + CmpF->setPredicate(CmpF->getInversePredicate()); + else + Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); + return replaceInstUsesWith(I, Sel); + } + } + } + + if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) + return NewXor; + + return nullptr; +} + +// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches +// here. We should standardize that construct where it is needed or choose some +// other way to ensure that commutated variants of patterns are not missed. +Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { + if (Value *V = SimplifyXorInst(I.getOperand(0), I.getOperand(1), + SQ.getWithInstruction(&I))) + return replaceInstUsesWith(I, V); + + if (SimplifyAssociativeOrCommutative(I)) + return &I; + + if (Instruction *X = foldVectorBinop(I)) + return X; + + if (Instruction *NewXor = foldXorToXor(I, Builder)) + return NewXor; + + // (A&B)^(A&C) -> A&(B^C) etc + if (Value *V = SimplifyUsingDistributiveLaws(I)) + return replaceInstUsesWith(I, V); + + // See if we can simplify any instructions used by the instruction whose sole + // purpose is to compute bits we don't care about. + if (SimplifyDemandedInstructionBits(I)) + return &I; + + if (Value *V = SimplifyBSwap(I, Builder)) + return replaceInstUsesWith(I, V); + + if (Instruction *R = foldNot(I)) + return R; + + // Fold (X & M) ^ (Y & ~M) -> (X & M) | (Y & ~M) + // This it a special case in haveNoCommonBitsSet, but the computeKnownBits + // calls in there are unnecessary as SimplifyDemandedInstructionBits should + // have already taken care of those cases. + Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + Value *M; + if (match(&I, m_c_Xor(m_c_And(m_Not(m_Value(M)), m_Value()), + m_c_And(m_Deferred(M), m_Value())))) + return BinaryOperator::CreateOr(Op0, Op1); + + if (Instruction *Xor = visitMaskedMerge(I, Builder)) + return Xor; + + Value *X, *Y; Constant *C1; if (match(Op1, m_Constant(C1))) { + // Use DeMorgan and reassociation to eliminate a 'not' op. Constant *C2; if (match(Op0, m_OneUse(m_Or(m_Not(m_Value(X)), m_Constant(C2))))) { // (~X | C2) ^ C1 --> ((X & ~C2) ^ -1) ^ C1 --> (X & ~C2) ^ ~C1 @@ -3425,15 +3468,24 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { Value *Or = Builder.CreateOr(X, ConstantExpr::getNot(C2)); return BinaryOperator::CreateXor(Or, ConstantExpr::getNot(C1)); } - } - // not (cmp A, B) = !cmp A, B - CmpInst::Predicate Pred; - if (match(&I, m_Not(m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))))) { - cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); - return replaceInstUsesWith(I, Op0); + // Convert xor ([trunc] (ashr X, BW-1)), C => + // select(X >s -1, C, ~C) + // The ashr creates "AllZeroOrAllOne's", which then optionally inverses the + // constant depending on whether this input is less than 0. + const APInt *CA; + if (match(Op0, m_OneUse(m_TruncOrSelf( + m_AShr(m_Value(X), m_APIntAllowUndef(CA))))) && + *CA == X->getType()->getScalarSizeInBits() - 1 && + !match(C1, m_AllOnes())) { + assert(!C1->isZeroValue() && "Unexpected xor with 0"); + Value *ICmp = + Builder.CreateICmpSGT(X, Constant::getAllOnesValue(X->getType())); + return SelectInst::Create(ICmp, Op1, Builder.CreateNot(Op1)); + } } + Type *Ty = I.getType(); { const APInt *RHSC; if (match(Op1, m_APInt(RHSC))) { @@ -3456,13 +3508,13 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { // canonicalize to a 'not' before the shift to help SCEV and codegen: // (X << C) ^ RHSC --> ~X << C if (match(Op0, m_OneUse(m_Shl(m_Value(X), m_APInt(C)))) && - *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).shl(*C)) { + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).shl(*C)) { Value *NotX = Builder.CreateNot(X); return BinaryOperator::CreateShl(NotX, ConstantInt::get(Ty, *C)); } // (X >>u C) ^ RHSC --> ~X >>u C if (match(Op0, m_OneUse(m_LShr(m_Value(X), m_APInt(C)))) && - *RHSC == APInt::getAllOnesValue(Ty->getScalarSizeInBits()).lshr(*C)) { + *RHSC == APInt::getAllOnes(Ty->getScalarSizeInBits()).lshr(*C)) { Value *NotX = Builder.CreateNot(X); return BinaryOperator::CreateLShr(NotX, ConstantInt::get(Ty, *C)); } @@ -3572,101 +3624,6 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; - // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: - // ~min(~X, ~Y) --> max(X, Y) - // ~max(~X, Y) --> min(X, ~Y) - auto *II = dyn_cast<IntrinsicInst>(Op0); - if (II && match(Op1, m_AllOnes())) { - if (match(Op0, m_MaxOrMin(m_Not(m_Value(X)), m_Not(m_Value(Y))))) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); - return replaceInstUsesWith(I, InvMaxMin); - } - if (match(Op0, m_OneUse(m_c_MaxOrMin(m_Not(m_Value(X)), m_Value(Y))))) { - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(II->getIntrinsicID()); - Value *NotY = Builder.CreateNot(Y); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotY); - return replaceInstUsesWith(I, InvMaxMin); - } - } - - // TODO: Remove folds if we canonicalize to intrinsics (see above). - // Eliminate a bitwise 'not' op of 'not' min/max by inverting the min/max: - // - // %notx = xor i32 %x, -1 - // %cmp1 = icmp sgt i32 %notx, %y - // %smax = select i1 %cmp1, i32 %notx, i32 %y - // %res = xor i32 %smax, -1 - // => - // %noty = xor i32 %y, -1 - // %cmp2 = icmp slt %x, %noty - // %res = select i1 %cmp2, i32 %x, i32 %noty - // - // Same is applicable for smin/umax/umin. - if (match(Op1, m_AllOnes()) && Op0->hasOneUse()) { - Value *LHS, *RHS; - SelectPatternFlavor SPF = matchSelectPattern(Op0, LHS, RHS).Flavor; - if (SelectPatternResult::isMinOrMax(SPF)) { - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(LHS, m_Not(m_Value(X))) && !isFreeToInvert(X, X->hasOneUse())) { - Value *NotY = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), X, NotY), X, NotY); - } - - // It's possible we get here before the not has been simplified, so make - // sure the input to the not isn't freely invertible. - if (match(RHS, m_Not(m_Value(Y))) && !isFreeToInvert(Y, Y->hasOneUse())) { - Value *NotX = Builder.CreateNot(LHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotX, Y), NotX, Y); - } - - // If both sides are freely invertible, then we can get rid of the xor - // completely. - if (isFreeToInvert(LHS, !LHS->hasNUsesOrMore(3)) && - isFreeToInvert(RHS, !RHS->hasNUsesOrMore(3))) { - Value *NotLHS = Builder.CreateNot(LHS); - Value *NotRHS = Builder.CreateNot(RHS); - return SelectInst::Create( - Builder.CreateICmp(getInverseMinMaxPred(SPF), NotLHS, NotRHS), - NotLHS, NotRHS); - } - } - - // Pull 'not' into operands of select if both operands are one-use compares - // or one is one-use compare and the other one is a constant. - // Inverting the predicates eliminates the 'not' operation. - // Example: - // not (select ?, (cmp TPred, ?, ?), (cmp FPred, ?, ?) --> - // select ?, (cmp InvTPred, ?, ?), (cmp InvFPred, ?, ?) - // not (select ?, (cmp TPred, ?, ?), true --> - // select ?, (cmp InvTPred, ?, ?), false - if (auto *Sel = dyn_cast<SelectInst>(Op0)) { - Value *TV = Sel->getTrueValue(); - Value *FV = Sel->getFalseValue(); - auto *CmpT = dyn_cast<CmpInst>(TV); - auto *CmpF = dyn_cast<CmpInst>(FV); - bool InvertibleT = (CmpT && CmpT->hasOneUse()) || isa<Constant>(TV); - bool InvertibleF = (CmpF && CmpF->hasOneUse()) || isa<Constant>(FV); - if (InvertibleT && InvertibleF) { - if (CmpT) - CmpT->setPredicate(CmpT->getInversePredicate()); - else - Sel->setTrueValue(ConstantExpr::getNot(cast<Constant>(TV))); - if (CmpF) - CmpF->setPredicate(CmpF->getInversePredicate()); - else - Sel->setFalseValue(ConstantExpr::getNot(cast<Constant>(FV))); - return replaceInstUsesWith(I, Sel); - } - } - } - - if (Instruction *NewXor = sinkNotIntoXor(I, Builder)) - return NewXor; - if (Instruction *Abs = canonicalizeAbs(I, Builder)) return Abs; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp index 726bb545be12..bfa7bfa2290a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -67,7 +67,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/AssumeBundleBuilder.h" #include "llvm/Transforms/Utils/Local.h" @@ -79,11 +78,12 @@ #include <utility> #include <vector> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumSimplified, "Number of library calls simplified"); static cl::opt<unsigned> GuardWideningWindow( @@ -513,7 +513,7 @@ static Instruction *foldCttzCtlz(IntrinsicInst &II, InstCombinerImpl &IC) { // If the input to cttz/ctlz is known to be non-zero, // then change the 'ZeroIsUndef' parameter to 'true' // because we know the zero behavior can't affect the result. - if (!Known.One.isNullValue() || + if (!Known.One.isZero() || isKnownNonZero(Op0, IC.getDataLayout(), 0, &IC.getAssumptionCache(), &II, &IC.getDominatorTree())) { if (!match(II.getArgOperand(1), m_One())) @@ -656,8 +656,8 @@ static Value *simplifyNeonTbl1(const IntrinsicInst &II, // comparison to the first NumOperands. static bool haveSameOperands(const IntrinsicInst &I, const IntrinsicInst &E, unsigned NumOperands) { - assert(I.getNumArgOperands() >= NumOperands && "Not enough operands"); - assert(E.getNumArgOperands() >= NumOperands && "Not enough operands"); + assert(I.arg_size() >= NumOperands && "Not enough operands"); + assert(E.arg_size() >= NumOperands && "Not enough operands"); for (unsigned i = 0; i < NumOperands; i++) if (I.getArgOperand(i) != E.getArgOperand(i)) return false; @@ -682,11 +682,11 @@ removeTriviallyEmptyRange(IntrinsicInst &EndI, InstCombinerImpl &IC, BasicBlock::reverse_iterator BI(EndI), BE(EndI.getParent()->rend()); for (; BI != BE; ++BI) { if (auto *I = dyn_cast<IntrinsicInst>(&*BI)) { - if (isa<DbgInfoIntrinsic>(I) || + if (I->isDebugOrPseudoInst() || I->getIntrinsicID() == EndI.getIntrinsicID()) continue; if (IsStart(*I)) { - if (haveSameOperands(EndI, *I, EndI.getNumArgOperands())) { + if (haveSameOperands(EndI, *I, EndI.arg_size())) { IC.eraseInstFromFunction(*I); IC.eraseInstFromFunction(EndI); return true; @@ -710,7 +710,7 @@ Instruction *InstCombinerImpl::visitVAEndInst(VAEndInst &I) { } static CallInst *canonicalizeConstantArg0ToArg1(CallInst &Call) { - assert(Call.getNumArgOperands() > 1 && "Need at least 2 args to swap"); + assert(Call.arg_size() > 1 && "Need at least 2 args to swap"); Value *Arg0 = Call.getArgOperand(0), *Arg1 = Call.getArgOperand(1); if (isa<Constant>(Arg0) && !isa<Constant>(Arg1)) { Call.setArgOperand(0, Arg1); @@ -754,6 +754,45 @@ static Optional<bool> getKnownSign(Value *Op, Instruction *CxtI, ICmpInst::ICMP_SLT, Op, Constant::getNullValue(Op->getType()), CxtI, DL); } +/// Try to canonicalize min/max(X + C0, C1) as min/max(X, C1 - C0) + C0. This +/// can trigger other combines. +static Instruction *moveAddAfterMinMax(IntrinsicInst *II, + InstCombiner::BuilderTy &Builder) { + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + assert((MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin || + MinMaxID == Intrinsic::umax || MinMaxID == Intrinsic::umin) && + "Expected a min or max intrinsic"); + + // TODO: Match vectors with undef elements, but undef may not propagate. + Value *Op0 = II->getArgOperand(0), *Op1 = II->getArgOperand(1); + Value *X; + const APInt *C0, *C1; + if (!match(Op0, m_OneUse(m_Add(m_Value(X), m_APInt(C0)))) || + !match(Op1, m_APInt(C1))) + return nullptr; + + // Check for necessary no-wrap and overflow constraints. + bool IsSigned = MinMaxID == Intrinsic::smax || MinMaxID == Intrinsic::smin; + auto *Add = cast<BinaryOperator>(Op0); + if ((IsSigned && !Add->hasNoSignedWrap()) || + (!IsSigned && !Add->hasNoUnsignedWrap())) + return nullptr; + + // If the constant difference overflows, then instsimplify should reduce the + // min/max to the add or C1. + bool Overflow; + APInt CDiff = + IsSigned ? C1->ssub_ov(*C0, Overflow) : C1->usub_ov(*C0, Overflow); + assert(!Overflow && "Expected simplify of min/max"); + + // min/max (add X, C0), C1 --> add (min/max X, C1 - C0), C0 + // Note: the "mismatched" no-overflow setting does not propagate. + Constant *NewMinMaxC = ConstantInt::get(II->getType(), CDiff); + Value *NewMinMax = Builder.CreateBinaryIntrinsic(MinMaxID, X, NewMinMaxC); + return IsSigned ? BinaryOperator::CreateNSWAdd(NewMinMax, Add->getOperand(1)) + : BinaryOperator::CreateNUWAdd(NewMinMax, Add->getOperand(1)); +} + /// If we have a clamp pattern like max (min X, 42), 41 -- where the output /// can only be one of two possible constant values -- turn that into a select /// of constants. @@ -795,6 +834,63 @@ static Instruction *foldClampRangeOfTwo(IntrinsicInst *II, return SelectInst::Create(Cmp, ConstantInt::get(II->getType(), *C0), I1); } +/// Reduce a sequence of min/max intrinsics with a common operand. +static Instruction *factorizeMinMaxTree(IntrinsicInst *II) { + // Match 3 of the same min/max ops. Example: umin(umin(), umin()). + auto *LHS = dyn_cast<IntrinsicInst>(II->getArgOperand(0)); + auto *RHS = dyn_cast<IntrinsicInst>(II->getArgOperand(1)); + Intrinsic::ID MinMaxID = II->getIntrinsicID(); + if (!LHS || !RHS || LHS->getIntrinsicID() != MinMaxID || + RHS->getIntrinsicID() != MinMaxID || + (!LHS->hasOneUse() && !RHS->hasOneUse())) + return nullptr; + + Value *A = LHS->getArgOperand(0); + Value *B = LHS->getArgOperand(1); + Value *C = RHS->getArgOperand(0); + Value *D = RHS->getArgOperand(1); + + // Look for a common operand. + Value *MinMaxOp = nullptr; + Value *ThirdOp = nullptr; + if (LHS->hasOneUse()) { + // If the LHS is only used in this chain and the RHS is used outside of it, + // reuse the RHS min/max because that will eliminate the LHS. + if (D == A || C == A) { + // min(min(a, b), min(c, a)) --> min(min(c, a), b) + // min(min(a, b), min(a, d)) --> min(min(a, d), b) + MinMaxOp = RHS; + ThirdOp = B; + } else if (D == B || C == B) { + // min(min(a, b), min(c, b)) --> min(min(c, b), a) + // min(min(a, b), min(b, d)) --> min(min(b, d), a) + MinMaxOp = RHS; + ThirdOp = A; + } + } else { + assert(RHS->hasOneUse() && "Expected one-use operand"); + // Reuse the LHS. This will eliminate the RHS. + if (D == A || D == B) { + // min(min(a, b), min(c, a)) --> min(min(a, b), c) + // min(min(a, b), min(c, b)) --> min(min(a, b), c) + MinMaxOp = LHS; + ThirdOp = C; + } else if (C == A || C == B) { + // min(min(a, b), min(b, d)) --> min(min(a, b), d) + // min(min(a, b), min(c, b)) --> min(min(a, b), d) + MinMaxOp = LHS; + ThirdOp = D; + } + } + + if (!MinMaxOp || !ThirdOp) + return nullptr; + + Module *Mod = II->getModule(); + Function *MinMax = Intrinsic::getDeclaration(Mod, MinMaxID, II->getType()); + return CallInst::Create(MinMax, { MinMaxOp, ThirdOp }); +} + /// CallInst simplification. This mostly only handles folding of intrinsic /// instructions. For normal calls, it allows visitCallBase to do the heavy /// lifting. @@ -896,7 +992,7 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (auto *IIFVTy = dyn_cast<FixedVectorType>(II->getType())) { auto VWidth = IIFVTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) { if (V != II) return replaceInstUsesWith(*II, V); @@ -1007,21 +1103,45 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } } - if (match(I0, m_Not(m_Value(X)))) { - // max (not X), (not Y) --> not (min X, Y) - Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); - if (match(I1, m_Not(m_Value(Y))) && + if (IID == Intrinsic::smax || IID == Intrinsic::smin) { + // smax (neg nsw X), (neg nsw Y) --> neg nsw (smin X, Y) + // smin (neg nsw X), (neg nsw Y) --> neg nsw (smax X, Y) + // TODO: Canonicalize neg after min/max if I1 is constant. + if (match(I0, m_NSWNeg(m_Value(X))) && match(I1, m_NSWNeg(m_Value(Y))) && (I0->hasOneUse() || I1->hasOneUse())) { + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, Y); - return BinaryOperator::CreateNot(InvMaxMin); + return BinaryOperator::CreateNSWNeg(InvMaxMin); } - // max (not X), C --> not(min X, ~C) - if (match(I1, m_Constant(C)) && I0->hasOneUse()) { - Constant *NotC = ConstantExpr::getNot(C); - Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, X, NotC); + } + + // If we can eliminate ~A and Y is free to invert: + // max ~A, Y --> ~(min A, ~Y) + // + // Examples: + // max ~A, ~Y --> ~(min A, Y) + // max ~A, C --> ~(min A, ~C) + // max ~A, (max ~Y, ~Z) --> ~min( A, (min Y, Z)) + auto moveNotAfterMinMax = [&](Value *X, Value *Y) -> Instruction * { + Value *A; + if (match(X, m_OneUse(m_Not(m_Value(A)))) && + !isFreeToInvert(A, A->hasOneUse()) && + isFreeToInvert(Y, Y->hasOneUse())) { + Value *NotY = Builder.CreateNot(Y); + Intrinsic::ID InvID = getInverseMinMaxIntrinsic(IID); + Value *InvMaxMin = Builder.CreateBinaryIntrinsic(InvID, A, NotY); return BinaryOperator::CreateNot(InvMaxMin); } - } + return nullptr; + }; + + if (Instruction *I = moveNotAfterMinMax(I0, I1)) + return I; + if (Instruction *I = moveNotAfterMinMax(I1, I0)) + return I; + + if (Instruction *I = moveAddAfterMinMax(II, Builder)) + return I; // smax(X, -X) --> abs(X) // smin(X, -X) --> -abs(X) @@ -1051,11 +1171,17 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Instruction *Sel = foldClampRangeOfTwo(II, Builder)) return Sel; + if (Instruction *SAdd = matchSAddSubSat(*II)) + return SAdd; + if (match(I1, m_ImmConstant())) if (auto *Sel = dyn_cast<SelectInst>(I0)) if (Instruction *R = FoldOpIntoSelect(*II, Sel)) return R; + if (Instruction *NewMinMax = factorizeMinMaxTree(II)) + return NewMinMax; + break; } case Intrinsic::bswap: { @@ -1098,6 +1224,19 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { if (Power->equalsInt(2)) return BinaryOperator::CreateFMulFMF(II->getArgOperand(0), II->getArgOperand(0), II); + + if (!Power->getValue()[0]) { + Value *X; + // If power is even: + // powi(-x, p) -> powi(x, p) + // powi(fabs(x), p) -> powi(x, p) + // powi(copysign(x, y), p) -> powi(x, p) + if (match(II->getArgOperand(0), m_FNeg(m_Value(X))) || + match(II->getArgOperand(0), m_FAbs(m_Value(X))) || + match(II->getArgOperand(0), + m_Intrinsic<Intrinsic::copysign>(m_Value(X), m_Value()))) + return replaceOperand(*II, 0, X); + } } break; @@ -1637,14 +1776,66 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { break; } case Intrinsic::stackrestore: { - // If the save is right next to the restore, remove the restore. This can - // happen when variable allocas are DCE'd. + enum class ClassifyResult { + None, + Alloca, + StackRestore, + CallWithSideEffects, + }; + auto Classify = [](const Instruction *I) { + if (isa<AllocaInst>(I)) + return ClassifyResult::Alloca; + + if (auto *CI = dyn_cast<CallInst>(I)) { + if (auto *II = dyn_cast<IntrinsicInst>(CI)) { + if (II->getIntrinsicID() == Intrinsic::stackrestore) + return ClassifyResult::StackRestore; + + if (II->mayHaveSideEffects()) + return ClassifyResult::CallWithSideEffects; + } else { + // Consider all non-intrinsic calls to be side effects + return ClassifyResult::CallWithSideEffects; + } + } + + return ClassifyResult::None; + }; + + // If the stacksave and the stackrestore are in the same BB, and there is + // no intervening call, alloca, or stackrestore of a different stacksave, + // remove the restore. This can happen when variable allocas are DCE'd. if (IntrinsicInst *SS = dyn_cast<IntrinsicInst>(II->getArgOperand(0))) { - if (SS->getIntrinsicID() == Intrinsic::stacksave) { - // Skip over debug info. - if (SS->getNextNonDebugInstruction() == II) { - return eraseInstFromFunction(CI); + if (SS->getIntrinsicID() == Intrinsic::stacksave && + SS->getParent() == II->getParent()) { + BasicBlock::iterator BI(SS); + bool CannotRemove = false; + for (++BI; &*BI != II; ++BI) { + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. + break; + + case ClassifyResult::StackRestore: + // If we found an intervening stackrestore for a different + // stacksave, we can't remove the stackrestore. Otherwise, continue. + if (cast<IntrinsicInst>(*BI).getArgOperand(0) != SS) + CannotRemove = true; + break; + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic + // call with side effects, we can't remove the stackrestore. + CannotRemove = true; + break; + } + if (CannotRemove) + break; } + + if (!CannotRemove) + return eraseInstFromFunction(CI); } } @@ -1654,29 +1845,25 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { Instruction *TI = II->getParent()->getTerminator(); bool CannotRemove = false; for (++BI; &*BI != TI; ++BI) { - if (isa<AllocaInst>(BI)) { - CannotRemove = true; + switch (Classify(&*BI)) { + case ClassifyResult::None: + // So far so good, look at next instructions. break; - } - if (CallInst *BCI = dyn_cast<CallInst>(BI)) { - if (auto *II2 = dyn_cast<IntrinsicInst>(BCI)) { - // If there is a stackrestore below this one, remove this one. - if (II2->getIntrinsicID() == Intrinsic::stackrestore) - return eraseInstFromFunction(CI); - // Bail if we cross over an intrinsic with side effects, such as - // llvm.stacksave, or llvm.read_register. - if (II2->mayHaveSideEffects()) { - CannotRemove = true; - break; - } - } else { - // If we found a non-intrinsic call, we can't remove the stack - // restore. - CannotRemove = true; - break; - } + case ClassifyResult::StackRestore: + // If there is a stackrestore below this one, remove this one. + return eraseInstFromFunction(CI); + + case ClassifyResult::Alloca: + case ClassifyResult::CallWithSideEffects: + // If we found an alloca, a non-intrinsic call, or an intrinsic call + // with side effects (such as llvm.stacksave and llvm.read_register), + // we can't remove the stack restore. + CannotRemove = true; + break; } + if (CannotRemove) + break; } // If the stack restore is in a return, resume, or unwind block and if there @@ -1963,6 +2150,46 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } break; } + case Intrinsic::experimental_vector_reverse: { + Value *BO0, *BO1, *X, *Y; + Value *Vec = II->getArgOperand(0); + if (match(Vec, m_OneUse(m_BinOp(m_Value(BO0), m_Value(BO1))))) { + auto *OldBinOp = cast<BinaryOperator>(Vec); + if (match(BO0, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))) { + // rev(binop rev(X), rev(Y)) --> binop X, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y)))) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, Y, OldBinOp, + OldBinOp->getName(), II)); + // rev(binop rev(X), BO1Splat) --> binop X, BO1Splat + if (isSplatValue(BO1)) + return replaceInstUsesWith(CI, + BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), X, BO1, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(binop BO0Splat, rev(Y)) --> binop BO0Splat, Y + if (match(BO1, m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(Y))) && + isSplatValue(BO0)) + return replaceInstUsesWith(CI, BinaryOperator::CreateWithCopiedFlags( + OldBinOp->getOpcode(), BO0, Y, + OldBinOp, OldBinOp->getName(), II)); + } + // rev(unop rev(X)) --> unop X + if (match(Vec, m_OneUse(m_UnOp( + m_Intrinsic<Intrinsic::experimental_vector_reverse>( + m_Value(X)))))) { + auto *OldUnOp = cast<UnaryOperator>(Vec); + auto *NewUnOp = UnaryOperator::CreateWithCopiedFlags( + OldUnOp->getOpcode(), X, OldUnOp, OldUnOp->getName(), II); + return replaceInstUsesWith(CI, NewUnOp); + } + break; + } case Intrinsic::vector_reduce_or: case Intrinsic::vector_reduce_and: { // Canonicalize logical or/and reductions: @@ -1973,21 +2200,26 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { // %val = bitcast <ReduxWidth x i1> to iReduxWidth // %res = cmp eq iReduxWidth %val, 11111 Value *Arg = II->getArgOperand(0); - Type *RetTy = II->getType(); - if (RetTy == Builder.getInt1Ty()) - if (auto *FVTy = dyn_cast<FixedVectorType>(Arg->getType())) { - Value *Res = Builder.CreateBitCast( - Arg, Builder.getIntNTy(FVTy->getNumElements())); - if (IID == Intrinsic::vector_reduce_and) { - Res = Builder.CreateICmpEQ( - Res, ConstantInt::getAllOnesValue(Res->getType())); - } else { - assert(IID == Intrinsic::vector_reduce_or && - "Expected or reduction."); - Res = Builder.CreateIsNotNull(Res); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateBitCast( + Vect, Builder.getIntNTy(FTy->getNumElements())); + if (IID == Intrinsic::vector_reduce_and) { + Res = Builder.CreateICmpEQ( + Res, ConstantInt::getAllOnesValue(Res->getType())); + } else { + assert(IID == Intrinsic::vector_reduce_or && + "Expected or reduction."); + Res = Builder.CreateIsNotNull(Res); + } + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); } - return replaceInstUsesWith(CI, Res); - } + } LLVM_FALLTHROUGH; } case Intrinsic::vector_reduce_add: { @@ -2017,12 +2249,117 @@ Instruction *InstCombinerImpl::visitCallInst(CallInst &CI) { } LLVM_FALLTHROUGH; } - case Intrinsic::vector_reduce_mul: - case Intrinsic::vector_reduce_xor: - case Intrinsic::vector_reduce_umax: + case Intrinsic::vector_reduce_xor: { + if (IID == Intrinsic::vector_reduce_xor) { + // Exclusive disjunction reduction over the vector with + // (potentially-extended) i1 element type is actually a + // (potentially-extended) arithmetic `add` reduction over the original + // non-extended value: + // vector_reduce_xor(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_add(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAddReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } + case Intrinsic::vector_reduce_mul: { + if (IID == Intrinsic::vector_reduce_mul) { + // Multiplicative reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially zero-extended) + // logical `and` reduction over the original non-extended value: + // vector_reduce_mul(?ext(<n x i1>)) + // --> + // zext(vector_reduce_and(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = Builder.CreateAndReduce(Vect); + if (Res->getType() != II->getType()) + Res = Builder.CreateZExt(Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_umin: - case Intrinsic::vector_reduce_smax: + case Intrinsic::vector_reduce_umax: { + if (IID == Intrinsic::vector_reduce_umin || + IID == Intrinsic::vector_reduce_umax) { + // UMin/UMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_u{min,max}(?ext(<n x i1>)) + // --> + // ?ext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Value *Res = IID == Intrinsic::vector_reduce_umin + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(cast<CastInst>(Arg)->getOpcode(), Res, + II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_smin: + case Intrinsic::vector_reduce_smax: { + if (IID == Intrinsic::vector_reduce_smin || + IID == Intrinsic::vector_reduce_smax) { + // SMin/SMax reduction over the vector with (potentially-extended) + // i1 element type is actually a (potentially-extended) + // logical `and`/`or` reduction over the original non-extended value: + // vector_reduce_s{min,max}(<n x i1>) + // --> + // vector_reduce_{or,and}(<n x i1>) + // and + // vector_reduce_s{min,max}(sext(<n x i1>)) + // --> + // sext(vector_reduce_{or,and}(<n x i1>)) + // and + // vector_reduce_s{min,max}(zext(<n x i1>)) + // --> + // zext(vector_reduce_{and,or}(<n x i1>)) + Value *Arg = II->getArgOperand(0); + Value *Vect; + if (match(Arg, m_ZExtOrSExtOrSelf(m_Value(Vect)))) { + if (auto *FTy = dyn_cast<FixedVectorType>(Vect->getType())) + if (FTy->getElementType() == Builder.getInt1Ty()) { + Instruction::CastOps ExtOpc = Instruction::CastOps::CastOpsEnd; + if (Arg != Vect) + ExtOpc = cast<CastInst>(Arg)->getOpcode(); + Value *Res = ((IID == Intrinsic::vector_reduce_smin) == + (ExtOpc == Instruction::CastOps::ZExt)) + ? Builder.CreateAndReduce(Vect) + : Builder.CreateOrReduce(Vect); + if (Arg != Vect) + Res = Builder.CreateCast(ExtOpc, Res, II->getType()); + return replaceInstUsesWith(CI, Res); + } + } + } + LLVM_FALLTHROUGH; + } case Intrinsic::vector_reduce_fmax: case Intrinsic::vector_reduce_fmin: case Intrinsic::vector_reduce_fadd: @@ -2228,7 +2565,7 @@ static IntrinsicInst *findInitTrampoline(Value *Callee) { } void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI) { - unsigned NumArgs = Call.getNumArgOperands(); + unsigned NumArgs = Call.arg_size(); ConstantInt *Op0C = dyn_cast<ConstantInt>(Call.getOperand(0)); ConstantInt *Op1C = (NumArgs == 1) ? nullptr : dyn_cast<ConstantInt>(Call.getOperand(1)); @@ -2239,55 +2576,46 @@ void InstCombinerImpl::annotateAnyAllocSite(CallBase &Call, const TargetLibraryI if (isMallocLikeFn(&Call, TLI) && Op0C) { if (isOpNewLikeFn(&Call, TLI)) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableBytes( - Call.getContext(), Op0C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableBytes( + Call.getContext(), Op0C->getZExtValue())); else - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op0C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op0C->getZExtValue())); } else if (isAlignedAllocLikeFn(&Call, TLI)) { if (Op1C) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); // Add alignment attribute if alignment is a power of two constant. if (Op0C && Op0C->getValue().ult(llvm::Value::MaximumAlignment) && isKnownNonZero(Call.getOperand(1), DL, 0, &AC, &Call, &DT)) { uint64_t AlignmentVal = Op0C->getZExtValue(); if (llvm::isPowerOf2_64(AlignmentVal)) { - Call.removeAttribute(AttributeList::ReturnIndex, Attribute::Alignment); - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithAlignment(Call.getContext(), - Align(AlignmentVal))); + Call.removeRetAttr(Attribute::Alignment); + Call.addRetAttr(Attribute::getWithAlignment(Call.getContext(), + Align(AlignmentVal))); } } } else if (isReallocLikeFn(&Call, TLI) && Op1C) { - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Op1C->getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Op1C->getZExtValue())); } else if (isCallocLikeFn(&Call, TLI) && Op0C && Op1C) { bool Overflow; const APInt &N = Op0C->getValue(); APInt Size = N.umul_ov(Op1C->getValue(), Overflow); if (!Overflow) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Size.getZExtValue())); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Size.getZExtValue())); } else if (isStrdupLikeFn(&Call, TLI)) { uint64_t Len = GetStringLength(Call.getOperand(0)); if (Len) { // strdup if (NumArgs == 1) - Call.addAttribute(AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), Len)); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), Len)); // strndup else if (NumArgs == 2 && Op1C) - Call.addAttribute( - AttributeList::ReturnIndex, - Attribute::getWithDereferenceableOrNullBytes( - Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); + Call.addRetAttr(Attribute::getWithDereferenceableOrNullBytes( + Call.getContext(), std::min(Len, Op1C->getZExtValue() + 1))); } } } @@ -2489,7 +2817,7 @@ Instruction *InstCombinerImpl::visitCallBase(CallBase &Call) { // isKnownNonNull -> nonnull attribute if (!GCR.hasRetAttr(Attribute::NonNull) && isKnownNonZero(DerivedPtr, DL, 0, &AC, &Call, &DT)) { - GCR.addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); + GCR.addRetAttr(Attribute::NonNull); // We discovered new fact, re-check users. Worklist.pushUsersToWorkList(GCR); } @@ -2646,19 +2974,19 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { if (!CastInst::isBitOrNoopPointerCastable(ActTy, ParamTy, DL)) return false; // Cannot transform this parameter value. - if (AttrBuilder(CallerPAL.getParamAttributes(i)) + if (AttrBuilder(CallerPAL.getParamAttrs(i)) .overlaps(AttributeFuncs::typeIncompatible(ParamTy))) return false; // Attribute not compatible with transformed value. if (Call.isInAllocaArgument(i)) return false; // Cannot transform to and from inalloca. - if (CallerPAL.hasParamAttribute(i, Attribute::SwiftError)) + if (CallerPAL.hasParamAttr(i, Attribute::SwiftError)) return false; // If the parameter is passed as a byval argument, then we have to have a // sized type and the sized type has to have the same size as the old type. - if (ParamTy != ActTy && CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { + if (ParamTy != ActTy && CallerPAL.hasParamAttr(i, Attribute::ByVal)) { PointerType *ParamPTy = dyn_cast<PointerType>(ParamTy); if (!ParamPTy || !ParamPTy->getElementType()->isSized()) return false; @@ -2699,7 +3027,7 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { // that are compatible with being a vararg call argument. unsigned SRetIdx; if (CallerPAL.hasAttrSomewhere(Attribute::StructRet, &SRetIdx) && - SRetIdx > FT->getNumParams()) + SRetIdx - AttributeList::FirstArgIndex >= FT->getNumParams()) return false; } @@ -2728,12 +3056,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Args.push_back(NewArg); // Add any parameter attributes. - if (CallerPAL.hasParamAttribute(i, Attribute::ByVal)) { - AttrBuilder AB(CallerPAL.getParamAttributes(i)); + if (CallerPAL.hasParamAttr(i, Attribute::ByVal)) { + AttrBuilder AB(CallerPAL.getParamAttrs(i)); AB.addByValAttr(NewArg->getType()->getPointerElementType()); ArgAttrs.push_back(AttributeSet::get(Ctx, AB)); } else - ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); + ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); } // If the function takes more arguments than the call was taking, add them @@ -2760,12 +3088,12 @@ bool InstCombinerImpl::transformConstExprCastCall(CallBase &Call) { Args.push_back(NewArg); // Add any parameter attributes. - ArgAttrs.push_back(CallerPAL.getParamAttributes(i)); + ArgAttrs.push_back(CallerPAL.getParamAttrs(i)); } } } - AttributeSet FnAttrs = CallerPAL.getFnAttributes(); + AttributeSet FnAttrs = CallerPAL.getFnAttrs(); if (NewRetTy->isVoidTy()) Caller->setName(""); // Void type should not have a name. @@ -2866,7 +3194,7 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, for (FunctionType::param_iterator I = NestFTy->param_begin(), E = NestFTy->param_end(); I != E; ++NestArgNo, ++I) { - AttributeSet AS = NestAttrs.getParamAttributes(NestArgNo); + AttributeSet AS = NestAttrs.getParamAttrs(NestArgNo); if (AS.hasAttribute(Attribute::Nest)) { // Record the parameter type and any other attributes. NestTy = *I; @@ -2902,7 +3230,7 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, // Add the original argument and attributes. NewArgs.push_back(*I); - NewArgAttrs.push_back(Attrs.getParamAttributes(ArgNo)); + NewArgAttrs.push_back(Attrs.getParamAttrs(ArgNo)); ++ArgNo; ++I; @@ -2948,8 +3276,8 @@ InstCombinerImpl::transformCallThroughTrampoline(CallBase &Call, NestF : ConstantExpr::getBitCast(NestF, PointerType::getUnqual(NewFTy)); AttributeList NewPAL = - AttributeList::get(FTy->getContext(), Attrs.getFnAttributes(), - Attrs.getRetAttributes(), NewArgAttrs); + AttributeList::get(FTy->getContext(), Attrs.getFnAttrs(), + Attrs.getRetAttrs(), NewArgAttrs); SmallVector<OperandBundleDef, 1> OpBundles; Call.getOperandBundlesAsDefs(OpBundles); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp index 04877bec94ec..ca87477c5d81 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -333,7 +333,7 @@ Instruction *InstCombinerImpl::commonCastTransforms(CastInst &CI) { SrcTy->getNumElements() == DestTy->getNumElements() && SrcTy->getPrimitiveSizeInBits() == DestTy->getPrimitiveSizeInBits()) { Value *CastX = Builder.CreateCast(CI.getOpcode(), X, DestTy); - return new ShuffleVectorInst(CastX, UndefValue::get(DestTy), Mask); + return new ShuffleVectorInst(CastX, Mask); } } @@ -701,10 +701,10 @@ static Instruction *shrinkSplatShuffle(TruncInst &Trunc, if (Shuf && Shuf->hasOneUse() && match(Shuf->getOperand(1), m_Undef()) && is_splat(Shuf->getShuffleMask()) && Shuf->getType() == Shuf->getOperand(0)->getType()) { - // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Undef, SplatMask - Constant *NarrowUndef = UndefValue::get(Trunc.getType()); + // trunc (shuf X, Undef, SplatMask) --> shuf (trunc X), Poison, SplatMask + // trunc (shuf X, Poison, SplatMask) --> shuf (trunc X), Poison, SplatMask Value *NarrowOp = Builder.CreateTrunc(Shuf->getOperand(0), Trunc.getType()); - return new ShuffleVectorInst(NarrowOp, NarrowUndef, Shuf->getShuffleMask()); + return new ShuffleVectorInst(NarrowOp, Shuf->getShuffleMask()); } return nullptr; @@ -961,14 +961,25 @@ Instruction *InstCombinerImpl::visitTrunc(TruncInst &Trunc) { return BinaryOperator::CreateAdd(NarrowCtlz, WidthDiff); } } + + if (match(Src, m_VScale(DL))) { + if (Trunc.getFunction() && + Trunc.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = Trunc.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (MaxVScale > 0 && Log2_32(MaxVScale) < DestWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(Trunc, VScale); + } + } + } + return nullptr; } -/// Transform (zext icmp) to bitwise / integer operations in order to -/// eliminate it. If DoTransform is false, just test whether the given -/// (zext icmp) can be transformed. -Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, - bool DoTransform) { +Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext) { // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. @@ -977,10 +988,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || - (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { - if (!DoTransform) return Cmp; - + if ((Cmp->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isZero()) || + (Cmp->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnes())) { Value *In = Cmp->getOperand(0); Value *Sh = ConstantInt::get(In->getType(), In->getType()->getScalarSizeInBits() - 1); @@ -1004,7 +1013,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && + if ((Op1CV->isZero() || Op1CV->isPowerOf2()) && // This only works for EQ and NE Cmp->isEquality()) { // If Op1C some other power of two, convert: @@ -1012,10 +1021,8 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, APInt KnownZeroMask(~Known.Zero); if (KnownZeroMask.isPowerOf2()) { // Exactly 1 possible 1? - if (!DoTransform) return Cmp; - bool isNE = Cmp->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { + if (!Op1CV->isZero() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true Constant *Res = ConstantInt::get(Zext.getType(), isNE); @@ -1031,7 +1038,7 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, In->getName() + ".lobit"); } - if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit. + if (!Op1CV->isZero() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One); } @@ -1053,9 +1060,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, if (Cmp->hasOneUse() && match(Cmp->getOperand(1), m_ZeroInt()) && match(Cmp->getOperand(0), m_OneUse(m_c_And(m_Shl(m_One(), m_Value(ShAmt)), m_Value(X))))) { - if (!DoTransform) - return Cmp; - if (Cmp->getPredicate() == ICmpInst::ICMP_EQ) X = Builder.CreateNot(X); Value *Lshr = Builder.CreateLShr(X, ShAmt); @@ -1077,8 +1081,6 @@ Instruction *InstCombinerImpl::transformZExtICmp(ICmpInst *Cmp, ZExtInst &Zext, APInt KnownBits = KnownLHS.Zero | KnownLHS.One; APInt UnknownBit = ~KnownBits; if (UnknownBit.countPopulation() == 1) { - if (!DoTransform) return Cmp; - Value *Result = Builder.CreateXor(LHS, RHS); // Mask off any bits that are set and won't be shifted away. @@ -1316,51 +1318,37 @@ Instruction *InstCombinerImpl::visitZExt(ZExtInst &CI) { if (ICmpInst *Cmp = dyn_cast<ICmpInst>(Src)) return transformZExtICmp(Cmp, CI); - BinaryOperator *SrcI = dyn_cast<BinaryOperator>(Src); - if (SrcI && SrcI->getOpcode() == Instruction::Or) { - // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) if at least one - // of the (zext icmp) can be eliminated. If so, immediately perform the - // according elimination. - ICmpInst *LHS = dyn_cast<ICmpInst>(SrcI->getOperand(0)); - ICmpInst *RHS = dyn_cast<ICmpInst>(SrcI->getOperand(1)); - if (LHS && RHS && LHS->hasOneUse() && RHS->hasOneUse() && - LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType() && - (transformZExtICmp(LHS, CI, false) || - transformZExtICmp(RHS, CI, false))) { - // zext (or icmp, icmp) -> or (zext icmp), (zext icmp) - Value *LCast = Builder.CreateZExt(LHS, CI.getType(), LHS->getName()); - Value *RCast = Builder.CreateZExt(RHS, CI.getType(), RHS->getName()); - Value *Or = Builder.CreateOr(LCast, RCast, CI.getName()); - if (auto *OrInst = dyn_cast<Instruction>(Or)) - Builder.SetInsertPoint(OrInst); - - // Perform the elimination. - if (auto *LZExt = dyn_cast<ZExtInst>(LCast)) - transformZExtICmp(LHS, *LZExt); - if (auto *RZExt = dyn_cast<ZExtInst>(RCast)) - transformZExtICmp(RHS, *RZExt); - - return replaceInstUsesWith(CI, Or); - } - } - // zext(trunc(X) & C) -> (X & zext(C)). Constant *C; Value *X; - if (SrcI && - match(SrcI, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && + if (match(Src, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Constant(C)))) && X->getType() == CI.getType()) return BinaryOperator::CreateAnd(X, ConstantExpr::getZExt(C, CI.getType())); // zext((trunc(X) & C) ^ C) -> ((X & zext(C)) ^ zext(C)). Value *And; - if (SrcI && match(SrcI, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && + if (match(Src, m_OneUse(m_Xor(m_Value(And), m_Constant(C)))) && match(And, m_OneUse(m_And(m_Trunc(m_Value(X)), m_Specific(C)))) && X->getType() == CI.getType()) { Constant *ZC = ConstantExpr::getZExt(C, CI.getType()); return BinaryOperator::CreateXor(Builder.CreateAnd(X, ZC), ZC); } + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = CI.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + unsigned TypeWidth = Src->getType()->getScalarSizeInBits(); + if (MaxVScale > 0 && Log2_32(MaxVScale) < TypeWidth) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + return nullptr; } @@ -1605,6 +1593,32 @@ Instruction *InstCombinerImpl::visitSExt(SExtInst &CI) { return BinaryOperator::CreateAShr(A, NewShAmt); } + // Splatting a bit of constant-index across a value: + // sext (ashr (trunc iN X to iM), M-1) to iN --> ashr (shl X, N-M), N-1 + // TODO: If the dest type is different, use a cast (adjust use check). + if (match(Src, m_OneUse(m_AShr(m_Trunc(m_Value(X)), + m_SpecificInt(SrcBitSize - 1)))) && + X->getType() == DestTy) { + Constant *ShlAmtC = ConstantInt::get(DestTy, DestBitSize - SrcBitSize); + Constant *AshrAmtC = ConstantInt::get(DestTy, DestBitSize - 1); + Value *Shl = Builder.CreateShl(X, ShlAmtC); + return BinaryOperator::CreateAShr(Shl, AshrAmtC); + } + + if (match(Src, m_VScale(DL))) { + if (CI.getFunction() && + CI.getFunction()->hasFnAttribute(Attribute::VScaleRange)) { + unsigned MaxVScale = CI.getFunction() + ->getFnAttribute(Attribute::VScaleRange) + .getVScaleRangeArgs() + .second; + if (MaxVScale > 0 && Log2_32(MaxVScale) < (SrcBitSize - 1)) { + Value *VScale = Builder.CreateVScale(ConstantInt::get(DestTy, 1)); + return replaceInstUsesWith(CI, VScale); + } + } + } + return nullptr; } @@ -2060,6 +2074,19 @@ Instruction *InstCombinerImpl::visitPtrToInt(PtrToIntInst &CI) { return CastInst::CreateIntegerCast(P, Ty, /*isSigned=*/false); } + if (auto *GEP = dyn_cast<GetElementPtrInst>(SrcOp)) { + // Fold ptrtoint(gep null, x) to multiply + constant if the GEP has one use. + // While this can increase the number of instructions it doesn't actually + // increase the overall complexity since the arithmetic is just part of + // the GEP otherwise. + if (GEP->hasOneUse() && + isa<ConstantPointerNull>(GEP->getPointerOperand())) { + return replaceInstUsesWith(CI, + Builder.CreateIntCast(EmitGEPOffset(GEP), Ty, + /*isSigned=*/false)); + } + } + Value *Vec, *Scalar, *Index; if (match(SrcOp, m_OneUse(m_InsertElt(m_IntToPtr(m_Value(Vec)), m_Value(Scalar), m_Value(Index)))) && @@ -2133,9 +2160,9 @@ optimizeVectorResizeWithIntegerBitCasts(Value *InVal, VectorType *DestTy, if (SrcElts > DestElts) { // If we're shrinking the number of elements (rewriting an integer // truncate), just shuffle in the elements corresponding to the least - // significant bits from the input and use undef as the second shuffle + // significant bits from the input and use poison as the second shuffle // input. - V2 = UndefValue::get(SrcTy); + V2 = PoisonValue::get(SrcTy); // Make sure the shuffle mask selects the "least significant bits" by // keeping elements from back of the src vector for big endian, and from the // front for little endian. @@ -2528,7 +2555,7 @@ Instruction *InstCombinerImpl::optimizeBitCastFromPhi(CastInst &CI, // As long as the user is another old PHI node, then even if we don't // rewrite it, the PHI web we're considering won't have any users // outside itself, so it'll be dead. - if (OldPhiNodes.count(PHI) == 0) + if (!OldPhiNodes.contains(PHI)) return nullptr; } else { return nullptr; @@ -2736,6 +2763,30 @@ Instruction *InstCombinerImpl::visitBitCast(BitCastInst &CI) { if (auto *InsElt = dyn_cast<InsertElementInst>(Src)) return new BitCastInst(InsElt->getOperand(1), DestTy); } + + // Convert an artificial vector insert into more analyzable bitwise logic. + unsigned BitWidth = DestTy->getScalarSizeInBits(); + Value *X, *Y; + uint64_t IndexC; + if (match(Src, m_OneUse(m_InsertElt(m_OneUse(m_BitCast(m_Value(X))), + m_Value(Y), m_ConstantInt(IndexC)))) && + DestTy->isIntegerTy() && X->getType() == DestTy && + isDesirableIntType(BitWidth)) { + // Adjust for big endian - the LSBs are at the high index. + if (DL.isBigEndian()) + IndexC = SrcVTy->getNumElements() - 1 - IndexC; + + // We only handle (endian-normalized) insert to index 0. Any other insert + // would require a left-shift, so that is an extra instruction. + if (IndexC == 0) { + // bitcast (inselt (bitcast X), Y, 0) --> or (and X, MaskC), (zext Y) + unsigned EltWidth = Y->getType()->getScalarSizeInBits(); + APInt MaskC = APInt::getHighBitsSet(BitWidth, BitWidth - EltWidth); + Value *AndX = Builder.CreateAnd(X, MaskC); + Value *ZextY = Builder.CreateZExt(Y, DestTy); + return BinaryOperator::CreateOr(AndX, ZextY); + } + } } if (auto *Shuf = dyn_cast<ShuffleVectorInst>(Src)) { diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index 2b0ef0c5f2cc..7a9e177f19da 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -78,15 +78,15 @@ static bool isSignTest(ICmpInst::Predicate &Pred, const APInt &C) { if (!ICmpInst::isSigned(Pred)) return false; - if (C.isNullValue()) + if (C.isZero()) return ICmpInst::isRelational(Pred); - if (C.isOneValue()) { + if (C.isOne()) { if (Pred == ICmpInst::ICMP_SLT) { Pred = ICmpInst::ICMP_SLE; return true; } - } else if (C.isAllOnesValue()) { + } else if (C.isAllOnes()) { if (Pred == ICmpInst::ICMP_SGT) { Pred = ICmpInst::ICMP_SGE; return true; @@ -541,7 +541,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, if (!CI->isNoopCast(DL)) return false; - if (Explored.count(CI->getOperand(0)) == 0) + if (!Explored.contains(CI->getOperand(0))) WorkList.push_back(CI->getOperand(0)); } @@ -553,7 +553,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, GEP->getType() != Start->getType()) return false; - if (Explored.count(GEP->getOperand(0)) == 0) + if (!Explored.contains(GEP->getOperand(0))) WorkList.push_back(GEP->getOperand(0)); } @@ -575,7 +575,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, // Explore the PHI nodes further. for (auto *PN : PHIs) for (Value *Op : PN->incoming_values()) - if (Explored.count(Op) == 0) + if (!Explored.contains(Op)) WorkList.push_back(Op); } @@ -589,7 +589,7 @@ static bool canRewriteGEPAsOffset(Value *Start, Value *Base, auto *Inst = dyn_cast<Instruction>(Val); if (Inst == Base || Inst == PHI || !Inst || !PHI || - Explored.count(PHI) == 0) + !Explored.contains(PHI)) continue; if (PHI->getParent() == Inst->getParent()) @@ -1147,12 +1147,12 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; bool IsAShr = isa<AShrOperator>(I.getOperand(0)); if (IsAShr) { - if (AP2.isAllOnesValue()) + if (AP2.isAllOnes()) return nullptr; if (AP2.isNegative() != AP1.isNegative()) return nullptr; @@ -1178,7 +1178,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstConst(ICmpInst &I, Value *A, if (IsAShr && AP1 == AP2.ashr(Shift)) { // There are multiple solutions if we are comparing against -1 and the LHS // of the ashr is not a power of two. - if (AP1.isAllOnesValue() && !AP2.isPowerOf2()) + if (AP1.isAllOnes() && !AP2.isPowerOf2()) return getICmp(I.ICMP_UGE, A, ConstantInt::get(A->getType(), Shift)); return getICmp(I.ICMP_EQ, A, ConstantInt::get(A->getType(), Shift)); } else if (AP1 == AP2.lshr(Shift)) { @@ -1206,7 +1206,7 @@ Instruction *InstCombinerImpl::foldICmpShlConstConst(ICmpInst &I, Value *A, }; // Don't bother doing any work for cases which InstSimplify handles. - if (AP2.isNullValue()) + if (AP2.isZero()) return nullptr; unsigned AP2TrailingZeros = AP2.countTrailingZeros(); @@ -1270,9 +1270,8 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, // This is only really a signed overflow check if the inputs have been // sign-extended; check for that condition. For example, if CI2 is 2^31 and // the operands of the add are 64 bits wide, we need at least 33 sign bits. - unsigned NeededSignBits = CI1->getBitWidth() - NewWidth + 1; - if (IC.ComputeNumSignBits(A, 0, &I) < NeededSignBits || - IC.ComputeNumSignBits(B, 0, &I) < NeededSignBits) + if (IC.ComputeMinSignedBits(A, 0, &I) > NewWidth || + IC.ComputeMinSignedBits(B, 0, &I) > NewWidth) return nullptr; // In order to replace the original add with a narrower @@ -1544,7 +1543,7 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C.isOneValue() && C.getBitWidth() > 1) { + if (C.isOne() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1725,7 +1724,7 @@ Instruction *InstCombinerImpl::foldICmpAndShift(ICmpInst &Cmp, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isZero() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1749,7 +1748,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // For vectors: icmp ne (and X, 1), 0 --> trunc X to N x i1 // TODO: We canonicalize to the longer form for scalars because we have // better analysis/folds for icmp, and codegen may be better with icmp. - if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isNullValue() && + if (isICMP_NE && Cmp.getType()->isVectorTy() && C1.isZero() && match(And->getOperand(1), m_One())) return new TruncInst(And->getOperand(0), Cmp.getType()); @@ -1762,7 +1761,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, if (!And->hasOneUse()) return nullptr; - if (Cmp.isEquality() && C1.isNullValue()) { + if (Cmp.isEquality() && C1.isZero()) { // Restrict this fold to single-use 'and' (PR10267). // Replace (and X, (1 << size(X)-1) != 0) with X s< 0 if (C2->isSignMask()) { @@ -1812,7 +1811,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstConst(ICmpInst &Cmp, // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && + if (!Cmp.isSigned() && C1.isZero() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1889,7 +1888,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && C.isNegatedPowerOf2()) { auto NewPred = Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1899,7 +1898,7 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isZero() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1920,7 +1919,7 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C.isOneValue()) { + if (C.isOne()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1950,7 +1949,18 @@ Instruction *InstCombinerImpl::foldICmpOrConstant(ICmpInst &Cmp, } } - if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) + // (X | (X-1)) s< 0 --> X s< 1 + // (X | (X-1)) s> -1 --> X s> 0 + Value *X; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(Or, m_c_Or(m_Add(m_Value(X), m_AllOnes()), m_Deferred(X)))) { + auto NewPred = TrueIfSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_SGT; + Constant *NewC = ConstantInt::get(X->getType(), TrueIfSigned ? 1 : 0); + return new ICmpInst(NewPred, X, NewC); + } + + if (!Cmp.isEquality() || !C.isZero() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -2001,14 +2011,14 @@ Instruction *InstCombinerImpl::foldICmpMulConstant(ICmpInst &Cmp, // If the multiply does not wrap, try to divide the compare constant by the // multiplication factor. - if (Cmp.isEquality() && !MulC->isNullValue()) { + if (Cmp.isEquality() && !MulC->isZero()) { // (mul nsw X, MulC) == C --> X == C /s MulC - if (Mul->hasNoSignedWrap() && C.srem(*MulC).isNullValue()) { + if (Mul->hasNoSignedWrap() && C.srem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.sdiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } // (mul nuw X, MulC) == C --> X == C /u MulC - if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isNullValue()) { + if (Mul->hasNoUnsignedWrap() && C.urem(*MulC).isZero()) { Constant *NewC = ConstantInt::get(Mul->getType(), C.udiv(*MulC)); return new ICmpInst(Pred, Mul->getOperand(0), NewC); } @@ -2053,7 +2063,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C.isAllOnesValue()) { + if (C.isAllOnes()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -2227,8 +2237,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); - if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C.isNullValue()) + if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && C.isZero()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; @@ -2316,7 +2325,7 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Shr->isExact()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - if (C.isNullValue()) { + if (C.isZero()) { // == 0 is u< 1. if (Pred == CmpInst::ICMP_EQ) return new ICmpInst(CmpInst::ICMP_ULT, X, @@ -2355,7 +2364,7 @@ Instruction *InstCombinerImpl::foldICmpSRemConstant(ICmpInst &Cmp, return nullptr; const APInt *DivisorC; - if (!C.isNullValue() || !match(SRem->getOperand(1), m_Power2(DivisorC))) + if (!C.isZero() || !match(SRem->getOperand(1), m_Power2(DivisorC))) return nullptr; // Mask off the sign bit and the modulo bits (low-bits). @@ -2435,8 +2444,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, // INT_MIN will also fail if the divisor is 1. Although folds of all these // division-by-constant cases should be present, we can not assert that they // have happened before we reach this icmp instruction. - if (C2->isNullValue() || C2->isOneValue() || - (DivIsSigned && C2->isAllOnesValue())) + if (C2->isZero() || C2->isOne() || (DivIsSigned && C2->isAllOnes())) return nullptr; // Compute Prod = C * C2. We are essentially solving an equation of @@ -2476,16 +2484,16 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C.isNullValue()) { // (X / pos) op 0 + if (C.isZero()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C.isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); - } else { // (X / pos) op neg + } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; @@ -2497,7 +2505,7 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) RangeSize.negate(); - if (C.isNullValue()) { // (X / neg) op 0 + if (C.isZero()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) LoBound = RangeSize + 1; HiBound = -RangeSize; @@ -2505,13 +2513,13 @@ Instruction *InstCombinerImpl::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = 1; // [INTMIN+1, overflow) HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C.isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; - } else { // (X / neg) op neg + } else { // (X / neg) op neg LoBound = Prod; // e.g. X/-5 op -3 --> [15, 20) LoOverflow = HiOverflow = ProdOV; if (!HiOverflow) @@ -2581,42 +2589,54 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); - const APInt *C2; - APInt SubResult; + Type *Ty = Sub->getType(); - // icmp eq/ne (sub C, Y), C -> icmp eq/ne Y, 0 - if (match(X, m_APInt(C2)) && *C2 == C && Cmp.isEquality()) - return new ICmpInst(Cmp.getPredicate(), Y, - ConstantInt::get(Y->getType(), 0)); + // (SubC - Y) == C) --> Y == (SubC - C) + // (SubC - Y) != C) --> Y != (SubC - C) + Constant *SubC; + if (Cmp.isEquality() && match(X, m_ImmConstant(SubC))) { + return new ICmpInst(Pred, Y, + ConstantExpr::getSub(SubC, ConstantInt::get(Ty, C))); + } // (icmp P (sub nuw|nsw C2, Y), C) -> (icmp swap(P) Y, C2-C) + const APInt *C2; + APInt SubResult; + ICmpInst::Predicate SwappedPred = Cmp.getSwappedPredicate(); + bool HasNSW = Sub->hasNoSignedWrap(); + bool HasNUW = Sub->hasNoUnsignedWrap(); if (match(X, m_APInt(C2)) && - ((Cmp.isUnsigned() && Sub->hasNoUnsignedWrap()) || - (Cmp.isSigned() && Sub->hasNoSignedWrap())) && + ((Cmp.isUnsigned() && HasNUW) || (Cmp.isSigned() && HasNSW)) && !subWithOverflow(SubResult, *C2, C, Cmp.isSigned())) - return new ICmpInst(Cmp.getSwappedPredicate(), Y, - ConstantInt::get(Y->getType(), SubResult)); + return new ICmpInst(SwappedPred, Y, ConstantInt::get(Ty, SubResult)); // The following transforms are only worth it if the only user of the subtract // is the icmp. + // TODO: This is an artificial restriction for all of the transforms below + // that only need a single replacement icmp. if (!Sub->hasOneUse()) return nullptr; + // X - Y == 0 --> X == Y. + // X - Y != 0 --> X != Y. + if (Cmp.isEquality() && C.isZero()) + return new ICmpInst(Pred, X, Y); + if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnes()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isZero()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOne()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2634,7 +2654,12 @@ Instruction *InstCombinerImpl::foldICmpSubConstant(ICmpInst &Cmp, if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); - return nullptr; + // We have handled special cases that reduce. + // Canonicalize any remaining sub to add as: + // (C2 - Y) > C --> (Y + ~C2) < ~C + Value *Add = Builder.CreateAdd(Y, ConstantInt::get(Ty, ~(*C2)), "notsub", + HasNUW, HasNSW); + return new ICmpInst(SwappedPred, Add, ConstantInt::get(Ty, ~C)); } /// Fold icmp (add X, Y), C. @@ -2723,6 +2748,14 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); + // The range test idiom can use either ult or ugt. Arbitrarily canonicalize + // to the ult form. + // X+C2 >u C -> X+(C2-C-1) <u ~C + if (Pred == ICmpInst::ICMP_UGT) + return new ICmpInst(ICmpInst::ICMP_ULT, + Builder.CreateAdd(X, ConstantInt::get(Ty, *C2 - C - 1)), + ConstantInt::get(Ty, ~C)); + return nullptr; } @@ -2830,8 +2863,7 @@ Instruction *InstCombinerImpl::foldICmpSelectConstant(ICmpInst &Cmp, return nullptr; } -static Instruction *foldICmpBitCast(ICmpInst &Cmp, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombinerImpl::foldICmpBitCast(ICmpInst &Cmp) { auto *Bitcast = dyn_cast<BitCastInst>(Cmp.getOperand(0)); if (!Bitcast) return nullptr; @@ -2917,6 +2949,39 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, return new ICmpInst(Pred, BCSrcOp, Op1); } + const APInt *C; + if (!match(Cmp.getOperand(1), m_APInt(C)) || + !Bitcast->getType()->isIntegerTy() || + !Bitcast->getSrcTy()->isIntOrIntVectorTy()) + return nullptr; + + // If this is checking if all elements of a vector compare are set or not, + // invert the casted vector equality compare and test if all compare + // elements are clear or not. Compare against zero is generally easier for + // analysis and codegen. + // icmp eq/ne (bitcast (not X) to iN), -1 --> icmp eq/ne (bitcast X to iN), 0 + // Example: are all elements equal? --> are zero elements not equal? + // TODO: Try harder to reduce compare of 2 freely invertible operands? + if (Cmp.isEquality() && C->isAllOnes() && Bitcast->hasOneUse() && + isFreeToInvert(BCSrcOp, BCSrcOp->hasOneUse())) { + Type *ScalarTy = Bitcast->getType(); + Value *Cast = Builder.CreateBitCast(Builder.CreateNot(BCSrcOp), ScalarTy); + return new ICmpInst(Pred, Cast, ConstantInt::getNullValue(ScalarTy)); + } + + // If this is checking if all elements of an extended vector are clear or not, + // compare in a narrow type to eliminate the extend: + // icmp eq/ne (bitcast (ext X) to iN), 0 --> icmp eq/ne (bitcast X to iM), 0 + Value *X; + if (Cmp.isEquality() && C->isZero() && Bitcast->hasOneUse() && + match(BCSrcOp, m_ZExtOrSExt(m_Value(X)))) { + if (auto *VecTy = dyn_cast<FixedVectorType>(X->getType())) { + Type *NewType = Builder.getIntNTy(VecTy->getPrimitiveSizeInBits()); + Value *NewCast = Builder.CreateBitCast(X, NewType); + return new ICmpInst(Pred, NewCast, ConstantInt::getNullValue(NewType)); + } + } + // Folding: icmp <pred> iN X, C // where X = bitcast <M x iK> (shufflevector <M x iK> %vec, undef, SC)) to iN // and C is a splat of a K-bit pattern @@ -2924,12 +2989,6 @@ static Instruction *foldICmpBitCast(ICmpInst &Cmp, // Into: // %E = extractelement <M x iK> %vec, i32 C' // icmp <pred> iK %E, trunc(C) - const APInt *C; - if (!match(Cmp.getOperand(1), m_APInt(C)) || - !Bitcast->getType()->isIntegerTy() || - !Bitcast->getSrcTy()->isIntOrIntVectorTy()) - return nullptr; - Value *Vec; ArrayRef<int> Mask; if (match(BCSrcOp, m_Shuffle(m_Value(Vec), m_Undef(), m_Mask(Mask)))) { @@ -3055,7 +3114,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C.isNullValue() && BO->hasOneUse()) { + if (C.isZero() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -3069,7 +3128,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( if (Constant *BOC = dyn_cast<Constant>(BOp1)) { if (BO->hasOneUse()) return new ICmpInst(Pred, BOp0, ConstantExpr::getSub(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -3090,25 +3149,12 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C.isNullValue()) { + } else if (C.isZero()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } } break; - case Instruction::Sub: - if (BO->hasOneUse()) { - // Only check for constant LHS here, as constant RHS will be canonicalized - // to add and use the fold above. - if (Constant *BOC = dyn_cast<Constant>(BOp0)) { - // Replace ((sub BOC, B) != C) with (B != BOC-C). - return new ICmpInst(Pred, BOp1, ConstantExpr::getSub(BOC, RHS)); - } else if (C.isNullValue()) { - // Replace ((sub A, B) != 0) with (A != B). - return new ICmpInst(Pred, BOp0, BOp1); - } - } - break; case Instruction::Or: { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BO->hasOneUse() && RHS->isAllOnesValue()) { @@ -3132,7 +3178,7 @@ Instruction *InstCombinerImpl::foldICmpBinOpEqualityWithConstant( break; } case Instruction::UDiv: - if (C.isNullValue()) { + if (C.isZero()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -3149,25 +3195,26 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( ICmpInst &Cmp, IntrinsicInst *II, const APInt &C) { Type *Ty = II->getType(); unsigned BitWidth = C.getBitWidth(); + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + switch (II->getIntrinsicID()) { case Intrinsic::abs: // abs(A) == 0 -> A == 0 // abs(A) == INT_MIN -> A == INT_MIN - if (C.isNullValue() || C.isMinSignedValue()) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - ConstantInt::get(Ty, C)); + if (C.isZero() || C.isMinSignedValue()) + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C)); break; case Intrinsic::bswap: // bswap(A) == C -> A == bswap(C) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::get(Ty, C.byteSwap())); case Intrinsic::ctlz: case Intrinsic::cttz: { // ctz(A) == bitwidth(A) -> A == 0 and likewise for != if (C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), + return new ICmpInst(Pred, II->getArgOperand(0), ConstantInt::getNullValue(Ty)); // ctz(A) == C -> A & Mask1 == Mask2, where Mask2 only has bit C set @@ -3181,9 +3228,8 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( APInt Mask2 = IsTrailing ? APInt::getOneBitSet(BitWidth, Num) : APInt::getOneBitSet(BitWidth, BitWidth - Num - 1); - return new ICmpInst(Cmp.getPredicate(), - Builder.CreateAnd(II->getArgOperand(0), Mask1), - ConstantInt::get(Ty, Mask2)); + return new ICmpInst(Pred, Builder.CreateAnd(II->getArgOperand(0), Mask1), + ConstantInt::get(Ty, Mask2)); } break; } @@ -3191,28 +3237,49 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C.isNullValue(); + bool IsZero = C.isZero(); if (IsZero || C == BitWidth) - return new ICmpInst(Cmp.getPredicate(), II->getArgOperand(0), - IsZero ? Constant::getNullValue(Ty) : Constant::getAllOnesValue(Ty)); + return new ICmpInst(Pred, II->getArgOperand(0), + IsZero ? Constant::getNullValue(Ty) + : Constant::getAllOnesValue(Ty)); break; } + case Intrinsic::fshl: + case Intrinsic::fshr: + if (II->getArgOperand(0) == II->getArgOperand(1)) { + // (rot X, ?) == 0/-1 --> X == 0/-1 + // TODO: This transform is safe to re-use undef elts in a vector, but + // the constant value passed in by the caller doesn't allow that. + if (C.isZero() || C.isAllOnes()) + return new ICmpInst(Pred, II->getArgOperand(0), Cmp.getOperand(1)); + + const APInt *RotAmtC; + // ror(X, RotAmtC) == C --> X == rol(C, RotAmtC) + // rol(X, RotAmtC) == C --> X == ror(C, RotAmtC) + if (match(II->getArgOperand(2), m_APInt(RotAmtC))) + return new ICmpInst(Pred, II->getArgOperand(0), + II->getIntrinsicID() == Intrinsic::fshl + ? ConstantInt::get(Ty, C.rotr(*RotAmtC)) + : ConstantInt::get(Ty, C.rotl(*RotAmtC))); + } + break; + case Intrinsic::uadd_sat: { // uadd.sat(a, b) == 0 -> (a | b) == 0 - if (C.isNullValue()) { + if (C.isZero()) { Value *Or = Builder.CreateOr(II->getArgOperand(0), II->getArgOperand(1)); - return new ICmpInst(Cmp.getPredicate(), Or, Constant::getNullValue(Ty)); + return new ICmpInst(Pred, Or, Constant::getNullValue(Ty)); } break; } case Intrinsic::usub_sat: { // usub.sat(a, b) == 0 -> a <= b - if (C.isNullValue()) { - ICmpInst::Predicate NewPred = Cmp.getPredicate() == ICmpInst::ICMP_EQ - ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; + if (C.isZero()) { + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, II->getArgOperand(0), II->getArgOperand(1)); } break; @@ -3224,6 +3291,42 @@ Instruction *InstCombinerImpl::foldICmpEqIntrinsicWithConstant( return nullptr; } +/// Fold an icmp with LLVM intrinsics +static Instruction *foldICmpIntrinsicWithIntrinsic(ICmpInst &Cmp) { + assert(Cmp.isEquality()); + + ICmpInst::Predicate Pred = Cmp.getPredicate(); + Value *Op0 = Cmp.getOperand(0); + Value *Op1 = Cmp.getOperand(1); + const auto *IIOp0 = dyn_cast<IntrinsicInst>(Op0); + const auto *IIOp1 = dyn_cast<IntrinsicInst>(Op1); + if (!IIOp0 || !IIOp1 || IIOp0->getIntrinsicID() != IIOp1->getIntrinsicID()) + return nullptr; + + switch (IIOp0->getIntrinsicID()) { + case Intrinsic::bswap: + case Intrinsic::bitreverse: + // If both operands are byte-swapped or bit-reversed, just compare the + // original values. + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + case Intrinsic::fshl: + case Intrinsic::fshr: + // If both operands are rotated by same amount, just compare the + // original values. + if (IIOp0->getOperand(0) != IIOp0->getOperand(1)) + break; + if (IIOp1->getOperand(0) != IIOp1->getOperand(1)) + break; + if (IIOp0->getOperand(2) != IIOp1->getOperand(2)) + break; + return new ICmpInst(Pred, IIOp0->getOperand(0), IIOp1->getOperand(0)); + default: + break; + } + + return nullptr; +} + /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombinerImpl::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, IntrinsicInst *II, @@ -3663,7 +3766,7 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, (WidestTy->getScalarSizeInBits() - 1) + (NarrowestTy->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(XShAmt->getType()->getScalarSizeInBits()); + APInt::getAllOnes(XShAmt->getType()->getScalarSizeInBits()); if (MaximalRepresentableShiftAmount.ult(MaximalPossibleTotalShiftAmount)) return nullptr; @@ -3746,19 +3849,22 @@ foldShiftIntoShiftInAnotherHandOfAndInICmp(ICmpInst &I, const SimplifyQuery SQ, /// Fold /// (-1 u/ x) u< y -/// ((x * y) u/ x) != y +/// ((x * y) ?/ x) != y /// to -/// @llvm.umul.with.overflow(x, y) plus extraction of overflow bit +/// @llvm.?mul.with.overflow(x, y) plus extraction of overflow bit /// Note that the comparison is commutative, while inverted (u>=, ==) predicate /// will mean that we are looking for the opposite answer. -Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { +Value *InstCombinerImpl::foldMultiplicationOverflowCheck(ICmpInst &I) { ICmpInst::Predicate Pred; Value *X, *Y; Instruction *Mul; + Instruction *Div; bool NeedNegation; // Look for: (-1 u/ x) u</u>= y if (!I.isEquality() && - match(&I, m_c_ICmp(Pred, m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + match(&I, m_c_ICmp(Pred, + m_CombineAnd(m_OneUse(m_UDiv(m_AllOnes(), m_Value(X))), + m_Instruction(Div)), m_Value(Y)))) { Mul = nullptr; @@ -3773,13 +3879,16 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { default: return nullptr; // Wrong predicate. } - } else // Look for: ((x * y) u/ x) !=/== y + } else // Look for: ((x * y) / x) !=/== y if (I.isEquality() && - match(&I, m_c_ICmp(Pred, m_Value(Y), - m_OneUse(m_UDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), + match(&I, + m_c_ICmp(Pred, m_Value(Y), + m_CombineAnd( + m_OneUse(m_IDiv(m_CombineAnd(m_c_Mul(m_Deferred(Y), m_Value(X)), m_Instruction(Mul)), - m_Deferred(X)))))) { + m_Deferred(X))), + m_Instruction(Div))))) { NeedNegation = Pred == ICmpInst::Predicate::ICMP_EQ; } else return nullptr; @@ -3791,19 +3900,22 @@ Value *InstCombinerImpl::foldUnsignedMultiplicationOverflowCheck(ICmpInst &I) { if (MulHadOtherUses) Builder.SetInsertPoint(Mul); - Function *F = Intrinsic::getDeclaration( - I.getModule(), Intrinsic::umul_with_overflow, X->getType()); - CallInst *Call = Builder.CreateCall(F, {X, Y}, "umul"); + Function *F = Intrinsic::getDeclaration(I.getModule(), + Div->getOpcode() == Instruction::UDiv + ? Intrinsic::umul_with_overflow + : Intrinsic::smul_with_overflow, + X->getType()); + CallInst *Call = Builder.CreateCall(F, {X, Y}, "mul"); // If the multiplication was used elsewhere, to ensure that we don't leave // "duplicate" instructions, replace uses of that original multiplication // with the multiplication result from the with.overflow intrinsic. if (MulHadOtherUses) - replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "umul.val")); + replaceInstUsesWith(*Mul, Builder.CreateExtractValue(Call, 0, "mul.val")); - Value *Res = Builder.CreateExtractValue(Call, 1, "umul.ov"); + Value *Res = Builder.CreateExtractValue(Call, 1, "mul.ov"); if (NeedNegation) // This technically increases instruction count. - Res = Builder.CreateNot(Res, "umul.not.ov"); + Res = Builder.CreateNot(Res, "mul.not.ov"); // If we replaced the mul, erase it. Do this after all uses of Builder, // as the mul is used as insertion point. @@ -4079,8 +4191,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, if (match(Op0, m_Mul(m_Value(X), m_APInt(C))) && *C != 0 && match(Op1, m_Mul(m_Value(Y), m_SpecificInt(*C))) && I.isEquality()) if (!C->countTrailingZeros() || - (BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || - (BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) + (BO0 && BO1 && BO0->hasNoSignedWrap() && BO1->hasNoSignedWrap()) || + (BO0 && BO1 && BO0->hasNoUnsignedWrap() && BO1->hasNoUnsignedWrap())) return new ICmpInst(Pred, X, Y); } @@ -4146,8 +4258,8 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, break; const APInt *C; - if (match(BO0->getOperand(1), m_APInt(C)) && !C->isNullValue() && - !C->isOneValue()) { + if (match(BO0->getOperand(1), m_APInt(C)) && !C->isZero() && + !C->isOne()) { // icmp eq/ne (X * C), (Y * C) --> icmp (X & Mask), (Y & Mask) // Mask = -1 >> count-trailing-zeros(C). if (unsigned TZs = C->countTrailingZeros()) { @@ -4200,7 +4312,7 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, } } - if (Value *V = foldUnsignedMultiplicationOverflowCheck(I)) + if (Value *V = foldMultiplicationOverflowCheck(I)) return replaceInstUsesWith(I, V); if (Value *V = foldICmpWithLowBitMaskedVal(I, Builder)) @@ -4373,6 +4485,19 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } + { + // Similar to above, but specialized for constant because invert is needed: + // (X | C) == (Y | C) --> (X ^ Y) & ~C == 0 + Value *X, *Y; + Constant *C; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_Constant(C)))) && + match(Op1, m_OneUse(m_Or(m_Value(Y), m_Specific(C))))) { + Value *Xor = Builder.CreateXor(X, Y); + Value *And = Builder.CreateAnd(Xor, ConstantExpr::getNot(C)); + return new ICmpInst(Pred, And, Constant::getNullValue(And->getType())); + } + } + // Transform (zext A) == (B & (1<<X)-1) --> A == (trunc B) // and (B & (1<<X)-1) == (zext A) --> A == (trunc B) ConstantInt *Cst1; @@ -4441,14 +4566,8 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { } } - // If both operands are byte-swapped or bit-reversed, just compare the - // original values. - // TODO: Move this to a function similar to foldICmpIntrinsicWithConstant() - // and handle more intrinsics. - if ((match(Op0, m_BSwap(m_Value(A))) && match(Op1, m_BSwap(m_Value(B)))) || - (match(Op0, m_BitReverse(m_Value(A))) && - match(Op1, m_BitReverse(m_Value(B))))) - return new ICmpInst(Pred, A, B); + if (Instruction *ICmp = foldICmpIntrinsicWithIntrinsic(I)) + return ICmp; // Canonicalize checking for a power-of-2-or-zero value: // (A & (A-1)) == 0 --> ctpop(A) < 2 (two commuted variants) @@ -4474,6 +4593,74 @@ Instruction *InstCombinerImpl::foldICmpEquality(ICmpInst &I) { : new ICmpInst(ICmpInst::ICMP_UGT, CtPop, ConstantInt::get(Ty, 1)); } + // Match icmp eq (trunc (lshr A, BW), (ashr (trunc A), BW-1)), which checks the + // top BW/2 + 1 bits are all the same. Create "A >=s INT_MIN && A <=s INT_MAX", + // which we generate as "icmp ult (add A, 2^(BW-1)), 2^BW" to skip a few steps + // of instcombine. + unsigned BitWidth = Op0->getType()->getScalarSizeInBits(); + if (match(Op0, m_AShr(m_Trunc(m_Value(A)), m_SpecificInt(BitWidth - 1))) && + match(Op1, m_Trunc(m_LShr(m_Specific(A), m_SpecificInt(BitWidth)))) && + A->getType()->getScalarSizeInBits() == BitWidth * 2 && + (I.getOperand(0)->hasOneUse() || I.getOperand(1)->hasOneUse())) { + APInt C = APInt::getOneBitSet(BitWidth * 2, BitWidth - 1); + Value *Add = Builder.CreateAdd(A, ConstantInt::get(A->getType(), C)); + return new ICmpInst(Pred == ICmpInst::ICMP_EQ ? ICmpInst::ICMP_ULT + : ICmpInst::ICMP_UGE, + Add, ConstantInt::get(A->getType(), C.shl(1))); + } + + return nullptr; +} + +static Instruction *foldICmpWithTrunc(ICmpInst &ICmp, + InstCombiner::BuilderTy &Builder) { + const ICmpInst::Predicate Pred = ICmp.getPredicate(); + Value *Op0 = ICmp.getOperand(0), *Op1 = ICmp.getOperand(1); + + // Try to canonicalize trunc + compare-to-constant into a mask + cmp. + // The trunc masks high bits while the compare may effectively mask low bits. + Value *X; + const APInt *C; + if (!match(Op0, m_OneUse(m_Trunc(m_Value(X)))) || !match(Op1, m_APInt(C))) + return nullptr; + + unsigned SrcBits = X->getType()->getScalarSizeInBits(); + if (Pred == ICmpInst::ICMP_ULT) { + if (C->isPowerOf2()) { + // If C is a power-of-2 (one set bit): + // (trunc X) u< C --> (X & -C) == 0 (are all masked-high-bits clear?) + Constant *MaskC = ConstantInt::get(X->getType(), (-*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_EQ, And, Zero); + } + // If C is a negative power-of-2 (high-bit mask): + // (trunc X) u< C --> (X & C) != C (are any masked-high-bits clear?) + if (C->isNegatedPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), C->zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_NE, And, MaskC); + } + } + + if (Pred == ICmpInst::ICMP_UGT) { + // If C is a low-bit-mask (C+1 is a power-of-2): + // (trunc X) u> C --> (X & ~C) != 0 (are any masked-high-bits set?) + if (C->isMask()) { + Constant *MaskC = ConstantInt::get(X->getType(), (~*C).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + Constant *Zero = ConstantInt::getNullValue(X->getType()); + return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); + } + // If C is not-of-power-of-2 (one clear bit): + // (trunc X) u> C --> (X & (C+1)) == C+1 (are all masked-high-bits set?) + if ((~*C).isPowerOf2()) { + Constant *MaskC = ConstantInt::get(X->getType(), (*C + 1).zext(SrcBits)); + Value *And = Builder.CreateAnd(X, MaskC); + return new ICmpInst(ICmpInst::ICMP_EQ, And, MaskC); + } + } + return nullptr; } @@ -4620,6 +4807,9 @@ Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1); } + if (Instruction *R = foldICmpWithTrunc(ICmp, Builder)) + return R; + return foldICmpWithZextOrSext(ICmp, Builder); } @@ -4943,7 +5133,7 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { const APInt *RHS; if (!match(I.getOperand(1), m_APInt(RHS))) - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); // If this is a normal comparison, it demands all bits. If it is a sign bit // comparison, it only demands the sign bit. @@ -4965,7 +5155,7 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: - return APInt::getAllOnesValue(BitWidth); + return APInt::getAllOnes(BitWidth); } } @@ -5129,8 +5319,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { Op0Known, 0)) return &I; - if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), - Op1Known, 0)) + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnes(BitWidth), Op1Known, 0)) return &I; // Given the known and unknown bits, compute a range that the LHS could be @@ -5158,6 +5347,83 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { if (!isa<Constant>(Op1) && Op1Min == Op1Max) return new ICmpInst(Pred, Op0, ConstantExpr::getIntegerValue(Ty, Op1Min)); + // Don't break up a clamp pattern -- (min(max X, Y), Z) -- by replacing a + // min/max canonical compare with some other compare. That could lead to + // conflict with select canonicalization and infinite looping. + // FIXME: This constraint may go away if min/max intrinsics are canonical. + auto isMinMaxCmp = [&](Instruction &Cmp) { + if (!Cmp.hasOneUse()) + return false; + Value *A, *B; + SelectPatternFlavor SPF = matchSelectPattern(Cmp.user_back(), A, B).Flavor; + if (!SelectPatternResult::isMinOrMax(SPF)) + return false; + return match(Op0, m_MaxOrMin(m_Value(), m_Value())) || + match(Op1, m_MaxOrMin(m_Value(), m_Value())); + }; + if (!isMinMaxCmp(I)) { + switch (Pred) { + default: + break; + case ICmpInst::ICMP_ULT: { + if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A <u C -> A == C-1 if min(A)+1 == C + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X <u C --> X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_UGT: { + if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + // A >u C -> A == C+1 if max(a)-1 == C + if (*CmpC == Op0Max - 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); + } + break; + } + case ICmpInst::ICMP_SLT: { + if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + } + break; + } + case ICmpInst::ICMP_SGT: { + if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC + 1)); + } + break; + } + } + } + // Based on the range information we know about the LHS, see if we can // simplify this comparison. For example, (x&4) < 8 is always true. switch (Pred) { @@ -5203,7 +5469,7 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { // Check if the LHS is 8 >>u x and the result is a power of 2 like 1. const APInt *CI; - if (Op0KnownZeroInverted.isOneValue() && + if (Op0KnownZeroInverted.isOne() && match(LHS, m_LShr(m_Power2(CI), m_Value(X)))) { // ((8 >>u X) & 1) == 0 -> X != 3 // ((8 >>u X) & 1) != 0 -> X == 3 @@ -5219,21 +5485,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.uge(Op1Max)) // A <u B -> false if min(A) >= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <u B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A <u C -> A == C-1 if min(A)+1 == C - if (*CmpC == Op0Min + 1) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC - 1)); - // X <u C --> X == 0, if the number of zero bits in the bottom of X - // exceeds the log2 of C. - if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Constant::getNullValue(Op1->getType())); - } break; } case ICmpInst::ICMP_UGT: { @@ -5241,21 +5492,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - // A >u C -> A == C+1 if max(a)-1 == C - if (*CmpC == Op0Max - 1) - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC + 1)); - // X >u C --> X != 0, if the number of zero bits in the bottom of X - // exceeds the log2 of C. - if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, - Constant::getNullValue(Op1->getType())); - } break; } case ICmpInst::ICMP_SLT: { @@ -5263,14 +5499,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC - 1)); - } break; } case ICmpInst::ICMP_SGT: { @@ -5278,14 +5506,6 @@ Instruction *InstCombinerImpl::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) - return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - ConstantInt::get(Op1->getType(), *CmpC + 1)); - } break; } case ICmpInst::ICMP_SGE: @@ -5587,7 +5807,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, if (match(RHS, m_Shuffle(m_Value(V2), m_Undef(), m_SpecificMask(M))) && V1Ty == V2->getType() && (LHS->hasOneUse() || RHS->hasOneUse())) { Value *NewCmp = Builder.CreateCmp(Pred, V1, V2); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), M); + return new ShuffleVectorInst(NewCmp, M); } // Try to canonicalize compare with splatted operand and splat constant. @@ -5608,8 +5828,7 @@ static Instruction *foldVectorCmp(CmpInst &Cmp, ScalarC); SmallVector<int, 8> NewM(M.size(), MaskSplatIndex); Value *NewCmp = Builder.CreateCmp(Pred, V1, C); - return new ShuffleVectorInst(NewCmp, UndefValue::get(NewCmp->getType()), - NewM); + return new ShuffleVectorInst(NewCmp, NewM); } return nullptr; @@ -5645,6 +5864,23 @@ static Instruction *foldICmpOfUAddOv(ICmpInst &I) { return ExtractValueInst::Create(UAddOv, 1); } +static Instruction *foldICmpInvariantGroup(ICmpInst &I) { + if (!I.getOperand(0)->getType()->isPointerTy() || + NullPointerIsDefined( + I.getParent()->getParent(), + I.getOperand(0)->getType()->getPointerAddressSpace())) { + return nullptr; + } + Instruction *Op; + if (match(I.getOperand(0), m_Instruction(Op)) && + match(I.getOperand(1), m_Zero()) && + Op->isLaunderOrStripInvariantGroup()) { + return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), + Op->getOperand(0), I.getOperand(1)); + } + return nullptr; +} + Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { bool Changed = false; const SimplifyQuery Q = SQ.getWithInstruction(&I); @@ -5698,9 +5934,6 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpWithDominatingICmp(I)) return Res; - if (Instruction *Res = foldICmpBinOp(I, Q)) - return Res; - if (Instruction *Res = foldICmpUsingKnownBits(I)) return Res; @@ -5746,6 +5979,15 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { } } + // The folds in here may rely on wrapping flags and special constants, so + // they can break up min/max idioms in some cases but not seemingly similar + // patterns. + // FIXME: It may be possible to enhance select folding to make this + // unnecessary. It may also be moot if we canonicalize to min/max + // intrinsics. + if (Instruction *Res = foldICmpBinOp(I, Q)) + return Res; + if (Instruction *Res = foldICmpInstWithConstant(I)) return Res; @@ -5757,13 +5999,12 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldICmpInstWithConstantNotInt(I)) return Res; - // If we can optimize a 'icmp GEP, P' or 'icmp P, GEP', do so now. - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op0)) + // Try to optimize 'icmp GEP, P' or 'icmp P, GEP'. + if (auto *GEP = dyn_cast<GEPOperator>(Op0)) if (Instruction *NI = foldGEPICmp(GEP, Op1, I.getPredicate(), I)) return NI; - if (GEPOperator *GEP = dyn_cast<GEPOperator>(Op1)) - if (Instruction *NI = foldGEPICmp(GEP, Op0, - ICmpInst::getSwappedPredicate(I.getPredicate()), I)) + if (auto *GEP = dyn_cast<GEPOperator>(Op1)) + if (Instruction *NI = foldGEPICmp(GEP, Op0, I.getSwappedPredicate(), I)) return NI; // Try to optimize equality comparisons against alloca-based pointers. @@ -5777,7 +6018,7 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { return New; } - if (Instruction *Res = foldICmpBitCast(I, Builder)) + if (Instruction *Res = foldICmpBitCast(I)) return Res; // TODO: Hoist this above the min/max bailout. @@ -5879,6 +6120,9 @@ Instruction *InstCombinerImpl::visitICmpInst(ICmpInst &I) { if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; + if (Instruction *Res = foldICmpInvariantGroup(I)) + return Res; + return Changed ? &I : nullptr; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h index eaa53348028d..72e1b21e8d49 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -22,14 +22,15 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Value.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/Local.h" #include <cassert> #define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" using namespace llvm::PatternMatch; @@ -61,7 +62,7 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final : public InstCombiner, public InstVisitor<InstCombinerImpl, Instruction *> { public: - InstCombinerImpl(InstCombineWorklist &Worklist, BuilderTy &Builder, + InstCombinerImpl(InstructionWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, AAResults *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, @@ -190,6 +191,7 @@ public: private: void annotateAnyAllocSite(CallBase &Call, const TargetLibraryInfo *TLI); + bool isDesirableIntType(unsigned BitWidth) const; bool shouldChangeType(unsigned FromBitWidth, unsigned ToBitWidth) const; bool shouldChangeType(Type *From, Type *To) const; Value *dyn_castNegVal(Value *V) const; @@ -240,15 +242,11 @@ private: /// /// \param ICI The icmp of the (zext icmp) pair we are interested in. /// \parem CI The zext of the (zext icmp) pair we are interested in. - /// \param DoTransform Pass false to just test whether the given (zext icmp) - /// would be transformed. Pass true to actually perform the transformation. /// /// \return null if the transformation cannot be performed. If the /// transformation can be performed the new instruction that replaces the - /// (zext icmp) pair will be returned (if \p DoTransform is false the - /// unmodified \p ICI will be returned in this case). - Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, - bool DoTransform = true); + /// (zext icmp) pair will be returned. + Instruction *transformZExtICmp(ICmpInst *ICI, ZExtInst &CI); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); @@ -319,13 +317,15 @@ private: Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); + Instruction *foldBitcastExtElt(ExtractElementInst &ExtElt); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); Instruction *narrowBinOp(TruncInst &Trunc); Instruction *narrowMaskedBinOp(BinaryOperator &And); Instruction *narrowMathIfNoOverflow(BinaryOperator &I); Instruction *narrowFunnelShift(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); - Instruction *matchSAddSubSat(SelectInst &MinMax1); + Instruction *matchSAddSubSat(Instruction &MinMax1); + Instruction *foldNot(BinaryOperator &I); void freelyInvertAllUsersOf(Value *V); @@ -347,6 +347,8 @@ private: Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); + Value *foldEqOfParts(ICmpInst *Cmp0, ICmpInst *Cmp1, bool IsAnd); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. @@ -623,6 +625,7 @@ public: Instruction *foldPHIArgGEPIntoPHI(PHINode &PN); Instruction *foldPHIArgLoadIntoPHI(PHINode &PN); Instruction *foldPHIArgZextsIntoPHI(PHINode &PN); + Instruction *foldPHIArgIntToPtrToPHI(PHINode &PN); /// If an integer typed PHI has only one use which is an IntToPtr operation, /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise @@ -657,7 +660,7 @@ public: Instruction *foldSignBitTest(ICmpInst &I); Instruction *foldICmpWithZero(ICmpInst &Cmp); - Value *foldUnsignedMultiplicationOverflowCheck(ICmpInst &Cmp); + Value *foldMultiplicationOverflowCheck(ICmpInst &Cmp); Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); @@ -701,6 +704,7 @@ public: const APInt &C); Instruction *foldICmpEqIntrinsicWithConstant(ICmpInst &ICI, IntrinsicInst *II, const APInt &C); + Instruction *foldICmpBitCast(ICmpInst &Cmp); // Helpers of visitSelectInst(). Instruction *foldSelectExtConst(SelectInst &Sel); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index a8474e27383d..79a8a065d02a 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -261,8 +261,8 @@ private: bool PointerReplacer::collectUsers(Instruction &I) { for (auto U : I.users()) { - Instruction *Inst = cast<Instruction>(&*U); - if (LoadInst *Load = dyn_cast<LoadInst>(Inst)) { + auto *Inst = cast<Instruction>(&*U); + if (auto *Load = dyn_cast<LoadInst>(Inst)) { if (Load->isVolatile()) return false; Worklist.insert(Load); @@ -270,7 +270,9 @@ bool PointerReplacer::collectUsers(Instruction &I) { Worklist.insert(Inst); if (!collectUsers(*Inst)) return false; - } else if (isa<MemTransferInst>(Inst)) { + } else if (auto *MI = dyn_cast<MemTransferInst>(Inst)) { + if (MI->isVolatile()) + return false; Worklist.insert(Inst); } else if (Inst->isLifetimeStartOrEnd()) { continue; @@ -335,8 +337,7 @@ void PointerReplacer::replace(Instruction *I) { MemCpy->getIntrinsicID(), MemCpy->getRawDest(), MemCpy->getDestAlign(), SrcV, MemCpy->getSourceAlign(), MemCpy->getLength(), MemCpy->isVolatile()); - AAMDNodes AAMD; - MemCpy->getAAMetadata(AAMD); + AAMDNodes AAMD = MemCpy->getAAMetadata(); if (AAMD) NewI->setAAMetadata(AAMD); @@ -647,9 +648,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { if (NumElements == 1) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ST->getTypeAtIndex(0U), ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - NewLoad->setAAMetadata(AAMD); + NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -678,9 +677,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { ST->getElementType(i), Ptr, commonAlignment(Align, SL->getElementOffset(i)), Name + ".unpack"); // Propagate AA metadata. It'll still be valid on the narrowed load. - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - L->setAAMetadata(AAMD); + L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); } @@ -693,9 +690,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto NumElements = AT->getNumElements(); if (NumElements == 1) { LoadInst *NewLoad = IC.combineLoadToNewType(LI, ET, ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - NewLoad->setAAMetadata(AAMD); + NewLoad->setAAMetadata(LI.getAAMetadata()); return IC.replaceInstUsesWith(LI, IC.Builder.CreateInsertValue( UndefValue::get(T), NewLoad, 0, Name)); } @@ -727,9 +722,7 @@ static Instruction *unpackLoadToAggregate(InstCombinerImpl &IC, LoadInst &LI) { auto *L = IC.Builder.CreateAlignedLoad(AT->getElementType(), Ptr, commonAlignment(Align, Offset), Name + ".unpack"); - AAMDNodes AAMD; - LI.getAAMetadata(AAMD); - L->setAAMetadata(AAMD); + L->setAAMetadata(LI.getAAMetadata()); V = IC.Builder.CreateInsertValue(V, L, i); Offset += EltSize; } @@ -1206,9 +1199,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, SL->getElementOffset(i)); llvm::Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); - AAMDNodes AAMD; - SI.getAAMetadata(AAMD); - NS->setAAMetadata(AAMD); + NS->setAAMetadata(SI.getAAMetadata()); } return true; @@ -1254,9 +1245,7 @@ static bool unpackStoreToAggregate(InstCombinerImpl &IC, StoreInst &SI) { auto *Val = IC.Builder.CreateExtractValue(V, i, EltName); auto EltAlign = commonAlignment(Align, Offset); Instruction *NS = IC.Builder.CreateAlignedStore(Val, Ptr, EltAlign); - AAMDNodes AAMD; - SI.getAAMetadata(AAMD); - NS->setAAMetadata(AAMD); + NS->setAAMetadata(SI.getAAMetadata()); Offset += EltSize; } @@ -1498,8 +1487,8 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { StoreInst *OtherStore = nullptr; if (OtherBr->isUnconditional()) { --BBI; - // Skip over debugging info. - while (isa<DbgInfoIntrinsic>(BBI) || + // Skip over debugging info and pseudo probes. + while (BBI->isDebugOrPseudoInst() || (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy())) { if (BBI==OtherBB->begin()) return false; @@ -1567,12 +1556,9 @@ bool InstCombinerImpl::mergeStoreIntoSuccessor(StoreInst &SI) { NewSI->setDebugLoc(MergedLoc); // If the two stores had AA tags, merge them. - AAMDNodes AATags; - SI.getAAMetadata(AATags); - if (AATags) { - OtherStore->getAAMetadata(AATags, /* Merge = */ true); - NewSI->setAAMetadata(AATags); - } + AAMDNodes AATags = SI.getAAMetadata(); + if (AATags) + NewSI->setAAMetadata(AATags.merge(OtherStore->getAAMetadata())); // Nuke the old stores. eraseInstFromFunction(SI); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index 6f2a8ebf839a..779d298da7a4 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -31,7 +31,6 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include <cassert> @@ -39,11 +38,12 @@ #include <cstdint> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - /// The specific integer value is used in a context where it is known to be /// non-zero. If this allows us to simplify the computation, do so and return /// the new operand, otherwise return null. @@ -107,14 +107,19 @@ static Value *foldMulSelectToNegate(BinaryOperator &I, // mul (select Cond, 1, -1), OtherOp --> select Cond, OtherOp, -OtherOp // mul OtherOp, (select Cond, 1, -1) --> select Cond, OtherOp, -OtherOp if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_One(), m_AllOnes())), - m_Value(OtherOp)))) - return Builder.CreateSelect(Cond, OtherOp, Builder.CreateNeg(OtherOp)); - + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, OtherOp, Neg); + } // mul (select Cond, -1, 1), OtherOp --> select Cond, -OtherOp, OtherOp // mul OtherOp, (select Cond, -1, 1) --> select Cond, -OtherOp, OtherOp if (match(&I, m_c_Mul(m_OneUse(m_Select(m_Value(Cond), m_AllOnes(), m_One())), - m_Value(OtherOp)))) - return Builder.CreateSelect(Cond, Builder.CreateNeg(OtherOp), OtherOp); + m_Value(OtherOp)))) { + bool HasAnyNoWrap = I.hasNoSignedWrap() || I.hasNoUnsignedWrap(); + Value *Neg = Builder.CreateNeg(OtherOp, "", false, HasAnyNoWrap); + return Builder.CreateSelect(Cond, Neg, OtherOp); + } // fmul (select Cond, 1.0, -1.0), OtherOp --> select Cond, OtherOp, -OtherOp // fmul OtherOp, (select Cond, 1.0, -1.0) --> select Cond, OtherOp, -OtherOp @@ -564,6 +569,16 @@ Instruction *InstCombinerImpl::visitFMul(BinaryOperator &I) { return replaceInstUsesWith(I, NewPow); } + // powi(x, y) * powi(x, z) -> powi(x, y + z) + if (match(Op0, m_Intrinsic<Intrinsic::powi>(m_Value(X), m_Value(Y))) && + match(Op1, m_Intrinsic<Intrinsic::powi>(m_Specific(X), m_Value(Z))) && + Y->getType() == Z->getType()) { + auto *YZ = Builder.CreateAdd(Y, Z); + auto *NewPow = Builder.CreateIntrinsic( + Intrinsic::powi, {X->getType(), YZ->getType()}, {X, YZ}, &I); + return replaceInstUsesWith(I, NewPow); + } + // exp(X) * exp(Y) -> exp(X + Y) if (match(Op0, m_Intrinsic<Intrinsic::exp>(m_Value(X))) && match(Op1, m_Intrinsic<Intrinsic::exp>(m_Value(Y)))) { @@ -706,11 +721,11 @@ static bool isMultiple(const APInt &C1, const APInt &C2, APInt &Quotient, assert(C1.getBitWidth() == C2.getBitWidth() && "Constant widths not equal"); // Bail if we will divide by zero. - if (C2.isNullValue()) + if (C2.isZero()) return false; // Bail if we would divide INT_MIN by -1. - if (IsSigned && C1.isMinSignedValue() && C2.isAllOnesValue()) + if (IsSigned && C1.isMinSignedValue() && C2.isAllOnes()) return false; APInt Remainder(C1.getBitWidth(), /*val=*/0ULL, IsSigned); @@ -778,11 +793,12 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } if ((IsSigned && match(Op0, m_NSWShl(m_Value(X), m_APInt(C1))) && - *C1 != C1->getBitWidth() - 1) || - (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))))) { + C1->ult(C1->getBitWidth() - 1)) || + (!IsSigned && match(Op0, m_NUWShl(m_Value(X), m_APInt(C1))) && + C1->ult(C1->getBitWidth()))) { APInt Quotient(C1->getBitWidth(), /*val=*/0ULL, IsSigned); APInt C1Shifted = APInt::getOneBitSet( - C1->getBitWidth(), static_cast<unsigned>(C1->getLimitedValue())); + C1->getBitWidth(), static_cast<unsigned>(C1->getZExtValue())); // (X << C1) / C2 -> X / (C2 >> C1) if C2 is a multiple of 1 << C1. if (isMultiple(*C2, C1Shifted, Quotient, IsSigned)) { @@ -803,7 +819,7 @@ Instruction *InstCombinerImpl::commonIDivTransforms(BinaryOperator &I) { } } - if (!C2->isNullValue()) // avoid X udiv 0 + if (!C2->isZero()) // avoid X udiv 0 if (Instruction *FoldedDiv = foldBinOpIntoSelectOrPhi(I)) return FoldedDiv; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp index 37c7e6135501..7dc516c6fdc3 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -215,6 +215,20 @@ LLVM_NODISCARD Value *Negator::visitImpl(Value *V, unsigned Depth) { : Builder.CreateSExt(I->getOperand(0), I->getType(), I->getName() + ".neg"); break; + case Instruction::Select: { + // If both arms of the select are constants, we don't need to recurse. + // Therefore, this transform is not limited by uses. + auto *Sel = cast<SelectInst>(I); + Constant *TrueC, *FalseC; + if (match(Sel->getTrueValue(), m_ImmConstant(TrueC)) && + match(Sel->getFalseValue(), m_ImmConstant(FalseC))) { + Constant *NegTrueC = ConstantExpr::getNeg(TrueC); + Constant *NegFalseC = ConstantExpr::getNeg(FalseC); + return Builder.CreateSelect(Sel->getCondition(), NegTrueC, NegFalseC, + I->getName() + ".neg", /*MDFrom=*/I); + } + break; + } default: break; // Other instructions require recursive reasoning. } diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp index 6c6351c70e3a..35739c3b9a21 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -299,6 +299,29 @@ Instruction *InstCombinerImpl::foldIntegerTypedPHI(PHINode &PN) { IntToPtr->getOperand(0)->getType()); } +// Remove RoundTrip IntToPtr/PtrToInt Cast on PHI-Operand and +// fold Phi-operand to bitcast. +Instruction *InstCombinerImpl::foldPHIArgIntToPtrToPHI(PHINode &PN) { + // convert ptr2int ( phi[ int2ptr(ptr2int(x))] ) --> ptr2int ( phi [ x ] ) + // Make sure all uses of phi are ptr2int. + if (!all_of(PN.users(), [](User *U) { return isa<PtrToIntInst>(U); })) + return nullptr; + + // Iterating over all operands to check presence of target pointers for + // optimization. + bool OperandWithRoundTripCast = false; + for (unsigned OpNum = 0; OpNum != PN.getNumIncomingValues(); ++OpNum) { + if (auto *NewOp = + simplifyIntToPtrRoundTripCast(PN.getIncomingValue(OpNum))) { + PN.setIncomingValue(OpNum, NewOp); + OperandWithRoundTripCast = true; + } + } + if (!OperandWithRoundTripCast) + return nullptr; + return &PN; +} + /// If we have something like phi [insertvalue(a,b,0), insertvalue(c,d,0)], /// turn this into a phi[a,c] and phi[b,d] and a single insertvalue. Instruction * @@ -1306,6 +1329,9 @@ Instruction *InstCombinerImpl::visitPHINode(PHINode &PN) { if (Instruction *Result = foldPHIArgZextsIntoPHI(PN)) return Result; + if (Instruction *Result = foldPHIArgIntToPtrToPHI(PN)) + return Result; + // If all PHI operands are the same operation, pull them through the PHI, // reducing code size. if (isa<Instruction>(PN.getIncomingValue(0)) && diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp index ce2b913dba61..4a1e82ae9c1d 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -38,15 +38,16 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" static Value *createMinMax(InstCombiner::BuilderTy &Builder, SelectPatternFlavor SPF, Value *A, Value *B) { @@ -165,7 +166,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // simplify/reduce the instructions. APInt TC = *SelTC; APInt FC = *SelFC; - if (!TC.isNullValue() && !FC.isNullValue()) { + if (!TC.isZero() && !FC.isZero()) { // If the select constants differ by exactly one bit and that's the same // bit that is masked and checked by the select condition, the select can // be replaced by bitwise logic to set/clear one bit of the constant result. @@ -202,7 +203,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Determine which shift is needed to transform result of the 'and' into the // desired result. - const APInt &ValC = !TC.isNullValue() ? TC : FC; + const APInt &ValC = !TC.isZero() ? TC : FC; unsigned ValZeros = ValC.logBase2(); unsigned AndZeros = AndMask.logBase2(); @@ -224,7 +225,7 @@ static Value *foldSelectICmpAnd(SelectInst &Sel, ICmpInst *Cmp, // Okay, now we know that everything is set up, we just don't know whether we // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TC.isNullValue(); + bool ShouldNotVal = !TC.isZero(); ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; if (ShouldNotVal) V = Builder.CreateXor(V, ValC); @@ -319,8 +320,16 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, Value *X, *Y; if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && (TI->hasOneUse() || FI->hasOneUse())) { + // Intersect FMF from the fneg instructions and union those with the select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - return UnaryOperator::CreateFNegFMF(NewSel, TI); + if (auto *NewSelI = dyn_cast<Instruction>(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; } // Min/max intrinsic with a common operand can have the common operand pulled @@ -420,10 +429,9 @@ Instruction *InstCombinerImpl::foldSelectOpOp(SelectInst &SI, Instruction *TI, } static bool isSelect01(const APInt &C1I, const APInt &C2I) { - if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero. + if (!C1I.isZero() && !C2I.isZero()) // One side must be zero. return false; - return C1I.isOneValue() || C1I.isAllOnesValue() || - C2I.isOneValue() || C2I.isAllOnesValue(); + return C1I.isOne() || C1I.isAllOnes() || C2I.isOne() || C2I.isAllOnes(); } /// Try to fold the select into one of the operands to allow further @@ -715,6 +723,58 @@ static Instruction *foldSetClearBits(SelectInst &Sel, return nullptr; } +// select (x == 0), 0, x * y --> freeze(y) * x +// select (y == 0), 0, x * y --> freeze(x) * y +// select (x == 0), undef, x * y --> freeze(y) * x +// select (x == undef), 0, x * y --> freeze(y) * x +// Usage of mul instead of 0 will make the result more poisonous, +// so the operand that was not checked in the condition should be frozen. +// The latter folding is applied only when a constant compared with x is +// is a vector consisting of 0 and undefs. If a constant compared with x +// is a scalar undefined value or undefined vector then an expression +// should be already folded into a constant. +static Instruction *foldSelectZeroOrMul(SelectInst &SI, InstCombinerImpl &IC) { + auto *CondVal = SI.getCondition(); + auto *TrueVal = SI.getTrueValue(); + auto *FalseVal = SI.getFalseValue(); + Value *X, *Y; + ICmpInst::Predicate Predicate; + + // Assuming that constant compared with zero is not undef (but it may be + // a vector with some undef elements). Otherwise (when a constant is undef) + // the select expression should be already simplified. + if (!match(CondVal, m_ICmp(Predicate, m_Value(X), m_Zero())) || + !ICmpInst::isEquality(Predicate)) + return nullptr; + + if (Predicate == ICmpInst::ICMP_NE) + std::swap(TrueVal, FalseVal); + + // Check that TrueVal is a constant instead of matching it with m_Zero() + // to handle the case when it is a scalar undef value or a vector containing + // non-zero elements that are masked by undef elements in the compare + // constant. + auto *TrueValC = dyn_cast<Constant>(TrueVal); + if (TrueValC == nullptr || + !match(FalseVal, m_c_Mul(m_Specific(X), m_Value(Y))) || + !isa<Instruction>(FalseVal)) + return nullptr; + + auto *ZeroC = cast<Constant>(cast<Instruction>(CondVal)->getOperand(1)); + auto *MergedC = Constant::mergeUndefsWith(TrueValC, ZeroC); + // If X is compared with 0 then TrueVal could be either zero or undef. + // m_Zero match vectors containing some undef elements, but for scalars + // m_Undef should be used explicitly. + if (!match(MergedC, m_Zero()) && !match(MergedC, m_Undef())) + return nullptr; + + auto *FalseValI = cast<Instruction>(FalseVal); + auto *FrY = IC.InsertNewInstBefore(new FreezeInst(Y, Y->getName() + ".fr"), + *FalseValI); + IC.replaceOperand(*FalseValI, FalseValI->getOperand(0) == Y ? 0 : 1, FrY); + return IC.replaceInstUsesWith(SI, FalseValI); +} + /// Transform patterns such as (a > b) ? a - b : 0 into usub.sat(a, b). /// There are 8 commuted/swapped variants of this pattern. /// TODO: Also support a - UMIN(a,b) patterns. @@ -1229,8 +1289,8 @@ Instruction *InstCombinerImpl::foldSelectValueEquivalence(SelectInst &Sel, // Iff -C1 s<= C2 s<= C0-C1 // Also ULT predicate can also be UGT iff C0 != -1 (+invert result) // SLT predicate can also be SGT iff C2 != INT_MAX (+invert res.) -static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, - InstCombiner::BuilderTy &Builder) { +static Value *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, + InstCombiner::BuilderTy &Builder) { Value *X = Sel0.getTrueValue(); Value *Sel1 = Sel0.getFalseValue(); @@ -1238,36 +1298,42 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // Said condition must be one-use. if (!Cmp0.hasOneUse()) return nullptr; + ICmpInst::Predicate Pred0 = Cmp0.getPredicate(); Value *Cmp00 = Cmp0.getOperand(0); Constant *C0; if (!match(Cmp0.getOperand(1), m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C0)))) return nullptr; - // Canonicalize Cmp0 into the form we expect. + + if (!isa<SelectInst>(Sel1)) { + Pred0 = ICmpInst::getInversePredicate(Pred0); + std::swap(X, Sel1); + } + + // Canonicalize Cmp0 into ult or uge. // FIXME: we shouldn't care about lanes that are 'undef' in the end? - switch (Cmp0.getPredicate()) { + switch (Pred0) { case ICmpInst::Predicate::ICMP_ULT: + case ICmpInst::Predicate::ICMP_UGE: + // Although icmp ult %x, 0 is an unusual thing to try and should generally + // have been simplified, it does not verify with undef inputs so ensure we + // are not in a strange state. + if (!match(C0, m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getZero(C0->getType()->getScalarSizeInBits())))) + return nullptr; break; // Great! case ICmpInst::Predicate::ICMP_ULE: - // We'd have to increment C0 by one, and for that it must not have all-ones - // element, but then it would have been canonicalized to 'ult' before - // we get here. So we can't do anything useful with 'ule'. - return nullptr; case ICmpInst::Predicate::ICMP_UGT: - // We want to canonicalize it to 'ult', so we'll need to increment C0, - // which again means it must not have any all-ones elements. + // We want to canonicalize it to 'ult' or 'uge', so we'll need to increment + // C0, which again means it must not have any all-ones elements. if (!match(C0, - m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_NE, - APInt::getAllOnesValue( - C0->getType()->getScalarSizeInBits())))) + m_SpecificInt_ICMP( + ICmpInst::Predicate::ICMP_NE, + APInt::getAllOnes(C0->getType()->getScalarSizeInBits())))) return nullptr; // Can't do, have all-ones element[s]. C0 = InstCombiner::AddOne(C0); - std::swap(X, Sel1); break; - case ICmpInst::Predicate::ICMP_UGE: - // The only way we'd get this predicate if this `icmp` has extra uses, - // but then we won't be able to do this fold. - return nullptr; default: return nullptr; // Unknown predicate. } @@ -1277,11 +1343,16 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!Sel1->hasOneUse()) return nullptr; + // If the types do not match, look through any truncs to the underlying + // instruction. + if (Cmp00->getType() != X->getType() && X->hasOneUse()) + match(X, m_TruncOrSelf(m_Value(X))); + // We now can finish matching the condition of the outermost select: // it should either be the X itself, or an addition of some constant to X. Constant *C1; if (Cmp00 == X) - C1 = ConstantInt::getNullValue(Sel0.getType()); + C1 = ConstantInt::getNullValue(X->getType()); else if (!match(Cmp00, m_Add(m_Specific(X), m_CombineAnd(m_AnyIntegralConstant(), m_Constant(C1))))) @@ -1335,6 +1406,8 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, // The thresholds of this clamp-like pattern. auto *ThresholdLowIncl = ConstantExpr::getNeg(C1); auto *ThresholdHighExcl = ConstantExpr::getSub(C0, C1); + if (Pred0 == ICmpInst::Predicate::ICMP_UGE) + std::swap(ThresholdLowIncl, ThresholdHighExcl); // The fold has a precondition 1: C2 s>= ThresholdLow auto *Precond1 = ConstantExpr::getICmp(ICmpInst::Predicate::ICMP_SGE, C2, @@ -1347,15 +1420,29 @@ static Instruction *canonicalizeClampLike(SelectInst &Sel0, ICmpInst &Cmp0, if (!match(Precond2, m_One())) return nullptr; + // If we are matching from a truncated input, we need to sext the + // ReplacementLow and ReplacementHigh values. Only do the transform if they + // are free to extend due to being constants. + if (X->getType() != Sel0.getType()) { + Constant *LowC, *HighC; + if (!match(ReplacementLow, m_ImmConstant(LowC)) || + !match(ReplacementHigh, m_ImmConstant(HighC))) + return nullptr; + ReplacementLow = ConstantExpr::getSExt(LowC, X->getType()); + ReplacementHigh = ConstantExpr::getSExt(HighC, X->getType()); + } + // All good, finally emit the new pattern. Value *ShouldReplaceLow = Builder.CreateICmpSLT(X, ThresholdLowIncl); Value *ShouldReplaceHigh = Builder.CreateICmpSGE(X, ThresholdHighExcl); Value *MaybeReplacedLow = Builder.CreateSelect(ShouldReplaceLow, ReplacementLow, X); - Instruction *MaybeReplacedHigh = - SelectInst::Create(ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); - return MaybeReplacedHigh; + // Create the final select. If we looked through a truncate above, we will + // need to retruncate the result. + Value *MaybeReplacedHigh = Builder.CreateSelect( + ShouldReplaceHigh, ReplacementHigh, MaybeReplacedLow); + return Builder.CreateTrunc(MaybeReplacedHigh, Sel0.getType()); } // If we have @@ -1446,8 +1533,8 @@ Instruction *InstCombinerImpl::foldSelectInstWithICmp(SelectInst &SI, if (Instruction *NewAbs = canonicalizeAbsNabs(SI, *ICI, *this)) return NewAbs; - if (Instruction *NewAbs = canonicalizeClampLike(SI, *ICI, Builder)) - return NewAbs; + if (Value *V = canonicalizeClampLike(SI, *ICI, Builder)) + return replaceInstUsesWith(SI, V); if (Instruction *NewSel = tryToReuseConstantFromSelectInComparison(SI, *ICI, *this)) @@ -1816,9 +1903,7 @@ foldOverflowingAddSubSelect(SelectInst &SI, InstCombiner::BuilderTy &Builder) { m_Value(TrueVal), m_Value(FalseVal)))) return false; - auto IsZeroOrOne = [](const APInt &C) { - return C.isNullValue() || C.isOneValue(); - }; + auto IsZeroOrOne = [](const APInt &C) { return C.isZero() || C.isOne(); }; auto IsMinMax = [&](Value *Min, Value *Max) { APInt MinVal = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); APInt MaxVal = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); @@ -2182,7 +2267,7 @@ static Instruction *moveAddAfterMinMax(SelectPatternFlavor SPF, Value *X, } /// Match a sadd_sat or ssub_sat which is using min/max to clamp the value. -Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { +Instruction *InstCombinerImpl::matchSAddSubSat(Instruction &MinMax1) { Type *Ty = MinMax1.getType(); // We are looking for a tree of: @@ -2212,23 +2297,14 @@ Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { if (!shouldChangeType(Ty->getScalarType()->getIntegerBitWidth(), NewBitWidth)) return nullptr; - // Also make sure that the number of uses is as expected. The "3"s are for the - // the two items of min/max (the compare and the select). - if (MinMax2->hasNUsesOrMore(3) || AddSub->hasNUsesOrMore(3)) + // Also make sure that the number of uses is as expected. The 3 is for the + // the two items of the compare and the select, or 2 from a min/max. + unsigned ExpUses = isa<IntrinsicInst>(MinMax1) ? 2 : 3; + if (MinMax2->hasNUsesOrMore(ExpUses) || AddSub->hasNUsesOrMore(ExpUses)) return nullptr; // Create the new type (which can be a vector type) Type *NewTy = Ty->getWithNewBitWidth(NewBitWidth); - // Match the two extends from the add/sub - Value *A, *B; - if(!match(AddSub, m_BinOp(m_SExt(m_Value(A)), m_SExt(m_Value(B))))) - return nullptr; - // And check the incoming values are of a type smaller than or equal to the - // size of the saturation. Otherwise the higher bits can cause different - // results. - if (A->getType()->getScalarSizeInBits() > NewBitWidth || - B->getType()->getScalarSizeInBits() > NewBitWidth) - return nullptr; Intrinsic::ID IntrinsicID; if (AddSub->getOpcode() == Instruction::Add) @@ -2238,10 +2314,16 @@ Instruction *InstCombinerImpl::matchSAddSubSat(SelectInst &MinMax1) { else return nullptr; + // The two operands of the add/sub must be nsw-truncatable to the NewTy. This + // is usually achieved via a sext from a smaller type. + if (ComputeMinSignedBits(AddSub->getOperand(0), 0, AddSub) > NewBitWidth || + ComputeMinSignedBits(AddSub->getOperand(1), 0, AddSub) > NewBitWidth) + return nullptr; + // Finally create and return the sat intrinsic, truncated to the new type Function *F = Intrinsic::getDeclaration(MinMax1.getModule(), IntrinsicID, NewTy); - Value *AT = Builder.CreateSExt(A, NewTy); - Value *BT = Builder.CreateSExt(B, NewTy); + Value *AT = Builder.CreateTrunc(AddSub->getOperand(0), NewTy); + Value *BT = Builder.CreateTrunc(AddSub->getOperand(1), NewTy); Value *Sat = Builder.CreateCall(F, {AT, BT}); return CastInst::Create(Instruction::SExt, Sat, Ty); } @@ -2432,7 +2514,7 @@ Instruction *InstCombinerImpl::foldVectorSelect(SelectInst &Sel) { unsigned NumElts = VecTy->getNumElements(); APInt UndefElts(NumElts, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(NumElts)); + APInt AllOnesEltMask(APInt::getAllOnes(NumElts)); if (Value *V = SimplifyDemandedVectorElts(&Sel, AllOnesEltMask, UndefElts)) { if (V != &Sel) return replaceInstUsesWith(Sel, V); @@ -2754,11 +2836,16 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { /* IsAnd */ IsAnd)) return I; - if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) - if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) + if (auto *ICmp0 = dyn_cast<ICmpInst>(CondVal)) { + if (auto *ICmp1 = dyn_cast<ICmpInst>(Op1)) { if (auto *V = foldAndOrOfICmpsOfAndWithPow2(ICmp0, ICmp1, &SI, IsAnd, /* IsLogical */ true)) return replaceInstUsesWith(SI, V); + + if (auto *V = foldEqOfParts(ICmp0, ICmp1, IsAnd)) + return replaceInstUsesWith(SI, V); + } + } } // select (select a, true, b), c, false -> select a, c, false @@ -2863,14 +2950,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { } // Canonicalize select with fcmp to fabs(). -0.0 makes this tricky. We need - // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. We - // also require nnan because we do not want to unintentionally change the - // sign of a NaN value. + // fast-math-flags (nsz) or fsub with +0.0 (not fneg) for this to work. // (X <= +/-0.0) ? (0.0 - X) : X --> fabs(X) - Instruction *FSub; if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && match(TrueVal, m_FSub(m_PosZeroFP(), m_Specific(FalseVal))) && - match(TrueVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && (Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); return replaceInstUsesWith(SI, Fabs); @@ -2878,7 +2961,6 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X > +/-0.0) ? X : (0.0 - X) --> fabs(X) if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && match(FalseVal, m_FSub(m_PosZeroFP(), m_Specific(TrueVal))) && - match(FalseVal, m_Instruction(FSub)) && FSub->hasNoNaNs() && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_UGT)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); return replaceInstUsesWith(SI, Fabs); @@ -2886,11 +2968,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // With nnan and nsz: // (X < +/-0.0) ? -X : X --> fabs(X) // (X <= +/-0.0) ? -X : X --> fabs(X) - Instruction *FNeg; if (match(CondVal, m_FCmp(Pred, m_Specific(FalseVal), m_AnyZeroFP())) && - match(TrueVal, m_FNeg(m_Specific(FalseVal))) && - match(TrueVal, m_Instruction(FNeg)) && FNeg->hasNoNaNs() && - FNeg->hasNoSignedZeros() && SI.hasNoSignedZeros() && + match(TrueVal, m_FNeg(m_Specific(FalseVal))) && SI.hasNoSignedZeros() && (Pred == FCmpInst::FCMP_OLT || Pred == FCmpInst::FCMP_OLE || Pred == FCmpInst::FCMP_ULT || Pred == FCmpInst::FCMP_ULE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, FalseVal, &SI); @@ -2900,9 +2979,7 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { // (X > +/-0.0) ? X : -X --> fabs(X) // (X >= +/-0.0) ? X : -X --> fabs(X) if (match(CondVal, m_FCmp(Pred, m_Specific(TrueVal), m_AnyZeroFP())) && - match(FalseVal, m_FNeg(m_Specific(TrueVal))) && - match(FalseVal, m_Instruction(FNeg)) && FNeg->hasNoNaNs() && - FNeg->hasNoSignedZeros() && SI.hasNoSignedZeros() && + match(FalseVal, m_FNeg(m_Specific(TrueVal))) && SI.hasNoSignedZeros() && (Pred == FCmpInst::FCMP_OGT || Pred == FCmpInst::FCMP_OGE || Pred == FCmpInst::FCMP_UGT || Pred == FCmpInst::FCMP_UGE)) { Value *Fabs = Builder.CreateUnaryIntrinsic(Intrinsic::fabs, TrueVal, &SI); @@ -2920,6 +2997,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { return Add; if (Instruction *Or = foldSetClearBits(SI, Builder)) return Or; + if (Instruction *Mul = foldSelectZeroOrMul(SI, *this)) + return Mul; // Turn (select C, (op X, Y), (op X, Z)) -> (op X, (select C, Y, Z)) auto *TI = dyn_cast<Instruction>(TrueVal); @@ -2939,8 +3018,10 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (Gep->getNumOperands() != 2 || Gep->getPointerOperand() != Base || !Gep->hasOneUse()) return nullptr; - Type *ElementType = Gep->getResultElementType(); Value *Idx = Gep->getOperand(1); + if (isa<VectorType>(CondVal->getType()) && !isa<VectorType>(Idx->getType())) + return nullptr; + Type *ElementType = Gep->getResultElementType(); Value *NewT = Idx; Value *NewF = Constant::getNullValue(Idx->getType()); if (Swap) @@ -3188,9 +3269,9 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { if (!CondVal->getType()->isVectorTy() && !AC.assumptions().empty()) { KnownBits Known(1); computeKnownBits(CondVal, Known, 0, &SI); - if (Known.One.isOneValue()) + if (Known.One.isOne()) return replaceInstUsesWith(SI, TrueVal); - if (Known.Zero.isOneValue()) + if (Known.Zero.isOne()) return replaceInstUsesWith(SI, FalseVal); } @@ -3230,7 +3311,8 @@ Instruction *InstCombinerImpl::visitSelectInst(SelectInst &SI) { Value *Mask; if (match(TrueVal, m_Zero()) && match(FalseVal, m_MaskedLoad(m_Value(), m_Value(), m_Value(Mask), - m_CombineOr(m_Undef(), m_Zero())))) { + m_CombineOr(m_Undef(), m_Zero()))) && + (CondVal->getType() == Mask->getType())) { // We can remove the select by ensuring the load zeros all lanes the // select would have. We determine this by proving there is no overlap // between the load and select masks. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index ca5e473fdecb..06421d553915 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -41,7 +41,7 @@ bool canTryToConstantAddTwoShiftAmounts(Value *Sh0, Value *ShAmt0, Value *Sh1, (Sh0->getType()->getScalarSizeInBits() - 1) + (Sh1->getType()->getScalarSizeInBits() - 1); APInt MaximalRepresentableShiftAmount = - APInt::getAllOnesValue(ShAmt0->getType()->getScalarSizeInBits()); + APInt::getAllOnes(ShAmt0->getType()->getScalarSizeInBits()); return MaximalRepresentableShiftAmount.uge(MaximalPossibleTotalShiftAmount); } @@ -172,8 +172,8 @@ Value *InstCombinerImpl::reassociateShiftAmtsOfTwoSameDirectionShifts( // There are many variants to this pattern: // a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt // b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt -// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt -// d) (x & ((-1 << MaskShAmt) >> MaskShAmt)) << ShiftShAmt +// c) (x & (-1 l>> MaskShAmt)) << ShiftShAmt +// d) (x & ((-1 << MaskShAmt) l>> MaskShAmt)) << ShiftShAmt // e) ((x << MaskShAmt) l>> MaskShAmt) << ShiftShAmt // f) ((x << MaskShAmt) a>> MaskShAmt) << ShiftShAmt // All these patterns can be simplified to just: @@ -213,11 +213,11 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes()); // (~(-1 << maskNbits)) auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes()); - // (-1 >> MaskShAmt) - auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt)); - // ((-1 << MaskShAmt) >> MaskShAmt) + // (-1 l>> MaskShAmt) + auto MaskC = m_LShr(m_AllOnes(), m_Value(MaskShAmt)); + // ((-1 << MaskShAmt) l>> MaskShAmt) auto MaskD = - m_Shr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); + m_LShr(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_Deferred(MaskShAmt)); Value *X; Constant *NewMask; @@ -240,7 +240,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // that shall remain in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with final + // completely unknown. Replace the `undef` shift amounts with final // shift bitwidth to ensure that the value remains undef when creating the // subsequent shift op. SumOfShAmts = Constant::replaceUndefsWith( @@ -272,7 +272,7 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift, // shall be unset in the root value (OuterShift). // An extend of an undef value becomes zero because the high bits are never - // completely unknown. Replace the the `undef` shift amounts with negated + // completely unknown. Replace the `undef` shift amounts with negated // bitwidth of innermost shift to ensure that the value remains undef when // creating the subsequent shift op. unsigned WidestTyBitWidth = WidestTy->getScalarSizeInBits(); @@ -346,9 +346,8 @@ static Instruction *foldShiftOfShiftedLogic(BinaryOperator &I, // TODO: Remove the one-use check if the other logic operand (Y) is constant. Value *X, *Y; auto matchFirstShift = [&](Value *V) { - BinaryOperator *BO; APInt Threshold(Ty->getScalarSizeInBits(), Ty->getScalarSizeInBits()); - return match(V, m_BinOp(BO)) && BO->getOpcode() == ShiftOpcode && + return match(V, m_BinOp(ShiftOpcode, m_Value(), m_Value())) && match(V, m_OneUse(m_Shift(m_Value(X), m_Constant(C0)))) && match(ConstantExpr::getAdd(C0, C1), m_SpecificInt_ICMP(ICmpInst::ICMP_ULT, Threshold)); @@ -661,23 +660,22 @@ static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { - bool isLeftShift = I.getOpcode() == Instruction::Shl; - const APInt *Op1C; if (!match(Op1, m_APInt(Op1C))) return nullptr; // See if we can propagate this shift into the input, this covers the trivial // cast of lshr(shl(x,c1),c2) as well as other more complex cases. + bool IsLeftShift = I.getOpcode() == Instruction::Shl; if (I.getOpcode() != Instruction::AShr && - canEvaluateShifted(Op0, Op1C->getZExtValue(), isLeftShift, *this, &I)) { + canEvaluateShifted(Op0, Op1C->getZExtValue(), IsLeftShift, *this, &I)) { LLVM_DEBUG( dbgs() << "ICE: GetShiftedValue propagating shift through expression" " to eliminate shift:\n IN: " << *Op0 << "\n SH: " << I << "\n"); return replaceInstUsesWith( - I, getShiftedValue(Op0, Op1C->getZExtValue(), isLeftShift, *this, DL)); + I, getShiftedValue(Op0, Op1C->getZExtValue(), IsLeftShift, *this, DL)); } // See if we can simplify any instructions used by the instruction whose sole @@ -686,202 +684,72 @@ Instruction *InstCombinerImpl::FoldShiftByConstant(Value *Op0, Constant *Op1, unsigned TypeBits = Ty->getScalarSizeInBits(); assert(!Op1C->uge(TypeBits) && "Shift over the type width should have been removed already"); + (void)TypeBits; if (Instruction *FoldedShift = foldBinOpIntoSelectOrPhi(I)) return FoldedShift; - // Fold shift2(trunc(shift1(x,c1)), c2) -> trunc(shift2(shift1(x,c1),c2)) - if (auto *TI = dyn_cast<TruncInst>(Op0)) { - // If 'shift2' is an ashr, we would have to get the sign bit into a funny - // place. Don't try to do this transformation in this case. Also, we - // require that the input operand is a shift-by-constant so that we have - // confidence that the shifts will get folded together. We could do this - // xform in more cases, but it is unlikely to be profitable. - const APInt *TrShiftAmt; - if (I.isLogicalShift() && - match(TI->getOperand(0), m_Shift(m_Value(), m_APInt(TrShiftAmt)))) { - auto *TrOp = cast<Instruction>(TI->getOperand(0)); - Type *SrcTy = TrOp->getType(); - - // Okay, we'll do this xform. Make the shift of shift. - Constant *ShAmt = ConstantExpr::getZExt(Op1, SrcTy); - // (shift2 (shift1 & 0x00FF), c2) - Value *NSh = Builder.CreateBinOp(I.getOpcode(), TrOp, ShAmt, I.getName()); - - // For logical shifts, the truncation has the effect of making the high - // part of the register be zeros. Emulate this by inserting an AND to - // clear the top bits as needed. This 'and' will usually be zapped by - // other xforms later if dead. - unsigned SrcSize = SrcTy->getScalarSizeInBits(); - Constant *MaskV = - ConstantInt::get(SrcTy, APInt::getLowBitsSet(SrcSize, TypeBits)); - - // The mask we constructed says what the trunc would do if occurring - // between the shifts. We want to know the effect *after* the second - // shift. We know that it is a logical shift by a constant, so adjust the - // mask as appropriate. - MaskV = ConstantExpr::get(I.getOpcode(), MaskV, ShAmt); - // shift1 & 0x00FF - Value *And = Builder.CreateAnd(NSh, MaskV, TI->getName()); - // Return the value truncated to the interesting size. - return new TruncInst(And, Ty); - } - } - - if (Op0->hasOneUse()) { - if (BinaryOperator *Op0BO = dyn_cast<BinaryOperator>(Op0)) { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - Value *V1; - const APInt *CC; - switch (Op0BO->getOpcode()) { - default: break; - case Instruction::Add: - case Instruction::And: - case Instruction::Or: - case Instruction::Xor: { - // These operators commute. - // Turn (Y + (X >> C)) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(1)->hasOneUse() && - match(Op0BO->getOperand(1), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), YS, V1, - Op0BO->getOperand(1)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (Y + ((X >> C) & CC)) << C -> ((X & (CC << C)) + (Y << C)) - Value *Op0BOOp1 = Op0BO->getOperand(1); - if (isLeftShift && Op0BOOp1->hasOneUse() && - match(Op0BOOp1, m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(0), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), YS, XM); - } - LLVM_FALLTHROUGH; - } - - case Instruction::Sub: { - // Turn ((X >> C) + Y) << C -> (X + (Y << C)) & (~0 << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), m_Shr(m_Value(V1), - m_Specific(Op1)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // (X + (Y << C)) - Value *X = Builder.CreateBinOp(Op0BO->getOpcode(), V1, YS, - Op0BO->getOperand(0)->getName()); - unsigned Op1Val = Op1C->getLimitedValue(TypeBits); - APInt Bits = APInt::getHighBitsSet(TypeBits, TypeBits - Op1Val); - Constant *Mask = ConstantInt::get(Ty, Bits); - return BinaryOperator::CreateAnd(X, Mask); - } - - // Turn (((X >> C)&CC) + Y) << C -> (X + (Y << C)) & (CC << C) - if (isLeftShift && Op0BO->getOperand(0)->hasOneUse() && - match(Op0BO->getOperand(0), - m_And(m_OneUse(m_Shr(m_Value(V1), m_Specific(Op1))), - m_APInt(CC)))) { - Value *YS = // (Y << C) - Builder.CreateShl(Op0BO->getOperand(1), Op1, Op0BO->getName()); - // X & (CC << C) - Value *XM = Builder.CreateAnd( - V1, ConstantExpr::getShl(ConstantInt::get(Ty, *CC), Op1), - V1->getName() + ".mask"); - return BinaryOperator::Create(Op0BO->getOpcode(), XM, YS); - } - - break; - } - } + if (!Op0->hasOneUse()) + return nullptr; - // If the operand is a bitwise operator with a constant RHS, and the - // shift is the only use, we can pull it out of the shift. - const APInt *Op0C; - if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { - if (canShiftBinOpWithConstantRHS(I, Op0BO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(1)), Op1); + if (auto *Op0BO = dyn_cast<BinaryOperator>(Op0)) { + // If the operand is a bitwise operator with a constant RHS, and the + // shift is the only use, we can pull it out of the shift. + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(Op0BO->getOperand(1)), Op1); - Value *NewShift = + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); - NewShift->takeName(Op0BO); - - return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, - NewRHS); - } - } - - // If the operand is a subtract with a constant LHS, and the shift - // is the only use, we can pull it out of the shift. - // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) - if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && - match(Op0BO->getOperand(0), m_APInt(Op0C))) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(Op0BO->getOperand(0)), Op1); - - Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); NewShift->takeName(Op0BO); - return BinaryOperator::CreateSub(NewRHS, NewShift); + return BinaryOperator::Create(Op0BO->getOpcode(), NewShift, NewRHS); } } + } - // If we have a select that conditionally executes some binary operator, - // see if we can pull it the select and operator through the shift. - // - // For example, turning: - // shl (select C, (add X, C1), X), C2 - // Into: - // Y = shl X, C2 - // select C, (add Y, C1 << C2), Y - Value *Cond; - BinaryOperator *TBO; - Value *FalseVal; - if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), - m_Value(FalseVal)))) { - const APInt *C; - if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && - match(TBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, TBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(TBO->getOperand(1)), Op1); + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(TBO->getOperand(1)), Op1); - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); - Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewOp, NewShift); - } + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); } + } - BinaryOperator *FBO; - Value *TrueVal; - if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), - m_OneUse(m_BinOp(FBO))))) { - const APInt *C; - if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && - match(FBO->getOperand(1), m_APInt(C)) && - canShiftBinOpWithConstantRHS(I, FBO)) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), - cast<Constant>(FBO->getOperand(1)), Op1); + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO)) { + Constant *NewRHS = ConstantExpr::get( + I.getOpcode(), cast<Constant>(FBO->getOperand(1)), Op1); - Value *NewShift = - Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); - Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, - NewRHS); - return SelectInst::Create(Cond, NewShift, NewOp); - } + Value *NewShift = Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); } } @@ -908,41 +776,41 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { Type *Ty = I.getType(); unsigned BitWidth = Ty->getScalarSizeInBits(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); - // shl (zext X), ShAmt --> zext (shl X, ShAmt) + // shl (zext X), C --> zext (shl X, C) // This is only valid if X would have zeros shifted out. Value *X; if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { unsigned SrcWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt < SrcWidth && - MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmt), 0, &I)) - return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); + if (ShAmtC < SrcWidth && + MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), 0, &I)) + return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); } // (X >> C) << C --> X & (-1 << C) if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } - const APInt *ShOp1; - if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>?,exact C1) << C2 --> X << (C2 - C1) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + const APInt *C1; + if (match(Op0, m_Exact(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>?,exact C1) << C --> X << (C - C1) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); return NewShl; } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>?exact C1) << C2 --> X >>?exact (C1 - C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>?exact C1) << C --> X >>?exact (C1 - C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *NewShr = BinaryOperator::Create( cast<BinaryOperator>(Op0)->getOpcode(), X, ShiftDiff); NewShr->setIsExact(true); @@ -950,49 +818,135 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { } } - if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(ShOp1)))) && - ShOp1->ult(BitWidth)) { - unsigned ShrAmt = ShOp1->getZExtValue(); - if (ShrAmt < ShAmt) { - // If C1 < C2: (X >>? C1) << C2 --> X << (C2 - C1) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShrAmt); + if (match(Op0, m_OneUse(m_Shr(m_Value(X), m_APInt(C1)))) && + C1->ult(BitWidth)) { + unsigned ShrAmt = C1->getZExtValue(); + if (ShrAmt < ShAmtC) { + // If C1 < C: (X >>? C1) << C --> (X << (C - C1)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShrAmt); auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); NewShl->setHasNoSignedWrap(I.hasNoSignedWrap()); Builder.Insert(NewShl); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - if (ShrAmt > ShAmt) { - // If C1 > C2: (X >>? C1) << C2 --> X >>? (C1 - C2) & (-1 << C2) - Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmt); + if (ShrAmt > ShAmtC) { + // If C1 > C: (X >>? C1) << C --> (X >>? (C1 - C)) & (-1 << C) + Constant *ShiftDiff = ConstantInt::get(Ty, ShrAmt - ShAmtC); auto *OldShr = cast<BinaryOperator>(Op0); auto *NewShr = BinaryOperator::Create(OldShr->getOpcode(), X, ShiftDiff); NewShr->setIsExact(OldShr->isExact()); Builder.Insert(NewShr); - APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShr, ConstantInt::get(Ty, Mask)); } } - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // Similar to above, but look through an intermediate trunc instruction. + BinaryOperator *Shr; + if (match(Op0, m_OneUse(m_Trunc(m_OneUse(m_BinOp(Shr))))) && + match(Shr, m_Shr(m_Value(X), m_APInt(C1)))) { + // The larger shift direction survives through the transform. + unsigned ShrAmtC = C1->getZExtValue(); + unsigned ShDiff = ShrAmtC > ShAmtC ? ShrAmtC - ShAmtC : ShAmtC - ShrAmtC; + Constant *ShiftDiffC = ConstantInt::get(X->getType(), ShDiff); + auto ShiftOpc = ShrAmtC > ShAmtC ? Shr->getOpcode() : Instruction::Shl; + + // If C1 > C: + // (trunc (X >> C1)) << C --> (trunc (X >> (C1 - C))) && (-1 << C) + // If C > C1: + // (trunc (X >> C1)) << C --> (trunc (X << (C - C1))) && (-1 << C) + Value *NewShift = Builder.CreateBinOp(ShiftOpc, X, ShiftDiffC, "sh.diff"); + Value *Trunc = Builder.CreateTrunc(NewShift, Ty, "tr.sh.diff"); + APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmtC)); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, Mask)); + } + + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + unsigned AmtSum = ShAmtC + C1->getZExtValue(); // Oversized shifts are simplified to zero in InstSimplify. if (AmtSum < BitWidth) // (X << C1) << C2 --> X << (C1 + C2) return BinaryOperator::CreateShl(X, ConstantInt::get(Ty, AmtSum)); } + // If we have an opposite shift by the same amount, we may be able to + // reorder binops and shifts to eliminate math/logic. + auto isSuitableBinOpcode = [](Instruction::BinaryOps BinOpcode) { + switch (BinOpcode) { + default: + return false; + case Instruction::Add: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Sub: + // NOTE: Sub is not commutable and the tranforms below may not be valid + // when the shift-right is operand 1 (RHS) of the sub. + return true; + } + }; + BinaryOperator *Op0BO; + if (match(Op0, m_OneUse(m_BinOp(Op0BO))) && + isSuitableBinOpcode(Op0BO->getOpcode())) { + // Commute so shift-right is on LHS of the binop. + // (Y bop (X >> C)) << C -> ((X >> C) bop Y) << C + // (Y bop ((X >> C) & CC)) << C -> (((X >> C) & CC) bop Y) << C + Value *Shr = Op0BO->getOperand(0); + Value *Y = Op0BO->getOperand(1); + Value *X; + const APInt *CC; + if (Op0BO->isCommutative() && Y->hasOneUse() && + (match(Y, m_Shr(m_Value(), m_Specific(Op1))) || + match(Y, m_And(m_OneUse(m_Shr(m_Value(), m_Specific(Op1))), + m_APInt(CC))))) + std::swap(Shr, Y); + + // ((X >> C) bop Y) << C -> (X bop (Y << C)) & (~0 << C) + if (match(Shr, m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // (X bop (Y << C)) + Value *B = + Builder.CreateBinOp(Op0BO->getOpcode(), X, YS, Shr->getName()); + unsigned Op1Val = C->getLimitedValue(BitWidth); + APInt Bits = APInt::getHighBitsSet(BitWidth, BitWidth - Op1Val); + Constant *Mask = ConstantInt::get(Ty, Bits); + return BinaryOperator::CreateAnd(B, Mask); + } + + // (((X >> C) & CC) bop Y) << C -> (X & (CC << C)) bop (Y << C) + if (match(Shr, + m_OneUse(m_And(m_OneUse(m_Shr(m_Value(X), m_Specific(Op1))), + m_APInt(CC))))) { + // Y << C + Value *YS = Builder.CreateShl(Y, Op1, Op0BO->getName()); + // X & (CC << C) + Value *M = Builder.CreateAnd(X, ConstantInt::get(Ty, CC->shl(*C)), + X->getName() + ".mask"); + return BinaryOperator::Create(Op0BO->getOpcode(), M, YS); + } + } + + // (C1 - X) << C --> (C1 << C) - (X << C) + if (match(Op0, m_OneUse(m_Sub(m_APInt(C1), m_Value(X))))) { + Constant *NewLHS = ConstantInt::get(Ty, C1->shl(*C)); + Value *NewShift = Builder.CreateShl(X, Op1); + return BinaryOperator::CreateSub(NewLHS, NewShift); + } + // If the shifted-out value is known-zero, then this is a NUW shift. if (!I.hasNoUnsignedWrap() && - MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getHighBitsSet(BitWidth, ShAmtC), 0, + &I)) { I.setHasNoUnsignedWrap(); return &I; } // If the shifted-out value is all signbits, then this is a NSW shift. - if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmt) { + if (!I.hasNoSignedWrap() && ComputeNumSignBits(Op0, 0, &I) > ShAmtC) { I.setHasNoSignedWrap(); return &I; } @@ -1048,12 +1002,12 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); Type *Ty = I.getType(); - const APInt *ShAmtAPInt; - if (match(Op1, m_APInt(ShAmtAPInt))) { - unsigned ShAmt = ShAmtAPInt->getZExtValue(); + const APInt *C; + if (match(Op1, m_APInt(C))) { + unsigned ShAmtC = C->getZExtValue(); unsigned BitWidth = Ty->getScalarSizeInBits(); auto *II = dyn_cast<IntrinsicInst>(Op0); - if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmt && + if (II && isPowerOf2_32(BitWidth) && Log2_32(BitWidth) == ShAmtC && (II->getIntrinsicID() == Intrinsic::ctlz || II->getIntrinsicID() == Intrinsic::cttz || II->getIntrinsicID() == Intrinsic::ctpop)) { @@ -1067,78 +1021,81 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { } Value *X; - const APInt *ShOp1; - if (match(Op0, m_Shl(m_Value(X), m_APInt(ShOp1))) && ShOp1->ult(BitWidth)) { - if (ShOp1->ult(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShAmt - ShlAmt); + const APInt *C1; + if (match(Op0, m_Shl(m_Value(X), m_APInt(C1))) && C1->ult(BitWidth)) { + if (C1->ult(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShAmtC - ShlAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X >>u (C2 - C1) + // (X <<nuw C1) >>u C --> X >>u (C - C1) auto *NewLShr = BinaryOperator::CreateLShr(X, ShiftDiff); NewLShr->setIsExact(I.isExact()); return NewLShr; } - // (X << C1) >>u C2 --> (X >>u (C2 - C1)) & (-1 >> C2) + // (X << C1) >>u C --> (X >>u (C - C1)) & (-1 >> C) Value *NewLShr = Builder.CreateLShr(X, ShiftDiff, "", I.isExact()); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewLShr, ConstantInt::get(Ty, Mask)); } - if (ShOp1->ugt(ShAmt)) { - unsigned ShlAmt = ShOp1->getZExtValue(); - Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmt - ShAmt); + if (C1->ugt(ShAmtC)) { + unsigned ShlAmtC = C1->getZExtValue(); + Constant *ShiftDiff = ConstantInt::get(Ty, ShlAmtC - ShAmtC); if (cast<BinaryOperator>(Op0)->hasNoUnsignedWrap()) { - // (X <<nuw C1) >>u C2 --> X <<nuw (C1 - C2) + // (X <<nuw C1) >>u C --> X <<nuw (C1 - C) auto *NewShl = BinaryOperator::CreateShl(X, ShiftDiff); NewShl->setHasNoUnsignedWrap(true); return NewShl; } - // (X << C1) >>u C2 --> X << (C1 - C2) & (-1 >> C2) + // (X << C1) >>u C --> X << (C1 - C) & (-1 >> C) Value *NewShl = Builder.CreateShl(X, ShiftDiff); - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(NewShl, ConstantInt::get(Ty, Mask)); } - assert(*ShOp1 == ShAmt); + assert(*C1 == ShAmtC); // (X << C) >>u C --> X & (-1 >>u C) - APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmt)); + APInt Mask(APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - assert(ShAmt < X->getType()->getScalarSizeInBits() && + assert(ShAmtC < X->getType()->getScalarSizeInBits() && "Big shift not simplified to zero?"); // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN - Value *NewLShr = Builder.CreateLShr(X, ShAmt); + Value *NewLShr = Builder.CreateLShr(X, ShAmtC); return new ZExtInst(NewLShr, Ty); } - if (match(Op0, m_SExt(m_Value(X))) && - (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { - // Are we moving the sign bit to the low bit and widening with high zeros? + if (match(Op0, m_SExt(m_Value(X)))) { unsigned SrcTyBitWidth = X->getType()->getScalarSizeInBits(); - if (ShAmt == BitWidth - 1) { - // lshr (sext i1 X to iN), N-1 --> zext X to iN - if (SrcTyBitWidth == 1) - return new ZExtInst(X, Ty); + // lshr (sext i1 X to iN), C --> select (X, -1 >> C, 0) + if (SrcTyBitWidth == 1) { + auto *NewC = ConstantInt::get( + Ty, APInt::getLowBitsSet(BitWidth, BitWidth - ShAmtC)); + return SelectInst::Create(X, NewC, ConstantInt::getNullValue(Ty)); + } - // lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN - if (Op0->hasOneUse()) { + if ((!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType())) && + Op0->hasOneUse()) { + // Are we moving the sign bit to the low bit and widening with high + // zeros? lshr (sext iM X to iN), N-1 --> zext (lshr X, M-1) to iN + if (ShAmtC == BitWidth - 1) { Value *NewLShr = Builder.CreateLShr(X, SrcTyBitWidth - 1); return new ZExtInst(NewLShr, Ty); } - } - // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN - if (ShAmt == BitWidth - SrcTyBitWidth && Op0->hasOneUse()) { - // The new shift amount can't be more than the narrow source type. - unsigned NewShAmt = std::min(ShAmt, SrcTyBitWidth - 1); - Value *AShr = Builder.CreateAShr(X, NewShAmt); - return new ZExtInst(AShr, Ty); + // lshr (sext iM X to iN), N-M --> zext (ashr X, min(N-M, M-1)) to iN + if (ShAmtC == BitWidth - SrcTyBitWidth) { + // The new shift amount can't be more than the narrow source type. + unsigned NewShAmt = std::min(ShAmtC, SrcTyBitWidth - 1); + Value *AShr = Builder.CreateAShr(X, NewShAmt); + return new ZExtInst(AShr, Ty); + } } } Value *Y; - if (ShAmt == BitWidth - 1) { + if (ShAmtC == BitWidth - 1) { // lshr i32 or(X,-X), 31 --> zext (X != 0) if (match(Op0, m_OneUse(m_c_Or(m_Neg(m_Value(X)), m_Deferred(X))))) return new ZExtInst(Builder.CreateIsNotNull(X), Ty); @@ -1150,32 +1107,55 @@ Instruction *InstCombinerImpl::visitLShr(BinaryOperator &I) { // Check if a number is negative and odd: // lshr i32 (srem X, 2), 31 --> and (X >> 31), X if (match(Op0, m_OneUse(m_SRem(m_Value(X), m_SpecificInt(2))))) { - Value *Signbit = Builder.CreateLShr(X, ShAmt); + Value *Signbit = Builder.CreateLShr(X, ShAmtC); return BinaryOperator::CreateAnd(Signbit, X); } } - if (match(Op0, m_LShr(m_Value(X), m_APInt(ShOp1)))) { - unsigned AmtSum = ShAmt + ShOp1->getZExtValue(); + // (X >>u C1) >>u C --> X >>u (C1 + C) + if (match(Op0, m_LShr(m_Value(X), m_APInt(C1)))) { // Oversized shifts are simplified to zero in InstSimplify. + unsigned AmtSum = ShAmtC + C1->getZExtValue(); if (AmtSum < BitWidth) - // (X >>u C1) >>u C2 --> X >>u (C1 + C2) return BinaryOperator::CreateLShr(X, ConstantInt::get(Ty, AmtSum)); } + Instruction *TruncSrc; + if (match(Op0, m_OneUse(m_Trunc(m_Instruction(TruncSrc)))) && + match(TruncSrc, m_LShr(m_Value(X), m_APInt(C1)))) { + unsigned SrcWidth = X->getType()->getScalarSizeInBits(); + unsigned AmtSum = ShAmtC + C1->getZExtValue(); + + // If the combined shift fits in the source width: + // (trunc (X >>u C1)) >>u C --> and (trunc (X >>u (C1 + C)), MaskC + // + // If the first shift covers the number of bits truncated, then the + // mask instruction is eliminated (and so the use check is relaxed). + if (AmtSum < SrcWidth && + (TruncSrc->hasOneUse() || C1->uge(SrcWidth - BitWidth))) { + Value *SumShift = Builder.CreateLShr(X, AmtSum, "sum.shift"); + Value *Trunc = Builder.CreateTrunc(SumShift, Ty, I.getName()); + + // If the first shift does not cover the number of bits truncated, then + // we require a mask to get rid of high bits in the result. + APInt MaskC = APInt::getAllOnes(BitWidth).lshr(ShAmtC); + return BinaryOperator::CreateAnd(Trunc, ConstantInt::get(Ty, MaskC)); + } + } + // Look for a "splat" mul pattern - it replicates bits across each half of // a value, so a right shift is just a mask of the low bits: // lshr i32 (mul nuw X, Pow2+1), 16 --> and X, Pow2-1 // TODO: Generalize to allow more than just half-width shifts? const APInt *MulC; if (match(Op0, m_NUWMul(m_Value(X), m_APInt(MulC))) && - ShAmt * 2 == BitWidth && (*MulC - 1).isPowerOf2() && - MulC->logBase2() == ShAmt) + ShAmtC * 2 == BitWidth && (*MulC - 1).isPowerOf2() && + MulC->logBase2() == ShAmtC) return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, *MulC - 2)); // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && - MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { + MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmtC), 0, &I)) { I.setIsExact(); return &I; } @@ -1346,6 +1326,22 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { } } + // Prefer `-(x & 1)` over `(x << (bitwidth(x)-1)) a>> (bitwidth(x)-1)` + // as the pattern to splat the lowest bit. + // FIXME: iff X is already masked, we don't need the one-use check. + Value *X; + if (match(Op1, m_SpecificIntAllowUndef(BitWidth - 1)) && + match(Op0, m_OneUse(m_Shl(m_Value(X), + m_SpecificIntAllowUndef(BitWidth - 1))))) { + Constant *Mask = ConstantInt::get(Ty, 1); + // Retain the knowledge about the ignored lanes. + Mask = Constant::mergeUndefsWith( + Constant::mergeUndefsWith(Mask, cast<Constant>(Op1)), + cast<Constant>(cast<Instruction>(Op0)->getOperand(1))); + X = Builder.CreateAnd(X, Mask); + return BinaryOperator::CreateNeg(X); + } + if (Instruction *R = foldVariableSignZeroExtensionOfVariableHighBitExtract(I)) return R; @@ -1354,7 +1350,6 @@ Instruction *InstCombinerImpl::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateLShr(Op0, Op1); // ashr (xor %x, -1), %y --> xor (ashr %x, %y), -1 - Value *X; if (match(Op0, m_OneUse(m_Not(m_Value(X))))) { // Note that we must drop 'exact'-ness of the shift! // Note that we can't keep undef's in -1 vector constant! diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index 15b51ae8a5ee..e357a9da8b12 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -55,7 +55,7 @@ static bool ShrinkDemandedConstant(Instruction *I, unsigned OpNo, bool InstCombinerImpl::SimplifyDemandedInstructionBits(Instruction &Inst) { unsigned BitWidth = Inst.getType()->getScalarSizeInBits(); KnownBits Known(BitWidth); - APInt DemandedMask(APInt::getAllOnesValue(BitWidth)); + APInt DemandedMask(APInt::getAllOnes(BitWidth)); Value *V = SimplifyDemandedUseBits(&Inst, DemandedMask, Known, 0, &Inst); @@ -124,7 +124,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } Known.resetAll(); - if (DemandedMask.isNullValue()) // Not demanding any bits from V. + if (DemandedMask.isZero()) // Not demanding any bits from V. return UndefValue::get(VTy); if (Depth == MaxAnalysisRecursionDepth) @@ -274,8 +274,8 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // constant because that's a canonical 'not' op, and that is better for // combining, SCEV, and codegen. const APInt *C; - if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnesValue()) { - if ((*C | ~DemandedMask).isAllOnesValue()) { + if (match(I->getOperand(1), m_APInt(C)) && !C->isAllOnes()) { + if ((*C | ~DemandedMask).isAllOnes()) { // Force bits to 1 to create a 'not' op. I->setOperand(1, ConstantInt::getAllOnesValue(VTy)); return I; @@ -385,8 +385,26 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, Known = KnownBits::commonBits(LHSKnown, RHSKnown); break; } - case Instruction::ZExt: case Instruction::Trunc: { + // If we do not demand the high bits of a right-shifted and truncated value, + // then we may be able to truncate it before the shift. + Value *X; + const APInt *C; + if (match(I->getOperand(0), m_OneUse(m_LShr(m_Value(X), m_APInt(C))))) { + // The shift amount must be valid (not poison) in the narrow type, and + // it must not be greater than the high bits demanded of the result. + if (C->ult(I->getType()->getScalarSizeInBits()) && + C->ule(DemandedMask.countLeadingZeros())) { + // trunc (lshr X, C) --> lshr (trunc X), C + IRBuilderBase::InsertPointGuard Guard(Builder); + Builder.SetInsertPoint(I); + Value *Trunc = Builder.CreateTrunc(X, I->getType()); + return Builder.CreateLShr(Trunc, C->getZExtValue()); + } + } + } + LLVM_FALLTHROUGH; + case Instruction::ZExt: { unsigned SrcBitWidth = I->getOperand(0)->getType()->getScalarSizeInBits(); APInt InputDemandedMask = DemandedMask.zextOrTrunc(SrcBitWidth); @@ -516,8 +534,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, return I->getOperand(0); // We can't do this with the LHS for subtraction, unless we are only // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || - DemandedFromOps.isOneValue()) && + if ((I->getOpcode() == Instruction::Add || DemandedFromOps.isOne()) && DemandedFromOps.isSubsetOf(LHSKnown.Zero)) return I->getOperand(1); @@ -615,7 +632,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, // always convert this into a logical shr, even if the shift amount is // variable. The low bit of the shift cannot be an input sign bit unless // the shift amount is >= the size of the datatype, which is undefined. - if (DemandedMask.isOneValue()) { + if (DemandedMask.isOne()) { // Perform the logical shift right. Instruction *NewVal = BinaryOperator::CreateLShr( I->getOperand(0), I->getOperand(1), I->getName()); @@ -743,7 +760,7 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, } case Instruction::URem: { KnownBits Known2(BitWidth); - APInt AllOnes = APInt::getAllOnesValue(BitWidth); + APInt AllOnes = APInt::getAllOnes(BitWidth); if (SimplifyDemandedBits(I, 0, AllOnes, Known2, Depth + 1) || SimplifyDemandedBits(I, 1, AllOnes, Known2, Depth + 1)) return I; @@ -829,6 +846,29 @@ Value *InstCombinerImpl::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, KnownBitsComputed = true; break; } + case Intrinsic::umax: { + // UMax(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-zero bit of C. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getActiveBits()) + return II->getArgOperand(0); + break; + } + case Intrinsic::umin: { + // UMin(A, C) == A if ... + // The lowest non-zero bit of DemandMask is higher than the highest + // non-one bit of C. + // This comes from using DeMorgans on the above umax example. + const APInt *C; + unsigned CTZ = DemandedMask.countTrailingZeros(); + if (match(II->getArgOperand(1), m_APInt(C)) && + CTZ >= C->getBitWidth() - C->countLeadingOnes()) + return II->getArgOperand(0); + break; + } default: { // Handle target specific intrinsics Optional<Value *> V = targetSimplifyDemandedUseBitsIntrinsic( @@ -1021,8 +1061,8 @@ Value *InstCombinerImpl::simplifyShrShlDemandedBits( Known.Zero.setLowBits(ShlAmt - 1); Known.Zero &= DemandedMask; - APInt BitMask1(APInt::getAllOnesValue(BitWidth)); - APInt BitMask2(APInt::getAllOnesValue(BitWidth)); + APInt BitMask1(APInt::getAllOnes(BitWidth)); + APInt BitMask2(APInt::getAllOnes(BitWidth)); bool isLshr = (Shr->getOpcode() == Instruction::LShr); BitMask1 = isLshr ? (BitMask1.lshr(ShrAmt) << ShlAmt) : @@ -1088,7 +1128,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); - APInt EltMask(APInt::getAllOnesValue(VWidth)); + APInt EltMask(APInt::getAllOnes(VWidth)); assert((DemandedElts & ~EltMask) == 0 && "Invalid DemandedElts!"); if (match(V, m_Undef())) { @@ -1097,7 +1137,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, return nullptr; } - if (DemandedElts.isNullValue()) { // If nothing is demanded, provide poison. + if (DemandedElts.isZero()) { // If nothing is demanded, provide poison. UndefElts = EltMask; return PoisonValue::get(V->getType()); } @@ -1107,7 +1147,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, if (auto *C = dyn_cast<Constant>(V)) { // Check if this is identity. If so, return 0 since we are not simplifying // anything. - if (DemandedElts.isAllOnesValue()) + if (DemandedElts.isAllOnes()) return nullptr; Type *EltTy = cast<VectorType>(V->getType())->getElementType(); @@ -1260,7 +1300,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Handle trivial case of a splat. Only check the first element of LHS // operand. if (all_of(Shuffle->getShuffleMask(), [](int Elt) { return Elt == 0; }) && - DemandedElts.isAllOnesValue()) { + DemandedElts.isAllOnes()) { if (!match(I->getOperand(1), m_Undef())) { I->setOperand(1, PoisonValue::get(I->getOperand(1)->getType())); MadeChange = true; @@ -1515,8 +1555,8 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // Subtlety: If we load from a pointer, the pointer must be valid // regardless of whether the element is demanded. Doing otherwise risks // segfaults which didn't exist in the original program. - APInt DemandedPtrs(APInt::getAllOnesValue(VWidth)), - DemandedPassThrough(DemandedElts); + APInt DemandedPtrs(APInt::getAllOnes(VWidth)), + DemandedPassThrough(DemandedElts); if (auto *CV = dyn_cast<ConstantVector>(II->getOperand(2))) for (unsigned i = 0; i < VWidth; i++) { Constant *CElt = CV->getAggregateElement(i); @@ -1568,7 +1608,7 @@ Value *InstCombinerImpl::SimplifyDemandedVectorElts(Value *V, // If we've proven all of the lanes undef, return an undef value. // TODO: Intersect w/demanded lanes - if (UndefElts.isAllOnesValue()) + if (UndefElts.isAllOnes()) return UndefValue::get(I->getType());; return MadeChange ? I : nullptr; diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index 32b15376f898..32e537897140 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -35,37 +35,46 @@ #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/InstCombine/InstCombiner.h" #include <cassert> #include <cstdint> #include <iterator> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumAggregateReconstructionsSimplified, "Number of aggregate reconstructions turned into reuse of the " "original aggregate"); /// Return true if the value is cheaper to scalarize than it is to leave as a -/// vector operation. IsConstantExtractIndex indicates whether we are extracting -/// one known element from a vector constant. +/// vector operation. If the extract index \p EI is a constant integer then +/// some operations may be cheap to scalarize. /// /// FIXME: It's possible to create more instructions than previously existed. -static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { +static bool cheapToScalarize(Value *V, Value *EI) { + ConstantInt *CEI = dyn_cast<ConstantInt>(EI); + // If we can pick a scalar constant value out of a vector, that is free. if (auto *C = dyn_cast<Constant>(V)) - return IsConstantExtractIndex || C->getSplatValue(); + return CEI || C->getSplatValue(); + + if (CEI && match(V, m_Intrinsic<Intrinsic::experimental_stepvector>())) { + ElementCount EC = cast<VectorType>(V->getType())->getElementCount(); + // Index needs to be lower than the minimum size of the vector, because + // for scalable vector, the vector size is known at run time. + return CEI->getValue().ult(EC.getKnownMinValue()); + } // An insertelement to the same constant index as our extract will simplify // to the scalar inserted element. An insertelement to a different constant // index is irrelevant to our extract. if (match(V, m_InsertElt(m_Value(), m_Value(), m_ConstantInt()))) - return IsConstantExtractIndex; + return CEI; if (match(V, m_OneUse(m_Load(m_Value())))) return true; @@ -75,14 +84,12 @@ static bool cheapToScalarize(Value *V, bool IsConstantExtractIndex) { Value *V0, *V1; if (match(V, m_OneUse(m_BinOp(m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; CmpInst::Predicate UnusedPred; if (match(V, m_OneUse(m_Cmp(UnusedPred, m_Value(V0), m_Value(V1))))) - if (cheapToScalarize(V0, IsConstantExtractIndex) || - cheapToScalarize(V1, IsConstantExtractIndex)) + if (cheapToScalarize(V0, EI) || cheapToScalarize(V1, EI)) return true; return false; @@ -119,7 +126,8 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, // and that it is a binary operation which is cheap to scalarize. // otherwise return nullptr. if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || - !(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true)) + !(isa<BinaryOperator>(PHIUser)) || + !cheapToScalarize(PHIUser, EI.getIndexOperand())) return nullptr; // Create a scalar PHI node that will replace the vector PHI node @@ -170,24 +178,46 @@ Instruction *InstCombinerImpl::scalarizePHI(ExtractElementInst &EI, return &EI; } -static Instruction *foldBitcastExtElt(ExtractElementInst &Ext, - InstCombiner::BuilderTy &Builder, - bool IsBigEndian) { +Instruction *InstCombinerImpl::foldBitcastExtElt(ExtractElementInst &Ext) { Value *X; uint64_t ExtIndexC; if (!match(Ext.getVectorOperand(), m_BitCast(m_Value(X))) || - !X->getType()->isVectorTy() || !match(Ext.getIndexOperand(), m_ConstantInt(ExtIndexC))) return nullptr; + ElementCount NumElts = + cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); + Type *DestTy = Ext.getType(); + bool IsBigEndian = DL.isBigEndian(); + + // If we are casting an integer to vector and extracting a portion, that is + // a shift-right and truncate. + // TODO: Allow FP dest type by casting the trunc to FP? + if (X->getType()->isIntegerTy() && DestTy->isIntegerTy() && + isDesirableIntType(X->getType()->getPrimitiveSizeInBits())) { + assert(isa<FixedVectorType>(Ext.getVectorOperand()->getType()) && + "Expected fixed vector type for bitcast from scalar integer"); + + // Big endian requires adjusting the extract index since MSB is at index 0. + // LittleEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 X to i8 + // BigEndian: extelt (bitcast i32 X to v4i8), 0 -> trunc i32 (X >> 24) to i8 + if (IsBigEndian) + ExtIndexC = NumElts.getKnownMinValue() - 1 - ExtIndexC; + unsigned ShiftAmountC = ExtIndexC * DestTy->getPrimitiveSizeInBits(); + if (!ShiftAmountC || Ext.getVectorOperand()->hasOneUse()) { + Value *Lshr = Builder.CreateLShr(X, ShiftAmountC, "extelt.offset"); + return new TruncInst(Lshr, DestTy); + } + } + + if (!X->getType()->isVectorTy()) + return nullptr; + // If this extractelement is using a bitcast from a vector of the same number // of elements, see if we can find the source element from the source vector: // extelt (bitcast VecX), IndexC --> bitcast X[IndexC] auto *SrcTy = cast<VectorType>(X->getType()); - Type *DestTy = Ext.getType(); ElementCount NumSrcElts = SrcTy->getElementCount(); - ElementCount NumElts = - cast<VectorType>(Ext.getVectorOperandType())->getElementCount(); if (NumSrcElts == NumElts) if (Value *Elt = findScalarElement(X, ExtIndexC)) return new BitCastInst(Elt, DestTy); @@ -274,7 +304,7 @@ static APInt findDemandedEltsBySingleUser(Value *V, Instruction *UserInstr) { unsigned VWidth = cast<FixedVectorType>(V->getType())->getNumElements(); // Conservatively assume that all elements are needed. - APInt UsedElts(APInt::getAllOnesValue(VWidth)); + APInt UsedElts(APInt::getAllOnes(VWidth)); switch (UserInstr->getOpcode()) { case Instruction::ExtractElement: { @@ -322,11 +352,11 @@ static APInt findDemandedEltsByAllUsers(Value *V) { if (Instruction *I = dyn_cast<Instruction>(U.getUser())) { UnionUsedElts |= findDemandedEltsBySingleUser(V, I); } else { - UnionUsedElts = APInt::getAllOnesValue(VWidth); + UnionUsedElts = APInt::getAllOnes(VWidth); break; } - if (UnionUsedElts.isAllOnesValue()) + if (UnionUsedElts.isAllOnes()) break; } @@ -388,7 +418,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // If the input vector has multiple uses, simplify it based on a union // of all elements used. APInt DemandedElts = findDemandedEltsByAllUsers(SrcVec); - if (!DemandedElts.isAllOnesValue()) { + if (!DemandedElts.isAllOnes()) { APInt UndefElts(NumElts, 0); if (Value *V = SimplifyDemandedVectorElts( SrcVec, DemandedElts, UndefElts, 0 /* Depth */, @@ -402,7 +432,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } } - if (Instruction *I = foldBitcastExtElt(EI, Builder, DL.isBigEndian())) + if (Instruction *I = foldBitcastExtElt(EI)) return I; // If there's a vector PHI feeding a scalar use through this extractelement @@ -415,7 +445,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { // TODO come up with a n-ary matcher that subsumes both unary and // binary matchers. UnaryOperator *UO; - if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_UnOp(UO)) && cheapToScalarize(SrcVec, Index)) { // extelt (unop X), Index --> unop (extelt X, Index) Value *X = UO->getOperand(0); Value *E = Builder.CreateExtractElement(X, Index); @@ -423,7 +453,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { } BinaryOperator *BO; - if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, IndexC)) { + if (match(SrcVec, m_BinOp(BO)) && cheapToScalarize(SrcVec, Index)) { // extelt (binop X, Y), Index --> binop (extelt X, Index), (extelt Y, Index) Value *X = BO->getOperand(0), *Y = BO->getOperand(1); Value *E0 = Builder.CreateExtractElement(X, Index); @@ -434,7 +464,7 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) { Value *X, *Y; CmpInst::Predicate Pred; if (match(SrcVec, m_Cmp(Pred, m_Value(X), m_Value(Y))) && - cheapToScalarize(SrcVec, IndexC)) { + cheapToScalarize(SrcVec, Index)) { // extelt (cmp X, Y), Index --> cmp (extelt X, Index), (extelt Y, Index) Value *E0 = Builder.CreateExtractElement(X, Index); Value *E1 = Builder.CreateExtractElement(Y, Index); @@ -651,8 +681,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, if (InsElt->hasOneUse() && isa<InsertElementInst>(InsElt->user_back())) return; - auto *WideVec = - new ShuffleVectorInst(ExtVecOp, PoisonValue::get(ExtVecType), ExtendMask); + auto *WideVec = new ShuffleVectorInst(ExtVecOp, ExtendMask); // Insert the new shuffle after the vector operand of the extract is defined // (as long as it's not a PHI) or at the start of the basic block of the @@ -913,7 +942,7 @@ Instruction *InstCombinerImpl::foldAggregateConstructionIntoAggregateReuse( "We don't store nullptr in SourceAggregate!"); assert((Describe(SourceAggregate) == AggregateDescription::Found) == (I.index() != 0) && - "SourceAggregate should be valid after the the first element,"); + "SourceAggregate should be valid after the first element,"); // For this element, is there a plausible source aggregate? // FIXME: we could special-case undef element, IFF we know that in the @@ -1179,7 +1208,7 @@ static Instruction *foldInsSequenceIntoSplat(InsertElementInst &InsElt) { if (!ElementPresent[i]) Mask[i] = -1; - return new ShuffleVectorInst(FirstIE, PoisonVec, Mask); + return new ShuffleVectorInst(FirstIE, Mask); } /// Try to fold an insert element into an existing splat shuffle by changing @@ -1208,15 +1237,15 @@ static Instruction *foldInsEltIntoSplat(InsertElementInst &InsElt) { // Replace the shuffle mask element at the index of this insert with a zero. // For example: - // inselt (shuf (inselt undef, X, 0), undef, <0,undef,0,undef>), X, 1 - // --> shuf (inselt undef, X, 0), undef, <0,0,0,undef> + // inselt (shuf (inselt undef, X, 0), _, <0,undef,0,undef>), X, 1 + // --> shuf (inselt undef, X, 0), poison, <0,0,0,undef> unsigned NumMaskElts = cast<FixedVectorType>(Shuf->getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts); for (unsigned i = 0; i != NumMaskElts; ++i) NewMask[i] = i == IdxC ? 0 : Shuf->getMaskValue(i); - return new ShuffleVectorInst(Op0, UndefValue::get(Op0->getType()), NewMask); + return new ShuffleVectorInst(Op0, NewMask); } /// Try to fold an extract+insert element into an existing identity shuffle by @@ -1348,6 +1377,10 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { NewShufElts[I] = ShufConstVec->getAggregateElement(I); NewMaskElts[I] = Mask[I]; } + + // Bail if we failed to find an element. + if (!NewShufElts[I]) + return nullptr; } // Create new operands for a shuffle that includes the constant of the @@ -1399,6 +1432,41 @@ static Instruction *foldConstantInsEltIntoShuffle(InsertElementInst &InsElt) { return nullptr; } +/// If both the base vector and the inserted element are extended from the same +/// type, do the insert element in the narrow source type followed by extend. +/// TODO: This can be extended to include other cast opcodes, but particularly +/// if we create a wider insertelement, make sure codegen is not harmed. +static Instruction *narrowInsElt(InsertElementInst &InsElt, + InstCombiner::BuilderTy &Builder) { + // We are creating a vector extend. If the original vector extend has another + // use, that would mean we end up with 2 vector extends, so avoid that. + // TODO: We could ease the use-clause to "if at least one op has one use" + // (assuming that the source types match - see next TODO comment). + Value *Vec = InsElt.getOperand(0); + if (!Vec->hasOneUse()) + return nullptr; + + Value *Scalar = InsElt.getOperand(1); + Value *X, *Y; + CastInst::CastOps CastOpcode; + if (match(Vec, m_FPExt(m_Value(X))) && match(Scalar, m_FPExt(m_Value(Y)))) + CastOpcode = Instruction::FPExt; + else if (match(Vec, m_SExt(m_Value(X))) && match(Scalar, m_SExt(m_Value(Y)))) + CastOpcode = Instruction::SExt; + else if (match(Vec, m_ZExt(m_Value(X))) && match(Scalar, m_ZExt(m_Value(Y)))) + CastOpcode = Instruction::ZExt; + else + return nullptr; + + // TODO: We can allow mismatched types by creating an intermediate cast. + if (X->getType()->getScalarType() != Y->getType()) + return nullptr; + + // inselt (ext X), (ext Y), Index --> ext (inselt X, Y, Index) + Value *NewInsElt = Builder.CreateInsertElement(X, Y, InsElt.getOperand(2)); + return CastInst::Create(CastOpcode, NewInsElt, InsElt.getType()); +} + Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { Value *VecOp = IE.getOperand(0); Value *ScalarOp = IE.getOperand(1); @@ -1495,7 +1563,7 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (auto VecTy = dyn_cast<FixedVectorType>(VecOp->getType())) { unsigned VWidth = VecTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&IE, AllOnesEltMask, UndefElts)) { if (V != &IE) return replaceInstUsesWith(IE, V); @@ -1518,6 +1586,9 @@ Instruction *InstCombinerImpl::visitInsertElementInst(InsertElementInst &IE) { if (Instruction *IdentityShuf = foldInsEltIntoIdentityShuffle(IE)) return IdentityShuf; + if (Instruction *Ext = narrowInsElt(IE, Builder)) + return Ext; + return nullptr; } @@ -1924,8 +1995,8 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, // Splat from element 0. Any mask element that is undefined remains undefined. // For example: - // shuf (inselt undef, X, 2), undef, <2,2,undef> - // --> shuf (inselt undef, X, 0), undef, <0,0,undef> + // shuf (inselt undef, X, 2), _, <2,2,undef> + // --> shuf (inselt undef, X, 0), poison, <0,0,undef> unsigned NumMaskElts = cast<FixedVectorType>(Shuf.getType())->getNumElements(); SmallVector<int, 16> NewMask(NumMaskElts, 0); @@ -1933,7 +2004,7 @@ static Instruction *canonicalizeInsertSplat(ShuffleVectorInst &Shuf, if (Mask[i] == UndefMaskElem) NewMask[i] = Mask[i]; - return new ShuffleVectorInst(NewIns, UndefVec, NewMask); + return new ShuffleVectorInst(NewIns, NewMask); } /// Try to fold shuffles that are the equivalent of a vector select. @@ -2197,12 +2268,8 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, SmallVector<int, 16> Mask; Shuf.getShuffleMask(Mask); - // The shuffle must not change vector sizes. - // TODO: This restriction could be removed if the insert has only one use - // (because the transform would require a new length-changing shuffle). int NumElts = Mask.size(); - if (NumElts != (int)(cast<FixedVectorType>(V0->getType())->getNumElements())) - return nullptr; + int InpNumElts = cast<FixedVectorType>(V0->getType())->getNumElements(); // This is a specialization of a fold in SimplifyDemandedVectorElts. We may // not be able to handle it there if the insertelement has >1 use. @@ -2219,11 +2286,16 @@ static Instruction *foldShuffleWithInsert(ShuffleVectorInst &Shuf, if (match(V1, m_InsertElt(m_Value(X), m_Value(), m_ConstantInt(IdxC)))) { // Offset the index constant by the vector width because we are checking for // accesses to the 2nd vector input of the shuffle. - IdxC += NumElts; + IdxC += InpNumElts; // shuf ?, (inselt X, ?, IdxC), Mask --> shuf ?, X, Mask if (!is_contained(Mask, (int)IdxC)) return IC.replaceOperand(Shuf, 1, X); } + // For the rest of the transform, the shuffle must not change vector sizes. + // TODO: This restriction could be removed if the insert has only one use + // (because the transform would require a new length-changing shuffle). + if (NumElts != InpNumElts) + return nullptr; // shuffle (insert ?, Scalar, IndexC), V1, Mask --> insert V1, Scalar, IndexC' auto isShufflingScalarIntoOp1 = [&](Value *&Scalar, ConstantInt *&IndexC) { @@ -2413,16 +2485,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (LHS == RHS) { assert(!match(RHS, m_Undef()) && "Shuffle with 2 undef ops not simplified?"); - // Remap any references to RHS to use LHS. - SmallVector<int, 16> Elts; - for (unsigned i = 0; i != VWidth; ++i) { - // Propagate undef elements or force mask to LHS. - if (Mask[i] < 0) - Elts.push_back(UndefMaskElem); - else - Elts.push_back(Mask[i] % LHSWidth); - } - return new ShuffleVectorInst(LHS, UndefValue::get(RHS->getType()), Elts); + return new ShuffleVectorInst(LHS, createUnaryMask(Mask, LHSWidth)); } // shuffle undef, x, mask --> shuffle x, undef, mask' @@ -2444,7 +2507,7 @@ Instruction *InstCombinerImpl::visitShuffleVectorInst(ShuffleVectorInst &SVI) { return I; APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) return replaceInstUsesWith(SVI, V); diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp index 4e3b18e805ee..47b6dcb67a78 100644 --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -100,7 +100,6 @@ #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" -#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> #include <cassert> @@ -109,11 +108,12 @@ #include <string> #include <utility> +#define DEBUG_TYPE "instcombine" +#include "llvm/Transforms/Utils/InstructionWorklist.h" + using namespace llvm; using namespace llvm::PatternMatch; -#define DEBUG_TYPE "instcombine" - STATISTIC(NumWorklistIterations, "Number of instruction combining iterations performed"); @@ -202,23 +202,37 @@ Value *InstCombinerImpl::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(&Builder, DL, GEP); } +/// Legal integers and common types are considered desirable. This is used to +/// avoid creating instructions with types that may not be supported well by the +/// the backend. +/// NOTE: This treats i8, i16 and i32 specially because they are common +/// types in frontend languages. +bool InstCombinerImpl::isDesirableIntType(unsigned BitWidth) const { + switch (BitWidth) { + case 8: + case 16: + case 32: + return true; + default: + return DL.isLegalInteger(BitWidth); + } +} + /// Return true if it is desirable to convert an integer computation from a /// given bit width to a new bit width. /// We don't want to convert from a legal to an illegal type or from a smaller -/// to a larger illegal type. A width of '1' is always treated as a legal type -/// because i1 is a fundamental type in IR, and there are many specialized -/// optimizations for i1 types. Widths of 8, 16 or 32 are equally treated as +/// to a larger illegal type. A width of '1' is always treated as a desirable +/// type because i1 is a fundamental type in IR, and there are many specialized +/// optimizations for i1 types. Common/desirable widths are equally treated as /// legal to convert to, in order to open up more combining opportunities. -/// NOTE: this treats i8, i16 and i32 specially, due to them being so common -/// from frontend languages. bool InstCombinerImpl::shouldChangeType(unsigned FromWidth, unsigned ToWidth) const { bool FromLegal = FromWidth == 1 || DL.isLegalInteger(FromWidth); bool ToLegal = ToWidth == 1 || DL.isLegalInteger(ToWidth); - // Convert to widths of 8, 16 or 32 even if they are not legal types. Only - // shrink types, to prevent infinite loops. - if (ToWidth < FromWidth && (ToWidth == 8 || ToWidth == 16 || ToWidth == 32)) + // Convert to desirable widths even if they are not legal types. + // Only shrink types, to prevent infinite loops. + if (ToWidth < FromWidth && isDesirableIntType(ToWidth)) return true; // If this is a legal integer from type, and the result would be an illegal @@ -359,7 +373,8 @@ Value *InstCombinerImpl::simplifyIntToPtrRoundTripCast(Value *Val) { PtrToInt->getSrcTy()->getPointerAddressSpace() && DL.getPointerTypeSizeInBits(PtrToInt->getSrcTy()) == DL.getTypeSizeInBits(PtrToInt->getDestTy())) { - return Builder.CreateBitCast(PtrToInt->getOperand(0), CastTy); + return CastInst::CreateBitOrPointerCast(PtrToInt->getOperand(0), CastTy, + "", PtrToInt); } } return nullptr; @@ -961,14 +976,14 @@ static Value *foldOperationIntoSelectOperand(Instruction &I, Value *SO, assert(canConstantFoldCallTo(II, cast<Function>(II->getCalledOperand())) && "Expected constant-foldable intrinsic"); Intrinsic::ID IID = II->getIntrinsicID(); - if (II->getNumArgOperands() == 1) + if (II->arg_size() == 1) return Builder.CreateUnaryIntrinsic(IID, SO); // This works for real binary ops like min/max (where we always expect the // constant operand to be canonicalized as op1) and unary ops with a bonus // constant argument like ctlz/cttz. // TODO: Handle non-commutative binary intrinsics as below for binops. - assert(II->getNumArgOperands() == 2 && "Expected binary intrinsic"); + assert(II->arg_size() == 2 && "Expected binary intrinsic"); assert(isa<Constant>(II->getArgOperand(1)) && "Expected constant operand"); return Builder.CreateBinaryIntrinsic(IID, SO, II->getArgOperand(1)); } @@ -1058,7 +1073,7 @@ Instruction *InstCombinerImpl::FoldOpIntoSelect(Instruction &Op, // Compare for equality including undefs as equal. auto *Cmp = ConstantExpr::getCompare(ICmpInst::ICMP_EQ, ConstA, ConstB); const APInt *C; - return match(Cmp, m_APIntAllowUndef(C)) && C->isOneValue(); + return match(Cmp, m_APIntAllowUndef(C)) && C->isOne(); }; if ((areLooselyEqual(TV, Op0) && areLooselyEqual(FV, Op1)) || @@ -1120,9 +1135,11 @@ Instruction *InstCombinerImpl::foldOpIntoPhi(Instruction &I, PHINode *PN) { BasicBlock *NonConstBB = nullptr; for (unsigned i = 0; i != NumPHIValues; ++i) { Value *InVal = PN->getIncomingValue(i); - // If I is a freeze instruction, count undef as a non-constant. - if (match(InVal, m_ImmConstant()) && - (!isa<FreezeInst>(I) || isGuaranteedNotToBeUndefOrPoison(InVal))) + // For non-freeze, require constant operand + // For freeze, require non-undef, non-poison operand + if (!isa<FreezeInst>(I) && match(InVal, m_ImmConstant())) + continue; + if (isa<FreezeInst>(I) && isGuaranteedNotToBeUndefOrPoison(InVal)) continue; if (isa<PHINode>(InVal)) return nullptr; // Itself a phi. @@ -1268,61 +1285,19 @@ Instruction *InstCombinerImpl::foldBinOpIntoSelectOrPhi(BinaryOperator &I) { /// specified offset. If so, fill them into NewIndices and return the resultant /// element type, otherwise return null. Type * -InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t Offset, +InstCombinerImpl::FindElementAtOffset(PointerType *PtrTy, int64_t IntOffset, SmallVectorImpl<Value *> &NewIndices) { Type *Ty = PtrTy->getElementType(); if (!Ty->isSized()) return nullptr; - // Start with the index over the outer type. Note that the type size - // might be zero (even if the offset isn't zero) if the indexed type - // is something like [0 x {int, int}] - Type *IndexTy = DL.getIndexType(PtrTy); - int64_t FirstIdx = 0; - if (int64_t TySize = DL.getTypeAllocSize(Ty)) { - FirstIdx = Offset/TySize; - Offset -= FirstIdx*TySize; - - // Handle hosts where % returns negative instead of values [0..TySize). - if (Offset < 0) { - --FirstIdx; - Offset += TySize; - assert(Offset >= 0); - } - assert((uint64_t)Offset < (uint64_t)TySize && "Out of range offset"); - } - - NewIndices.push_back(ConstantInt::get(IndexTy, FirstIdx)); - - // Index into the types. If we fail, set OrigBase to null. - while (Offset) { - // Indexing into tail padding between struct/array elements. - if (uint64_t(Offset * 8) >= DL.getTypeSizeInBits(Ty)) - return nullptr; - - if (StructType *STy = dyn_cast<StructType>(Ty)) { - const StructLayout *SL = DL.getStructLayout(STy); - assert(Offset < (int64_t)SL->getSizeInBytes() && - "Offset must stay within the indexed type"); - - unsigned Elt = SL->getElementContainingOffset(Offset); - NewIndices.push_back(ConstantInt::get(Type::getInt32Ty(Ty->getContext()), - Elt)); - - Offset -= SL->getElementOffset(Elt); - Ty = STy->getElementType(Elt); - } else if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) { - uint64_t EltSize = DL.getTypeAllocSize(AT->getElementType()); - assert(EltSize && "Cannot index into a zero-sized array"); - NewIndices.push_back(ConstantInt::get(IndexTy,Offset/EltSize)); - Offset %= EltSize; - Ty = AT->getElementType(); - } else { - // Otherwise, we can't index into the middle of this atomic type, bail. - return nullptr; - } - } + APInt Offset(DL.getIndexTypeSizeInBits(PtrTy), IntOffset); + SmallVector<APInt> Indices = DL.getGEPIndicesForOffset(Ty, Offset); + if (!Offset.isZero()) + return nullptr; + for (const APInt &Index : Indices) + NewIndices.push_back(Builder.getInt(Index)); return Ty; } @@ -1623,7 +1598,7 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { Value *XY = Builder.CreateBinOp(Opcode, X, Y); if (auto *BO = dyn_cast<BinaryOperator>(XY)) BO->copyIRFlags(&Inst); - return new ShuffleVectorInst(XY, UndefValue::get(XY->getType()), M); + return new ShuffleVectorInst(XY, M); }; // If both arguments of the binary operation are shuffles that use the same @@ -1754,25 +1729,20 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { Value *X; ArrayRef<int> MaskC; int SplatIndex; - BinaryOperator *BO; + Value *Y, *OtherOp; if (!match(LHS, m_OneUse(m_Shuffle(m_Value(X), m_Undef(), m_Mask(MaskC)))) || !match(MaskC, m_SplatOrUndefMask(SplatIndex)) || - X->getType() != Inst.getType() || !match(RHS, m_OneUse(m_BinOp(BO))) || - BO->getOpcode() != Opcode) + X->getType() != Inst.getType() || + !match(RHS, m_OneUse(m_BinOp(Opcode, m_Value(Y), m_Value(OtherOp))))) return nullptr; // FIXME: This may not be safe if the analysis allows undef elements. By // moving 'Y' before the splat shuffle, we are implicitly assuming // that it is not undef/poison at the splat index. - Value *Y, *OtherOp; - if (isSplatValue(BO->getOperand(0), SplatIndex)) { - Y = BO->getOperand(0); - OtherOp = BO->getOperand(1); - } else if (isSplatValue(BO->getOperand(1), SplatIndex)) { - Y = BO->getOperand(1); - OtherOp = BO->getOperand(0); - } else { + if (isSplatValue(OtherOp, SplatIndex)) { + std::swap(Y, OtherOp); + } else if (!isSplatValue(Y, SplatIndex)) { return nullptr; } @@ -1788,7 +1758,7 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { // dropped to be safe. if (isa<FPMathOperator>(R)) { R->copyFastMathFlags(&Inst); - R->andIRFlags(BO); + R->andIRFlags(RHS); } if (auto *NewInstBO = dyn_cast<BinaryOperator>(NewBO)) NewInstBO->copyIRFlags(R); @@ -1896,7 +1866,8 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { Type *GEPType = GEP.getType(); Type *GEPEltType = GEP.getSourceElementType(); bool IsGEPSrcEleScalable = isa<ScalableVectorType>(GEPEltType); - if (Value *V = SimplifyGEPInst(GEPEltType, Ops, SQ.getWithInstruction(&GEP))) + if (Value *V = SimplifyGEPInst(GEPEltType, Ops, GEP.isInBounds(), + SQ.getWithInstruction(&GEP))) return replaceInstUsesWith(GEP, V); // For vector geps, use the generic demanded vector support. @@ -1905,7 +1876,7 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (auto *GEPFVTy = dyn_cast<FixedVectorType>(GEPType)) { auto VWidth = GEPFVTy->getNumElements(); APInt UndefElts(VWidth, 0); - APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth)); + APInt AllOnesEltMask(APInt::getAllOnes(VWidth)); if (Value *V = SimplifyDemandedVectorElts(&GEP, AllOnesEltMask, UndefElts)) { if (V != &GEP) @@ -2117,10 +2088,12 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { // -- have to recreate %src & %gep // put NewSrc at same location as %src Builder.SetInsertPoint(cast<Instruction>(PtrOp)); - auto *NewSrc = cast<GetElementPtrInst>( - Builder.CreateGEP(GEPEltType, SO0, GO1, Src->getName())); - NewSrc->setIsInBounds(Src->isInBounds()); - auto *NewGEP = + Value *NewSrc = + Builder.CreateGEP(GEPEltType, SO0, GO1, Src->getName()); + // Propagate 'inbounds' if the new source was not constant-folded. + if (auto *NewSrcGEPI = dyn_cast<GetElementPtrInst>(NewSrc)) + NewSrcGEPI->setIsInBounds(Src->isInBounds()); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create(GEPEltType, NewSrc, {SO1}); NewGEP->setIsInBounds(GEP.isInBounds()); return NewGEP; @@ -2128,18 +2101,6 @@ Instruction *InstCombinerImpl::visitGetElementPtrInst(GetElementPtrInst &GEP) { } } } - - // Fold (gep(gep(Ptr,Idx0),Idx1) -> gep(Ptr,add(Idx0,Idx1)) - if (GO1->getType() == SO1->getType()) { - bool NewInBounds = GEP.isInBounds() && Src->isInBounds(); - auto *NewIdx = - Builder.CreateAdd(GO1, SO1, GEP.getName() + ".idx", - /*HasNUW*/ false, /*HasNSW*/ NewInBounds); - auto *NewGEP = GetElementPtrInst::Create( - GEPEltType, Src->getPointerOperand(), {NewIdx}); - NewGEP->setIsInBounds(NewInBounds); - return NewGEP; - } } // Note that if our source is a gep chain itself then we wait for that @@ -2647,6 +2608,13 @@ static bool isAllocSiteRemovable(Instruction *AI, Users.emplace_back(I); continue; } + + if (isReallocLikeFn(I, TLI, true)) { + Users.emplace_back(I); + Worklist.push_back(I); + continue; + } + return false; case Instruction::Store: { @@ -2834,15 +2802,33 @@ static Instruction *tryToMoveFreeBeforeNullTest(CallInst &FI, // At this point, we know that everything in FreeInstrBB can be moved // before TI. - for (BasicBlock::iterator It = FreeInstrBB->begin(), End = FreeInstrBB->end(); - It != End;) { - Instruction &Instr = *It++; + for (Instruction &Instr : llvm::make_early_inc_range(*FreeInstrBB)) { if (&Instr == FreeInstrBBTerminator) break; Instr.moveBefore(TI); } assert(FreeInstrBB->size() == 1 && "Only the branch instruction should remain"); + + // Now that we've moved the call to free before the NULL check, we have to + // remove any attributes on its parameter that imply it's non-null, because + // those attributes might have only been valid because of the NULL check, and + // we can get miscompiles if we keep them. This is conservative if non-null is + // also implied by something other than the NULL check, but it's guaranteed to + // be correct, and the conservativeness won't matter in practice, since the + // attributes are irrelevant for the call to free itself and the pointer + // shouldn't be used after the call. + AttributeList Attrs = FI.getAttributes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, Attribute::NonNull); + Attribute Dereferenceable = Attrs.getParamAttr(0, Attribute::Dereferenceable); + if (Dereferenceable.isValid()) { + uint64_t Bytes = Dereferenceable.getDereferenceableBytes(); + Attrs = Attrs.removeParamAttribute(FI.getContext(), 0, + Attribute::Dereferenceable); + Attrs = Attrs.addDereferenceableOrNullParamAttr(FI.getContext(), 0, Bytes); + } + FI.setAttributes(Attrs); + return &FI; } @@ -2861,6 +2847,15 @@ Instruction *InstCombinerImpl::visitFree(CallInst &FI) { if (isa<ConstantPointerNull>(Op)) return eraseInstFromFunction(FI); + // If we had free(realloc(...)) with no intervening uses, then eliminate the + // realloc() entirely. + if (CallInst *CI = dyn_cast<CallInst>(Op)) { + if (CI->hasOneUse() && isReallocLikeFn(CI, &TLI, true)) { + return eraseInstFromFunction( + *replaceInstUsesWith(*CI, CI->getOperand(0))); + } + } + // If we optimize for code size, try to move the call to free before the null // test so that simplify cfg can remove the empty block and dead code // elimination the branch. I.e., helps to turn something like: @@ -2947,7 +2942,7 @@ Instruction *InstCombinerImpl::visitUnconditionalBranchInst(BranchInst &BI) { auto GetLastSinkableStore = [](BasicBlock::iterator BBI) { auto IsNoopInstrForStoreMerging = [](BasicBlock::iterator BBI) { - return isa<DbgInfoIntrinsic>(BBI) || + return BBI->isDebugOrPseudoInst() || (isa<BitCastInst>(BBI) && BBI->getType()->isPointerTy()); }; @@ -3138,26 +3133,21 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { // checking for overflow. const APInt *C; if (match(WO->getRHS(), m_APInt(C))) { - // Compute the no-wrap range [X,Y) for LHS given RHS=C, then - // check for the inverted range using range offset trick (i.e. - // use a subtract to shift the range to bottom of either the - // signed or unsigned domain and then use a single compare to - // check range membership). + // Compute the no-wrap range for LHS given RHS=C, then construct an + // equivalent icmp, potentially using an offset. ConstantRange NWR = ConstantRange::makeExactNoWrapRegion(WO->getBinaryOp(), *C, WO->getNoWrapKind()); - APInt Min = WO->isSigned() ? NWR.getSignedMin() : NWR.getUnsignedMin(); - NWR = NWR.subtract(Min); CmpInst::Predicate Pred; - APInt NewRHSC; - if (NWR.getEquivalentICmp(Pred, NewRHSC)) { - auto *OpTy = WO->getRHS()->getType(); - auto *NewLHS = Builder.CreateSub(WO->getLHS(), - ConstantInt::get(OpTy, Min)); - return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, - ConstantInt::get(OpTy, NewRHSC)); - } + APInt NewRHSC, Offset; + NWR.getEquivalentICmp(Pred, NewRHSC, Offset); + auto *OpTy = WO->getRHS()->getType(); + auto *NewLHS = WO->getLHS(); + if (Offset != 0) + NewLHS = Builder.CreateAdd(NewLHS, ConstantInt::get(OpTy, Offset)); + return new ICmpInst(ICmpInst::getInversePredicate(Pred), NewLHS, + ConstantInt::get(OpTy, NewRHSC)); } } } @@ -3183,9 +3173,7 @@ Instruction *InstCombinerImpl::visitExtractValueInst(ExtractValueInst &EV) { Instruction *NL = Builder.CreateLoad(EV.getType(), GEP); // Whatever aliasing information we had for the orignal load must also // hold for the smaller load, so propagate the annotations. - AAMDNodes Nodes; - L->getAAMetadata(Nodes); - NL->setAAMetadata(Nodes); + NL->setAAMetadata(L->getAAMetadata()); // Returning the load directly will cause the main loop to insert it in // the wrong spot, so use replaceInstUsesWith(). return replaceInstUsesWith(EV, NL); @@ -3568,8 +3556,14 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { // While we could change the other users of OrigOp to use freeze(OrigOp), that // potentially reduces their optimization potential, so let's only do this iff // the OrigOp is only used by the freeze. - if (!OrigOpInst || !OrigOpInst->hasOneUse() || isa<PHINode>(OrigOp) || - canCreateUndefOrPoison(dyn_cast<Operator>(OrigOp))) + if (!OrigOpInst || !OrigOpInst->hasOneUse() || isa<PHINode>(OrigOp)) + return nullptr; + + // We can't push the freeze through an instruction which can itself create + // poison. If the only source of new poison is flags, we can simply + // strip them (since we know the only use is the freeze and nothing can + // benefit from them.) + if (canCreateUndefOrPoison(cast<Operator>(OrigOp), /*ConsiderFlags*/ false)) return nullptr; // If operand is guaranteed not to be poison, there is no need to add freeze @@ -3585,6 +3579,8 @@ InstCombinerImpl::pushFreezeToPreventPoisonFromPropagating(FreezeInst &OrigFI) { return nullptr; } + OrigOpInst->dropPoisonGeneratingFlags(); + // If all operands are guaranteed to be non-poison, we can drop freeze. if (!MaybePoisonOperand) return OrigOp; @@ -3668,7 +3664,7 @@ Instruction *InstCombinerImpl::visitFreeze(FreezeInst &I) { /// instruction past all of the instructions between it and the end of its /// block. static bool TryToSinkInstruction(Instruction *I, BasicBlock *DestBlock) { - assert(I->getSingleUndroppableUse() && "Invariants didn't hold!"); + assert(I->getUniqueUndroppableUser() && "Invariants didn't hold!"); BasicBlock *SrcBlock = I->getParent(); // Cannot move control-flow-involving, volatile loads, vaarg, etc. @@ -3822,51 +3818,71 @@ bool InstCombinerImpl::run() { // See if we can trivially sink this instruction to its user if we can // prove that the successor is not executed more frequently than our block. - if (EnableCodeSinking) - if (Use *SingleUse = I->getSingleUndroppableUse()) { - BasicBlock *BB = I->getParent(); - Instruction *UserInst = cast<Instruction>(SingleUse->getUser()); - BasicBlock *UserParent; + // Return the UserBlock if successful. + auto getOptionalSinkBlockForInst = + [this](Instruction *I) -> Optional<BasicBlock *> { + if (!EnableCodeSinking) + return None; + auto *UserInst = cast_or_null<Instruction>(I->getUniqueUndroppableUser()); + if (!UserInst) + return None; - // Get the block the use occurs in. - if (PHINode *PN = dyn_cast<PHINode>(UserInst)) - UserParent = PN->getIncomingBlock(*SingleUse); - else - UserParent = UserInst->getParent(); + BasicBlock *BB = I->getParent(); + BasicBlock *UserParent = nullptr; - // Try sinking to another block. If that block is unreachable, then do - // not bother. SimplifyCFG should handle it. - if (UserParent != BB && DT.isReachableFromEntry(UserParent)) { - // See if the user is one of our successors that has only one - // predecessor, so that we don't have to split the critical edge. - bool ShouldSink = UserParent->getUniquePredecessor() == BB; - // Another option where we can sink is a block that ends with a - // terminator that does not pass control to other block (such as - // return or unreachable). In this case: - // - I dominates the User (by SSA form); - // - the User will be executed at most once. - // So sinking I down to User is always profitable or neutral. - if (!ShouldSink) { - auto *Term = UserParent->getTerminator(); - ShouldSink = isa<ReturnInst>(Term) || isa<UnreachableInst>(Term); - } - if (ShouldSink) { - assert(DT.dominates(BB, UserParent) && - "Dominance relation broken?"); - // Okay, the CFG is simple enough, try to sink this instruction. - if (TryToSinkInstruction(I, UserParent)) { - LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); - MadeIRChange = true; - // We'll add uses of the sunk instruction below, but since sinking - // can expose opportunities for it's *operands* add them to the - // worklist - for (Use &U : I->operands()) - if (Instruction *OpI = dyn_cast<Instruction>(U.get())) - Worklist.push(OpI); - } + // Special handling for Phi nodes - get the block the use occurs in. + if (PHINode *PN = dyn_cast<PHINode>(UserInst)) { + for (unsigned i = 0; i < PN->getNumIncomingValues(); i++) { + if (PN->getIncomingValue(i) == I) { + // Bail out if we have uses in different blocks. We don't do any + // sophisticated analysis (i.e finding NearestCommonDominator of these + // use blocks). + if (UserParent && UserParent != PN->getIncomingBlock(i)) + return None; + UserParent = PN->getIncomingBlock(i); } } + assert(UserParent && "expected to find user block!"); + } else + UserParent = UserInst->getParent(); + + // Try sinking to another block. If that block is unreachable, then do + // not bother. SimplifyCFG should handle it. + if (UserParent == BB || !DT.isReachableFromEntry(UserParent)) + return None; + + auto *Term = UserParent->getTerminator(); + // See if the user is one of our successors that has only one + // predecessor, so that we don't have to split the critical edge. + // Another option where we can sink is a block that ends with a + // terminator that does not pass control to other block (such as + // return or unreachable). In this case: + // - I dominates the User (by SSA form); + // - the User will be executed at most once. + // So sinking I down to User is always profitable or neutral. + if (UserParent->getUniquePredecessor() == BB || + (isa<ReturnInst>(Term) || isa<UnreachableInst>(Term))) { + assert(DT.dominates(BB, UserParent) && "Dominance relation broken?"); + return UserParent; } + return None; + }; + + auto OptBB = getOptionalSinkBlockForInst(I); + if (OptBB) { + auto *UserParent = *OptBB; + // Okay, the CFG is simple enough, try to sink this instruction. + if (TryToSinkInstruction(I, UserParent)) { + LLVM_DEBUG(dbgs() << "IC: Sink: " << *I << '\n'); + MadeIRChange = true; + // We'll add uses of the sunk instruction below, but since + // sinking can expose opportunities for it's *operands* add + // them to the worklist + for (Use &U : I->operands()) + if (Instruction *OpI = dyn_cast<Instruction>(U.get())) + Worklist.push(OpI); + } + } // Now that we have an instruction, try combining it to simplify it. Builder.SetInsertPoint(I); @@ -3994,13 +4010,13 @@ public: /// whose condition is a known constant, we only visit the reachable successors. static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI, - InstCombineWorklist &ICWorklist) { + InstructionWorklist &ICWorklist) { bool MadeIRChange = false; SmallPtrSet<BasicBlock *, 32> Visited; SmallVector<BasicBlock*, 256> Worklist; Worklist.push_back(&F.front()); - SmallVector<Instruction*, 128> InstrsForInstCombineWorklist; + SmallVector<Instruction *, 128> InstrsForInstructionWorklist; DenseMap<Constant *, Constant *> FoldedConstants; AliasScopeTracker SeenAliasScopes; @@ -4011,25 +4027,23 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, if (!Visited.insert(BB).second) continue; - for (BasicBlock::iterator BBI = BB->begin(), E = BB->end(); BBI != E; ) { - Instruction *Inst = &*BBI++; - + for (Instruction &Inst : llvm::make_early_inc_range(*BB)) { // ConstantProp instruction if trivially constant. - if (!Inst->use_empty() && - (Inst->getNumOperands() == 0 || isa<Constant>(Inst->getOperand(0)))) - if (Constant *C = ConstantFoldInstruction(Inst, DL, TLI)) { - LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << *Inst + if (!Inst.use_empty() && + (Inst.getNumOperands() == 0 || isa<Constant>(Inst.getOperand(0)))) + if (Constant *C = ConstantFoldInstruction(&Inst, DL, TLI)) { + LLVM_DEBUG(dbgs() << "IC: ConstFold to: " << *C << " from: " << Inst << '\n'); - Inst->replaceAllUsesWith(C); + Inst.replaceAllUsesWith(C); ++NumConstProp; - if (isInstructionTriviallyDead(Inst, TLI)) - Inst->eraseFromParent(); + if (isInstructionTriviallyDead(&Inst, TLI)) + Inst.eraseFromParent(); MadeIRChange = true; continue; } // See if we can constant fold its operands. - for (Use &U : Inst->operands()) { + for (Use &U : Inst.operands()) { if (!isa<ConstantVector>(U) && !isa<ConstantExpr>(U)) continue; @@ -4039,7 +4053,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, FoldRes = ConstantFoldConstant(C, DL, TLI); if (FoldRes != C) { - LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << *Inst + LLVM_DEBUG(dbgs() << "IC: ConstFold operand of: " << Inst << "\n Old = " << *C << "\n New = " << *FoldRes << '\n'); U = FoldRes; @@ -4050,9 +4064,9 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // Skip processing debug and pseudo intrinsics in InstCombine. Processing // these call instructions consumes non-trivial amount of time and // provides no value for the optimization. - if (!Inst->isDebugOrPseudoInst()) { - InstrsForInstCombineWorklist.push_back(Inst); - SeenAliasScopes.analyse(Inst); + if (!Inst.isDebugOrPseudoInst()) { + InstrsForInstructionWorklist.push_back(&Inst); + SeenAliasScopes.analyse(&Inst); } } @@ -4097,8 +4111,8 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, // of the function down. This jives well with the way that it adds all uses // of instructions to the worklist after doing a transformation, thus avoiding // some N^2 behavior in pathological cases. - ICWorklist.reserve(InstrsForInstCombineWorklist.size()); - for (Instruction *Inst : reverse(InstrsForInstCombineWorklist)) { + ICWorklist.reserve(InstrsForInstructionWorklist.size()); + for (Instruction *Inst : reverse(InstrsForInstructionWorklist)) { // DCE instruction if trivially dead. As we iterate in reverse program // order here, we will clean up whole chains of dead instructions. if (isInstructionTriviallyDead(Inst, TLI) || @@ -4118,7 +4132,7 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, } static bool combineInstructionsOverFunction( - Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, + Function &F, InstructionWorklist &Worklist, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, TargetTransformInfo &TTI, DominatorTree &DT, OptimizationRemarkEmitter &ORE, BlockFrequencyInfo *BFI, ProfileSummaryInfo *PSI, unsigned MaxIterations, LoopInfo *LI) { |
