diff options
Diffstat (limited to 'lib/Analysis/InstructionSimplify.cpp')
-rw-r--r-- | lib/Analysis/InstructionSimplify.cpp | 167 |
1 files changed, 88 insertions, 79 deletions
diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 4a713f441ce87..5728887cc1e9c 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -1317,7 +1317,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0, // If all valid bits in the shift amount are known zero, the first operand is // unchanged. unsigned NumValidShiftBits = Log2_32_Ceil(BitWidth); - if (Known.Zero.countTrailingOnes() >= NumValidShiftBits) + if (Known.countMinTrailingZeros() >= NumValidShiftBits) return Op0; return nullptr; @@ -1536,7 +1536,7 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1, auto Range0 = ConstantRange::makeExactICmpRegion(Cmp0->getPredicate(), *C0); auto Range1 = ConstantRange::makeExactICmpRegion(Cmp1->getPredicate(), *C1); - // For and-of-comapares, check if the intersection is empty: + // For and-of-compares, check if the intersection is empty: // (icmp X, C0) && (icmp X, C1) --> empty set --> false if (IsAnd && Range0.intersectWith(Range1).isEmptySet()) return getFalse(Cmp0->getType()); @@ -1870,6 +1870,24 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q, match(Op1, m_c_And(m_Not(m_Specific(A)), m_Specific(B))))) return Op0; + // (A & B) | (~A ^ B) -> (~A ^ B) + // (B & A) | (~A ^ B) -> (~A ^ B) + // (A & B) | (B ^ ~A) -> (B ^ ~A) + // (B & A) | (B ^ ~A) -> (B ^ ~A) + if (match(Op0, m_And(m_Value(A), m_Value(B))) && + (match(Op1, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op1, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op1; + + // (~A ^ B) | (A & B) -> (~A ^ B) + // (~A ^ B) | (B & A) -> (~A ^ B) + // (B ^ ~A) | (A & B) -> (B ^ ~A) + // (B ^ ~A) | (B & A) -> (B ^ ~A) + if (match(Op1, m_And(m_Value(A), m_Value(B))) && + (match(Op0, m_c_Xor(m_Specific(A), m_Not(m_Specific(B)))) || + match(Op0, m_c_Xor(m_Not(m_Specific(A)), m_Specific(B))))) + return Op0; + if (Value *V = simplifyAndOrOfICmps(Op0, Op1, false)) return V; @@ -2286,7 +2304,6 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, return nullptr; Type *ITy = GetCompareTy(LHS); // The return type. - bool LHSKnownNonNegative, LHSKnownNegative; switch (Pred) { default: llvm_unreachable("Unknown ICmp predicate!"); @@ -2304,39 +2321,41 @@ static Value *simplifyICmpWithZero(CmpInst::Predicate Pred, Value *LHS, if (isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; - case ICmpInst::ICMP_SLT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + case ICmpInst::ICMP_SLT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getFalse(ITy); break; - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SLE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getTrue(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getFalse(ITy); break; - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGE: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative) + if (LHSKnown.isNonNegative()) return getTrue(ITy); break; - case ICmpInst::ICMP_SGT: - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNegative) + } + case ICmpInst::ICMP_SGT: { + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNegative()) return getFalse(ITy); - if (LHSKnownNonNegative && isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) + if (LHSKnown.isNonNegative() && + isKnownNonZero(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT)) return getTrue(ITy); break; } + } return nullptr; } @@ -2535,6 +2554,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS, return nullptr; } +/// TODO: A large part of this logic is duplicated in InstCombine's +/// foldICmpBinOp(). We should be able to share that and avoid the code +/// duplication. static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, Value *RHS, const SimplifyQuery &Q, unsigned MaxRecurse) { @@ -2616,15 +2638,11 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getTrue(ITy); if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGE) { - bool RHSKnownNonNegative, RHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(RHS, RHSKnownNonNegative, RHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (RHSKnownNonNegative && YKnownNegative) + KnownBits RHSKnown = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (RHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SLT ? getTrue(ITy) : getFalse(ITy); - if (RHSKnownNegative || YKnownNonNegative) + if (RHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SLT ? getFalse(ITy) : getTrue(ITy); } } @@ -2636,15 +2654,11 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, return getFalse(ITy); if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLE) { - bool LHSKnownNonNegative, LHSKnownNegative; - bool YKnownNonNegative, YKnownNegative; - ComputeSignBit(LHS, LHSKnownNonNegative, LHSKnownNegative, Q.DL, 0, - Q.AC, Q.CxtI, Q.DT); - ComputeSignBit(Y, YKnownNonNegative, YKnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (LHSKnownNonNegative && YKnownNegative) + KnownBits LHSKnown = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + KnownBits YKnown = computeKnownBits(Y, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (LHSKnown.isNonNegative() && YKnown.isNegative()) return Pred == ICmpInst::ICMP_SGT ? getTrue(ITy) : getFalse(ITy); - if (LHSKnownNegative || YKnownNonNegative) + if (LHSKnown.isNegative() || YKnown.isNonNegative()) return Pred == ICmpInst::ICMP_SGT ? getFalse(ITy) : getTrue(ITy); } } @@ -2691,28 +2705,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred (urem X, Y), Y if (LBO && match(LBO, m_URem(m_Value(), m_Specific(RHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getFalse(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(RHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(RHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2722,28 +2735,27 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, // icmp pred X, (urem Y, X) if (RBO && match(RBO, m_URem(m_Value(), m_Specific(LHS)))) { - bool KnownNonNegative, KnownNegative; switch (Pred) { default: break; case ICmpInst::ICMP_SGT: - case ICmpInst::ICMP_SGE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SGE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_NE: case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_UGE: return getTrue(ITy); case ICmpInst::ICMP_SLT: - case ICmpInst::ICMP_SLE: - ComputeSignBit(LHS, KnownNonNegative, KnownNegative, Q.DL, 0, Q.AC, - Q.CxtI, Q.DT); - if (!KnownNonNegative) + case ICmpInst::ICMP_SLE: { + KnownBits Known = computeKnownBits(LHS, Q.DL, 0, Q.AC, Q.CxtI, Q.DT); + if (!Known.isNonNegative()) break; LLVM_FALLTHROUGH; + } case ICmpInst::ICMP_EQ: case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_ULE: @@ -2815,10 +2827,19 @@ static Value *simplifyICmpWithBinOp(CmpInst::Predicate Pred, Value *LHS, break; case Instruction::UDiv: case Instruction::LShr: - if (ICmpInst::isSigned(Pred)) + if (ICmpInst::isSigned(Pred) || !LBO->isExact() || !RBO->isExact()) break; - LLVM_FALLTHROUGH; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::SDiv: + if (!ICmpInst::isEquality(Pred) || !LBO->isExact() || !RBO->isExact()) + break; + if (Value *V = SimplifyICmpInst(Pred, LBO->getOperand(0), + RBO->getOperand(0), Q, MaxRecurse - 1)) + return V; + break; case Instruction::AShr: if (!LBO->isExact() || !RBO->isExact()) break; @@ -4034,24 +4055,21 @@ Value *llvm::SimplifyCastInst(unsigned CastOpc, Value *Op, Type *Ty, /// match a root vector source operand that contains that element in the same /// vector lane (ie, the same mask index), so we can eliminate the shuffle(s). static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, - Constant *Mask, Value *RootVec, int RootElt, + int MaskVal, Value *RootVec, unsigned MaxRecurse) { if (!MaxRecurse--) return nullptr; // Bail out if any mask value is undefined. That kind of shuffle may be // simplified further based on demanded bits or other folds. - int MaskVal = ShuffleVectorInst::getMaskValue(Mask, RootElt); if (MaskVal == -1) return nullptr; // The mask value chooses which source operand we need to look at next. - Value *SourceOp; int InVecNumElts = Op0->getType()->getVectorNumElements(); - if (MaskVal < InVecNumElts) { - RootElt = MaskVal; - SourceOp = Op0; - } else { + int RootElt = MaskVal; + Value *SourceOp = Op0; + if (MaskVal >= InVecNumElts) { RootElt = MaskVal - InVecNumElts; SourceOp = Op1; } @@ -4061,7 +4079,7 @@ static Value *foldIdentityShuffles(int DestElt, Value *Op0, Value *Op1, if (auto *SourceShuf = dyn_cast<ShuffleVectorInst>(SourceOp)) { return foldIdentityShuffles( DestElt, SourceShuf->getOperand(0), SourceShuf->getOperand(1), - SourceShuf->getMask(), RootVec, RootElt, MaxRecurse); + SourceShuf->getMaskValue(RootElt), RootVec, MaxRecurse); } // TODO: Look through bitcasts? What if the bitcast changes the vector element @@ -4126,17 +4144,7 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, // second one. if (Op0Const && !Op1Const) { std::swap(Op0, Op1); - for (int &Idx : Indices) { - if (Idx == -1) - continue; - Idx = Idx < (int)InVecNumElts ? Idx + InVecNumElts : Idx - InVecNumElts; - assert(Idx >= 0 && Idx < (int)InVecNumElts * 2 && - "shufflevector mask index out of range"); - } - Mask = ConstantDataVector::get( - Mask->getContext(), - makeArrayRef(reinterpret_cast<uint32_t *>(Indices.data()), - MaskNumElts)); + ShuffleVectorInst::commuteShuffleMask(Indices, InVecNumElts); } // A shuffle of a splat is always the splat itself. Legal if the shuffle's @@ -4160,7 +4168,8 @@ static Value *SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask, for (unsigned i = 0; i != MaskNumElts; ++i) { // Note that recursion is limited for each vector element, so if any element // exceeds the limit, this will fail to simplify. - RootVec = foldIdentityShuffles(i, Op0, Op1, Mask, RootVec, i, MaxRecurse); + RootVec = + foldIdentityShuffles(i, Op0, Op1, Indices[i], RootVec, MaxRecurse); // We can't replace a widening/narrowing shuffle with one of its operands. if (!RootVec || RootVec->getType() != RetTy) |