diff options
Diffstat (limited to 'lib/Analysis/InstructionSimplify.cpp')
| -rw-r--r-- | lib/Analysis/InstructionSimplify.cpp | 506 | 
1 files changed, 336 insertions, 170 deletions
| diff --git a/lib/Analysis/InstructionSimplify.cpp b/lib/Analysis/InstructionSimplify.cpp index 93fb1143e505..519d6d67be51 100644 --- a/lib/Analysis/InstructionSimplify.cpp +++ b/lib/Analysis/InstructionSimplify.cpp @@ -62,6 +62,8 @@ static Value *SimplifyOrInst(Value *, Value *, const SimplifyQuery &, unsigned);  static Value *SimplifyXorInst(Value *, Value *, const SimplifyQuery &, unsigned);  static Value *SimplifyCastInst(unsigned, Value *, Type *,                                 const SimplifyQuery &, unsigned); +static Value *SimplifyGEPInst(Type *, ArrayRef<Value *>, const SimplifyQuery &, +                              unsigned);  /// For a boolean type or a vector of boolean type, return false or a vector  /// with every element false. @@ -90,7 +92,7 @@ static bool isSameCompare(Value *V, CmpInst::Predicate Pred, Value *LHS,  }  /// Does the given value dominate the specified phi node? -static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) { +static bool valueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {    Instruction *I = dyn_cast<Instruction>(V);    if (!I)      // Arguments and constants dominate all instructions. @@ -99,7 +101,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {    // If we are processing instructions (and/or basic blocks) that have not been    // fully added to a function, the parent nodes may still be null. Simply    // return the conservative answer in these cases. -  if (!I->getParent() || !P->getParent() || !I->getParent()->getParent()) +  if (!I->getParent() || !P->getParent() || !I->getFunction())      return false;    // If we have a DominatorTree then do a precise test. @@ -108,7 +110,7 @@ static bool ValueDominatesPHI(Value *V, PHINode *P, const DominatorTree *DT) {    // Otherwise, if the instruction is in the entry block and is not an invoke,    // then it obviously dominates all phi nodes. -  if (I->getParent() == &I->getParent()->getParent()->getEntryBlock() && +  if (I->getParent() == &I->getFunction()->getEntryBlock() &&        !isa<InvokeInst>(I))      return true; @@ -443,13 +445,13 @@ static Value *ThreadBinOpOverPHI(Instruction::BinaryOps Opcode, Value *LHS,    if (isa<PHINode>(LHS)) {      PI = cast<PHINode>(LHS);      // Bail out if RHS and the phi may be mutually interdependent due to a loop. -    if (!ValueDominatesPHI(RHS, PI, Q.DT)) +    if (!valueDominatesPHI(RHS, PI, Q.DT))        return nullptr;    } else {      assert(isa<PHINode>(RHS) && "No PHI instruction operand!");      PI = cast<PHINode>(RHS);      // Bail out if LHS and the phi may be mutually interdependent due to a loop. -    if (!ValueDominatesPHI(LHS, PI, Q.DT)) +    if (!valueDominatesPHI(LHS, PI, Q.DT))        return nullptr;    } @@ -490,7 +492,7 @@ static Value *ThreadCmpOverPHI(CmpInst::Predicate Pred, Value *LHS, Value *RHS,    PHINode *PI = cast<PHINode>(LHS);    // Bail out if RHS and the phi may be mutually interdependent due to a loop. -  if (!ValueDominatesPHI(RHS, PI, Q.DT)) +  if (!valueDominatesPHI(RHS, PI, Q.DT))      return nullptr;    // Evaluate the BinOp on the incoming phi values. @@ -525,7 +527,7 @@ static Constant *foldOrCommuteConstant(Instruction::BinaryOps Opcode,  /// Given operands for an Add, see if we can fold the result.  /// If not, this returns null. -static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,                                const SimplifyQuery &Q, unsigned MaxRecurse) {    if (Constant *C = foldOrCommuteConstant(Instruction::Add, Op0, Op1, Q))      return C; @@ -538,6 +540,10 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    if (match(Op1, m_Zero()))      return Op0; +  // If two operands are negative, return 0. +  if (isKnownNegation(Op0, Op1)) +    return Constant::getNullValue(Op0->getType()); +    // X + (Y - X) -> Y    // (Y - X) + X -> Y    // Eg: X + -X -> 0 @@ -555,10 +561,14 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    // add nsw/nuw (xor Y, signmask), signmask --> Y    // The no-wrapping add guarantees that the top bit will be set by the add.    // Therefore, the xor must be clearing the already set sign bit of Y. -  if ((isNSW || isNUW) && match(Op1, m_SignMask()) && +  if ((IsNSW || IsNUW) && match(Op1, m_SignMask()) &&        match(Op0, m_Xor(m_Value(Y), m_SignMask())))      return Y; +  // add nuw %x, -1  ->  -1, because %x can only be 0. +  if (IsNUW && match(Op1, m_AllOnes())) +    return Op1; // Which is -1. +    /// i1 add -> xor.    if (MaxRecurse && Op0->getType()->isIntOrIntVectorTy(1))      if (Value *V = SimplifyXorInst(Op0, Op1, Q, MaxRecurse-1)) @@ -581,12 +591,12 @@ static Value *SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    return nullptr;  } -Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW, +Value *llvm::SimplifyAddInst(Value *Op0, Value *Op1, bool IsNSW, bool IsNUW,                               const SimplifyQuery &Query) { -  return ::SimplifyAddInst(Op0, Op1, isNSW, isNUW, Query, RecursionLimit); +  return ::SimplifyAddInst(Op0, Op1, IsNSW, IsNUW, Query, RecursionLimit);  } -/// \brief Compute the base pointer and cumulative constant offsets for V. +/// Compute the base pointer and cumulative constant offsets for V.  ///  /// This strips all constant offsets off of V, leaving it the base pointer, and  /// accumulates the total constant offset applied in the returned constant. It @@ -637,7 +647,7 @@ static Constant *stripAndComputeConstantOffsets(const DataLayout &DL, Value *&V,    return OffsetIntPtr;  } -/// \brief Compute the constant difference between two pointer values. +/// Compute the constant difference between two pointer values.  /// If the difference is not a constant, returns zero.  static Constant *computePointerDifference(const DataLayout &DL, Value *LHS,                                            Value *RHS) { @@ -680,14 +690,14 @@ static Value *SimplifySubInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    if (match(Op0, m_Zero())) {      // 0 - X -> 0 if the sub is NUW.      if (isNUW) -      return Op0; +      return Constant::getNullValue(Op0->getType());      KnownBits Known = computeKnownBits(Op1, Q.DL, 0, Q.AC, Q.CxtI, Q.DT);      if (Known.Zero.isMaxSignedValue()) {        // Op1 is either 0 or the minimum signed value. If the sub is NSW, then        // Op1 must be 0 because negating the minimum signed value is undefined.        if (isNSW) -        return Op0; +        return Constant::getNullValue(Op0->getType());        // 0 - X -> X if X is 0 or the minimum signed value.        return Op1; @@ -799,12 +809,9 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return C;    // X * undef -> 0 -  if (match(Op1, m_Undef())) -    return Constant::getNullValue(Op0->getType()); -    // X * 0 -> 0 -  if (match(Op1, m_Zero())) -    return Op1; +  if (match(Op1, m_CombineOr(m_Undef(), m_Zero()))) +    return Constant::getNullValue(Op0->getType());    // X * 1 -> X    if (match(Op1, m_One())) @@ -826,7 +833,7 @@ static Value *SimplifyMulInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,                                            MaxRecurse))      return V; -  // Mul distributes over Add.  Try some generic simplifications based on this. +  // Mul distributes over Add. Try some generic simplifications based on this.    if (Value *V = ExpandBinOp(Instruction::Mul, Op0, Op1, Instruction::Add,                               Q, MaxRecurse))      return V; @@ -868,13 +875,14 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) {    if (match(Op1, m_Zero()))      return UndefValue::get(Ty); -  // If any element of a constant divisor vector is zero, the whole op is undef. +  // If any element of a constant divisor vector is zero or undef, the whole op +  // is undef.    auto *Op1C = dyn_cast<Constant>(Op1);    if (Op1C && Ty->isVectorTy()) {      unsigned NumElts = Ty->getVectorNumElements();      for (unsigned i = 0; i != NumElts; ++i) {        Constant *Elt = Op1C->getAggregateElement(i); -      if (Elt && Elt->isNullValue()) +      if (Elt && (Elt->isNullValue() || isa<UndefValue>(Elt)))          return UndefValue::get(Ty);      }    } @@ -887,7 +895,7 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) {    // 0 / X -> 0    // 0 % X -> 0    if (match(Op0, m_Zero())) -    return Op0; +    return Constant::getNullValue(Op0->getType());    // X / X -> 1    // X % X -> 0 @@ -898,7 +906,10 @@ static Value *simplifyDivRem(Value *Op0, Value *Op1, bool IsDiv) {    // X % 1 -> 0    // If this is a boolean op (single-bit element type), we can't have    // division-by-zero or remainder-by-zero, so assume the divisor is 1. -  if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1)) +  // Similarly, if we're zero-extending a boolean divisor, then assume it's a 1. +  Value *X; +  if (match(Op1, m_One()) || Ty->isIntOrIntVectorTy(1) || +      (match(Op1, m_ZExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))      return IsDiv ? Op0 : Constant::getNullValue(Ty);    return nullptr; @@ -978,18 +989,17 @@ static Value *simplifyDiv(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,    bool IsSigned = Opcode == Instruction::SDiv;    // (X * Y) / Y -> X if the multiplication does not overflow. -  Value *X = nullptr, *Y = nullptr; -  if (match(Op0, m_Mul(m_Value(X), m_Value(Y))) && (X == Op1 || Y == Op1)) { -    if (Y != Op1) std::swap(X, Y); // Ensure expression is (X * Y) / Y, Y = Op1 -    OverflowingBinaryOperator *Mul = cast<OverflowingBinaryOperator>(Op0); -    // If the Mul knows it does not overflow, then we are good to go. +  Value *X; +  if (match(Op0, m_c_Mul(m_Value(X), m_Specific(Op1)))) { +    auto *Mul = cast<OverflowingBinaryOperator>(Op0); +    // If the Mul does not overflow, then we are good to go.      if ((IsSigned && Mul->hasNoSignedWrap()) ||          (!IsSigned && Mul->hasNoUnsignedWrap()))        return X; -    // If X has the form X = A / Y then X * Y cannot overflow. -    if (BinaryOperator *Div = dyn_cast<BinaryOperator>(X)) -      if (Div->getOpcode() == Opcode && Div->getOperand(1) == Y) -        return X; +    // If X has the form X = A / Y, then X * Y cannot overflow. +    if ((IsSigned && match(X, m_SDiv(m_Value(), m_Specific(Op1)))) || +        (!IsSigned && match(X, m_UDiv(m_Value(), m_Specific(Op1))))) +      return X;    }    // (X rem Y) / Y -> 0 @@ -1041,6 +1051,13 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,         match(Op0, m_URem(m_Value(), m_Specific(Op1)))))      return Op0; +  // (X << Y) % X -> 0 +  if ((Opcode == Instruction::SRem && +       match(Op0, m_NSWShl(m_Specific(Op1), m_Value()))) || +      (Opcode == Instruction::URem && +       match(Op0, m_NUWShl(m_Specific(Op1), m_Value())))) +    return Constant::getNullValue(Op0->getType()); +    // If the operation is with the result of a select instruction, check whether    // operating on either branch of the select always yields the same value.    if (isa<SelectInst>(Op0) || isa<SelectInst>(Op1)) @@ -1064,6 +1081,10 @@ static Value *simplifyRem(Instruction::BinaryOps Opcode, Value *Op0, Value *Op1,  /// If not, this returns null.  static Value *SimplifySDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,                                 unsigned MaxRecurse) { +  // If two operands are negated and no signed overflow, return -1. +  if (isKnownNegation(Op0, Op1, /*NeedNSW=*/true)) +    return Constant::getAllOnesValue(Op0->getType()); +    return simplifyDiv(Instruction::SDiv, Op0, Op1, Q, MaxRecurse);  } @@ -1086,6 +1107,16 @@ Value *llvm::SimplifyUDivInst(Value *Op0, Value *Op1, const SimplifyQuery &Q) {  /// If not, this returns null.  static Value *SimplifySRemInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,                                 unsigned MaxRecurse) { +  // If the divisor is 0, the result is undefined, so assume the divisor is -1. +  // srem Op0, (sext i1 X) --> srem Op0, -1 --> 0 +  Value *X; +  if (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) +    return ConstantInt::getNullValue(Op0->getType()); + +  // If the two operands are negated, return 0. +  if (isKnownNegation(Op0, Op1)) +    return ConstantInt::getNullValue(Op0->getType()); +    return simplifyRem(Instruction::SRem, Op0, Op1, Q, MaxRecurse);  } @@ -1140,10 +1171,14 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,    // 0 shift by X -> 0    if (match(Op0, m_Zero())) -    return Op0; +    return Constant::getNullValue(Op0->getType());    // X shift by 0 -> X -  if (match(Op1, m_Zero())) +  // Shift-by-sign-extended bool must be shift-by-0 because shift-by-all-ones +  // would be poison. +  Value *X; +  if (match(Op1, m_Zero()) || +      (match(Op1, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)))      return Op0;    // Fold undefined shifts. @@ -1177,7 +1212,7 @@ static Value *SimplifyShift(Instruction::BinaryOps Opcode, Value *Op0,    return nullptr;  } -/// \brief Given operands for an Shl, LShr or AShr, see if we can +/// Given operands for an Shl, LShr or AShr, see if we can  /// fold the result.  If not, this returns null.  static Value *SimplifyRightShift(Instruction::BinaryOps Opcode, Value *Op0,                                   Value *Op1, bool isExact, const SimplifyQuery &Q, @@ -1220,6 +1255,13 @@ static Value *SimplifyShlInst(Value *Op0, Value *Op1, bool isNSW, bool isNUW,    Value *X;    if (match(Op0, m_Exact(m_Shr(m_Value(X), m_Specific(Op1)))))      return X; + +  // shl nuw i8 C, %x  ->  C  iff C has sign bit set. +  if (isNUW && match(Op0, m_Negative())) +    return Op0; +  // NOTE: could use computeKnownBits() / LazyValueInfo, +  // but the cost-benefit analysis suggests it isn't worth it. +    return nullptr;  } @@ -1257,9 +1299,10 @@ static Value *SimplifyAShrInst(Value *Op0, Value *Op1, bool isExact,                                      MaxRecurse))      return V; -  // all ones >>a X -> all ones +  // all ones >>a X -> -1 +  // Do not return Op0 because it may contain undef elements if it's a vector.    if (match(Op0, m_AllOnes())) -    return Op0; +    return Constant::getAllOnesValue(Op0->getType());    // (X << A) >> A -> X    Value *X; @@ -1295,7 +1338,7 @@ static Value *simplifyUnsignedRangeCheck(ICmpInst *ZeroICmp,        ICmpInst::isUnsigned(UnsignedPred))      ;    else if (match(UnsignedICmp, -                 m_ICmp(UnsignedPred, m_Value(Y), m_Specific(X))) && +                 m_ICmp(UnsignedPred, m_Specific(Y), m_Value(X))) &&             ICmpInst::isUnsigned(UnsignedPred))      UnsignedPred = ICmpInst::getSwappedPredicate(UnsignedPred);    else @@ -1413,6 +1456,43 @@ static Value *simplifyAndOrOfICmpsWithConstants(ICmpInst *Cmp0, ICmpInst *Cmp1,    return nullptr;  } +static Value *simplifyAndOrOfICmpsWithZero(ICmpInst *Cmp0, ICmpInst *Cmp1, +                                           bool IsAnd) { +  ICmpInst::Predicate P0 = Cmp0->getPredicate(), P1 = Cmp1->getPredicate(); +  if (!match(Cmp0->getOperand(1), m_Zero()) || +      !match(Cmp1->getOperand(1), m_Zero()) || P0 != P1) +    return nullptr; + +  if ((IsAnd && P0 != ICmpInst::ICMP_NE) || (!IsAnd && P1 != ICmpInst::ICMP_EQ)) +    return nullptr; + +  // We have either "(X == 0 || Y == 0)" or "(X != 0 && Y != 0)". +  Value *X = Cmp0->getOperand(0); +  Value *Y = Cmp1->getOperand(0); + +  // If one of the compares is a masked version of a (not) null check, then +  // that compare implies the other, so we eliminate the other. Optionally, look +  // through a pointer-to-int cast to match a null check of a pointer type. + +  // (X == 0) || (([ptrtoint] X & ?) == 0) --> ([ptrtoint] X & ?) == 0 +  // (X == 0) || ((? & [ptrtoint] X) == 0) --> (? & [ptrtoint] X) == 0 +  // (X != 0) && (([ptrtoint] X & ?) != 0) --> ([ptrtoint] X & ?) != 0 +  // (X != 0) && ((? & [ptrtoint] X) != 0) --> (? & [ptrtoint] X) != 0 +  if (match(Y, m_c_And(m_Specific(X), m_Value())) || +      match(Y, m_c_And(m_PtrToInt(m_Specific(X)), m_Value()))) +    return Cmp1; + +  // (([ptrtoint] Y & ?) == 0) || (Y == 0) --> ([ptrtoint] Y & ?) == 0 +  // ((? & [ptrtoint] Y) == 0) || (Y == 0) --> (? & [ptrtoint] Y) == 0 +  // (([ptrtoint] Y & ?) != 0) && (Y != 0) --> ([ptrtoint] Y & ?) != 0 +  // ((? & [ptrtoint] Y) != 0) && (Y != 0) --> (? & [ptrtoint] Y) != 0 +  if (match(X, m_c_And(m_Specific(Y), m_Value())) || +      match(X, m_c_And(m_PtrToInt(m_Specific(Y)), m_Value()))) +    return Cmp0; + +  return nullptr; +} +  static Value *simplifyAndOfICmpsWithAdd(ICmpInst *Op0, ICmpInst *Op1) {    // (icmp (add V, C0), C1) & (icmp V, C0)    ICmpInst::Predicate Pred0, Pred1; @@ -1473,6 +1553,9 @@ static Value *simplifyAndOfICmps(ICmpInst *Op0, ICmpInst *Op1) {    if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, true))      return X; +  if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, true)) +    return X; +    if (Value *X = simplifyAndOfICmpsWithAdd(Op0, Op1))      return X;    if (Value *X = simplifyAndOfICmpsWithAdd(Op1, Op0)) @@ -1541,6 +1624,9 @@ static Value *simplifyOrOfICmps(ICmpInst *Op0, ICmpInst *Op1) {    if (Value *X = simplifyAndOrOfICmpsWithConstants(Op0, Op1, false))      return X; +  if (Value *X = simplifyAndOrOfICmpsWithZero(Op0, Op1, false)) +    return X; +    if (Value *X = simplifyOrOfICmpsWithAdd(Op0, Op1))      return X;    if (Value *X = simplifyOrOfICmpsWithAdd(Op1, Op0)) @@ -1638,7 +1724,7 @@ static Value *SimplifyAndInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,    // X & 0 = 0    if (match(Op1, m_Zero())) -    return Op1; +    return Constant::getNullValue(Op0->getType());    // X & -1 = X    if (match(Op1, m_AllOnes())) @@ -1733,21 +1819,16 @@ static Value *SimplifyOrInst(Value *Op0, Value *Op1, const SimplifyQuery &Q,      return C;    // X | undef -> -1 -  if (match(Op1, m_Undef())) +  // X | -1 = -1 +  // Do not return Op1 because it may contain undef elements if it's a vector. +  if (match(Op1, m_Undef()) || match(Op1, m_AllOnes()))      return Constant::getAllOnesValue(Op0->getType());    // X | X = X -  if (Op0 == Op1) -    return Op0; -    // X | 0 = X -  if (match(Op1, m_Zero())) +  if (Op0 == Op1 || match(Op1, m_Zero()))      return Op0; -  // X | -1 = -1 -  if (match(Op1, m_AllOnes())) -    return Op1; -    // A | ~A  =  ~A | A  =  -1    if (match(Op0, m_Not(m_Specific(Op1))) ||        match(Op1, m_Not(m_Specific(Op0)))) @@ -2051,9 +2132,12 @@ computePointerICmp(const DataLayout &DL, const TargetLibraryInfo *TLI,        ConstantInt *LHSOffsetCI = dyn_cast<ConstantInt>(LHSOffset);        ConstantInt *RHSOffsetCI = dyn_cast<ConstantInt>(RHSOffset);        uint64_t LHSSize, RHSSize; +      ObjectSizeOpts Opts; +      Opts.NullIsUnknownSize = +          NullPointerIsDefined(cast<AllocaInst>(LHS)->getFunction());        if (LHSOffsetCI && RHSOffsetCI && -          getObjectSize(LHS, LHSSize, DL, TLI) && -          getObjectSize(RHS, RHSSize, DL, TLI)) { +          getObjectSize(LHS, LHSSize, DL, TLI, Opts) && +          getObjectSize(RHS, RHSSize, DL, TLI, Opts)) {          const APInt &LHSOffsetValue = LHSOffsetCI->getValue();          const APInt &RHSOffsetValue = RHSOffsetCI->getValue();          if (!LHSOffsetValue.isNegative() && @@ -2442,6 +2526,20 @@ static void setLimitsForBinOp(BinaryOperator &BO, APInt &Lower, APInt &Upper) {  static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,                                         Value *RHS) { +  Type *ITy = GetCompareTy(RHS); // The return type. + +  Value *X; +  // Sign-bit checks can be optimized to true/false after unsigned +  // floating-point casts: +  // icmp slt (bitcast (uitofp X)),  0 --> false +  // icmp sgt (bitcast (uitofp X)), -1 --> true +  if (match(LHS, m_BitCast(m_UIToFP(m_Value(X))))) { +    if (Pred == ICmpInst::ICMP_SLT && match(RHS, m_Zero())) +      return ConstantInt::getFalse(ITy); +    if (Pred == ICmpInst::ICMP_SGT && match(RHS, m_AllOnes())) +      return ConstantInt::getTrue(ITy); +  } +    const APInt *C;    if (!match(RHS, m_APInt(C)))      return nullptr; @@ -2449,9 +2547,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,    // Rule out tautological comparisons (eg., ult 0 or uge 0).    ConstantRange RHS_CR = ConstantRange::makeExactICmpRegion(Pred, *C);    if (RHS_CR.isEmptySet()) -    return ConstantInt::getFalse(GetCompareTy(RHS)); +    return ConstantInt::getFalse(ITy);    if (RHS_CR.isFullSet()) -    return ConstantInt::getTrue(GetCompareTy(RHS)); +    return ConstantInt::getTrue(ITy);    // Find the range of possible values for binary operators.    unsigned Width = C->getBitWidth(); @@ -2469,9 +2567,9 @@ static Value *simplifyICmpWithConstant(CmpInst::Predicate Pred, Value *LHS,    if (!LHS_CR.isFullSet()) {      if (RHS_CR.contains(LHS_CR)) -      return ConstantInt::getTrue(GetCompareTy(RHS)); +      return ConstantInt::getTrue(ITy);      if (RHS_CR.inverse().contains(LHS_CR)) -      return ConstantInt::getFalse(GetCompareTy(RHS)); +      return ConstantInt::getFalse(ITy);    }    return nullptr; @@ -3008,8 +3106,7 @@ static Value *SimplifyICmpInst(unsigned Predicate, Value *LHS, Value *RHS,    Type *ITy = GetCompareTy(LHS); // The return type.    // icmp X, X -> true/false -  // X icmp undef -> true/false.  For example, icmp ugt %X, undef -> false -  // because X could be 0. +  // icmp X, undef -> true/false because undef could be X.    if (LHS == RHS || isa<UndefValue>(RHS))      return ConstantInt::get(ITy, CmpInst::isTrueWhenEqual(Pred)); @@ -3309,6 +3406,12 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,        return getTrue(RetTy);    } +  // NaN is unordered; NaN is not ordered. +  assert((FCmpInst::isOrdered(Pred) || FCmpInst::isUnordered(Pred)) && +         "Comparison must be either ordered or unordered"); +  if (match(RHS, m_NaN())) +    return ConstantInt::get(RetTy, CmpInst::isUnordered(Pred)); +    // fcmp pred x, undef  and  fcmp pred undef, x    // fold to true if unordered, false if ordered    if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) { @@ -3328,15 +3431,6 @@ static Value *SimplifyFCmpInst(unsigned Predicate, Value *LHS, Value *RHS,    // Handle fcmp with constant RHS.    const APFloat *C;    if (match(RHS, m_APFloat(C))) { -    // If the constant is a nan, see if we can fold the comparison based on it. -    if (C->isNaN()) { -      if (FCmpInst::isOrdered(Pred)) // True "if ordered and foo" -        return getFalse(RetTy); -      assert(FCmpInst::isUnordered(Pred) && -             "Comparison must be either ordered or unordered!"); -      // True if unordered. -      return getTrue(RetTy); -    }      // Check whether the constant is an infinity.      if (C->isInfinity()) {        if (C->isNegative()) { @@ -3475,6 +3569,17 @@ static const Value *SimplifyWithOpReplaced(Value *V, Value *Op, Value *RepOp,      }    } +  // Same for GEPs. +  if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { +    if (MaxRecurse) { +      SmallVector<Value *, 8> NewOps(GEP->getNumOperands()); +      transform(GEP->operands(), NewOps.begin(), +                [&](Value *V) { return V == Op ? RepOp : V; }); +      return SimplifyGEPInst(GEP->getSourceElementType(), NewOps, Q, +                             MaxRecurse - 1); +    } +  } +    // TODO: We could hand off more cases to instsimplify here.    // If all operands are constant after substituting Op for RepOp then we can @@ -3581,24 +3686,6 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,                                                TrueVal, FalseVal))      return V; -  if (CondVal->hasOneUse()) { -    const APInt *C; -    if (match(CmpRHS, m_APInt(C))) { -      // X < MIN ? T : F  -->  F -      if (Pred == ICmpInst::ICMP_SLT && C->isMinSignedValue()) -        return FalseVal; -      // X < MIN ? T : F  -->  F -      if (Pred == ICmpInst::ICMP_ULT && C->isMinValue()) -        return FalseVal; -      // X > MAX ? T : F  -->  F -      if (Pred == ICmpInst::ICMP_SGT && C->isMaxSignedValue()) -        return FalseVal; -      // X > MAX ? T : F  -->  F -      if (Pred == ICmpInst::ICMP_UGT && C->isMaxValue()) -        return FalseVal; -    } -  } -    // If we have an equality comparison, then we know the value in one of the    // arms of the select. See if substituting this value into the arm and    // simplifying the result yields the same value as the other arm. @@ -3631,37 +3718,38 @@ static Value *simplifySelectWithICmpCond(Value *CondVal, Value *TrueVal,  /// Given operands for a SelectInst, see if we can fold the result.  /// If not, this returns null. -static Value *SimplifySelectInst(Value *CondVal, Value *TrueVal, -                                 Value *FalseVal, const SimplifyQuery &Q, -                                 unsigned MaxRecurse) { -  // select true, X, Y  -> X -  // select false, X, Y -> Y -  if (Constant *CB = dyn_cast<Constant>(CondVal)) { -    if (Constant *CT = dyn_cast<Constant>(TrueVal)) -      if (Constant *CF = dyn_cast<Constant>(FalseVal)) -        return ConstantFoldSelectInstruction(CB, CT, CF); -    if (CB->isAllOnesValue()) +static Value *SimplifySelectInst(Value *Cond, Value *TrueVal, Value *FalseVal, +                                 const SimplifyQuery &Q, unsigned MaxRecurse) { +  if (auto *CondC = dyn_cast<Constant>(Cond)) { +    if (auto *TrueC = dyn_cast<Constant>(TrueVal)) +      if (auto *FalseC = dyn_cast<Constant>(FalseVal)) +        return ConstantFoldSelectInstruction(CondC, TrueC, FalseC); + +    // select undef, X, Y -> X or Y +    if (isa<UndefValue>(CondC)) +      return isa<Constant>(FalseVal) ? FalseVal : TrueVal; + +    // TODO: Vector constants with undef elements don't simplify. + +    // select true, X, Y  -> X +    if (CondC->isAllOnesValue())        return TrueVal; -    if (CB->isNullValue()) +    // select false, X, Y -> Y +    if (CondC->isNullValue())        return FalseVal;    } -  // select C, X, X -> X +  // select ?, X, X -> X    if (TrueVal == FalseVal)      return TrueVal; -  if (isa<UndefValue>(CondVal)) {  // select undef, X, Y -> X or Y -    if (isa<Constant>(FalseVal)) -      return FalseVal; -    return TrueVal; -  } -  if (isa<UndefValue>(TrueVal))   // select C, undef, X -> X +  if (isa<UndefValue>(TrueVal))   // select ?, undef, X -> X      return FalseVal; -  if (isa<UndefValue>(FalseVal))   // select C, X, undef -> X +  if (isa<UndefValue>(FalseVal))   // select ?, X, undef -> X      return TrueVal;    if (Value *V = -          simplifySelectWithICmpCond(CondVal, TrueVal, FalseVal, Q, MaxRecurse)) +          simplifySelectWithICmpCond(Cond, TrueVal, FalseVal, Q, MaxRecurse))      return V;    return nullptr; @@ -3697,7 +3785,7 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,    if (Ops.size() == 2) {      // getelementptr P, 0 -> P. -    if (match(Ops[1], m_Zero())) +    if (match(Ops[1], m_Zero()) && Ops[0]->getType() == GEPTy)        return Ops[0];      Type *Ty = SrcTy; @@ -3706,13 +3794,13 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,        uint64_t C;        uint64_t TyAllocSize = Q.DL.getTypeAllocSize(Ty);        // getelementptr P, N -> P if P points to a type of zero size. -      if (TyAllocSize == 0) +      if (TyAllocSize == 0 && Ops[0]->getType() == GEPTy)          return Ops[0];        // The following transforms are only safe if the ptrtoint cast        // doesn't truncate the pointers.        if (Ops[1]->getType()->getScalarSizeInBits() == -          Q.DL.getPointerSizeInBits(AS)) { +          Q.DL.getIndexSizeInBits(AS)) {          auto PtrToIntOrZero = [GEPTy](Value *P) -> Value * {            if (match(P, m_Zero()))              return Constant::getNullValue(GEPTy); @@ -3752,10 +3840,10 @@ static Value *SimplifyGEPInst(Type *SrcTy, ArrayRef<Value *> Ops,    if (Q.DL.getTypeAllocSize(LastType) == 1 &&        all_of(Ops.slice(1).drop_back(1),               [](Value *Idx) { return match(Idx, m_Zero()); })) { -    unsigned PtrWidth = -        Q.DL.getPointerSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); -    if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == PtrWidth) { -      APInt BasePtrOffset(PtrWidth, 0); +    unsigned IdxWidth = +        Q.DL.getIndexSizeInBits(Ops[0]->getType()->getPointerAddressSpace()); +    if (Q.DL.getTypeSizeInBits(Ops.back()->getType()) == IdxWidth) { +      APInt BasePtrOffset(IdxWidth, 0);        Value *StrippedBasePtr =            Ops[0]->stripAndAccumulateInBoundsConstantOffsets(Q.DL,                                                              BasePtrOffset); @@ -3838,12 +3926,13 @@ Value *llvm::SimplifyInsertElementInst(Value *Vec, Value *Val, Value *Idx,    // Fold into undef if index is out of bounds.    if (auto *CI = dyn_cast<ConstantInt>(Idx)) {      uint64_t NumElements = cast<VectorType>(Vec->getType())->getNumElements(); -      if (CI->uge(NumElements))        return UndefValue::get(Vec->getType());    } -  // TODO: We should also fold if index is iteslf an undef. +  // If index is undef, it might be out of bounds (see above case) +  if (isa<UndefValue>(Idx)) +    return UndefValue::get(Vec->getType());    return nullptr;  } @@ -3896,10 +3985,13 @@ static Value *SimplifyExtractElementInst(Value *Vec, Value *Idx, const SimplifyQ    // If extracting a specified index from the vector, see if we can recursively    // find a previously computed scalar that was inserted into the vector. -  if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) -    if (IdxC->getValue().ule(Vec->getType()->getVectorNumElements())) -      if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) -        return Elt; +  if (auto *IdxC = dyn_cast<ConstantInt>(Idx)) { +    if (IdxC->getValue().uge(Vec->getType()->getVectorNumElements())) +      // definitely out of bounds, thus undefined result +      return UndefValue::get(Vec->getType()->getVectorElementType()); +    if (Value *Elt = findScalarElement(Vec, IdxC->getZExtValue())) +      return Elt; +  }    // An undef extract index can be arbitrarily chosen to be an out-of-range    // index value, which would result in the instruction being undef. @@ -3942,7 +4034,7 @@ static Value *SimplifyPHINode(PHINode *PN, const SimplifyQuery &Q) {    // instruction, we cannot return X as the result of the PHI node unless it    // dominates the PHI block.    if (HasUndefInput) -    return ValueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr; +    return valueDominatesPHI(CommonValue, PN, Q.DT) ? CommonValue : nullptr;    return CommonValue;  } @@ -4119,6 +4211,28 @@ Value *llvm::SimplifyShuffleVectorInst(Value *Op0, Value *Op1, Constant *Mask,    return ::SimplifyShuffleVectorInst(Op0, Op1, Mask, RetTy, Q, RecursionLimit);  } +static Constant *propagateNaN(Constant *In) { +  // If the input is a vector with undef elements, just return a default NaN. +  if (!In->isNaN()) +    return ConstantFP::getNaN(In->getType()); + +  // Propagate the existing NaN constant when possible. +  // TODO: Should we quiet a signaling NaN? +  return In; +} + +static Constant *simplifyFPBinop(Value *Op0, Value *Op1) { +  if (isa<UndefValue>(Op0) || isa<UndefValue>(Op1)) +    return ConstantFP::getNaN(Op0->getType()); + +  if (match(Op0, m_NaN())) +    return propagateNaN(cast<Constant>(Op0)); +  if (match(Op1, m_NaN())) +    return propagateNaN(cast<Constant>(Op1)); + +  return nullptr; +} +  /// Given operands for an FAdd, see if we can fold the result.  If not, this  /// returns null.  static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF, @@ -4126,29 +4240,28 @@ static Value *SimplifyFAddInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FAdd, Op0, Op1, Q))      return C; +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; +    // fadd X, -0 ==> X -  if (match(Op1, m_NegZero())) +  if (match(Op1, m_NegZeroFP()))      return Op0;    // fadd X, 0 ==> X, when we know X is not -0 -  if (match(Op1, m_Zero()) && +  if (match(Op1, m_PosZeroFP()) &&        (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI)))      return Op0; -  // fadd [nnan ninf] X, (fsub [nnan ninf] 0, X) ==> 0 -  //   where nnan and ninf have to occur at least once somewhere in this -  //   expression -  Value *SubOp = nullptr; -  if (match(Op1, m_FSub(m_AnyZero(), m_Specific(Op0)))) -    SubOp = Op1; -  else if (match(Op0, m_FSub(m_AnyZero(), m_Specific(Op1)))) -    SubOp = Op0; -  if (SubOp) { -    Instruction *FSub = cast<Instruction>(SubOp); -    if ((FMF.noNaNs() || FSub->hasNoNaNs()) && -        (FMF.noInfs() || FSub->hasNoInfs())) -      return Constant::getNullValue(Op0->getType()); -  } +  // With nnan: (+/-0.0 - X) + X --> 0.0 (and commuted variant) +  // We don't have to explicitly exclude infinities (ninf): INF + -INF == NaN. +  // Negative zeros are allowed because we always end up with positive zero: +  // X = -0.0: (-0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 +  // X = -0.0: ( 0.0 - (-0.0)) + (-0.0) == ( 0.0) + (-0.0) == 0.0 +  // X =  0.0: (-0.0 - ( 0.0)) + ( 0.0) == (-0.0) + ( 0.0) == 0.0 +  // X =  0.0: ( 0.0 - ( 0.0)) + ( 0.0) == ( 0.0) + ( 0.0) == 0.0 +  if (FMF.noNaNs() && (match(Op0, m_FSub(m_AnyZeroFP(), m_Specific(Op1))) || +                       match(Op1, m_FSub(m_AnyZeroFP(), m_Specific(Op0))))) +    return ConstantFP::getNullValue(Op0->getType());    return nullptr;  } @@ -4160,23 +4273,27 @@ static Value *SimplifyFSubInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FSub, Op0, Op1, Q))      return C; -  // fsub X, 0 ==> X -  if (match(Op1, m_Zero())) +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; + +  // fsub X, +0 ==> X +  if (match(Op1, m_PosZeroFP()))      return Op0;    // fsub X, -0 ==> X, when we know X is not -0 -  if (match(Op1, m_NegZero()) && +  if (match(Op1, m_NegZeroFP()) &&        (FMF.noSignedZeros() || CannotBeNegativeZero(Op0, Q.TLI)))      return Op0;    // fsub -0.0, (fsub -0.0, X) ==> X    Value *X; -  if (match(Op0, m_NegZero()) && match(Op1, m_FSub(m_NegZero(), m_Value(X)))) +  if (match(Op0, m_NegZeroFP()) && +      match(Op1, m_FSub(m_NegZeroFP(), m_Value(X))))      return X;    // fsub 0.0, (fsub 0.0, X) ==> X if signed zeros are ignored. -  if (FMF.noSignedZeros() && match(Op0, m_AnyZero()) && -      match(Op1, m_FSub(m_AnyZero(), m_Value(X)))) +  if (FMF.noSignedZeros() && match(Op0, m_AnyZeroFP()) && +      match(Op1, m_FSub(m_AnyZeroFP(), m_Value(X))))      return X;    // fsub nnan x, x ==> 0.0 @@ -4192,13 +4309,25 @@ static Value *SimplifyFMulInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FMul, Op0, Op1, Q))      return C; +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; +    // fmul X, 1.0 ==> X    if (match(Op1, m_FPOne()))      return Op0;    // fmul nnan nsz X, 0 ==> 0 -  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZero())) -    return Op1; +  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op1, m_AnyZeroFP())) +    return ConstantFP::getNullValue(Op0->getType()); + +  // sqrt(X) * sqrt(X) --> X, if we can: +  // 1. Remove the intermediate rounding (reassociate). +  // 2. Ignore non-zero negative numbers because sqrt would produce NAN. +  // 3. Ignore -0.0 because sqrt(-0.0) == -0.0, but -0.0 * -0.0 == 0.0. +  Value *X; +  if (Op0 == Op1 && match(Op0, m_Intrinsic<Intrinsic::sqrt>(m_Value(X))) && +      FMF.allowReassoc() && FMF.noNaNs() && FMF.noSignedZeros()) +    return X;    return nullptr;  } @@ -4224,13 +4353,8 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FDiv, Op0, Op1, Q))      return C; -  // undef / X -> undef    (the undef could be a snan). -  if (match(Op0, m_Undef())) -    return Op0; - -  // X / undef -> undef -  if (match(Op1, m_Undef())) -    return Op1; +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C;    // X / 1.0 -> X    if (match(Op1, m_FPOne())) @@ -4239,14 +4363,20 @@ static Value *SimplifyFDivInst(Value *Op0, Value *Op1, FastMathFlags FMF,    // 0 / X -> 0    // Requires that NaNs are off (X could be zero) and signed zeroes are    // ignored (X could be positive or negative, so the output sign is unknown). -  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) -    return Op0; +  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZeroFP())) +    return ConstantFP::getNullValue(Op0->getType());    if (FMF.noNaNs()) {      // X / X -> 1.0 is legal when NaNs are ignored. +    // We can ignore infinities because INF/INF is NaN.      if (Op0 == Op1)        return ConstantFP::get(Op0->getType(), 1.0); +    // (X * Y) / Y --> X if we can reassociate to the above form. +    Value *X; +    if (FMF.allowReassoc() && match(Op0, m_c_FMul(m_Value(X), m_Specific(Op1)))) +      return X; +      // -X /  X -> -1.0 and      //  X / -X -> -1.0 are legal when NaNs are ignored.      // We can ignore signed zeros because +-0.0/+-0.0 is NaN and ignored. @@ -4270,19 +4400,20 @@ static Value *SimplifyFRemInst(Value *Op0, Value *Op1, FastMathFlags FMF,    if (Constant *C = foldOrCommuteConstant(Instruction::FRem, Op0, Op1, Q))      return C; -  // undef % X -> undef    (the undef could be a snan). -  if (match(Op0, m_Undef())) -    return Op0; - -  // X % undef -> undef -  if (match(Op1, m_Undef())) -    return Op1; +  if (Constant *C = simplifyFPBinop(Op0, Op1)) +    return C; -  // 0 % X -> 0 -  // Requires that NaNs are off (X could be zero) and signed zeroes are -  // ignored (X could be positive or negative, so the output sign is unknown). -  if (FMF.noNaNs() && FMF.noSignedZeros() && match(Op0, m_AnyZero())) -    return Op0; +  // Unlike fdiv, the result of frem always matches the sign of the dividend. +  // The constant match may include undef elements in a vector, so return a full +  // zero constant as the result. +  if (FMF.noNaNs()) { +    // +0 % X -> 0 +    if (match(Op0, m_PosZeroFP())) +      return ConstantFP::getNullValue(Op0->getType()); +    // -0 % X -> -0 +    if (match(Op0, m_NegZeroFP())) +      return ConstantFP::getNegativeZero(Op0->getType()); +  }    return nullptr;  } @@ -4489,28 +4620,55 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,        }      } +    Value *IIOperand = *ArgBegin; +    Value *X;      switch (IID) {      case Intrinsic::fabs: { -      if (SignBitMustBeZero(*ArgBegin, Q.TLI)) -        return *ArgBegin; +      if (SignBitMustBeZero(IIOperand, Q.TLI)) +        return IIOperand;        return nullptr;      }      case Intrinsic::bswap: { -      Value *IIOperand = *ArgBegin; -      Value *X = nullptr;        // bswap(bswap(x)) -> x        if (match(IIOperand, m_BSwap(m_Value(X))))          return X;        return nullptr;      }      case Intrinsic::bitreverse: { -      Value *IIOperand = *ArgBegin; -      Value *X = nullptr;        // bitreverse(bitreverse(x)) -> x        if (match(IIOperand, m_BitReverse(m_Value(X))))          return X;        return nullptr;      } +    case Intrinsic::exp: { +      // exp(log(x)) -> x +      if (Q.CxtI->hasAllowReassoc() && +          match(IIOperand, m_Intrinsic<Intrinsic::log>(m_Value(X)))) +        return X; +      return nullptr; +    } +    case Intrinsic::exp2: { +      // exp2(log2(x)) -> x +      if (Q.CxtI->hasAllowReassoc() && +          match(IIOperand, m_Intrinsic<Intrinsic::log2>(m_Value(X)))) +        return X; +      return nullptr; +    } +    case Intrinsic::log: { +      // log(exp(x)) -> x +      if (Q.CxtI->hasAllowReassoc() && +          match(IIOperand, m_Intrinsic<Intrinsic::exp>(m_Value(X)))) +        return X; +      return nullptr; +    } +    case Intrinsic::log2: { +      // log2(exp2(x)) -> x +      if (Q.CxtI->hasAllowReassoc() && +          match(IIOperand, m_Intrinsic<Intrinsic::exp2>(m_Value(X)))) { +        return X; +      } +      return nullptr; +    }      default:        return nullptr;      } @@ -4575,6 +4733,14 @@ static Value *SimplifyIntrinsic(Function *F, IterTy ArgBegin, IterTy ArgEnd,            return LHS;        }        return nullptr; +    case Intrinsic::maxnum: +    case Intrinsic::minnum: +      // If one argument is NaN, return the other argument. +      if (match(LHS, m_NaN())) +        return RHS; +      if (match(RHS, m_NaN())) +        return LHS; +      return nullptr;      default:        return nullptr;      } @@ -4812,7 +4978,7 @@ Value *llvm::SimplifyInstruction(Instruction *I, const SimplifyQuery &SQ,    return Result == I ? UndefValue::get(I->getType()) : Result;  } -/// \brief Implementation of recursive simplification through an instruction's +/// Implementation of recursive simplification through an instruction's  /// uses.  ///  /// This is the common implementation of the recursive simplification routines. | 
