diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineCompares.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineCompares.cpp | 49 |
1 files changed, 23 insertions, 26 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index 6ad32490a3288..58b8b2f526299 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -112,10 +112,10 @@ static bool subWithOverflow(Constant *&Result, Constant *In1, /// Given an icmp instruction, return true if any use of this comparison is a /// branch on sign bit comparison. -static bool isBranchOnSignBitCheck(ICmpInst &I, bool isSignBit) { +static bool hasBranchUse(ICmpInst &I) { for (auto *U : I.users()) if (isa<BranchInst>(U)) - return isSignBit; + return true; return false; } @@ -1448,12 +1448,13 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { // of a test and branch. So we avoid canonicalizing in such situations // because test and branch instruction has better branch displacement // than compare and branch instruction. - if (!isBranchOnSignBitCheck(Cmp, IsSignBit) && !Cmp.isEquality()) { - if (auto *AI = Intersection.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); - if (auto *AD = Difference.getSingleElement()) - return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); - } + if (Cmp.isEquality() || (IsSignBit && hasBranchUse(Cmp))) + return nullptr; + + if (auto *AI = Intersection.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_EQ, X, Builder->getInt(*AI)); + if (auto *AD = Difference.getSingleElement()) + return new ICmpInst(ICmpInst::ICMP_NE, X, Builder->getInt(*AD)); } return nullptr; @@ -3301,12 +3302,12 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { return nullptr; Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); + const CmpInst::Predicate Pred = I.getPredicate(); Value *A, *B, *C, *D; if (match(Op0, m_Xor(m_Value(A), m_Value(B)))) { if (A == Op1 || B == Op1) { // (A^B) == A -> B == 0 Value *OtherVal = A == Op1 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); } if (match(Op1, m_Xor(m_Value(C), m_Value(D)))) { @@ -3316,26 +3317,25 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { Op1->hasOneUse()) { Constant *NC = Builder->getInt(C1->getValue() ^ C2->getValue()); Value *Xor = Builder->CreateXor(C, NC); - return new ICmpInst(I.getPredicate(), A, Xor); + return new ICmpInst(Pred, A, Xor); } // A^B == A^D -> B == D if (A == C) - return new ICmpInst(I.getPredicate(), B, D); + return new ICmpInst(Pred, B, D); if (A == D) - return new ICmpInst(I.getPredicate(), B, C); + return new ICmpInst(Pred, B, C); if (B == C) - return new ICmpInst(I.getPredicate(), A, D); + return new ICmpInst(Pred, A, D); if (B == D) - return new ICmpInst(I.getPredicate(), A, C); + return new ICmpInst(Pred, A, C); } } if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && (A == Op0 || B == Op0)) { // A == (A^B) -> B == 0 Value *OtherVal = A == Op0 ? B : A; - return new ICmpInst(I.getPredicate(), OtherVal, - Constant::getNullValue(A->getType())); + return new ICmpInst(Pred, OtherVal, Constant::getNullValue(A->getType())); } // (X&Z) == (Y&Z) -> (X^Y) & Z == 0 @@ -3380,8 +3380,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { APInt Pow2 = Cst1->getValue() + 1; if (Pow2.isPowerOf2() && isa<IntegerType>(A->getType()) && Pow2.logBase2() == cast<IntegerType>(A->getType())->getBitWidth()) - return new ICmpInst(I.getPredicate(), A, - Builder->CreateTrunc(B, A->getType())); + return new ICmpInst(Pred, A, Builder->CreateTrunc(B, A->getType())); } // (A >> C) == (B >> C) --> (A^B) u< (1 << C) @@ -3393,12 +3392,11 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { unsigned TypeBits = Cst1->getBitWidth(); unsigned ShAmt = (unsigned)Cst1->getLimitedValue(TypeBits); if (ShAmt < TypeBits && ShAmt != 0) { - ICmpInst::Predicate Pred = I.getPredicate() == ICmpInst::ICMP_NE - ? ICmpInst::ICMP_UGE - : ICmpInst::ICMP_ULT; + ICmpInst::Predicate NewPred = + Pred == ICmpInst::ICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; Value *Xor = Builder->CreateXor(A, B, I.getName() + ".unshifted"); APInt CmpVal = APInt::getOneBitSet(TypeBits, ShAmt); - return new ICmpInst(Pred, Xor, Builder->getInt(CmpVal)); + return new ICmpInst(NewPred, Xor, Builder->getInt(CmpVal)); } } @@ -3412,8 +3410,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { APInt AndVal = APInt::getLowBitsSet(TypeBits, TypeBits - ShAmt); Value *And = Builder->CreateAnd(Xor, Builder->getInt(AndVal), I.getName() + ".mask"); - return new ICmpInst(I.getPredicate(), And, - Constant::getNullValue(Cst1->getType())); + return new ICmpInst(Pred, And, Constant::getNullValue(Cst1->getType())); } } @@ -3437,7 +3434,7 @@ Instruction *InstCombiner::foldICmpEquality(ICmpInst &I) { CmpV <<= ShAmt; Value *Mask = Builder->CreateAnd(A, Builder->getInt(MaskV)); - return new ICmpInst(I.getPredicate(), Mask, Builder->getInt(CmpV)); + return new ICmpInst(Pred, Mask, Builder->getInt(CmpV)); } } |