diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
| -rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 1484 | 
1 files changed, 751 insertions, 733 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index e42a4b574d90..5e566bcdaff4 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -61,6 +61,8 @@  #include "llvm/Analysis/ScalarEvolution.h"  #include "llvm/ADT/Optional.h"  #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/Sequence.h"  #include "llvm/ADT/SmallPtrSet.h"  #include "llvm/ADT/Statistic.h"  #include "llvm/Analysis/AssumptionCache.h" @@ -120,6 +122,16 @@ static cl::opt<bool>                    cl::desc("Verify no dangling value in ScalarEvolution's "                             "ExprValueMap (slow)")); +static cl::opt<unsigned> MulOpsInlineThreshold( +    "scev-mulops-inline-threshold", cl::Hidden, +    cl::desc("Threshold for inlining multiplication operands into a SCEV"), +    cl::init(1000)); + +static cl::opt<unsigned> +    MaxCompareDepth("scalar-evolution-max-compare-depth", cl::Hidden, +                    cl::desc("Maximum depth of recursive compare complexity"), +                    cl::init(32)); +  //===----------------------------------------------------------------------===//  //                           SCEV class definitions  //===----------------------------------------------------------------------===// @@ -447,180 +459,233 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {  //                               SCEV Utilities  //===----------------------------------------------------------------------===// -namespace { -/// SCEVComplexityCompare - Return true if the complexity of the LHS is less -/// than the complexity of the RHS.  This comparator is used to canonicalize -/// expressions. -class SCEVComplexityCompare { -  const LoopInfo *const LI; -public: -  explicit SCEVComplexityCompare(const LoopInfo *li) : LI(li) {} +/// Compare the two values \p LV and \p RV in terms of their "complexity" where +/// "complexity" is a partial (and somewhat ad-hoc) relation used to order +/// operands in SCEV expressions.  \p EqCache is a set of pairs of values that +/// have been previously deemed to be "equally complex" by this routine.  It is +/// intended to avoid exponential time complexity in cases like: +/// +///   %a = f(%x, %y) +///   %b = f(%a, %a) +///   %c = f(%b, %b) +/// +///   %d = f(%x, %y) +///   %e = f(%d, %d) +///   %f = f(%e, %e) +/// +///   CompareValueComplexity(%f, %c) +/// +/// Since we do not continue running this routine on expression trees once we +/// have seen unequal values, there is no need to track them in the cache. +static int +CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, +                       const LoopInfo *const LI, Value *LV, Value *RV, +                       unsigned Depth) { +  if (Depth > MaxCompareDepth || EqCache.count({LV, RV})) +    return 0; + +  // Order pointer values after integer values. This helps SCEVExpander form +  // GEPs. +  bool LIsPointer = LV->getType()->isPointerTy(), +       RIsPointer = RV->getType()->isPointerTy(); +  if (LIsPointer != RIsPointer) +    return (int)LIsPointer - (int)RIsPointer; -  // Return true or false if LHS is less than, or at least RHS, respectively. -  bool operator()(const SCEV *LHS, const SCEV *RHS) const { -    return compare(LHS, RHS) < 0; +  // Compare getValueID values. +  unsigned LID = LV->getValueID(), RID = RV->getValueID(); +  if (LID != RID) +    return (int)LID - (int)RID; + +  // Sort arguments by their position. +  if (const auto *LA = dyn_cast<Argument>(LV)) { +    const auto *RA = cast<Argument>(RV); +    unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); +    return (int)LArgNo - (int)RArgNo;    } -  // Return negative, zero, or positive, if LHS is less than, equal to, or -  // greater than RHS, respectively. A three-way result allows recursive -  // comparisons to be more efficient. -  int compare(const SCEV *LHS, const SCEV *RHS) const { -    // Fast-path: SCEVs are uniqued so we can do a quick equality check. -    if (LHS == RHS) -      return 0; - -    // Primarily, sort the SCEVs by their getSCEVType(). -    unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); -    if (LType != RType) -      return (int)LType - (int)RType; - -    // Aside from the getSCEVType() ordering, the particular ordering -    // isn't very important except that it's beneficial to be consistent, -    // so that (a + b) and (b + a) don't end up as different expressions. -    switch (static_cast<SCEVTypes>(LType)) { -    case scUnknown: { -      const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); -      const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); - -      // Sort SCEVUnknown values with some loose heuristics. TODO: This is -      // not as complete as it could be. -      const Value *LV = LU->getValue(), *RV = RU->getValue(); - -      // Order pointer values after integer values. This helps SCEVExpander -      // form GEPs. -      bool LIsPointer = LV->getType()->isPointerTy(), -        RIsPointer = RV->getType()->isPointerTy(); -      if (LIsPointer != RIsPointer) -        return (int)LIsPointer - (int)RIsPointer; - -      // Compare getValueID values. -      unsigned LID = LV->getValueID(), -        RID = RV->getValueID(); -      if (LID != RID) -        return (int)LID - (int)RID; - -      // Sort arguments by their position. -      if (const Argument *LA = dyn_cast<Argument>(LV)) { -        const Argument *RA = cast<Argument>(RV); -        unsigned LArgNo = LA->getArgNo(), RArgNo = RA->getArgNo(); -        return (int)LArgNo - (int)RArgNo; -      } +  if (const auto *LGV = dyn_cast<GlobalValue>(LV)) { +    const auto *RGV = cast<GlobalValue>(RV); -      // For instructions, compare their loop depth, and their operand -      // count.  This is pretty loose. -      if (const Instruction *LInst = dyn_cast<Instruction>(LV)) { -        const Instruction *RInst = cast<Instruction>(RV); - -        // Compare loop depths. -        const BasicBlock *LParent = LInst->getParent(), -          *RParent = RInst->getParent(); -        if (LParent != RParent) { -          unsigned LDepth = LI->getLoopDepth(LParent), -            RDepth = LI->getLoopDepth(RParent); -          if (LDepth != RDepth) -            return (int)LDepth - (int)RDepth; -        } +    const auto IsGVNameSemantic = [&](const GlobalValue *GV) { +      auto LT = GV->getLinkage(); +      return !(GlobalValue::isPrivateLinkage(LT) || +               GlobalValue::isInternalLinkage(LT)); +    }; -        // Compare the number of operands. -        unsigned LNumOps = LInst->getNumOperands(), -          RNumOps = RInst->getNumOperands(); -        return (int)LNumOps - (int)RNumOps; -      } +    // Use the names to distinguish the two values, but only if the +    // names are semantically important. +    if (IsGVNameSemantic(LGV) && IsGVNameSemantic(RGV)) +      return LGV->getName().compare(RGV->getName()); +  } + +  // For instructions, compare their loop depth, and their operand count.  This +  // is pretty loose. +  if (const auto *LInst = dyn_cast<Instruction>(LV)) { +    const auto *RInst = cast<Instruction>(RV); -      return 0; +    // Compare loop depths. +    const BasicBlock *LParent = LInst->getParent(), +                     *RParent = RInst->getParent(); +    if (LParent != RParent) { +      unsigned LDepth = LI->getLoopDepth(LParent), +               RDepth = LI->getLoopDepth(RParent); +      if (LDepth != RDepth) +        return (int)LDepth - (int)RDepth;      } -    case scConstant: { -      const SCEVConstant *LC = cast<SCEVConstant>(LHS); -      const SCEVConstant *RC = cast<SCEVConstant>(RHS); +    // Compare the number of operands. +    unsigned LNumOps = LInst->getNumOperands(), +             RNumOps = RInst->getNumOperands(); +    if (LNumOps != RNumOps) +      return (int)LNumOps - (int)RNumOps; -      // Compare constant values. -      const APInt &LA = LC->getAPInt(); -      const APInt &RA = RC->getAPInt(); -      unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); -      if (LBitWidth != RBitWidth) -        return (int)LBitWidth - (int)RBitWidth; -      return LA.ult(RA) ? -1 : 1; +    for (unsigned Idx : seq(0u, LNumOps)) { +      int Result = +          CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), +                                 RInst->getOperand(Idx), Depth + 1); +      if (Result != 0) +        return Result;      } +  } -    case scAddRecExpr: { -      const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); -      const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); +  EqCache.insert({LV, RV}); +  return 0; +} -      // Compare addrec loop depths. -      const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); -      if (LLoop != RLoop) { -        unsigned LDepth = LLoop->getLoopDepth(), -          RDepth = RLoop->getLoopDepth(); -        if (LDepth != RDepth) -          return (int)LDepth - (int)RDepth; -      } +// Return negative, zero, or positive, if LHS is less than, equal to, or greater +// than RHS, respectively. A three-way result allows recursive comparisons to be +// more efficient. +static int CompareSCEVComplexity( +    SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, +    const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, +    unsigned Depth = 0) { +  // Fast-path: SCEVs are uniqued so we can do a quick equality check. +  if (LHS == RHS) +    return 0; -      // Addrec complexity grows with operand count. -      unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); -      if (LNumOps != RNumOps) -        return (int)LNumOps - (int)RNumOps; +  // Primarily, sort the SCEVs by their getSCEVType(). +  unsigned LType = LHS->getSCEVType(), RType = RHS->getSCEVType(); +  if (LType != RType) +    return (int)LType - (int)RType; -      // Lexicographically compare. -      for (unsigned i = 0; i != LNumOps; ++i) { -        long X = compare(LA->getOperand(i), RA->getOperand(i)); -        if (X != 0) -          return X; -      } +  if (Depth > MaxCompareDepth || EqCacheSCEV.count({LHS, RHS})) +    return 0; +  // Aside from the getSCEVType() ordering, the particular ordering +  // isn't very important except that it's beneficial to be consistent, +  // so that (a + b) and (b + a) don't end up as different expressions. +  switch (static_cast<SCEVTypes>(LType)) { +  case scUnknown: { +    const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); +    const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); + +    SmallSet<std::pair<Value *, Value *>, 8> EqCache; +    int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), +                                   Depth + 1); +    if (X == 0) +      EqCacheSCEV.insert({LHS, RHS}); +    return X; +  } -      return 0; +  case scConstant: { +    const SCEVConstant *LC = cast<SCEVConstant>(LHS); +    const SCEVConstant *RC = cast<SCEVConstant>(RHS); + +    // Compare constant values. +    const APInt &LA = LC->getAPInt(); +    const APInt &RA = RC->getAPInt(); +    unsigned LBitWidth = LA.getBitWidth(), RBitWidth = RA.getBitWidth(); +    if (LBitWidth != RBitWidth) +      return (int)LBitWidth - (int)RBitWidth; +    return LA.ult(RA) ? -1 : 1; +  } + +  case scAddRecExpr: { +    const SCEVAddRecExpr *LA = cast<SCEVAddRecExpr>(LHS); +    const SCEVAddRecExpr *RA = cast<SCEVAddRecExpr>(RHS); + +    // Compare addrec loop depths. +    const Loop *LLoop = LA->getLoop(), *RLoop = RA->getLoop(); +    if (LLoop != RLoop) { +      unsigned LDepth = LLoop->getLoopDepth(), RDepth = RLoop->getLoopDepth(); +      if (LDepth != RDepth) +        return (int)LDepth - (int)RDepth;      } -    case scAddExpr: -    case scMulExpr: -    case scSMaxExpr: -    case scUMaxExpr: { -      const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); -      const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); - -      // Lexicographically compare n-ary expressions. -      unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); -      if (LNumOps != RNumOps) -        return (int)LNumOps - (int)RNumOps; - -      for (unsigned i = 0; i != LNumOps; ++i) { -        if (i >= RNumOps) -          return 1; -        long X = compare(LC->getOperand(i), RC->getOperand(i)); -        if (X != 0) -          return X; -      } +    // Addrec complexity grows with operand count. +    unsigned LNumOps = LA->getNumOperands(), RNumOps = RA->getNumOperands(); +    if (LNumOps != RNumOps)        return (int)LNumOps - (int)RNumOps; + +    // Lexicographically compare. +    for (unsigned i = 0; i != LNumOps; ++i) { +      int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), +                                    RA->getOperand(i), Depth + 1); +      if (X != 0) +        return X;      } +    EqCacheSCEV.insert({LHS, RHS}); +    return 0; +  } -    case scUDivExpr: { -      const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); -      const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); +  case scAddExpr: +  case scMulExpr: +  case scSMaxExpr: +  case scUMaxExpr: { +    const SCEVNAryExpr *LC = cast<SCEVNAryExpr>(LHS); +    const SCEVNAryExpr *RC = cast<SCEVNAryExpr>(RHS); + +    // Lexicographically compare n-ary expressions. +    unsigned LNumOps = LC->getNumOperands(), RNumOps = RC->getNumOperands(); +    if (LNumOps != RNumOps) +      return (int)LNumOps - (int)RNumOps; -      // Lexicographically compare udiv expressions. -      long X = compare(LC->getLHS(), RC->getLHS()); +    for (unsigned i = 0; i != LNumOps; ++i) { +      if (i >= RNumOps) +        return 1; +      int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), +                                    RC->getOperand(i), Depth + 1);        if (X != 0)          return X; -      return compare(LC->getRHS(), RC->getRHS());      } +    EqCacheSCEV.insert({LHS, RHS}); +    return 0; +  } -    case scTruncate: -    case scZeroExtend: -    case scSignExtend: { -      const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); -      const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); +  case scUDivExpr: { +    const SCEVUDivExpr *LC = cast<SCEVUDivExpr>(LHS); +    const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); -      // Compare cast expressions by operand. -      return compare(LC->getOperand(), RC->getOperand()); -    } +    // Lexicographically compare udiv expressions. +    int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), +                                  Depth + 1); +    if (X != 0) +      return X; +    X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), +                              Depth + 1); +    if (X == 0) +      EqCacheSCEV.insert({LHS, RHS}); +    return X; +  } -    case scCouldNotCompute: -      llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); -    } -    llvm_unreachable("Unknown SCEV kind!"); +  case scTruncate: +  case scZeroExtend: +  case scSignExtend: { +    const SCEVCastExpr *LC = cast<SCEVCastExpr>(LHS); +    const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); + +    // Compare cast expressions by operand. +    int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), +                                  RC->getOperand(), Depth + 1); +    if (X == 0) +      EqCacheSCEV.insert({LHS, RHS}); +    return X;    } -}; -}  // end anonymous namespace + +  case scCouldNotCompute: +    llvm_unreachable("Attempt to use a SCEVCouldNotCompute object!"); +  } +  llvm_unreachable("Unknown SCEV kind!"); +}  /// Given a list of SCEV objects, order them by their complexity, and group  /// objects of the same complexity together by value.  When this routine is @@ -635,17 +700,22 @@ public:  static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,                                LoopInfo *LI) {    if (Ops.size() < 2) return;  // Noop + +  SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache;    if (Ops.size() == 2) {      // This is the common case, which also happens to be trivially simple.      // Special case it.      const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; -    if (SCEVComplexityCompare(LI)(RHS, LHS)) +    if (CompareSCEVComplexity(EqCache, LI, RHS, LHS) < 0)        std::swap(LHS, RHS);      return;    }    // Do the rough sort by complexity. -  std::stable_sort(Ops.begin(), Ops.end(), SCEVComplexityCompare(LI)); +  std::stable_sort(Ops.begin(), Ops.end(), +                   [&EqCache, LI](const SCEV *LHS, const SCEV *RHS) { +                     return CompareSCEVComplexity(EqCache, LI, LHS, RHS) < 0; +                   });    // Now that we are sorted by complexity, group elements of the same    // complexity.  Note that this is, at worst, N^2, but the vector is likely to @@ -2518,6 +2588,8 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,    if (Idx < Ops.size()) {      bool DeletedMul = false;      while (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { +      if (Ops.size() > MulOpsInlineThreshold) +        break;        // If we have an mul, expand the mul operands onto the end of the operands        // list.        Ops.erase(Ops.begin()+Idx); @@ -2970,9 +3042,9 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,  }  const SCEV * -ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr, -                            const SmallVectorImpl<const SCEV *> &IndexExprs, -                            bool InBounds) { +ScalarEvolution::getGEPExpr(GEPOperator *GEP, +                            const SmallVectorImpl<const SCEV *> &IndexExprs) { +  const SCEV *BaseExpr = getSCEV(GEP->getPointerOperand());    // getSCEV(Base)->getType() has the same address space as Base->getType()    // because SCEV::getType() preserves the address space.    Type *IntPtrTy = getEffectiveSCEVType(BaseExpr->getType()); @@ -2981,12 +3053,13 @@ ScalarEvolution::getGEPExpr(Type *PointeeType, const SCEV *BaseExpr,    // flow and the no-overflow bits may not be valid for the expression in any    // context. This can be fixed similarly to how these flags are handled for    // adds. -  SCEV::NoWrapFlags Wrap = InBounds ? SCEV::FlagNSW : SCEV::FlagAnyWrap; +  SCEV::NoWrapFlags Wrap = GEP->isInBounds() ? SCEV::FlagNSW +                                             : SCEV::FlagAnyWrap;    const SCEV *TotalOffset = getZero(IntPtrTy); -  // The address space is unimportant. The first thing we do on CurTy is getting +  // The array size is unimportant. The first thing we do on CurTy is getting    // its element type. -  Type *CurTy = PointerType::getUnqual(PointeeType); +  Type *CurTy = ArrayType::get(GEP->getSourceElementType(), 0);    for (const SCEV *IndexExpr : IndexExprs) {      // Compute the (potentially symbolic) offset in bytes for this index.      if (StructType *STy = dyn_cast<StructType>(CurTy)) { @@ -3311,75 +3384,47 @@ const SCEV *ScalarEvolution::getCouldNotCompute() {    return CouldNotCompute.get();  } -  bool ScalarEvolution::checkValidity(const SCEV *S) const { -  // Helper class working with SCEVTraversal to figure out if a SCEV contains -  // a SCEVUnknown with null value-pointer. FindInvalidSCEVUnknown::FindOne -  // is set iff if find such SCEVUnknown. -  // -  struct FindInvalidSCEVUnknown { -    bool FindOne; -    FindInvalidSCEVUnknown() { FindOne = false; } -    bool follow(const SCEV *S) { -      switch (static_cast<SCEVTypes>(S->getSCEVType())) { -      case scConstant: -        return false; -      case scUnknown: -        if (!cast<SCEVUnknown>(S)->getValue()) -          FindOne = true; -        return false; -      default: -        return true; -      } -    } -    bool isDone() const { return FindOne; } -  }; - -  FindInvalidSCEVUnknown F; -  SCEVTraversal<FindInvalidSCEVUnknown> ST(F); -  ST.visitAll(S); +  bool ContainsNulls = SCEVExprContains(S, [](const SCEV *S) { +    auto *SU = dyn_cast<SCEVUnknown>(S); +    return SU && SU->getValue() == nullptr; +  }); -  return !F.FindOne; -} - -namespace { -// Helper class working with SCEVTraversal to figure out if a SCEV contains -// a sub SCEV of scAddRecExpr type.  FindInvalidSCEVUnknown::FoundOne is set -// iff if such sub scAddRecExpr type SCEV is found. -struct FindAddRecurrence { -  bool FoundOne; -  FindAddRecurrence() : FoundOne(false) {} - -  bool follow(const SCEV *S) { -    switch (static_cast<SCEVTypes>(S->getSCEVType())) { -    case scAddRecExpr: -      FoundOne = true; -    case scConstant: -    case scUnknown: -    case scCouldNotCompute: -      return false; -    default: -      return true; -    } -  } -  bool isDone() const { return FoundOne; } -}; +  return !ContainsNulls;  }  bool ScalarEvolution::containsAddRecurrence(const SCEV *S) { -  HasRecMapType::iterator I = HasRecMap.find_as(S); +  HasRecMapType::iterator I = HasRecMap.find(S);    if (I != HasRecMap.end())      return I->second; -  FindAddRecurrence F; -  SCEVTraversal<FindAddRecurrence> ST(F); -  ST.visitAll(S); -  HasRecMap.insert({S, F.FoundOne}); -  return F.FoundOne; +  bool FoundAddRec = SCEVExprContains(S, isa<SCEVAddRecExpr, const SCEV *>); +  HasRecMap.insert({S, FoundAddRec}); +  return FoundAddRec; +} + +/// Try to split a SCEVAddExpr into a pair of {SCEV, ConstantInt}. +/// If \p S is a SCEVAddExpr and is composed of a sub SCEV S' and an +/// offset I, then return {S', I}, else return {\p S, nullptr}. +static std::pair<const SCEV *, ConstantInt *> splitAddExpr(const SCEV *S) { +  const auto *Add = dyn_cast<SCEVAddExpr>(S); +  if (!Add) +    return {S, nullptr}; + +  if (Add->getNumOperands() != 2) +    return {S, nullptr}; + +  auto *ConstOp = dyn_cast<SCEVConstant>(Add->getOperand(0)); +  if (!ConstOp) +    return {S, nullptr}; + +  return {Add->getOperand(1), ConstOp->getValue()};  } -/// Return the Value set from S. -SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) { +/// Return the ValueOffsetPair set for \p S. \p S can be represented +/// by the value and offset from any ValueOffsetPair in the set. +SetVector<ScalarEvolution::ValueOffsetPair> * +ScalarEvolution::getSCEVValues(const SCEV *S) {    ExprValueMapType::iterator SI = ExprValueMap.find_as(S);    if (SI == ExprValueMap.end())      return nullptr; @@ -3387,24 +3432,31 @@ SetVector<Value *> *ScalarEvolution::getSCEVValues(const SCEV *S) {    if (VerifySCEVMap) {      // Check there is no dangling Value in the set returned.      for (const auto &VE : SI->second) -      assert(ValueExprMap.count(VE)); +      assert(ValueExprMap.count(VE.first));    }  #endif    return &SI->second;  } -/// Erase Value from ValueExprMap and ExprValueMap.  If ValueExprMap.erase(V) is -/// not used together with forgetMemoizedResults(S), eraseValueFromMap should be -/// used instead to ensure whenever V->S is removed from ValueExprMap, V is also -/// removed from the set of ExprValueMap[S]. +/// Erase Value from ValueExprMap and ExprValueMap. ValueExprMap.erase(V) +/// cannot be used separately. eraseValueFromMap should be used to remove +/// V from ValueExprMap and ExprValueMap at the same time.  void ScalarEvolution::eraseValueFromMap(Value *V) {    ValueExprMapType::iterator I = ValueExprMap.find_as(V);    if (I != ValueExprMap.end()) {      const SCEV *S = I->second; -    SetVector<Value *> *SV = getSCEVValues(S); -    // Remove V from the set of ExprValueMap[S] -    if (SV) -      SV->remove(V); +    // Remove {V, 0} from the set of ExprValueMap[S] +    if (SetVector<ValueOffsetPair> *SV = getSCEVValues(S)) +      SV->remove({V, nullptr}); + +    // Remove {V, Offset} from the set of ExprValueMap[Stripped] +    const SCEV *Stripped; +    ConstantInt *Offset; +    std::tie(Stripped, Offset) = splitAddExpr(S); +    if (Offset != nullptr) { +      if (SetVector<ValueOffsetPair> *SV = getSCEVValues(Stripped)) +        SV->remove({V, Offset}); +    }      ValueExprMap.erase(V);    }  } @@ -3419,11 +3471,26 @@ const SCEV *ScalarEvolution::getSCEV(Value *V) {      S = createSCEV(V);      // During PHI resolution, it is possible to create two SCEVs for the same      // V, so it is needed to double check whether V->S is inserted into -    // ValueExprMap before insert S->V into ExprValueMap. +    // ValueExprMap before insert S->{V, 0} into ExprValueMap.      std::pair<ValueExprMapType::iterator, bool> Pair =          ValueExprMap.insert({SCEVCallbackVH(V, this), S}); -    if (Pair.second) -      ExprValueMap[S].insert(V); +    if (Pair.second) { +      ExprValueMap[S].insert({V, nullptr}); + +      // If S == Stripped + Offset, add Stripped -> {V, Offset} into +      // ExprValueMap. +      const SCEV *Stripped = S; +      ConstantInt *Offset = nullptr; +      std::tie(Stripped, Offset) = splitAddExpr(S); +      // If stripped is SCEVUnknown, don't bother to save +      // Stripped -> {V, offset}. It doesn't simplify and sometimes even +      // increase the complexity of the expansion code. +      // If V is GetElementPtrInst, don't save Stripped -> {V, offset} +      // because it may generate add/sub instead of GEP in SCEV expansion. +      if (Offset != nullptr && !isa<SCEVUnknown>(Stripped) && +          !isa<GetElementPtrInst>(V)) +        ExprValueMap[Stripped].insert({V, Offset}); +    }    }    return S;  } @@ -3436,8 +3503,8 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {      const SCEV *S = I->second;      if (checkValidity(S))        return S; +    eraseValueFromMap(V);      forgetMemoizedResults(S); -    ValueExprMap.erase(I);    }    return nullptr;  } @@ -3675,8 +3742,8 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {        if (!isa<PHINode>(I) ||            !isa<SCEVUnknown>(Old) ||            (I != PN && Old == SymName)) { +        eraseValueFromMap(It->first);          forgetMemoizedResults(Old); -        ValueExprMap.erase(It);        }      } @@ -4055,7 +4122,7 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {      // to create an AddRecExpr for this PHI node. We can not keep this temporary      // as it will prevent later (possibly simpler) SCEV expressions to be added      // to the ValueExprMap. -    ValueExprMap.erase(PN); +    eraseValueFromMap(PN);    }    return nullptr; @@ -4168,7 +4235,9 @@ static bool BrPHIToSelect(DominatorTree &DT, BranchInst *BI, PHINode *Merge,  }  const SCEV *ScalarEvolution::createNodeFromSelectLikePHI(PHINode *PN) { -  if (PN->getNumIncomingValues() == 2) { +  auto IsReachable = +      [&](BasicBlock *BB) { return DT.isReachableFromEntry(BB); }; +  if (PN->getNumIncomingValues() == 2 && all_of(PN->blocks(), IsReachable)) {      const Loop *L = LI.getLoopFor(PN->getParent());      // We don't want to break LCSSA, even in a SCEV expression tree. @@ -4244,7 +4313,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,    case ICmpInst::ICMP_SLT:    case ICmpInst::ICMP_SLE:      std::swap(LHS, RHS); -  // fall through +    LLVM_FALLTHROUGH;    case ICmpInst::ICMP_SGT:    case ICmpInst::ICMP_SGE:      // a >s b ? a+x : b+x  ->  smax(a, b)+x @@ -4267,7 +4336,7 @@ const SCEV *ScalarEvolution::createNodeForSelectOrPHI(Instruction *I,    case ICmpInst::ICMP_ULT:    case ICmpInst::ICMP_ULE:      std::swap(LHS, RHS); -  // fall through +    LLVM_FALLTHROUGH;    case ICmpInst::ICMP_UGT:    case ICmpInst::ICMP_UGE:      // a >u b ? a+x : b+x  ->  umax(a, b)+x @@ -4332,9 +4401,7 @@ const SCEV *ScalarEvolution::createNodeForGEP(GEPOperator *GEP) {    SmallVector<const SCEV *, 4> IndexExprs;    for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index)      IndexExprs.push_back(getSCEV(*Index)); -  return getGEPExpr(GEP->getSourceElementType(), -                    getSCEV(GEP->getPointerOperand()), -                    IndexExprs, GEP->isInBounds()); +  return getGEPExpr(GEP, IndexExprs);  }  uint32_t @@ -4612,19 +4679,18 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,    MaxBECount = getNoopOrZeroExtend(MaxBECount, Start->getType());    ConstantRange MaxBECountRange = getUnsignedRange(MaxBECount); -  ConstantRange ZExtMaxBECountRange = -      MaxBECountRange.zextOrTrunc(BitWidth * 2 + 1); +  ConstantRange ZExtMaxBECountRange = MaxBECountRange.zextOrTrunc(BitWidth * 2);    ConstantRange StepSRange = getSignedRange(Step); -  ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2 + 1); +  ConstantRange SExtStepSRange = StepSRange.sextOrTrunc(BitWidth * 2);    ConstantRange StartURange = getUnsignedRange(Start);    ConstantRange EndURange =        StartURange.add(MaxBECountRange.multiply(StepSRange));    // Check for unsigned overflow. -  ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2 + 1); -  ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2 + 1); +  ConstantRange ZExtStartURange = StartURange.zextOrTrunc(BitWidth * 2); +  ConstantRange ZExtEndURange = EndURange.zextOrTrunc(BitWidth * 2);    if (ZExtStartURange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==        ZExtEndURange) {      APInt Min = APIntOps::umin(StartURange.getUnsignedMin(), @@ -4644,8 +4710,8 @@ ConstantRange ScalarEvolution::getRangeForAffineAR(const SCEV *Start,    // Check for signed overflow. This must be done with ConstantRange    // arithmetic because we could be called from within the ScalarEvolution    // overflow checking code. -  ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2 + 1); -  ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2 + 1); +  ConstantRange SExtStartSRange = StartSRange.sextOrTrunc(BitWidth * 2); +  ConstantRange SExtEndSRange = EndSRange.sextOrTrunc(BitWidth * 2);    if (SExtStartSRange.add(ZExtMaxBECountRange.multiply(SExtStepSRange)) ==        SExtEndSRange) {      APInt Min = @@ -4909,17 +4975,33 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {    return LatchControlDependentOnPoison && loopHasNoAbnormalExits(L);  } -bool ScalarEvolution::loopHasNoAbnormalExits(const Loop *L) { -  auto Itr = LoopHasNoAbnormalExits.find(L); -  if (Itr == LoopHasNoAbnormalExits.end()) { -    auto NoAbnormalExitInBB = [&](BasicBlock *BB) { -      return all_of(*BB, [](Instruction &I) { -        return isGuaranteedToTransferExecutionToSuccessor(&I); -      }); +ScalarEvolution::LoopProperties +ScalarEvolution::getLoopProperties(const Loop *L) { +  typedef ScalarEvolution::LoopProperties LoopProperties; + +  auto Itr = LoopPropertiesCache.find(L); +  if (Itr == LoopPropertiesCache.end()) { +    auto HasSideEffects = [](Instruction *I) { +      if (auto *SI = dyn_cast<StoreInst>(I)) +        return !SI->isSimple(); + +      return I->mayHaveSideEffects();      }; -    auto InsertPair = LoopHasNoAbnormalExits.insert( -        {L, all_of(L->getBlocks(), NoAbnormalExitInBB)}); +    LoopProperties LP = {/* HasNoAbnormalExits */ true, +                         /*HasNoSideEffects*/ true}; + +    for (auto *BB : L->getBlocks()) +      for (auto &I : *BB) { +        if (!isGuaranteedToTransferExecutionToSuccessor(&I)) +          LP.HasNoAbnormalExits = false; +        if (HasSideEffects(&I)) +          LP.HasNoSideEffects = false; +        if (!LP.HasNoAbnormalExits && !LP.HasNoSideEffects) +          break; // We're already as pessimistic as we can get. +      } + +    auto InsertPair = LoopPropertiesCache.insert({L, LP});      assert(InsertPair.second && "We just checked!");      Itr = InsertPair.first;    } @@ -5247,6 +5329,20 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {  //                   Iteration Count Computation Code  // +static unsigned getConstantTripCount(const SCEVConstant *ExitCount) { +  if (!ExitCount) +    return 0; + +  ConstantInt *ExitConst = ExitCount->getValue(); + +  // Guard against huge trip counts. +  if (ExitConst->getValue().getActiveBits() > 32) +    return 0; + +  // In case of integer overflow, this returns 0, which is correct. +  return ((unsigned)ExitConst->getZExtValue()) + 1; +} +  unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L) {    if (BasicBlock *ExitingBB = L->getExitingBlock())      return getSmallConstantTripCount(L, ExitingBB); @@ -5262,17 +5358,13 @@ unsigned ScalarEvolution::getSmallConstantTripCount(Loop *L,           "Exiting block must actually branch out of the loop!");    const SCEVConstant *ExitCount =        dyn_cast<SCEVConstant>(getExitCount(L, ExitingBlock)); -  if (!ExitCount) -    return 0; - -  ConstantInt *ExitConst = ExitCount->getValue(); - -  // Guard against huge trip counts. -  if (ExitConst->getValue().getActiveBits() > 32) -    return 0; +  return getConstantTripCount(ExitCount); +} -  // In case of integer overflow, this returns 0, which is correct. -  return ((unsigned)ExitConst->getZExtValue()) + 1; +unsigned ScalarEvolution::getSmallConstantMaxTripCount(Loop *L) { +  const auto *MaxExitCount = +      dyn_cast<SCEVConstant>(getMaxBackedgeTakenCount(L)); +  return getConstantTripCount(MaxExitCount);  }  unsigned ScalarEvolution::getSmallConstantTripMultiple(Loop *L) { @@ -5351,6 +5443,10 @@ const SCEV *ScalarEvolution::getMaxBackedgeTakenCount(const Loop *L) {    return getBackedgeTakenInfo(L).getMax(this);  } +bool ScalarEvolution::isBackedgeTakenCountMaxOrZero(const Loop *L) { +  return getBackedgeTakenInfo(L).isMaxOrZero(this); +} +  /// Push PHI nodes in the header of the given loop onto the given Worklist.  static void  PushLoopPHIs(const Loop *L, SmallVectorImpl<Instruction *> &Worklist) { @@ -5376,7 +5472,7 @@ ScalarEvolution::getPredicatedBackedgeTakenInfo(const Loop *L) {    BackedgeTakenInfo Result =        computeBackedgeTakenCount(L, /*AllowPredicates=*/true); -  return PredicatedBackedgeTakenCounts.find(L)->second = Result; +  return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);  }  const ScalarEvolution::BackedgeTakenInfo & @@ -5435,8 +5531,8 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {          // case, createNodeForPHI will perform the necessary updates on its          // own when it gets to that point.          if (!isa<PHINode>(I) || !isa<SCEVUnknown>(Old)) { +          eraseValueFromMap(It->first);            forgetMemoizedResults(Old); -          ValueExprMap.erase(It);          }          if (PHINode *PN = dyn_cast<PHINode>(I))            ConstantEvolutionLoopExitValue.erase(PN); @@ -5451,7 +5547,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {    // recusive call to getBackedgeTakenInfo (on a different    // loop), which would invalidate the iterator computed    // earlier. -  return BackedgeTakenCounts.find(L)->second = Result; +  return BackedgeTakenCounts.find(L)->second = std::move(Result);  }  void ScalarEvolution::forgetLoop(const Loop *L) { @@ -5481,8 +5577,8 @@ void ScalarEvolution::forgetLoop(const Loop *L) {      ValueExprMapType::iterator It =        ValueExprMap.find_as(static_cast<Value *>(I));      if (It != ValueExprMap.end()) { +      eraseValueFromMap(It->first);        forgetMemoizedResults(It->second); -      ValueExprMap.erase(It);        if (PHINode *PN = dyn_cast<PHINode>(I))          ConstantEvolutionLoopExitValue.erase(PN);      } @@ -5495,7 +5591,7 @@ void ScalarEvolution::forgetLoop(const Loop *L) {    for (Loop *I : *L)      forgetLoop(I); -  LoopHasNoAbnormalExits.erase(L); +  LoopPropertiesCache.erase(L);  }  void ScalarEvolution::forgetValue(Value *V) { @@ -5515,8 +5611,8 @@ void ScalarEvolution::forgetValue(Value *V) {      ValueExprMapType::iterator It =        ValueExprMap.find_as(static_cast<Value *>(I));      if (It != ValueExprMap.end()) { +      eraseValueFromMap(It->first);        forgetMemoizedResults(It->second); -      ValueExprMap.erase(It);        if (PHINode *PN = dyn_cast<PHINode>(I))          ConstantEvolutionLoopExitValue.erase(PN);      } @@ -5534,14 +5630,11 @@ void ScalarEvolution::forgetValue(Value *V) {  /// caller's responsibility to specify the relevant loop exit using  /// getExact(ExitingBlock, SE).  const SCEV * -ScalarEvolution::BackedgeTakenInfo::getExact( -    ScalarEvolution *SE, SCEVUnionPredicate *Preds) const { +ScalarEvolution::BackedgeTakenInfo::getExact(ScalarEvolution *SE, +                                             SCEVUnionPredicate *Preds) const {    // If any exits were not computable, the loop is not computable. -  if (!ExitNotTaken.isCompleteList()) return SE->getCouldNotCompute(); - -  // We need exactly one computable exit. -  if (!ExitNotTaken.ExitingBlock) return SE->getCouldNotCompute(); -  assert(ExitNotTaken.ExactNotTaken && "uninitialized not-taken info"); +  if (!isComplete() || ExitNotTaken.empty()) +    return SE->getCouldNotCompute();    const SCEV *BECount = nullptr;    for (auto &ENT : ExitNotTaken) { @@ -5551,10 +5644,10 @@ ScalarEvolution::BackedgeTakenInfo::getExact(        BECount = ENT.ExactNotTaken;      else if (BECount != ENT.ExactNotTaken)        return SE->getCouldNotCompute(); -    if (Preds && ENT.getPred()) -      Preds->add(ENT.getPred()); +    if (Preds && !ENT.hasAlwaysTruePredicate()) +      Preds->add(ENT.Predicate.get()); -    assert((Preds || ENT.hasAlwaysTruePred()) && +    assert((Preds || ENT.hasAlwaysTruePredicate()) &&             "Predicate should be always true!");    } @@ -5567,7 +5660,7 @@ const SCEV *  ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,                                               ScalarEvolution *SE) const {    for (auto &ENT : ExitNotTaken) -    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePred()) +    if (ENT.ExitingBlock == ExitingBlock && ENT.hasAlwaysTruePredicate())        return ENT.ExactNotTaken;    return SE->getCouldNotCompute(); @@ -5576,21 +5669,29 @@ ScalarEvolution::BackedgeTakenInfo::getExact(BasicBlock *ExitingBlock,  /// getMax - Get the max backedge taken count for the loop.  const SCEV *  ScalarEvolution::BackedgeTakenInfo::getMax(ScalarEvolution *SE) const { -  for (auto &ENT : ExitNotTaken) -    if (!ENT.hasAlwaysTruePred()) -      return SE->getCouldNotCompute(); +  auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { +    return !ENT.hasAlwaysTruePredicate(); +  }; -  return Max ? Max : SE->getCouldNotCompute(); +  if (any_of(ExitNotTaken, PredicateNotAlwaysTrue) || !getMax()) +    return SE->getCouldNotCompute(); + +  return getMax(); +} + +bool ScalarEvolution::BackedgeTakenInfo::isMaxOrZero(ScalarEvolution *SE) const { +  auto PredicateNotAlwaysTrue = [](const ExitNotTakenInfo &ENT) { +    return !ENT.hasAlwaysTruePredicate(); +  }; +  return MaxOrZero && !any_of(ExitNotTaken, PredicateNotAlwaysTrue);  }  bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,                                                      ScalarEvolution *SE) const { -  if (Max && Max != SE->getCouldNotCompute() && SE->hasOperand(Max, S)) +  if (getMax() && getMax() != SE->getCouldNotCompute() && +      SE->hasOperand(getMax(), S))      return true; -  if (!ExitNotTaken.ExitingBlock) -    return false; -    for (auto &ENT : ExitNotTaken)      if (ENT.ExactNotTaken != SE->getCouldNotCompute() &&          SE->hasOperand(ENT.ExactNotTaken, S)) @@ -5602,62 +5703,31 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,  /// Allocate memory for BackedgeTakenInfo and copy the not-taken count of each  /// computable exit into a persistent ExitNotTakenInfo array.  ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( -    SmallVectorImpl<EdgeInfo> &ExitCounts, bool Complete, const SCEV *MaxCount) -    : Max(MaxCount) { - -  if (!Complete) -    ExitNotTaken.setIncomplete(); - -  unsigned NumExits = ExitCounts.size(); -  if (NumExits == 0) return; - -  ExitNotTaken.ExitingBlock = ExitCounts[0].ExitBlock; -  ExitNotTaken.ExactNotTaken = ExitCounts[0].Taken; - -  // Determine the number of ExitNotTakenExtras structures that we need. -  unsigned ExtraInfoSize = 0; -  if (NumExits > 1) -    ExtraInfoSize = 1 + std::count_if(std::next(ExitCounts.begin()), -                                      ExitCounts.end(), [](EdgeInfo &Entry) { -                                        return !Entry.Pred.isAlwaysTrue(); -                                      }); -  else if (!ExitCounts[0].Pred.isAlwaysTrue()) -    ExtraInfoSize = 1; - -  ExitNotTakenExtras *ENT = nullptr; - -  // Allocate the ExitNotTakenExtras structures and initialize the first -  // element (ExitNotTaken). -  if (ExtraInfoSize > 0) { -    ENT = new ExitNotTakenExtras[ExtraInfoSize]; -    ExitNotTaken.ExtraInfo = &ENT[0]; -    *ExitNotTaken.getPred() = std::move(ExitCounts[0].Pred); -  } - -  if (NumExits == 1) -    return; - -  assert(ENT && "ExitNotTakenExtras is NULL while having more than one exit"); - -  auto &Exits = ExitNotTaken.ExtraInfo->Exits; - -  // Handle the rare case of multiple computable exits. -  for (unsigned i = 1, PredPos = 1; i < NumExits; ++i) { -    ExitNotTakenExtras *Ptr = nullptr; -    if (!ExitCounts[i].Pred.isAlwaysTrue()) { -      Ptr = &ENT[PredPos++]; -      Ptr->Pred = std::move(ExitCounts[i].Pred); -    } - -    Exits.emplace_back(ExitCounts[i].ExitBlock, ExitCounts[i].Taken, Ptr); -  } +    SmallVectorImpl<ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo> +        &&ExitCounts, +    bool Complete, const SCEV *MaxCount, bool MaxOrZero) +    : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { +  typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; +  ExitNotTaken.reserve(ExitCounts.size()); +  std::transform( +      ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), +      [&](const EdgeExitInfo &EEI) { +        BasicBlock *ExitBB = EEI.first; +        const ExitLimit &EL = EEI.second; +        if (EL.Predicates.empty()) +          return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, nullptr); + +        std::unique_ptr<SCEVUnionPredicate> Predicate(new SCEVUnionPredicate); +        for (auto *Pred : EL.Predicates) +          Predicate->add(Pred); + +        return ExitNotTakenInfo(ExitBB, EL.ExactNotTaken, std::move(Predicate)); +      });  }  /// Invalidate this result and free the ExitNotTakenInfo array.  void ScalarEvolution::BackedgeTakenInfo::clear() { -  ExitNotTaken.ExitingBlock = nullptr; -  ExitNotTaken.ExactNotTaken = nullptr; -  delete[] ExitNotTaken.ExtraInfo; +  ExitNotTaken.clear();  }  /// Compute the number of times the backedge of the specified loop will execute. @@ -5667,11 +5737,14 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,    SmallVector<BasicBlock *, 8> ExitingBlocks;    L->getExitingBlocks(ExitingBlocks); -  SmallVector<EdgeInfo, 4> ExitCounts; +  typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; + +  SmallVector<EdgeExitInfo, 4> ExitCounts;    bool CouldComputeBECount = true;    BasicBlock *Latch = L->getLoopLatch(); // may be NULL.    const SCEV *MustExitMaxBECount = nullptr;    const SCEV *MayExitMaxBECount = nullptr; +  bool MustExitMaxOrZero = false;    // Compute the ExitLimit for each loop exit. Use this to populate ExitCounts    // and compute maxBECount. @@ -5680,17 +5753,17 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,      BasicBlock *ExitBB = ExitingBlocks[i];      ExitLimit EL = computeExitLimit(L, ExitBB, AllowPredicates); -    assert((AllowPredicates || EL.Pred.isAlwaysTrue()) && +    assert((AllowPredicates || EL.Predicates.empty()) &&             "Predicated exit limit when predicates are not allowed!");      // 1. For each exit that can be computed, add an entry to ExitCounts.      // CouldComputeBECount is true only if all exits can be computed. -    if (EL.Exact == getCouldNotCompute()) +    if (EL.ExactNotTaken == getCouldNotCompute())        // We couldn't compute an exact value for this exit, so        // we won't be able to compute an exact value for the loop.        CouldComputeBECount = false;      else -      ExitCounts.emplace_back(EdgeInfo(ExitBB, EL.Exact, EL.Pred)); +      ExitCounts.emplace_back(ExitBB, EL);      // 2. Derive the loop's MaxBECount from each exit's max number of      // non-exiting iterations. Partition the loop exits into two kinds: @@ -5698,29 +5771,35 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,      //      // If the exit dominates the loop latch, it is a LoopMustExit otherwise it      // is a LoopMayExit.  If any computable LoopMustExit is found, then -    // MaxBECount is the minimum EL.Max of computable LoopMustExits. Otherwise, -    // MaxBECount is conservatively the maximum EL.Max, where CouldNotCompute is -    // considered greater than any computable EL.Max. -    if (EL.Max != getCouldNotCompute() && Latch && +    // MaxBECount is the minimum EL.MaxNotTaken of computable +    // LoopMustExits. Otherwise, MaxBECount is conservatively the maximum +    // EL.MaxNotTaken, where CouldNotCompute is considered greater than any +    // computable EL.MaxNotTaken. +    if (EL.MaxNotTaken != getCouldNotCompute() && Latch &&          DT.dominates(ExitBB, Latch)) { -      if (!MustExitMaxBECount) -        MustExitMaxBECount = EL.Max; -      else { +      if (!MustExitMaxBECount) { +        MustExitMaxBECount = EL.MaxNotTaken; +        MustExitMaxOrZero = EL.MaxOrZero; +      } else {          MustExitMaxBECount = -          getUMinFromMismatchedTypes(MustExitMaxBECount, EL.Max); +            getUMinFromMismatchedTypes(MustExitMaxBECount, EL.MaxNotTaken);        }      } else if (MayExitMaxBECount != getCouldNotCompute()) { -      if (!MayExitMaxBECount || EL.Max == getCouldNotCompute()) -        MayExitMaxBECount = EL.Max; +      if (!MayExitMaxBECount || EL.MaxNotTaken == getCouldNotCompute()) +        MayExitMaxBECount = EL.MaxNotTaken;        else {          MayExitMaxBECount = -          getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.Max); +            getUMaxFromMismatchedTypes(MayExitMaxBECount, EL.MaxNotTaken);        }      }    }    const SCEV *MaxBECount = MustExitMaxBECount ? MustExitMaxBECount :      (MayExitMaxBECount ? MayExitMaxBECount : getCouldNotCompute()); -  return BackedgeTakenInfo(ExitCounts, CouldComputeBECount, MaxBECount); +  // The loop backedge will be taken the maximum or zero times if there's +  // a single exit that must be taken the maximum or zero times. +  bool MaxOrZero = (MustExitMaxOrZero && ExitingBlocks.size() == 1); +  return BackedgeTakenInfo(std::move(ExitCounts), CouldComputeBECount, +                           MaxBECount, MaxOrZero);  }  ScalarEvolution::ExitLimit @@ -5825,39 +5904,40 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,        if (EitherMayExit) {          // Both conditions must be true for the loop to continue executing.          // Choose the less conservative count. -        if (EL0.Exact == getCouldNotCompute() || -            EL1.Exact == getCouldNotCompute()) +        if (EL0.ExactNotTaken == getCouldNotCompute() || +            EL1.ExactNotTaken == getCouldNotCompute())            BECount = getCouldNotCompute();          else -          BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); -        if (EL0.Max == getCouldNotCompute()) -          MaxBECount = EL1.Max; -        else if (EL1.Max == getCouldNotCompute()) -          MaxBECount = EL0.Max; +          BECount = +              getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); +        if (EL0.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL1.MaxNotTaken; +        else if (EL1.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL0.MaxNotTaken;          else -          MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); +          MaxBECount = +              getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);        } else {          // Both conditions must be true at the same time for the loop to exit.          // For now, be conservative.          assert(L->contains(FBB) && "Loop block has no successor in loop!"); -        if (EL0.Max == EL1.Max) -          MaxBECount = EL0.Max; -        if (EL0.Exact == EL1.Exact) -          BECount = EL0.Exact; +        if (EL0.MaxNotTaken == EL1.MaxNotTaken) +          MaxBECount = EL0.MaxNotTaken; +        if (EL0.ExactNotTaken == EL1.ExactNotTaken) +          BECount = EL0.ExactNotTaken;        } -      SCEVUnionPredicate NP; -      NP.add(&EL0.Pred); -      NP.add(&EL1.Pred);        // There are cases (e.g. PR26207) where computeExitLimitFromCond is able        // to be more aggressive when computing BECount than when computing -      // MaxBECount.  In these cases it is possible for EL0.Exact and EL1.Exact -      // to match, but for EL0.Max and EL1.Max to not. +      // MaxBECount.  In these cases it is possible for EL0.ExactNotTaken and +      // EL1.ExactNotTaken to match, but for EL0.MaxNotTaken and EL1.MaxNotTaken +      // to not.        if (isa<SCEVCouldNotCompute>(MaxBECount) &&            !isa<SCEVCouldNotCompute>(BECount))          MaxBECount = BECount; -      return ExitLimit(BECount, MaxBECount, NP); +      return ExitLimit(BECount, MaxBECount, false, +                       {&EL0.Predicates, &EL1.Predicates});      }      if (BO->getOpcode() == Instruction::Or) {        // Recurse on the operands of the or. @@ -5873,31 +5953,31 @@ ScalarEvolution::computeExitLimitFromCond(const Loop *L,        if (EitherMayExit) {          // Both conditions must be false for the loop to continue executing.          // Choose the less conservative count. -        if (EL0.Exact == getCouldNotCompute() || -            EL1.Exact == getCouldNotCompute()) +        if (EL0.ExactNotTaken == getCouldNotCompute() || +            EL1.ExactNotTaken == getCouldNotCompute())            BECount = getCouldNotCompute();          else -          BECount = getUMinFromMismatchedTypes(EL0.Exact, EL1.Exact); -        if (EL0.Max == getCouldNotCompute()) -          MaxBECount = EL1.Max; -        else if (EL1.Max == getCouldNotCompute()) -          MaxBECount = EL0.Max; +          BECount = +              getUMinFromMismatchedTypes(EL0.ExactNotTaken, EL1.ExactNotTaken); +        if (EL0.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL1.MaxNotTaken; +        else if (EL1.MaxNotTaken == getCouldNotCompute()) +          MaxBECount = EL0.MaxNotTaken;          else -          MaxBECount = getUMinFromMismatchedTypes(EL0.Max, EL1.Max); +          MaxBECount = +              getUMinFromMismatchedTypes(EL0.MaxNotTaken, EL1.MaxNotTaken);        } else {          // Both conditions must be false at the same time for the loop to exit.          // For now, be conservative.          assert(L->contains(TBB) && "Loop block has no successor in loop!"); -        if (EL0.Max == EL1.Max) -          MaxBECount = EL0.Max; -        if (EL0.Exact == EL1.Exact) -          BECount = EL0.Exact; +        if (EL0.MaxNotTaken == EL1.MaxNotTaken) +          MaxBECount = EL0.MaxNotTaken; +        if (EL0.ExactNotTaken == EL1.ExactNotTaken) +          BECount = EL0.ExactNotTaken;        } -      SCEVUnionPredicate NP; -      NP.add(&EL0.Pred); -      NP.add(&EL1.Pred); -      return ExitLimit(BECount, MaxBECount, NP); +      return ExitLimit(BECount, MaxBECount, false, +                       {&EL0.Predicates, &EL1.Predicates});      }    } @@ -5979,8 +6059,8 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,      if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(LHS))        if (AddRec->getLoop() == L) {          // Form the constant range. -        ConstantRange CompRange( -            ICmpInst::makeConstantRange(Cond, RHSC->getAPInt())); +        ConstantRange CompRange = +            ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt());          const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);          if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; @@ -6184,7 +6264,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(    //   %iv = phi i32 [ %iv.shifted, %loop ], [ %val, %preheader ]    //   %iv.shifted = lshr i32 %iv, <positive constant>    // -  // Return true on a succesful match.  Return the corresponding PHI node (%iv +  // Return true on a successful match.  Return the corresponding PHI node (%iv    // above) in PNOut and the opcode of the shift operation in OpCodeOut.    auto MatchShiftRecurrence =        [&](Value *V, PHINode *&PNOut, Instruction::BinaryOps &OpCodeOut) { @@ -6282,8 +6362,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::computeShiftCompareExitLimit(      unsigned BitWidth = getTypeSizeInBits(RHS->getType());      const SCEV *UpperBound =          getConstant(getEffectiveSCEVType(RHS->getType()), BitWidth); -    SCEVUnionPredicate P; -    return ExitLimit(getCouldNotCompute(), UpperBound, P); +    return ExitLimit(getCouldNotCompute(), UpperBound, false);    }    return getCouldNotCompute(); @@ -7044,7 +7123,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,    // effectively V != 0.  We know and take advantage of the fact that this    // expression only being used in a comparison by zero context. -  SCEVUnionPredicate P; +  SmallPtrSet<const SCEVPredicate *, 4> Predicates;    // If the value is a constant    if (const SCEVConstant *C = dyn_cast<SCEVConstant>(V)) {      // If the value is already zero, the branch will execute zero times. @@ -7057,7 +7136,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,      // Try to make this an AddRec using runtime tests, in the first X      // iterations of this loop, where X is the SCEV expression found by the      // algorithm below. -    AddRec = convertSCEVToAddRecWithPredicates(V, L, P); +    AddRec = convertSCEVToAddRecWithPredicates(V, L, Predicates);    if (!AddRec || AddRec->getLoop() != L)      return getCouldNotCompute(); @@ -7079,7 +7158,8 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,          // should not accept a root of 2.          const SCEV *Val = AddRec->evaluateAtIteration(R1, *this);          if (Val->isZero()) -          return ExitLimit(R1, R1, P); // We found a quadratic root! +          // We found a quadratic root! +          return ExitLimit(R1, R1, false, Predicates);        }      }      return getCouldNotCompute(); @@ -7136,7 +7216,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,      else        MaxBECount = getConstant(CountDown ? CR.getUnsignedMax()                                           : -CR.getUnsignedMin()); -    return ExitLimit(Distance, MaxBECount, P); +    return ExitLimit(Distance, MaxBECount, false, Predicates);    }    // As a special case, handle the instance where Step is a positive power of @@ -7191,7 +7271,7 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,        const SCEV *Limit =            getZeroExtendExpr(getTruncateExpr(ModuloResult, NarrowTy), WideTy); -      return ExitLimit(Limit, Limit, P); +      return ExitLimit(Limit, Limit, false, Predicates);      }    } @@ -7204,14 +7284,14 @@ ScalarEvolution::howFarToZero(const SCEV *V, const Loop *L, bool ControlsExit,        loopHasNoAbnormalExits(AddRec->getLoop())) {      const SCEV *Exact =          getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step); -    return ExitLimit(Exact, Exact, P); +    return ExitLimit(Exact, Exact, false, Predicates);    }    // Then, try to solve the above equation provided that Start is constant.    if (const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start)) {      const SCEV *E = SolveLinEquationWithOverflow(          StepC->getValue()->getValue(), -StartC->getValue()->getValue(), *this); -    return ExitLimit(E, E, P); +    return ExitLimit(E, E, false, Predicates);    }    return getCouldNotCompute();  } @@ -7323,149 +7403,77 @@ bool ScalarEvolution::SimplifyICmpOperands(ICmpInst::Predicate &Pred,    // cases, and canonicalize *-or-equal comparisons to regular comparisons.    if (const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS)) {      const APInt &RA = RC->getAPInt(); -    switch (Pred) { -    default: llvm_unreachable("Unexpected ICmpInst::Predicate value!"); -    case ICmpInst::ICMP_EQ: -    case ICmpInst::ICMP_NE: -      // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. -      if (!RA) -        if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) -          if (const SCEVMulExpr *ME = dyn_cast<SCEVMulExpr>(AE->getOperand(0))) -            if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && -                ME->getOperand(0)->isAllOnesValue()) { -              RHS = AE->getOperand(1); -              LHS = ME->getOperand(1); -              Changed = true; -            } -      break; -    case ICmpInst::ICMP_UGE: -      if ((RA - 1).isMinValue()) { -        Pred = ICmpInst::ICMP_NE; -        RHS = getConstant(RA - 1); -        Changed = true; -        break; -      } -      if (RA.isMaxValue()) { -        Pred = ICmpInst::ICMP_EQ; -        Changed = true; -        break; -      } -      if (RA.isMinValue()) goto trivially_true; -      Pred = ICmpInst::ICMP_UGT; -      RHS = getConstant(RA - 1); -      Changed = true; -      break; -    case ICmpInst::ICMP_ULE: -      if ((RA + 1).isMaxValue()) { -        Pred = ICmpInst::ICMP_NE; -        RHS = getConstant(RA + 1); -        Changed = true; -        break; -      } -      if (RA.isMinValue()) { -        Pred = ICmpInst::ICMP_EQ; -        Changed = true; -        break; -      } -      if (RA.isMaxValue()) goto trivially_true; +    bool SimplifiedByConstantRange = false; -      Pred = ICmpInst::ICMP_ULT; -      RHS = getConstant(RA + 1); -      Changed = true; -      break; -    case ICmpInst::ICMP_SGE: -      if ((RA - 1).isMinSignedValue()) { -        Pred = ICmpInst::ICMP_NE; -        RHS = getConstant(RA - 1); -        Changed = true; -        break; -      } -      if (RA.isMaxSignedValue()) { -        Pred = ICmpInst::ICMP_EQ; -        Changed = true; -        break; +    if (!ICmpInst::isEquality(Pred)) { +      ConstantRange ExactCR = ConstantRange::makeExactICmpRegion(Pred, RA); +      if (ExactCR.isFullSet()) +        goto trivially_true; +      else if (ExactCR.isEmptySet()) +        goto trivially_false; + +      APInt NewRHS; +      CmpInst::Predicate NewPred; +      if (ExactCR.getEquivalentICmp(NewPred, NewRHS) && +          ICmpInst::isEquality(NewPred)) { +        // We were able to convert an inequality to an equality. +        Pred = NewPred; +        RHS = getConstant(NewRHS); +        Changed = SimplifiedByConstantRange = true;        } -      if (RA.isMinSignedValue()) goto trivially_true; +    } -      Pred = ICmpInst::ICMP_SGT; -      RHS = getConstant(RA - 1); -      Changed = true; -      break; -    case ICmpInst::ICMP_SLE: -      if ((RA + 1).isMaxSignedValue()) { -        Pred = ICmpInst::ICMP_NE; -        RHS = getConstant(RA + 1); -        Changed = true; +    if (!SimplifiedByConstantRange) { +      switch (Pred) { +      default:          break; -      } -      if (RA.isMinSignedValue()) { -        Pred = ICmpInst::ICMP_EQ; -        Changed = true; +      case ICmpInst::ICMP_EQ: +      case ICmpInst::ICMP_NE: +        // Fold ((-1) * %a) + %b == 0 (equivalent to %b-%a == 0) into %a == %b. +        if (!RA) +          if (const SCEVAddExpr *AE = dyn_cast<SCEVAddExpr>(LHS)) +            if (const SCEVMulExpr *ME = +                    dyn_cast<SCEVMulExpr>(AE->getOperand(0))) +              if (AE->getNumOperands() == 2 && ME->getNumOperands() == 2 && +                  ME->getOperand(0)->isAllOnesValue()) { +                RHS = AE->getOperand(1); +                LHS = ME->getOperand(1); +                Changed = true; +              }          break; -      } -      if (RA.isMaxSignedValue()) goto trivially_true; -      Pred = ICmpInst::ICMP_SLT; -      RHS = getConstant(RA + 1); -      Changed = true; -      break; -    case ICmpInst::ICMP_UGT: -      if (RA.isMinValue()) { -        Pred = ICmpInst::ICMP_NE; + +        // The "Should have been caught earlier!" messages refer to the fact +        // that the ExactCR.isFullSet() or ExactCR.isEmptySet() check above +        // should have fired on the corresponding cases, and canonicalized the +        // check to trivially_true or trivially_false. + +      case ICmpInst::ICMP_UGE: +        assert(!RA.isMinValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_UGT; +        RHS = getConstant(RA - 1);          Changed = true;          break; -      } -      if ((RA + 1).isMaxValue()) { -        Pred = ICmpInst::ICMP_EQ; +      case ICmpInst::ICMP_ULE: +        assert(!RA.isMaxValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_ULT;          RHS = getConstant(RA + 1);          Changed = true;          break; -      } -      if (RA.isMaxValue()) goto trivially_false; -      break; -    case ICmpInst::ICMP_ULT: -      if (RA.isMaxValue()) { -        Pred = ICmpInst::ICMP_NE; -        Changed = true; -        break; -      } -      if ((RA - 1).isMinValue()) { -        Pred = ICmpInst::ICMP_EQ; +      case ICmpInst::ICMP_SGE: +        assert(!RA.isMinSignedValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_SGT;          RHS = getConstant(RA - 1);          Changed = true;          break; -      } -      if (RA.isMinValue()) goto trivially_false; -      break; -    case ICmpInst::ICMP_SGT: -      if (RA.isMinSignedValue()) { -        Pred = ICmpInst::ICMP_NE; -        Changed = true; -        break; -      } -      if ((RA + 1).isMaxSignedValue()) { -        Pred = ICmpInst::ICMP_EQ; +      case ICmpInst::ICMP_SLE: +        assert(!RA.isMaxSignedValue() && "Should have been caught earlier!"); +        Pred = ICmpInst::ICMP_SLT;          RHS = getConstant(RA + 1);          Changed = true;          break;        } -      if (RA.isMaxSignedValue()) goto trivially_false; -      break; -    case ICmpInst::ICMP_SLT: -      if (RA.isMaxSignedValue()) { -        Pred = ICmpInst::ICMP_NE; -        Changed = true; -        break; -      } -      if ((RA - 1).isMinSignedValue()) { -       Pred = ICmpInst::ICMP_EQ; -       RHS = getConstant(RA - 1); -        Changed = true; -       break; -      } -      if (RA.isMinSignedValue()) goto trivially_false; -      break;      }    } @@ -8025,34 +8033,16 @@ ScalarEvolution::isLoopEntryGuardedByCond(const Loop *L,    return false;  } -namespace { -/// RAII wrapper to prevent recursive application of isImpliedCond. -/// ScalarEvolution's PendingLoopPredicates set must be empty unless we are -/// currently evaluating isImpliedCond. -struct MarkPendingLoopPredicate { -  Value *Cond; -  DenseSet<Value*> &LoopPreds; -  bool Pending; - -  MarkPendingLoopPredicate(Value *C, DenseSet<Value*> &LP) -    : Cond(C), LoopPreds(LP) { -    Pending = !LoopPreds.insert(Cond).second; -  } -  ~MarkPendingLoopPredicate() { -    if (!Pending) -      LoopPreds.erase(Cond); -  } -}; -} // end anonymous namespace -  bool ScalarEvolution::isImpliedCond(ICmpInst::Predicate Pred,                                      const SCEV *LHS, const SCEV *RHS,                                      Value *FoundCondValue,                                      bool Inverse) { -  MarkPendingLoopPredicate Mark(FoundCondValue, PendingLoopPredicates); -  if (Mark.Pending) +  if (!PendingLoopPredicates.insert(FoundCondValue).second)      return false; +  auto ClearOnExit = +      make_scope_exit([&]() { PendingLoopPredicates.erase(FoundCondValue); }); +    // Recursively handle And and Or conditions.    if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FoundCondValue)) {      if (BO->getOpcode() == Instruction::And) { @@ -8237,9 +8227,8 @@ bool ScalarEvolution::splitBinaryAdd(const SCEV *Expr,    return true;  } -bool ScalarEvolution::computeConstantDifference(const SCEV *Less, -                                                const SCEV *More, -                                                APInt &C) { +Optional<APInt> ScalarEvolution::computeConstantDifference(const SCEV *More, +                                                           const SCEV *Less) {    // We avoid subtracting expressions here because this function is usually    // fairly deep in the call stack (i.e. is called many times). @@ -8248,15 +8237,15 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less,      const auto *MAR = cast<SCEVAddRecExpr>(More);      if (LAR->getLoop() != MAR->getLoop()) -      return false; +      return None;      // We look at affine expressions only; not for correctness but to keep      // getStepRecurrence cheap.      if (!LAR->isAffine() || !MAR->isAffine()) -      return false; +      return None;      if (LAR->getStepRecurrence(*this) != MAR->getStepRecurrence(*this)) -      return false; +      return None;      Less = LAR->getStart();      More = MAR->getStart(); @@ -8267,27 +8256,22 @@ bool ScalarEvolution::computeConstantDifference(const SCEV *Less,    if (isa<SCEVConstant>(Less) && isa<SCEVConstant>(More)) {      const auto &M = cast<SCEVConstant>(More)->getAPInt();      const auto &L = cast<SCEVConstant>(Less)->getAPInt(); -    C = M - L; -    return true; +    return M - L;    }    const SCEV *L, *R;    SCEV::NoWrapFlags Flags;    if (splitBinaryAdd(Less, L, R, Flags))      if (const auto *LC = dyn_cast<SCEVConstant>(L)) -      if (R == More) { -        C = -(LC->getAPInt()); -        return true; -      } +      if (R == More) +        return -(LC->getAPInt());    if (splitBinaryAdd(More, L, R, Flags))      if (const auto *LC = dyn_cast<SCEVConstant>(L)) -      if (R == Less) { -        C = LC->getAPInt(); -        return true; -      } +      if (R == Less) +        return LC->getAPInt(); -  return false; +  return None;  }  bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow( @@ -8344,22 +8328,21 @@ bool ScalarEvolution::isImpliedCondOperandsViaNoOverflow(    // neither necessary nor sufficient to prove "(FoundLHS + C) s< (FoundRHS +    // C)". -  APInt LDiff, RDiff; -  if (!computeConstantDifference(FoundLHS, LHS, LDiff) || -      !computeConstantDifference(FoundRHS, RHS, RDiff) || -      LDiff != RDiff) +  Optional<APInt> LDiff = computeConstantDifference(LHS, FoundLHS); +  Optional<APInt> RDiff = computeConstantDifference(RHS, FoundRHS); +  if (!LDiff || !RDiff || *LDiff != *RDiff)      return false; -  if (LDiff == 0) +  if (LDiff->isMinValue())      return true;    APInt FoundRHSLimit;    if (Pred == CmpInst::ICMP_ULT) { -    FoundRHSLimit = -RDiff; +    FoundRHSLimit = -(*RDiff);    } else {      assert(Pred == CmpInst::ICMP_SLT && "Checked above!"); -    FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - RDiff; +    FoundRHSLimit = APInt::getSignedMinValue(getTypeSizeInBits(RHS->getType())) - *RDiff;    }    // Try to prove (1) or (2), as needed. @@ -8469,7 +8452,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,    case ICmpInst::ICMP_SGE:      std::swap(LHS, RHS); -    // fall through +    LLVM_FALLTHROUGH;    case ICmpInst::ICMP_SLE:      return        // min(A, ...) <= A @@ -8479,7 +8462,7 @@ static bool IsKnownPredicateViaMinOrMax(ScalarEvolution &SE,    case ICmpInst::ICMP_UGE:      std::swap(LHS, RHS); -    // fall through +    LLVM_FALLTHROUGH;    case ICmpInst::ICMP_ULE:      return        // min(A, ...) <= A @@ -8550,9 +8533,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,      // reduce the compile time impact of this optimization.      return false; -  const SCEVAddExpr *AddLHS = dyn_cast<SCEVAddExpr>(LHS); -  if (!AddLHS || AddLHS->getOperand(1) != FoundLHS || -      !isa<SCEVConstant>(AddLHS->getOperand(0))) +  Optional<APInt> Addend = computeConstantDifference(LHS, FoundLHS); +  if (!Addend)      return false;    APInt ConstFoundRHS = cast<SCEVConstant>(FoundRHS)->getAPInt(); @@ -8562,10 +8544,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,    ConstantRange FoundLHSRange =        ConstantRange::makeAllowedICmpRegion(Pred, ConstFoundRHS); -  // Since `LHS` is `FoundLHS` + `AddLHS->getOperand(0)`, we can compute a range -  // for `LHS`: -  APInt Addend = cast<SCEVConstant>(AddLHS->getOperand(0))->getAPInt(); -  ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(Addend)); +  // Since `LHS` is `FoundLHS` + `Addend`, we can compute a range for `LHS`: +  ConstantRange LHSRange = FoundLHSRange.add(ConstantRange(*Addend));    // We can also compute the range of values for `LHS` that satisfy the    // consequent, "`LHS` `Pred` `RHS`": @@ -8580,6 +8560,8 @@ bool ScalarEvolution::isImpliedCondOperandsViaRanges(ICmpInst::Predicate Pred,  bool ScalarEvolution::doesIVOverflowOnLT(const SCEV *RHS, const SCEV *Stride,                                           bool IsSigned, bool NoWrap) { +  assert(isKnownPositive(Stride) && "Positive stride expected!"); +    if (NoWrap) return false;    unsigned BitWidth = getTypeSizeInBits(RHS->getType()); @@ -8642,17 +8624,21 @@ ScalarEvolution::ExitLimit  ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,                                    const Loop *L, bool IsSigned,                                    bool ControlsExit, bool AllowPredicates) { -  SCEVUnionPredicate P; +  SmallPtrSet<const SCEVPredicate *, 4> Predicates;    // We handle only IV < Invariant    if (!isLoopInvariant(RHS, L))      return getCouldNotCompute();    const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); -  if (!IV && AllowPredicates) +  bool PredicatedIV = false; + +  if (!IV && AllowPredicates) {      // Try to make this an AddRec using runtime tests, in the first X      // iterations of this loop, where X is the SCEV expression found by the      // algorithm below. -    IV = convertSCEVToAddRecWithPredicates(LHS, L, P); +    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates); +    PredicatedIV = true; +  }    // Avoid weird loops    if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8663,61 +8649,144 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,    const SCEV *Stride = IV->getStepRecurrence(*this); -  // Avoid negative or zero stride values -  if (!isKnownPositive(Stride)) -    return getCouldNotCompute(); +  bool PositiveStride = isKnownPositive(Stride); -  // Avoid proven overflow cases: this will ensure that the backedge taken count -  // will not generate any unsigned overflow. Relaxed no-overflow conditions -  // exploit NoWrapFlags, allowing to optimize in presence of undefined -  // behaviors like the case of C language. -  if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) +  // Avoid negative or zero stride values. +  if (!PositiveStride) { +    // We can compute the correct backedge taken count for loops with unknown +    // strides if we can prove that the loop is not an infinite loop with side +    // effects. Here's the loop structure we are trying to handle - +    // +    // i = start +    // do { +    //   A[i] = i; +    //   i += s; +    // } while (i < end); +    // +    // The backedge taken count for such loops is evaluated as - +    // (max(end, start + stride) - start - 1) /u stride +    // +    // The additional preconditions that we need to check to prove correctness +    // of the above formula is as follows - +    // +    // a) IV is either nuw or nsw depending upon signedness (indicated by the +    //    NoWrap flag). +    // b) loop is single exit with no side effects. +    // +    // +    // Precondition a) implies that if the stride is negative, this is a single +    // trip loop. The backedge taken count formula reduces to zero in this case. +    // +    // Precondition b) implies that the unknown stride cannot be zero otherwise +    // we have UB. +    // +    // The positive stride case is the same as isKnownPositive(Stride) returning +    // true (original behavior of the function). +    // +    // We want to make sure that the stride is truly unknown as there are edge +    // cases where ScalarEvolution propagates no wrap flags to the +    // post-increment/decrement IV even though the increment/decrement operation +    // itself is wrapping. The computed backedge taken count may be wrong in +    // such cases. This is prevented by checking that the stride is not known to +    // be either positive or non-positive. For example, no wrap flags are +    // propagated to the post-increment IV of this loop with a trip count of 2 - +    // +    // unsigned char i; +    // for(i=127; i<128; i+=129) +    //   A[i] = i; +    // +    if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || +        !loopHasNoSideEffects(L)) +      return getCouldNotCompute(); + +  } else if (!Stride->isOne() && +             doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) +    // Avoid proven overflow cases: this will ensure that the backedge taken +    // count will not generate any unsigned overflow. Relaxed no-overflow +    // conditions exploit NoWrapFlags, allowing to optimize in presence of +    // undefined behaviors like the case of C language.      return getCouldNotCompute();    ICmpInst::Predicate Cond = IsSigned ? ICmpInst::ICMP_SLT                                        : ICmpInst::ICMP_ULT;    const SCEV *Start = IV->getStart();    const SCEV *End = RHS; -  if (!isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) +  // If the backedge is taken at least once, then it will be taken +  // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start +  // is the LHS value of the less-than comparison the first time it is evaluated +  // and End is the RHS. +  const SCEV *BECountIfBackedgeTaken = +    computeBECount(getMinusSCEV(End, Start), Stride, false); +  // If the loop entry is guarded by the result of the backedge test of the +  // first loop iteration, then we know the backedge will be taken at least +  // once and so the backedge taken count is as above. If not then we use the +  // expression (max(End,Start)-Start)/Stride to describe the backedge count, +  // as if the backedge is taken at least once max(End,Start) is End and so the +  // result is as above, and if not max(End,Start) is Start so we get a backedge +  // count of zero. +  const SCEV *BECount; +  if (isLoopEntryGuardedByCond(L, Cond, getMinusSCEV(Start, Stride), RHS)) +    BECount = BECountIfBackedgeTaken; +  else {      End = IsSigned ? getSMaxExpr(RHS, Start) : getUMaxExpr(RHS, Start); +    BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); +  } -  const SCEV *BECount = computeBECount(getMinusSCEV(End, Start), Stride, false); +  const SCEV *MaxBECount; +  bool MaxOrZero = false; +  if (isa<SCEVConstant>(BECount)) +    MaxBECount = BECount; +  else if (isa<SCEVConstant>(BECountIfBackedgeTaken)) { +    // If we know exactly how many times the backedge will be taken if it's +    // taken at least once, then the backedge count will either be that or +    // zero. +    MaxBECount = BECountIfBackedgeTaken; +    MaxOrZero = true; +  } else { +    // Calculate the maximum backedge count based on the range of values +    // permitted by Start, End, and Stride. +    APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() +                              : getUnsignedRange(Start).getUnsignedMin(); -  APInt MinStart = IsSigned ? getSignedRange(Start).getSignedMin() -                            : getUnsignedRange(Start).getUnsignedMin(); +    unsigned BitWidth = getTypeSizeInBits(LHS->getType()); -  APInt MinStride = IsSigned ? getSignedRange(Stride).getSignedMin() -                             : getUnsignedRange(Stride).getUnsignedMin(); +    APInt StrideForMaxBECount; -  unsigned BitWidth = getTypeSizeInBits(LHS->getType()); -  APInt Limit = IsSigned ? APInt::getSignedMaxValue(BitWidth) - (MinStride - 1) -                         : APInt::getMaxValue(BitWidth) - (MinStride - 1); +    if (PositiveStride) +      StrideForMaxBECount = +        IsSigned ? getSignedRange(Stride).getSignedMin() +                 : getUnsignedRange(Stride).getUnsignedMin(); +    else +      // Using a stride of 1 is safe when computing max backedge taken count for +      // a loop with unknown stride. +      StrideForMaxBECount = APInt(BitWidth, 1, IsSigned); -  // Although End can be a MAX expression we estimate MaxEnd considering only -  // the case End = RHS. This is safe because in the other case (End - Start) -  // is zero, leading to a zero maximum backedge taken count. -  APInt MaxEnd = -    IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) -             : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); +    APInt Limit = +      IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1) +               : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1); + +    // Although End can be a MAX expression we estimate MaxEnd considering only +    // the case End = RHS. This is safe because in the other case (End - Start) +    // is zero, leading to a zero maximum backedge taken count. +    APInt MaxEnd = +      IsSigned ? APIntOps::smin(getSignedRange(RHS).getSignedMax(), Limit) +               : APIntOps::umin(getUnsignedRange(RHS).getUnsignedMax(), Limit); -  const SCEV *MaxBECount; -  if (isa<SCEVConstant>(BECount)) -    MaxBECount = BECount; -  else      MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), -                                getConstant(MinStride), false); +                                getConstant(StrideForMaxBECount), false); +  }    if (isa<SCEVCouldNotCompute>(MaxBECount))      MaxBECount = BECount; -  return ExitLimit(BECount, MaxBECount, P); +  return ExitLimit(BECount, MaxBECount, MaxOrZero, Predicates);  }  ScalarEvolution::ExitLimit  ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,                                       const Loop *L, bool IsSigned,                                       bool ControlsExit, bool AllowPredicates) { -  SCEVUnionPredicate P; +  SmallPtrSet<const SCEVPredicate *, 4> Predicates;    // We handle only IV > Invariant    if (!isLoopInvariant(RHS, L))      return getCouldNotCompute(); @@ -8727,7 +8796,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,      // Try to make this an AddRec using runtime tests, in the first X      // iterations of this loop, where X is the SCEV expression found by the      // algorithm below. -    IV = convertSCEVToAddRecWithPredicates(LHS, L, P); +    IV = convertSCEVToAddRecWithPredicates(LHS, L, Predicates);    // Avoid weird loops    if (!IV || IV->getLoop() != L || !IV->isAffine()) @@ -8787,7 +8856,7 @@ ScalarEvolution::howManyGreaterThans(const SCEV *LHS, const SCEV *RHS,    if (isa<SCEVCouldNotCompute>(MaxBECount))      MaxBECount = BECount; -  return ExitLimit(BECount, MaxBECount, P); +  return ExitLimit(BECount, MaxBECount, false, Predicates);  }  const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range, @@ -8859,9 +8928,7 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,      // Range.getUpper() is crossed.      SmallVector<const SCEV *, 4> NewOps(op_begin(), op_end());      NewOps[0] = SE.getNegativeSCEV(SE.getConstant(Range.getUpper())); -    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), -                                             // getNoWrapFlags(FlagNW) -                                             FlagAnyWrap); +    const SCEV *NewAddRec = SE.getAddRecExpr(NewOps, getLoop(), FlagAnyWrap);      // Next, solve the constructed addrec      if (auto Roots = @@ -8905,38 +8972,15 @@ const SCEV *SCEVAddRecExpr::getNumIterationsInRange(const ConstantRange &Range,    return SE.getCouldNotCompute();  } -namespace { -struct FindUndefs { -  bool Found; -  FindUndefs() : Found(false) {} - -  bool follow(const SCEV *S) { -    if (const SCEVUnknown *C = dyn_cast<SCEVUnknown>(S)) { -      if (isa<UndefValue>(C->getValue())) -        Found = true; -    } else if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { -      if (isa<UndefValue>(C->getValue())) -        Found = true; -    } - -    // Keep looking if we haven't found it yet. -    return !Found; -  } -  bool isDone() const { -    // Stop recursion if we have found an undef. -    return Found; -  } -}; -} -  // Return true when S contains at least an undef value. -static inline bool -containsUndefs(const SCEV *S) { -  FindUndefs F; -  SCEVTraversal<FindUndefs> ST(F); -  ST.visitAll(S); - -  return F.Found; +static inline bool containsUndefs(const SCEV *S) { +  return SCEVExprContains(S, [](const SCEV *S) { +    if (const auto *SU = dyn_cast<SCEVUnknown>(S)) +      return isa<UndefValue>(SU->getValue()); +    else if (const auto *SC = dyn_cast<SCEVConstant>(S)) +      return isa<UndefValue>(SC->getValue()); +    return false; +  });  }  namespace { @@ -8964,7 +9008,8 @@ struct SCEVCollectTerms {        : Terms(T) {}    bool follow(const SCEV *S) { -    if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S)) { +    if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || +        isa<SCEVSignExtendExpr>(S)) {        if (!containsUndefs(S))          Terms.push_back(S); @@ -9116,10 +9161,9 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,    }    // Remove all SCEVConstants. -  Terms.erase(std::remove_if(Terms.begin(), Terms.end(), [](const SCEV *E) { -                return isa<SCEVConstant>(E); -              }), -              Terms.end()); +  Terms.erase( +      remove_if(Terms, [](const SCEV *E) { return isa<SCEVConstant>(E); }), +      Terms.end());    if (Terms.size() > 0)      if (!findArrayDimensionsRec(SE, Terms, Sizes)) @@ -9129,40 +9173,11 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,    return true;  } -// Returns true when S contains at least a SCEVUnknown parameter. -static inline bool -containsParameters(const SCEV *S) { -  struct FindParameter { -    bool FoundParameter; -    FindParameter() : FoundParameter(false) {} - -    bool follow(const SCEV *S) { -      if (isa<SCEVUnknown>(S)) { -        FoundParameter = true; -        // Stop recursion: we found a parameter. -        return false; -      } -      // Keep looking. -      return true; -    } -    bool isDone() const { -      // Stop recursion if we have found a parameter. -      return FoundParameter; -    } -  }; - -  FindParameter F; -  SCEVTraversal<FindParameter> ST(F); -  ST.visitAll(S); - -  return F.FoundParameter; -}  // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. -static inline bool -containsParameters(SmallVectorImpl<const SCEV *> &Terms) { +static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {    for (const SCEV *T : Terms) -    if (containsParameters(T)) +    if (SCEVExprContains(T, isa<SCEVUnknown, const SCEV *>))        return true;    return false;  } @@ -9493,6 +9508,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)      : F(Arg.F), HasGuards(Arg.HasGuards), TLI(Arg.TLI), AC(Arg.AC), DT(Arg.DT),        LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),        ValueExprMap(std::move(Arg.ValueExprMap)), +      PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)),        WalkingBEDominatingConds(false), ProvingSplitPredicate(false),        BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),        PredicatedBackedgeTakenCounts( @@ -9501,6 +9517,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)            std::move(Arg.ConstantEvolutionLoopExitValue)),        ValuesAtScopes(std::move(Arg.ValuesAtScopes)),        LoopDispositions(std::move(Arg.LoopDispositions)), +      LoopPropertiesCache(std::move(Arg.LoopPropertiesCache)),        BlockDispositions(std::move(Arg.BlockDispositions)),        UnsignedRanges(std::move(Arg.UnsignedRanges)),        SignedRanges(std::move(Arg.SignedRanges)), @@ -9569,6 +9586,8 @@ static void PrintLoopInfo(raw_ostream &OS, ScalarEvolution *SE,    if (!isa<SCEVCouldNotCompute>(SE->getMaxBackedgeTakenCount(L))) {      OS << "max backedge-taken count is " << *SE->getMaxBackedgeTakenCount(L); +    if (SE->isBackedgeTakenCountMaxOrZero(L)) +      OS << ", actual taken count either this or zero.";    } else {      OS << "Unpredictable max backedge-taken count. ";    } @@ -9829,8 +9848,10 @@ ScalarEvolution::computeBlockDisposition(const SCEV *S, const BasicBlock *BB) {      const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(S);      if (!DT.dominates(AR->getLoop()->getHeader(), BB))        return DoesNotDominateBlock; + +    // Fall through into SCEVNAryExpr handling. +    LLVM_FALLTHROUGH;    } -  // FALL THROUGH into SCEVNAryExpr handling.    case scAddExpr:    case scMulExpr:    case scUMaxExpr: @@ -9883,24 +9904,7 @@ bool ScalarEvolution::properlyDominates(const SCEV *S, const BasicBlock *BB) {  }  bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { -  // Search for a SCEV expression node within an expression tree. -  // Implements SCEVTraversal::Visitor. -  struct SCEVSearch { -    const SCEV *Node; -    bool IsFound; - -    SCEVSearch(const SCEV *N): Node(N), IsFound(false) {} - -    bool follow(const SCEV *S) { -      IsFound |= (S == Node); -      return !IsFound; -    } -    bool isDone() const { return IsFound; } -  }; - -  SCEVSearch Search(Op); -  visitAll(S, Search); -  return Search.IsFound; +  return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });  }  void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { @@ -10008,10 +10012,10 @@ void ScalarEvolution::verify() const {    // TODO: Verify more things.  } -char ScalarEvolutionAnalysis::PassID; +AnalysisKey ScalarEvolutionAnalysis::Key;  ScalarEvolution ScalarEvolutionAnalysis::run(Function &F, -                                             AnalysisManager<Function> &AM) { +                                             FunctionAnalysisManager &AM) {    return ScalarEvolution(F, AM.getResult<TargetLibraryAnalysis>(F),                           AM.getResult<AssumptionAnalysis>(F),                           AM.getResult<DominatorTreeAnalysis>(F), @@ -10019,7 +10023,7 @@ ScalarEvolution ScalarEvolutionAnalysis::run(Function &F,  }  PreservedAnalyses -ScalarEvolutionPrinterPass::run(Function &F, AnalysisManager<Function> &AM) { +ScalarEvolutionPrinterPass::run(Function &F, FunctionAnalysisManager &AM) {    AM.getResult<ScalarEvolutionAnalysis>(F).print(OS);    return PreservedAnalyses::all();  } @@ -10106,25 +10110,34 @@ namespace {  class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {  public: -  // Rewrites \p S in the context of a loop L and the predicate A. -  // If Assume is true, rewrite is free to add further predicates to A -  // such that the result will be an AddRecExpr. +  /// Rewrites \p S in the context of a loop L and the SCEV predication +  /// infrastructure. +  /// +  /// If \p Pred is non-null, the SCEV expression is rewritten to respect the +  /// equivalences present in \p Pred. +  /// +  /// If \p NewPreds is non-null, rewrite is free to add further predicates to +  /// \p NewPreds such that the result will be an AddRecExpr.    static const SCEV *rewrite(const SCEV *S, const Loop *L, ScalarEvolution &SE, -                             SCEVUnionPredicate &A, bool Assume) { -    SCEVPredicateRewriter Rewriter(L, SE, A, Assume); +                             SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, +                             SCEVUnionPredicate *Pred) { +    SCEVPredicateRewriter Rewriter(L, SE, NewPreds, Pred);      return Rewriter.visit(S);    }    SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, -                        SCEVUnionPredicate &P, bool Assume) -      : SCEVRewriteVisitor(SE), P(P), L(L), Assume(Assume) {} +                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, +                        SCEVUnionPredicate *Pred) +      : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {}    const SCEV *visitUnknown(const SCEVUnknown *Expr) { -    auto ExprPreds = P.getPredicatesForExpr(Expr); -    for (auto *Pred : ExprPreds) -      if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) -        if (IPred->getLHS() == Expr) -          return IPred->getRHS(); +    if (Pred) { +      auto ExprPreds = Pred->getPredicatesForExpr(Expr); +      for (auto *Pred : ExprPreds) +        if (const auto *IPred = dyn_cast<SCEVEqualPredicate>(Pred)) +          if (IPred->getLHS() == Expr) +            return IPred->getRHS(); +    }      return Expr;    } @@ -10165,32 +10178,31 @@ private:    bool addOverflowAssumption(const SCEVAddRecExpr *AR,                               SCEVWrapPredicate::IncrementWrapFlags AddedFlags) {      auto *A = SE.getWrapPredicate(AR, AddedFlags); -    if (!Assume) { +    if (!NewPreds) {        // Check if we've already made this assumption. -      if (P.implies(A)) -        return true; -      return false; +      return Pred && Pred->implies(A);      } -    P.add(A); +    NewPreds->insert(A);      return true;    } -  SCEVUnionPredicate &P; +  SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; +  SCEVUnionPredicate *Pred;    const Loop *L; -  bool Assume;  };  } // end anonymous namespace  const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,                                                     SCEVUnionPredicate &Preds) { -  return SCEVPredicateRewriter::rewrite(S, L, *this, Preds, false); +  return SCEVPredicateRewriter::rewrite(S, L, *this, nullptr, &Preds);  } -const SCEVAddRecExpr * -ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L, -                                                   SCEVUnionPredicate &Preds) { -  SCEVUnionPredicate TransformPreds; -  S = SCEVPredicateRewriter::rewrite(S, L, *this, TransformPreds, true); +const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( +    const SCEV *S, const Loop *L, +    SmallPtrSetImpl<const SCEVPredicate *> &Preds) { + +  SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; +  S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);    auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);    if (!AddRec) @@ -10198,7 +10210,9 @@ ScalarEvolution::convertSCEVToAddRecWithPredicates(const SCEV *S, const Loop *L,    // Since the transformation was successful, we can now transfer the SCEV    // predicates. -  Preds.add(&TransformPreds); +  for (auto *P : TransformPreds) +    Preds.insert(P); +    return AddRec;  } @@ -10351,7 +10365,7 @@ const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {      return Entry.second;    // We found an entry but it's stale. Rewrite the stale entry -  // acording to the current predicate. +  // according to the current predicate.    if (Entry.second)      Expr = Entry.second; @@ -10425,11 +10439,15 @@ bool PredicatedScalarEvolution::hasNoOverflow(  const SCEVAddRecExpr *PredicatedScalarEvolution::getAsAddRec(Value *V) {    const SCEV *Expr = this->getSCEV(V); -  auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, Preds); +  SmallPtrSet<const SCEVPredicate *, 4> NewPreds; +  auto *New = SE.convertSCEVToAddRecWithPredicates(Expr, &L, NewPreds);    if (!New)      return nullptr; +  for (auto *P : NewPreds) +    Preds.add(P); +    updateGeneration();    RewriteMap[SE.getSCEV(V)] = {Generation, New};    return New;  | 
