diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 212 |
1 files changed, 142 insertions, 70 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 428f94bb5e93..bbafa9e9f468 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -230,7 +230,9 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, return nullptr; uint64_t ArrayElementCount = Init->getType()->getArrayNumElements(); - if (ArrayElementCount > 1024) return nullptr; // Don't blow up on huge arrays. + // Don't blow up on huge arrays. + if (ArrayElementCount > MaxArraySizeForCombine) + return nullptr; // There are many forms of this optimization we can handle, for now, just do // the simple index into a single-dimensional array. @@ -1663,7 +1665,7 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using - // ShouldChangeType()? + // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); @@ -1792,6 +1794,15 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, ConstantInt::get(V->getType(), 1)); } + // X | C == C --> X <=u C + // X | C != C --> X >u C + // iff C+1 is a power of 2 (C is a bitmask of the low bits) + if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && + (*C + 1).isPowerOf2()) { + Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; + return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); + } + if (!Cmp.isEquality() || *C != 0 || !Or->hasOneUse()) return nullptr; @@ -1914,61 +1925,89 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Shl->getOperand(0); - if (Cmp.isEquality()) { - // If the shift is NUW, then it is just shifting out zeros, no need for an - // AND. - Constant *LShrC = ConstantInt::get(Shl->getType(), C->lshr(*ShiftAmt)); - if (Shl->hasNoUnsignedWrap()) - return new ICmpInst(Pred, X, LShrC); - - // If the shift is NSW and we compare to 0, then it is just shifting out - // sign bits, no need for an AND either. - if (Shl->hasNoSignedWrap() && *C == 0) - return new ICmpInst(Pred, X, LShrC); - - if (Shl->hasOneUse()) { - // Otherwise, strength reduce the shift into an and. - Constant *Mask = ConstantInt::get(Shl->getType(), - APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); - - Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); - return new ICmpInst(Pred, And, LShrC); + Type *ShType = Shl->getType(); + + // NSW guarantees that we are only shifting out sign bits from the high bits, + // so we can ASHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoSignedWrap()) { + if (Pred == ICmpInst::ICMP_SGT) { + // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the SGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->ashr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_SLT) { + // SLE is the same as above, but SLE is canonicalized to SLT, so convert: + // (X << S) <=s C is equiv to X <=s (C >> S) for all C + // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX + // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN + assert(!C->isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + // If this is a signed comparison to 0 and the shift is sign preserving, + // use the shift LHS operand instead; isSignTest may change 'Pred', so only + // do that if we're sure to not continue on in this function. + if (isSignTest(Pred, *C)) + return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); + } + + // NUW guarantees that we are only shifting out zero bits from the high bits, + // so we can LSHR the compare constant without needing a mask and eliminate + // the shift. + if (Shl->hasNoUnsignedWrap()) { + if (Pred == ICmpInst::ICMP_UGT) { + // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { + // This is the same code as the UGT case, but assert the pre-condition + // that is needed for this to work with equality predicates. + assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + "Compare known true or false was not folded"); + APInt ShiftedC = C->lshr(*ShiftAmt); + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); + } + if (Pred == ICmpInst::ICMP_ULT) { + // ULE is the same as above, but ULE is canonicalized to ULT, so convert: + // (X << S) <=u C is equiv to X <=u (C >> S) for all C + // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u + // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 + assert(C->ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } - // If this is a signed comparison to 0 and the shift is sign preserving, - // use the shift LHS operand instead; isSignTest may change 'Pred', so only - // do that if we're sure to not continue on in this function. - if (Shl->hasNoSignedWrap() && isSignTest(Pred, *C)) - return new ICmpInst(Pred, X, Constant::getNullValue(X->getType())); + if (Cmp.isEquality() && Shl->hasOneUse()) { + // Strength-reduce the shift into an 'and'. + Constant *Mask = ConstantInt::get( + ShType, + APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); + Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); + Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + return new ICmpInst(Pred, And, LShrC); + } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( - X->getType(), + ShType, APInt::getOneBitSet(TypeBits, TypeBits - ShiftAmt->getZExtValue() - 1)); Value *And = Builder->CreateAnd(X, Mask, Shl->getName() + ".mask"); return new ICmpInst(TrueIfSigned ? ICmpInst::ICMP_NE : ICmpInst::ICMP_EQ, - And, Constant::getNullValue(And->getType())); - } - - // When the shift is nuw and pred is >u or <=u, comparison only really happens - // in the pre-shifted bits. Since InstSimplify canonicalizes <=u into <u, the - // <=u case can be further converted to match <u (see below). - if (Shl->hasNoUnsignedWrap() && - (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_ULT)) { - // Derivation for the ult case: - // (X << S) <=u C is equiv to X <=u (C >> S) for all C - // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u - // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert((Pred != ICmpInst::ICMP_ULT || C->ugt(0)) && - "Encountered `ult 0` that should have been eliminated by " - "InstSimplify."); - APInt ShiftedC = Pred == ICmpInst::ICMP_ULT ? (*C - 1).lshr(*ShiftAmt) + 1 - : C->lshr(*ShiftAmt); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), ShiftedC)); + And, Constant::getNullValue(ShType)); } // Transform (icmp pred iM (shl iM %v, N), C) @@ -1981,8 +2020,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); - if (X->getType()->isVectorTy()) - TruncTy = VectorType::get(TruncTy, X->getType()->getVectorNumElements()); + if (ShType->isVectorTy()) + TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder->CreateTrunc(X, TruncTy), NewC); @@ -2342,8 +2381,24 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // Fold icmp pred (add X, C2), C. Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - auto CR = - ConstantRange::makeExactICmpRegion(Cmp.getPredicate(), *C).subtract(*C2); + CmpInst::Predicate Pred = Cmp.getPredicate(); + + // If the add does not wrap, we can always adjust the compare by subtracting + // the constants. Equality comparisons are handled elsewhere. SGE/SLE are + // canonicalized to SGT/SLT. + if (Add->hasNoSignedWrap() && + (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { + bool Overflow; + APInt NewC = C->ssub_ov(*C2, Overflow); + // If there is overflow, the result must be true or false. + // TODO: Can we assert there is no overflow because InstSimplify always + // handles those cases? + if (!Overflow) + // icmp Pred (add nsw X, C2), C --> icmp Pred X, (C - C2) + return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); + } + + auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { @@ -2364,16 +2419,14 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == 0) + if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) return new ICmpInst(ICmpInst::ICMP_EQ, Builder->CreateAnd(X, -(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && - (*C2 & *C) == 0) + if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) return new ICmpInst(ICmpInst::ICMP_NE, Builder->CreateAnd(X, ~(*C)), ConstantExpr::getNeg(cast<Constant>(Y))); @@ -2656,7 +2709,7 @@ Instruction *InstCombiner::foldICmpInstWithConstantNotInt(ICmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::Select: { @@ -2767,12 +2820,6 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { D = BO1->getOperand(1); } - // icmp (X+cst) < 0 --> X < -cst - if (NoOp0WrapProblem && ICmpInst::isSigned(Pred) && match(Op1, m_Zero())) - if (ConstantInt *RHSC = dyn_cast_or_null<ConstantInt>(B)) - if (!RHSC->isMinValue(/*isSigned=*/true)) - return new ICmpInst(Pred, A, ConstantExpr::getNeg(RHSC)); - // icmp (X+Y), X -> icmp Y, 0 for equalities or if there is no overflow. if ((A == Op1 || B == Op1) && NoOp0WrapProblem) return new ICmpInst(Pred, A == Op1 ? B : A, @@ -2847,6 +2894,31 @@ Instruction *InstCombiner::foldICmpBinOp(ICmpInst &I) { if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_SLT && match(D, m_One())) return new ICmpInst(CmpInst::ICMP_SLE, Op0, C); + // TODO: The subtraction-related identities shown below also hold, but + // canonicalization from (X -nuw 1) to (X + -1) means that the combinations + // wouldn't happen even if they were implemented. + // + // icmp ult (X - 1), Y -> icmp ule X, Y + // icmp uge (X - 1), Y -> icmp ugt X, Y + // icmp ugt X, (Y - 1) -> icmp uge X, Y + // icmp ule X, (Y - 1) -> icmp ult X, Y + + // icmp ule (X + 1), Y -> icmp ult X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_ULE && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_ULT, A, Op1); + + // icmp ugt (X + 1), Y -> icmp uge X, Y + if (A && NoOp0WrapProblem && Pred == CmpInst::ICMP_UGT && match(B, m_One())) + return new ICmpInst(CmpInst::ICMP_UGE, A, Op1); + + // icmp uge X, (Y + 1) -> icmp ugt X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_UGE && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_UGT, Op0, C); + + // icmp ult X, (Y + 1) -> icmp ule X, Y + if (C && NoOp1WrapProblem && Pred == CmpInst::ICMP_ULT && match(D, m_One())) + return new ICmpInst(CmpInst::ICMP_ULE, Op0, C); + // if C1 has greater magnitude than C2: // icmp (X + C1), (Y + C2) -> icmp (X + C3), Y // s.t. C3 = C1 - C2 @@ -3738,16 +3810,14 @@ static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, // greater than the RHS must differ in a bit higher than these due to carry. case ICmpInst::ICMP_UGT: { unsigned trailingOnes = RHS.countTrailingOnes(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingOnes); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingOnes); } // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. case ICmpInst::ICMP_ULT: { unsigned trailingZeros = RHS.countTrailingZeros(); - APInt lowBitsSet = APInt::getLowBitsSet(BitWidth, trailingZeros); - return ~lowBitsSet; + return APInt::getBitsSetFrom(BitWidth, trailingZeros); } default: @@ -3887,7 +3957,7 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, assert((SIOpd == 1 || SIOpd == 2) && "Invalid select operand!"); if (isChainSelectCmpBranch(SI) && Icmp->getPredicate() == ICmpInst::ICMP_EQ) { BasicBlock *Succ = SI->getParent()->getTerminator()->getSuccessor(1); - // The check for the unique predecessor is not the best that can be + // The check for the single predecessor is not the best that can be // done. But it protects efficiently against cases like when SI's // home block has two successors, Succ and Succ1, and Succ1 predecessor // of Succ. Then SI can't be replaced by SIOpd because the use that gets @@ -3895,8 +3965,10 @@ bool InstCombiner::replacedSelectWithOperand(SelectInst *SI, // guarantees that the path all uses of SI (outside SI's parent) are on // is disjoint from all other paths out of SI. But that information // is more expensive to compute, and the trade-off here is in favor - // of compile-time. - if (Succ->getUniquePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { + // of compile-time. It should also be noticed that we check for a single + // predecessor and not only uniqueness. This to handle the situation when + // Succ and Succ1 points to the same basic block. + if (Succ->getSinglePredecessor() && dominatesAllUses(SI, Icmp, Succ)) { NumSel++; SI->replaceUsesOutsideBlock(SI->getOperand(SIOpd), SI->getParent()); return true; @@ -3932,12 +4004,12 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { APInt Op0KnownZero(BitWidth, 0), Op0KnownOne(BitWidth, 0); APInt Op1KnownZero(BitWidth, 0), Op1KnownOne(BitWidth, 0); - if (SimplifyDemandedBits(I.getOperandUse(0), + if (SimplifyDemandedBits(&I, 0, getDemandedBitsLHSMask(I, BitWidth, IsSignBit), Op0KnownZero, Op0KnownOne, 0)) return &I; - if (SimplifyDemandedBits(I.getOperandUse(1), APInt::getAllOnesValue(BitWidth), + if (SimplifyDemandedBits(&I, 1, APInt::getAllOnesValue(BitWidth), Op1KnownZero, Op1KnownOne, 0)) return &I; @@ -4801,7 +4873,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // block. If in the same block, we're encouraging jump threading. If // not, we are just pessimizing the code by making an i1 phi. if (LHSI->getParent() == I.getParent()) - if (Instruction *NV = FoldOpIntoPhi(I)) + if (Instruction *NV = foldOpIntoPhi(I, cast<PHINode>(LHSI))) return NV; break; case Instruction::SIToFP: |