diff options
Diffstat (limited to 'lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | lib/Analysis/ValueTracking.cpp | 648 |
1 files changed, 474 insertions, 174 deletions
diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index 0ef39163bda3..0446426c0e66 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -26,6 +26,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" @@ -118,14 +119,18 @@ struct Query { /// (all of which can call computeKnownBits), and so on. std::array<const Value *, MaxDepth> Excluded; + /// If true, it is safe to use metadata during simplification. + InstrInfoQuery IIQ; + unsigned NumExcluded = 0; Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE) {} + const DominatorTree *DT, bool UseInstrInfo, + OptimizationRemarkEmitter *ORE = nullptr) + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), IIQ(UseInstrInfo) {} Query(const Query &Q, const Value *NewExcl) - : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), + : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), IIQ(Q.IIQ), NumExcluded(Q.NumExcluded) { Excluded = Q.Excluded; Excluded[NumExcluded++] = NewExcl; @@ -165,9 +170,9 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE) { + OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { ::computeKnownBits(V, Known, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } static KnownBits computeKnownBits(const Value *V, unsigned Depth, @@ -177,15 +182,16 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, - OptimizationRemarkEmitter *ORE) { - return ::computeKnownBits(V, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT, ORE)); + OptimizationRemarkEmitter *ORE, + bool UseInstrInfo) { + return ::computeKnownBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { assert(LHS->getType() == RHS->getType() && "LHS and RHS should have the same type"); assert(LHS->getType()->isIntOrIntVectorTy() && @@ -201,8 +207,8 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); - computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT); - computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT); + computeKnownBits(LHS, LHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, 0, AC, CxtI, DT, nullptr, UseInstrInfo); return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); } @@ -222,69 +228,71 @@ static bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, const Query &Q); bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, - bool OrZero, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownToBeAPowerOfTwo(V, OrZero, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + bool OrZero, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownToBeAPowerOfTwo( + V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownNonZero(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::isKnownNonZero(V, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } bool llvm::isKnownNonNegative(const Value *V, const DataLayout &DL, - unsigned Depth, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); return Known.isNonNegative(); } bool llvm::isKnownPositive(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { + const DominatorTree *DT, bool UseInstrInfo) { if (auto *CI = dyn_cast<ConstantInt>(V)) return CI->getValue().isStrictlyPositive(); // TODO: We'd doing two recursive queries here. We should factor this such // that only a single query is needed. - return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT) && - isKnownNonZero(V, DL, Depth, AC, CxtI, DT); + return isKnownNonNegative(V, DL, Depth, AC, CxtI, DT, UseInstrInfo) && + isKnownNonZero(V, DL, Depth, AC, CxtI, DT, UseInstrInfo); } bool llvm::isKnownNegative(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits Known = computeKnownBits(V, DL, Depth, AC, CxtI, DT); + const DominatorTree *DT, bool UseInstrInfo) { + KnownBits Known = + computeKnownBits(V, DL, Depth, AC, CxtI, DT, nullptr, UseInstrInfo); return Known.isNegative(); } static bool isKnownNonEqual(const Value *V1, const Value *V2, const Query &Q); bool llvm::isKnownNonEqual(const Value *V1, const Value *V2, - const DataLayout &DL, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::isKnownNonEqual(V1, V2, Query(DL, AC, - safeCxtI(V1, safeCxtI(V2, CxtI)), - DT)); + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + return ::isKnownNonEqual(V1, V2, + Query(DL, AC, safeCxtI(V1, safeCxtI(V2, CxtI)), DT, + UseInstrInfo, /*ORE=*/nullptr)); } static bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, const Query &Q); bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, - const DataLayout &DL, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, const DominatorTree *DT) { - return ::MaskedValueIsZero(V, Mask, Depth, - Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { + return ::MaskedValueIsZero( + V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, @@ -293,8 +301,9 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { - return ::ComputeNumSignBits(V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT)); + const DominatorTree *DT, bool UseInstrInfo) { + return ::ComputeNumSignBits( + V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, @@ -965,7 +974,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, switch (I->getOpcode()) { default: break; case Instruction::Load: - if (MDNode *MD = cast<LoadInst>(I)->getMetadata(LLVMContext::MD_range)) + if (MDNode *MD = + Q.IIQ.getMetadata(cast<LoadInst>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); break; case Instruction::And: { @@ -1014,7 +1024,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } case Instruction::Mul: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; @@ -1082,7 +1092,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // RHS from matchSelectPattern returns the negation part of abs pattern. // If the negate has an NSW flag we can assume the sign bit of the result // will be 0 because that makes abs(INT_MIN) undefined. - if (cast<Instruction>(RHS)->hasNoSignedWrap()) + if (Q.IIQ.hasNoSignedWrap(cast<Instruction>(RHS))) MaxHighZeros = 1; } @@ -1151,7 +1161,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, } case Instruction::Shl: { // (shl X, C1) & C2 == 0 iff (X & C2 >>u C1) == 0 - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); auto KZF = [NSW](const APInt &KnownZero, unsigned ShiftAmt) { APInt KZResult = KnownZero << ShiftAmt; KZResult.setLowBits(ShiftAmt); // Low bits known 0. @@ -1202,13 +1212,13 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } case Instruction::Sub: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; } case Instruction::Add: { - bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, Known, Known2, Depth, Q); break; @@ -1369,7 +1379,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, Known3.countMinTrailingZeros())); auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU); - if (OverflowOp && OverflowOp->hasNoSignedWrap()) { + if (OverflowOp && Q.IIQ.hasNoSignedWrap(OverflowOp)) { // If initial value of recurrence is nonnegative, and we are adding // a nonnegative number with nsw, the result can only be nonnegative // or poison value regardless of the number of times we execute the @@ -1442,7 +1452,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // If range metadata is attached to this call, set known bits from that, // and then intersect with known bits based on other properties of the // function. - if (MDNode *MD = cast<Instruction>(I)->getMetadata(LLVMContext::MD_range)) + if (MDNode *MD = + Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { computeKnownBits(RV, Known2, Depth + 1, Q); @@ -1495,6 +1506,27 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // of bits which might be set provided by popcnt KnownOne2. break; } + case Intrinsic::fshr: + case Intrinsic::fshl: { + const APInt *SA; + if (!match(I->getOperand(2), m_APInt(SA))) + break; + + // Normalize to funnel shift left. + uint64_t ShiftAmt = SA->urem(BitWidth); + if (II->getIntrinsicID() == Intrinsic::fshr) + ShiftAmt = BitWidth - ShiftAmt; + + KnownBits Known3(Known); + computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q); + + Known.Zero = + Known2.Zero.shl(ShiftAmt) | Known3.Zero.lshr(BitWidth - ShiftAmt); + Known.One = + Known2.One.shl(ShiftAmt) | Known3.One.lshr(BitWidth - ShiftAmt); + break; + } case Intrinsic::x86_sse42_crc32_64_64: Known.Zero.setBitsFrom(32); break; @@ -1722,7 +1754,8 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, // either the original power-of-two, a larger power-of-two or zero. if (match(V, m_Add(m_Value(X), m_Value(Y)))) { const OverflowingBinaryOperator *VOBO = cast<OverflowingBinaryOperator>(V); - if (OrZero || VOBO->hasNoUnsignedWrap() || VOBO->hasNoSignedWrap()) { + if (OrZero || Q.IIQ.hasNoUnsignedWrap(VOBO) || + Q.IIQ.hasNoSignedWrap(VOBO)) { if (match(X, m_And(m_Specific(Y), m_Value())) || match(X, m_And(m_Value(), m_Specific(Y)))) if (isKnownToBeAPowerOfTwo(Y, OrZero, Depth, Q)) @@ -1860,19 +1893,41 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) continue; + SmallVector<const User *, 4> WorkList; + SmallPtrSet<const User *, 4> Visited; for (auto *CmpU : U->users()) { - if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { - assert(BI->isConditional() && "uses a comparison!"); + assert(WorkList.empty() && "Should be!"); + if (Visited.insert(CmpU).second) + WorkList.push_back(CmpU); + + while (!WorkList.empty()) { + auto *Curr = WorkList.pop_back_val(); + + // If a user is an AND, add all its users to the work list. We only + // propagate "pred != null" condition through AND because it is only + // correct to assume that all conditions of AND are met in true branch. + // TODO: Support similar logic of OR and EQ predicate? + if (Pred == ICmpInst::ICMP_NE) + if (auto *BO = dyn_cast<BinaryOperator>(Curr)) + if (BO->getOpcode() == Instruction::And) { + for (auto *BOU : BO->users()) + if (Visited.insert(BOU).second) + WorkList.push_back(BOU); + continue; + } - BasicBlock *NonNullSuccessor = - BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); - BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); - if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + if (const BranchInst *BI = dyn_cast<BranchInst>(Curr)) { + assert(BI->isConditional() && "uses a comparison!"); + + BasicBlock *NonNullSuccessor = + BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); + BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); + if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + return true; + } else if (Pred == ICmpInst::ICMP_NE && isGuard(Curr) && + DT->dominates(cast<Instruction>(Curr), CtxI)) { return true; - } else if (Pred == ICmpInst::ICMP_NE && - match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && - DT->dominates(cast<Instruction>(CmpU), CtxI)) { - return true; + } } } } @@ -1937,7 +1992,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } if (auto *I = dyn_cast<Instruction>(V)) { - if (MDNode *Ranges = I->getMetadata(LLVMContext::MD_range)) { + if (MDNode *Ranges = Q.IIQ.getMetadata(I, LLVMContext::MD_range)) { // If the possible ranges don't contain zero, then the value is // definitely non-zero. if (auto *Ty = dyn_cast<IntegerType>(V->getType())) { @@ -1965,13 +2020,13 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // A Load tagged with nonnull metadata is never null. if (const LoadInst *LI = dyn_cast<LoadInst>(V)) - if (LI->getMetadata(LLVMContext::MD_nonnull)) + if (Q.IIQ.getMetadata(LI, LLVMContext::MD_nonnull)) return true; - if (auto CS = ImmutableCallSite(V)) { - if (CS.isReturnNonNull()) + if (const auto *Call = dyn_cast<CallBase>(V)) { + if (Call->isReturnNonNull()) return true; - if (const auto *RP = getArgumentAliasingToReturnedPointer(CS)) + if (const auto *RP = getArgumentAliasingToReturnedPointer(Call)) return isKnownNonZero(RP, Depth, Q); } } @@ -2003,7 +2058,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (match(V, m_Shl(m_Value(X), m_Value(Y)))) { // shl nuw can't remove any non-zero bits. const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); - if (BO->hasNoUnsignedWrap()) + if (Q.IIQ.hasNoUnsignedWrap(BO)) return isKnownNonZero(X, Depth, Q); KnownBits Known(BitWidth); @@ -2078,7 +2133,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { const OverflowingBinaryOperator *BO = cast<OverflowingBinaryOperator>(V); // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. - if ((BO->hasNoSignedWrap() || BO->hasNoUnsignedWrap()) && + if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) return true; } @@ -2100,7 +2155,8 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (ConstantInt *C = dyn_cast<ConstantInt>(Start)) { if (!C->isZero() && !C->isNegative()) { ConstantInt *X; - if ((match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || + if (Q.IIQ.UseInstrInfo && + (match(Induction, m_NSWAdd(m_Specific(PN), m_ConstantInt(X))) || match(Induction, m_NUWAdd(m_Specific(PN), m_ConstantInt(X)))) && !X->isNegative()) return true; @@ -2174,6 +2230,36 @@ bool MaskedValueIsZero(const Value *V, const APInt &Mask, unsigned Depth, return Mask.isSubsetOf(Known.Zero); } +// Match a signed min+max clamp pattern like smax(smin(In, CHigh), CLow). +// Returns the input and lower/upper bounds. +static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, + const APInt *&CLow, const APInt *&CHigh) { + assert(isa<Operator>(Select) && + cast<Operator>(Select)->getOpcode() == Instruction::Select && + "Input should be a Select!"); + + const Value *LHS, *RHS, *LHS2, *RHS2; + SelectPatternFlavor SPF = matchSelectPattern(Select, LHS, RHS).Flavor; + if (SPF != SPF_SMAX && SPF != SPF_SMIN) + return false; + + if (!match(RHS, m_APInt(CLow))) + return false; + + SelectPatternFlavor SPF2 = matchSelectPattern(LHS, LHS2, RHS2).Flavor; + if (getInverseMinMaxFlavor(SPF) != SPF2) + return false; + + if (!match(RHS2, m_APInt(CHigh))) + return false; + + if (SPF == SPF_SMIN) + std::swap(CLow, CHigh); + + In = LHS2; + return CLow->sle(*CHigh); +} + /// For vector constants, loop over the elements and find the constant with the /// minimum number of sign bits. Return 0 if the value is not a vector constant /// or if any element was not analyzed; otherwise, return the count for the @@ -2335,11 +2421,19 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, } break; - case Instruction::Select: + case Instruction::Select: { + // If we have a clamp pattern, we know that the number of sign bits will be + // the minimum of the clamp min/max range. + const Value *X; + const APInt *CLow, *CHigh; + if (isSignedMinMaxClamp(U, X, CLow, CHigh)) + return std::min(CLow->getNumSignBits(), CHigh->getNumSignBits()); + Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); if (Tmp == 1) break; Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); return std::min(Tmp, Tmp2); + } case Instruction::Add: // Add can have at most one carry bit. Thus we know that the output @@ -2437,6 +2531,44 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // valid for all elements of the vector (for example if vector is sign // extended, shifted, etc). return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + + case Instruction::ShuffleVector: { + // TODO: This is copied almost directly from the SelectionDAG version of + // ComputeNumSignBits. It would be better if we could share common + // code. If not, make sure that changes are translated to the DAG. + + // Collect the minimum number of sign bits that are shared by every vector + // element referenced by the shuffle. + auto *Shuf = cast<ShuffleVectorInst>(U); + int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); + int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); + APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); + for (int i = 0; i != NumMaskElts; ++i) { + int M = Shuf->getMaskValue(i); + assert(M < NumElts * 2 && "Invalid shuffle mask constant"); + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M == -1) + return 1; + if (M < NumElts) + DemandedLHS.setBit(M % NumElts); + else + DemandedRHS.setBit(M % NumElts); + } + Tmp = std::numeric_limits<unsigned>::max(); + if (!!DemandedLHS) + Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q); + if (!!DemandedRHS) { + Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q); + Tmp = std::min(Tmp, Tmp2); + } + // If we don't know anything, early out and try computeKnownBits fall-back. + if (Tmp == 1) + break; + assert(Tmp <= V->getType()->getScalarSizeInBits() && + "Failed to determine minimum sign bits"); + return Tmp; + } } // Finally, if we can prove that the top bits of the result are 0's or 1's, @@ -2722,6 +2854,7 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, break; // sqrt(-0.0) = -0.0, no other negative results are possible. case Intrinsic::sqrt: + case Intrinsic::canonicalize: return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); // fabs(x) != -0.0 case Intrinsic::fabs: @@ -2817,11 +2950,20 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, default: break; case Intrinsic::maxnum: + return (isKnownNeverNaN(I->getOperand(0), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, + SignBitOnly, Depth + 1)) || + (isKnownNeverNaN(I->getOperand(1), TLI) && + cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, + SignBitOnly, Depth + 1)); + + case Intrinsic::maximum: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1) || cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, Depth + 1); case Intrinsic::minnum: + case Intrinsic::minimum: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1) && cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, SignBitOnly, @@ -2882,7 +3024,8 @@ bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); } -bool llvm::isKnownNeverNaN(const Value *V) { +bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, + unsigned Depth) { assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); // If we're told that NaNs won't happen, assume they won't. @@ -2890,13 +3033,60 @@ bool llvm::isKnownNeverNaN(const Value *V) { if (FPMathOp->hasNoNaNs()) return true; - // TODO: Handle instructions and potentially recurse like other 'isKnown' - // functions. For example, the result of sitofp is never NaN. - // Handle scalar constants. if (auto *CFP = dyn_cast<ConstantFP>(V)) return !CFP->isNaN(); + if (Depth == MaxDepth) + return false; + + if (auto *Inst = dyn_cast<Instruction>(V)) { + switch (Inst->getOpcode()) { + case Instruction::FAdd: + case Instruction::FMul: + case Instruction::FSub: + case Instruction::FDiv: + case Instruction::FRem: { + // TODO: Need isKnownNeverInfinity + return false; + } + case Instruction::Select: { + return isKnownNeverNaN(Inst->getOperand(1), TLI, Depth + 1) && + isKnownNeverNaN(Inst->getOperand(2), TLI, Depth + 1); + } + case Instruction::SIToFP: + case Instruction::UIToFP: + return true; + case Instruction::FPTrunc: + case Instruction::FPExt: + return isKnownNeverNaN(Inst->getOperand(0), TLI, Depth + 1); + default: + break; + } + } + + if (const auto *II = dyn_cast<IntrinsicInst>(V)) { + switch (II->getIntrinsicID()) { + case Intrinsic::canonicalize: + case Intrinsic::fabs: + case Intrinsic::copysign: + case Intrinsic::exp: + case Intrinsic::exp2: + case Intrinsic::floor: + case Intrinsic::ceil: + case Intrinsic::trunc: + case Intrinsic::rint: + case Intrinsic::nearbyint: + case Intrinsic::round: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1); + case Intrinsic::sqrt: + return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) && + CannotBeOrderedLessThanZero(II->getArgOperand(0), TLI); + default: + return false; + } + } + // Bail out for constant expressions, but try to handle vector constants. if (!V->getType()->isVectorTy() || !isa<Constant>(V)) return false; @@ -2917,62 +3107,92 @@ bool llvm::isKnownNeverNaN(const Value *V) { return true; } -/// If the specified value can be set by repeating the same byte in memory, -/// return the i8 value that it is represented with. This is -/// true for all i8 values obviously, but is also true for i32 0, i32 -1, -/// i16 0xF0F0, double 0.0 etc. If the value can't be handled with a repeated -/// byte store (e.g. i16 0x1234), return null. Value *llvm::isBytewiseValue(Value *V) { + // All byte-wide stores are splatable, even of arbitrary variables. - if (V->getType()->isIntegerTy(8)) return V; + if (V->getType()->isIntegerTy(8)) + return V; + + LLVMContext &Ctx = V->getContext(); + + // Undef don't care. + auto *UndefInt8 = UndefValue::get(Type::getInt8Ty(Ctx)); + if (isa<UndefValue>(V)) + return UndefInt8; + + Constant *C = dyn_cast<Constant>(V); + if (!C) { + // Conceptually, we could handle things like: + // %a = zext i8 %X to i16 + // %b = shl i16 %a, 8 + // %c = or i16 %a, %b + // but until there is an example that actually needs this, it doesn't seem + // worth worrying about. + return nullptr; + } // Handle 'null' ConstantArrayZero etc. - if (Constant *C = dyn_cast<Constant>(V)) - if (C->isNullValue()) - return Constant::getNullValue(Type::getInt8Ty(V->getContext())); + if (C->isNullValue()) + return Constant::getNullValue(Type::getInt8Ty(Ctx)); - // Constant float and double values can be handled as integer values if the + // Constant floating-point values can be handled as integer values if the // corresponding integer value is "byteable". An important case is 0.0. - if (ConstantFP *CFP = dyn_cast<ConstantFP>(V)) { - if (CFP->getType()->isFloatTy()) - V = ConstantExpr::getBitCast(CFP, Type::getInt32Ty(V->getContext())); - if (CFP->getType()->isDoubleTy()) - V = ConstantExpr::getBitCast(CFP, Type::getInt64Ty(V->getContext())); + if (ConstantFP *CFP = dyn_cast<ConstantFP>(C)) { + Type *Ty = nullptr; + if (CFP->getType()->isHalfTy()) + Ty = Type::getInt16Ty(Ctx); + else if (CFP->getType()->isFloatTy()) + Ty = Type::getInt32Ty(Ctx); + else if (CFP->getType()->isDoubleTy()) + Ty = Type::getInt64Ty(Ctx); // Don't handle long double formats, which have strange constraints. + return Ty ? isBytewiseValue(ConstantExpr::getBitCast(CFP, Ty)) : nullptr; } // We can handle constant integers that are multiple of 8 bits. - if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + if (ConstantInt *CI = dyn_cast<ConstantInt>(C)) { if (CI->getBitWidth() % 8 == 0) { assert(CI->getBitWidth() > 8 && "8 bits should be handled above!"); - if (!CI->getValue().isSplat(8)) return nullptr; - return ConstantInt::get(V->getContext(), CI->getValue().trunc(8)); + return ConstantInt::get(Ctx, CI->getValue().trunc(8)); } } - // A ConstantDataArray/Vector is splatable if all its members are equal and - // also splatable. - if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(V)) { - Value *Elt = CA->getElementAsConstant(0); - Value *Val = isBytewiseValue(Elt); - if (!Val) + auto Merge = [&](Value *LHS, Value *RHS) -> Value * { + if (LHS == RHS) + return LHS; + if (!LHS || !RHS) return nullptr; + if (LHS == UndefInt8) + return RHS; + if (RHS == UndefInt8) + return LHS; + return nullptr; + }; - for (unsigned I = 1, E = CA->getNumElements(); I != E; ++I) - if (CA->getElementAsConstant(I) != Elt) + if (ConstantDataSequential *CA = dyn_cast<ConstantDataSequential>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = CA->getNumElements(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(CA->getElementAsConstant(I))))) return nullptr; + return Val; + } + if (isa<ConstantVector>(C)) { + Constant *Splat = cast<ConstantVector>(C)->getSplatValue(); + return Splat ? isBytewiseValue(Splat) : nullptr; + } + + if (isa<ConstantArray>(C) || isa<ConstantStruct>(C)) { + Value *Val = UndefInt8; + for (unsigned I = 0, E = C->getNumOperands(); I != E; ++I) + if (!(Val = Merge(Val, isBytewiseValue(C->getOperand(I))))) + return nullptr; return Val; } - // Conceptually, we could handle things like: - // %a = zext i8 %X to i16 - // %b = shl i16 %a, 8 - // %c = or i16 %a, %b - // but until there is an example that actually needs this, it doesn't seem - // worth worrying about. + // Don't try to handle the handful of other constants. return nullptr; } @@ -3169,7 +3389,14 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, if (!GEP->accumulateConstantOffset(DL, GEPOffset)) break; - ByteOffset += GEPOffset.getSExtValue(); + APInt OrigByteOffset(ByteOffset); + ByteOffset += GEPOffset.sextOrTrunc(ByteOffset.getBitWidth()); + if (ByteOffset.getMinSignedBits() > 64) { + // Stop traversal if the pointer offset wouldn't fit into int64_t + // (this should be removed if Offset is updated to an APInt) + ByteOffset = OrigByteOffset; + break; + } Ptr = GEP->getPointerOperand(); } else if (Operator::getOpcode(Ptr) == Instruction::BitCast || @@ -3397,21 +3624,21 @@ uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { return Len == ~0ULL ? 1 : Len; } -const Value *llvm::getArgumentAliasingToReturnedPointer(ImmutableCallSite CS) { - assert(CS && - "getArgumentAliasingToReturnedPointer only works on nonnull CallSite"); - if (const Value *RV = CS.getReturnedArgOperand()) +const Value *llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call) { + assert(Call && + "getArgumentAliasingToReturnedPointer only works on nonnull calls"); + if (const Value *RV = Call->getReturnedArgOperand()) return RV; // This can be used only as a aliasing property. - if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) - return CS.getArgOperand(0); + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(Call)) + return Call->getArgOperand(0); return nullptr; } bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( - ImmutableCallSite CS) { - return CS.getIntrinsicID() == Intrinsic::launder_invariant_group || - CS.getIntrinsicID() == Intrinsic::strip_invariant_group; + const CallBase *Call) { + return Call->getIntrinsicID() == Intrinsic::launder_invariant_group || + Call->getIntrinsicID() == Intrinsic::strip_invariant_group; } /// \p PN defines a loop-variant pointer to an object. Check if the @@ -3459,7 +3686,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // An alloca can't be further simplified. return V; } else { - if (auto CS = CallSite(V)) { + if (auto *Call = dyn_cast<CallBase>(V)) { // CaptureTracking can know about special capturing properties of some // intrinsics like launder.invariant.group, that can't be expressed with // the attributes, but have properties like returning aliasing pointer. @@ -3469,7 +3696,7 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // because it should be in sync with CaptureTracking. Not using it may // cause weird miscompilations where 2 aliasing pointers are assumed to // noalias. - if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { + if (auto *RP = getArgumentAliasingToReturnedPointer(Call)) { V = RP; continue; } @@ -3602,8 +3829,7 @@ bool llvm::onlyUsedByLifetimeMarkers(const Value *V) { const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); if (!II) return false; - if (II->getIntrinsicID() != Intrinsic::lifetime_start && - II->getIntrinsicID() != Intrinsic::lifetime_end) + if (!II->isLifetimeStartOrEnd()) return false; } return true; @@ -3700,12 +3926,10 @@ bool llvm::mayBeMemoryDependent(const Instruction &I) { return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); } -OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { +OverflowResult llvm::computeOverflowForUnsignedMul( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -3715,8 +3939,10 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); KnownBits LHSKnown(BitWidth); KnownBits RHSKnown(BitWidth); - computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); - computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT); + computeKnownBits(LHS, LHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr, + UseInstrInfo); + computeKnownBits(RHS, RHSKnown, DL, /*Depth=*/0, AC, CxtI, DT, nullptr, + UseInstrInfo); // Note that underestimating the number of zero bits gives a more // conservative answer. unsigned ZeroBits = LHSKnown.countMinLeadingZeros() + @@ -3747,12 +3973,11 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, return OverflowResult::MayOverflow; } -OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { +OverflowResult +llvm::computeOverflowForSignedMul(const Value *LHS, const Value *RHS, + const DataLayout &DL, AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT, bool UseInstrInfo) { // Multiplying n * m significant bits yields a result of n + m significant // bits. If the total number of significant bits does not exceed the // result bit width (minus 1), there is no overflow. @@ -3781,33 +4006,33 @@ OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, // product is exactly the minimum negative number. // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 // For simplicity we just check if at least one side is not negative. - KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) return OverflowResult::NeverOverflows; } return OverflowResult::MayOverflow; } -OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, - const Value *RHS, - const DataLayout &DL, - AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); +OverflowResult llvm::computeOverflowForUnsignedAdd( + const Value *LHS, const Value *RHS, const DataLayout &DL, + AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, + bool UseInstrInfo) { + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT, + nullptr, UseInstrInfo); if (LHSKnown.isNegative() && RHSKnown.isNegative()) { // The sign bit is set in both cases: this MUST overflow. - // Create a simple add instruction, and insert it into the struct. return OverflowResult::AlwaysOverflows; } if (LHSKnown.isNonNegative() && RHSKnown.isNonNegative()) { // The sign bit is clear in both cases: this CANNOT overflow. - // Create a simple add instruction, and insert it into the struct. return OverflowResult::NeverOverflows; } } @@ -3924,11 +4149,18 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { - // If the LHS is negative and the RHS is non-negative, no unsigned wrap. KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); - KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); - if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) - return OverflowResult::NeverOverflows; + if (LHSKnown.isNonNegative() || LHSKnown.isNegative()) { + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. + if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) + return OverflowResult::NeverOverflows; + + // If the LHS is non-negative and the RHS negative, we always wrap. + if (LHSKnown.isNonNegative() && RHSKnown.isNegative()) + return OverflowResult::AlwaysOverflows; + } return OverflowResult::MayOverflow; } @@ -4241,12 +4473,34 @@ static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) { if (auto *C = dyn_cast<ConstantFP>(V)) return !C->isNaN(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isNaN()) + return false; + } + return true; + } + return false; } static bool isKnownNonZero(const Value *V) { if (auto *C = dyn_cast<ConstantFP>(V)) return !C->isZero(); + + if (auto *C = dyn_cast<ConstantDataVector>(V)) { + if (!C->getElementType()->isFloatingPointTy()) + return false; + for (unsigned I = 0, E = C->getNumElements(); I < E; ++I) { + if (C->getElementAsAPFloat(I).isZero()) + return false; + } + return true; + } + return false; } @@ -4538,6 +4792,27 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, Value *TrueVal, Value *FalseVal, Value *&LHS, Value *&RHS, unsigned Depth) { + if (CmpInst::isFPPredicate(Pred)) { + // IEEE-754 ignores the sign of 0.0 in comparisons. So if the select has one + // 0.0 operand, set the compare's 0.0 operands to that same value for the + // purpose of identifying min/max. Disregard vector constants with undefined + // elements because those can not be back-propagated for analysis. + Value *OutputZeroVal = nullptr; + if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && + !cast<Constant>(TrueVal)->containsUndefElement()) + OutputZeroVal = TrueVal; + else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && + !cast<Constant>(FalseVal)->containsUndefElement()) + OutputZeroVal = FalseVal; + + if (OutputZeroVal) { + if (match(CmpLHS, m_AnyZeroFP())) + CmpLHS = OutputZeroVal; + if (match(CmpRHS, m_AnyZeroFP())) + CmpRHS = OutputZeroVal; + } + } + LHS = CmpLHS; RHS = CmpRHS; @@ -4967,21 +5242,16 @@ static bool isMatchingOps(const Value *ALHS, const Value *ARHS, return IsMatchingOps || IsSwappedOps; } -/// Return true if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS BRHS" is -/// true. Return false if "icmp1 APred ALHS ARHS" implies "icmp2 BPred BLHS -/// BRHS" is false. Otherwise, return None if we can't infer anything. +/// Return true if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is true. +/// Return false if "icmp1 APred X, Y" implies "icmp2 BPred X, Y" is false. +/// Otherwise, return None if we can't infer anything. static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, - const Value *ALHS, - const Value *ARHS, CmpInst::Predicate BPred, - const Value *BLHS, - const Value *BRHS, - bool IsSwappedOps) { - // Canonicalize the operands so they're matching. - if (IsSwappedOps) { - std::swap(BLHS, BRHS); + bool AreSwappedOps) { + // Canonicalize the predicate as if the operands were not commuted. + if (AreSwappedOps) BPred = ICmpInst::getSwappedPredicate(BPred); - } + if (CmpInst::isImpliedTrueByMatchingCmp(APred, BPred)) return true; if (CmpInst::isImpliedFalseByMatchingCmp(APred, BPred)) @@ -4990,15 +5260,14 @@ static Optional<bool> isImpliedCondMatchingOperands(CmpInst::Predicate APred, return None; } -/// Return true if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS C2" is -/// true. Return false if "icmp1 APred ALHS C1" implies "icmp2 BPred BLHS -/// C2" is false. Otherwise, return None if we can't infer anything. +/// Return true if "icmp APred X, C1" implies "icmp BPred X, C2" is true. +/// Return false if "icmp APred X, C1" implies "icmp BPred X, C2" is false. +/// Otherwise, return None if we can't infer anything. static Optional<bool> -isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const Value *ALHS, +isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const ConstantInt *C1, CmpInst::Predicate BPred, - const Value *BLHS, const ConstantInt *C2) { - assert(ALHS == BLHS && "LHS operands must match."); + const ConstantInt *C2) { ConstantRange DomCR = ConstantRange::makeExactICmpRegion(APred, C1->getValue()); ConstantRange CR = @@ -5030,10 +5299,10 @@ static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, ICmpInst::Predicate BPred = RHS->getPredicate(); // Can we infer anything when the two compares have matching operands? - bool IsSwappedOps; - if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, IsSwappedOps)) { + bool AreSwappedOps; + if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, AreSwappedOps)) { if (Optional<bool> Implication = isImpliedCondMatchingOperands( - APred, ALHS, ARHS, BPred, BLHS, BRHS, IsSwappedOps)) + APred, BPred, AreSwappedOps)) return Implication; // No amount of additional analysis will infer the second condition, so // early exit. @@ -5044,8 +5313,7 @@ static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, // constants (not necessarily matching)? if (ALHS == BLHS && isa<ConstantInt>(ARHS) && isa<ConstantInt>(BRHS)) { if (Optional<bool> Implication = isImpliedCondMatchingImmOperands( - APred, ALHS, cast<ConstantInt>(ARHS), BPred, BLHS, - cast<ConstantInt>(BRHS))) + APred, cast<ConstantInt>(ARHS), BPred, cast<ConstantInt>(BRHS))) return Implication; // No amount of additional analysis will infer the second condition, so // early exit. @@ -5130,3 +5398,35 @@ Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, } return None; } + +Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond, + const Instruction *ContextI, + const DataLayout &DL) { + assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool"); + if (!ContextI || !ContextI->getParent()) + return None; + + // TODO: This is a poor/cheap way to determine dominance. Should we use a + // dominator tree (eg, from a SimplifyQuery) instead? + const BasicBlock *ContextBB = ContextI->getParent(); + const BasicBlock *PredBB = ContextBB->getSinglePredecessor(); + if (!PredBB) + return None; + + // We need a conditional branch in the predecessor. + Value *PredCond; + BasicBlock *TrueBB, *FalseBB; + if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB))) + return None; + + // The branch should get simplified. Don't bother simplifying this condition. + if (TrueBB == FalseBB) + return None; + + assert((TrueBB == ContextBB || FalseBB == ContextBB) && + "Predecessor block does not point to successor?"); + + // Is this condition implied by the predecessor condition? + bool CondIsTrue = TrueBB == ContextBB; + return isImpliedCondition(PredCond, Cond, DL, CondIsTrue); +} |