diff options
Diffstat (limited to 'lib/Transforms/InstCombine/InstCombineAndOrXor.cpp')
-rw-r--r-- | lib/Transforms/InstCombine/InstCombineAndOrXor.cpp | 1192 |
1 files changed, 566 insertions, 626 deletions
diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index da5384a86aac..b2a41c699202 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -137,9 +137,8 @@ Value *InstCombiner::SimplifyBSwap(BinaryOperator &I) { } /// This handles expressions of the form ((val OP C1) & C2). Where -/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. Op is -/// guaranteed to be a binary operator. -Instruction *InstCombiner::OptAndOp(Instruction *Op, +/// the Op parameter is 'OP', OpRHS is 'C1', and AndRHS is 'C2'. +Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, ConstantInt *OpRHS, ConstantInt *AndRHS, BinaryOperator &TheAnd) { @@ -149,6 +148,7 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, Together = ConstantExpr::getAnd(AndRHS, OpRHS); switch (Op->getOpcode()) { + default: break; case Instruction::Xor: if (Op->hasOneUse()) { // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) @@ -159,13 +159,6 @@ Instruction *InstCombiner::OptAndOp(Instruction *Op, break; case Instruction::Or: if (Op->hasOneUse()){ - if (Together != OpRHS) { - // (X | C1) & C2 --> (X | (C1&C2)) & C2 - Value *Or = Builder->CreateOr(X, Together); - Or->takeName(Op); - return BinaryOperator::CreateAnd(Or, AndRHS); - } - ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together); if (TogetherCI && !TogetherCI->isZero()){ // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 @@ -302,178 +295,91 @@ Value *InstCombiner::insertRangeTest(Value *V, const APInt &Lo, const APInt &Hi, return Builder->CreateICmp(Pred, VMinusLo, HiMinusLo); } -/// Returns true iff Val consists of one contiguous run of 1s with any number -/// of 0s on either side. The 1s are allowed to wrap from LSB to MSB, -/// so 0x000FFF0, 0x0000FFFF, and 0xFF0000FF are all runs. 0x0F0F0000 is -/// not, since all 1s are not contiguous. -static bool isRunOfOnes(ConstantInt *Val, uint32_t &MB, uint32_t &ME) { - const APInt& V = Val->getValue(); - uint32_t BitWidth = Val->getType()->getBitWidth(); - if (!APIntOps::isShiftedMask(BitWidth, V)) return false; - - // look for the first zero bit after the run of ones - MB = BitWidth - ((V - 1) ^ V).countLeadingZeros(); - // look for the first non-zero bit - ME = V.getActiveBits(); - return true; -} - -/// This is part of an expression (LHS +/- RHS) & Mask, where isSub determines -/// whether the operator is a sub. If we can fold one of the following xforms: +/// Classify (icmp eq (A & B), C) and (icmp ne (A & B), C) as matching patterns +/// that can be simplified. +/// One of A and B is considered the mask. The other is the value. This is +/// described as the "AMask" or "BMask" part of the enum. If the enum contains +/// only "Mask", then both A and B can be considered masks. If A is the mask, +/// then it was proven that (A & C) == C. This is trivial if C == A or C == 0. +/// If both A and C are constants, this proof is also easy. +/// For the following explanations, we assume that A is the mask. /// -/// ((A & N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == Mask -/// ((A | N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 -/// ((A ^ N) +/- B) & Mask -> (A +/- B) & Mask iff N&Mask == 0 +/// "AllOnes" declares that the comparison is true only if (A & B) == A or all +/// bits of A are set in B. +/// Example: (icmp eq (A & 3), 3) -> AMask_AllOnes /// -/// return (A +/- B). +/// "AllZeros" declares that the comparison is true only if (A & B) == 0 or all +/// bits of A are cleared in B. +/// Example: (icmp eq (A & 3), 0) -> Mask_AllZeroes +/// +/// "Mixed" declares that (A & B) == C and C might or might not contain any +/// number of one bits and zero bits. +/// Example: (icmp eq (A & 3), 1) -> AMask_Mixed +/// +/// "Not" means that in above descriptions "==" should be replaced by "!=". +/// Example: (icmp ne (A & 3), 3) -> AMask_NotAllOnes /// -Value *InstCombiner::FoldLogicalPlusAnd(Value *LHS, Value *RHS, - ConstantInt *Mask, bool isSub, - Instruction &I) { - Instruction *LHSI = dyn_cast<Instruction>(LHS); - if (!LHSI || LHSI->getNumOperands() != 2 || - !isa<ConstantInt>(LHSI->getOperand(1))) return nullptr; - - ConstantInt *N = cast<ConstantInt>(LHSI->getOperand(1)); - - switch (LHSI->getOpcode()) { - default: return nullptr; - case Instruction::And: - if (ConstantExpr::getAnd(N, Mask) == Mask) { - // If the AndRHS is a power of two minus one (0+1+), this is simple. - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == - Mask->getValue().getBitWidth()) - break; - - // Otherwise, if Mask is 0+1+0+, and if B is known to have the low 0+ - // part, we don't need any explicit masks to take them out of A. If that - // is all N is, ignore it. - uint32_t MB = 0, ME = 0; - if (isRunOfOnes(Mask, MB, ME)) { // begin/end bit of run, inclusive - uint32_t BitWidth = cast<IntegerType>(RHS->getType())->getBitWidth(); - APInt Mask(APInt::getLowBitsSet(BitWidth, MB-1)); - if (MaskedValueIsZero(RHS, Mask, 0, &I)) - break; - } - } - return nullptr; - case Instruction::Or: - case Instruction::Xor: - // If the AndRHS is a power of two minus one (0+1+), and N&Mask == 0 - if ((Mask->getValue().countLeadingZeros() + - Mask->getValue().countPopulation()) == Mask->getValue().getBitWidth() - && ConstantExpr::getAnd(N, Mask)->isNullValue()) - break; - return nullptr; - } - - if (isSub) - return Builder->CreateSub(LHSI->getOperand(0), RHS, "fold"); - return Builder->CreateAdd(LHSI->getOperand(0), RHS, "fold"); -} - -/// enum for classifying (icmp eq (A & B), C) and (icmp ne (A & B), C) -/// One of A and B is considered the mask, the other the value. This is -/// described as the "AMask" or "BMask" part of the enum. If the enum -/// contains only "Mask", then both A and B can be considered masks. -/// If A is the mask, then it was proven, that (A & C) == C. This -/// is trivial if C == A, or C == 0. If both A and C are constants, this -/// proof is also easy. -/// For the following explanations we assume that A is the mask. -/// The part "AllOnes" declares, that the comparison is true only -/// if (A & B) == A, or all bits of A are set in B. -/// Example: (icmp eq (A & 3), 3) -> FoldMskICmp_AMask_AllOnes -/// The part "AllZeroes" declares, that the comparison is true only -/// if (A & B) == 0, or all bits of A are cleared in B. -/// Example: (icmp eq (A & 3), 0) -> FoldMskICmp_Mask_AllZeroes -/// The part "Mixed" declares, that (A & B) == C and C might or might not -/// contain any number of one bits and zero bits. -/// Example: (icmp eq (A & 3), 1) -> FoldMskICmp_AMask_Mixed -/// The Part "Not" means, that in above descriptions "==" should be replaced -/// by "!=". -/// Example: (icmp ne (A & 3), 3) -> FoldMskICmp_AMask_NotAllOnes /// If the mask A contains a single bit, then the following is equivalent: /// (icmp eq (A & B), A) equals (icmp ne (A & B), 0) /// (icmp ne (A & B), A) equals (icmp eq (A & B), 0) enum MaskedICmpType { - FoldMskICmp_AMask_AllOnes = 1, - FoldMskICmp_AMask_NotAllOnes = 2, - FoldMskICmp_BMask_AllOnes = 4, - FoldMskICmp_BMask_NotAllOnes = 8, - FoldMskICmp_Mask_AllZeroes = 16, - FoldMskICmp_Mask_NotAllZeroes = 32, - FoldMskICmp_AMask_Mixed = 64, - FoldMskICmp_AMask_NotMixed = 128, - FoldMskICmp_BMask_Mixed = 256, - FoldMskICmp_BMask_NotMixed = 512 + AMask_AllOnes = 1, + AMask_NotAllOnes = 2, + BMask_AllOnes = 4, + BMask_NotAllOnes = 8, + Mask_AllZeros = 16, + Mask_NotAllZeros = 32, + AMask_Mixed = 64, + AMask_NotMixed = 128, + BMask_Mixed = 256, + BMask_NotMixed = 512 }; -/// Return the set of pattern classes (from MaskedICmpType) -/// that (icmp SCC (A & B), C) satisfies. -static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, - ICmpInst::Predicate SCC) -{ +/// Return the set of patterns (from MaskedICmpType) that (icmp SCC (A & B), C) +/// satisfies. +static unsigned getMaskedICmpType(Value *A, Value *B, Value *C, + ICmpInst::Predicate Pred) { ConstantInt *ACst = dyn_cast<ConstantInt>(A); ConstantInt *BCst = dyn_cast<ConstantInt>(B); ConstantInt *CCst = dyn_cast<ConstantInt>(C); - bool icmp_eq = (SCC == ICmpInst::ICMP_EQ); - bool icmp_abit = (ACst && !ACst->isZero() && - ACst->getValue().isPowerOf2()); - bool icmp_bbit = (BCst && !BCst->isZero() && - BCst->getValue().isPowerOf2()); - unsigned result = 0; + bool IsEq = (Pred == ICmpInst::ICMP_EQ); + bool IsAPow2 = (ACst && !ACst->isZero() && ACst->getValue().isPowerOf2()); + bool IsBPow2 = (BCst && !BCst->isZero() && BCst->getValue().isPowerOf2()); + unsigned MaskVal = 0; if (CCst && CCst->isZero()) { // if C is zero, then both A and B qualify as mask - result |= (icmp_eq ? (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed)); - return result; + MaskVal |= (IsEq ? (Mask_AllZeros | AMask_Mixed | BMask_Mixed) + : (Mask_NotAllZeros | AMask_NotMixed | BMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (AMask_NotAllOnes | AMask_NotMixed) + : (AMask_AllOnes | AMask_Mixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (BMask_NotAllOnes | BMask_NotMixed) + : (BMask_AllOnes | BMask_Mixed)); + return MaskVal; } + if (A == C) { - result |= (icmp_eq ? (FoldMskICmp_AMask_AllOnes | - FoldMskICmp_AMask_Mixed) - : (FoldMskICmp_AMask_NotAllOnes | - FoldMskICmp_AMask_NotMixed)); - if (icmp_abit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_AMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_AMask_Mixed)); - } else if (ACst && CCst && - ConstantExpr::getAnd(ACst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_AMask_Mixed - : FoldMskICmp_AMask_NotMixed); + MaskVal |= (IsEq ? (AMask_AllOnes | AMask_Mixed) + : (AMask_NotAllOnes | AMask_NotMixed)); + if (IsAPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | AMask_NotMixed) + : (Mask_AllZeros | AMask_Mixed)); + } else if (ACst && CCst && ConstantExpr::getAnd(ACst, CCst) == CCst) { + MaskVal |= (IsEq ? AMask_Mixed : AMask_NotMixed); } + if (B == C) { - result |= (icmp_eq ? (FoldMskICmp_BMask_AllOnes | - FoldMskICmp_BMask_Mixed) - : (FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_BMask_NotMixed)); - if (icmp_bbit) - result |= (icmp_eq ? (FoldMskICmp_Mask_NotAllZeroes | - FoldMskICmp_BMask_NotMixed) - : (FoldMskICmp_Mask_AllZeroes | - FoldMskICmp_BMask_Mixed)); - } else if (BCst && CCst && - ConstantExpr::getAnd(BCst, CCst) == CCst) { - result |= (icmp_eq ? FoldMskICmp_BMask_Mixed - : FoldMskICmp_BMask_NotMixed); - } - return result; + MaskVal |= (IsEq ? (BMask_AllOnes | BMask_Mixed) + : (BMask_NotAllOnes | BMask_NotMixed)); + if (IsBPow2) + MaskVal |= (IsEq ? (Mask_NotAllZeros | BMask_NotMixed) + : (Mask_AllZeros | BMask_Mixed)); + } else if (BCst && CCst && ConstantExpr::getAnd(BCst, CCst) == CCst) { + MaskVal |= (IsEq ? BMask_Mixed : BMask_NotMixed); + } + + return MaskVal; } /// Convert an analysis of a masked ICmp into its equivalent if all boolean @@ -482,32 +388,30 @@ static unsigned getTypeOfMaskedICmp(Value* A, Value* B, Value* C, /// involves swapping those bits over. static unsigned conjugateICmpMask(unsigned Mask) { unsigned NewMask; - NewMask = (Mask & (FoldMskICmp_AMask_AllOnes | FoldMskICmp_BMask_AllOnes | - FoldMskICmp_Mask_AllZeroes | FoldMskICmp_AMask_Mixed | - FoldMskICmp_BMask_Mixed)) + NewMask = (Mask & (AMask_AllOnes | BMask_AllOnes | Mask_AllZeros | + AMask_Mixed | BMask_Mixed)) << 1; - NewMask |= - (Mask & (FoldMskICmp_AMask_NotAllOnes | FoldMskICmp_BMask_NotAllOnes | - FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_AMask_NotMixed | - FoldMskICmp_BMask_NotMixed)) - >> 1; + NewMask |= (Mask & (AMask_NotAllOnes | BMask_NotAllOnes | Mask_NotAllZeros | + AMask_NotMixed | BMask_NotMixed)) + >> 1; return NewMask; } -/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E) -/// Return the set of pattern classes (from MaskedICmpType) -/// that both LHS and RHS satisfy. -static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, - Value*& B, Value*& C, - Value*& D, Value*& E, - ICmpInst *LHS, ICmpInst *RHS, - ICmpInst::Predicate &LHSCC, - ICmpInst::Predicate &RHSCC) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) return 0; +/// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). +/// Return the set of pattern classes (from MaskedICmpType) that both LHS and +/// RHS satisfy. +static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, + Value *&D, Value *&E, ICmpInst *LHS, + ICmpInst *RHS, + ICmpInst::Predicate &PredL, + ICmpInst::Predicate &PredR) { + if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + return 0; // vectors are not (yet?) supported - if (LHS->getOperand(0)->getType()->isVectorTy()) return 0; + if (LHS->getOperand(0)->getType()->isVectorTy()) + return 0; // Here comes the tricky part: // LHS might be of the form L11 & L12 == X, X == L21 & L22, @@ -517,9 +421,9 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, // above. Value *L1 = LHS->getOperand(0); Value *L2 = LHS->getOperand(1); - Value *L11,*L12,*L21,*L22; + Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(LHS, LHSCC, L11, L12, L2)) { + if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { // Look for ANDs in the LHS icmp. @@ -543,22 +447,26 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } // Bail if LHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(LHSCC)) + if (!ICmpInst::isEquality(PredL)) return 0; Value *R1 = RHS->getOperand(0); Value *R2 = RHS->getOperand(1); - Value *R11,*R12; - bool ok = false; - if (decomposeBitTestICmp(RHS, RHSCC, R11, R12, R2)) { + Value *R11, *R12; + bool Ok = false; + if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; + A = R11; + D = R12; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; + A = R12; + D = R11; } else { return 0; } - E = R2; R1 = nullptr; ok = true; + E = R2; + R1 = nullptr; + Ok = true; } else if (R1->getType()->isIntegerTy()) { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an @@ -568,46 +476,62 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R2; ok = true; + A = R11; + D = R12; + E = R2; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R2; ok = true; + A = R12; + D = R11; + E = R2; + Ok = true; } } // Bail if RHS was a icmp that can't be decomposed into an equality. - if (!ICmpInst::isEquality(RHSCC)) + if (!ICmpInst::isEquality(PredR)) return 0; // Look for ANDs on the right side of the RHS icmp. - if (!ok && R2->getType()->isIntegerTy()) { + if (!Ok && R2->getType()->isIntegerTy()) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; R12 = Constant::getAllOnesValue(R2->getType()); } if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { - A = R11; D = R12; E = R1; ok = true; + A = R11; + D = R12; + E = R1; + Ok = true; } else if (R12 == L11 || R12 == L12 || R12 == L21 || R12 == L22) { - A = R12; D = R11; E = R1; ok = true; + A = R12; + D = R11; + E = R1; + Ok = true; } else { return 0; } } - if (!ok) + if (!Ok) return 0; if (L11 == A) { - B = L12; C = L2; + B = L12; + C = L2; } else if (L12 == A) { - B = L11; C = L2; + B = L11; + C = L2; } else if (L21 == A) { - B = L22; C = L1; + B = L22; + C = L1; } else if (L22 == A) { - B = L21; C = L1; + B = L21; + C = L1; } - unsigned LeftType = getTypeOfMaskedICmp(A, B, C, LHSCC); - unsigned RightType = getTypeOfMaskedICmp(A, D, E, RHSCC); + unsigned LeftType = getMaskedICmpType(A, B, C, PredL); + unsigned RightType = getMaskedICmpType(A, D, E, PredR); return LeftType & RightType; } @@ -616,12 +540,14 @@ static unsigned foldLogOpOfMaskedICmpsHelper(Value*& A, static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, llvm::InstCombiner::BuilderTy *Builder) { Value *A = nullptr, *B = nullptr, *C = nullptr, *D = nullptr, *E = nullptr; - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); - unsigned Mask = foldLogOpOfMaskedICmpsHelper(A, B, C, D, E, LHS, RHS, - LHSCC, RHSCC); - if (Mask == 0) return nullptr; - assert(ICmpInst::isEquality(LHSCC) && ICmpInst::isEquality(RHSCC) && - "foldLogOpOfMaskedICmpsHelper must return an equality predicate."); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); + unsigned Mask = + getMaskedTypeForICmpPair(A, B, C, D, E, LHS, RHS, PredL, PredR); + if (Mask == 0) + return nullptr; + + assert(ICmpInst::isEquality(PredL) && ICmpInst::isEquality(PredR) && + "Expected equality predicates for masked type of icmps."); // In full generality: // (icmp (A & B) Op C) | (icmp (A & D) Op E) @@ -642,7 +568,7 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Mask = conjugateICmpMask(Mask); } - if (Mask & FoldMskICmp_Mask_AllZeroes) { + if (Mask & Mask_AllZeros) { // (icmp eq (A & B), 0) & (icmp eq (A & D), 0) // -> (icmp eq (A & (B|D)), 0) Value *NewOr = Builder->CreateOr(B, D); @@ -653,14 +579,14 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, Value *Zero = Constant::getNullValue(A->getType()); return Builder->CreateICmp(NewCC, NewAnd, Zero); } - if (Mask & FoldMskICmp_BMask_AllOnes) { + if (Mask & BMask_AllOnes) { // (icmp eq (A & B), B) & (icmp eq (A & D), D) // -> (icmp eq (A & (B|D)), (B|D)) Value *NewOr = Builder->CreateOr(B, D); Value *NewAnd = Builder->CreateAnd(A, NewOr); return Builder->CreateICmp(NewCC, NewAnd, NewOr); } - if (Mask & FoldMskICmp_AMask_AllOnes) { + if (Mask & AMask_AllOnes) { // (icmp eq (A & B), A) & (icmp eq (A & D), A) // -> (icmp eq (A & (B&D)), A) Value *NewAnd1 = Builder->CreateAnd(B, D); @@ -672,11 +598,13 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // their actual values. This isn't strictly necessary, just a "handle the // easy cases for now" decision. ConstantInt *BCst = dyn_cast<ConstantInt>(B); - if (!BCst) return nullptr; + if (!BCst) + return nullptr; ConstantInt *DCst = dyn_cast<ConstantInt>(D); - if (!DCst) return nullptr; + if (!DCst) + return nullptr; - if (Mask & (FoldMskICmp_Mask_NotAllZeroes | FoldMskICmp_BMask_NotAllOnes)) { + if (Mask & (Mask_NotAllZeros | BMask_NotAllOnes)) { // (icmp ne (A & B), 0) & (icmp ne (A & D), 0) and // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), 0) or (icmp ne (A & D), 0) @@ -689,7 +617,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_AMask_NotAllOnes) { + + if (Mask & AMask_NotAllOnes) { // (icmp ne (A & B), B) & (icmp ne (A & D), D) // -> (icmp ne (A & B), A) or (icmp ne (A & D), A) // Only valid if one of the masks is a superset of the other (check "B|D" is @@ -701,7 +630,8 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, else if (NewMask == DCst->getValue()) return RHS; } - if (Mask & FoldMskICmp_BMask_Mixed) { + + if (Mask & BMask_Mixed) { // (icmp eq (A & B), C) & (icmp eq (A & D), E) // We already know that B & C == C && D & E == E. // If we can prove that (B & D) & (C ^ E) == 0, that is, the bits of @@ -713,23 +643,28 @@ static Value *foldLogOpOfMaskedICmps(ICmpInst *LHS, ICmpInst *RHS, bool IsAnd, // (icmp ne (A & B), B) & (icmp eq (A & D), D) // with B and D, having a single bit set. ConstantInt *CCst = dyn_cast<ConstantInt>(C); - if (!CCst) return nullptr; + if (!CCst) + return nullptr; ConstantInt *ECst = dyn_cast<ConstantInt>(E); - if (!ECst) return nullptr; - if (LHSCC != NewCC) + if (!ECst) + return nullptr; + if (PredL != NewCC) CCst = cast<ConstantInt>(ConstantExpr::getXor(BCst, CCst)); - if (RHSCC != NewCC) + if (PredR != NewCC) ECst = cast<ConstantInt>(ConstantExpr::getXor(DCst, ECst)); + // If there is a conflict, we should actually return a false for the // whole construct. if (((BCst->getValue() & DCst->getValue()) & (CCst->getValue() ^ ECst->getValue())) != 0) return ConstantInt::get(LHS->getType(), !IsAnd); + Value *NewOr1 = Builder->CreateOr(B, D); Value *NewOr2 = ConstantExpr::getOr(CCst, ECst); Value *NewAnd = Builder->CreateAnd(A, NewOr1); return Builder->CreateICmp(NewCC, NewAnd, NewOr2); } + return nullptr; } @@ -789,12 +724,67 @@ Value *InstCombiner::simplifyRangeCheck(ICmpInst *Cmp0, ICmpInst *Cmp1, return Builder->CreateICmp(NewPred, Input, RangeEnd); } +static Value * +foldAndOrOfEqualityCmpsWithConstants(ICmpInst *LHS, ICmpInst *RHS, + bool JoinedByAnd, + InstCombiner::BuilderTy *Builder) { + Value *X = LHS->getOperand(0); + if (X != RHS->getOperand(0)) + return nullptr; + + const APInt *C1, *C2; + if (!match(LHS->getOperand(1), m_APInt(C1)) || + !match(RHS->getOperand(1), m_APInt(C2))) + return nullptr; + + // We only handle (X != C1 && X != C2) and (X == C1 || X == C2). + ICmpInst::Predicate Pred = LHS->getPredicate(); + if (Pred != RHS->getPredicate()) + return nullptr; + if (JoinedByAnd && Pred != ICmpInst::ICMP_NE) + return nullptr; + if (!JoinedByAnd && Pred != ICmpInst::ICMP_EQ) + return nullptr; + + // The larger unsigned constant goes on the right. + if (C1->ugt(*C2)) + std::swap(C1, C2); + + APInt Xor = *C1 ^ *C2; + if (Xor.isPowerOf2()) { + // If LHSC and RHSC differ by only one bit, then set that bit in X and + // compare against the larger constant: + // (X == C1 || X == C2) --> (X | (C1 ^ C2)) == C2 + // (X != C1 && X != C2) --> (X | (C1 ^ C2)) != C2 + // We choose an 'or' with a Pow2 constant rather than the inverse mask with + // 'and' because that may lead to smaller codegen from a smaller constant. + Value *Or = Builder->CreateOr(X, ConstantInt::get(X->getType(), Xor)); + return Builder->CreateICmp(Pred, Or, ConstantInt::get(X->getType(), *C2)); + } + + // Special case: get the ordering right when the values wrap around zero. + // Ie, we assumed the constants were unsigned when swapping earlier. + if (*C1 == 0 && C2->isAllOnesValue()) + std::swap(C1, C2); + + if (*C1 == *C2 - 1) { + // (X == 13 || X == 14) --> X - 13 <=u 1 + // (X != 13 && X != 14) --> X - 13 >u 1 + // An 'add' is the canonical IR form, so favor that over a 'sub'. + Value *Add = Builder->CreateAdd(X, ConstantInt::get(X->getType(), -(*C1))); + auto NewPred = JoinedByAnd ? ICmpInst::ICMP_UGT : ICmpInst::ICMP_ULE; + return Builder->CreateICmp(NewPred, Add, ConstantInt::get(X->getType(), 1)); + } + + return nullptr; +} + /// Fold (icmp)&(icmp) if possible. Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // (icmp1 A, B) & (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -819,86 +809,90 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/false)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, true, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) & (icmp2 B, C2). - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (!LHSCst || !RHSCst) return nullptr; + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ult A, C) & (icmp ult B, C) --> (icmp ult (A|B), C) // where C is a power of 2 or // (icmp eq A, 0) & (icmp eq B, 0) --> (icmp eq (A|B), 0) - if ((LHSCC == ICmpInst::ICMP_ULT && LHSCst->getValue().isPowerOf2()) || - (LHSCC == ICmpInst::ICMP_EQ && LHSCst->isZero())) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if ((PredL == ICmpInst::ICMP_ULT && LHSC->getValue().isPowerOf2()) || + (PredL == ICmpInst::ICMP_EQ && LHSC->isZero())) { + Value *NewOr = Builder->CreateOr(LHS0, RHS0); + return Builder->CreateICmp(PredL, NewOr, LHSC); } } // (trunc x) == C1 & (and x, CA) == C2 -> (and x, CA|CMAX) == C1|C2 // where CMAX is the all ones value for the truncated type, // iff the lower bits of C2 and CA are zero. - if (LHSCC == ICmpInst::ICMP_EQ && LHSCC == RHSCC && - LHS->hasOneUse() && RHS->hasOneUse()) { + if (PredL == ICmpInst::ICMP_EQ && PredL == PredR && LHS->hasOneUse() && + RHS->hasOneUse()) { Value *V; - ConstantInt *AndCst, *SmallCst = nullptr, *BigCst = nullptr; + ConstantInt *AndC, *SmallC = nullptr, *BigC = nullptr; // (trunc x) == C1 & (and x, CA) == C2 // (and x, CA) == C2 & (trunc x) == C1 - if (match(Val2, m_Trunc(m_Value(V))) && - match(Val, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = RHSCst; - BigCst = LHSCst; - } else if (match(Val, m_Trunc(m_Value(V))) && - match(Val2, m_And(m_Specific(V), m_ConstantInt(AndCst)))) { - SmallCst = LHSCst; - BigCst = RHSCst; + if (match(RHS0, m_Trunc(m_Value(V))) && + match(LHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = RHSC; + BigC = LHSC; + } else if (match(LHS0, m_Trunc(m_Value(V))) && + match(RHS0, m_And(m_Specific(V), m_ConstantInt(AndC)))) { + SmallC = LHSC; + BigC = RHSC; } - if (SmallCst && BigCst) { - unsigned BigBitSize = BigCst->getType()->getBitWidth(); - unsigned SmallBitSize = SmallCst->getType()->getBitWidth(); + if (SmallC && BigC) { + unsigned BigBitSize = BigC->getType()->getBitWidth(); + unsigned SmallBitSize = SmallC->getType()->getBitWidth(); // Check that the low bits are zero. APInt Low = APInt::getLowBitsSet(BigBitSize, SmallBitSize); - if ((Low & AndCst->getValue()) == 0 && (Low & BigCst->getValue()) == 0) { - Value *NewAnd = Builder->CreateAnd(V, Low | AndCst->getValue()); - APInt N = SmallCst->getValue().zext(BigBitSize) | BigCst->getValue(); - Value *NewVal = ConstantInt::get(AndCst->getType()->getContext(), N); - return Builder->CreateICmp(LHSCC, NewAnd, NewVal); + if ((Low & AndC->getValue()) == 0 && (Low & BigC->getValue()) == 0) { + Value *NewAnd = Builder->CreateAnd(V, Low | AndC->getValue()); + APInt N = SmallC->getValue().zext(BigBitSize) | BigC->getValue(); + Value *NewVal = ConstantInt::get(AndC->getType()->getContext(), N); + return Builder->CreateICmp(PredL, NewAnd, NewVal); } } } // From here on, we only handle: // (icmp1 A, C1) & (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) & (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -907,113 +901,95 @@ Value *InstCombiner::FoldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS) { // icmp eq, icmp ne, icmp [su]lt, and icmp [SU]gt here. We also know // (from the icmp folding check above), that the two constants // are not equal and that the larger constant is on the RHS - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 - case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 - case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_NE: // (X == 13 & X != 15) -> X == 13 + case ICmpInst::ICMP_ULT: // (X == 13 & X < 15) -> X == 13 + case ICmpInst::ICMP_SLT: // (X == 13 & X < 15) -> X == 13 return LHS; } case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_ULT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X u< 14) -> X < 13 - return Builder->CreateICmpULT(Val, LHSCst); - if (LHSCst->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (LHSC == SubOne(RHSC)) // (X != 13 & X u< 14) -> X < 13 + return Builder->CreateICmpULT(LHS0, LHSC); + if (LHSC->isNullValue()) // (X != 0 & X u< 14) -> X-1 u< 13 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - break; // (X != 13 & X u< 15) -> no change + break; // (X != 13 & X u< 15) -> no change case ICmpInst::ICMP_SLT: - if (LHSCst == SubOne(RHSCst)) // (X != 13 & X s< 14) -> X < 13 - return Builder->CreateICmpSLT(Val, LHSCst); - break; // (X != 13 & X s< 15) -> no change - case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 - case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 + if (LHSC == SubOne(RHSC)) // (X != 13 & X s< 14) -> X < 13 + return Builder->CreateICmpSLT(LHS0, LHSC); + break; // (X != 13 & X s< 15) -> no change + case ICmpInst::ICMP_EQ: // (X != 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X != 13 & X u> 15) -> X u> 15 + case ICmpInst::ICMP_SGT: // (X != 13 & X s> 15) -> X s> 15 return RHS; case ICmpInst::ICMP_NE: - // Special case to get the ordering right when the values wrap around - // zero. - if (LHSCst->getValue() == 0 && RHSCst->getValue().isAllOnesValue()) - std::swap(LHSCst, RHSCst); - if (LHSCst == SubOne(RHSCst)){// (X != 13 & X != 14) -> X-13 >u 1 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - return Builder->CreateICmpUGT(Add, ConstantInt::get(Add->getType(), 1), - Val->getName()+".cmp"); - } - break; // (X != 13 & X != 15) -> no change + // Potential folds for this case should already be handled. + break; } break; case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false - case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 & X == 15) -> false + case ICmpInst::ICMP_UGT: // (X u< 13 & X u> 15) -> false return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case ICmpInst::ICMP_SGT: // (X u< 13 & X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 - case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 + case ICmpInst::ICMP_NE: // (X u< 13 & X != 15) -> X u< 13 + case ICmpInst::ICMP_ULT: // (X u< 13 & X u< 15) -> X u< 13 return LHS; - case ICmpInst::ICMP_SLT: // (X u< 13 & X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_UGT: // (X s< 13 & X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 - case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_NE: // (X s< 13 & X != 15) -> X < 13 + case ICmpInst::ICMP_SLT: // (X s< 13 & X s< 15) -> X < 13 return LHS; - case ICmpInst::ICMP_ULT: // (X s< 13 & X u< 15) -> no change - break; } break; case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_UGT: // (X u> 13 & X u> 15) -> X u> 15 return RHS; - case ICmpInst::ICMP_SGT: // (X u> 13 & X s> 15) -> no change - break; case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X u> 13 & X != 14) -> X u> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X u> 13 & X != 15) -> no change - case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), + if (RHSC == AddOne(LHSC)) // (X u> 13 & X != 14) -> X u> 14 + return Builder->CreateICmp(PredL, LHS0, RHSC); + break; // (X u> 13 & X != 15) -> no change + case ICmpInst::ICMP_ULT: // (X u> 13 & X u< 15) -> (X-14) <u 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), false, true); - case ICmpInst::ICMP_SLT: // (X u> 13 & X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 - case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 & X == 15) -> X == 15 + case ICmpInst::ICMP_SGT: // (X s> 13 & X s> 15) -> X s> 15 return RHS; - case ICmpInst::ICMP_UGT: // (X s> 13 & X u> 15) -> no change - break; case ICmpInst::ICMP_NE: - if (RHSCst == AddOne(LHSCst)) // (X s> 13 & X != 14) -> X s> 14 - return Builder->CreateICmp(LHSCC, Val, RHSCst); - break; // (X s> 13 & X != 15) -> no change - case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 - return insertRangeTest(Val, LHSCst->getValue() + 1, RHSCst->getValue(), - true, true); - case ICmpInst::ICMP_ULT: // (X s> 13 & X u< 15) -> no change - break; + if (RHSC == AddOne(LHSC)) // (X s> 13 & X != 14) -> X s> 14 + return Builder->CreateICmp(PredL, LHS0, RHSC); + break; // (X s> 13 & X != 15) -> no change + case ICmpInst::ICMP_SLT: // (X s> 13 & X s< 15) -> (X-14) s< 1 + return insertRangeTest(LHS0, LHSC->getValue() + 1, RHSC->getValue(), true, + true); } break; } @@ -1314,39 +1290,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { break; } - case Instruction::Add: - // ((A & N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) + B) & AndRHS -> (A + B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - if (Value *V = FoldLogicalPlusAnd(Op0RHS, Op0LHS, AndRHS, false, I)) - return BinaryOperator::CreateAnd(V, AndRHS); // Add commutes - break; - case Instruction::Sub: - // ((A & N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == AndRHS. - // ((A | N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - // ((A ^ N) - B) & AndRHS -> (A - B) & AndRHS iff N&AndRHS == 0 - if (Value *V = FoldLogicalPlusAnd(Op0LHS, Op0RHS, AndRHS, true, I)) - return BinaryOperator::CreateAnd(V, AndRHS); - // -x & 1 -> x & 1 if (AndRHSMask == 1 && match(Op0LHS, m_Zero())) return BinaryOperator::CreateAnd(Op0RHS, AndRHS); - // (A - N) & AndRHS -> -N & AndRHS iff A&AndRHS==0 and AndRHS - // has 1's for all bits that the subtraction with A might affect. - if (Op0I->hasOneUse() && !match(Op0LHS, m_Zero())) { - uint32_t BitWidth = AndRHSMask.getBitWidth(); - uint32_t Zeros = AndRHSMask.countLeadingZeros(); - APInt Mask = APInt::getLowBitsSet(BitWidth, BitWidth - Zeros); - - if (MaskedValueIsZero(Op0LHS, Mask, 0, &I)) { - Value *NewNeg = Builder->CreateNeg(Op0RHS); - return BinaryOperator::CreateAnd(NewNeg, AndRHS); - } - } break; case Instruction::Shl: @@ -1361,6 +1309,33 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { break; } + // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth + // of X and OP behaves well when given trunc(C1) and X. + switch (Op0I->getOpcode()) { + default: + break; + case Instruction::Xor: + case Instruction::Or: + case Instruction::Mul: + case Instruction::Add: + case Instruction::Sub: + Value *X; + ConstantInt *C1; + if (match(Op0I, m_c_BinOp(m_ZExt(m_Value(X)), m_ConstantInt(C1)))) { + if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { + auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); + Value *BinOp; + if (isa<ZExtInst>(Op0LHS)) + BinOp = Builder->CreateBinOp(Op0I->getOpcode(), X, TruncC1); + else + BinOp = Builder->CreateBinOp(Op0I->getOpcode(), TruncC1, X); + auto *TruncC2 = ConstantExpr::getTrunc(AndRHS, X->getType()); + auto *And = Builder->CreateAnd(BinOp, TruncC2); + return new ZExtInst(And, I.getType()); + } + } + } + if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) if (Instruction *Res = OptAndOp(Op0I, Op0CI, AndRHS, I)) return Res; @@ -1381,10 +1356,11 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { return BinaryOperator::CreateAnd(NewCast, C3); } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } if (Instruction *DeMorgan = matchDeMorgansLaws(I, Builder)) return DeMorgan; @@ -1630,15 +1606,15 @@ static Value *matchSelectFromAndOr(Value *A, Value *C, Value *B, Value *D, /// Fold (icmp)|(icmp) if possible. Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction *CxtI) { - ICmpInst::Predicate LHSCC = LHS->getPredicate(), RHSCC = RHS->getPredicate(); + ICmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); // Fold (iszero(A & K1) | iszero(A & K2)) -> (A & (K1 | K2)) != (K1 | K2) // if K1 and K2 are a one-bit mask. - ConstantInt *LHSCst = dyn_cast<ConstantInt>(LHS->getOperand(1)); - ConstantInt *RHSCst = dyn_cast<ConstantInt>(RHS->getOperand(1)); + ConstantInt *LHSC = dyn_cast<ConstantInt>(LHS->getOperand(1)); + ConstantInt *RHSC = dyn_cast<ConstantInt>(RHS->getOperand(1)); - if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero() && - RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { + if (LHS->getPredicate() == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero() && + RHS->getPredicate() == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { BinaryOperator *LAnd = dyn_cast<BinaryOperator>(LHS->getOperand(0)); BinaryOperator *RAnd = dyn_cast<BinaryOperator>(RHS->getOperand(0)); @@ -1680,52 +1656,52 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // 4) LowRange1 ^ LowRange2 and HighRange1 ^ HighRange2 are one-bit mask. // This implies all values in the two ranges differ by exactly one bit. - if ((LHSCC == ICmpInst::ICMP_ULT || LHSCC == ICmpInst::ICMP_ULE) && - LHSCC == RHSCC && LHSCst && RHSCst && LHS->hasOneUse() && - RHS->hasOneUse() && LHSCst->getType() == RHSCst->getType() && - LHSCst->getValue() == (RHSCst->getValue())) { + if ((PredL == ICmpInst::ICMP_ULT || PredL == ICmpInst::ICMP_ULE) && + PredL == PredR && LHSC && RHSC && LHS->hasOneUse() && RHS->hasOneUse() && + LHSC->getType() == RHSC->getType() && + LHSC->getValue() == (RHSC->getValue())) { Value *LAdd = LHS->getOperand(0); Value *RAdd = RHS->getOperand(0); Value *LAddOpnd, *RAddOpnd; - ConstantInt *LAddCst, *RAddCst; - if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddCst))) && - match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddCst))) && - LAddCst->getValue().ugt(LHSCst->getValue()) && - RAddCst->getValue().ugt(LHSCst->getValue())) { - - APInt DiffCst = LAddCst->getValue() ^ RAddCst->getValue(); - if (LAddOpnd == RAddOpnd && DiffCst.isPowerOf2()) { - ConstantInt *MaxAddCst = nullptr; - if (LAddCst->getValue().ult(RAddCst->getValue())) - MaxAddCst = RAddCst; + ConstantInt *LAddC, *RAddC; + if (match(LAdd, m_Add(m_Value(LAddOpnd), m_ConstantInt(LAddC))) && + match(RAdd, m_Add(m_Value(RAddOpnd), m_ConstantInt(RAddC))) && + LAddC->getValue().ugt(LHSC->getValue()) && + RAddC->getValue().ugt(LHSC->getValue())) { + + APInt DiffC = LAddC->getValue() ^ RAddC->getValue(); + if (LAddOpnd == RAddOpnd && DiffC.isPowerOf2()) { + ConstantInt *MaxAddC = nullptr; + if (LAddC->getValue().ult(RAddC->getValue())) + MaxAddC = RAddC; else - MaxAddCst = LAddCst; + MaxAddC = LAddC; - APInt RRangeLow = -RAddCst->getValue(); - APInt RRangeHigh = RRangeLow + LHSCst->getValue(); - APInt LRangeLow = -LAddCst->getValue(); - APInt LRangeHigh = LRangeLow + LHSCst->getValue(); + APInt RRangeLow = -RAddC->getValue(); + APInt RRangeHigh = RRangeLow + LHSC->getValue(); + APInt LRangeLow = -LAddC->getValue(); + APInt LRangeHigh = LRangeLow + LHSC->getValue(); APInt LowRangeDiff = RRangeLow ^ LRangeLow; APInt HighRangeDiff = RRangeHigh ^ LRangeHigh; APInt RangeDiff = LRangeLow.sgt(RRangeLow) ? LRangeLow - RRangeLow : RRangeLow - LRangeLow; if (LowRangeDiff.isPowerOf2() && LowRangeDiff == HighRangeDiff && - RangeDiff.ugt(LHSCst->getValue())) { - Value *MaskCst = ConstantInt::get(LAddCst->getType(), ~DiffCst); + RangeDiff.ugt(LHSC->getValue())) { + Value *MaskC = ConstantInt::get(LAddC->getType(), ~DiffC); - Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskCst); - Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddCst); - return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSCst)); + Value *NewAnd = Builder->CreateAnd(LAddOpnd, MaskC); + Value *NewAdd = Builder->CreateAdd(NewAnd, MaxAddC); + return (Builder->CreateICmp(LHS->getPredicate(), NewAdd, LHSC)); } } } } // (icmp1 A, B) | (icmp2 A, B) --> (icmp3 A, B) - if (PredicatesFoldable(LHSCC, RHSCC)) { + if (PredicatesFoldable(PredL, PredR)) { if (LHS->getOperand(0) == RHS->getOperand(1) && LHS->getOperand(1) == RHS->getOperand(0)) LHS->swapOperands(); @@ -1743,25 +1719,25 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, false, Builder)) return V; - Value *Val = LHS->getOperand(0), *Val2 = RHS->getOperand(0); + Value *LHS0 = LHS->getOperand(0), *RHS0 = RHS->getOperand(0); if (LHS->hasOneUse() || RHS->hasOneUse()) { // (icmp eq B, 0) | (icmp ult A, B) -> (icmp ule A, B-1) // (icmp eq B, 0) | (icmp ugt B, A) -> (icmp ule A, B-1) Value *A = nullptr, *B = nullptr; - if (LHSCC == ICmpInst::ICMP_EQ && LHSCst && LHSCst->isZero()) { - B = Val; - if (RHSCC == ICmpInst::ICMP_ULT && Val == RHS->getOperand(1)) - A = Val2; - else if (RHSCC == ICmpInst::ICMP_UGT && Val == Val2) + if (PredL == ICmpInst::ICMP_EQ && LHSC && LHSC->isZero()) { + B = LHS0; + if (PredR == ICmpInst::ICMP_ULT && LHS0 == RHS->getOperand(1)) + A = RHS0; + else if (PredR == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = RHS->getOperand(1); } // (icmp ult A, B) | (icmp eq B, 0) -> (icmp ule A, B-1) // (icmp ugt B, A) | (icmp eq B, 0) -> (icmp ule A, B-1) - else if (RHSCC == ICmpInst::ICMP_EQ && RHSCst && RHSCst->isZero()) { - B = Val2; - if (LHSCC == ICmpInst::ICMP_ULT && Val2 == LHS->getOperand(1)) - A = Val; - else if (LHSCC == ICmpInst::ICMP_UGT && Val2 == Val) + else if (PredR == ICmpInst::ICMP_EQ && RHSC && RHSC->isZero()) { + B = RHS0; + if (PredL == ICmpInst::ICMP_ULT && RHS0 == LHS->getOperand(1)) + A = LHS0; + else if (PredL == ICmpInst::ICMP_UGT && LHS0 == RHS0) A = LHS->getOperand(1); } if (A && B) @@ -1778,54 +1754,58 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, if (Value *V = simplifyRangeCheck(RHS, LHS, /*Inverted=*/true)) return V; + if (Value *V = foldAndOrOfEqualityCmpsWithConstants(LHS, RHS, false, Builder)) + return V; + // This only handles icmp of constants: (icmp1 A, C1) | (icmp2 B, C2). - if (!LHSCst || !RHSCst) return nullptr; + if (!LHSC || !RHSC) + return nullptr; - if (LHSCst == RHSCst && LHSCC == RHSCC) { + if (LHSC == RHSC && PredL == PredR) { // (icmp ne A, 0) | (icmp ne B, 0) --> (icmp ne (A|B), 0) - if (LHSCC == ICmpInst::ICMP_NE && LHSCst->isZero()) { - Value *NewOr = Builder->CreateOr(Val, Val2); - return Builder->CreateICmp(LHSCC, NewOr, LHSCst); + if (PredL == ICmpInst::ICMP_NE && LHSC->isZero()) { + Value *NewOr = Builder->CreateOr(LHS0, RHS0); + return Builder->CreateICmp(PredL, NewOr, LHSC); } } // (icmp ult (X + CA), C1) | (icmp eq X, C2) -> (icmp ule (X + CA), C1) // iff C2 + CA == C1. - if (LHSCC == ICmpInst::ICMP_ULT && RHSCC == ICmpInst::ICMP_EQ) { - ConstantInt *AddCst; - if (match(Val, m_Add(m_Specific(Val2), m_ConstantInt(AddCst)))) - if (RHSCst->getValue() + AddCst->getValue() == LHSCst->getValue()) - return Builder->CreateICmpULE(Val, LHSCst); + if (PredL == ICmpInst::ICMP_ULT && PredR == ICmpInst::ICMP_EQ) { + ConstantInt *AddC; + if (match(LHS0, m_Add(m_Specific(RHS0), m_ConstantInt(AddC)))) + if (RHSC->getValue() + AddC->getValue() == LHSC->getValue()) + return Builder->CreateICmpULE(LHS0, LHSC); } // From here on, we only handle: // (icmp1 A, C1) | (icmp2 A, C2) --> something simpler. - if (Val != Val2) return nullptr; + if (LHS0 != RHS0) + return nullptr; - // ICMP_[US][GL]E X, CST is folded to ICMP_[US][GL]T elsewhere. - if (LHSCC == ICmpInst::ICMP_UGE || LHSCC == ICmpInst::ICMP_ULE || - RHSCC == ICmpInst::ICMP_UGE || RHSCC == ICmpInst::ICMP_ULE || - LHSCC == ICmpInst::ICMP_SGE || LHSCC == ICmpInst::ICMP_SLE || - RHSCC == ICmpInst::ICMP_SGE || RHSCC == ICmpInst::ICMP_SLE) + // ICMP_[US][GL]E X, C is folded to ICMP_[US][GL]T elsewhere. + if (PredL == ICmpInst::ICMP_UGE || PredL == ICmpInst::ICMP_ULE || + PredR == ICmpInst::ICMP_UGE || PredR == ICmpInst::ICMP_ULE || + PredL == ICmpInst::ICMP_SGE || PredL == ICmpInst::ICMP_SLE || + PredR == ICmpInst::ICMP_SGE || PredR == ICmpInst::ICMP_SLE) return nullptr; // We can't fold (ugt x, C) | (sgt x, C2). - if (!PredicatesFoldable(LHSCC, RHSCC)) + if (!PredicatesFoldable(PredL, PredR)) return nullptr; // Ensure that the larger constant is on the RHS. bool ShouldSwap; - if (CmpInst::isSigned(LHSCC) || - (ICmpInst::isEquality(LHSCC) && - CmpInst::isSigned(RHSCC))) - ShouldSwap = LHSCst->getValue().sgt(RHSCst->getValue()); + if (CmpInst::isSigned(PredL) || + (ICmpInst::isEquality(PredL) && CmpInst::isSigned(PredR))) + ShouldSwap = LHSC->getValue().sgt(RHSC->getValue()); else - ShouldSwap = LHSCst->getValue().ugt(RHSCst->getValue()); + ShouldSwap = LHSC->getValue().ugt(RHSC->getValue()); if (ShouldSwap) { std::swap(LHS, RHS); - std::swap(LHSCst, RHSCst); - std::swap(LHSCC, RHSCC); + std::swap(LHSC, RHSC); + std::swap(PredL, PredR); } // At this point, we know we have two icmp instructions @@ -1834,127 +1814,98 @@ Value *InstCombiner::FoldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, // ICMP_EQ, ICMP_NE, ICMP_LT, and ICMP_GT here. We also know (from the // icmp folding check above), that the two constants are not // equal. - assert(LHSCst != RHSCst && "Compares not folded above?"); + assert(LHSC != RHSC && "Compares not folded above?"); - switch (LHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredL) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); case ICmpInst::ICMP_EQ: - if (LHS->getOperand(0) == RHS->getOperand(0)) { - // if LHSCst and RHSCst differ only by one bit: - // (A == C1 || A == C2) -> (A | (C1 ^ C2)) == C2 - assert(LHSCst->getValue().ule(LHSCst->getValue())); - - APInt Xor = LHSCst->getValue() ^ RHSCst->getValue(); - if (Xor.isPowerOf2()) { - Value *Cst = Builder->getInt(Xor); - Value *Or = Builder->CreateOr(LHS->getOperand(0), Cst); - return Builder->CreateICmp(ICmpInst::ICMP_EQ, Or, RHSCst); - } - } - - if (LHSCst == SubOne(RHSCst)) { - // (X == 13 | X == 14) -> X-13 <u 2 - Constant *AddCST = ConstantExpr::getNeg(LHSCst); - Value *Add = Builder->CreateAdd(Val, AddCST, Val->getName()+".off"); - AddCST = ConstantExpr::getSub(AddOne(RHSCst), LHSCst); - return Builder->CreateICmpULT(Add, AddCST); - } - - break; // (X == 13 | X == 15) -> no change - case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change - case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change + // Potential folds for this case should already be handled. + break; + case ICmpInst::ICMP_UGT: // (X == 13 | X u> 14) -> no change + case ICmpInst::ICMP_SGT: // (X == 13 | X s> 14) -> no change break; - case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 - case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 + case ICmpInst::ICMP_NE: // (X == 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X == 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_SLT: // (X == 13 | X s< 15) -> X s< 15 return RHS; } break; case ICmpInst::ICMP_NE: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 - case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 - case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X != 13 | X == 15) -> X != 13 + case ICmpInst::ICMP_UGT: // (X != 13 | X u> 15) -> X != 13 + case ICmpInst::ICMP_SGT: // (X != 13 | X s> 15) -> X != 13 return LHS; - case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true - case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true + case ICmpInst::ICMP_NE: // (X != 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X != 13 | X u< 15) -> true + case ICmpInst::ICMP_SLT: // (X != 13 | X s< 15) -> true return Builder->getTrue(); } case ICmpInst::ICMP_ULT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling + case ICmpInst::ICMP_UGT: // (X u< 13 | X u> 15) -> (X-13) u> 2 + // If RHSC is [us]MAXINT, it is always false. Not handling // this can cause overflow. - if (RHSCst->isMaxValue(false)) + if (RHSC->isMaxValue(false)) return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, false, false); - case ICmpInst::ICMP_SGT: // (X u< 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 + case ICmpInst::ICMP_NE: // (X u< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_ULT: // (X u< 13 | X u< 15) -> X u< 15 return RHS; - case ICmpInst::ICMP_SLT: // (X u< 13 | X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SLT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s< 13 | X == 14) -> no change break; - case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 - // If RHSCst is [us]MAXINT, it is always false. Not handling + case ICmpInst::ICMP_SGT: // (X s< 13 | X s> 15) -> (X-13) s> 2 + // If RHSC is [us]MAXINT, it is always false. Not handling // this can cause overflow. - if (RHSCst->isMaxValue(true)) + if (RHSC->isMaxValue(true)) return LHS; - return insertRangeTest(Val, LHSCst->getValue(), RHSCst->getValue() + 1, - true, false); - case ICmpInst::ICMP_UGT: // (X s< 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 - case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 + return insertRangeTest(LHS0, LHSC->getValue(), RHSC->getValue() + 1, true, + false); + case ICmpInst::ICMP_NE: // (X s< 13 | X != 15) -> X != 15 + case ICmpInst::ICMP_SLT: // (X s< 13 | X s< 15) -> X s< 15 return RHS; - case ICmpInst::ICMP_ULT: // (X s< 13 | X u< 15) -> no change - break; } break; case ICmpInst::ICMP_UGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 - case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X u> 13 | X == 15) -> X u> 13 + case ICmpInst::ICMP_UGT: // (X u> 13 | X u> 15) -> X u> 13 return LHS; - case ICmpInst::ICMP_SGT: // (X u> 13 | X s> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true - case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true + case ICmpInst::ICMP_NE: // (X u> 13 | X != 15) -> true + case ICmpInst::ICMP_ULT: // (X u> 13 | X u< 15) -> true return Builder->getTrue(); - case ICmpInst::ICMP_SLT: // (X u> 13 | X s< 15) -> no change - break; } break; case ICmpInst::ICMP_SGT: - switch (RHSCC) { - default: llvm_unreachable("Unknown integer condition code!"); - case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 - case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 + switch (PredR) { + default: + llvm_unreachable("Unknown integer condition code!"); + case ICmpInst::ICMP_EQ: // (X s> 13 | X == 15) -> X > 13 + case ICmpInst::ICMP_SGT: // (X s> 13 | X s> 15) -> X > 13 return LHS; - case ICmpInst::ICMP_UGT: // (X s> 13 | X u> 15) -> no change - break; - case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true - case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true + case ICmpInst::ICMP_NE: // (X s> 13 | X != 15) -> true + case ICmpInst::ICMP_SLT: // (X s> 13 | X s< 15) -> true return Builder->getTrue(); - case ICmpInst::ICMP_ULT: // (X s> 13 | X u< 15) -> no change - break; } break; } @@ -2100,17 +2051,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { ConstantInt *C1 = nullptr; Value *X = nullptr; - // (X & C1) | C2 --> (X | C2) & (C1|C2) - // iff (C1 & C2) == 0. - if (match(Op0, m_And(m_Value(X), m_ConstantInt(C1))) && - (RHS->getValue() & C1->getValue()) != 0 && - Op0->hasOneUse()) { - Value *Or = Builder->CreateOr(X, RHS); - Or->takeName(Op0); - return BinaryOperator::CreateAnd(Or, - Builder->getInt(RHS->getValue() | C1->getValue())); - } - // (X ^ C1) | C2 --> (X | C2) ^ (C1&~C2) if (match(Op0, m_Xor(m_Value(X), m_ConstantInt(C1))) && Op0->hasOneUse()) { @@ -2119,45 +2059,51 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { return BinaryOperator::CreateXor(Or, Builder->getInt(C1->getValue() & ~RHS->getValue())); } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } // Given an OR instruction, check to see if this is a bswap. if (Instruction *BSwap = MatchBSwap(I)) return BSwap; - Value *A = nullptr, *B = nullptr; - ConstantInt *C1 = nullptr, *C2 = nullptr; + { + Value *A; + const APInt *C; + // (X^C)|Y -> (X|Y)^C iff Y&C == 0 + if (match(Op0, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op1, *C, 0, &I)) { + Value *NOr = Builder->CreateOr(A, Op1); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + cast<Instruction>(Op0)->getOperand(1)); + } - // (X^C)|Y -> (X|Y)^C iff Y&C == 0 - if (Op0->hasOneUse() && - match(Op0, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op1, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op1); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); + // Y|(X^C) -> (X|Y)^C iff Y&C == 0 + if (match(Op1, m_OneUse(m_Xor(m_Value(A), m_APInt(C)))) && + MaskedValueIsZero(Op0, *C, 0, &I)) { + Value *NOr = Builder->CreateOr(A, Op0); + NOr->takeName(Op0); + return BinaryOperator::CreateXor(NOr, + cast<Instruction>(Op1)->getOperand(1)); + } } - // Y|(X^C) -> (X|Y)^C iff Y&C == 0 - if (Op1->hasOneUse() && - match(Op1, m_Xor(m_Value(A), m_ConstantInt(C1))) && - MaskedValueIsZero(Op0, C1->getValue(), 0, &I)) { - Value *NOr = Builder->CreateOr(A, Op0); - NOr->takeName(Op0); - return BinaryOperator::CreateXor(NOr, C1); - } + Value *A, *B; // ((~A & B) | A) -> (A | B) - if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1, m_Specific(A))) - return BinaryOperator::CreateOr(A, B); + if (match(Op0, m_c_And(m_Not(m_Specific(Op1)), m_Value(A)))) + return BinaryOperator::CreateOr(A, Op1); + if (match(Op1, m_c_And(m_Not(m_Specific(Op0)), m_Value(A)))) + return BinaryOperator::CreateOr(Op0, A); // ((A & B) | ~A) -> (~A | B) - if (match(Op0, m_And(m_Value(A), m_Value(B))) && - match(Op1, m_Not(m_Specific(A)))) - return BinaryOperator::CreateOr(Builder->CreateNot(A), B); + // The NOT is guaranteed to be in the RHS by complexity ordering. + if (match(Op1, m_Not(m_Value(A))) && + match(Op0, m_c_And(m_Specific(A), m_Value(B)))) + return BinaryOperator::CreateOr(Op1, B); // (A & ~B) | (A ^ B) -> (A ^ B) // (~B & A) | (A ^ B) -> (A ^ B) @@ -2177,8 +2123,8 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { Value *V1 = nullptr, *V2 = nullptr; - C1 = dyn_cast<ConstantInt>(C); - C2 = dyn_cast<ConstantInt>(D); + ConstantInt *C1 = dyn_cast<ConstantInt>(C); + ConstantInt *C2 = dyn_cast<ConstantInt>(D); if (C1 && C2) { // (A & C1)|(B & C2) if ((C1->getValue() & C2->getValue()) == 0) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) @@ -2403,6 +2349,7 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { // be simplified by a later pass either, so we try swapping the inner/outer // ORs in the hopes that we'll be able to simplify it this way. // (X|C) | V --> (X|V) | C + ConstantInt *C1; if (Op0->hasOneUse() && !isa<ConstantInt>(Op1) && match(Op0, m_Or(m_Value(A), m_ConstantInt(C1)))) { Value *Inner = Builder->CreateOr(A, Op1); @@ -2493,23 +2440,22 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - if (Constant *RHS = dyn_cast<Constant>(Op1)) { - if (RHS->isAllOnesValue() && Op0->hasOneUse()) - // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B - if (CmpInst *CI = dyn_cast<CmpInst>(Op0)) - return CmpInst::Create(CI->getOpcode(), - CI->getInversePredicate(), - CI->getOperand(0), CI->getOperand(1)); + // xor (cmp A, B), true = not (cmp A, B) = !cmp A, B + ICmpInst::Predicate Pred; + if (match(Op0, m_OneUse(m_Cmp(Pred, m_Value(), m_Value()))) && + match(Op1, m_AllOnes())) { + cast<CmpInst>(Op0)->setPredicate(CmpInst::getInversePredicate(Pred)); + return replaceInstUsesWith(I, Op0); } - if (ConstantInt *RHS = dyn_cast<ConstantInt>(Op1)) { + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { if (CI->hasOneUse() && Op0C->hasOneUse()) { Instruction::CastOps Opcode = Op0C->getOpcode(); if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && - (RHS == ConstantExpr::getCast(Opcode, Builder->getTrue(), + (RHSC == ConstantExpr::getCast(Opcode, Builder->getTrue(), Op0C->getDestTy()))) { CI->setPredicate(CI->getInversePredicate()); return CastInst::Create(Opcode, CI, Op0C->getType()); @@ -2520,26 +2466,23 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { // ~(c-X) == X-c-1 == X+(-c-1) - if (Op0I->getOpcode() == Instruction::Sub && RHS->isAllOnesValue()) + if (Op0I->getOpcode() == Instruction::Sub && RHSC->isAllOnesValue()) if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); - Constant *ConstantRHS = ConstantExpr::getSub(NegOp0I0C, - ConstantInt::get(I.getType(), 1)); - return BinaryOperator::CreateAdd(Op0I->getOperand(1), ConstantRHS); + return BinaryOperator::CreateAdd(Op0I->getOperand(1), + SubOne(NegOp0I0C)); } if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { if (Op0I->getOpcode() == Instruction::Add) { // ~(X-c) --> (-c-1)-X - if (RHS->isAllOnesValue()) { + if (RHSC->isAllOnesValue()) { Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); - return BinaryOperator::CreateSub( - ConstantExpr::getSub(NegOp0CI, - ConstantInt::get(I.getType(), 1)), - Op0I->getOperand(0)); - } else if (RHS->getValue().isSignBit()) { + return BinaryOperator::CreateSub(SubOne(NegOp0CI), + Op0I->getOperand(0)); + } else if (RHSC->getValue().isSignBit()) { // (X + C) ^ signbit -> (X + C + signbit) - Constant *C = Builder->getInt(RHS->getValue() + Op0CI->getValue()); + Constant *C = Builder->getInt(RHSC->getValue() + Op0CI->getValue()); return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); } @@ -2547,10 +2490,10 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), 0, &I)) { - Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHS); + Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC); // Anything in both C1 and C2 is known to be zero, remove it from // NewRHS. - Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHS); + Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC); NewRHS = ConstantExpr::getAnd(NewRHS, ConstantExpr::getNot(CommonBits)); Worklist.Add(Op0I); @@ -2568,7 +2511,7 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { E1->getOpcode() == Instruction::Xor && (C1 = dyn_cast<ConstantInt>(E1->getOperand(1)))) { // fold (C1 >> C2) ^ C3 - ConstantInt *C2 = Op0CI, *C3 = RHS; + ConstantInt *C2 = Op0CI, *C3 = RHSC; APInt FoldConst = C1->getValue().lshr(C2->getValue()); FoldConst ^= C3->getValue(); // Prepare the two operands. @@ -2582,27 +2525,26 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } } + } + if (isa<Constant>(Op1)) if (Instruction *FoldedLogic = foldOpWithConstantIntoOperand(I)) return FoldedLogic; - } - BinaryOperator *Op1I = dyn_cast<BinaryOperator>(Op1); - if (Op1I) { + { Value *A, *B; - if (match(Op1I, m_Or(m_Value(A), m_Value(B)))) { - if (A == Op0) { // B^(B|A) == (A|B)^B - Op1I->swapOperands(); - I.swapOperands(); - std::swap(Op0, Op1); - } else if (B == Op0) { // B^(A|B) == (A|B)^B + if (match(Op1, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { + if (A == Op0) { // A^(A|B) == A^(B|A) + cast<BinaryOperator>(Op1)->swapOperands(); + std::swap(A, B); + } + if (B == Op0) { // A^(B|A) == (B|A)^A I.swapOperands(); // Simplified below. std::swap(Op0, Op1); } - } else if (match(Op1I, m_And(m_Value(A), m_Value(B))) && - Op1I->hasOneUse()){ + } else if (match(Op1, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op0) { // A^(A&B) -> A^(B&A) - Op1I->swapOperands(); + cast<BinaryOperator>(Op1)->swapOperands(); std::swap(A, B); } if (B == Op0) { // A^(B&A) -> (B&A)^A @@ -2612,65 +2554,63 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { } } - BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0); - if (Op0I) { + { Value *A, *B; - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()) { + if (match(Op0, m_OneUse(m_Or(m_Value(A), m_Value(B))))) { if (A == Op1) // (B|A)^B == (A|B)^B std::swap(A, B); if (B == Op1) // (A|B)^B == A & ~B return BinaryOperator::CreateAnd(A, Builder->CreateNot(Op1)); - } else if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - Op0I->hasOneUse()){ + } else if (match(Op0, m_OneUse(m_And(m_Value(A), m_Value(B))))) { if (A == Op1) // (A&B)^A -> (B&A)^A std::swap(A, B); + const APInt *C; if (B == Op1 && // (B&A)^A == ~B & A - !isa<ConstantInt>(Op1)) { // Canonical form is (B&C)^C + !match(Op1, m_APInt(C))) { // Canonical form is (B&C)^C return BinaryOperator::CreateAnd(Builder->CreateNot(A), Op1); } } } - if (Op0I && Op1I) { + { Value *A, *B, *C, *D; // (A & B)^(A | B) -> A ^ B - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Or(m_Value(C), m_Value(D)))) { + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_Or(m_Value(C), m_Value(D)))) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } // (A | B)^(A & B) -> A ^ B - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Value(C), m_Value(D)))) { + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_And(m_Value(C), m_Value(D)))) { if ((A == C && B == D) || (A == D && B == C)) return BinaryOperator::CreateXor(A, B); } // (A | ~B) ^ (~A | B) -> A ^ B // (~B | A) ^ (~A | B) -> A ^ B - if (match(Op0I, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(Op0, m_c_Or(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_Or(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (~A | B) ^ (A | ~B) -> A ^ B - if (match(Op0I, m_Or(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { + if (match(Op0, m_Or(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_Or(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A & ~B) ^ (~A & B) -> A ^ B // (~B & A) ^ (~A & B) -> A ^ B - if (match(Op0I, m_c_And(m_Value(A), m_Not(m_Value(B)))) && - match(Op1I, m_And(m_Not(m_Specific(A)), m_Specific(B)))) + if (match(Op0, m_c_And(m_Value(A), m_Not(m_Value(B)))) && + match(Op1, m_And(m_Not(m_Specific(A)), m_Specific(B)))) return BinaryOperator::CreateXor(A, B); // (~A & B) ^ (A & ~B) -> A ^ B - if (match(Op0I, m_And(m_Not(m_Value(A)), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Not(m_Specific(B))))) { + if (match(Op0, m_And(m_Not(m_Value(A)), m_Value(B))) && + match(Op1, m_And(m_Specific(A), m_Not(m_Specific(B))))) { return BinaryOperator::CreateXor(A, B); } // (A ^ C)^(A | B) -> ((~A) & B) ^ C - if (match(Op0I, m_Xor(m_Value(D), m_Value(C))) && - match(Op1I, m_Or(m_Value(A), m_Value(B)))) { + if (match(Op0, m_Xor(m_Value(D), m_Value(C))) && + match(Op1, m_Or(m_Value(A), m_Value(B)))) { if (D == A) return BinaryOperator::CreateXor( Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2679,8 +2619,8 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Builder->CreateAnd(Builder->CreateNot(B), A), C); } // (A | B)^(A ^ C) -> ((~A) & B) ^ C - if (match(Op0I, m_Or(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Value(D), m_Value(C)))) { + if (match(Op0, m_Or(m_Value(A), m_Value(B))) && + match(Op1, m_Xor(m_Value(D), m_Value(C)))) { if (D == A) return BinaryOperator::CreateXor( Builder->CreateAnd(Builder->CreateNot(A), B), C); @@ -2689,12 +2629,12 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { Builder->CreateAnd(Builder->CreateNot(B), A), C); } // (A & B) ^ (A ^ B) -> (A | B) - if (match(Op0I, m_And(m_Value(A), m_Value(B))) && - match(Op1I, m_Xor(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + match(Op1, m_c_Xor(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); // (A ^ B) ^ (A & B) -> (A | B) - if (match(Op0I, m_Xor(m_Value(A), m_Value(B))) && - match(Op1I, m_And(m_Specific(A), m_Specific(B)))) + if (match(Op0, m_Xor(m_Value(A), m_Value(B))) && + match(Op1, m_c_And(m_Specific(A), m_Specific(B)))) return BinaryOperator::CreateOr(A, B); } |