diff options
Diffstat (limited to 'llvm/lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | llvm/lib/Analysis/ValueTracking.cpp | 1413 |
1 files changed, 980 insertions, 433 deletions
diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp index ad6765e2514b4..43caaa62c2ec5 100644 --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -24,6 +24,7 @@ #include "llvm/ADT/StringRef.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumeBundleQueries.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GuardUtils.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -34,7 +35,6 @@ #include "llvm/IR/Argument.h" #include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" @@ -163,8 +163,61 @@ static const Instruction *safeCxtI(const Value *V, const Instruction *CxtI) { return nullptr; } -static void computeKnownBits(const Value *V, KnownBits &Known, - unsigned Depth, const Query &Q); +static bool getShuffleDemandedElts(const ShuffleVectorInst *Shuf, + const APInt &DemandedElts, + APInt &DemandedLHS, APInt &DemandedRHS) { + // The length of scalable vectors is unknown at compile time, thus we + // cannot check their values + if (isa<ScalableVectorType>(Shuf->getType())) + return false; + + int NumElts = + cast<VectorType>(Shuf->getOperand(0)->getType())->getNumElements(); + int NumMaskElts = Shuf->getType()->getNumElements(); + DemandedLHS = DemandedRHS = APInt::getNullValue(NumElts); + if (DemandedElts.isNullValue()) + return true; + // Simple case of a shuffle with zeroinitializer. + if (all_of(Shuf->getShuffleMask(), [](int Elt) { return Elt == 0; })) { + DemandedLHS.setBit(0); + return true; + } + for (int i = 0; i != NumMaskElts; ++i) { + if (!DemandedElts[i]) + continue; + int M = Shuf->getMaskValue(i); + assert(M < (NumElts * 2) && "Invalid shuffle mask constant"); + + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (M == -1) + return false; + if (M < NumElts) + DemandedLHS.setBit(M % NumElts); + else + DemandedRHS.setBit(M % NumElts); + } + + return true; +} + +static void computeKnownBits(const Value *V, const APInt &DemandedElts, + KnownBits &Known, unsigned Depth, const Query &Q); + +static void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, + const Query &Q) { + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa<ScalableVectorType>(V->getType())) { + Known.resetAll(); + return; + } + + auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); + APInt DemandedElts = + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); + computeKnownBits(V, DemandedElts, Known, Depth, Q); +} void llvm::computeKnownBits(const Value *V, KnownBits &Known, const DataLayout &DL, unsigned Depth, @@ -175,6 +228,18 @@ void llvm::computeKnownBits(const Value *V, KnownBits &Known, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } +void llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, + KnownBits &Known, const DataLayout &DL, + unsigned Depth, AssumptionCache *AC, + const Instruction *CxtI, const DominatorTree *DT, + OptimizationRemarkEmitter *ORE, bool UseInstrInfo) { + ::computeKnownBits(V, DemandedElts, Known, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); +} + +static KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const Query &Q); + static KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q); @@ -188,6 +253,17 @@ KnownBits llvm::computeKnownBits(const Value *V, const DataLayout &DL, V, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); } +KnownBits llvm::computeKnownBits(const Value *V, const APInt &DemandedElts, + const DataLayout &DL, unsigned Depth, + AssumptionCache *AC, const Instruction *CxtI, + const DominatorTree *DT, + OptimizationRemarkEmitter *ORE, + bool UseInstrInfo) { + return ::computeKnownBits( + V, DemandedElts, Depth, + Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo, ORE)); +} + bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, @@ -235,6 +311,9 @@ bool llvm::isKnownToBeAPowerOfTwo(const Value *V, const DataLayout &DL, V, OrZero, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } +static bool isKnownNonZero(const Value *V, const APInt &DemandedElts, + unsigned Depth, const Query &Q); + static bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q); bool llvm::isKnownNonZero(const Value *V, const DataLayout &DL, unsigned Depth, @@ -295,8 +374,21 @@ bool llvm::MaskedValueIsZero(const Value *V, const APInt &Mask, V, Mask, Depth, Query(DL, AC, safeCxtI(V, CxtI), DT, UseInstrInfo)); } +static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const Query &Q); + static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, - const Query &Q); + const Query &Q) { + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa<ScalableVectorType>(V->getType())) + return 1; + + auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); + APInt DemandedElts = + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); + return ComputeNumSignBits(V, DemandedElts, Depth, Q); +} unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, unsigned Depth, AssumptionCache *AC, @@ -307,26 +399,27 @@ unsigned llvm::ComputeNumSignBits(const Value *V, const DataLayout &DL, } static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, - bool NSW, + bool NSW, const APInt &DemandedElts, KnownBits &KnownOut, KnownBits &Known2, unsigned Depth, const Query &Q) { - unsigned BitWidth = KnownOut.getBitWidth(); + computeKnownBits(Op1, DemandedElts, KnownOut, Depth + 1, Q); - // If an initial sequence of bits in the result is not needed, the - // corresponding bits in the operands are not needed. - KnownBits LHSKnown(BitWidth); - computeKnownBits(Op0, LHSKnown, Depth + 1, Q); - computeKnownBits(Op1, Known2, Depth + 1, Q); + // If one operand is unknown and we have no nowrap information, + // the result will be unknown independently of the second operand. + if (KnownOut.isUnknown() && !NSW) + return; - KnownOut = KnownBits::computeForAddSub(Add, NSW, LHSKnown, Known2); + computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); + KnownOut = KnownBits::computeForAddSub(Add, NSW, Known2, KnownOut); } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, - KnownBits &Known, KnownBits &Known2, - unsigned Depth, const Query &Q) { + const APInt &DemandedElts, KnownBits &Known, + KnownBits &Known2, unsigned Depth, + const Query &Q) { unsigned BitWidth = Known.getBitWidth(); - computeKnownBits(Op1, Known, Depth + 1, Q); - computeKnownBits(Op0, Known2, Depth + 1, Q); + computeKnownBits(Op1, DemandedElts, Known, Depth + 1, Q); + computeKnownBits(Op0, DemandedElts, Known2, Depth + 1, Q); bool isKnownNegative = false; bool isKnownNonNegative = false; @@ -535,6 +628,29 @@ bool llvm::isValidAssumeForContext(const Instruction *Inv, // feeding the assume is trivially true, thus causing the removal of // the assume). + if (Inv->getParent() == CxtI->getParent()) { + // If Inv and CtxI are in the same block, check if the assume (Inv) is first + // in the BB. + if (Inv->comesBefore(CxtI)) + return true; + + // Don't let an assume affect itself - this would cause the problems + // `isEphemeralValueOf` is trying to prevent, and it would also make + // the loop below go out of bounds. + if (Inv == CxtI) + return false; + + // The context comes first, but they're both in the same block. + // Make sure there is nothing in between that might interrupt + // the control flow, not even CxtI itself. + for (BasicBlock::const_iterator I(CxtI), IE(Inv); I != IE; ++I) + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) + return false; + + return !isEphemeralValueOf(Inv, CxtI); + } + + // Inv and CxtI are in different blocks. if (DT) { if (DT->dominates(Inv, CxtI)) return true; @@ -543,37 +659,7 @@ bool llvm::isValidAssumeForContext(const Instruction *Inv, return true; } - // With or without a DT, the only remaining case we will check is if the - // instructions are in the same BB. Give up if that is not the case. - if (Inv->getParent() != CxtI->getParent()) - return false; - - // If we have a dom tree, then we now know that the assume doesn't dominate - // the other instruction. If we don't have a dom tree then we can check if - // the assume is first in the BB. - if (!DT) { - // Search forward from the assume until we reach the context (or the end - // of the block); the common case is that the assume will come first. - for (auto I = std::next(BasicBlock::const_iterator(Inv)), - IE = Inv->getParent()->end(); I != IE; ++I) - if (&*I == CxtI) - return true; - } - - // Don't let an assume affect itself - this would cause the problems - // `isEphemeralValueOf` is trying to prevent, and it would also make - // the loop below go out of bounds. - if (Inv == CxtI) - return false; - - // The context comes first, but they're both in the same block. - // Make sure there is nothing in between that might interrupt - // the control flow, not even CxtI itself. - for (BasicBlock::const_iterator I(CxtI), IE(Inv); I != IE; ++I) - if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) - return false; - - return !isEphemeralValueOf(Inv, CxtI); + return false; } static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { @@ -592,10 +678,6 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { CmpInst::Predicate Pred; if (!match(Cmp, m_c_ICmp(Pred, m_V, m_Value(RHS)))) return false; - // Canonicalize 'v' to be on the LHS of the comparison. - if (Cmp->getOperand(1) != RHS) - Pred = CmpInst::getSwappedPredicate(Pred); - // assume(v u> y) -> assume(v != 0) if (Pred == ICmpInst::ICMP_UGT) return true; @@ -615,6 +697,16 @@ static bool isKnownNonZeroFromAssume(const Value *V, const Query &Q) { return !TrueValues.contains(APInt::getNullValue(CI->getBitWidth())); }; + if (Q.CxtI && V->getType()->isPointerTy()) { + SmallVector<Attribute::AttrKind, 2> AttrKinds{Attribute::NonNull}; + if (!NullPointerIsDefined(Q.CxtI->getFunction(), + V->getType()->getPointerAddressSpace())) + AttrKinds.push_back(Attribute::Dereferenceable); + + if (getKnowledgeValidInContext(V, AttrKinds, Q.CxtI, Q.DT, Q.AC)) + return true; + } + for (auto &AssumeVH : Q.AC->assumptionsFor(V)) { if (!AssumeVH) continue; @@ -693,6 +785,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (!Cmp) continue; + // Note that ptrtoint may change the bitwidth. Value *A, *B; auto m_V = m_CombineOr(m_Specific(V), m_PtrToInt(m_Specific(V))); @@ -705,18 +798,18 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v = a) if (match(Cmp, m_c_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); Known.Zero |= RHSKnown.Zero; Known.One |= RHSKnown.One; // assume(v & b = a) } else if (match(Cmp, m_c_ICmp(Pred, m_c_And(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits MaskKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // known bits from the RHS to V. @@ -726,10 +819,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_And(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits MaskKnown(BitWidth); - computeKnownBits(B, MaskKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits MaskKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in the mask that are known to be one, we can propagate // inverted known bits from the RHS to V. @@ -739,10 +832,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_c_Or(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. @@ -752,10 +845,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Or(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. @@ -765,10 +858,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_c_Xor(m_V, m_Value(B)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate known // bits from the RHS to V. For those bits in B that are known to be one, @@ -781,10 +874,10 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_c_Xor(m_V, m_Value(B))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); - KnownBits BKnown(BitWidth); - computeKnownBits(B, BKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + KnownBits BKnown = + computeKnownBits(B, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in B that are known to be zero, we can propagate // inverted known bits from the RHS to V. For those bits in B that are @@ -797,8 +890,9 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); + // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. RHSKnown.Zero.lshrInPlace(C); @@ -809,8 +903,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. RHSKnown.One.lshrInPlace(C); @@ -821,8 +915,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. Known.Zero |= RHSKnown.Zero << C; @@ -831,8 +925,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Cmp, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT) && C < BitWidth) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. Known.Zero |= RHSKnown.One << C; @@ -843,8 +937,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v >=_s c) where c is non-negative if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -856,8 +950,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v >_s c) where c is at least -1. if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isAllOnes() || RHSKnown.isNonNegative()) { // We know that the sign bit is zero. @@ -869,8 +963,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v <=_s c) where c is negative if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth + 1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth + 1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -882,8 +976,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v <_s c) where c is non-positive if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); if (RHSKnown.isZero() || RHSKnown.isNegative()) { // We know that the sign bit is one. @@ -895,8 +989,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v <=_u c) if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // Whatever high bits in c are zero are known to be zero. Known.Zero.setHighBits(RHSKnown.countMinLeadingZeros()); @@ -906,8 +1000,8 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, // assume(v <_u c) if (match(Cmp, m_ICmp(Pred, m_V, m_Value(A))) && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { - KnownBits RHSKnown(BitWidth); - computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); + KnownBits RHSKnown = + computeKnownBits(A, Depth+1, Query(Q, I)).anyextOrTrunc(BitWidth); // If the RHS is known zero, then this assumption must be wrong (nothing // is unsigned less than zero). Signal a conflict and get out of here. @@ -957,16 +1051,17 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, /// amount. The results from calling KZF and KOF are conservatively combined for /// all permitted shift amounts. static void computeKnownBitsFromShiftOperator( - const Operator *I, KnownBits &Known, KnownBits &Known2, - unsigned Depth, const Query &Q, + const Operator *I, const APInt &DemandedElts, KnownBits &Known, + KnownBits &Known2, unsigned Depth, const Query &Q, function_ref<APInt(const APInt &, unsigned)> KZF, function_ref<APInt(const APInt &, unsigned)> KOF) { unsigned BitWidth = Known.getBitWidth(); - if (auto *SA = dyn_cast<ConstantInt>(I->getOperand(1))) { - unsigned ShiftAmt = SA->getLimitedValue(BitWidth-1); + computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); + if (Known.isConstant()) { + unsigned ShiftAmt = Known.getConstant().getLimitedValue(BitWidth - 1); - computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known, Depth + 1, Q); Known.Zero = KZF(Known.Zero, ShiftAmt); Known.One = KOF(Known.One, ShiftAmt); // If the known bits conflict, this must be an overflowing left shift, so @@ -978,11 +1073,10 @@ static void computeKnownBitsFromShiftOperator( return; } - computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); - // If the shift amount could be greater than or equal to the bit-width of the // LHS, the value could be poison, but bail out because the check below is - // expensive. TODO: Should we just carry on? + // expensive. + // TODO: Should we just carry on? if (Known.getMaxValue().uge(BitWidth)) { Known.resetAll(); return; @@ -1006,12 +1100,13 @@ static void computeKnownBitsFromShiftOperator( // Early exit if we can't constrain any well-defined shift amount. if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) && !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) { - ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), Depth + 1, Q); + ShifterOperandIsNonZero = + isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q); if (!*ShifterOperandIsNonZero) return; } - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); Known.Zero.setAllBits(); Known.One.setAllBits(); @@ -1028,7 +1123,7 @@ static void computeKnownBitsFromShiftOperator( if (ShiftAmt == 0) { if (!ShifterOperandIsNonZero.hasValue()) ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), Depth + 1, Q); + isKnownNonZero(I->getOperand(1), DemandedElts, Depth + 1, Q); if (*ShifterOperandIsNonZero) continue; } @@ -1043,11 +1138,13 @@ static void computeKnownBitsFromShiftOperator( Known.setAllZero(); } -static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, - unsigned Depth, const Query &Q) { +static void computeKnownBitsFromOperator(const Operator *I, + const APInt &DemandedElts, + KnownBits &Known, unsigned Depth, + const Query &Q) { unsigned BitWidth = Known.getBitWidth(); - KnownBits Known2(Known); + KnownBits Known2(BitWidth); switch (I->getOpcode()) { default: break; case Instruction::Load: @@ -1057,13 +1154,10 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; case Instruction::And: { // If either the LHS or the RHS are Zero, the result is zero. - computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - // Output known-1 bits are only known if set in both the LHS & RHS. - Known.One &= Known2.One; - // Output known-0 are known to be clear if zero in either the LHS | RHS. - Known.Zero |= Known2.Zero; + Known &= Known2; // and(x, add (x, -1)) is a common idiom that always clears the low bit; // here we handle the more general case of adding any odd number by @@ -1074,36 +1168,28 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, if (!Known.Zero[0] && !Known.One[0] && match(I, m_c_BinOp(m_Value(X), m_Add(m_Deferred(X), m_Value(Y))))) { Known2.resetAll(); - computeKnownBits(Y, Known2, Depth + 1, Q); + computeKnownBits(Y, DemandedElts, Known2, Depth + 1, Q); if (Known2.countMinTrailingOnes() > 0) Known.Zero.setBit(0); } break; } case Instruction::Or: - computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - // Output known-0 bits are only known if clear in both the LHS & RHS. - Known.Zero &= Known2.Zero; - // Output known-1 are known to be set if set in either the LHS | RHS. - Known.One |= Known2.One; + Known |= Known2; break; - case Instruction::Xor: { - computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + case Instruction::Xor: + computeKnownBits(I->getOperand(1), DemandedElts, Known, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); - // Output known-0 bits are known if clear or set in both the LHS & RHS. - APInt KnownZeroOut = (Known.Zero & Known2.Zero) | (Known.One & Known2.One); - // Output known-1 are known to be set if set in only one of the LHS, RHS. - Known.One = (Known.Zero & Known2.One) | (Known.One & Known2.Zero); - Known.Zero = std::move(KnownZeroOut); + Known ^= Known2; break; - } case Instruction::Mul: { bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); - computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, Known, - Known2, Depth, Q); + computeKnownBitsMul(I->getOperand(0), I->getOperand(1), NSW, DemandedElts, + Known, Known2, Depth, Q); break; } case Instruction::UDiv: { @@ -1207,9 +1293,9 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, Q.DL.getTypeSizeInBits(ScalarTy); assert(SrcBitWidth && "SrcBitWidth can't be zero"); - Known = Known.zextOrTrunc(SrcBitWidth, false); + Known = Known.anyextOrTrunc(SrcBitWidth); computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); - Known = Known.zextOrTrunc(BitWidth, true /* ExtendedBitsAreKnownZero */); + Known = Known.zextOrTrunc(BitWidth); break; } case Instruction::BitCast: { @@ -1254,7 +1340,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, return KOResult; }; - computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, + KZF, KOF); break; } case Instruction::LShr: { @@ -1270,7 +1357,8 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, return KnownOne.lshr(ShiftAmt); }; - computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, + KZF, KOF); break; } case Instruction::AShr: { @@ -1283,19 +1371,20 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, return KnownOne.ashr(ShiftAmt); }; - computeKnownBitsFromShiftOperator(I, Known, Known2, Depth, Q, KZF, KOF); + computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q, + KZF, KOF); break; } case Instruction::Sub: { bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(false, I->getOperand(0), I->getOperand(1), NSW, - Known, Known2, Depth, Q); + DemandedElts, Known, Known2, Depth, Q); break; } case Instruction::Add: { bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I)); computeKnownBitsAddSub(true, I->getOperand(0), I->getOperand(1), NSW, - Known, Known2, Depth, Q); + DemandedElts, Known, Known2, Depth, Q); break; } case Instruction::SRem: @@ -1355,17 +1444,9 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, Known.Zero.setHighBits(Leaders); break; } - - case Instruction::Alloca: { - const AllocaInst *AI = cast<AllocaInst>(I); - unsigned Align = AI->getAlignment(); - if (Align == 0) - Align = Q.DL.getABITypeAlignment(AI->getAllocatedType()); - - if (Align > 0) - Known.Zero.setLowBits(countTrailingZeros(Align)); + case Instruction::Alloca: + Known.Zero.setLowBits(Log2(cast<AllocaInst>(I)->getAlign())); break; - } case Instruction::GetElementPtr: { // Analyze all of the subscripts of this getelementptr instruction // to determine if we can prove known low zero bits. @@ -1375,6 +1456,10 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, gep_type_iterator GTI = gep_type_begin(I); for (unsigned i = 1, e = I->getNumOperands(); i != e; ++i, ++GTI) { + // TrailZ can only become smaller, short-circuit if we hit zero. + if (TrailZ == 0) + break; + Value *Index = I->getOperand(i); if (StructType *STy = GTI.getStructTypeOrNull()) { // Handle struct member offset arithmetic. @@ -1400,7 +1485,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } unsigned GEPOpiBits = Index->getType()->getScalarSizeInBits(); - uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy); + uint64_t TypeSize = Q.DL.getTypeAllocSize(IndexedTy).getKnownMinSize(); LocalKnown.Zero = LocalKnown.One = APInt(GEPOpiBits, 0); computeKnownBits(Index, LocalKnown, Depth + 1, Q); TrailZ = std::min(TrailZ, @@ -1457,7 +1542,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, computeKnownBits(R, Known2, Depth + 1, RecQ); // We need to take the minimum number of known bits - KnownBits Known3(Known); + KnownBits Known3(BitWidth); RecQ.CxtI = LInst; computeKnownBits(L, Known3, Depth + 1, RecQ); @@ -1549,7 +1634,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, if (MDNode *MD = Q.IIQ.getMetadata(cast<Instruction>(I), LLVMContext::MD_range)) computeKnownBitsFromRangeMetadata(*MD, Known); - if (const Value *RV = ImmutableCallSite(I).getReturnedArgOperand()) { + if (const Value *RV = cast<CallBase>(I)->getReturnedArgOperand()) { computeKnownBits(RV, Known2, Depth + 1, Q); Known.Zero |= Known2.Zero; Known.One |= Known2.One; @@ -1558,12 +1643,12 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, switch (II->getIntrinsicID()) { default: break; case Intrinsic::bitreverse: - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); Known.Zero |= Known2.Zero.reverseBits(); Known.One |= Known2.One.reverseBits(); break; case Intrinsic::bswap: - computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); + computeKnownBits(I->getOperand(0), DemandedElts, Known2, Depth + 1, Q); Known.Zero |= Known2.Zero.byteSwap(); Known.One |= Known2.One.byteSwap(); break; @@ -1611,7 +1696,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, if (II->getIntrinsicID() == Intrinsic::fshr) ShiftAmt = BitWidth - ShiftAmt; - KnownBits Known3(Known); + KnownBits Known3(BitWidth); computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); computeKnownBits(I->getOperand(1), Known3, Depth + 1, Q); @@ -1658,13 +1743,85 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, } } break; - case Instruction::ExtractElement: - // Look through extract element. At the moment we keep this simple and skip - // tracking the specific element. But at least we might find information - // valid for all elements of the vector (for example if vector is sign - // extended, shifted, etc). - computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); + case Instruction::ShuffleVector: { + auto *Shuf = dyn_cast<ShuffleVectorInst>(I); + // FIXME: Do we need to handle ConstantExpr involving shufflevectors? + if (!Shuf) { + Known.resetAll(); + return; + } + // For undef elements, we don't know anything about the common state of + // the shuffle result. + APInt DemandedLHS, DemandedRHS; + if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) { + Known.resetAll(); + return; + } + Known.One.setAllBits(); + Known.Zero.setAllBits(); + if (!!DemandedLHS) { + const Value *LHS = Shuf->getOperand(0); + computeKnownBits(LHS, DemandedLHS, Known, Depth + 1, Q); + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + } + if (!!DemandedRHS) { + const Value *RHS = Shuf->getOperand(1); + computeKnownBits(RHS, DemandedRHS, Known2, Depth + 1, Q); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } + break; + } + case Instruction::InsertElement: { + const Value *Vec = I->getOperand(0); + const Value *Elt = I->getOperand(1); + auto *CIdx = dyn_cast<ConstantInt>(I->getOperand(2)); + // Early out if the index is non-constant or out-of-range. + unsigned NumElts = DemandedElts.getBitWidth(); + if (!CIdx || CIdx->getValue().uge(NumElts)) { + Known.resetAll(); + return; + } + Known.One.setAllBits(); + Known.Zero.setAllBits(); + unsigned EltIdx = CIdx->getZExtValue(); + // Do we demand the inserted element? + if (DemandedElts[EltIdx]) { + computeKnownBits(Elt, Known, Depth + 1, Q); + // If we don't know any bits, early out. + if (Known.isUnknown()) + break; + } + // We don't need the base vector element that has been inserted. + APInt DemandedVecElts = DemandedElts; + DemandedVecElts.clearBit(EltIdx); + if (!!DemandedVecElts) { + computeKnownBits(Vec, DemandedVecElts, Known2, Depth + 1, Q); + Known.One &= Known2.One; + Known.Zero &= Known2.Zero; + } break; + } + case Instruction::ExtractElement: { + // Look through extract element. If the index is non-constant or + // out-of-range demand all elements, otherwise just the extracted element. + const Value *Vec = I->getOperand(0); + const Value *Idx = I->getOperand(1); + auto *CIdx = dyn_cast<ConstantInt>(Idx); + if (isa<ScalableVectorType>(Vec->getType())) { + // FIXME: there's probably *something* we can do with scalable vectors + Known.resetAll(); + break; + } + unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements(); + APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); + if (CIdx && CIdx->getValue().ult(NumElts)) + DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); + computeKnownBits(Vec, DemandedVecElts, Known, Depth + 1, Q); + break; + } case Instruction::ExtractValue: if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I->getOperand(0))) { const ExtractValueInst *EVI = cast<ExtractValueInst>(I); @@ -1675,28 +1832,38 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, case Intrinsic::uadd_with_overflow: case Intrinsic::sadd_with_overflow: computeKnownBitsAddSub(true, II->getArgOperand(0), - II->getArgOperand(1), false, Known, Known2, - Depth, Q); + II->getArgOperand(1), false, DemandedElts, + Known, Known2, Depth, Q); break; case Intrinsic::usub_with_overflow: case Intrinsic::ssub_with_overflow: computeKnownBitsAddSub(false, II->getArgOperand(0), - II->getArgOperand(1), false, Known, Known2, - Depth, Q); + II->getArgOperand(1), false, DemandedElts, + Known, Known2, Depth, Q); break; case Intrinsic::umul_with_overflow: case Intrinsic::smul_with_overflow: computeKnownBitsMul(II->getArgOperand(0), II->getArgOperand(1), false, - Known, Known2, Depth, Q); + DemandedElts, Known, Known2, Depth, Q); break; } } } + break; } } /// Determine which bits of V are known to be either zero or one and return /// them. +KnownBits computeKnownBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const Query &Q) { + KnownBits Known(getBitWidth(V->getType(), Q.DL)); + computeKnownBits(V, DemandedElts, Known, Depth, Q); + return Known; +} + +/// Determine which bits of V are known to be either zero or one and return +/// them. KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { KnownBits Known(getBitWidth(V->getType(), Q.DL)); computeKnownBits(V, Known, Depth, Q); @@ -1717,23 +1884,44 @@ KnownBits computeKnownBits(const Value *V, unsigned Depth, const Query &Q) { /// type, and vectors of integers. In the case /// where V is a vector, known zero, and known one values are the /// same width as the vector element, and the bit is set only if it is true -/// for all of the elements in the vector. -void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, - const Query &Q) { +/// for all of the demanded elements in the vector specified by DemandedElts. +void computeKnownBits(const Value *V, const APInt &DemandedElts, + KnownBits &Known, unsigned Depth, const Query &Q) { + if (!DemandedElts || isa<ScalableVectorType>(V->getType())) { + // No demanded elts or V is a scalable vector, better to assume we don't + // know anything. + Known.resetAll(); + return; + } + assert(V && "No Value?"); assert(Depth <= MaxDepth && "Limit Search Depth"); + +#ifndef NDEBUG + Type *Ty = V->getType(); unsigned BitWidth = Known.getBitWidth(); - assert((V->getType()->isIntOrIntVectorTy(BitWidth) || - V->getType()->isPtrOrPtrVectorTy()) && + assert((Ty->isIntOrIntVectorTy(BitWidth) || Ty->isPtrOrPtrVectorTy()) && "Not integer or pointer type!"); - Type *ScalarTy = V->getType()->getScalarType(); - unsigned ExpectedWidth = ScalarTy->isPointerTy() ? - Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); - assert(ExpectedWidth == BitWidth && "V and Known should have same BitWidth"); - (void)BitWidth; - (void)ExpectedWidth; + if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } + + Type *ScalarTy = Ty->getScalarType(); + if (ScalarTy->isPointerTy()) { + assert(BitWidth == Q.DL.getPointerTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } else { + assert(BitWidth == Q.DL.getTypeSizeInBits(ScalarTy) && + "V and Known should have same BitWidth"); + } +#endif const APInt *C; if (match(V, m_APInt(C))) { @@ -1749,12 +1937,14 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, } // Handle a constant vector by taking the intersection of the known bits of // each element. - if (const ConstantDataSequential *CDS = dyn_cast<ConstantDataSequential>(V)) { - // We know that CDS must be a vector of integers. Take the intersection of + if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(V)) { + // We know that CDV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { - APInt Elt = CDS->getElementAsAPInt(i); + for (unsigned i = 0, e = CDV->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) + continue; + APInt Elt = CDV->getElementAsAPInt(i); Known.Zero &= ~Elt; Known.One &= Elt; } @@ -1766,6 +1956,8 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { + if (!DemandedElts[i]) + continue; Constant *Element = CV->getAggregateElement(i); auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); if (!ElementCI) { @@ -1804,13 +1996,12 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, } if (const Operator *I = dyn_cast<Operator>(V)) - computeKnownBitsFromOperator(I, Known, Depth, Q); + computeKnownBitsFromOperator(I, DemandedElts, Known, Depth, Q); // Aligned pointers have trailing zeros - refine Known.Zero set - if (V->getType()->isPointerTy()) { - const MaybeAlign Align = V->getPointerAlignment(Q.DL); - if (Align) - Known.Zero.setLowBits(countTrailingZeros(Align->value())); + if (isa<PointerType>(V->getType())) { + Align Alignment = V->getPointerAlignment(Q.DL); + Known.Zero.setLowBits(countTrailingZeros(Alignment.value())); } // computeKnownBitsFromAssume strictly refines Known. @@ -1960,7 +2151,7 @@ static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, } // If we have a zero-sized type, the index doesn't matter. Keep looping. - if (Q.DL.getTypeAllocSize(GTI.getIndexedType()) == 0) + if (Q.DL.getTypeAllocSize(GTI.getIndexedType()).getKnownMinSize() == 0) continue; // Fast path the constant operand case both for efficiency and so we don't @@ -2004,11 +2195,11 @@ static bool isKnownNonNullFromDominatingCondition(const Value *V, // If the value is used as an argument to a call or invoke, then argument // attributes may provide an answer about null-ness. - if (auto CS = ImmutableCallSite(U)) - if (auto *CalledFunc = CS.getCalledFunction()) + if (const auto *CB = dyn_cast<CallBase>(U)) + if (auto *CalledFunc = CB->getCalledFunction()) for (const Argument &Arg : CalledFunc->args()) - if (CS.getArgOperand(Arg.getArgNo()) == V && - Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) + if (CB->getArgOperand(Arg.getArgNo()) == V && + Arg.hasNonNullAttr() && DT->dominates(CB, CtxI)) return true; // If the value is used as a load/store, then the pointer must be non null. @@ -2088,12 +2279,18 @@ static bool rangeMetadataExcludesValue(const MDNode* Ranges, const APInt& Value) } /// Return true if the given value is known to be non-zero when defined. For -/// vectors, return true if every element is known to be non-zero when +/// vectors, return true if every demanded element is known to be non-zero when /// defined. For pointers, if the context instruction and dominator tree are /// specified, perform context-sensitive analysis and return true if the /// pointer couldn't possibly be null at the specified instruction. /// Supports values with integer or pointer type and vectors of integers. -bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { +bool isKnownNonZero(const Value *V, const APInt &DemandedElts, unsigned Depth, + const Query &Q) { + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa<ScalableVectorType>(V->getType())) + return false; + if (auto *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return false; @@ -2112,8 +2309,10 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // For constant vectors, check that all elements are undefined or known // non-zero to determine that the whole vector is known non-zero. - if (auto *VecTy = dyn_cast<VectorType>(C->getType())) { + if (auto *VecTy = dyn_cast<FixedVectorType>(C->getType())) { for (unsigned i = 0, e = VecTy->getNumElements(); i != e; ++i) { + if (!DemandedElts[i]) + continue; Constant *Elt = C->getAggregateElement(i); if (!Elt || Elt->isNullValue()) return false; @@ -2161,7 +2360,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // A byval, inalloca, or nonnull argument is never null. if (const Argument *A = dyn_cast<Argument>(V)) - if (A->hasByValOrInAllocaAttr() || A->hasNonNullAttr()) + if (A->hasPassPointeeByValueAttr() || A->hasNonNullAttr()) return true; // A Load tagged with nonnull metadata is never null. @@ -2214,7 +2413,8 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // X | Y != 0 if X != 0 or Y != 0. Value *X = nullptr, *Y = nullptr; if (match(V, m_Or(m_Value(X), m_Value(Y)))) - return isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q); + return isKnownNonZero(X, DemandedElts, Depth, Q) || + isKnownNonZero(Y, DemandedElts, Depth, Q); // ext X != 0 if X != 0. if (isa<SExtInst>(V) || isa<ZExtInst>(V)) @@ -2229,7 +2429,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { return isKnownNonZero(X, Depth, Q); KnownBits Known(BitWidth); - computeKnownBits(X, Known, Depth, Q); + computeKnownBits(X, DemandedElts, Known, Depth, Q); if (Known.One[0]) return true; } @@ -2241,7 +2441,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (BO->isExact()) return isKnownNonZero(X, Depth, Q); - KnownBits Known = computeKnownBits(X, Depth, Q); + KnownBits Known = computeKnownBits(X, DemandedElts, Depth, Q); if (Known.isNegative()) return true; @@ -2255,22 +2455,23 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { return true; // Are all the bits to be shifted out known zero? if (Known.countMinTrailingZeros() >= ShiftVal) - return isKnownNonZero(X, Depth, Q); + return isKnownNonZero(X, DemandedElts, Depth, Q); } } // div exact can only produce a zero if the dividend is zero. else if (match(V, m_Exact(m_IDiv(m_Value(X), m_Value())))) { - return isKnownNonZero(X, Depth, Q); + return isKnownNonZero(X, DemandedElts, Depth, Q); } // X + Y. else if (match(V, m_Add(m_Value(X), m_Value(Y)))) { - KnownBits XKnown = computeKnownBits(X, Depth, Q); - KnownBits YKnown = computeKnownBits(Y, Depth, Q); + KnownBits XKnown = computeKnownBits(X, DemandedElts, Depth, Q); + KnownBits YKnown = computeKnownBits(Y, DemandedElts, Depth, Q); // If X and Y are both non-negative (as signed values) then their sum is not // zero unless both X and Y are zero. if (XKnown.isNonNegative() && YKnown.isNonNegative()) - if (isKnownNonZero(X, Depth, Q) || isKnownNonZero(Y, Depth, Q)) + if (isKnownNonZero(X, DemandedElts, Depth, Q) || + isKnownNonZero(Y, DemandedElts, Depth, Q)) return true; // If X and Y are both negative (as signed values) then their sum is not @@ -2301,13 +2502,14 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { // If X and Y are non-zero then so is X * Y as long as the multiplication // does not overflow. if ((Q.IIQ.hasNoSignedWrap(BO) || Q.IIQ.hasNoUnsignedWrap(BO)) && - isKnownNonZero(X, Depth, Q) && isKnownNonZero(Y, Depth, Q)) + isKnownNonZero(X, DemandedElts, Depth, Q) && + isKnownNonZero(Y, DemandedElts, Depth, Q)) return true; } // (C ? X : Y) != 0 if X != 0 and Y != 0. else if (const SelectInst *SI = dyn_cast<SelectInst>(V)) { - if (isKnownNonZero(SI->getTrueValue(), Depth, Q) && - isKnownNonZero(SI->getFalseValue(), Depth, Q)) + if (isKnownNonZero(SI->getTrueValue(), DemandedElts, Depth, Q) && + isKnownNonZero(SI->getFalseValue(), DemandedElts, Depth, Q)) return true; } // PHI @@ -2337,12 +2539,35 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { if (AllNonZeroConstants) return true; } + // ExtractElement + else if (const auto *EEI = dyn_cast<ExtractElementInst>(V)) { + const Value *Vec = EEI->getVectorOperand(); + const Value *Idx = EEI->getIndexOperand(); + auto *CIdx = dyn_cast<ConstantInt>(Idx); + unsigned NumElts = cast<FixedVectorType>(Vec->getType())->getNumElements(); + APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); + if (CIdx && CIdx->getValue().ult(NumElts)) + DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); + return isKnownNonZero(Vec, DemandedVecElts, Depth, Q); + } KnownBits Known(BitWidth); - computeKnownBits(V, Known, Depth, Q); + computeKnownBits(V, DemandedElts, Known, Depth, Q); return Known.One != 0; } +bool isKnownNonZero(const Value* V, unsigned Depth, const Query& Q) { + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa<ScalableVectorType>(V->getType())) + return false; + + auto *FVTy = dyn_cast<FixedVectorType>(V->getType()); + APInt DemandedElts = + FVTy ? APInt::getAllOnesValue(FVTy->getNumElements()) : APInt(1, 1); + return isKnownNonZero(V, DemandedElts, Depth, Q); +} + /// Return true if V2 == V1 + X, where X is known non-zero. static bool isAddOfNonZero(const Value *V1, const Value *V2, const Query &Q) { const BinaryOperator *BO = dyn_cast<BinaryOperator>(V1); @@ -2433,14 +2658,17 @@ static bool isSignedMinMaxClamp(const Value *Select, const Value *&In, /// or if any element was not analyzed; otherwise, return the count for the /// element with the minimum number of sign bits. static unsigned computeNumSignBitsVectorConstant(const Value *V, + const APInt &DemandedElts, unsigned TyBits) { const auto *CV = dyn_cast<Constant>(V); - if (!CV || !CV->getType()->isVectorTy()) + if (!CV || !isa<FixedVectorType>(CV->getType())) return 0; unsigned MinSignBits = TyBits; - unsigned NumElts = CV->getType()->getVectorNumElements(); + unsigned NumElts = cast<FixedVectorType>(CV->getType())->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { + if (!DemandedElts[i]) + continue; // If we find a non-ConstantInt, bail out. auto *Elt = dyn_cast_or_null<ConstantInt>(CV->getAggregateElement(i)); if (!Elt) @@ -2452,12 +2680,13 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, return MinSignBits; } -static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, - const Query &Q); +static unsigned ComputeNumSignBitsImpl(const Value *V, + const APInt &DemandedElts, + unsigned Depth, const Query &Q); -static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, - const Query &Q) { - unsigned Result = ComputeNumSignBitsImpl(V, Depth, Q); +static unsigned ComputeNumSignBits(const Value *V, const APInt &DemandedElts, + unsigned Depth, const Query &Q) { + unsigned Result = ComputeNumSignBitsImpl(V, DemandedElts, Depth, Q); assert(Result > 0 && "At least one sign bit needs to be present!"); return Result; } @@ -2467,16 +2696,36 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, /// (itself), but other cases can give us information. For example, immediately /// after an "ashr X, 2", we know that the top 3 bits are all equal to each /// other, so we return 3. For vectors, return the number of sign bits for the -/// vector element with the minimum number of known sign bits. -static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, - const Query &Q) { +/// vector element with the minimum number of known sign bits of the demanded +/// elements in the vector specified by DemandedElts. +static unsigned ComputeNumSignBitsImpl(const Value *V, + const APInt &DemandedElts, + unsigned Depth, const Query &Q) { + Type *Ty = V->getType(); + + // FIXME: We currently have no way to represent the DemandedElts of a scalable + // vector + if (isa<ScalableVectorType>(Ty)) + return 1; + +#ifndef NDEBUG assert(Depth <= MaxDepth && "Limit Search Depth"); + if (auto *FVTy = dyn_cast<FixedVectorType>(Ty)) { + assert( + FVTy->getNumElements() == DemandedElts.getBitWidth() && + "DemandedElt width should equal the fixed vector number of elements"); + } else { + assert(DemandedElts == APInt(1, 1) && + "DemandedElt width should be 1 for scalars"); + } +#endif + // We return the minimum number of sign bits that are guaranteed to be present // in V, so for undef we have to conservatively return 1. We don't have the // same behavior for poison though -- that's a FIXME today. - Type *ScalarTy = V->getType()->getScalarType(); + Type *ScalarTy = Ty->getScalarType(); unsigned TyBits = ScalarTy->isPointerTy() ? Q.DL.getPointerTypeSizeInBits(ScalarTy) : Q.DL.getTypeSizeInBits(ScalarTy); @@ -2702,40 +2951,37 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, return ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); case Instruction::ShuffleVector: { - // TODO: This is copied almost directly from the SelectionDAG version of - // ComputeNumSignBits. It would be better if we could share common - // code. If not, make sure that changes are translated to the DAG. - // Collect the minimum number of sign bits that are shared by every vector // element referenced by the shuffle. - auto *Shuf = cast<ShuffleVectorInst>(U); - int NumElts = Shuf->getOperand(0)->getType()->getVectorNumElements(); - int NumMaskElts = Shuf->getMask()->getType()->getVectorNumElements(); - APInt DemandedLHS(NumElts, 0), DemandedRHS(NumElts, 0); - for (int i = 0; i != NumMaskElts; ++i) { - int M = Shuf->getMaskValue(i); - assert(M < NumElts * 2 && "Invalid shuffle mask constant"); - // For undef elements, we don't know anything about the common state of - // the shuffle result. - if (M == -1) - return 1; - if (M < NumElts) - DemandedLHS.setBit(M % NumElts); - else - DemandedRHS.setBit(M % NumElts); + auto *Shuf = dyn_cast<ShuffleVectorInst>(U); + if (!Shuf) { + // FIXME: Add support for shufflevector constant expressions. + return 1; } + APInt DemandedLHS, DemandedRHS; + // For undef elements, we don't know anything about the common state of + // the shuffle result. + if (!getShuffleDemandedElts(Shuf, DemandedElts, DemandedLHS, DemandedRHS)) + return 1; Tmp = std::numeric_limits<unsigned>::max(); - if (!!DemandedLHS) - Tmp = ComputeNumSignBits(Shuf->getOperand(0), Depth + 1, Q); + if (!!DemandedLHS) { + const Value *LHS = Shuf->getOperand(0); + Tmp = ComputeNumSignBits(LHS, DemandedLHS, Depth + 1, Q); + } + // If we don't know anything, early out and try computeKnownBits + // fall-back. + if (Tmp == 1) + break; if (!!DemandedRHS) { - Tmp2 = ComputeNumSignBits(Shuf->getOperand(1), Depth + 1, Q); + const Value *RHS = Shuf->getOperand(1); + Tmp2 = ComputeNumSignBits(RHS, DemandedRHS, Depth + 1, Q); Tmp = std::min(Tmp, Tmp2); } // If we don't know anything, early out and try computeKnownBits // fall-back. if (Tmp == 1) break; - assert(Tmp <= V->getType()->getScalarSizeInBits() && + assert(Tmp <= Ty->getScalarSizeInBits() && "Failed to determine minimum sign bits"); return Tmp; } @@ -2747,11 +2993,12 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, // If we can examine all elements of a vector constant successfully, we're // done (we can't do any better than that). If not, keep trying. - if (unsigned VecSignBits = computeNumSignBitsVectorConstant(V, TyBits)) + if (unsigned VecSignBits = + computeNumSignBitsVectorConstant(V, DemandedElts, TyBits)) return VecSignBits; KnownBits Known(TyBits); - computeKnownBits(V, Known, Depth, Q); + computeKnownBits(V, DemandedElts, Known, Depth, Q); // If we know that the sign bit is either zero or one, determine the number of // identical bits in the top of the input value. @@ -2877,30 +3124,23 @@ bool llvm::ComputeMultiple(Value *V, unsigned Base, Value *&Multiple, return false; } -Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, +Intrinsic::ID llvm::getIntrinsicForCallSite(const CallBase &CB, const TargetLibraryInfo *TLI) { - const Function *F = ICS.getCalledFunction(); + const Function *F = CB.getCalledFunction(); if (!F) return Intrinsic::not_intrinsic; if (F->isIntrinsic()) return F->getIntrinsicID(); - if (!TLI) - return Intrinsic::not_intrinsic; - + // We are going to infer semantics of a library function based on mapping it + // to an LLVM intrinsic. Check that the library function is available from + // this callbase and in this environment. LibFunc Func; - // We're going to make assumptions on the semantics of the functions, check - // that the target knows that it's available in this environment and it does - // not have local linkage. - if (!F || F->hasLocalLinkage() || !TLI->getLibFunc(*F, Func)) + if (F->hasLocalLinkage() || !TLI || !TLI->getLibFunc(CB, Func) || + !CB.onlyReadsMemory()) return Intrinsic::not_intrinsic; - if (!ICS.onlyReadsMemory()) - return Intrinsic::not_intrinsic; - - // Otherwise check if we have a call to a function that can be turned into a - // vector intrinsic. switch (Func) { default: break; @@ -2972,6 +3212,10 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, case LibFunc_roundf: case LibFunc_roundl: return Intrinsic::round; + case LibFunc_roundeven: + case LibFunc_roundevenf: + case LibFunc_roundevenl: + return Intrinsic::roundeven; case LibFunc_pow: case LibFunc_powf: case LibFunc_powl: @@ -2987,6 +3231,9 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, /// Return true if we can prove that the specified FP value is never equal to /// -0.0. +/// NOTE: Do not check 'nsz' here because that fast-math-flag does not guarantee +/// that a value is not -0.0. It only guarantees that -0.0 may be treated +/// the same as +0.0 in floating-point ops. /// /// NOTE: this function will need to be revisited when we support non-default /// rounding modes! @@ -3003,11 +3250,6 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, if (!Op) return false; - // Check if the nsz fast-math flag is set. - if (auto *FPO = dyn_cast<FPMathOperator>(Op)) - if (FPO->hasNoSignedZeros()) - return true; - // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. if (match(Op, m_FAdd(m_Value(), m_PosZeroFP()))) return true; @@ -3017,7 +3259,7 @@ bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, return true; if (auto *Call = dyn_cast<CallInst>(Op)) { - Intrinsic::ID IID = getIntrinsicForCallSite(Call, TLI); + Intrinsic::ID IID = getIntrinsicForCallSite(*Call, TLI); switch (IID) { default: break; @@ -3053,8 +3295,8 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, // Handle vector of constants. if (auto *CV = dyn_cast<Constant>(V)) { - if (CV->getType()->isVectorTy()) { - unsigned NumElts = CV->getType()->getVectorNumElements(); + if (auto *CVFVTy = dyn_cast<FixedVectorType>(CV->getType())) { + unsigned NumElts = CVFVTy->getNumElements(); for (unsigned i = 0; i != NumElts; ++i) { auto *CFP = dyn_cast_or_null<ConstantFP>(CV->getAggregateElement(i)); if (!CFP) @@ -3083,14 +3325,15 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, case Instruction::UIToFP: return true; case Instruction::FMul: - // x*x is always non-negative or a NaN. + case Instruction::FDiv: + // X * X is always non-negative or a NaN. + // X / X is always exactly 1.0 or a NaN. if (I->getOperand(0) == I->getOperand(1) && (!SignBitOnly || cast<FPMathOperator>(I)->hasNoNaNs())) return true; LLVM_FALLTHROUGH; case Instruction::FAdd: - case Instruction::FDiv: case Instruction::FRem: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, Depth + 1) && @@ -3114,17 +3357,32 @@ static bool cannotBeOrderedLessThanZeroImpl(const Value *V, Depth + 1); case Instruction::Call: const auto *CI = cast<CallInst>(I); - Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + Intrinsic::ID IID = getIntrinsicForCallSite(*CI, TLI); switch (IID) { default: break; - case Intrinsic::maxnum: - return (isKnownNeverNaN(I->getOperand(0), TLI) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, - SignBitOnly, Depth + 1)) || - (isKnownNeverNaN(I->getOperand(1), TLI) && - cannotBeOrderedLessThanZeroImpl(I->getOperand(1), TLI, - SignBitOnly, Depth + 1)); + case Intrinsic::maxnum: { + Value *V0 = I->getOperand(0), *V1 = I->getOperand(1); + auto isPositiveNum = [&](Value *V) { + if (SignBitOnly) { + // With SignBitOnly, this is tricky because the result of + // maxnum(+0.0, -0.0) is unspecified. Just check if the operand is + // a constant strictly greater than 0.0. + const APFloat *C; + return match(V, m_APFloat(C)) && + *C > APFloat::getZero(C->getSemantics()); + } + + // -0.0 compares equal to 0.0, so if this operand is at least -0.0, + // maxnum can't be ordered-less-than-zero. + return isKnownNeverNaN(V, TLI) && + cannotBeOrderedLessThanZeroImpl(V, TLI, false, Depth + 1); + }; + + // TODO: This could be improved. We could also check that neither operand + // has its sign bit set (and at least 1 is not-NAN?). + return isPositiveNum(V0) || isPositiveNum(V1); + } case Intrinsic::maximum: return cannotBeOrderedLessThanZeroImpl(I->getOperand(0), TLI, SignBitOnly, @@ -3225,24 +3483,26 @@ bool llvm::isKnownNeverInfinity(const Value *V, const TargetLibraryInfo *TLI, } } - // Bail out for constant expressions, but try to handle vector constants. - if (!V->getType()->isVectorTy() || !isa<Constant>(V)) - return false; - - // For vectors, verify that each element is not infinity. - unsigned NumElts = V->getType()->getVectorNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa<UndefValue>(Elt)) - continue; - auto *CElt = dyn_cast<ConstantFP>(Elt); - if (!CElt || CElt->isInfinity()) - return false; + // try to handle fixed width vector constants + if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) { + // For vectors, verify that each element is not infinity. + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast<Constant>(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CElt = dyn_cast<ConstantFP>(Elt); + if (!CElt || CElt->isInfinity()) + return false; + } + // All elements were confirmed non-infinity or undefined. + return true; } - // All elements were confirmed non-infinity or undefined. - return true; + + // was not able to prove that V never contains infinity + return false; } bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, @@ -3312,6 +3572,7 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, case Intrinsic::rint: case Intrinsic::nearbyint: case Intrinsic::round: + case Intrinsic::roundeven: return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1); case Intrinsic::sqrt: return isKnownNeverNaN(II->getArgOperand(0), TLI, Depth + 1) && @@ -3326,24 +3587,26 @@ bool llvm::isKnownNeverNaN(const Value *V, const TargetLibraryInfo *TLI, } } - // Bail out for constant expressions, but try to handle vector constants. - if (!V->getType()->isVectorTy() || !isa<Constant>(V)) - return false; - - // For vectors, verify that each element is not NaN. - unsigned NumElts = V->getType()->getVectorNumElements(); - for (unsigned i = 0; i != NumElts; ++i) { - Constant *Elt = cast<Constant>(V)->getAggregateElement(i); - if (!Elt) - return false; - if (isa<UndefValue>(Elt)) - continue; - auto *CElt = dyn_cast<ConstantFP>(Elt); - if (!CElt || CElt->isNaN()) - return false; + // Try to handle fixed width vector constants + if (isa<FixedVectorType>(V->getType()) && isa<Constant>(V)) { + // For vectors, verify that each element is not NaN. + unsigned NumElts = cast<VectorType>(V->getType())->getNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast<Constant>(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CElt = dyn_cast<ConstantFP>(Elt); + if (!CElt || CElt->isNaN()) + return false; + } + // All elements were confirmed not-NaN or undefined. + return true; } - // All elements were confirmed not-NaN or undefined. - return true; + + // Was not able to prove that V never contains NaN + return false; } Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { @@ -3359,8 +3622,8 @@ Value *llvm::isBytewiseValue(Value *V, const DataLayout &DL) { if (isa<UndefValue>(V)) return UndefInt8; - const uint64_t Size = DL.getTypeStoreSize(V->getType()); - if (!Size) + // Return Undef for zero-sized type. + if (!DL.getTypeStoreSize(V->getType()).isNonZero()) return UndefInt8; Constant *C = dyn_cast<Constant>(V); @@ -3678,7 +3941,7 @@ bool llvm::getConstantDataArrayInfo(const Value *V, Array = nullptr; } else { const DataLayout &DL = GV->getParent()->getDataLayout(); - uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy); + uint64_t SizeInBytes = DL.getTypeStoreSize(GVTy).getFixedSize(); uint64_t Length = SizeInBytes / (ElementSize / 8); if (Length <= Offset) return false; @@ -3839,12 +4102,17 @@ llvm::getArgumentAliasingToReturnedPointer(const CallBase *Call, bool llvm::isIntrinsicReturningPointerAliasingArgumentWithoutCapturing( const CallBase *Call, bool MustPreserveNullness) { - return Call->getIntrinsicID() == Intrinsic::launder_invariant_group || - Call->getIntrinsicID() == Intrinsic::strip_invariant_group || - Call->getIntrinsicID() == Intrinsic::aarch64_irg || - Call->getIntrinsicID() == Intrinsic::aarch64_tagp || - (!MustPreserveNullness && - Call->getIntrinsicID() == Intrinsic::ptrmask); + switch (Call->getIntrinsicID()) { + case Intrinsic::launder_invariant_group: + case Intrinsic::strip_invariant_group: + case Intrinsic::aarch64_irg: + case Intrinsic::aarch64_tagp: + return true; + case Intrinsic::ptrmask: + return !MustPreserveNullness; + default: + return false; + } } /// \p PN defines a loop-variant pointer to an object. Check if the @@ -3884,15 +4152,20 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, } else if (Operator::getOpcode(V) == Instruction::BitCast || Operator::getOpcode(V) == Instruction::AddrSpaceCast) { V = cast<Operator>(V)->getOperand(0); + if (!V->getType()->isPointerTy()) + return V; } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(V)) { if (GA->isInterposable()) return V; V = GA->getAliasee(); - } else if (isa<AllocaInst>(V)) { - // An alloca can't be further simplified. - return V; } else { - if (auto *Call = dyn_cast<CallBase>(V)) { + if (auto *PHI = dyn_cast<PHINode>(V)) { + // Look through single-arg phi nodes created by LCSSA. + if (PHI->getNumIncomingValues() == 1) { + V = PHI->getIncomingValue(0); + continue; + } + } else if (auto *Call = dyn_cast<CallBase>(V)) { // CaptureTracking can know about special capturing properties of some // intrinsics like launder.invariant.group, that can't be expressed with // the attributes, but have properties like returning aliasing pointer. @@ -3908,14 +4181,6 @@ Value *llvm::GetUnderlyingObject(Value *V, const DataLayout &DL, } } - // See if InstructionSimplify knows any relevant tricks. - if (Instruction *I = dyn_cast<Instruction>(V)) - // TODO: Acquire a DominatorTree and AssumptionCache and use them. - if (Value *Simplified = SimplifyInstruction(I, {DL, I})) { - V = Simplified; - continue; - } - return V; } assert(V->getType()->isPointerTy() && "Unexpected operand type!"); @@ -4309,6 +4574,16 @@ OverflowResult llvm::computeOverflowForUnsignedSub(const Value *LHS, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT) { + // Checking for conditions implied by dominating conditions may be expensive. + // Limit it to usub_with_overflow calls for now. + if (match(CxtI, + m_Intrinsic<Intrinsic::usub_with_overflow>(m_Value(), m_Value()))) + if (auto C = + isImpliedByDomCondition(CmpInst::ICMP_UGE, LHS, RHS, CxtI, DL)) { + if (*C) + return OverflowResult::NeverOverflows; + return OverflowResult::AlwaysOverflowsLow; + } ConstantRange LHSRange = computeConstantRangeIncludingKnownBits( LHS, /*ForSigned=*/false, DL, /*Depth=*/0, AC, CxtI, DT); ConstantRange RHSRange = computeConstantRangeIncludingKnownBits( @@ -4385,7 +4660,100 @@ bool llvm::isOverflowIntrinsicNoWrap(const WithOverflowInst *WO, return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch); } -bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V) { +bool llvm::canCreatePoison(const Instruction *I) { + // See whether I has flags that may create poison + if (isa<OverflowingBinaryOperator>(I) && + (I->hasNoSignedWrap() || I->hasNoUnsignedWrap())) + return true; + if (isa<PossiblyExactOperator>(I) && I->isExact()) + return true; + if (auto *FP = dyn_cast<FPMathOperator>(I)) { + auto FMF = FP->getFastMathFlags(); + if (FMF.noNaNs() || FMF.noInfs()) + return true; + } + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) + if (GEP->isInBounds()) + return true; + + unsigned Opcode = I->getOpcode(); + + // Check whether opcode is a poison-generating operation + switch (Opcode) { + case Instruction::Shl: + case Instruction::AShr: + case Instruction::LShr: { + // Shifts return poison if shiftwidth is larger than the bitwidth. + if (auto *C = dyn_cast<Constant>(I->getOperand(1))) { + SmallVector<Constant *, 4> ShiftAmounts; + if (auto *FVTy = dyn_cast<FixedVectorType>(C->getType())) { + unsigned NumElts = FVTy->getNumElements(); + for (unsigned i = 0; i < NumElts; ++i) + ShiftAmounts.push_back(C->getAggregateElement(i)); + } else if (isa<ScalableVectorType>(C->getType())) + return true; // Can't tell, just return true to be safe + else + ShiftAmounts.push_back(C); + + bool Safe = llvm::all_of(ShiftAmounts, [](Constant *C) { + auto *CI = dyn_cast<ConstantInt>(C); + return CI && CI->getZExtValue() < C->getType()->getIntegerBitWidth(); + }); + return !Safe; + } + return true; + } + case Instruction::FPToSI: + case Instruction::FPToUI: + // fptosi/ui yields poison if the resulting value does not fit in the + // destination type. + return true; + case Instruction::Call: + case Instruction::CallBr: + case Instruction::Invoke: + // Function calls can return a poison value even if args are non-poison + // values. + return true; + case Instruction::InsertElement: + case Instruction::ExtractElement: { + // If index exceeds the length of the vector, it returns poison + auto *VTy = cast<VectorType>(I->getOperand(0)->getType()); + unsigned IdxOp = I->getOpcode() == Instruction::InsertElement ? 2 : 1; + auto *Idx = dyn_cast<ConstantInt>(I->getOperand(IdxOp)); + if (!Idx || Idx->getZExtValue() >= VTy->getElementCount().Min) + return true; + return false; + } + case Instruction::FNeg: + case Instruction::PHI: + case Instruction::Select: + case Instruction::URem: + case Instruction::SRem: + case Instruction::ShuffleVector: + case Instruction::ExtractValue: + case Instruction::InsertValue: + case Instruction::Freeze: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::GetElementPtr: + return false; + default: + if (isa<CastInst>(I)) + return false; + else if (isa<BinaryOperator>(I)) + return false; + // Be conservative and return true. + return true; + } +} + +bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V, + const Instruction *CtxI, + const DominatorTree *DT, + unsigned Depth) { + if (Depth >= MaxDepth) + return false; + // If the value is a freeze instruction, then it can never // be undef or poison. if (isa<FreezeInst>(V)) @@ -4393,10 +4761,100 @@ bool llvm::isGuaranteedNotToBeUndefOrPoison(const Value *V) { // TODO: Some instructions are guaranteed to return neither undef // nor poison if their arguments are not poison/undef. - // TODO: Deal with other Constant subclasses. - if (isa<ConstantInt>(V) || isa<GlobalVariable>(V)) + if (auto *C = dyn_cast<Constant>(V)) { + // TODO: We can analyze ConstExpr by opcode to determine if there is any + // possibility of poison. + if (isa<UndefValue>(C) || isa<ConstantExpr>(C)) + return false; + + if (isa<ConstantInt>(C) || isa<GlobalVariable>(C) || isa<ConstantFP>(V) || + isa<ConstantPointerNull>(C) || isa<Function>(C)) + return true; + + if (C->getType()->isVectorTy()) + return !C->containsUndefElement() && !C->containsConstantExpression(); + + // TODO: Recursively analyze aggregates or other constants. + return false; + } + + // Strip cast operations from a pointer value. + // Note that stripPointerCastsSameRepresentation can strip off getelementptr + // inbounds with zero offset. To guarantee that the result isn't poison, the + // stripped pointer is checked as it has to be pointing into an allocated + // object or be null `null` to ensure `inbounds` getelement pointers with a + // zero offset could not produce poison. + // It can strip off addrspacecast that do not change bit representation as + // well. We believe that such addrspacecast is equivalent to no-op. + auto *StrippedV = V->stripPointerCastsSameRepresentation(); + if (isa<AllocaInst>(StrippedV) || isa<GlobalVariable>(StrippedV) || + isa<Function>(StrippedV) || isa<ConstantPointerNull>(StrippedV)) return true; + auto OpCheck = [&](const Value *V) { + return isGuaranteedNotToBeUndefOrPoison(V, CtxI, DT, Depth + 1); + }; + + if (auto *I = dyn_cast<Instruction>(V)) { + switch (I->getOpcode()) { + case Instruction::GetElementPtr: { + auto *GEPI = dyn_cast<GetElementPtrInst>(I); + if (!GEPI->isInBounds() && llvm::all_of(GEPI->operands(), OpCheck)) + return true; + break; + } + case Instruction::FCmp: { + auto *FI = dyn_cast<FCmpInst>(I); + if (FI->getFastMathFlags().none() && + llvm::all_of(FI->operands(), OpCheck)) + return true; + break; + } + case Instruction::BitCast: + case Instruction::PHI: + case Instruction::ICmp: + if (llvm::all_of(I->operands(), OpCheck)) + return true; + break; + default: + break; + } + + if (programUndefinedIfPoison(I) && I->getType()->isIntegerTy(1)) + // Note: once we have an agreement that poison is a value-wise concept, + // we can remove the isIntegerTy(1) constraint. + return true; + } + + // CxtI may be null or a cloned instruction. + if (!CtxI || !CtxI->getParent() || !DT) + return false; + + auto *DNode = DT->getNode(CtxI->getParent()); + if (!DNode) + // Unreachable block + return false; + + // If V is used as a branch condition before reaching CtxI, V cannot be + // undef or poison. + // br V, BB1, BB2 + // BB1: + // CtxI ; V cannot be undef or poison here + auto *Dominator = DNode->getIDom(); + while (Dominator) { + auto *TI = Dominator->getBlock()->getTerminator(); + + if (auto BI = dyn_cast<BranchInst>(TI)) { + if (BI->isConditional() && BI->getCondition() == V) + return true; + } else if (auto SI = dyn_cast<SwitchInst>(TI)) { + if (SI->getCondition() == V) + return true; + } + + Dominator = Dominator->getIDom(); + } + return false; } @@ -4436,14 +4894,14 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { return false; // Calls can throw, or contain an infinite loop, or kill the process. - if (auto CS = ImmutableCallSite(I)) { + if (const auto *CB = dyn_cast<CallBase>(I)) { // Call sites that throw have implicit non-local control flow. - if (!CS.doesNotThrow()) + if (!CB->doesNotThrow()) return false; // A function which doens't throw and has "willreturn" attribute will // always return. - if (CS.hasFnAttr(Attribute::WillReturn)) + if (CB->hasFnAttr(Attribute::WillReturn)) return true; // Non-throwing call sites can loop infinitely, call exit/pthread_exit @@ -4462,7 +4920,7 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { // FIXME: This isn't aggressive enough; a call which only writes to a global // is guaranteed to return. - return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory(); + return CB->onlyReadsMemory() || CB->onlyAccessesArgMemory(); } // Other instructions return normally. @@ -4493,41 +4951,28 @@ bool llvm::isGuaranteedToExecuteForEveryIteration(const Instruction *I, llvm_unreachable("Instruction not contained in its own parent basic block."); } -bool llvm::propagatesFullPoison(const Instruction *I) { - // TODO: This should include all instructions apart from phis, selects and - // call-like instructions. +bool llvm::propagatesPoison(const Instruction *I) { switch (I->getOpcode()) { - case Instruction::Add: - case Instruction::Sub: - case Instruction::Xor: - case Instruction::Trunc: - case Instruction::BitCast: - case Instruction::AddrSpaceCast: - case Instruction::Mul: - case Instruction::Shl: - case Instruction::GetElementPtr: - // These operations all propagate poison unconditionally. Note that poison - // is not any particular value, so xor or subtraction of poison with - // itself still yields poison, not zero. - return true; - - case Instruction::AShr: - case Instruction::SExt: - // For these operations, one bit of the input is replicated across - // multiple output bits. A replicated poison bit is still poison. - return true; - + case Instruction::Freeze: + case Instruction::Select: + case Instruction::PHI: + case Instruction::Call: + case Instruction::Invoke: + return false; case Instruction::ICmp: - // Comparing poison with any value yields poison. This is why, for - // instance, x s< (x +nsw 1) can be folded to true. + case Instruction::FCmp: + case Instruction::GetElementPtr: return true; - default: + if (isa<BinaryOperator>(I) || isa<UnaryOperator>(I) || isa<CastInst>(I)) + return true; + + // Be conservative and return false. return false; } } -const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) { +const Value *llvm::getGuaranteedNonPoisonOp(const Instruction *I) { switch (I->getOpcode()) { case Instruction::Store: return cast<StoreInst>(I)->getPointerOperand(); @@ -4547,23 +4992,30 @@ const Value *llvm::getGuaranteedNonFullPoisonOp(const Instruction *I) { case Instruction::SRem: return I->getOperand(1); + case Instruction::Call: + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + case Intrinsic::assume: + return II->getArgOperand(0); + default: + return nullptr; + } + } + return nullptr; + default: - // Note: It's really tempting to think that a conditional branch or - // switch should be listed here, but that's incorrect. It's not - // branching off of poison which is UB, it is executing a side effecting - // instruction which follows the branch. return nullptr; } } bool llvm::mustTriggerUB(const Instruction *I, const SmallSet<const Value *, 16>& KnownPoison) { - auto *NotPoison = getGuaranteedNonFullPoisonOp(I); + auto *NotPoison = getGuaranteedNonPoisonOp(I); return (NotPoison && KnownPoison.count(NotPoison)); } -bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { +bool llvm::programUndefinedIfPoison(const Instruction *PoisonI) { // We currently only look for uses of poison values within the same basic // block, as that makes it easier to guarantee that the uses will be // executed given that PoisonI is executed. @@ -4596,7 +5048,7 @@ bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { if (YieldsPoison.count(&I)) { for (const User *User : I.users()) { const Instruction *UserI = cast<Instruction>(User); - if (propagatesFullPoison(UserI)) + if (propagatesPoison(UserI)) YieldsPoison.insert(User); } } @@ -4633,6 +5085,9 @@ static bool isKnownNonNaN(const Value *V, FastMathFlags FMF) { return true; } + if (isa<ConstantAggregateZero>(V)) + return true; + return false; } @@ -4689,7 +5144,7 @@ static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred, if (match(FalseVal, m_CombineOr(m_OrdFMin(m_Specific(CmpLHS), m_APFloat(FC2)), m_UnordFMin(m_Specific(CmpLHS), m_APFloat(FC2)))) && - FC1->compare(*FC2) == APFloat::cmpResult::cmpLessThan) + *FC1 < *FC2) return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false}; break; case CmpInst::FCMP_OGT: @@ -4699,7 +5154,7 @@ static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred, if (match(FalseVal, m_CombineOr(m_OrdFMax(m_Specific(CmpLHS), m_APFloat(FC2)), m_UnordFMax(m_Specific(CmpLHS), m_APFloat(FC2)))) && - FC1->compare(*FC2) == APFloat::cmpResult::cmpGreaterThan) + *FC1 > *FC2) return {SPF_FMINNUM, SPNB_RETURNS_ANY, false}; break; default: @@ -4840,6 +5295,21 @@ static SelectPatternResult matchMinMaxOfMinMax(CmpInst::Predicate Pred, return {SPF_UNKNOWN, SPNB_NA, false}; } +/// If the input value is the result of a 'not' op, constant integer, or vector +/// splat of a constant integer, return the bitwise-not source value. +/// TODO: This could be extended to handle non-splat vector integer constants. +static Value *getNotValue(Value *V) { + Value *NotV; + if (match(V, m_Not(m_Value(NotV)))) + return NotV; + + const APInt *C; + if (match(V, m_APInt(C))) + return ConstantInt::get(V->getType(), ~(*C)); + + return nullptr; +} + /// Match non-obvious integer minimum and maximum sequences. static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, Value *CmpLHS, Value *CmpRHS, @@ -4858,6 +5328,31 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) return SPR; + // Look through 'not' ops to find disguised min/max. + // (X > Y) ? ~X : ~Y ==> (~X < ~Y) ? ~X : ~Y ==> MIN(~X, ~Y) + // (X < Y) ? ~X : ~Y ==> (~X > ~Y) ? ~X : ~Y ==> MAX(~X, ~Y) + if (CmpLHS == getNotValue(TrueVal) && CmpRHS == getNotValue(FalseVal)) { + switch (Pred) { + case CmpInst::ICMP_SGT: return {SPF_SMIN, SPNB_NA, false}; + case CmpInst::ICMP_SLT: return {SPF_SMAX, SPNB_NA, false}; + case CmpInst::ICMP_UGT: return {SPF_UMIN, SPNB_NA, false}; + case CmpInst::ICMP_ULT: return {SPF_UMAX, SPNB_NA, false}; + default: break; + } + } + + // (X > Y) ? ~Y : ~X ==> (~X < ~Y) ? ~Y : ~X ==> MAX(~Y, ~X) + // (X < Y) ? ~Y : ~X ==> (~X > ~Y) ? ~Y : ~X ==> MIN(~Y, ~X) + if (CmpLHS == getNotValue(FalseVal) && CmpRHS == getNotValue(TrueVal)) { + switch (Pred) { + case CmpInst::ICMP_SGT: return {SPF_SMAX, SPNB_NA, false}; + case CmpInst::ICMP_SLT: return {SPF_SMIN, SPNB_NA, false}; + case CmpInst::ICMP_UGT: return {SPF_UMAX, SPNB_NA, false}; + case CmpInst::ICMP_ULT: return {SPF_UMIN, SPNB_NA, false}; + default: break; + } + } + if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4898,19 +5393,6 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, return {CmpLHS == FalseVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; } - // Look through 'not' ops to find disguised signed min/max. - // (X >s C) ? ~X : ~C ==> (~X <s ~C) ? ~X : ~C ==> SMIN(~X, ~C) - // (X <s C) ? ~X : ~C ==> (~X >s ~C) ? ~X : ~C ==> SMAX(~X, ~C) - if (match(TrueVal, m_Not(m_Specific(CmpLHS))) && - match(FalseVal, m_APInt(C2)) && ~(*C1) == *C2) - return {Pred == CmpInst::ICMP_SGT ? SPF_SMIN : SPF_SMAX, SPNB_NA, false}; - - // (X >s C) ? ~C : ~X ==> (~X <s ~C) ? ~C : ~X ==> SMAX(~C, ~X) - // (X <s C) ? ~C : ~X ==> (~X >s ~C) ? ~C : ~X ==> SMIN(~C, ~X) - if (match(FalseVal, m_Not(m_Specific(CmpLHS))) && - match(TrueVal, m_APInt(C2)) && ~(*C1) == *C2) - return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; - return {SPF_UNKNOWN, SPNB_NA, false}; } @@ -5445,20 +5927,18 @@ isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, /// Return true if LHS implies RHS is true. Return false if LHS implies RHS is /// false. Otherwise, return None if we can't infer anything. static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, - const ICmpInst *RHS, + CmpInst::Predicate BPred, + const Value *BLHS, const Value *BRHS, const DataLayout &DL, bool LHSIsTrue, unsigned Depth) { Value *ALHS = LHS->getOperand(0); Value *ARHS = LHS->getOperand(1); + // The rest of the logic assumes the LHS condition is true. If that's not the // case, invert the predicate to make it so. - ICmpInst::Predicate APred = + CmpInst::Predicate APred = LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate(); - Value *BLHS = RHS->getOperand(0); - Value *BRHS = RHS->getOperand(1); - ICmpInst::Predicate BPred = RHS->getPredicate(); - // Can we infer anything when the two compares have matching operands? bool AreSwappedOps; if (isMatchingOps(ALHS, ARHS, BLHS, BRHS, AreSwappedOps)) { @@ -5489,10 +5969,11 @@ static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, /// Return true if LHS implies RHS is true. Return false if LHS implies RHS is /// false. Otherwise, return None if we can't infer anything. We expect the /// RHS to be an icmp and the LHS to be an 'and' or an 'or' instruction. -static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS, - const ICmpInst *RHS, - const DataLayout &DL, bool LHSIsTrue, - unsigned Depth) { +static Optional<bool> +isImpliedCondAndOr(const BinaryOperator *LHS, CmpInst::Predicate RHSPred, + const Value *RHSOp0, const Value *RHSOp1, + + const DataLayout &DL, bool LHSIsTrue, unsigned Depth) { // The LHS must be an 'or' or an 'and' instruction. assert((LHS->getOpcode() == Instruction::And || LHS->getOpcode() == Instruction::Or) && @@ -5507,36 +5988,33 @@ static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS, if ((!LHSIsTrue && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) || (LHSIsTrue && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) { // FIXME: Make this non-recursion. - if (Optional<bool> Implication = - isImpliedCondition(ALHS, RHS, DL, LHSIsTrue, Depth + 1)) + if (Optional<bool> Implication = isImpliedCondition( + ALHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1)) return Implication; - if (Optional<bool> Implication = - isImpliedCondition(ARHS, RHS, DL, LHSIsTrue, Depth + 1)) + if (Optional<bool> Implication = isImpliedCondition( + ARHS, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, Depth + 1)) return Implication; return None; } return None; } -Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, - const DataLayout &DL, bool LHSIsTrue, - unsigned Depth) { +Optional<bool> +llvm::isImpliedCondition(const Value *LHS, CmpInst::Predicate RHSPred, + const Value *RHSOp0, const Value *RHSOp1, + const DataLayout &DL, bool LHSIsTrue, unsigned Depth) { // Bail out when we hit the limit. if (Depth == MaxDepth) return None; // A mismatch occurs when we compare a scalar cmp to a vector cmp, for // example. - if (LHS->getType() != RHS->getType()) + if (RHSOp0->getType()->isVectorTy() != LHS->getType()->isVectorTy()) return None; Type *OpTy = LHS->getType(); assert(OpTy->isIntOrIntVectorTy(1) && "Expected integer type only!"); - // LHS ==> RHS by definition - if (LHS == RHS) - return LHSIsTrue; - // FIXME: Extending the code below to handle vectors. if (OpTy->isVectorTy()) return None; @@ -5545,51 +6023,87 @@ Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, // Both LHS and RHS are icmps. const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS); - const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); - if (LHSCmp && RHSCmp) - return isImpliedCondICmps(LHSCmp, RHSCmp, DL, LHSIsTrue, Depth); + if (LHSCmp) + return isImpliedCondICmps(LHSCmp, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, + Depth); - // The LHS should be an 'or' or an 'and' instruction. We expect the RHS to be - // an icmp. FIXME: Add support for and/or on the RHS. + /// The LHS should be an 'or' or an 'and' instruction. We expect the RHS to + /// be / an icmp. FIXME: Add support for and/or on the RHS. const BinaryOperator *LHSBO = dyn_cast<BinaryOperator>(LHS); - if (LHSBO && RHSCmp) { + if (LHSBO) { if ((LHSBO->getOpcode() == Instruction::And || LHSBO->getOpcode() == Instruction::Or)) - return isImpliedCondAndOr(LHSBO, RHSCmp, DL, LHSIsTrue, Depth); + return isImpliedCondAndOr(LHSBO, RHSPred, RHSOp0, RHSOp1, DL, LHSIsTrue, + Depth); } return None; } -Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond, - const Instruction *ContextI, - const DataLayout &DL) { - assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool"); +Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + // LHS ==> RHS by definition + if (LHS == RHS) + return LHSIsTrue; + + const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); + if (RHSCmp) + return isImpliedCondition(LHS, RHSCmp->getPredicate(), + RHSCmp->getOperand(0), RHSCmp->getOperand(1), DL, + LHSIsTrue, Depth); + return None; +} + +// Returns a pair (Condition, ConditionIsTrue), where Condition is a branch +// condition dominating ContextI or nullptr, if no condition is found. +static std::pair<Value *, bool> +getDomPredecessorCondition(const Instruction *ContextI) { if (!ContextI || !ContextI->getParent()) - return None; + return {nullptr, false}; // TODO: This is a poor/cheap way to determine dominance. Should we use a // dominator tree (eg, from a SimplifyQuery) instead? const BasicBlock *ContextBB = ContextI->getParent(); const BasicBlock *PredBB = ContextBB->getSinglePredecessor(); if (!PredBB) - return None; + return {nullptr, false}; // We need a conditional branch in the predecessor. Value *PredCond; BasicBlock *TrueBB, *FalseBB; if (!match(PredBB->getTerminator(), m_Br(m_Value(PredCond), TrueBB, FalseBB))) - return None; + return {nullptr, false}; // The branch should get simplified. Don't bother simplifying this condition. if (TrueBB == FalseBB) - return None; + return {nullptr, false}; assert((TrueBB == ContextBB || FalseBB == ContextBB) && "Predecessor block does not point to successor?"); // Is this condition implied by the predecessor condition? - bool CondIsTrue = TrueBB == ContextBB; - return isImpliedCondition(PredCond, Cond, DL, CondIsTrue); + return {PredCond, TrueBB == ContextBB}; +} + +Optional<bool> llvm::isImpliedByDomCondition(const Value *Cond, + const Instruction *ContextI, + const DataLayout &DL) { + assert(Cond->getType()->isIntOrIntVectorTy(1) && "Condition must be bool"); + auto PredCond = getDomPredecessorCondition(ContextI); + if (PredCond.first) + return isImpliedCondition(PredCond.first, Cond, DL, PredCond.second); + return None; +} + +Optional<bool> llvm::isImpliedByDomCondition(CmpInst::Predicate Pred, + const Value *LHS, const Value *RHS, + const Instruction *ContextI, + const DataLayout &DL) { + auto PredCond = getDomPredecessorCondition(ContextI); + if (PredCond.first) + return isImpliedCondition(PredCond.first, Pred, LHS, RHS, DL, + PredCond.second); + return None; } static void setLimitsForBinOp(const BinaryOperator &BO, APInt &Lower, @@ -5861,9 +6375,15 @@ static void setLimitsForSelectPattern(const SelectInst &SI, APInt &Lower, } } -ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) { +ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo, + AssumptionCache *AC, + const Instruction *CtxI, + unsigned Depth) { assert(V->getType()->isIntOrIntVectorTy() && "Expected integer instruction"); + if (Depth == MaxDepth) + return ConstantRange::getFull(V->getType()->getScalarSizeInBits()); + const APInt *C; if (match(V, m_APInt(C))) return ConstantRange(*C); @@ -5885,6 +6405,31 @@ ConstantRange llvm::computeConstantRange(const Value *V, bool UseInstrInfo) { if (auto *Range = IIQ.getMetadata(I, LLVMContext::MD_range)) CR = CR.intersectWith(getConstantRangeFromMetadata(*Range)); + if (CtxI && AC) { + // Try to restrict the range based on information from assumptions. + for (auto &AssumeVH : AC->assumptionsFor(V)) { + if (!AssumeVH) + continue; + CallInst *I = cast<CallInst>(AssumeVH); + assert(I->getParent()->getParent() == CtxI->getParent()->getParent() && + "Got assumption for the wrong function!"); + assert(I->getCalledFunction()->getIntrinsicID() == Intrinsic::assume && + "must be an assume intrinsic"); + + if (!isValidAssumeForContext(I, CtxI, nullptr)) + continue; + Value *Arg = I->getArgOperand(0); + ICmpInst *Cmp = dyn_cast<ICmpInst>(Arg); + // Currently we just use information from comparisons. + if (!Cmp || Cmp->getOperand(0) != V) + continue; + ConstantRange RHS = computeConstantRange(Cmp->getOperand(1), UseInstrInfo, + AC, I, Depth + 1); + CR = CR.intersectWith( + ConstantRange::makeSatisfyingICmpRegion(Cmp->getPredicate(), RHS)); + } + } + return CR; } @@ -5910,10 +6455,12 @@ getOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, const DataLayout &DL) { continue; } - // Otherwise, we have a sequential type like an array or vector. Multiply - // the index by the ElementSize. - uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); - Offset += Size * OpC->getSExtValue(); + // Otherwise, we have a sequential type like an array or fixed-length + // vector. Multiply the index by the ElementSize. + TypeSize Size = DL.getTypeAllocSize(GTI.getIndexedType()); + if (Size.isScalable()) + return None; + Offset += Size.getFixedSize() * OpC->getSExtValue(); } return Offset; |