diff options
Diffstat (limited to 'lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | lib/Analysis/ValueTracking.cpp | 515 |
1 files changed, 435 insertions, 80 deletions
diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index cd4cee631568..04a7b73c22bf 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -89,7 +89,7 @@ static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { if (unsigned BitWidth = Ty->getScalarSizeInBits()) return BitWidth; - return DL.getPointerTypeSizeInBits(Ty); + return DL.getIndexTypeSizeInBits(Ty); } namespace { @@ -190,6 +190,14 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, "LHS and RHS should have the same type"); assert(LHS->getType()->isIntOrIntVectorTy() && "LHS and RHS should be integers"); + // Look for an inverted mask: (X & ~M) op (Y & M). + Value *M; + if (match(LHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(RHS, m_c_And(m_Specific(M), m_Value()))) + return true; + if (match(RHS, m_c_And(m_Not(m_Value(M)), m_Value())) && + match(LHS, m_c_And(m_Specific(M), m_Value()))) + return true; IntegerType *IT = cast<IntegerType>(LHS->getType()->getScalarType()); KnownBits LHSKnown(IT->getBitWidth()); KnownBits RHSKnown(IT->getBitWidth()); @@ -493,6 +501,7 @@ bool llvm::isAssumeLikeIntrinsic(const Instruction *I) { case Intrinsic::sideeffect: case Intrinsic::dbg_declare: case Intrinsic::dbg_value: + case Intrinsic::dbg_label: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: @@ -530,7 +539,7 @@ bool llvm::isValidAssumeForContext(const Instruction *Inv, if (Inv->getParent() != CxtI->getParent()) return false; - // If we have a dom tree, then we now know that the assume doens't dominate + // If we have a dom tree, then we now know that the assume doesn't dominate // the other instruction. If we don't have a dom tree then we can check if // the assume is first in the BB. if (!DT) { @@ -574,7 +583,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (Q.isExcluded(I)) continue; - // Warning: This loop can end up being somewhat performance sensetive. + // Warning: This loop can end up being somewhat performance sensitive. // We're running this loop for once for each value queried resulting in a // runtime of ~O(#assumes * #values). @@ -816,6 +825,14 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, KnownBits RHSKnown(BitWidth); computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + // If the RHS is known zero, then this assumption must be wrong (nothing + // is unsigned less than zero). Signal a conflict and get out of here. + if (RHSKnown.isZero()) { + Known.Zero.setAllBits(); + Known.One.setAllBits(); + break; + } + // Whatever high bits in c are zero are known to be zero (if c is a power // of 2, then one more). if (isKnownToBeAPowerOfTwo(A, false, Depth + 1, Query(Q, I))) @@ -848,7 +865,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, /// Compute known bits from a shift operator, including those with a /// non-constant shift amount. Known is the output of this function. Known2 is a /// pre-allocated temporary with the same bit width as Known. KZF and KOF are -/// operator-specific functors that, given the known-zero or known-one bits +/// operator-specific functions that, given the known-zero or known-one bits /// respectively, and a shift amount, compute the implied known-zero or /// known-one bits of the shift operator's result respectively for that shift /// amount. The results from calling KZF and KOF are conservatively combined for @@ -966,12 +983,9 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // matching the form add(x, add(x, y)) where y is odd. // TODO: This could be generalized to clearing any bit set in y where the // following bit is known to be unset in y. - Value *Y = nullptr; + Value *X = nullptr, *Y = nullptr; if (!Known.Zero[0] && !Known.One[0] && - (match(I->getOperand(0), m_Add(m_Specific(I->getOperand(1)), - m_Value(Y))) || - match(I->getOperand(1), m_Add(m_Specific(I->getOperand(0)), - m_Value(Y))))) { + match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { Known2.resetAll(); computeKnownBits(Y, Known2, Depth + 1, Q); if (Known2.countMinTrailingOnes() > 0) @@ -1064,6 +1078,12 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // leading zero bits. MaxHighZeros = std::max(Known.countMinLeadingZeros(), Known2.countMinLeadingZeros()); + } else if (SPF == SPF_ABS) { + // RHS from matchSelectPattern returns the negation part of abs pattern. + // If the negate has an NSW flag we can assume the sign bit of the result + // will be 0 because that makes abs(INT_MIN) undefined. + if (cast<Instruction>(RHS)->hasNoSignedWrap()) + MaxHighZeros = 1; } // Only known if known in both the LHS and RHS. @@ -1093,7 +1113,10 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, unsigned SrcBitWidth; // Note that we handle pointer operands here because of inttoptr/ptrtoint // which fall through here. - SrcBitWidth = Q.DL.getTypeSizeInBits(SrcTy->getScalarType()); + Type *ScalarTy = SrcTy->getScalarType(); + SrcBitWidth = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : + Q.DL.getTypeSizeInBits(ScalarTy); assert(SrcBitWidth && "SrcBitWidth can't be zero"); Known = Known.zextOrTrunc(SrcBitWidth); @@ -1106,7 +1129,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, } case Instruction::BitCast: { Type *SrcTy = I->getOperand(0)->getType(); - if ((SrcTy->isIntegerTy() || SrcTy->isPointerTy()) && + if (SrcTy->isIntOrPtrTy() && // TODO: For now, not handling conversions like: // (bitcast i64 %x to <2 x i32>) !I->getType()->isVectorTy()) { @@ -1547,9 +1570,13 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, assert((V->getType()->isIntOrIntVectorTy(BitWidth) || V->getType()->isPtrOrPtrVectorTy()) && "Not integer or pointer type!"); - assert(Q.DL.getTypeSizeInBits(V->getType()->getScalarType()) == BitWidth && - "V and Known should have same BitWidth"); + + Type *ScalarTy = V->getType()->getScalarType(); + unsigned ExpectedWidth = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); + assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); (void)BitWidth; + (void)ExpectedWidth; const APInt *C; if (match(V, m_APInt(C))) { @@ -1646,14 +1673,11 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, const Query &Q) { assert(Depth <= MaxDepth && "Limit Search Depth"); - if (const Constant *C = dyn_cast<Constant>(V)) { - if (C->isNullValue()) - return OrZero; - - const APInt *ConstIntOrConstSplatInt; - if (match(C, m_APInt(ConstIntOrConstSplatInt))) - return ConstIntOrConstSplatInt->isPowerOf2(); - } + // Attempt to match against constants. + if (OrZero && match(V, m_Power2OrZero())) + return true; + if (match(V, m_Power2())) + return true; // 1 << X is clearly a power of two if the one is not shifted off the end. If // it is shifted off the end then the result is undefined. @@ -1737,7 +1761,7 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, return false; } -/// \brief Test whether a GEP's result is known to be non-null. +/// Test whether a GEP's result is known to be non-null. /// /// Uses properties inherent in a GEP to try to determine whether it is known /// to be non-null. @@ -1745,7 +1769,12 @@ bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, /// Currently this routine does not support vector GEPs. static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, const Query &Q) { - if (!GEP->isInBounds() || GEP->getPointerAddressSpace() != 0) + const Function *F = nullptr; + if (const Instruction *I = dyn_cast<Instruction>(GEP)) + F = I->getFunction(); + + if (!GEP->isInBounds() || + NullPointerIsDefined(F, GEP->getPointerAddressSpace())) return false; // FIXME: Support vector-GEPs. @@ -1919,6 +1948,10 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } } + // Some of the tests below are recursive, so bail out if we hit the limit. + if (Depth++ >= MaxDepth) + return false; + // Check for pointer simplifications. if (V->getType()->isPointerTy()) { // Alloca never returns null, malloc might. @@ -1935,14 +1968,14 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (LI->getMetadata(LLVMContext::MD_nonnull)) return true; - if (auto CS = ImmutableCallSite(V)) + if (auto CS = ImmutableCallSite(V)) { if (CS.isReturnNonNull()) return true; + if (const auto *RP = getArgumentAliasingToReturnedPointer(CS)) + return isKnownNonZero(RP, Depth, Q); + } } - // The remaining tests are all recursive, so bail out if we hit the limit. - if (Depth++ >= MaxDepth) - return false; // Check for recursive pointer simplifications. if (V->getType()->isPointerTy()) { @@ -2180,7 +2213,7 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each /// other, so we return 3. For vectors, return the number of sign bits for the -/// vector element with the mininum number of known sign bits. +/// vector element with the minimum number of known sign bits. static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, const Query &Q) { assert(Depth <= MaxDepth && "Limit Search Depth"); @@ -2189,7 +2222,11 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // in V, so for undef we have to conservatively return 1. We don't have the // same behavior for poison though -- that's a FIXME today. - unsigned TyBits = Q.DL.getTypeSizeInBits(V->getType()->getScalarType()); + Type *ScalarTy = V->getType()->getScalarType(); + unsigned TyBits = ScalarTy->isPointerTy() ? + Q.DL.getIndexTypeSizeInBits(ScalarTy) : + Q.DL.getTypeSizeInBits(ScalarTy); + unsigned Tmp, Tmp2; unsigned FirstAnswer = 1; @@ -2264,9 +2301,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // ashr X, C -> adds C sign bits. Vectors too. const APInt *ShAmt; if (match(U->getOperand(1), m_APInt(ShAmt))) { - unsigned ShAmtLimited = ShAmt->getZExtValue(); - if (ShAmtLimited >= TyBits) + if (ShAmt->uge(TyBits)) break; // Bad shift. + unsigned ShAmtLimited = ShAmt->getZExtValue(); Tmp += ShAmtLimited; if (Tmp > TyBits) Tmp = TyBits; } @@ -2277,9 +2314,9 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, if (match(U->getOperand(1), m_APInt(ShAmt))) { // shl destroys sign bits. Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (ShAmt->uge(TyBits) || // Bad shift. + ShAmt->uge(Tmp)) break; // Shifted all sign bits out. Tmp2 = ShAmt->getZExtValue(); - if (Tmp2 >= TyBits || // Bad shift. - Tmp2 >= Tmp) break; // Shifted all sign bits out. return Tmp - Tmp2; } break; @@ -2300,7 +2337,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, case Instruction::Select: Tmp = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp == 1) return 1; // Early out. + if (Tmp == 1) break; Tmp2 = ComputeNumSignBits(U->getOperand(2), Depth + 1, Q); return std::min(Tmp, Tmp2); @@ -2308,7 +2345,7 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // Add can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (Tmp == 1) return 1; // Early out. + if (Tmp == 1) break; // Special case decrementing a value (ADD X, -1): if (const auto *CRHS = dyn_cast<Constant>(U->getOperand(1))) @@ -2328,12 +2365,12 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, } Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp2 == 1) return 1; + if (Tmp2 == 1) break; return std::min(Tmp, Tmp2)-1; case Instruction::Sub: Tmp2 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (Tmp2 == 1) return 1; + if (Tmp2 == 1) break; // Handle NEG. if (const auto *CLHS = dyn_cast<Constant>(U->getOperand(0))) @@ -2356,15 +2393,15 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // Sub can have at most one carry bit. Thus we know that the output // is, at worst, one more bit than the inputs. Tmp = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (Tmp == 1) return 1; // Early out. + if (Tmp == 1) break; return std::min(Tmp, Tmp2)-1; case Instruction::Mul: { // The output of the Mul can be at most twice the valid bits in the inputs. unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); - if (SignBitsOp0 == 1) return 1; // Early out. + if (SignBitsOp0 == 1) break; unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); - if (SignBitsOp1 == 1) return 1; + if (SignBitsOp1 == 1) break; unsigned OutValidBits = (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1); return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; @@ -2671,7 +2708,7 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, return true; // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. - if (match(Op, m_FAdd(m_Value(), m_Zero()))) + if (match(Op, m_FAdd(m_Value(), m_PosZeroFP()))) return true; // sitofp and uitofp turn into +0.0 for zero. @@ -2712,6 +2749,24 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, (!SignBitOnly && CFP->getValueAPF().isZero()); } + // Handle vector of constants. + if (auto *CV = dyn_cast<Constant>(V)) { + if (CV->getType()->isVectorTy()) { + unsigned NumElts = CV->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); + if (!CFP) + return false; + if (CFP->getValueAPF().isNegative() && + (SignBitOnly || !CFP->getValueAPF().isZero())) + return false; + } + + // All non-negative ConstantFPs. + return true; + } + } + if (Depth == MaxDepth) return false; // Limit search depth. @@ -2749,6 +2804,12 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, // Widening/narrowing never change sign. return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1); + case Instruction::ExtractElement: + // Look through extract element. At the moment we keep this simple and skip + // tracking the specific element. But at least we might find information + // valid for all elements of the vector. + return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, + Depth + 1); case Instruction::Call: const auto *CI = cast<CallInst>(I); Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); @@ -2963,7 +3024,7 @@ static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType, if (!V) return nullptr; - // Insert the value in the new (sub) aggregrate + // Insert the value in the new (sub) aggregate return InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip), "tmp", InsertBefore); } @@ -2992,9 +3053,9 @@ static Value *BuildSubAggregate(Value *From, ArrayRef<unsigned> idx_range, return BuildSubAggregate(From, To, IndexedType, Idxs, IdxSkip, InsertBefore); } -/// Given an aggregrate and an sequence of indices, see if -/// the scalar value indexed is already around as a register, for example if it -/// were inserted directly into the aggregrate. +/// Given an aggregate and a sequence of indices, see if the scalar value +/// indexed is already around as a register, for example if it was inserted +/// directly into the aggregate. /// /// If InsertBefore is not null, this function will duplicate (modified) /// insertvalues when a part of a nested struct is extracted. @@ -3086,7 +3147,7 @@ Value *llvm::FindInsertedValue(Value *V, ArrayRef<unsigned> idx_range, /// pointer plus a constant offset. Return the base and offset to the caller. Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, const DataLayout &DL) { - unsigned BitWidth = DL.getPointerTypeSizeInBits(Ptr->getType()); + unsigned BitWidth = DL.getIndexTypeSizeInBits(Ptr->getType()); APInt ByteOffset(BitWidth, 0); // We walk up the defs but use a visited set to handle unreachable code. In @@ -3104,7 +3165,7 @@ Value *llvm::GetPointerBaseWithConstantOffset(Value *Ptr, int64_t &Offset, // means when we construct GEPOffset, we need to use the size // of GEP's pointer type rather than the size of the original // pointer type. - APInt GEPOffset(DL.getPointerTypeSizeInBits(Ptr->getType()), 0); + APInt GEPOffset(DL.getIndexTypeSizeInBits(Ptr->getType()), 0); if (!GEP->accumulateConstantOffset(DL, GEPOffset)) break; @@ -3326,7 +3387,8 @@ static uint64_t GetStringLengthH(const Value *V, /// If we can compute the length of the string pointed to by /// the specified pointer, return 'len+1'. If we can't, return 0. uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { - if (!V->getType()->isPointerTy()) return 0; + if (!V->getType()->isPointerTy()) + return 0; SmallPtrSet<const PHINode*, 32> PHIs; uint64_t Len = GetStringLengthH(V, PHIs, CharSize); @@ -3335,7 +3397,24 @@ uint64_t llvm::GetStringLength(const Value *V, unsigned CharSize) { return Len == ~0ULL ? 1 : Len; } -/// \brief \p PN defines a loop-variant pointer to an object. Check if the +const Value *llvm::getArgumentAliasingToReturnedPointer(ImmutableCallSite CS) { + assert(CS && + "getArgumentAliasingToReturnedPointer only works on nonnull CallSite"); + if (const Value *RV = CS.getReturnedArgOperand()) + return RV; + // This can be used only as a aliasing property. + if (isIntrinsicReturningPointerAliasingArgumentWithoutCapturing(CS)) + return CS.getArgOperand(0); + return nullptr; +} + +bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( + ImmutableCallSite CS) { + return CS.getIntrinsicID() == Intrinsic::launder_invariant_group || + CS.getIntrinsicID() == Intrinsic::strip_invariant_group; +} + +/// \p PN defines a loop-variant pointer to an object. Check if the /// previous iteration of the loop was referring to the same object as \p PN. static bool isSameUnderlyingObjectInLoop(const PHINode *PN, const LoopInfo *LI) { @@ -3380,11 +3459,21 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, // An alloca can't be further simplified. return V; } else { - if (auto CS = CallSite(V)) - if (Value *RV = CS.getReturnedArgOperand()) { - V = RV; + if (auto CS = CallSite(V)) { + // CaptureTracking can know about special capturing properties of some + // intrinsics like launder.invariant.group, that can't be expressed with + // the attributes, but have properties like returning aliasing pointer. + // Because some analysis may assume that nocaptured pointer is not + // returned from some special intrinsic (because function would have to + // be marked with returns attribute), it is crucial to use this function + // because it should be in sync with CaptureTracking. Not using it may + // cause weird miscompilations where 2 aliasing pointers are assumed to + // noalias. + if (auto *RP = getArgumentAliasingToReturnedPointer(CS)) { + V = RP; continue; } + } // See if InstructionSimplify knows any relevant tricks. if (Instruction *I = dyn_cast<Instruction>(V)) @@ -3658,6 +3747,48 @@ OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, return OverflowResult::MayOverflow; } +OverflowResult llvm::computeOverflowForSignedMul(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // Multiplying n * m significant bits yields a result of n + m significant + // bits. If the total number of significant bits does not exceed the + // result bit width (minus 1), there is no overflow. + // This means if we have enough leading sign bits in the operands + // we can guarantee that the result does not overflow. + // Ref: "Hacker's Delight" by Henry Warren + unsigned BitWidth = LHS->getType()->getScalarSizeInBits(); + + // Note that underestimating the number of sign bits gives a more + // conservative answer. + unsigned SignBits = ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) + + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT); + + // First handle the easy case: if we have enough sign bits there's + // definitely no overflow. + if (SignBits > BitWidth + 1) + return OverflowResult::NeverOverflows; + + // There are two ambiguous cases where there can be no overflow: + // SignBits == BitWidth + 1 and + // SignBits == BitWidth + // The second case is difficult to check, therefore we only handle the + // first case. + if (SignBits == BitWidth + 1) { + // It overflows only when both arguments are negative and the true + // product is exactly the minimum negative number. + // E.g. mul i16 with 17 sign bits: 0xff00 * 0xff80 = 0x8000 + // For simplicity we just check if at least one side is not negative. + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + if (LHSKnown.isNonNegative() || RHSKnown.isNonNegative()) + return OverflowResult::NeverOverflows; + } + return OverflowResult::MayOverflow; +} + OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, const DataLayout &DL, @@ -3684,7 +3815,7 @@ OverflowResult llvm::computeOverflowForUnsignedAdd(const Value *LHS, return OverflowResult::MayOverflow; } -/// \brief Return true if we can prove that adding the two values of the +/// Return true if we can prove that adding the two values of the /// knownbits will not overflow. /// Otherwise return false. static bool checkRippleForSignedAdd(const KnownBits &LHSKnown, @@ -3787,6 +3918,47 @@ static OverflowResult computeOverflowForSignedAdd(const Value *LHS, return OverflowResult::MayOverflow; } +OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // If the LHS is negative and the RHS is non-negative, no unsigned wrap. + KnownBits LHSKnown = computeKnownBits(LHS, DL, /*Depth=*/0, AC, CxtI, DT); + KnownBits RHSKnown = computeKnownBits(RHS, DL, /*Depth=*/0, AC, CxtI, DT); + if (LHSKnown.isNegative() && RHSKnown.isNonNegative()) + return OverflowResult::NeverOverflows; + + return OverflowResult::MayOverflow; +} + +OverflowResult llvm::computeOverflowForSignedSub(const Value *LHS, + const Value *RHS, + const DataLayout &DL, + AssumptionCache *AC, + const Instruction *CxtI, + const DominatorTree *DT) { + // If LHS and RHS each have at least two sign bits, the subtraction + // cannot overflow. + if (ComputeNumSignBits(LHS, DL, 0, AC, CxtI, DT) > 1 && + ComputeNumSignBits(RHS, DL, 0, AC, CxtI, DT) > 1) + return OverflowResult::NeverOverflows; + + KnownBits LHSKnown = computeKnownBits(LHS, DL, 0, AC, CxtI, DT); + + KnownBits RHSKnown = computeKnownBits(RHS, DL, 0, AC, CxtI, DT); + + // Subtraction of two 2's complement numbers having identical signs will + // never overflow. + if ((LHSKnown.isNegative() && RHSKnown.isNegative()) || + (LHSKnown.isNonNegative() && RHSKnown.isNonNegative())) + return OverflowResult::NeverOverflows; + + // TODO: implement logic similar to checkRippleForAdd + return OverflowResult::MayOverflow; +} + bool llvm::isOverflowIntrinsicNoWrap(const IntrinsicInst *II, const DominatorTree &DT) { #ifndef NDEBUG @@ -3928,6 +4100,15 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { return true; } +bool llvm::isGuaranteedToTransferExecutionToSuccessor(const BasicBlock *BB) { + // TODO: This is slightly consdervative for invoke instruction since exiting + // via an exception *is* normal control for them. + for (auto I = BB->begin(), E = BB->end(); I != E; ++I) + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) + return false; + return true; +} + bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, const Loop *L) { // The loop header is guaranteed to be executed for every iteration. @@ -4161,11 +4342,107 @@ static SelectPatternResult matchClamp(CmpInst::Predicate Pred, return {SPF_UNKNOWN, SPNB_NA, false}; } +/// Recognize variations of: +/// a < c ? min(a,b) : min(b,c) ==> min(min(a,b),min(b,c)) +static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TVal, Value *FVal, + unsigned Depth) { + // TODO: Allow FP min/max with nnan/nsz. + assert(CmpInst::isIntPredicate(Pred) && "Expected integer comparison"); + + Value *A, *B; + SelectPatternResult L = matchSelectPattern(TVal, A, B, nullptr, Depth + 1); + if (!SelectPatternResult::isMinOrMax(L.Flavor)) + return {SPF_UNKNOWN, SPNB_NA, false}; + + Value *C, *D; + SelectPatternResult R = matchSelectPattern(FVal, C, D, nullptr, Depth + 1); + if (L.Flavor != R.Flavor) + return {SPF_UNKNOWN, SPNB_NA, false}; + + // We have something like: x Pred y ? min(a, b) : min(c, d). + // Try to match the compare to the min/max operations of the select operands. + // First, make sure we have the right compare predicate. + switch (L.Flavor) { + case SPF_SMIN: + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_SMAX: + if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SGE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_UMIN: + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + case SPF_UMAX: + if (Pred == ICmpInst::ICMP_ULT || Pred == ICmpInst::ICMP_ULE) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(CmpLHS, CmpRHS); + } + if (Pred == ICmpInst::ICMP_UGT || Pred == ICmpInst::ICMP_UGE) + break; + return {SPF_UNKNOWN, SPNB_NA, false}; + default: + return {SPF_UNKNOWN, SPNB_NA, false}; + } + + // If there is a common operand in the already matched min/max and the other + // min/max operands match the compare operands (either directly or inverted), + // then this is min/max of the same flavor. + + // a pred c ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) + // ~c pred ~a ? m(a, b) : m(c, b) --> m(m(a, b), m(c, b)) + if (D == B) { + if ((CmpLHS == A && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && + match(A, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // a pred d ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) + // ~d pred ~a ? m(a, b) : m(b, d) --> m(m(a, b), m(b, d)) + if (C == B) { + if ((CmpLHS == A && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && + match(A, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // b pred c ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) + // ~c pred ~b ? m(a, b) : m(c, a) --> m(m(a, b), m(c, a)) + if (D == A) { + if ((CmpLHS == B && CmpRHS == C) || (match(C, m_Not(m_Specific(CmpLHS))) && + match(B, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + // b pred d ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) + // ~d pred ~b ? m(a, b) : m(a, d) --> m(m(a, b), m(a, d)) + if (C == A) { + if ((CmpLHS == B && CmpRHS == D) || (match(D, m_Not(m_Specific(CmpLHS))) && + match(B, m_Not(m_Specific(CmpRHS))))) + return {L.Flavor, SPNB_NA, false}; + } + + return {SPF_UNKNOWN, SPNB_NA, false}; +} + /// Match non-obvious integer minimum and maximum sequences. static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS, Value *TrueVal, Value *FalseVal, - Value *&LHS, Value *&RHS) { + Value *&LHS, Value *&RHS, + unsigned Depth) { // Assume success. If there's no match, callers should not use these anyway. LHS = TrueVal; RHS = FalseVal; @@ -4174,6 +4451,10 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) return SPR; + SPR = matchMinMaxOfMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, Depth); + if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) + return SPR; + if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4230,11 +4511,33 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, return {SPF_UNKNOWN, SPNB_NA, false}; } +bool llvm::isKnownNegation(const Value *X, const Value *Y, bool NeedNSW) { + assert(X && Y && "Invalid operand"); + + // X = sub (0, Y) || X = sub nsw (0, Y) + if ((!NeedNSW && match(X, m_Sub(m_ZeroInt(), m_Specific(Y)))) || + (NeedNSW && match(X, m_NSWSub(m_ZeroInt(), m_Specific(Y))))) + return true; + + // Y = sub (0, X) || Y = sub nsw (0, X) + if ((!NeedNSW && match(Y, m_Sub(m_ZeroInt(), m_Specific(X)))) || + (NeedNSW && match(Y, m_NSWSub(m_ZeroInt(), m_Specific(X))))) + return true; + + // X = sub (A, B), Y = sub (B, A) || X = sub nsw (A, B), Y = sub nsw (B, A) + Value *A, *B; + return (!NeedNSW && (match(X, m_Sub(m_Value(A), m_Value(B))) && + match(Y, m_Sub(m_Specific(B), m_Specific(A))))) || + (NeedNSW && (match(X, m_NSWSub(m_Value(A), m_Value(B))) && + match(Y, m_NSWSub(m_Specific(B), m_Specific(A))))); +} + static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, FastMathFlags FMF, Value *CmpLHS, Value *CmpRHS, Value *TrueVal, Value *FalseVal, - Value *&LHS, Value *&RHS) { + Value *&LHS, Value *&RHS, + unsigned Depth) { LHS = CmpLHS; RHS = CmpRHS; @@ -4327,30 +4630,54 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, case FCmpInst::FCMP_OLE: return {SPF_FMINNUM, NaNBehavior, Ordered}; } } - - const APInt *C1; - if (match(CmpRHS, m_APInt(C1))) { - if ((CmpLHS == TrueVal && match(FalseVal, m_Neg(m_Specific(CmpLHS)))) || - (CmpLHS == FalseVal && match(TrueVal, m_Neg(m_Specific(CmpLHS))))) { - - // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X - // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X - if (Pred == ICmpInst::ICMP_SGT && - (C1->isNullValue() || C1->isAllOnesValue())) { - return {(CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; - } - - // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X - // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X - if (Pred == ICmpInst::ICMP_SLT && - (C1->isNullValue() || C1->isOneValue())) { - return {(CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; - } + + if (isKnownNegation(TrueVal, FalseVal)) { + // Sign-extending LHS does not change its sign, so TrueVal/FalseVal can + // match against either LHS or sext(LHS). + auto MaybeSExtCmpLHS = + m_CombineOr(m_Specific(CmpLHS), m_SExt(m_Specific(CmpLHS))); + auto ZeroOrAllOnes = m_CombineOr(m_ZeroInt(), m_AllOnes()); + auto ZeroOrOne = m_CombineOr(m_ZeroInt(), m_One()); + if (match(TrueVal, MaybeSExtCmpLHS)) { + // Set the return values. If the compare uses the negated value (-X >s 0), + // swap the return values because the negated value is always 'RHS'. + LHS = TrueVal; + RHS = FalseVal; + if (match(CmpLHS, m_Neg(m_Specific(FalseVal)))) + std::swap(LHS, RHS); + + // (X >s 0) ? X : -X or (X >s -1) ? X : -X --> ABS(X) + // (-X >s 0) ? -X : X or (-X >s -1) ? -X : X --> ABS(X) + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) + return {SPF_ABS, SPNB_NA, false}; + + // (X <s 0) ? X : -X or (X <s 1) ? X : -X --> NABS(X) + // (-X <s 0) ? -X : X or (-X <s 1) ? -X : X --> NABS(X) + if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) + return {SPF_NABS, SPNB_NA, false}; + } + else if (match(FalseVal, MaybeSExtCmpLHS)) { + // Set the return values. If the compare uses the negated value (-X >s 0), + // swap the return values because the negated value is always 'RHS'. + LHS = FalseVal; + RHS = TrueVal; + if (match(CmpLHS, m_Neg(m_Specific(TrueVal)))) + std::swap(LHS, RHS); + + // (X >s 0) ? -X : X or (X >s -1) ? -X : X --> NABS(X) + // (-X >s 0) ? X : -X or (-X >s -1) ? X : -X --> NABS(X) + if (Pred == ICmpInst::ICMP_SGT && match(CmpRHS, ZeroOrAllOnes)) + return {SPF_NABS, SPNB_NA, false}; + + // (X <s 0) ? -X : X or (X <s 1) ? -X : X --> ABS(X) + // (-X <s 0) ? X : -X or (-X <s 1) ? X : -X --> ABS(X) + if (Pred == ICmpInst::ICMP_SLT && match(CmpRHS, ZeroOrOne)) + return {SPF_ABS, SPNB_NA, false}; } } if (CmpInst::isIntPredicate(Pred)) - return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); + return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS, Depth); // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar // may return either -0.0 or 0.0, so fcmp/select pair has stricter @@ -4367,7 +4694,7 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, /// /// The function processes the case when type of true and false values of a /// select instruction differs from type of the cmp instruction operands because -/// of a cast instructon. The function checks if it is legal to move the cast +/// of a cast instruction. The function checks if it is legal to move the cast /// operation after "select". If yes, it returns the new second value of /// "select" (with the assumption that cast is moved): /// 1. As operand of cast instruction when both values of "select" are same cast @@ -4471,7 +4798,11 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, } SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, - Instruction::CastOps *CastOp) { + Instruction::CastOps *CastOp, + unsigned Depth) { + if (Depth >= MaxDepth) + return {SPF_UNKNOWN, SPNB_NA, false}; + SelectInst *SI = dyn_cast<SelectInst>(V); if (!SI) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4500,7 +4831,7 @@ SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, FMF.setNoSignedZeros(); return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, cast<CastInst>(TrueVal)->getOperand(0), C, - LHS, RHS); + LHS, RHS, Depth); } if (Value *C = lookThroughCast(CmpI, FalseVal, TrueVal, CastOp)) { // If this is a potential fmin/fmax with a cast to integer, then ignore @@ -4509,11 +4840,35 @@ SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, FMF.setNoSignedZeros(); return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, C, cast<CastInst>(FalseVal)->getOperand(0), - LHS, RHS); + LHS, RHS, Depth); } } return ::matchSelectPattern(Pred, FMF, CmpLHS, CmpRHS, TrueVal, FalseVal, - LHS, RHS); + LHS, RHS, Depth); +} + +CmpInst::Predicate llvm::getMinMaxPred(SelectPatternFlavor SPF, bool Ordered) { + if (SPF == SPF_SMIN) return ICmpInst::ICMP_SLT; + if (SPF == SPF_UMIN) return ICmpInst::ICMP_ULT; + if (SPF == SPF_SMAX) return ICmpInst::ICMP_SGT; + if (SPF == SPF_UMAX) return ICmpInst::ICMP_UGT; + if (SPF == SPF_FMINNUM) + return Ordered ? FCmpInst::FCMP_OLT : FCmpInst::FCMP_ULT; + if (SPF == SPF_FMAXNUM) + return Ordered ? FCmpInst::FCMP_OGT : FCmpInst::FCMP_UGT; + llvm_unreachable("unhandled!"); +} + +SelectPatternFlavor llvm::getInverseMinMaxFlavor(SelectPatternFlavor SPF) { + if (SPF == SPF_SMIN) return SPF_SMAX; + if (SPF == SPF_UMIN) return SPF_UMAX; + if (SPF == SPF_SMAX) return SPF_SMIN; + if (SPF == SPF_UMAX) return SPF_UMIN; + llvm_unreachable("unhandled!"); +} + +CmpInst::Predicate llvm::getInverseMinMaxPred(SelectPatternFlavor SPF) { + return getMinMaxPred(getInverseMinMaxFlavor(SPF)); } /// Return true if "icmp Pred LHS RHS" is always true. |