summaryrefslogtreecommitdiff
path: root/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
diff options
context:
space:
mode:
authorDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
committerDimitry Andric <dim@FreeBSD.org>2021-07-29 20:15:26 +0000
commit344a3780b2e33f6ca763666c380202b18aab72a3 (patch)
treef0b203ee6eb71d7fdd792373e3c81eb18d6934dd /llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
parentb60736ec1405bb0a8dd40989f67ef4c93da068ab (diff)
Diffstat (limited to 'llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r--llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp168
1 files changed, 151 insertions, 17 deletions
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index cd9a036179b6..2b0ef0c5f2cc 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -279,9 +279,28 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
Idx = Builder.CreateTrunc(Idx, IntPtrTy);
}
+ // If inbounds keyword is not present, Idx * ElementSize can overflow.
+ // Let's assume that ElementSize is 2 and the wanted value is at offset 0.
+ // Then, there are two possible values for Idx to match offset 0:
+ // 0x00..00, 0x80..00.
+ // Emitting 'icmp eq Idx, 0' isn't correct in this case because the
+ // comparison is false if Idx was 0x80..00.
+ // We need to erase the highest countTrailingZeros(ElementSize) bits of Idx.
+ unsigned ElementSize =
+ DL.getTypeAllocSize(Init->getType()->getArrayElementType());
+ auto MaskIdx = [&](Value* Idx){
+ if (!GEP->isInBounds() && countTrailingZeros(ElementSize) != 0) {
+ Value *Mask = ConstantInt::get(Idx->getType(), -1);
+ Mask = Builder.CreateLShr(Mask, countTrailingZeros(ElementSize));
+ Idx = Builder.CreateAnd(Idx, Mask);
+ }
+ return Idx;
+ };
+
// If the comparison is only true for one or two elements, emit direct
// comparisons.
if (SecondTrueElement != Overdefined) {
+ Idx = MaskIdx(Idx);
// None true -> false.
if (FirstTrueElement == Undefined)
return replaceInstUsesWith(ICI, Builder.getFalse());
@@ -302,6 +321,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
// If the comparison is only false for one or two elements, emit direct
// comparisons.
if (SecondFalseElement != Overdefined) {
+ Idx = MaskIdx(Idx);
// None false -> true.
if (FirstFalseElement == Undefined)
return replaceInstUsesWith(ICI, Builder.getTrue());
@@ -323,6 +343,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
// where it is true, emit the range check.
if (TrueRangeEnd != Overdefined) {
assert(TrueRangeEnd != FirstTrueElement && "Should emit single compare");
+ Idx = MaskIdx(Idx);
// Generate (i-FirstTrue) <u (TrueRangeEnd-FirstTrue+1).
if (FirstTrueElement) {
@@ -338,6 +359,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
// False range check.
if (FalseRangeEnd != Overdefined) {
assert(FalseRangeEnd != FirstFalseElement && "Should emit single compare");
+ Idx = MaskIdx(Idx);
// Generate (i-FirstFalse) >u (FalseRangeEnd-FirstFalse).
if (FirstFalseElement) {
Value *Offs = ConstantInt::get(Idx->getType(), -FirstFalseElement);
@@ -364,6 +386,7 @@ InstCombinerImpl::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP,
Ty = DL.getSmallestLegalIntType(Init->getContext(), ArrayElementCount);
if (Ty) {
+ Idx = MaskIdx(Idx);
Value *V = Builder.CreateIntCast(Idx, Ty, false);
V = Builder.CreateLShr(ConstantInt::get(Ty, MagicBitvector), V);
V = Builder.CreateAnd(ConstantInt::get(Ty, 1), V);
@@ -1417,7 +1440,7 @@ Instruction *InstCombinerImpl::foldICmpWithConstant(ICmpInst &Cmp) {
// icmp(phi(C1, C2, ...), C) -> phi(icmp(C1, C), icmp(C2, C), ...).
Constant *C = dyn_cast<Constant>(Op1);
- if (!C)
+ if (!C || C->canTrap())
return nullptr;
if (auto *Phi = dyn_cast<PHINode>(Op0))
@@ -1478,7 +1501,7 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
// br DomCond, CmpBB, FalseBB
// CmpBB:
// Cmp = icmp Pred X, C
- ConstantRange CR = ConstantRange::makeAllowedICmpRegion(Pred, *C);
+ ConstantRange CR = ConstantRange::makeExactICmpRegion(Pred, *C);
ConstantRange DominatingCR =
(CmpBB == TrueBB) ? ConstantRange::makeExactICmpRegion(DomPred, *DomC)
: ConstantRange::makeExactICmpRegion(
@@ -1500,6 +1523,12 @@ Instruction *InstCombinerImpl::foldICmpWithDominatingICmp(ICmpInst &Cmp) {
if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp)))
return nullptr;
+ // Avoid an infinite loop with min/max canonicalization.
+ // TODO: This will be unnecessary if we canonicalize to min/max intrinsics.
+ if (Cmp.hasOneUse() &&
+ match(Cmp.user_back(), m_MaxOrMin(m_Value(), m_Value())))
+ return nullptr;
+
if (const APInt *EqC = Intersection.getSingleElement())
return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder.getInt(*EqC));
if (const APInt *NeC = Difference.getSingleElement())
@@ -1523,11 +1552,11 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
ConstantInt::get(V->getType(), 1));
}
+ unsigned DstBits = Trunc->getType()->getScalarSizeInBits(),
+ SrcBits = X->getType()->getScalarSizeInBits();
if (Cmp.isEquality() && Trunc->hasOneUse()) {
// Simplify icmp eq (trunc x to i8), 42 -> icmp eq x, 42|highbits if all
// of the high bits truncated out of x are known.
- unsigned DstBits = Trunc->getType()->getScalarSizeInBits(),
- SrcBits = X->getType()->getScalarSizeInBits();
KnownBits Known = computeKnownBits(X, 0, &Cmp);
// If all the high bits are known, we can do this xform.
@@ -1539,6 +1568,22 @@ Instruction *InstCombinerImpl::foldICmpTruncConstant(ICmpInst &Cmp,
}
}
+ // Look through truncated right-shift of the sign-bit for a sign-bit check:
+ // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] < 0 --> ShOp < 0
+ // trunc iN (ShOp >> ShAmtC) to i[N - ShAmtC] > -1 --> ShOp > -1
+ Value *ShOp;
+ const APInt *ShAmtC;
+ bool TrueIfSigned;
+ if (isSignBitCheck(Pred, C, TrueIfSigned) &&
+ match(X, m_Shr(m_Value(ShOp), m_APInt(ShAmtC))) &&
+ DstBits == SrcBits - ShAmtC->getZExtValue()) {
+ return TrueIfSigned
+ ? new ICmpInst(ICmpInst::ICMP_SLT, ShOp,
+ ConstantInt::getNullValue(X->getType()))
+ : new ICmpInst(ICmpInst::ICMP_SGT, ShOp,
+ ConstantInt::getAllOnesValue(X->getType()));
+ }
+
return nullptr;
}
@@ -1810,6 +1855,19 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
if (Instruction *I = foldICmpAndConstConst(Cmp, And, C))
return I;
+ const ICmpInst::Predicate Pred = Cmp.getPredicate();
+ bool TrueIfNeg;
+ if (isSignBitCheck(Pred, C, TrueIfNeg)) {
+ // ((X - 1) & ~X) < 0 --> X == 0
+ // ((X - 1) & ~X) >= 0 --> X != 0
+ Value *X;
+ if (match(And->getOperand(0), m_Add(m_Value(X), m_AllOnes())) &&
+ match(And->getOperand(1), m_Not(m_Specific(X)))) {
+ auto NewPred = TrueIfNeg ? CmpInst::ICMP_EQ : CmpInst::ICMP_NE;
+ return new ICmpInst(NewPred, X, ConstantInt::getNullValue(X->getType()));
+ }
+ }
+
// TODO: These all require that Y is constant too, so refactor with the above.
// Try to optimize things like "A[i] & 42 == 0" to index computations.
@@ -1832,8 +1890,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
// X & -C != -C -> X <= u ~C
// iff C is a power of 2
if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) {
- auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT
- : CmpInst::ICMP_ULE;
+ auto NewPred =
+ Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE;
return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1))));
}
@@ -1848,8 +1906,8 @@ Instruction *InstCombinerImpl::foldICmpAndConstant(ICmpInst &Cmp,
if (auto *AndVTy = dyn_cast<VectorType>(And->getType()))
NTy = VectorType::get(NTy, AndVTy->getElementCount());
Value *Trunc = Builder.CreateTrunc(X, NTy);
- auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE
- : CmpInst::ICMP_SLT;
+ auto NewPred =
+ Pred == CmpInst::ICMP_EQ ? CmpInst::ICMP_SGE : CmpInst::ICMP_SLT;
return new ICmpInst(NewPred, Trunc, Constant::getNullValue(NTy));
}
}
@@ -2258,6 +2316,16 @@ Instruction *InstCombinerImpl::foldICmpShrConstant(ICmpInst &Cmp,
if (Shr->isExact())
return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal));
+ if (C.isNullValue()) {
+ // == 0 is u< 1.
+ if (Pred == CmpInst::ICMP_EQ)
+ return new ICmpInst(CmpInst::ICMP_ULT, X,
+ ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal)));
+ else
+ return new ICmpInst(CmpInst::ICMP_UGT, X,
+ ConstantInt::get(ShrTy, (C + 1).shl(ShAmtVal) - 1));
+ }
+
if (Shr->hasOneUse()) {
// Canonicalize the shift into an 'and':
// icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt)
@@ -2581,7 +2649,7 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
// Fold icmp pred (add X, C2), C.
Value *X = Add->getOperand(0);
Type *Ty = Add->getType();
- CmpInst::Predicate Pred = Cmp.getPredicate();
+ const CmpInst::Predicate Pred = Cmp.getPredicate();
// If the add does not wrap, we can always adjust the compare by subtracting
// the constants. Equality comparisons are handled elsewhere. SGE/SLE/UGE/ULE
@@ -2616,6 +2684,28 @@ Instruction *InstCombinerImpl::foldICmpAddConstant(ICmpInst &Cmp,
return new ICmpInst(ICmpInst::ICMP_UGE, X, ConstantInt::get(Ty, Lower));
}
+ // This set of folds is intentionally placed after folds that use no-wrapping
+ // flags because those folds are likely better for later analysis/codegen.
+ const APInt SMax = APInt::getSignedMaxValue(Ty->getScalarSizeInBits());
+ const APInt SMin = APInt::getSignedMinValue(Ty->getScalarSizeInBits());
+
+ // Fold compare with offset to opposite sign compare if it eliminates offset:
+ // (X + C2) >u C --> X <s -C2 (if C == C2 + SMAX)
+ if (Pred == CmpInst::ICMP_UGT && C == *C2 + SMax)
+ return new ICmpInst(ICmpInst::ICMP_SLT, X, ConstantInt::get(Ty, -(*C2)));
+
+ // (X + C2) <u C --> X >s ~C2 (if C == C2 + SMIN)
+ if (Pred == CmpInst::ICMP_ULT && C == *C2 + SMin)
+ return new ICmpInst(ICmpInst::ICMP_SGT, X, ConstantInt::get(Ty, ~(*C2)));
+
+ // (X + C2) >s C --> X <u (SMAX - C) (if C == C2 - 1)
+ if (Pred == CmpInst::ICMP_SGT && C == *C2 - 1)
+ return new ICmpInst(ICmpInst::ICMP_ULT, X, ConstantInt::get(Ty, SMax - C));
+
+ // (X + C2) <s C --> X >u (C ^ SMAX) (if C == C2)
+ if (Pred == CmpInst::ICMP_SLT && C == *C2)
+ return new ICmpInst(ICmpInst::ICMP_UGT, X, ConstantInt::get(Ty, C ^ SMax));
+
if (!Add->hasOneUse())
return nullptr;
@@ -3235,14 +3325,24 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
// constant folded and the select turned into a bitwise or.
Value *Op1 = nullptr, *Op2 = nullptr;
ConstantInt *CI = nullptr;
- if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(1))) {
- Op1 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+
+ auto SimplifyOp = [&](Value *V) {
+ Value *Op = nullptr;
+ if (Constant *C = dyn_cast<Constant>(V)) {
+ Op = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+ } else if (RHSC->isNullValue()) {
+ // If null is being compared, check if it can be further simplified.
+ Op = SimplifyICmpInst(I.getPredicate(), V, RHSC, SQ);
+ }
+ return Op;
+ };
+ Op1 = SimplifyOp(LHSI->getOperand(1));
+ if (Op1)
CI = dyn_cast<ConstantInt>(Op1);
- }
- if (Constant *C = dyn_cast<Constant>(LHSI->getOperand(2))) {
- Op2 = ConstantExpr::getICmp(I.getPredicate(), C, RHSC);
+
+ Op2 = SimplifyOp(LHSI->getOperand(2));
+ if (Op2)
CI = dyn_cast<ConstantInt>(Op2);
- }
// We only want to perform this transformation if it will not lead to
// additional code. This is true if either both sides of the select
@@ -3901,11 +4001,15 @@ Instruction *InstCombinerImpl::foldICmpBinOp(ICmpInst &I,
APInt AP2Abs = C2->getValue().abs();
if (AP1Abs.uge(AP2Abs)) {
ConstantInt *C3 = Builder.getInt(AP1 - AP2);
- Value *NewAdd = Builder.CreateNSWAdd(A, C3);
+ bool HasNUW = BO0->hasNoUnsignedWrap() && C3->getValue().ule(AP1);
+ bool HasNSW = BO0->hasNoSignedWrap();
+ Value *NewAdd = Builder.CreateAdd(A, C3, "", HasNUW, HasNSW);
return new ICmpInst(Pred, NewAdd, C);
} else {
ConstantInt *C3 = Builder.getInt(AP2 - AP1);
- Value *NewAdd = Builder.CreateNSWAdd(C, C3);
+ bool HasNUW = BO1->hasNoUnsignedWrap() && C3->getValue().ule(AP2);
+ bool HasNSW = BO1->hasNoSignedWrap();
+ Value *NewAdd = Builder.CreateAdd(C, C3, "", HasNUW, HasNSW);
return new ICmpInst(Pred, A, NewAdd);
}
}
@@ -4467,6 +4571,16 @@ static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp,
/// Handle icmp (cast x), (cast or constant).
Instruction *InstCombinerImpl::foldICmpWithCastOp(ICmpInst &ICmp) {
+ // If any operand of ICmp is a inttoptr roundtrip cast then remove it as
+ // icmp compares only pointer's value.
+ // icmp (inttoptr (ptrtoint p1)), p2 --> icmp p1, p2.
+ Value *SimplifiedOp0 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(0));
+ Value *SimplifiedOp1 = simplifyIntToPtrRoundTripCast(ICmp.getOperand(1));
+ if (SimplifiedOp0 || SimplifiedOp1)
+ return new ICmpInst(ICmp.getPredicate(),
+ SimplifiedOp0 ? SimplifiedOp0 : ICmp.getOperand(0),
+ SimplifiedOp1 ? SimplifiedOp1 : ICmp.getOperand(1));
+
auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
if (!CastOp0)
return nullptr;
@@ -6267,6 +6381,26 @@ Instruction *InstCombinerImpl::visitFCmpInst(FCmpInst &I) {
}
}
+ // Convert a sign-bit test of an FP value into a cast and integer compare.
+ // TODO: Simplify if the copysign constant is 0.0 or NaN.
+ // TODO: Handle non-zero compare constants.
+ // TODO: Handle other predicates.
+ const APFloat *C;
+ if (match(Op0, m_OneUse(m_Intrinsic<Intrinsic::copysign>(m_APFloat(C),
+ m_Value(X)))) &&
+ match(Op1, m_AnyZeroFP()) && !C->isZero() && !C->isNaN()) {
+ Type *IntType = Builder.getIntNTy(X->getType()->getScalarSizeInBits());
+ if (auto *VecTy = dyn_cast<VectorType>(OpType))
+ IntType = VectorType::get(IntType, VecTy->getElementCount());
+
+ // copysign(non-zero constant, X) < 0.0 --> (bitcast X) < 0
+ if (Pred == FCmpInst::FCMP_OLT) {
+ Value *IntX = Builder.CreateBitCast(X, IntType);
+ return new ICmpInst(ICmpInst::ICMP_SLT, IntX,
+ ConstantInt::getNullValue(IntType));
+ }
+ }
+
if (I.getType()->isVectorTy())
if (Instruction *Res = foldVectorCmp(I, Builder))
return Res;