diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineAndOrXor.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 243 |
1 files changed, 106 insertions, 137 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 5e0bfe8e26d2..0dbe11d2f01f 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -14,6 +14,7 @@ #include "InstCombine.h" #include "llvm/Intrinsics.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Transforms/Utils/CmpInstAnalysis.h" #include "llvm/Support/ConstantRange.h" #include "llvm/Support/PatternMatch.h" using namespace llvm; @@ -62,50 +63,6 @@ static inline Value *dyn_castNotVal(Value *V) { return 0; } - -/// getICmpCode - Encode a icmp predicate into a three bit mask. These bits -/// are carefully arranged to allow folding of expressions such as: -/// -/// (A < B) | (A > B) --> (A != B) -/// -/// Note that this is only valid if the first and second predicates have the -/// same sign. Is illegal to do: (A u< B) | (A s> B) -/// -/// Three bits are used to represent the condition, as follows: -/// 0 A > B -/// 1 A == B -/// 2 A < B -/// -/// <=> Value Definition -/// 000 0 Always false -/// 001 1 A > B -/// 010 2 A == B -/// 011 3 A >= B -/// 100 4 A < B -/// 101 5 A != B -/// 110 6 A <= B -/// 111 7 Always true -/// -static unsigned getICmpCode(const ICmpInst *ICI) { - switch (ICI->getPredicate()) { - // False -> 0 - case ICmpInst::ICMP_UGT: return 1; // 001 - case ICmpInst::ICMP_SGT: return 1; // 001 - case ICmpInst::ICMP_EQ: return 2; // 010 - case ICmpInst::ICMP_UGE: return 3; // 011 - case ICmpInst::ICMP_SGE: return 3; // 011 - case ICmpInst::ICMP_ULT: return 4; // 100 - case ICmpInst::ICMP_SLT: return 4; // 100 - case ICmpInst::ICMP_NE: return 5; // 101 - case ICmpInst::ICMP_ULE: return 6; // 110 - case ICmpInst::ICMP_SLE: return 6; // 110 - // True -> 7 - default: - llvm_unreachable("Invalid ICmp predicate!"); - return 0; - } -} - /// getFCmpCode - Similar to getICmpCode but for FCmpInst. This encodes a fcmp /// predicate into a three bit mask. It also returns whether it is an ordered /// predicate by reference. @@ -130,31 +87,19 @@ static unsigned getFCmpCode(FCmpInst::Predicate CC, bool &isOrdered) { default: // Not expecting FCMP_FALSE and FCMP_TRUE; llvm_unreachable("Unexpected FCmp predicate!"); - return 0; } } -/// getICmpValue - This is the complement of getICmpCode, which turns an +/// getNewICmpValue - This is the complement of getICmpCode, which turns an /// opcode and two operands into either a constant true or false, or a brand /// new ICmp instruction. The sign is passed in to determine which kind /// of predicate to use in the new icmp instruction. -static Value *getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, - InstCombiner::BuilderTy *Builder) { - CmpInst::Predicate Pred; - switch (Code) { - default: assert(0 && "Illegal ICmp code!"); - case 0: // False. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case 1: Pred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; - case 2: Pred = ICmpInst::ICMP_EQ; break; - case 3: Pred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; - case 4: Pred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; - case 5: Pred = ICmpInst::ICMP_NE; break; - case 6: Pred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; - case 7: // True. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - } - return Builder->CreateICmp(Pred, LHS, RHS); +static Value *getNewICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, + InstCombiner::BuilderTy *Builder) { + ICmpInst::Predicate NewPred; + if (Value *NewConstant = getICmpValue(Sign, Code, LHS, RHS, NewPred)) + return NewConstant; + return Builder->CreateICmp(NewPred, LHS, RHS); } /// getFCmpValue - This is the complement of getFCmpCode, which turns an @@ -165,7 +110,7 @@ static Value *getFCmpValue(bool isordered, unsigned code, InstCombiner::BuilderTy *Builder) { CmpInst::Predicate Pred; switch (code) { - default: assert(0 && "Illegal FCmp code!"); + default: llvm_unreachable("Illegal FCmp code!"); case 0: Pred = isordered ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; break; case 1: Pred = isordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; break; case 2: Pred = isordered ? FCmpInst::FCMP_OEQ : FCmpInst::FCMP_UEQ; break; @@ -180,14 +125,6 @@ static Value *getFCmpValue(bool isordered, unsigned code, return Builder->CreateFCmp(Pred, LHS, RHS); } -/// PredicatesFoldable - Return true if both predicates match sign or if at -/// least one of them is an equality comparison (which is signless). -static bool PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { - return (CmpInst::isSigned(p1) == CmpInst::isSigned(p2)) || - (CmpInst::isSigned(p1) && ICmpInst::isEquality(p2)) || - (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1)); -} - // OptAndOp - This handles expressions of the form ((val OP C1) & C2). Where // the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is // guaranteed to be a binary operator. @@ -558,6 +495,38 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, return result; } +/// decomposeBitTestICmp - Decompose an icmp into the form ((X & Y) pred Z) +/// if possible. The returned predicate is either == or !=. Returns false if +/// decomposition fails. +static bool decomposeBitTestICmp(const ICmpInst *I, ICmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + // X < 0 is equivalent to (X & SignBit) != 0. + if (I->getPredicate() == ICmpInst::ICMP_SLT) + if (ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1))) + if (C->isZero()) { + X = I->getOperand(0); + Y = ConstantInt::get(I->getContext(), + APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_NE; + Z = C; + return true; + } + + // X > -1 is equivalent to (X & SignBit) == 0. + if (I->getPredicate() == ICmpInst::ICMP_SGT) + if (ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1))) + if (C->isAllOnesValue()) { + X = I->getOperand(0); + Y = ConstantInt::get(I->getContext(), + APInt::getSignBit(C->getBitWidth())); + Pred = ICmpInst::ICMP_EQ; + Z = ConstantInt::getNullValue(C->getType()); + return true; + } + + return false; +} + /// foldLogOpOfMaskedICmpsHelper: /// handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) /// return the set of pattern classes (from MaskedICmpType) @@ -565,10 +534,9 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, Value*& B, Value*& C, Value*& D, Value*& E, - ICmpInst *LHS, ICmpInst *RHS) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - if (LHSCC != ICmpInst::ICMP_EQ && LHSCC != ICmpInst::ICMP_NE) return 0; - if (RHSCC != ICmpInst::ICMP_EQ && RHSCC != ICmpInst::ICMP_NE) return 0; + ICmpInst *LHS, ICmpInst *RHS, + ICmpInst::Predicate &LHSCC, + ICmpInst::Predicate &RHSCC) { if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; // vectors are not (yet?) supported if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; @@ -582,40 +550,60 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, Value *L1 = LHS->getOperand(0); Value *L2 = LHS->getOperand(1); Value *L11,*L12,*L21,*L22; - if (match(L1, m_And(m_Value(L11), m_Value(L12)))) { - if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) + // Check whether the icmp can be decomposed into a bit test. + if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { + L21 = L22 = L1 = 0; + } else { + // Look for ANDs in the LHS icmp. + if (match(L1, m_And(m_Value(L11), m_Value(L12)))) { + if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) + L21 = L22 = 0; + } else { + if (!match(L2, m_And(m_Value(L11), m_Value(L12)))) + return 0; + std::swap(L1, L2); L21 = L22 = 0; - } - else { - if (!match(L2, m_And(m_Value(L11), m_Value(L12)))) - return 0; - std::swap(L1, L2); - L21 = L22 = 0; + } } + // Bail if LHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(LHSCC)) + return 0; + Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); Value *R11,*R12; bool ok = false; - if (match(R1, m_And(m_Value(R11), m_Value(R12)))) { - if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) { - A = R11; D = R12; E = R2; ok = true; + if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { + A = R12; D = R11; + } else { + return 0; } - else - if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) { + E = R2; R1 = 0; ok = true; + } else if (match(R1, m_And(m_Value(R11), m_Value(R12)))) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; E = R2; ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { A = R12; D = R11; E = R2; ok = true; } } + + // Bail if RHS was a icmp that can't be decomposed into an equality. + if (!ICmpInst::isEquality(RHSCC)) + return 0; + + // Look for ANDs in on the right side of the RHS icmp. if (!ok && match(R2, m_And(m_Value(R11), m_Value(R12)))) { - if (R11 != 0 && (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22)) { - A = R11; D = R12; E = R1; ok = true; - } - else - if (R12 != 0 && (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22)) { + if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { + A = R11; D = R12; E = R1; ok = true; + } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { A = R12; D = R11; E = R1; ok = true; - } - else + } else { return 0; + } } if (!ok) return 0; @@ -644,8 +632,12 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, ICmpInst::Predicate NEWCC, llvm::InstCombiner::BuilderTy* Builder) { Value *A = 0, *B = 0, *C = 0, *D = 0, *E = 0; - unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS); + ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + unsigned mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, + LHSCC, RHSCC); if (mask == 0) return 0; + assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && + "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); if (NEWCC == ICmpInst::ICMP_NE) mask >>= 1; // treat "Not"-states as normal states @@ -693,11 +685,11 @@ static Value* foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, ConstantInt *CCst = dyn_cast<ConstantInt>(C); if (CCst == 0) return 0; - if (LHS->getPredicate() != NEWCC) + if (LHSCC != NEWCC) CCst = dyn_cast<ConstantInt>( ConstantExpr::getXor(BCst, CCst) ); ConstantInt *ECst = dyn_cast<ConstantInt>(E); if (ECst == 0) return 0; - if (RHS->getPredicate() != NEWCC) + if (RHSCC != NEWCC) ECst = dyn_cast<ConstantInt>( ConstantExpr::getXor(DCst, ECst) ); ConstantInt* MCst = dyn_cast<ConstantInt>( ConstantExpr::getAnd(ConstantExpr::getAnd(BCst, DCst), @@ -728,7 +720,7 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) & getICmpCode(RHS); bool isSigned = LHS->isSigned() || RHS->isSigned(); - return getICmpValue(isSigned, Code, Op0, Op1, Builder); + return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); } } @@ -756,24 +748,12 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } - - // (icmp slt A, 0) & (icmp slt B, 0) --> (icmp slt (A&B), 0) - if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) { - Value *NewAnd = Builder->CreateAnd(Val, Val2); - return Builder->CreateICmp(LHSCC, NewAnd, LHSCst); - } - - // (icmp sgt A, -1) & (icmp sgt B, -1) --> (icmp sgt (A|B), -1) - if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } } // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. - if (LHSCC == RHSCC && ICmpInst::isEquality(LHSCC) && + if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC && LHS->hasOneUse() && RHS->hasOneUse()) { Value *V; ConstantInt *AndCst, *SmallCst = 0, *BigCst = 0; @@ -805,7 +785,7 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { } } } - + // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. if (Val != Val2) return 0; @@ -1382,13 +1362,8 @@ static bool CollectBSwapParts(Value *V, int OverallLeftShift, uint32_t ByteMask, // part of the value (e.g. byte 3) then it must be shifted right. If from the // low part, it must be shifted left. unsigned DestByteNo = InputByteNo + OverallLeftShift; - if (InputByteNo < ByteValues.size()/2) { - if (ByteValues.size()-1-DestByteNo != InputByteNo) - return true; - } else { - if (ByteValues.size()-1-DestByteNo != InputByteNo) - return true; - } + if (ByteValues.size()-1-DestByteNo != InputByteNo) + return true; // If the destination byte value is already defined, the values are or'd // together, which isn't a bswap (unless it's an or of the same bits). @@ -1469,7 +1444,7 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *Op0 = LHS->getOperand(0), *Op1 = LHS->getOperand(1); unsigned Code = getICmpCode(LHS) | getICmpCode(RHS); bool isSigned = LHS->isSigned() || RHS->isSigned(); - return getICmpValue(isSigned, Code, Op0, Op1, Builder); + return getNewICmpValue(isSigned, Code, Op0, Op1, Builder); } } @@ -1490,18 +1465,6 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { Value *NewOr = Builder->CreateOr(Val, Val2); return Builder->CreateICmp(LHSCC, NewOr, LHSCst); } - - // (icmp slt A, 0) | (icmp slt B, 0) --> (icmp slt (A|B), 0) - if (LHSCC == ICmpInst::ICMP_SLT && LHSCst->isZero()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); - } - - // (icmp sgt A, -1) | (icmp sgt B, -1) --> (icmp sgt (A&B), -1) - if (LHSCC == ICmpInst::ICMP_SGT && LHSCst->isAllOnesValue()) { - Value *NewAnd = Builder->CreateAnd(Val, Val2); - return Builder->CreateICmp(LHSCC, NewAnd, LHSCst); - } } // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) @@ -1586,7 +1549,6 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS) { case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true return ConstantInt::getTrue(LHS->getContext()); } - break; case ICmpInst::ICMP_ULT: switch (RHSCC) { default: llvm_unreachable("Unknown integer condition code!"); @@ -1962,8 +1924,11 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } // Canonicalize xor to the RHS. - if (match(Op0, m_Xor(m_Value(), m_Value()))) + bool SwappedForXor = false; + if (match(Op0, m_Xor(m_Value(), m_Value()))) { std::swap(Op0, Op1); + SwappedForXor = true; + } // A | ( A ^ B) -> A | B // A | (~A ^ B) -> A | ~B @@ -1994,6 +1959,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateOr(Not, Op0); } + if (SwappedForXor) + std::swap(Op0, Op1); + if (ICmpInst *RHS = dyn_cast<ICmpInst>(I.getOperand(1))) if (ICmpInst *LHS = dyn_cast<ICmpInst>(I.getOperand(0))) if (Value *Res = FoldOrOfICmps(LHS, RHS)) @@ -2281,7 +2249,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { unsigned Code = getICmpCode(LHS) ^ getICmpCode(RHS); bool isSigned = LHS->isSigned() || RHS->isSigned(); return ReplaceInstUsesWith(I, - getICmpValue(isSigned, Code, Op0, Op1, Builder)); + getNewICmpValue(isSigned, Code, Op0, Op1, + Builder)); } } |