diff options
| author | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
|---|---|---|
| committer | Dimitry Andric <dim@FreeBSD.org> | 2021-07-29 20:15:26 +0000 |
| commit | 344a3780b2e33f6ca763666c380202b18aab72a3 (patch) | |
| tree | f0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | |
| parent | b60736ec1405bb0a8dd40989f67ef4c93da068ab (diff) | |
vendor/llvm-project/llvmorg-13-init-16847-g88e66fa60ae5vendor/llvm-project/llvmorg-12.0.1-rc2-0-ge7dac564cd0evendor/llvm-project/llvmorg-12.0.1-0-gfed41342a82f
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
| -rw-r--r-- | llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp | 168 |
1 files changed, 151 insertions, 17 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp index cd9a036179b6..2b0ef0c5f2cc 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -279,9 +279,28 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, Idx = Builder.CreateTrunc(Idx, IntPtrTy); } + // If inbounds keyword is not present, Idx * ElementSize can overflow. + // Let's assume that ElementSize is 2 and the wanted value is at offset 0. + // Then, there are two possible values for Idx to match offset 0: + // 0x00..00, 0x80..00. + // Emitting 'icmp eq Idx, 0' isn't correct in this case because the + // comparison is false if Idx was 0x80..00. + // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx. + unsigned ElementSize = + DL.getTypeAllocSize(Init->getType()->getArrayElementType()); + auto MaskIdx = [&](Value* Idx){ + if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) { + Value *Mask = ConstantInt::get(Idx->getType(), -1); + Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize)); + Idx = Builder.CreateAnd(Idx, Mask); + } + return Idx; + }; + // If the comparison is only true for one or two elements, emit direct // comparisons. if (SecondTrueElement != Overdefined) { + Idx = MaskIdx(Idx); // None true -> false. if (FirstTrueElement == Undefined) return replaceInstUsesWith(ICI, Builder.getFalse()); @@ -302,6 +321,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // If the comparison is only false for one or two elements, emit direct // comparisons. if (SecondFalseElement != Overdefined) { + Idx = MaskIdx(Idx); // None false -> true. if (FirstFalseElement == Undefined) return replaceInstUsesWith(ICI, Builder.getTrue()); @@ -323,6 +343,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // where it is true, emit the range check. if (TrueRangeEnd != Overdefined) { assert(TrueRangeEnd != FirstTrueElement && "Should emit single compare"); + Idx = MaskIdx(Idx); // Generate (i-FirstTrue) <u (TrueRangeEnd-FirstTrue+1). if (FirstTrueElement) { @@ -338,6 +359,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // False range check. if (FalseRangeEnd != Overdefined) { assert(FalseRangeEnd != FirstFalseElement && "Should emit single compare"); + Idx = MaskIdx(Idx); // Generate (i-FirstFalse) >u (FalseRangeEnd-FirstFalse). if (FirstFalseElement) { Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement); @@ -364,6 +386,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount); if (Ty) { + Idx = MaskIdx(Idx); Value *V = Builder.CreateIntCast(Idx, Ty, false); V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V); V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V); @@ -1417,7 +1440,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) { // icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...). Constant *C = dyn_cast<Constant>(Op1); - if (!C) + if (!C || C->canTrap()) return nullptr; if (auto *Phi = dyn_cast<PHINode>(Op0)) @@ -1478,7 +1501,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { // br DomCond, CmpBB, FalseBB // CmpBB: // Cmp = icmp Pred X, C - ConstantRange CR = ConstantRange::makeAllowedICmpRegion(Pred, *C); + ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C); ConstantRange DominatingCR = (CmpBB == TrueBB) ? ConstantRange::makeExactICmpRegion(DomPred, *DomC) : ConstantRange::makeExactICmpRegion( @@ -1500,6 +1523,12 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) { if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) return nullptr; + // Avoid an infinite loop with min/max canonicalization. + // TODO: This will be unnecessary if we canonicalize to min/max intrinsics. + if (Cmp.hasOneUse() && + match(Cmp.user_back(), m_MaxOrMin(m_Value(), m_Value()))) + return nullptr; + if (const APInt *EqC = Intersection.getSingleElement()) return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC)); if (const APInt *NeC = Difference.getSingleElement()) @@ -1523,11 +1552,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, ConstantInt::get(V->getType(), 1)); } + unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), + SrcBits = X->getType()->getScalarSizeInBits(); if (Cmp.isEquality() && Trunc->hasOneUse()) { // Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all // of the high bits truncated out of x are known. - unsigned DstBits = Trunc->getType()->getScalarSizeInBits(), - SrcBits = X->getType()->getScalarSizeInBits(); KnownBits Known = computeKnownBits(X, 0, &Cmp); // If all the high bits are known, we can do this xform. @@ -1539,6 +1568,22 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp, } } + // Look through truncated right-shift of the sign-bit for a sign-bit check: + // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0 + // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1 + Value *ShOp; + const APInt *ShAmtC; + bool TrueIfSigned; + if (isSignBitCheck(Pred, C, TrueIfSigned) && + match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) && + DstBits == SrcBits - ShAmtC->getZExtValue()) { + return TrueIfSigned + ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp, + ConstantInt::getNullValue(X->getType())) + : new ICmpInst(ICmpInst::ICMP_SGT, ShOp, + ConstantInt::getAllOnesValue(X->getType())); + } + return nullptr; } @@ -1810,6 +1855,19 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) return I; + const ICmpInst::Predicate Pred = Cmp.getPredicate(); + bool TrueIfNeg; + if (isSignBitCheck(Pred, C, TrueIfNeg)) { + // ((X - 1) & ~X) < 0 --> X == 0 + // ((X - 1) & ~X) >= 0 --> X != 0 + Value *X; + if (match(And->getOperand(0), m_Add(m_Value(X), m_AllOnes())) && + match(And->getOperand(1), m_Not(m_Specific(X)))) { + auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE; + return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType())); + } + } + // TODO: These all require that Y is constant too, so refactor with the above. // Try to optimize things like "A[i] & 42 == 0" to index computations. @@ -1832,8 +1890,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, // X & -C != -C -> X <= u ~C // iff C is a power of 2 if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { - auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT - : CmpInst::ICMP_ULE; + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); } @@ -1848,8 +1906,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp, if (auto *AndVTy = dyn_cast<VectorType>(And->getType())) NTy = VectorType::get(NTy, AndVTy->getElementCount()); Value *Trunc = Builder.CreateTrunc(X, NTy); - auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE - : CmpInst::ICMP_SLT; + auto NewPred = + Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT; return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy)); } } @@ -2258,6 +2316,16 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp, if (Shr->isExact()) return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); + if (C.isNullValue()) { + // == 0 is u< 1. + if (Pred == CmpInst::ICMP_EQ) + return new ICmpInst(CmpInst::ICMP_ULT, X, + ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal))); + else + return new ICmpInst(CmpInst::ICMP_UGT, X, + ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1)); + } + if (Shr->hasOneUse()) { // Canonicalize the shift into an 'and': // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) @@ -2581,7 +2649,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, // Fold icmp pred (add X, C2), C. Value *X = Add->getOperand(0); Type *Ty = Add->getType(); - CmpInst::Predicate Pred = Cmp.getPredicate(); + const CmpInst::Predicate Pred = Cmp.getPredicate(); // If the add does not wrap, we can always adjust the compare by subtracting // the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE @@ -2616,6 +2684,28 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower)); } + // This set of folds is intentionally placed after folds that use no-wrapping + // flags because those folds are likely better for later analysis/codegen. + const APInt SMax = APInt::getSignedMaxValue(Ty->getScalarSizeInBits()); + const APInt SMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits()); + + // Fold compare with offset to opposite sign compare if it eliminates offset: + // (X + C2) >u C --> X <s -C2 (if C == C2 + SMAX) + if (Pred == CmpInst::ICMP_UGT && C == *C2 + SMax) + return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, -(*C2))); + + // (X + C2) <u C --> X >s ~C2 (if C == C2 + SMIN) + if (Pred == CmpInst::ICMP_ULT && C == *C2 + SMin) + return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::get(Ty, ~(*C2))); + + // (X + C2) >s C --> X <u (SMAX - C) (if C == C2 - 1) + if (Pred == CmpInst::ICMP_SGT && C == *C2 - 1) + return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, SMax - C)); + + // (X + C2) <s C --> X >u (C ^ SMAX) (if C == C2) + if (Pred == CmpInst::ICMP_SLT && C == *C2) + return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax)); + if (!Add->hasOneUse()) return nullptr; @@ -3235,14 +3325,24 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) { // constant folded and the select turned into a bitwise or. Value *Op1 = nullptr, *Op2 = nullptr; ConstantInt *CI = nullptr; - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) { - Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + + auto SimplifyOp = [&](Value *V) { + Value *Op = nullptr; + if (Constant *C = dyn_cast<Constant>(V)) { + Op = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + } else if (RHSC->isNullValue()) { + // If null is being compared, check if it can be further simplified. + Op = SimplifyICmpInst(I.getPredicate(), V, RHSC, SQ); + } + return Op; + }; + Op1 = SimplifyOp(LHSI->getOperand(1)); + if (Op1) CI = dyn_cast<ConstantInt>(Op1); - } - if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) { - Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC); + + Op2 = SimplifyOp(LHSI->getOperand(2)); + if (Op2) CI = dyn_cast<ConstantInt>(Op2); - } // We only want to perform this transformation if it will not lead to // additional code. This is true if either both sides of the select @@ -3901,11 +4001,15 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I, APInt AP2Abs = C2->getValue().abs(); if (AP1Abs.uge(AP2Abs)) { ConstantInt *C3 = Builder.getInt(AP1 - AP2); - Value *NewAdd = Builder.CreateNSWAdd(A, C3); + bool HasNUW = BO0->hasNoUnsignedWrap() && C3->getValue().ule(AP1); + bool HasNSW = BO0->hasNoSignedWrap(); + Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW); return new ICmpInst(Pred, NewAdd, C); } else { ConstantInt *C3 = Builder.getInt(AP2 - AP1); - Value *NewAdd = Builder.CreateNSWAdd(C, C3); + bool HasNUW = BO1->hasNoUnsignedWrap() && C3->getValue().ule(AP2); + bool HasNSW = BO1->hasNoSignedWrap(); + Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW); return new ICmpInst(Pred, A, NewAdd); } } @@ -4467,6 +4571,16 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp, /// Handle icmp (cast x), (cast or constant). Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) { + // If any operand of ICmp is a inttoptr roundtrip cast then remove it as + // icmp compares only pointer's value. + // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2. + Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0)); + Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1)); + if (SimplifiedOp0 || SimplifiedOp1) + return new ICmpInst(ICmp.getPredicate(), + SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0), + SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1)); + auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0)); if (!CastOp0) return nullptr; @@ -6267,6 +6381,26 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) { } } + // Convert a sign-bit test of an FP value into a cast and integer compare. + // TODO: Simplify if the copysign constant is 0.0 or NaN. + // TODO: Handle non-zero compare constants. + // TODO: Handle other predicates. + const APFloat *C; + if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C), + m_Value(X)))) && + match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) { + Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits()); + if (auto *VecTy = dyn_cast<VectorType>(OpType)) + IntType = VectorType::get(IntType, VecTy->getElementCount()); + + // copysign(non-zero constant, X) < 0.0 --> (bitcast X) < 0 + if (Pred == FCmpInst::FCMP_OLT) { + Value *IntX = Builder.CreateBitCast(X, IntType); + return new ICmpInst(ICmpInst::ICMP_SLT, IntX, + ConstantInt::getNullValue(IntType)); + } + } + if (I.getType()->isVectorTy()) if (Instruction *Res = foldVectorCmp(I, Builder)) return Res; |
