diff options
Diffstat (limited to 'lib/Analysis/ValueTracking.cpp')
-rw-r--r-- | lib/Analysis/ValueTracking.cpp | 926 |
1 files changed, 574 insertions, 352 deletions
diff --git a/lib/Analysis/ValueTracking.cpp b/lib/Analysis/ValueTracking.cpp index cdfe74d158c9..2730daefa625 100644 --- a/lib/Analysis/ValueTracking.cpp +++ b/lib/Analysis/ValueTracking.cpp @@ -13,37 +13,66 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ValueTracking.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/IR/Statepoint.h" -#include "llvm/Support/Debug.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" #include <algorithm> #include <array> -#include <cstring> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + using namespace llvm; using namespace llvm::PatternMatch; @@ -54,12 +83,6 @@ const unsigned MaxDepth = 6; static cl::opt<unsigned> DomConditionsMaxUses("dom-conditions-max-uses", cl::Hidden, cl::init(20)); -// This optimization is known to cause performance regressions is some cases, -// keep it under a temporary flag for now. -static cl::opt<bool> -DontImproveNonNegativePhiBits("dont-improve-non-negative-phi-bits", - cl::Hidden, cl::init(true)); - /// Returns the bitwidth of the given scalar or pointer type. For vector types, /// returns the element type's bitwidth. static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { @@ -70,6 +93,7 @@ static unsigned getBitWidth(Type *Ty, const DataLayout &DL) { } namespace { + // Simplifying using an assume can only be done in a particular control-flow // context (the context instruction provides that context). If an assume and // the context instruction are not in the same block then the DT helps in @@ -79,6 +103,7 @@ struct Query { AssumptionCache *AC; const Instruction *CxtI; const DominatorTree *DT; + // Unlike the other analyses, this may be a nullptr because not all clients // provide it currently. OptimizationRemarkEmitter *ORE; @@ -92,11 +117,12 @@ struct Query { /// isKnownNonZero, which calls computeKnownBits and isKnownToBeAPowerOfTwo /// (all of which can call computeKnownBits), and so on. std::array<const Value *, MaxDepth> Excluded; - unsigned NumExcluded; + + unsigned NumExcluded = 0; Query(const DataLayout &DL, AssumptionCache *AC, const Instruction *CxtI, const DominatorTree *DT, OptimizationRemarkEmitter *ORE = nullptr) - : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE), NumExcluded(0) {} + : DL(DL), AC(AC), CxtI(CxtI), DT(DT), ORE(ORE) {} Query(const Query &Q, const Value *NewExcl) : DL(Q.DL), AC(Q.AC), CxtI(Q.CxtI), DT(Q.DT), ORE(Q.ORE), @@ -113,6 +139,7 @@ struct Query { return std::find(Excluded.begin(), End, Value) != End; } }; + } // end anonymous namespace // Given the provided Value and, potentially, a context instruction, return @@ -171,7 +198,6 @@ bool llvm::haveNoCommonBitsSet(const Value *LHS, const Value *RHS, return (LHSKnown.Zero | RHSKnown.Zero).isAllOnesValue(); } - bool llvm::isOnlyUsedInZeroEqualityComparison(const Instruction *CxtI) { for (const User *U : CxtI->users()) { if (const ICmpInst *IC = dyn_cast<ICmpInst>(U)) @@ -275,47 +301,7 @@ static void computeKnownBitsAddSub(bool Add, const Value *Op0, const Value *Op1, computeKnownBits(Op0, LHSKnown, Depth + 1, Q); computeKnownBits(Op1, Known2, Depth + 1, Q); - // Carry in a 1 for a subtract, rather than a 0. - uint64_t CarryIn = 0; - if (!Add) { - // Sum = LHS + ~RHS + 1 - std::swap(Known2.Zero, Known2.One); - CarryIn = 1; - } - - APInt PossibleSumZero = ~LHSKnown.Zero + ~Known2.Zero + CarryIn; - APInt PossibleSumOne = LHSKnown.One + Known2.One + CarryIn; - - // Compute known bits of the carry. - APInt CarryKnownZero = ~(PossibleSumZero ^ LHSKnown.Zero ^ Known2.Zero); - APInt CarryKnownOne = PossibleSumOne ^ LHSKnown.One ^ Known2.One; - - // Compute set of known bits (where all three relevant bits are known). - APInt LHSKnownUnion = LHSKnown.Zero | LHSKnown.One; - APInt RHSKnownUnion = Known2.Zero | Known2.One; - APInt CarryKnownUnion = CarryKnownZero | CarryKnownOne; - APInt Known = LHSKnownUnion & RHSKnownUnion & CarryKnownUnion; - - assert((PossibleSumZero & Known) == (PossibleSumOne & Known) && - "known bits of sum differ"); - - // Compute known bits of the result. - KnownOut.Zero = ~PossibleSumOne & Known; - KnownOut.One = PossibleSumOne & Known; - - // Are we still trying to solve for the sign bit? - if (!Known.isSignBitSet()) { - if (NSW) { - // Adding two non-negative numbers, or subtracting a negative number from - // a non-negative one, can't wrap into negative. - if (LHSKnown.isNonNegative() && Known2.isNonNegative()) - KnownOut.makeNonNegative(); - // Adding two negative numbers, or subtracting a non-negative number from - // a negative one, can't wrap into non-negative. - else if (LHSKnown.isNegative() && Known2.isNegative()) - KnownOut.makeNegative(); - } - } + KnownOut = KnownBits::computeForAddSub(Add, NSW, LHSKnown, Known2); } static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, @@ -350,21 +336,78 @@ static void computeKnownBitsMul(const Value *Op0, const Value *Op1, bool NSW, } } - // If low bits are zero in either operand, output low known-0 bits. - // Also compute a conservative estimate for high known-0 bits. - // More trickiness is possible, but this is sufficient for the - // interesting case of alignment computation. - unsigned TrailZ = Known.countMinTrailingZeros() + - Known2.countMinTrailingZeros(); + assert(!Known.hasConflict() && !Known2.hasConflict()); + // Compute a conservative estimate for high known-0 bits. unsigned LeadZ = std::max(Known.countMinLeadingZeros() + Known2.countMinLeadingZeros(), BitWidth) - BitWidth; - - TrailZ = std::min(TrailZ, BitWidth); LeadZ = std::min(LeadZ, BitWidth); + + // The result of the bottom bits of an integer multiply can be + // inferred by looking at the bottom bits of both operands and + // multiplying them together. + // We can infer at least the minimum number of known trailing bits + // of both operands. Depending on number of trailing zeros, we can + // infer more bits, because (a*b) <=> ((a/m) * (b/n)) * (m*n) assuming + // a and b are divisible by m and n respectively. + // We then calculate how many of those bits are inferrable and set + // the output. For example, the i8 mul: + // a = XXXX1100 (12) + // b = XXXX1110 (14) + // We know the bottom 3 bits are zero since the first can be divided by + // 4 and the second by 2, thus having ((12/4) * (14/2)) * (2*4). + // Applying the multiplication to the trimmed arguments gets: + // XX11 (3) + // X111 (7) + // ------- + // XX11 + // XX11 + // XX11 + // XX11 + // ------- + // XXXXX01 + // Which allows us to infer the 2 LSBs. Since we're multiplying the result + // by 8, the bottom 3 bits will be 0, so we can infer a total of 5 bits. + // The proof for this can be described as: + // Pre: (C1 >= 0) && (C1 < (1 << C5)) && (C2 >= 0) && (C2 < (1 << C6)) && + // (C7 == (1 << (umin(countTrailingZeros(C1), C5) + + // umin(countTrailingZeros(C2), C6) + + // umin(C5 - umin(countTrailingZeros(C1), C5), + // C6 - umin(countTrailingZeros(C2), C6)))) - 1) + // %aa = shl i8 %a, C5 + // %bb = shl i8 %b, C6 + // %aaa = or i8 %aa, C1 + // %bbb = or i8 %bb, C2 + // %mul = mul i8 %aaa, %bbb + // %mask = and i8 %mul, C7 + // => + // %mask = i8 ((C1*C2)&C7) + // Where C5, C6 describe the known bits of %a, %b + // C1, C2 describe the known bottom bits of %a, %b. + // C7 describes the mask of the known bits of the result. + APInt Bottom0 = Known.One; + APInt Bottom1 = Known2.One; + + // How many times we'd be able to divide each argument by 2 (shr by 1). + // This gives us the number of trailing zeros on the multiplication result. + unsigned TrailBitsKnown0 = (Known.Zero | Known.One).countTrailingOnes(); + unsigned TrailBitsKnown1 = (Known2.Zero | Known2.One).countTrailingOnes(); + unsigned TrailZero0 = Known.countMinTrailingZeros(); + unsigned TrailZero1 = Known2.countMinTrailingZeros(); + unsigned TrailZ = TrailZero0 + TrailZero1; + + // Figure out the fewest known-bits operand. + unsigned SmallestOperand = std::min(TrailBitsKnown0 - TrailZero0, + TrailBitsKnown1 - TrailZero1); + unsigned ResultBitsKnown = std::min(SmallestOperand + TrailZ, BitWidth); + + APInt BottomKnown = Bottom0.getLoBits(TrailBitsKnown0) * + Bottom1.getLoBits(TrailBitsKnown1); + Known.resetAll(); - Known.Zero.setLowBits(TrailZ); Known.Zero.setHighBits(LeadZ); + Known.Zero |= (~BottomKnown).getLoBits(ResultBitsKnown); + Known.One |= BottomKnown.getLoBits(ResultBitsKnown); // Only make use of no-wrap flags if we failed to compute the sign bit // directly. This matters if the multiplication always overflows, in @@ -420,17 +463,19 @@ static bool isEphemeralValueOf(const Instruction *I, const Value *E) { continue; // If all uses of this value are ephemeral, then so is this value. - if (all_of(V->users(), [&](const User *U) { return EphValues.count(U); })) { + if (llvm::all_of(V->users(), [&](const User *U) { + return EphValues.count(U); + })) { if (V == E) return true; - EphValues.insert(V); - if (const User *U = dyn_cast<User>(V)) - for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); - J != JE; ++J) { - if (isSafeToSpeculativelyExecute(*J)) - WorkSet.push_back(*J); - } + if (V == I || isSafeToSpeculativelyExecute(V)) { + EphValues.insert(V); + if (const User *U = dyn_cast<User>(V)) + for (User::const_op_iterator J = U->op_begin(), JE = U->op_end(); + J != JE; ++J) + WorkSet.push_back(*J); + } } } @@ -438,13 +483,14 @@ static bool isEphemeralValueOf(const Instruction *I, const Value *E) { } // Is this an intrinsic that cannot be speculated but also cannot trap? -static bool isAssumeLikeIntrinsic(const Instruction *I) { +bool llvm::isAssumeLikeIntrinsic(const Instruction *I) { if (const CallInst *CI = dyn_cast<CallInst>(I)) if (Function *F = CI->getCalledFunction()) switch (F->getIntrinsicID()) { default: break; // FIXME: This list is repeated from NoTTI::getIntrinsicCost. case Intrinsic::assume: + case Intrinsic::sideeffect: case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::invariant_start: @@ -463,7 +509,6 @@ static bool isAssumeLikeIntrinsic(const Instruction *I) { bool llvm::isValidAssumeForContext(const Instruction *Inv, const Instruction *CxtI, const DominatorTree *DT) { - // There are two restrictions on the use of an assume: // 1. The assume must dominate the context (or the control flow must // reach the assume whenever it reaches the context). @@ -560,7 +605,7 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, m_BitCast(m_Specific(V)))); CmpInst::Predicate Pred; - ConstantInt *C; + uint64_t C; // assume(v = a) if (match(Arg, m_c_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_EQ && isValidAssumeForContext(I, Q.CxtI, Q.DT)) { @@ -662,51 +707,55 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, } else if (match(Arg, m_c_ICmp(Pred, m_Shl(m_V, m_ConstantInt(C)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { KnownBits RHSKnown(BitWidth); computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. - RHSKnown.Zero.lshrInPlace(C->getZExtValue()); + RHSKnown.Zero.lshrInPlace(C); Known.Zero |= RHSKnown.Zero; - RHSKnown.One.lshrInPlace(C->getZExtValue()); + RHSKnown.One.lshrInPlace(C); Known.One |= RHSKnown.One; // assume(~(v << c) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shl(m_V, m_ConstantInt(C))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { KnownBits RHSKnown(BitWidth); computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. - RHSKnown.One.lshrInPlace(C->getZExtValue()); + RHSKnown.One.lshrInPlace(C); Known.Zero |= RHSKnown.One; - RHSKnown.Zero.lshrInPlace(C->getZExtValue()); + RHSKnown.Zero.lshrInPlace(C); Known.One |= RHSKnown.Zero; // assume(v >> c = a) } else if (match(Arg, m_c_ICmp(Pred, m_Shr(m_V, m_ConstantInt(C)), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { KnownBits RHSKnown(BitWidth); computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them to known // bits in V shifted to the right by C. - Known.Zero |= RHSKnown.Zero << C->getZExtValue(); - Known.One |= RHSKnown.One << C->getZExtValue(); + Known.Zero |= RHSKnown.Zero << C; + Known.One |= RHSKnown.One << C; // assume(~(v >> c) = a) } else if (match(Arg, m_c_ICmp(Pred, m_Not(m_Shr(m_V, m_ConstantInt(C))), m_Value(A))) && Pred == ICmpInst::ICMP_EQ && - isValidAssumeForContext(I, Q.CxtI, Q.DT)) { + isValidAssumeForContext(I, Q.CxtI, Q.DT) && + C < BitWidth) { KnownBits RHSKnown(BitWidth); computeKnownBits(A, RHSKnown, Depth+1, Query(Q, I)); // For those bits in RHS that are known, we can propagate them inverted // to known bits in V shifted to the right by C. - Known.Zero |= RHSKnown.One << C->getZExtValue(); - Known.One |= RHSKnown.Zero << C->getZExtValue(); + Known.Zero |= RHSKnown.One << C; + Known.One |= RHSKnown.Zero << C; // assume(v >=_s c) where c is non-negative } else if (match(Arg, m_ICmp(Pred, m_V, m_Value(A))) && Pred == ICmpInst::ICMP_SGE && @@ -784,24 +833,26 @@ static void computeKnownBitsFromAssume(const Value *V, KnownBits &Known, if (Known.Zero.intersects(Known.One)) { Known.resetAll(); - if (Q.ORE) { - auto *CxtI = const_cast<Instruction *>(Q.CxtI); - OptimizationRemarkAnalysis ORA("value-tracking", "BadAssumption", CxtI); - Q.ORE->emit(ORA << "Detected conflicting code assumptions. Program may " - "have undefined behavior, or compiler may have " - "internal error."); - } + if (Q.ORE) + Q.ORE->emit([&]() { + auto *CxtI = const_cast<Instruction *>(Q.CxtI); + return OptimizationRemarkAnalysis("value-tracking", "BadAssumption", + CxtI) + << "Detected conflicting code assumptions. Program may " + "have undefined behavior, or compiler may have " + "internal error."; + }); } } -// Compute known bits from a shift operator, including those with a -// non-constant shift amount. Known is the outputs of this function. Known2 is a -// pre-allocated temporary with the/ same bit width as Known. KZF and KOF are -// operator-specific functors that, given the known-zero or known-one bits -// respectively, and a shift amount, compute the implied known-zero or known-one -// bits of the shift operator's result respectively for that shift amount. The -// results from calling KZF and KOF are conservatively combined for all -// permitted shift amounts. +/// Compute known bits from a shift operator, including those with a +/// non-constant shift amount. Known is the output of this function. Known2 is a +/// pre-allocated temporary with the same bit width as Known. KZF and KOF are +/// operator-specific functors that, given the known-zero or known-one bits +/// respectively, and a shift amount, compute the implied known-zero or +/// known-one bits of the shift operator's result respectively for that shift +/// amount. The results from calling KZF and KOF are conservatively combined for +/// all permitted shift amounts. static void computeKnownBitsFromShiftOperator( const Operator *I, KnownBits &Known, KnownBits &Known2, unsigned Depth, const Query &Q, @@ -815,19 +866,20 @@ static void computeKnownBitsFromShiftOperator( computeKnownBits(I->getOperand(0), Known, Depth + 1, Q); Known.Zero = KZF(Known.Zero, ShiftAmt); Known.One = KOF(Known.One, ShiftAmt); - // If there is conflict between Known.Zero and Known.One, this must be an - // overflowing left shift, so the shift result is undefined. Clear Known - // bits so that other code could propagate this undef. - if ((Known.Zero & Known.One) != 0) - Known.resetAll(); + // If the known bits conflict, this must be an overflowing left shift, so + // the shift result is poison. We can return anything we want. Choose 0 for + // the best folding opportunity. + if (Known.hasConflict()) + Known.setAllZero(); return; } computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); - // If the shift amount could be greater than or equal to the bit-width of the LHS, the - // value could be undef, so we don't know anything about it. + // If the shift amount could be greater than or equal to the bit-width of the + // LHS, the value could be poison, but bail out because the check below is + // expensive. TODO: Should we just carry on? if ((~Known.Zero).uge(BitWidth)) { Known.resetAll(); return; @@ -851,8 +903,7 @@ static void computeKnownBitsFromShiftOperator( // Early exit if we can't constrain any well-defined shift amount. if (!(ShiftAmtKZ & (PowerOf2Ceil(BitWidth) - 1)) && !(ShiftAmtKO & (PowerOf2Ceil(BitWidth) - 1))) { - ShifterOperandIsNonZero = - isKnownNonZero(I->getOperand(1), Depth + 1, Q); + ShifterOperandIsNonZero = isKnownNonZero(I->getOperand(1), Depth + 1, Q); if (!*ShifterOperandIsNonZero) return; } @@ -883,13 +934,10 @@ static void computeKnownBitsFromShiftOperator( Known.One &= KOF(Known2.One, ShiftAmt); } - // If there are no compatible shift amounts, then we've proven that the shift - // amount must be >= the BitWidth, and the result is undefined. We could - // return anything we'd like, but we need to make sure the sets of known bits - // stay disjoint (it should be better for some other code to actually - // propagate the undef than to pick a value here using known bits). - if (Known.Zero.intersects(Known.One)) - Known.resetAll(); + // If the known bits conflict, the result is poison. Return a 0 and hope the + // caller can further optimize that. + if (Known.hasConflict()) + Known.setAllZero(); } static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, @@ -931,7 +979,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, } break; } - case Instruction::Or: { + case Instruction::Or: computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); @@ -940,7 +988,6 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, // Output known-1 are known to be set if set in either the LHS | RHS. Known.One |= Known2.One; break; - } case Instruction::Xor: { computeKnownBits(I->getOperand(1), Known, Depth + 1, Q); computeKnownBits(I->getOperand(0), Known2, Depth + 1, Q); @@ -1103,7 +1150,7 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, break; } case Instruction::LShr: { - // (ushr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 + // (lshr X, C1) & C2 == 0 iff (-1 >> C1) & C2 == 0 auto KZF = [](const APInt &KnownZero, unsigned ShiftAmt) { APInt KZResult = KnownZero.lshr(ShiftAmt); // High bits known zero. @@ -1298,9 +1345,6 @@ static void computeKnownBitsFromOperator(const Operator *I, KnownBits &Known, Known.Zero.setLowBits(std::min(Known2.countMinTrailingZeros(), Known3.countMinTrailingZeros())); - if (DontImproveNonNegativePhiBits) - break; - auto *OverflowOp = dyn_cast<OverflowingBinaryOperator>(LU); if (OverflowOp && OverflowOp->hasNoSignedWrap()) { // If initial value of recurrence is nonnegative, and we are adding @@ -1525,9 +1569,8 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, // We know that CDS must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - APInt Elt(BitWidth, 0); for (unsigned i = 0, e = CDS->getNumElements(); i != e; ++i) { - Elt = CDS->getElementAsInteger(i); + APInt Elt = CDS->getElementAsAPInt(i); Known.Zero &= ~Elt; Known.One &= Elt; } @@ -1538,7 +1581,6 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, // We know that CV must be a vector of integers. Take the intersection of // each element. Known.Zero.setAllBits(); Known.One.setAllBits(); - APInt Elt(BitWidth, 0); for (unsigned i = 0, e = CV->getNumOperands(); i != e; ++i) { Constant *Element = CV->getAggregateElement(i); auto *ElementCI = dyn_cast_or_null<ConstantInt>(Element); @@ -1546,7 +1588,7 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, Known.resetAll(); return; } - Elt = ElementCI->getValue(); + const APInt &Elt = ElementCI->getValue(); Known.Zero &= ~Elt; Known.One &= Elt; } @@ -1602,6 +1644,8 @@ void computeKnownBits(const Value *V, KnownBits &Known, unsigned Depth, /// types and vectors of integers. bool isKnownToBeAPowerOfTwo(const Value *V, bool OrZero, unsigned Depth, const Query &Q) { + assert(Depth <= MaxDepth && "Limit Search Depth"); + if (const Constant *C = dyn_cast<Constant>(V)) { if (C->isNullValue()) return OrZero; @@ -1755,6 +1799,58 @@ static bool isGEPKnownNonNull(const GEPOperator *GEP, unsigned Depth, return false; } +static bool isKnownNonNullFromDominatingCondition(const Value *V, + const Instruction *CtxI, + const DominatorTree *DT) { + assert(V->getType()->isPointerTy() && "V must be pointer type"); + assert(!isa<ConstantData>(V) && "Did not expect ConstantPointerNull"); + + if (!CtxI || !DT) + return false; + + unsigned NumUsesExplored = 0; + for (auto *U : V->users()) { + // Avoid massive lists + if (NumUsesExplored >= DomConditionsMaxUses) + break; + NumUsesExplored++; + + // If the value is used as an argument to a call or invoke, then argument + // attributes may provide an answer about null-ness. + if (auto CS = ImmutableCallSite(U)) + if (auto *CalledFunc = CS.getCalledFunction()) + for (const Argument &Arg : CalledFunc->args()) + if (CS.getArgOperand(Arg.getArgNo()) == V && + Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) + return true; + + // Consider only compare instructions uniquely controlling a branch + CmpInst::Predicate Pred; + if (!match(const_cast<User *>(U), + m_c_ICmp(Pred, m_Specific(V), m_Zero())) || + (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) + continue; + + for (auto *CmpU : U->users()) { + if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { + assert(BI->isConditional() && "uses a comparison!"); + + BasicBlock *NonNullSuccessor = + BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); + BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); + if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) + return true; + } else if (Pred == ICmpInst::ICMP_NE && + match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && + DT->dominates(cast<Instruction>(CmpU), CtxI)) { + return true; + } + } + } + + return false; +} + /// Does the 'Range' metadata (which must be a valid MD_range operand list) /// ensure that the value it's attached to is never Value? 'RangeType' is /// is the type of the value described by the range. @@ -1800,7 +1896,15 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { return true; } - return false; + // A global variable in address space 0 is non null unless extern weak + // or an absolute symbol reference. Other address spaces may have null as a + // valid address for a global, so we can't assume anything. + if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) { + if (!GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && + GV->getType()->getAddressSpace() == 0) + return true; + } else + return false; } if (auto *I = dyn_cast<Instruction>(V)) { @@ -1815,14 +1919,36 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } } + // Check for pointer simplifications. + if (V->getType()->isPointerTy()) { + // Alloca never returns null, malloc might. + if (isa<AllocaInst>(V) && Q.DL.getAllocaAddrSpace() == 0) + return true; + + // A byval, inalloca, or nonnull argument is never null. + if (const Argument *A = dyn_cast<Argument>(V)) + if (A->hasByValOrInAllocaAttr() || A->hasNonNullAttr()) + return true; + + // A Load tagged with nonnull metadata is never null. + if (const LoadInst *LI = dyn_cast<LoadInst>(V)) + if (LI->getMetadata(LLVMContext::MD_nonnull)) + return true; + + if (auto CS = ImmutableCallSite(V)) + if (CS.isReturnNonNull()) + return true; + } + // The remaining tests are all recursive, so bail out if we hit the limit. if (Depth++ >= MaxDepth) return false; - // Check for pointer simplifications. + // Check for recursive pointer simplifications. if (V->getType()->isPointerTy()) { - if (isKnownNonNullAt(V, Q.CxtI, Q.DT)) + if (isKnownNonNullFromDominatingCondition(V, Q.CxtI, Q.DT)) return true; + if (const GEPOperator *GEP = dyn_cast<GEPOperator>(V)) if (isGEPKnownNonNull(GEP, Depth, Q)) return true; @@ -1949,7 +2075,7 @@ bool isKnownNonZero(const Value *V, unsigned Depth, const Query &Q) { } } // Check if all incoming values are non-zero constant. - bool AllNonZeroConstants = all_of(PN->operands(), [](Value *V) { + bool AllNonZeroConstants = llvm::all_of(PN->operands(), [](Value *V) { return isa<ConstantInt>(V) && !cast<ConstantInt>(V)->isZero(); }); if (AllNonZeroConstants) @@ -2033,11 +2159,7 @@ static unsigned computeNumSignBitsVectorConstant(const Value *V, if (!Elt) return 0; - // If the sign bit is 1, flip the bits, so we always count leading zeros. - APInt EltVal = Elt->getValue(); - if (EltVal.isNegative()) - EltVal = ~EltVal; - MinSignBits = std::min(MinSignBits, EltVal.countLeadingZeros()); + MinSignBits = std::min(MinSignBits, Elt->getValue().getNumSignBits()); } return MinSignBits; @@ -2061,6 +2183,7 @@ static unsigned ComputeNumSignBits(const Value *V, unsigned Depth, /// vector element with the mininum number of known sign bits. static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, const Query &Q) { + assert(Depth <= MaxDepth && "Limit Search Depth"); // We return the minimum number of sign bits that are guaranteed to be present // in V, so for undef we have to conservatively return 1. We don't have the @@ -2236,6 +2359,17 @@ static unsigned ComputeNumSignBitsImpl(const Value *V, unsigned Depth, if (Tmp == 1) return 1; // Early out. return std::min(Tmp, Tmp2)-1; + case Instruction::Mul: { + // The output of the Mul can be at most twice the valid bits in the inputs. + unsigned SignBitsOp0 = ComputeNumSignBits(U->getOperand(0), Depth + 1, Q); + if (SignBitsOp0 == 1) return 1; // Early out. + unsigned SignBitsOp1 = ComputeNumSignBits(U->getOperand(1), Depth + 1, Q); + if (SignBitsOp1 == 1) return 1; + unsigned OutValidBits = + (TyBits - SignBitsOp0 + 1) + (TyBits - SignBitsOp1 + 1); + return OutValidBits > TyBits ? 1 : TyBits - OutValidBits + 1; + } + case Instruction::PHI: { const PHINode *PN = cast<PHINode>(U); unsigned NumIncomingValues = PN->getNumIncomingValues(); @@ -2507,9 +2641,7 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, case LibFunc_sqrt: case LibFunc_sqrtf: case LibFunc_sqrtl: - if (ICS->hasNoNaNs()) - return Intrinsic::sqrt; - return Intrinsic::not_intrinsic; + return Intrinsic::sqrt; } return Intrinsic::not_intrinsic; @@ -2520,41 +2652,40 @@ Intrinsic::ID llvm::getIntrinsicForCallSite(ImmutableCallSite ICS, /// /// NOTE: this function will need to be revisited when we support non-default /// rounding modes! -/// bool llvm::CannotBeNegativeZero(const Value *V, const TargetLibraryInfo *TLI, unsigned Depth) { - if (const ConstantFP *CFP = dyn_cast<ConstantFP>(V)) + if (auto *CFP = dyn_cast<ConstantFP>(V)) return !CFP->getValueAPF().isNegZero(); + // Limit search depth. if (Depth == MaxDepth) - return false; // Limit search depth. + return false; - const Operator *I = dyn_cast<Operator>(V); - if (!I) return false; + auto *Op = dyn_cast<Operator>(V); + if (!Op) + return false; - // Check if the nsz fast-math flag is set - if (const FPMathOperator *FPO = dyn_cast<FPMathOperator>(I)) + // Check if the nsz fast-math flag is set. + if (auto *FPO = dyn_cast<FPMathOperator>(Op)) if (FPO->hasNoSignedZeros()) return true; - // (add x, 0.0) is guaranteed to return +0.0, not -0.0. - if (I->getOpcode() == Instruction::FAdd) - if (ConstantFP *CFP = dyn_cast<ConstantFP>(I->getOperand(1))) - if (CFP->isNullValue()) - return true; + // (fadd x, 0.0) is guaranteed to return +0.0, not -0.0. + if (match(Op, m_FAdd(m_Value(), m_Zero()))) + return true; // sitofp and uitofp turn into +0.0 for zero. - if (isa<SIToFPInst>(I) || isa<UIToFPInst>(I)) + if (isa<SIToFPInst>(Op) || isa<UIToFPInst>(Op)) return true; - if (const CallInst *CI = dyn_cast<CallInst>(I)) { - Intrinsic::ID IID = getIntrinsicForCallSite(CI, TLI); + if (auto *Call = dyn_cast<CallInst>(Op)) { + Intrinsic::ID IID = getIntrinsicForCallSite(Call, TLI); switch (IID) { default: break; // sqrt(-0.0) = -0.0, no other negative results are possible. case Intrinsic::sqrt: - return CannotBeNegativeZero(CI->getArgOperand(0), TLI, Depth + 1); + return CannotBeNegativeZero(Call->getArgOperand(0), TLI, Depth + 1); // fabs(x) != -0.0 case Intrinsic::fabs: return true; @@ -2690,6 +2821,41 @@ bool llvm::SignBitMustBeZero(const Value *V, const TargetLibraryInfo *TLI) { return cannotBeOrderedLessThanZeroImpl(V, TLI, true, 0); } +bool llvm::isKnownNeverNaN(const Value *V) { + assert(V->getType()->isFPOrFPVectorTy() && "Querying for NaN on non-FP type"); + + // If we're told that NaNs won't happen, assume they won't. + if (auto *FPMathOp = dyn_cast<FPMathOperator>(V)) + if (FPMathOp->hasNoNaNs()) + return true; + + // TODO: Handle instructions and potentially recurse like other 'isKnown' + // functions. For example, the result of sitofp is never NaN. + + // Handle scalar constants. + if (auto *CFP = dyn_cast<ConstantFP>(V)) + return !CFP->isNaN(); + + // Bail out for constant expressions, but try to handle vector constants. + if (!V->getType()->isVectorTy() || !isa<Constant>(V)) + return false; + + // For vectors, verify that each element is not NaN. + unsigned NumElts = V->getType()->getVectorNumElements(); + for (unsigned i = 0; i != NumElts; ++i) { + Constant *Elt = cast<Constant>(V)->getAggregateElement(i); + if (!Elt) + return false; + if (isa<UndefValue>(Elt)) + continue; + auto *CElt = dyn_cast<ConstantFP>(Elt); + if (!CElt || CElt->isNaN()) + return false; + } + // All elements were confirmed not-NaN or undefined. + return true; +} + /// If the specified value can be set by repeating the same byte in memory, /// return the i8 value that it is represented with. This is /// true for all i8 values obviously, but is also true for i32 0, i32 -1, @@ -2749,7 +2915,6 @@ Value *llvm::isBytewiseValue(Value *V) { return nullptr; } - // This is the recursive version of BuildSubAggregate. It takes a few different // arguments. Idxs is the index within the nested struct From that we are // looking at now (which is of type IndexedType). IdxSkip is the number of @@ -2760,7 +2925,7 @@ static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType, SmallVectorImpl<unsigned> &Idxs, unsigned IdxSkip, Instruction *InsertBefore) { - llvm::StructType *STy = dyn_cast<llvm::StructType>(IndexedType); + StructType *STy = dyn_cast<StructType>(IndexedType); if (STy) { // Save the original To argument so we can modify it Value *OrigTo = To; @@ -2799,8 +2964,8 @@ static Value *BuildSubAggregate(Value *From, Value* To, Type *IndexedType, return nullptr; // Insert the value in the new (sub) aggregrate - return llvm::InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip), - "tmp", InsertBefore); + return InsertValueInst::Create(To, V, makeArrayRef(Idxs).slice(IdxSkip), + "tmp", InsertBefore); } // This helper takes a nested struct and extracts a part of it (which is again a @@ -3307,7 +3472,8 @@ static const Value *getUnderlyingObjectFromInt(const Value *V) { /// This is a wrapper around GetUnderlyingObjects and adds support for basic /// ptrtoint+arithmetic+inttoptr sequences. -void llvm::getUnderlyingObjectsForCodeGen(const Value *V, +/// It returns false if unidentified object is found in GetUnderlyingObjects. +bool llvm::getUnderlyingObjectsForCodeGen(const Value *V, SmallVectorImpl<Value *> &Objects, const DataLayout &DL) { SmallPtrSet<const Value *, 16> Visited; @@ -3333,11 +3499,12 @@ void llvm::getUnderlyingObjectsForCodeGen(const Value *V, // getUnderlyingObjectsForCodeGen also fails for safety. if (!isIdentifiedObject(V)) { Objects.clear(); - return; + return false; } Objects.push_back(const_cast<Value *>(V)); } } while (!Working.empty()); + return true; } /// Return true if the only users of this pointer are lifetime markers. @@ -3401,7 +3568,8 @@ bool llvm::isSafeToSpeculativelyExecute(const Value *V, // Speculative load may create a race that did not exist in the source. LI->getFunction()->hasFnAttribute(Attribute::SanitizeThread) || // Speculative load may load data from dirty regions. - LI->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) + LI->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + LI->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) return false; const DataLayout &DL = LI->getModule()->getDataLayout(); return isDereferenceableAndAlignedPointer(LI->getPointerOperand(), @@ -3443,100 +3611,6 @@ bool llvm::mayBeMemoryDependent(const Instruction &I) { return I.mayReadOrWriteMemory() || !isSafeToSpeculativelyExecute(&I); } -/// Return true if we know that the specified value is never null. -bool llvm::isKnownNonNull(const Value *V) { - assert(V->getType()->isPointerTy() && "V must be pointer type"); - - // Alloca never returns null, malloc might. - if (isa<AllocaInst>(V)) return true; - - // A byval, inalloca, or nonnull argument is never null. - if (const Argument *A = dyn_cast<Argument>(V)) - return A->hasByValOrInAllocaAttr() || A->hasNonNullAttr(); - - // A global variable in address space 0 is non null unless extern weak - // or an absolute symbol reference. Other address spaces may have null as a - // valid address for a global, so we can't assume anything. - if (const GlobalValue *GV = dyn_cast<GlobalValue>(V)) - return !GV->isAbsoluteSymbolRef() && !GV->hasExternalWeakLinkage() && - GV->getType()->getAddressSpace() == 0; - - // A Load tagged with nonnull metadata is never null. - if (const LoadInst *LI = dyn_cast<LoadInst>(V)) - return LI->getMetadata(LLVMContext::MD_nonnull); - - if (auto CS = ImmutableCallSite(V)) - if (CS.isReturnNonNull()) - return true; - - return false; -} - -static bool isKnownNonNullFromDominatingCondition(const Value *V, - const Instruction *CtxI, - const DominatorTree *DT) { - assert(V->getType()->isPointerTy() && "V must be pointer type"); - assert(!isa<ConstantData>(V) && "Did not expect ConstantPointerNull"); - assert(CtxI && "Context instruction required for analysis"); - assert(DT && "Dominator tree required for analysis"); - - unsigned NumUsesExplored = 0; - for (auto *U : V->users()) { - // Avoid massive lists - if (NumUsesExplored >= DomConditionsMaxUses) - break; - NumUsesExplored++; - - // If the value is used as an argument to a call or invoke, then argument - // attributes may provide an answer about null-ness. - if (auto CS = ImmutableCallSite(U)) - if (auto *CalledFunc = CS.getCalledFunction()) - for (const Argument &Arg : CalledFunc->args()) - if (CS.getArgOperand(Arg.getArgNo()) == V && - Arg.hasNonNullAttr() && DT->dominates(CS.getInstruction(), CtxI)) - return true; - - // Consider only compare instructions uniquely controlling a branch - CmpInst::Predicate Pred; - if (!match(const_cast<User *>(U), - m_c_ICmp(Pred, m_Specific(V), m_Zero())) || - (Pred != ICmpInst::ICMP_EQ && Pred != ICmpInst::ICMP_NE)) - continue; - - for (auto *CmpU : U->users()) { - if (const BranchInst *BI = dyn_cast<BranchInst>(CmpU)) { - assert(BI->isConditional() && "uses a comparison!"); - - BasicBlock *NonNullSuccessor = - BI->getSuccessor(Pred == ICmpInst::ICMP_EQ ? 1 : 0); - BasicBlockEdge Edge(BI->getParent(), NonNullSuccessor); - if (Edge.isSingleEdge() && DT->dominates(Edge, CtxI->getParent())) - return true; - } else if (Pred == ICmpInst::ICMP_NE && - match(CmpU, m_Intrinsic<Intrinsic::experimental_guard>()) && - DT->dominates(cast<Instruction>(CmpU), CtxI)) { - return true; - } - } - } - - return false; -} - -bool llvm::isKnownNonNullAt(const Value *V, const Instruction *CtxI, - const DominatorTree *DT) { - if (isa<ConstantPointerNull>(V) || isa<UndefValue>(V)) - return false; - - if (isKnownNonNull(V)) - return true; - - if (!CtxI || !DT) - return false; - - return ::isKnownNonNullFromDominatingCondition(V, CtxI, DT); -} - OverflowResult llvm::computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const DataLayout &DL, @@ -3771,7 +3845,7 @@ bool llvm::isOverflowIntrinsicNoWrap(const IntrinsicInst *II, return true; }; - return any_of(GuardingBranches, AllUsesGuardedByBranch); + return llvm::any_of(GuardingBranches, AllUsesGuardedByBranch); } @@ -3846,7 +3920,8 @@ bool llvm::isGuaranteedToTransferExecutionToSuccessor(const Instruction *I) { // FIXME: This isn't aggressive enough; a call which only writes to a global // is guaranteed to return. return CS.onlyReadsMemory() || CS.onlyAccessesArgMemory() || - match(I, m_Intrinsic<Intrinsic::assume>()); + match(I, m_Intrinsic<Intrinsic::assume>()) || + match(I, m_Intrinsic<Intrinsic::sideeffect>()); } // Other instructions return normally. @@ -3975,7 +4050,7 @@ bool llvm::programUndefinedIfFullPoison(const Instruction *PoisonI) { } break; - }; + } return false; } @@ -3994,21 +4069,75 @@ static bool isKnownNonZero(const Value *V) { return false; } -/// Match non-obvious integer minimum and maximum sequences. -static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, - Value *CmpLHS, Value *CmpRHS, - Value *TrueVal, Value *FalseVal, - Value *&LHS, Value *&RHS) { - // Assume success. If there's no match, callers should not use these anyway. +/// Match clamp pattern for float types without care about NaNs or signed zeros. +/// Given non-min/max outer cmp/select from the clamp pattern this +/// function recognizes if it can be substitued by a "canonical" min/max +/// pattern. +static SelectPatternResult matchFastFloatClamp(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS) { + // Try to match + // X < C1 ? C1 : Min(X, C2) --> Max(C1, Min(X, C2)) + // X > C1 ? C1 : Max(X, C2) --> Min(C1, Max(X, C2)) + // and return description of the outer Max/Min. + + // First, check if select has inverse order: + if (CmpRHS == FalseVal) { + std::swap(TrueVal, FalseVal); + Pred = CmpInst::getInversePredicate(Pred); + } + + // Assume success now. If there's no match, callers should not use these anyway. LHS = TrueVal; RHS = FalseVal; - // Recognize variations of: - // CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) + const APFloat *FC1; + if (CmpRHS != TrueVal || !match(CmpRHS, m_APFloat(FC1)) || !FC1->isFinite()) + return {SPF_UNKNOWN, SPNB_NA, false}; + + const APFloat *FC2; + switch (Pred) { + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULT: + case CmpInst::FCMP_ULE: + if (match(FalseVal, + m_CombineOr(m_OrdFMin(m_Specific(CmpLHS), m_APFloat(FC2)), + m_UnordFMin(m_Specific(CmpLHS), m_APFloat(FC2)))) && + FC1->compare(*FC2) == APFloat::cmpResult::cmpLessThan) + return {SPF_FMAXNUM, SPNB_RETURNS_ANY, false}; + break; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGT: + case CmpInst::FCMP_UGE: + if (match(FalseVal, + m_CombineOr(m_OrdFMax(m_Specific(CmpLHS), m_APFloat(FC2)), + m_UnordFMax(m_Specific(CmpLHS), m_APFloat(FC2)))) && + FC1->compare(*FC2) == APFloat::cmpResult::cmpGreaterThan) + return {SPF_FMINNUM, SPNB_RETURNS_ANY, false}; + break; + default: + break; + } + + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Recognize variations of: +/// CLAMP(v,l,h) ==> ((v) < (l) ? (l) : ((v) > (h) ? (h) : (v))) +static SelectPatternResult matchClamp(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal) { + // Swap the select operands and predicate to match the patterns below. + if (CmpRHS != TrueVal) { + Pred = ICmpInst::getSwappedPredicate(Pred); + std::swap(TrueVal, FalseVal); + } const APInt *C1; if (CmpRHS == TrueVal && match(CmpRHS, m_APInt(C1))) { const APInt *C2; - // (X <s C1) ? C1 : SMIN(X, C2) ==> SMAX(SMIN(X, C2), C1) if (match(FalseVal, m_SMin(m_Specific(CmpLHS), m_APInt(C2))) && C1->slt(*C2) && Pred == CmpInst::ICMP_SLT) @@ -4029,6 +4158,21 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, C1->ugt(*C2) && Pred == CmpInst::ICMP_UGT) return {SPF_UMIN, SPNB_NA, false}; } + return {SPF_UNKNOWN, SPNB_NA, false}; +} + +/// Match non-obvious integer minimum and maximum sequences. +static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, + Value *CmpLHS, Value *CmpRHS, + Value *TrueVal, Value *FalseVal, + Value *&LHS, Value *&RHS) { + // Assume success. If there's no match, callers should not use these anyway. + LHS = TrueVal; + RHS = FalseVal; + + SelectPatternResult SPR = matchClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal); + if (SPR.Flavor != SelectPatternFlavor::SPF_UNKNOWN) + return SPR; if (Pred != CmpInst::ICMP_SGT && Pred != CmpInst::ICMP_SLT) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4047,6 +4191,7 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, match(TrueVal, m_NSWSub(m_Specific(CmpLHS), m_Specific(CmpRHS)))) return {Pred == CmpInst::ICMP_SGT ? SPF_SMAX : SPF_SMIN, SPNB_NA, false}; + const APInt *C1; if (!match(CmpRHS, m_APInt(C1))) return {SPF_UNKNOWN, SPNB_NA, false}; @@ -4057,7 +4202,8 @@ static SelectPatternResult matchMinMax(CmpInst::Predicate Pred, // Is the sign bit set? // (X <s 0) ? X : MAXVAL ==> (X >u MAXVAL) ? X : MAXVAL ==> UMAX // (X <s 0) ? MAXVAL : X ==> (X >u MAXVAL) ? MAXVAL : X ==> UMIN - if (Pred == CmpInst::ICMP_SLT && *C1 == 0 && C2->isMaxSignedValue()) + if (Pred == CmpInst::ICMP_SLT && C1->isNullValue() && + C2->isMaxSignedValue()) return {CmpLHS == TrueVal ? SPF_UMAX : SPF_UMIN, SPNB_NA, false}; // Is the sign bit clear? @@ -4189,21 +4335,48 @@ static SelectPatternResult matchSelectPattern(CmpInst::Predicate Pred, // ABS(X) ==> (X >s 0) ? X : -X and (X >s -1) ? X : -X // NABS(X) ==> (X >s 0) ? -X : X and (X >s -1) ? -X : X - if (Pred == ICmpInst::ICMP_SGT && (*C1 == 0 || C1->isAllOnesValue())) { + if (Pred == ICmpInst::ICMP_SGT && + (C1->isNullValue() || C1->isAllOnesValue())) { return {(CmpLHS == TrueVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; } // ABS(X) ==> (X <s 0) ? -X : X and (X <s 1) ? -X : X // NABS(X) ==> (X <s 0) ? X : -X and (X <s 1) ? X : -X - if (Pred == ICmpInst::ICMP_SLT && (*C1 == 0 || *C1 == 1)) { + if (Pred == ICmpInst::ICMP_SLT && + (C1->isNullValue() || C1->isOneValue())) { return {(CmpLHS == FalseVal) ? SPF_ABS : SPF_NABS, SPNB_NA, false}; } } } - return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); + if (CmpInst::isIntPredicate(Pred)) + return matchMinMax(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); + + // According to (IEEE 754-2008 5.3.1), minNum(0.0, -0.0) and similar + // may return either -0.0 or 0.0, so fcmp/select pair has stricter + // semantics than minNum. Be conservative in such case. + if (NaNBehavior != SPNB_RETURNS_ANY || + (!FMF.noSignedZeros() && !isKnownNonZero(CmpLHS) && + !isKnownNonZero(CmpRHS))) + return {SPF_UNKNOWN, SPNB_NA, false}; + + return matchFastFloatClamp(Pred, CmpLHS, CmpRHS, TrueVal, FalseVal, LHS, RHS); } +/// Helps to match a select pattern in case of a type mismatch. +/// +/// The function processes the case when type of true and false values of a +/// select instruction differs from type of the cmp instruction operands because +/// of a cast instructon. The function checks if it is legal to move the cast +/// operation after "select". If yes, it returns the new second value of +/// "select" (with the assumption that cast is moved): +/// 1. As operand of cast instruction when both values of "select" are same cast +/// instructions. +/// 2. As restored constant (by applying reverse cast operation) when the first +/// value of the "select" is a cast operation and the second value is a +/// constant. +/// NOTE: We return only the new second value because the first value could be +/// accessed as operand of cast instruction. static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, Instruction::CastOps *CastOp) { auto *Cast1 = dyn_cast<CastInst>(V1); @@ -4234,7 +4407,34 @@ static Value *lookThroughCast(CmpInst *CmpI, Value *V1, Value *V2, CastedTo = ConstantExpr::getTrunc(C, SrcTy, true); break; case Instruction::Trunc: - CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); + Constant *CmpConst; + if (match(CmpI->getOperand(1), m_Constant(CmpConst)) && + CmpConst->getType() == SrcTy) { + // Here we have the following case: + // + // %cond = cmp iN %x, CmpConst + // %tr = trunc iN %x to iK + // %narrowsel = select i1 %cond, iK %t, iK C + // + // We can always move trunc after select operation: + // + // %cond = cmp iN %x, CmpConst + // %widesel = select i1 %cond, iN %x, iN CmpConst + // %tr = trunc iN %widesel to iK + // + // Note that C could be extended in any way because we don't care about + // upper bits after truncation. It can't be abs pattern, because it would + // look like: + // + // select i1 %cond, x, -x. + // + // So only min/max pattern could be matched. Such match requires widened C + // == CmpConst. That is why set widened C = CmpConst, condition trunc + // CmpConst == C is checked below. + CastedTo = CmpConst; + } else { + CastedTo = ConstantExpr::getIntegerCast(C, SrcTy, CmpI->isSigned()); + } break; case Instruction::FPTrunc: CastedTo = ConstantExpr::getFPExtend(C, SrcTy, true); @@ -4307,11 +4507,9 @@ SelectPatternResult llvm::matchSelectPattern(Value *V, Value *&LHS, Value *&RHS, } /// Return true if "icmp Pred LHS RHS" is always true. -static bool isTruePredicate(CmpInst::Predicate Pred, - const Value *LHS, const Value *RHS, - const DataLayout &DL, unsigned Depth, - AssumptionCache *AC, const Instruction *CxtI, - const DominatorTree *DT) { +static bool isTruePredicate(CmpInst::Predicate Pred, const Value *LHS, + const Value *RHS, const DataLayout &DL, + unsigned Depth) { assert(!LHS->getType()->isVectorTy() && "TODO: extend to handle vectors!"); if (ICmpInst::isTrueWhenEqual(Pred) && LHS == RHS) return true; @@ -4348,8 +4546,8 @@ static bool isTruePredicate(CmpInst::Predicate Pred, if (match(A, m_Or(m_Value(X), m_APInt(CA))) && match(B, m_Or(m_Specific(X), m_APInt(CB)))) { KnownBits Known(CA->getBitWidth()); - computeKnownBits(X, Known, DL, Depth + 1, AC, CxtI, DT); - + computeKnownBits(X, Known, DL, Depth + 1, /*AC*/ nullptr, + /*CxtI*/ nullptr, /*DT*/ nullptr); if (CA->isSubsetOf(Known.Zero) && CB->isSubsetOf(Known.Zero)) return true; } @@ -4371,27 +4569,23 @@ static bool isTruePredicate(CmpInst::Predicate Pred, /// ALHS ARHS" is true. Otherwise, return None. static Optional<bool> isImpliedCondOperands(CmpInst::Predicate Pred, const Value *ALHS, - const Value *ARHS, const Value *BLHS, - const Value *BRHS, const DataLayout &DL, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, const DominatorTree *DT) { + const Value *ARHS, const Value *BLHS, const Value *BRHS, + const DataLayout &DL, unsigned Depth) { switch (Pred) { default: return None; case CmpInst::ICMP_SLT: case CmpInst::ICMP_SLE: - if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth, AC, CxtI, - DT) && - isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth, AC, CxtI, DT)) + if (isTruePredicate(CmpInst::ICMP_SLE, BLHS, ALHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_SLE, ARHS, BRHS, DL, Depth)) return true; return None; case CmpInst::ICMP_ULT: case CmpInst::ICMP_ULE: - if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth, AC, CxtI, - DT) && - isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth, AC, CxtI, DT)) + if (isTruePredicate(CmpInst::ICMP_ULE, BLHS, ALHS, DL, Depth) && + isTruePredicate(CmpInst::ICMP_ULE, ARHS, BRHS, DL, Depth)) return true; return None; } @@ -4453,66 +4647,22 @@ isImpliedCondMatchingImmOperands(CmpInst::Predicate APred, const Value *ALHS, return None; } -Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, - const DataLayout &DL, bool LHSIsFalse, - unsigned Depth, AssumptionCache *AC, - const Instruction *CxtI, - const DominatorTree *DT) { - // Bail out when we hit the limit. - if (Depth == MaxDepth) - return None; - - // A mismatch occurs when we compare a scalar cmp to a vector cmp, for example. - if (LHS->getType() != RHS->getType()) - return None; - - Type *OpTy = LHS->getType(); - assert(OpTy->isIntOrIntVectorTy(1)); - - // LHS ==> RHS by definition - if (LHS == RHS) - return !LHSIsFalse; - - if (OpTy->isVectorTy()) - // TODO: extending the code below to handle vectors - return None; - assert(OpTy->isIntegerTy(1) && "implied by above"); - - Value *BLHS, *BRHS; - ICmpInst::Predicate BPred; - // We expect the RHS to be an icmp. - if (!match(RHS, m_ICmp(BPred, m_Value(BLHS), m_Value(BRHS)))) - return None; - - Value *ALHS, *ARHS; - ICmpInst::Predicate APred; - // The LHS can be an 'or', 'and', or 'icmp'. - if (!match(LHS, m_ICmp(APred, m_Value(ALHS), m_Value(ARHS)))) { - // The remaining tests are all recursive, so bail out if we hit the limit. - if (Depth == MaxDepth) - return None; - // If the result of an 'or' is false, then we know both legs of the 'or' are - // false. Similarly, if the result of an 'and' is true, then we know both - // legs of the 'and' are true. - if ((LHSIsFalse && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) || - (!LHSIsFalse && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) { - if (Optional<bool> Implication = isImpliedCondition( - ALHS, RHS, DL, LHSIsFalse, Depth + 1, AC, CxtI, DT)) - return Implication; - if (Optional<bool> Implication = isImpliedCondition( - ARHS, RHS, DL, LHSIsFalse, Depth + 1, AC, CxtI, DT)) - return Implication; - return None; - } - return None; - } - // All of the below logic assumes both LHS and RHS are icmps. - assert(isa<ICmpInst>(LHS) && isa<ICmpInst>(RHS) && "Expected icmps."); - +/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is +/// false. Otherwise, return None if we can't infer anything. +static Optional<bool> isImpliedCondICmps(const ICmpInst *LHS, + const ICmpInst *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + Value *ALHS = LHS->getOperand(0); + Value *ARHS = LHS->getOperand(1); // The rest of the logic assumes the LHS condition is true. If that's not the // case, invert the predicate to make it so. - if (LHSIsFalse) - APred = CmpInst::getInversePredicate(APred); + ICmpInst::Predicate APred = + LHSIsTrue ? LHS->getPredicate() : LHS->getInversePredicate(); + + Value *BLHS = RHS->getOperand(0); + Value *BRHS = RHS->getOperand(1); + ICmpInst::Predicate BPred = RHS->getPredicate(); // Can we infer anything when the two compares have matching operands? bool IsSwappedOps; @@ -4538,8 +4688,80 @@ Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, } if (APred == BPred) - return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth, AC, - CxtI, DT); + return isImpliedCondOperands(APred, ALHS, ARHS, BLHS, BRHS, DL, Depth); + return None; +} + +/// Return true if LHS implies RHS is true. Return false if LHS implies RHS is +/// false. Otherwise, return None if we can't infer anything. We expect the +/// RHS to be an icmp and the LHS to be an 'and' or an 'or' instruction. +static Optional<bool> isImpliedCondAndOr(const BinaryOperator *LHS, + const ICmpInst *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + // The LHS must be an 'or' or an 'and' instruction. + assert((LHS->getOpcode() == Instruction::And || + LHS->getOpcode() == Instruction::Or) && + "Expected LHS to be 'and' or 'or'."); + + assert(Depth <= MaxDepth && "Hit recursion limit"); + + // If the result of an 'or' is false, then we know both legs of the 'or' are + // false. Similarly, if the result of an 'and' is true, then we know both + // legs of the 'and' are true. + Value *ALHS, *ARHS; + if ((!LHSIsTrue && match(LHS, m_Or(m_Value(ALHS), m_Value(ARHS)))) || + (LHSIsTrue && match(LHS, m_And(m_Value(ALHS), m_Value(ARHS))))) { + // FIXME: Make this non-recursion. + if (Optional<bool> Implication = + isImpliedCondition(ALHS, RHS, DL, LHSIsTrue, Depth + 1)) + return Implication; + if (Optional<bool> Implication = + isImpliedCondition(ARHS, RHS, DL, LHSIsTrue, Depth + 1)) + return Implication; + return None; + } + return None; +} +Optional<bool> llvm::isImpliedCondition(const Value *LHS, const Value *RHS, + const DataLayout &DL, bool LHSIsTrue, + unsigned Depth) { + // Bail out when we hit the limit. + if (Depth == MaxDepth) + return None; + + // A mismatch occurs when we compare a scalar cmp to a vector cmp, for + // example. + if (LHS->getType() != RHS->getType()) + return None; + + Type *OpTy = LHS->getType(); + assert(OpTy->isIntOrIntVectorTy(1) && "Expected integer type only!"); + + // LHS ==> RHS by definition + if (LHS == RHS) + return LHSIsTrue; + + // FIXME: Extending the code below to handle vectors. + if (OpTy->isVectorTy()) + return None; + + assert(OpTy->isIntegerTy(1) && "implied by above"); + + // Both LHS and RHS are icmps. + const ICmpInst *LHSCmp = dyn_cast<ICmpInst>(LHS); + const ICmpInst *RHSCmp = dyn_cast<ICmpInst>(RHS); + if (LHSCmp && RHSCmp) + return isImpliedCondICmps(LHSCmp, RHSCmp, DL, LHSIsTrue, Depth); + + // The LHS should be an 'or' or an 'and' instruction. We expect the RHS to be + // an icmp. FIXME: Add support for and/or on the RHS. + const BinaryOperator *LHSBO = dyn_cast<BinaryOperator>(LHS); + if (LHSBO && RHSCmp) { + if ((LHSBO->getOpcode() == Instruction::And || + LHSBO->getOpcode() == Instruction::Or)) + return isImpliedCondAndOr(LHSBO, RHSCmp, DL, LHSIsTrue, Depth); + } return None; } |