summaryrefslogtreecommitdiff
path: root/lib/Transforms/InstCombine/InstCombineCompares.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r--lib/Transforms/InstCombine/InstCombineCompares.cpp49
1 files changed, 23 insertions, 26 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp
index 6ad32490a3288..58b8b2f526299 100644
--- a/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -112,10 +112,10 @@ static bool subWithOverflow(Constant *&Result, Constant *In1,
/// Given an icmp instruction, return true if any use of this comparison is a
/// branch on sign bit comparison.
-static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) {
+static bool hasBranchUse(ICmpInst &I) {
for (auto *U : I.users())
if (isa<BranchInst>(U))
- return isSignBit;
+ return true;
return false;
}
@@ -1448,12 +1448,13 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) {
// of a test and branch. So we avoid canonicalizing in such situations
// because test and branch instruction has better branch displacement
// than compare and branch instruction.
- if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) {
- if (auto *AI = Intersection.getSingleElement())
- return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI));
- if (auto *AD = Difference.getSingleElement())
- return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD));
- }
+ if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp)))
+ return nullptr;
+
+ if (auto *AI = Intersection.getSingleElement())
+ return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI));
+ if (auto *AD = Difference.getSingleElement())
+ return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD));
}
return nullptr;
@@ -3301,12 +3302,12 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
return nullptr;
Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
+ const CmpInst::Predicate Pred = I.getPredicate();
Value *A, *B, *C, *D;
if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) {
if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0
Value *OtherVal = A == Op1 ? B : A;
- return new ICmpInst(I.getPredicate(), OtherVal,
- Constant::getNullValue(A->getType()));
+ return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType()));
}
if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) {
@@ -3316,26 +3317,25 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
Op1->hasOneUse()) {
Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue());
Value *Xor = Builder->CreateXor(C, NC);
- return new ICmpInst(I.getPredicate(), A, Xor);
+ return new ICmpInst(Pred, A, Xor);
}
// A^B == A^D -> B == D
if (A == C)
- return new ICmpInst(I.getPredicate(), B, D);
+ return new ICmpInst(Pred, B, D);
if (A == D)
- return new ICmpInst(I.getPredicate(), B, C);
+ return new ICmpInst(Pred, B, C);
if (B == C)
- return new ICmpInst(I.getPredicate(), A, D);
+ return new ICmpInst(Pred, A, D);
if (B == D)
- return new ICmpInst(I.getPredicate(), A, C);
+ return new ICmpInst(Pred, A, C);
}
}
if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) {
// A == (A^B) -> B == 0
Value *OtherVal = A == Op0 ? B : A;
- return new ICmpInst(I.getPredicate(), OtherVal,
- Constant::getNullValue(A->getType()));
+ return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType()));
}
// (X&Z) == (Y&Z) -> (X^Y) & Z == 0
@@ -3380,8 +3380,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
APInt Pow2 = Cst1->getValue() + 1;
if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) &&
Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth())
- return new ICmpInst(I.getPredicate(), A,
- Builder->CreateTrunc(B, A->getType()));
+ return new ICmpInst(Pred, A, Builder->CreateTrunc(B, A->getType()));
}
// (A >> C) == (B >> C) --> (A^B) u< (1 << C)
@@ -3393,12 +3392,11 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
unsigned TypeBits = Cst1->getBitWidth();
unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits);
if (ShAmt < TypeBits && ShAmt != 0) {
- ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE
- ? ICmpInst::ICMP_UGE
- : ICmpInst::ICMP_ULT;
+ ICmpInst::Predicate NewPred =
+ Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT;
Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted");
APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt);
- return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal));
+ return new ICmpInst(NewPred, Xor, Builder->getInt(CmpVal));
}
}
@@ -3412,8 +3410,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt);
Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal),
I.getName() + ".mask");
- return new ICmpInst(I.getPredicate(), And,
- Constant::getNullValue(Cst1->getType()));
+ return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType()));
}
}
@@ -3437,7 +3434,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) {
CmpV <<= ShAmt;
Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV));
- return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV));
+ return new ICmpInst(Pred, Mask, Builder->getInt(CmpV));
}
}