diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
| -rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 933 | 
1 files changed, 645 insertions, 288 deletions
| diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 9539fd7c7559..0b8604187121 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -59,12 +59,23 @@  //===----------------------------------------------------------------------===//  #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/None.h"  #include "llvm/ADT/Optional.h"  #include "llvm/ADT/STLExtras.h"  #include "llvm/ADT/ScopeExit.h"  #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h"  #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h"  #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h"  #include "llvm/Analysis/AssumptionCache.h"  #include "llvm/Analysis/ConstantFolding.h"  #include "llvm/Analysis/InstructionSimplify.h" @@ -72,28 +83,55 @@  #include "llvm/Analysis/ScalarEvolutionExpressions.h"  #include "llvm/Analysis/TargetLibraryInfo.h"  #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h"  #include "llvm/IR/ConstantRange.h"  #include "llvm/IR/Constants.h"  #include "llvm/IR/DataLayout.h"  #include "llvm/IR/DerivedTypes.h"  #include "llvm/IR/Dominators.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/Function.h"  #include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h"  #include "llvm/IR/GlobalVariable.h"  #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h"  #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h"  #include "llvm/IR/LLVMContext.h"  #include "llvm/IR/Metadata.h"  #include "llvm/IR/Operator.h"  #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h"  #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h"  #include "llvm/Support/Debug.h"  #include "llvm/Support/ErrorHandling.h"  #include "llvm/Support/KnownBits.h" -#include "llvm/Support/MathExtras.h"  #include "llvm/Support/SaveAndRestore.h"  #include "llvm/Support/raw_ostream.h"  #include <algorithm> +#include <cassert> +#include <climits> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <map> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> +  using namespace llvm;  #define DEBUG_TYPE "scalar-evolution" @@ -115,11 +153,11 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden,                          cl::init(100));  // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. +static cl::opt<bool> VerifySCEV( +    "verify-scev", cl::Hidden, +    cl::desc("Verify ScalarEvolution's backedge taken counts (slow)"));  static cl::opt<bool> -VerifySCEV("verify-scev", -           cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); -static cl::opt<bool> -    VerifySCEVMap("verify-scev-maps", +    VerifySCEVMap("verify-scev-maps", cl::Hidden,                    cl::desc("Verify no dangling value in ScalarEvolution's "                             "ExprValueMap (slow)")); @@ -415,9 +453,6 @@ void SCEVUnknown::deleted() {  }  void SCEVUnknown::allUsesReplacedWith(Value *New) { -  // Clear this SCEVUnknown from various maps. -  SE->forgetMemoizedResults(this); -    // Remove this SCEVUnknown from the uniquing map.    SE->UniqueSCEVs.RemoveNode(this); @@ -514,10 +549,10 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const {  /// Since we do not continue running this routine on expression trees once we  /// have seen unequal values, there is no need to track them in the cache.  static int -CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, +CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue,                         const LoopInfo *const LI, Value *LV, Value *RV,                         unsigned Depth) { -  if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV})) +  if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV))      return 0;    // Order pointer values after integer values. This helps SCEVExpander form @@ -577,14 +612,14 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,      for (unsigned Idx : seq(0u, LNumOps)) {        int Result = -          CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), +          CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx),                                   RInst->getOperand(Idx), Depth + 1);        if (Result != 0)          return Result;      }    } -  EqCache.insert({LV, RV}); +  EqCacheValue.unionSets(LV, RV);    return 0;  } @@ -592,7 +627,8 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache,  // than RHS, respectively. A three-way result allows recursive comparisons to be  // more efficient.  static int CompareSCEVComplexity( -    SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, +    EquivalenceClasses<const SCEV *> &EqCacheSCEV, +    EquivalenceClasses<const Value *> &EqCacheValue,      const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS,      DominatorTree &DT, unsigned Depth = 0) {    // Fast-path: SCEVs are uniqued so we can do a quick equality check. @@ -604,7 +640,7 @@ static int CompareSCEVComplexity(    if (LType != RType)      return (int)LType - (int)RType; -  if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS})) +  if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS))      return 0;    // Aside from the getSCEVType() ordering, the particular ordering    // isn't very important except that it's beneficial to be consistent, @@ -614,11 +650,10 @@ static int CompareSCEVComplexity(      const SCEVUnknown *LU = cast<SCEVUnknown>(LHS);      const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); -    SmallSet<std::pair<Value *, Value *>, 8> EqCache; -    int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), -                                   Depth + 1); +    int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(), +                                   RU->getValue(), Depth + 1);      if (X == 0) -      EqCacheSCEV.insert({LHS, RHS}); +      EqCacheSCEV.unionSets(LHS, RHS);      return X;    } @@ -659,14 +694,19 @@ static int CompareSCEVComplexity(      if (LNumOps != RNumOps)        return (int)LNumOps - (int)RNumOps; +    // Compare NoWrap flags. +    if (LA->getNoWrapFlags() != RA->getNoWrapFlags()) +      return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags(); +      // Lexicographically compare.      for (unsigned i = 0; i != LNumOps; ++i) { -      int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), -                                    RA->getOperand(i), DT,  Depth + 1); +      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                    LA->getOperand(i), RA->getOperand(i), DT, +                                    Depth + 1);        if (X != 0)          return X;      } -    EqCacheSCEV.insert({LHS, RHS}); +    EqCacheSCEV.unionSets(LHS, RHS);      return 0;    } @@ -682,15 +722,18 @@ static int CompareSCEVComplexity(      if (LNumOps != RNumOps)        return (int)LNumOps - (int)RNumOps; +    // Compare NoWrap flags. +    if (LC->getNoWrapFlags() != RC->getNoWrapFlags()) +      return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags(); +      for (unsigned i = 0; i != LNumOps; ++i) { -      if (i >= RNumOps) -        return 1; -      int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), -                                    RC->getOperand(i), DT, Depth + 1); +      int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                    LC->getOperand(i), RC->getOperand(i), DT, +                                    Depth + 1);        if (X != 0)          return X;      } -    EqCacheSCEV.insert({LHS, RHS}); +    EqCacheSCEV.unionSets(LHS, RHS);      return 0;    } @@ -699,14 +742,14 @@ static int CompareSCEVComplexity(      const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS);      // Lexicographically compare udiv expressions. -    int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), -                                  DT, Depth + 1); +    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(), +                                  RC->getLHS(), DT, Depth + 1);      if (X != 0)        return X; -    X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT, -                              Depth + 1); +    X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(), +                              RC->getRHS(), DT, Depth + 1);      if (X == 0) -      EqCacheSCEV.insert({LHS, RHS}); +      EqCacheSCEV.unionSets(LHS, RHS);      return X;    } @@ -717,10 +760,11 @@ static int CompareSCEVComplexity(      const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS);      // Compare cast expressions by operand. -    int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), -                                  RC->getOperand(), DT, Depth + 1); +    int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                  LC->getOperand(), RC->getOperand(), DT, +                                  Depth + 1);      if (X == 0) -      EqCacheSCEV.insert({LHS, RHS}); +      EqCacheSCEV.unionSets(LHS, RHS);      return X;    } @@ -739,26 +783,26 @@ static int CompareSCEVComplexity(  /// results from this routine.  In other words, we don't want the results of  /// this to depend on where the addresses of various SCEV objects happened to  /// land in memory. -///  static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,                                LoopInfo *LI, DominatorTree &DT) {    if (Ops.size() < 2) return;  // Noop -  SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; +  EquivalenceClasses<const SCEV *> EqCacheSCEV; +  EquivalenceClasses<const Value *> EqCacheValue;    if (Ops.size() == 2) {      // This is the common case, which also happens to be trivially simple.      // Special case it.      const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; -    if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0) +    if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0)        std::swap(LHS, RHS);      return;    }    // Do the rough sort by complexity.    std::stable_sort(Ops.begin(), Ops.end(), -                   [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) { -                     return -                         CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0; +                   [&](const SCEV *LHS, const SCEV *RHS) { +                     return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, +                                                  LHS, RHS, DT) < 0;                     });    // Now that we are sorted by complexity, group elements of the same @@ -785,14 +829,16 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops,  // Returns the size of the SCEV S.  static inline int sizeOfSCEV(const SCEV *S) {    struct FindSCEVSize { -    int Size; -    FindSCEVSize() : Size(0) {} +    int Size = 0; + +    FindSCEVSize() = default;      bool follow(const SCEV *S) {        ++Size;        // Keep looking at all operands of S.        return true;      } +      bool isDone() const {        return false;      } @@ -1032,7 +1078,7 @@ private:    const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One;  }; -} +} // end anonymous namespace  //===----------------------------------------------------------------------===//  //                      Simple SCEV method implementations @@ -1157,7 +1203,6 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K,  ///   A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3)  ///  /// where BC(It, k) stands for binomial coefficient. -///  const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It,                                                  ScalarEvolution &SE) const {    const SCEV *Result = getStart(); @@ -1256,6 +1301,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op,    SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator),                                                   Op, Ty);    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  } @@ -1343,7 +1389,8 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase {  const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits<      SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; -} + +} // end anonymous namespace  // The recurrence AR has been shown to have no signed/unsigned wrap or something  // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as @@ -1473,7 +1520,6 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty,  //  // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T  // is `Delta` (defined below). -//  template <typename ExtendOpTy>  bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,                                                  const SCEV *Step, @@ -1484,7 +1530,6 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start,    // time here.  It is correct (but more expensive) to continue with a    // non-constant `Start` and do a general SCEV subtraction to compute    // `PreStart` below. -  //    const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start);    if (!StartC)      return false; @@ -1547,6 +1592,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {      SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),                                                       Op, Ty);      UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S);      return S;    } @@ -1733,6 +1779,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {    SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator),                                                     Op, Ty);    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  } @@ -1770,6 +1817,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {      SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),                                                       Op, Ty);      UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S);      return S;    } @@ -1981,12 +2029,12 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) {    SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator),                                                     Op, Ty);    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  }  /// getAnyExtendExpr - Return a SCEV for the given operand extended with  /// unspecified bits out to the given type. -///  const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,                                                Type *Ty) {    assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -2057,7 +2105,6 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op,  /// may be exposed. This helps getAddRecExpr short-circuit extra work in  /// the common case where no interesting opportunities are present, and  /// is also used as a check to avoid infinite recursion. -///  static bool  CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M,                               SmallVectorImpl<const SCEV *> &NewOps, @@ -2132,7 +2179,8 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type,                        const SmallVectorImpl<const SCEV *> &Ops,                        SCEV::NoWrapFlags Flags) {    using namespace std::placeholders; -  typedef OverflowingBinaryOperator OBO; + +  using OBO = OverflowingBinaryOperator;    bool CanAnalyze =        Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; @@ -2306,12 +2354,23 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,    // Check for truncates. If all the operands are truncated from the same    // type, see if factoring out the truncate would permit the result to be -  // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) +  // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)    // if the contents of the resulting outer trunc fold to something simple. -  for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) { -    const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); -    Type *DstType = Trunc->getType(); -    Type *SrcType = Trunc->getOperand()->getType(); +  auto FindTruncSrcType = [&]() -> Type * { +    // We're ultimately looking to fold an addrec of truncs and muls of only +    // constants and truncs, so if we find any other types of SCEV +    // as operands of the addrec then we bail and return nullptr here. +    // Otherwise, we return the type of the operand of a trunc that we find. +    if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx])) +      return T->getOperand()->getType(); +    if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { +      const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); +      if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp)) +        return T->getOperand()->getType(); +    } +    return nullptr; +  }; +  if (auto *SrcType = FindTruncSrcType()) {      SmallVector<const SCEV *, 8> LargeOps;      bool Ok = true;      // Check all the operands to see if they can be represented in the @@ -2354,7 +2413,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops,        const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);        // If it folds to something simple, use it. Otherwise, don't.        if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) -        return getTruncateExpr(Fold, DstType); +        return getTruncateExpr(Fold, Ty);      }    } @@ -2608,8 +2667,8 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,                                      SCEV::NoWrapFlags Flags) {    FoldingSetNodeID ID;    ID.AddInteger(scAddExpr); -  for (unsigned i = 0, e = Ops.size(); i != e; ++i) -    ID.AddPointer(Ops[i]); +  for (const SCEV *Op : Ops) +    ID.AddPointer(Op);    void *IP = nullptr;    SCEVAddExpr *S =        static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); @@ -2619,6 +2678,7 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops,      S = new (SCEVAllocator)          SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size());      UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S);    }    S->setNoWrapFlags(Flags);    return S; @@ -2640,6 +2700,7 @@ ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops,      S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator),                                          O, Ops.size());      UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S);    }    S->setNoWrapFlags(Flags);    return S; @@ -2679,20 +2740,24 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) {  /// Determine if any of the operands in this SCEV are a constant or if  /// any of the add or multiply expressions in this SCEV contain a constant. -static bool containsConstantSomewhere(const SCEV *StartExpr) { -  SmallVector<const SCEV *, 4> Ops; -  Ops.push_back(StartExpr); -  while (!Ops.empty()) { -    const SCEV *CurrentExpr = Ops.pop_back_val(); -    if (isa<SCEVConstant>(*CurrentExpr)) -      return true; +static bool containsConstantInAddMulChain(const SCEV *StartExpr) { +  struct FindConstantInAddMulChain { +    bool FoundConstant = false; -    if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { -      const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); -      Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); +    bool follow(const SCEV *S) { +      FoundConstant |= isa<SCEVConstant>(S); +      return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);      } -  } -  return false; + +    bool isDone() const { +      return FoundConstant; +    } +  }; + +  FindConstantInAddMulChain F; +  SCEVTraversal<FindConstantInAddMulChain> ST(F); +  ST.visitAll(StartExpr); +  return F.FoundConstant;  }  /// Get a canonical multiply expression, or something simpler if possible. @@ -2729,7 +2794,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,            // If any of Add's ops are Adds or Muls with a constant,            // apply this transformation as well.            if (Add->getNumOperands() == 2) -            if (containsConstantSomewhere(Add)) +            // TODO: There are some cases where this transformation is not +            // profitable, for example: +            // Add = (C0 + X) * Y + Z. +            // Maybe the scope of this transformation should be narrowed down. +            if (containsConstantInAddMulChain(Add))                return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),                                             SCEV::FlagAnyWrap, Depth + 1),                                  getMulExpr(LHSC, Add->getOperand(1), @@ -2941,6 +3010,34 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,    return getOrCreateMulExpr(Ops, Flags);  } +/// Represents an unsigned remainder expression based on unsigned division. +const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, +                                         const SCEV *RHS) { +  assert(getEffectiveSCEVType(LHS->getType()) == +         getEffectiveSCEVType(RHS->getType()) && +         "SCEVURemExpr operand types don't match!"); + +  // Short-circuit easy cases +  if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { +    // If constant is one, the result is trivial +    if (RHSC->getValue()->isOne()) +      return getZero(LHS->getType()); // X urem 1 --> 0 + +    // If constant is a power of two, fold into a zext(trunc(LHS)). +    if (RHSC->getAPInt().isPowerOf2()) { +      Type *FullTy = LHS->getType(); +      Type *TruncTy = +          IntegerType::get(getContext(), RHSC->getAPInt().logBase2()); +      return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy); +    } +  } + +  // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y) +  const SCEV *UDiv = getUDivExpr(LHS, RHS); +  const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); +  return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); +} +  /// Get a canonical unsigned division expression, or something simpler if  /// possible.  const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, @@ -3056,6 +3153,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS,    SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator),                                               LHS, RHS);    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  } @@ -3236,6 +3334,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands,      S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator),                                             O, Operands.size(), L);      UniqueSCEVs.InsertNode(S, IP); +    addToLoopUseLists(S);    }    S->setNoWrapFlags(Flags);    return S; @@ -3391,6 +3490,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {    SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator),                                               O, Ops.size());    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  } @@ -3492,6 +3592,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) {    SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator),                                               O, Ops.size());    UniqueSCEVs.InsertNode(S, IP); +  addToLoopUseLists(S);    return S;  } @@ -3714,7 +3815,6 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) {  }  /// Return a SCEV corresponding to -V = -1*V -///  const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V,                                               SCEV::NoWrapFlags Flags) {    if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) @@ -3957,6 +4057,7 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) {  }  namespace { +  class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> {  public:    static const SCEV *rewrite(const SCEV *S, const Loop *L, @@ -3966,9 +4067,6 @@ public:      return Rewriter.isValid() ? Result : SE.getCouldNotCompute();    } -  SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) -      : SCEVRewriteVisitor(SE), L(L), Valid(true) {} -    const SCEV *visitUnknown(const SCEVUnknown *Expr) {      if (!SE.isLoopInvariant(Expr, L))        Valid = false; @@ -3986,10 +4084,93 @@ public:    bool isValid() { return Valid; }  private: +  explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L) {} + +  const Loop *L; +  bool Valid = true; +}; + +/// This class evaluates the compare condition by matching it against the +/// condition of loop latch. If there is a match we assume a true value +/// for the condition while building SCEV nodes. +class SCEVBackedgeConditionFolder +    : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { +public: +  static const SCEV *rewrite(const SCEV *S, const Loop *L, +                             ScalarEvolution &SE) { +    bool IsPosBECond = false; +    Value *BECond = nullptr; +    if (BasicBlock *Latch = L->getLoopLatch()) { +      BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); +      if (BI && BI->isConditional()) { +        assert(BI->getSuccessor(0) != BI->getSuccessor(1) && +               "Both outgoing branches should not target same header!"); +        BECond = BI->getCondition(); +        IsPosBECond = BI->getSuccessor(0) == L->getHeader(); +      } else { +        return S; +      } +    } +    SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); +    return Rewriter.visit(S); +  } + +  const SCEV *visitUnknown(const SCEVUnknown *Expr) { +    const SCEV *Result = Expr; +    bool InvariantF = SE.isLoopInvariant(Expr, L); + +    if (!InvariantF) { +      Instruction *I = cast<Instruction>(Expr->getValue()); +      switch (I->getOpcode()) { +      case Instruction::Select: { +        SelectInst *SI = cast<SelectInst>(I); +        Optional<const SCEV *> Res = +            compareWithBackedgeCondition(SI->getCondition()); +        if (Res.hasValue()) { +          bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); +          Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); +        } +        break; +      } +      default: { +        Optional<const SCEV *> Res = compareWithBackedgeCondition(I); +        if (Res.hasValue()) +          Result = Res.getValue(); +        break; +      } +      } +    } +    return Result; +  } + +private: +  explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, +                                       bool IsPosBECond, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), +        IsPositiveBECond(IsPosBECond) {} + +  Optional<const SCEV *> compareWithBackedgeCondition(Value *IC); +    const Loop *L; -  bool Valid; +  /// Loop back condition. +  Value *BackedgeCond = nullptr; +  /// Set to true if loop back is on positive branch condition. +  bool IsPositiveBECond;  }; +Optional<const SCEV *> +SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { + +  // If value matches the backedge condition for loop latch, +  // then return a constant evolution node based on loopback +  // branch taken. +  if (BackedgeCond == IC) +    return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) +                            : SE.getZero(Type::getInt1Ty(SE.getContext())); +  return None; +} +  class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> {  public:    static const SCEV *rewrite(const SCEV *S, const Loop *L, @@ -3999,9 +4180,6 @@ public:      return Rewriter.isValid() ? Result : SE.getCouldNotCompute();    } -  SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) -      : SCEVRewriteVisitor(SE), L(L), Valid(true) {} -    const SCEV *visitUnknown(const SCEVUnknown *Expr) {      // Only allow AddRecExprs for this loop.      if (!SE.isLoopInvariant(Expr, L)) @@ -4015,12 +4193,17 @@ public:      Valid = false;      return Expr;    } +    bool isValid() { return Valid; }  private: +  explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) +      : SCEVRewriteVisitor(SE), L(L) {} +    const Loop *L; -  bool Valid; +  bool Valid = true;  }; +  } // end anonymous namespace  SCEV::NoWrapFlags @@ -4028,7 +4211,8 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {    if (!AR->isAffine())      return SCEV::FlagAnyWrap; -  typedef OverflowingBinaryOperator OBO; +  using OBO = OverflowingBinaryOperator; +    SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap;    if (!AR->hasNoSignedWrap()) { @@ -4055,6 +4239,7 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) {  }  namespace { +  /// Represents an abstract binary operation.  This may exist as a  /// normal instruction or constant expression, or may have been  /// derived from an expression tree. @@ -4062,16 +4247,16 @@ struct BinaryOp {    unsigned Opcode;    Value *LHS;    Value *RHS; -  bool IsNSW; -  bool IsNUW; +  bool IsNSW = false; +  bool IsNUW = false;    /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or    /// constant expression. -  Operator *Op; +  Operator *Op = nullptr;    explicit BinaryOp(Operator *Op)        : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), -        IsNSW(false), IsNUW(false), Op(Op) { +        Op(Op) {      if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) {        IsNSW = OBO->hasNoSignedWrap();        IsNUW = OBO->hasNoUnsignedWrap(); @@ -4080,11 +4265,10 @@ struct BinaryOp {    explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false,                      bool IsNUW = false) -      : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), -        Op(nullptr) {} +      : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {}  }; -} +} // end anonymous namespace  /// Try to map \p V into a BinaryOp, and return \c None on failure.  static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { @@ -4101,6 +4285,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {    case Instruction::Sub:    case Instruction::Mul:    case Instruction::UDiv: +  case Instruction::URem:    case Instruction::And:    case Instruction::Or:    case Instruction::AShr: @@ -4145,7 +4330,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {      if (auto *F = CI->getCalledFunction())        switch (F->getIntrinsicID()) {        case Intrinsic::sadd_with_overflow: -      case Intrinsic::uadd_with_overflow: { +      case Intrinsic::uadd_with_overflow:          if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT))            return BinaryOp(Instruction::Add, CI->getArgOperand(0),                            CI->getArgOperand(1)); @@ -4161,13 +4346,21 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {            return BinaryOp(Instruction::Add, CI->getArgOperand(0),                            CI->getArgOperand(1), /* IsNSW = */ false,                            /* IsNUW*/ true); -      } -        case Intrinsic::ssub_with_overflow:        case Intrinsic::usub_with_overflow: -        return BinaryOp(Instruction::Sub, CI->getArgOperand(0), -                        CI->getArgOperand(1)); +        if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1)); +        // The same reasoning as sadd/uadd above. +        if (F->getIntrinsicID() == Intrinsic::ssub_with_overflow) +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ true, +                          /* IsNUW = */ false); +        else +          return BinaryOp(Instruction::Sub, CI->getArgOperand(0), +                          CI->getArgOperand(1), /* IsNSW = */ false, +                          /* IsNUW = */ true);        case Intrinsic::smul_with_overflow:        case Intrinsic::umul_with_overflow:          return BinaryOp(Instruction::Mul, CI->getArgOperand(0), @@ -4184,28 +4377,27 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) {    return None;  } -/// Helper function to createAddRecFromPHIWithCasts. We have a phi  +/// Helper function to createAddRecFromPHIWithCasts. We have a phi  /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via -/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the  -/// way. This function checks if \p Op, an operand of this SCEVAddExpr,  +/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the +/// way. This function checks if \p Op, an operand of this SCEVAddExpr,  /// follows one of the following patterns:  /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)  /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy)  /// If the SCEV expression of \p Op conforms with one of the expected patterns  /// we return the type of the truncation operation, and indicate whether the -/// truncated type should be treated as signed/unsigned by setting  +/// truncated type should be treated as signed/unsigned by setting  /// \p Signed to true/false, respectively.  static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,                                 bool &Signed, ScalarEvolution &SE) { - -  // The case where Op == SymbolicPHI (that is, with no type conversions on  -  // the way) is handled by the regular add recurrence creating logic and  +  // The case where Op == SymbolicPHI (that is, with no type conversions on +  // the way) is handled by the regular add recurrence creating logic and    // would have already been triggered in createAddRecForPHI. Reaching it here -  // means that createAddRecFromPHI had failed for this PHI before (e.g.,  +  // means that createAddRecFromPHI had failed for this PHI before (e.g.,    // because one of the other operands of the SCEVAddExpr updating this PHI is -  // not invariant).  +  // not invariant).    // -  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in  +  // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in    // this case predicates that allow us to prove that Op == SymbolicPHI will    // be added.    if (Op == SymbolicPHI) @@ -4228,7 +4420,7 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI,    const SCEV *X = Trunc->getOperand();    if (X != SymbolicPHI)      return nullptr; -  Signed = SExt ? true : false;  +  Signed = SExt != nullptr;    return Trunc->getType();  } @@ -4257,7 +4449,7 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {  //    It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X),  //    and call this function with %SymbolicPHI = %X.  // -//    The analysis will find that the value coming around the backedge has  +//    The analysis will find that the value coming around the backedge has  //    the following SCEV:  //         BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step)  //    Upon concluding that this matches the desired pattern, the function @@ -4270,21 +4462,21 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {  //    The returned pair means that SymbolicPHI can be rewritten into NewAddRec  //    under the predicates {P1,P2,P3}.  //    This predicated rewrite will be cached in PredicatedSCEVRewrites: -//         PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}  +//         PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)}  //  // TODO's:  //  // 1) Extend the Induction descriptor to also support inductions that involve -//    casts: When needed (namely, when we are called in the context of the  -//    vectorizer induction analysis), a Set of cast instructions will be  +//    casts: When needed (namely, when we are called in the context of the +//    vectorizer induction analysis), a Set of cast instructions will be  //    populated by this method, and provided back to isInductionPHI. This is  //    needed to allow the vectorizer to properly record them to be ignored by  //    the cost model and to avoid vectorizing them (otherwise these casts, -//    which are redundant under the runtime overflow checks, will be  -//    vectorized, which can be costly).   +//    which are redundant under the runtime overflow checks, will be +//    vectorized, which can be costly).  //  // 2) Support additional induction/PHISCEV patterns: We also want to support -//    inductions where the sext-trunc / zext-trunc operations (partly) occur  +//    inductions where the sext-trunc / zext-trunc operations (partly) occur  //    after the induction update operation (the induction increment):  //  //      (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) @@ -4294,17 +4486,16 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) {  //    which correspond to a phi->trunc->add->sext/zext->phi update chain.  //  // 3) Outline common code with createAddRecFromPHI to avoid duplication. -//  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>  ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) {    SmallVector<const SCEVPredicate *, 3> Predicates; -  // *** Part1: Analyze if we have a phi-with-cast pattern for which we can  +  // *** Part1: Analyze if we have a phi-with-cast pattern for which we can    // return an AddRec expression under some predicate. -  +    auto *PN = cast<PHINode>(SymbolicPHI->getValue());    const Loop *L = isIntegerLoopHeaderPHI(PN, LI); -  assert (L && "Expecting an integer loop header phi"); +  assert(L && "Expecting an integer loop header phi");    // The loop may have multiple entrances or multiple exits; we can analyze    // this phi as an addrec if it has a unique entry value and a unique @@ -4339,12 +4530,12 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI      return None;    // If there is a single occurrence of the symbolic value, possibly -  // casted, replace it with a recurrence.  +  // casted, replace it with a recurrence.    unsigned FoundIndex = Add->getNumOperands();    Type *TruncTy = nullptr;    bool Signed;    for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) -    if ((TruncTy =  +    if ((TruncTy =               isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this)))        if (FoundIndex == e) {          FoundIndex = i; @@ -4366,77 +4557,122 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI    if (!isLoopInvariant(Accum, L))      return None; -   -  // *** Part2: Create the predicates  +  // *** Part2: Create the predicates    // Analysis was successful: we have a phi-with-cast pattern for which we    // can return an AddRec expression under the following predicates:    //    // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum)    //     fits within the truncated type (does not overflow) for i = 0 to n-1. -  // P2: An Equal predicate that guarantees that  +  // P2: An Equal predicate that guarantees that    //     Start = (Ext ix (Trunc iy (Start) to ix) to iy) -  // P3: An Equal predicate that guarantees that  +  // P3: An Equal predicate that guarantees that    //     Accum = (Ext ix (Trunc iy (Accum) to ix) to iy)    // -  // As we next prove, the above predicates guarantee that:  +  // As we next prove, the above predicates guarantee that:    //     Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy)    //    //    // More formally, we want to prove that: -  //     Expr(i+1) = Start + (i+1) * Accum  -  //               = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum  +  //     Expr(i+1) = Start + (i+1) * Accum +  //               = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum    //    // Given that: -  // 1) Expr(0) = Start  -  // 2) Expr(1) = Start + Accum  +  // 1) Expr(0) = Start +  // 2) Expr(1) = Start + Accum    //            = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2    // 3) Induction hypothesis (step i): -  //    Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum  +  //    Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum    //    // Proof:    //  Expr(i+1) =    //   = Start + (i+1)*Accum    //   = (Start + i*Accum) + Accum -  //   = Expr(i) + Accum   -  //   = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum  +  //   = Expr(i) + Accum +  //   = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum    //                                                             :: from step i    // -  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum  +  //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum    //    //   = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy)    //     + (Ext ix (Trunc iy (Accum) to ix) to iy)    //     + Accum                                                     :: from P3    // -  //   = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)  +  //   = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy)    //     + Accum                            :: from P1: Ext(x)+Ext(y)=>Ext(x+y)    //    //   = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum -  //   = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum  +  //   = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum    //    // By induction, the same applies to all iterations 1<=i<n:    // -   +    // Create a truncated addrec for which we will add a no overflow check (P1).    const SCEV *StartVal = getSCEV(StartValueV); -  const SCEV *PHISCEV =  +  const SCEV *PHISCEV =        getAddRecExpr(getTruncateExpr(StartVal, TruncTy), -                    getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap);  -  const auto *AR = cast<SCEVAddRecExpr>(PHISCEV); +                    getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); -  SCEVWrapPredicate::IncrementWrapFlags AddedFlags = -      Signed ? SCEVWrapPredicate::IncrementNSSW -             : SCEVWrapPredicate::IncrementNUSW; -  const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); -  Predicates.push_back(AddRecPred); +  // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr. +  // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV +  // will be constant. +  // +  //  If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't +  // add P1. +  if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { +    SCEVWrapPredicate::IncrementWrapFlags AddedFlags = +        Signed ? SCEVWrapPredicate::IncrementNSSW +               : SCEVWrapPredicate::IncrementNUSW; +    const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); +    Predicates.push_back(AddRecPred); +  }    // Create the Equal Predicates P2,P3: -  auto AppendPredicate = [&](const SCEV *Expr) -> void { -    assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); + +  // It is possible that the predicates P2 and/or P3 are computable at +  // compile time due to StartVal and/or Accum being constants. +  // If either one is, then we can check that now and escape if either P2 +  // or P3 is false. + +  // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) +  // for each of StartVal and Accum +  auto getExtendedExpr = [&](const SCEV *Expr,  +                             bool CreateSignExtend) -> const SCEV * { +    assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant");      const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy);      const SCEV *ExtendedExpr = -        Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType()) -               : getZeroExtendExpr(TruncatedExpr, Expr->getType()); +        CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) +                         : getZeroExtendExpr(TruncatedExpr, Expr->getType()); +    return ExtendedExpr; +  }; + +  // Given: +  //  ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy +  //               = getExtendedExpr(Expr) +  // Determine whether the predicate P: Expr == ExtendedExpr +  // is known to be false at compile time +  auto PredIsKnownFalse = [&](const SCEV *Expr, +                              const SCEV *ExtendedExpr) -> bool { +    return Expr != ExtendedExpr && +           isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); +  }; + +  const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); +  if (PredIsKnownFalse(StartVal, StartExtended)) { +    DEBUG(dbgs() << "P2 is compile-time false\n";); +    return None; +  } + +  // The Step is always Signed (because the overflow checks are either +  // NSSW or NUSW) +  const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); +  if (PredIsKnownFalse(Accum, AccumExtended)) { +    DEBUG(dbgs() << "P3 is compile-time false\n";); +    return None; +  } + +  auto AppendPredicate = [&](const SCEV *Expr, +                             const SCEV *ExtendedExpr) -> void {      if (Expr != ExtendedExpr &&          !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) {        const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); @@ -4444,14 +4680,14 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI        Predicates.push_back(Pred);      }    }; -   -  AppendPredicate(StartVal); -  AppendPredicate(Accum); -   + +  AppendPredicate(StartVal, StartExtended); +  AppendPredicate(Accum, AccumExtended); +    // *** Part3: Predicates are ready. Now go ahead and create the new addrec in    // which the casts had been folded away. The caller can rewrite SymbolicPHI    // into NewAR if it will also add the runtime overflow checks specified in -  // Predicates.   +  // Predicates.    auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap);    std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = @@ -4463,7 +4699,6 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI  Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>>  ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { -    auto *PN = cast<PHINode>(SymbolicPHI->getValue());    const Loop *L = isIntegerLoopHeaderPHI(PN, LI);    if (!L) @@ -4475,7 +4710,7 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {      std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite =          I->second;      // Analysis was done before and failed to create an AddRec: -    if (Rewrite.first == SymbolicPHI)  +    if (Rewrite.first == SymbolicPHI)        return None;      // Analysis was done before and succeeded to create an AddRec under      // a predicate: @@ -4497,6 +4732,30 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) {    return Rewrite;  } +// FIXME: This utility is currently required because the Rewriter currently  +// does not rewrite this expression:  +// {0, +, (sext ix (trunc iy to ix) to iy)}  +// into {0, +, %step}, +// even when the following Equal predicate exists:  +// "%step == (sext ix (trunc iy to ix) to iy)". +bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( +    const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const { +  if (AR1 == AR2) +    return true; + +  auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { +    if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) && +        !Preds.implies(SE.getEqualPredicate(Expr2, Expr1))) +      return false; +    return true; +  }; + +  if (!areExprsEqual(AR1->getStart(), AR2->getStart()) || +      !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE))) +    return false; +  return true; +} +  /// A helper function for createAddRecFromPHI to handle simple cases.  ///  /// This function tries to find an AddRec expression for the simplest (yet most @@ -4612,7 +4871,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) {        SmallVector<const SCEV *, 8> Ops;        for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i)          if (i != FoundIndex) -          Ops.push_back(Add->getOperand(i)); +          Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), +                                                             L, *this));        const SCEV *Accum = getAddExpr(Ops);        // This is not a valid addrec if the step amount is varying each @@ -5599,7 +5859,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) {  ScalarEvolution::LoopProperties  ScalarEvolution::getLoopProperties(const Loop *L) { -  typedef ScalarEvolution::LoopProperties LoopProperties; +  using LoopProperties = ScalarEvolution::LoopProperties;    auto Itr = LoopPropertiesCache.find(L);    if (Itr == LoopPropertiesCache.end()) { @@ -5735,6 +5995,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {      }      case Instruction::UDiv:        return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); +    case Instruction::URem: +      return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS));      case Instruction::Sub: {        SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap;        if (BO->Op) @@ -5886,7 +6148,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {      }      break; -    case Instruction::AShr: +    case Instruction::AShr: {        // AShr X, C, where C is a constant.        ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS);        if (!CI) @@ -5938,6 +6200,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {        }        break;      } +    }    }    switch (U->getOpcode()) { @@ -5948,6 +6211,21 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {      return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType());    case Instruction::SExt: +    if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) { +      // The NSW flag of a subtract does not always survive the conversion to +      // A + (-1)*B.  By pushing sign extension onto its operands we are much +      // more likely to preserve NSW and allow later AddRec optimisations. +      // +      // NOTE: This is effectively duplicating this logic from getSignExtend: +      //   sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> +      // but by that point the NSW information has potentially been lost. +      if (BO->Opcode == Instruction::Sub && BO->IsNSW) { +        Type *Ty = U->getType(); +        auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); +        auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); +        return getMinusSCEV(V1, V2, SCEV::FlagNSW); +      } +    }      return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType());    case Instruction::BitCast: @@ -5987,8 +6265,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) {    return getUnknown(V);  } - -  //===----------------------------------------------------------------------===//  //                   Iteration Count Computation Code  // @@ -6177,11 +6453,9 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {      SmallVector<Instruction *, 16> Worklist;      PushLoopPHIs(L, Worklist); -    SmallPtrSet<Instruction *, 8> Visited; +    SmallPtrSet<Instruction *, 8> Discovered;      while (!Worklist.empty()) {        Instruction *I = Worklist.pop_back_val(); -      if (!Visited.insert(I).second) -        continue;        ValueExprMapType::iterator It =          ValueExprMap.find_as(static_cast<Value *>(I)); @@ -6202,7 +6476,31 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {            ConstantEvolutionLoopExitValue.erase(PN);        } -      PushDefUseChildren(I, Worklist); +      // Since we don't need to invalidate anything for correctness and we're +      // only invalidating to make SCEV's results more precise, we get to stop +      // early to avoid invalidating too much.  This is especially important in +      // cases like: +      // +      //   %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node +      // loop0: +      //   %pn0 = phi +      //   ... +      // loop1: +      //   %pn1 = phi +      //   ... +      // +      // where both loop0 and loop1's backedge taken count uses the SCEV +      // expression for %v.  If we don't have the early stop below then in cases +      // like the above, getBackedgeTakenInfo(loop1) will clear out the trip +      // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip +      // count for loop1, effectively nullifying SCEV's trip count cache. +      for (auto *U : I->users()) +        if (auto *I = dyn_cast<Instruction>(U)) { +          auto *LoopForUser = LI.getLoopFor(I->getParent()); +          if (LoopForUser && L->contains(LoopForUser) && +              Discovered.insert(I).second) +            Worklist.push_back(I); +        }      }    } @@ -6217,7 +6515,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) {  void ScalarEvolution::forgetLoop(const Loop *L) {    // Drop any stored trip count value.    auto RemoveLoopFromBackedgeMap = -      [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { +      [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) {          auto BTCPos = Map.find(L);          if (BTCPos != Map.end()) {            BTCPos->second.clear(); @@ -6225,47 +6523,59 @@ void ScalarEvolution::forgetLoop(const Loop *L) {          }        }; -  RemoveLoopFromBackedgeMap(BackedgeTakenCounts); -  RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); +  SmallVector<const Loop *, 16> LoopWorklist(1, L); +  SmallVector<Instruction *, 32> Worklist; +  SmallPtrSet<Instruction *, 16> Visited; -  // Drop information about predicated SCEV rewrites for this loop. -  for (auto I = PredicatedSCEVRewrites.begin(); -       I != PredicatedSCEVRewrites.end();) { -    std::pair<const SCEV *, const Loop *> Entry = I->first; -    if (Entry.second == L) -      PredicatedSCEVRewrites.erase(I++); -    else -      ++I; -  } +  // Iterate over all the loops and sub-loops to drop SCEV information. +  while (!LoopWorklist.empty()) { +    auto *CurrL = LoopWorklist.pop_back_val(); -  // Drop information about expressions based on loop-header PHIs. -  SmallVector<Instruction *, 16> Worklist; -  PushLoopPHIs(L, Worklist); +    RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL); +    RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL); -  SmallPtrSet<Instruction *, 8> Visited; -  while (!Worklist.empty()) { -    Instruction *I = Worklist.pop_back_val(); -    if (!Visited.insert(I).second) -      continue; +    // Drop information about predicated SCEV rewrites for this loop. +    for (auto I = PredicatedSCEVRewrites.begin(); +         I != PredicatedSCEVRewrites.end();) { +      std::pair<const SCEV *, const Loop *> Entry = I->first; +      if (Entry.second == CurrL) +        PredicatedSCEVRewrites.erase(I++); +      else +        ++I; +    } -    ValueExprMapType::iterator It = -      ValueExprMap.find_as(static_cast<Value *>(I)); -    if (It != ValueExprMap.end()) { -      eraseValueFromMap(It->first); -      forgetMemoizedResults(It->second); -      if (PHINode *PN = dyn_cast<PHINode>(I)) -        ConstantEvolutionLoopExitValue.erase(PN); +    auto LoopUsersItr = LoopUsers.find(CurrL); +    if (LoopUsersItr != LoopUsers.end()) { +      for (auto *S : LoopUsersItr->second) +        forgetMemoizedResults(S); +      LoopUsers.erase(LoopUsersItr);      } -    PushDefUseChildren(I, Worklist); -  } +    // Drop information about expressions based on loop-header PHIs. +    PushLoopPHIs(CurrL, Worklist); -  // Forget all contained loops too, to avoid dangling entries in the -  // ValuesAtScopes map. -  for (Loop *I : *L) -    forgetLoop(I); +    while (!Worklist.empty()) { +      Instruction *I = Worklist.pop_back_val(); +      if (!Visited.insert(I).second) +        continue; -  LoopPropertiesCache.erase(L); +      ValueExprMapType::iterator It = +          ValueExprMap.find_as(static_cast<Value *>(I)); +      if (It != ValueExprMap.end()) { +        eraseValueFromMap(It->first); +        forgetMemoizedResults(It->second); +        if (PHINode *PN = dyn_cast<PHINode>(I)) +          ConstantEvolutionLoopExitValue.erase(PN); +      } + +      PushDefUseChildren(I, Worklist); +    } + +    LoopPropertiesCache.erase(CurrL); +    // Forget all contained loops too, to avoid dangling entries in the +    // ValuesAtScopes map. +    LoopWorklist.append(CurrL->begin(), CurrL->end()); +  }  }  void ScalarEvolution::forgetValue(Value *V) { @@ -6377,7 +6687,7 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S,  }  ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) -    : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) { +    : ExactNotTaken(E), MaxNotTaken(E) {    assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||            isa<SCEVConstant>(MaxNotTaken)) &&           "No point in having a non-constant max backedge taken count!"); @@ -6422,7 +6732,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo(          &&ExitCounts,      bool Complete, const SCEV *MaxCount, bool MaxOrZero)      : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { -  typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; +  using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; +    ExitNotTaken.reserve(ExitCounts.size());    std::transform(        ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), @@ -6454,7 +6765,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,    SmallVector<BasicBlock *, 8> ExitingBlocks;    L->getExitingBlocks(ExitingBlocks); -  typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; +  using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo;    SmallVector<EdgeExitInfo, 4> ExitCounts;    bool CouldComputeBECount = true; @@ -6521,8 +6832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L,  ScalarEvolution::ExitLimit  ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, -                                  bool AllowPredicates) { - +                                      bool AllowPredicates) {    // Okay, we've chosen an exiting block.  See what condition causes us to exit    // at this block and remember the exit block and whether all other targets    // lead to the loop header. @@ -6785,19 +7095,19 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,                                            BasicBlock *FBB,                                            bool ControlsExit,                                            bool AllowPredicates) { -    // If the condition was exit on true, convert the condition to exit on false -  ICmpInst::Predicate Cond; +  ICmpInst::Predicate Pred;    if (!L->contains(FBB)) -    Cond = ExitCond->getPredicate(); +    Pred = ExitCond->getPredicate();    else -    Cond = ExitCond->getInversePredicate(); +    Pred = ExitCond->getInversePredicate(); +  const ICmpInst::Predicate OriginalPred = Pred;    // Handle common loops like: for (X = "string"; *X; ++X)    if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0)))      if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) {        ExitLimit ItCnt = -        computeLoadConstantCompareExitLimit(LI, RHS, L, Cond); +        computeLoadConstantCompareExitLimit(LI, RHS, L, Pred);        if (ItCnt.hasAnyInfo())          return ItCnt;      } @@ -6814,11 +7124,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,    if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) {      // If there is a loop-invariant, force it into the RHS.      std::swap(LHS, RHS); -    Cond = ICmpInst::getSwappedPredicate(Cond); +    Pred = ICmpInst::getSwappedPredicate(Pred);    }    // Simplify the operands before analyzing them. -  (void)SimplifyICmpOperands(Cond, LHS, RHS); +  (void)SimplifyICmpOperands(Pred, LHS, RHS);    // If we have a comparison of a chrec against a constant, try to use value    // ranges to answer this query. @@ -6827,13 +7137,13 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,        if (AddRec->getLoop() == L) {          // Form the constant range.          ConstantRange CompRange = -            ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt()); +            ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt());          const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this);          if (!isa<SCEVCouldNotCompute>(Ret)) return Ret;        } -  switch (Cond) { +  switch (Pred) {    case ICmpInst::ICMP_NE: {                     // while (X != Y)      // Convert to: while (X-Y != 0)      ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, @@ -6849,7 +7159,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,    }    case ICmpInst::ICMP_SLT:    case ICmpInst::ICMP_ULT: {                    // while (X < Y) -    bool IsSigned = Cond == ICmpInst::ICMP_SLT; +    bool IsSigned = Pred == ICmpInst::ICMP_SLT;      ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit,                                      AllowPredicates);      if (EL.hasAnyInfo()) return EL; @@ -6857,7 +7167,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,    }    case ICmpInst::ICMP_SGT:    case ICmpInst::ICMP_UGT: {                    // while (X > Y) -    bool IsSigned = Cond == ICmpInst::ICMP_SGT; +    bool IsSigned = Pred == ICmpInst::ICMP_SGT;      ExitLimit EL =          howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit,                              AllowPredicates); @@ -6875,7 +7185,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L,      return ExhaustiveCount;    return computeShiftCompareExitLimit(ExitCond->getOperand(0), -                                      ExitCond->getOperand(1), L, Cond); +                                      ExitCond->getOperand(1), L, OriginalPred);  }  ScalarEvolution::ExitLimit @@ -6920,7 +7230,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit(    Constant *RHS,    const Loop *L,    ICmpInst::Predicate predicate) { -    if (LI->isVolatile()) return getCouldNotCompute();    // Check to see if the loaded pointer is a getelementptr of a global. @@ -7333,8 +7642,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN,    Value *BEValue = PN->getIncomingValueForBlock(Latch);    // Execute the loop symbolically to determine the exit value. -  if (BEs.getActiveBits() >= 32) -    return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! +  assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && +         "BEs is <= MaxBruteForceIterations which is an 'unsigned'!");    unsigned NumIterations = BEs.getZExtValue(); // must be in range    unsigned IterationNum = 0; @@ -7839,7 +8148,6 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B,  /// Find the roots of the quadratic equation for the given quadratic chrec  /// {L,+,M,+,N}.  This returns either the two roots (which might be the same) or  /// two SCEVCouldNotCompute objects. -///  static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>>  SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) {    assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); @@ -8080,7 +8388,6 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) {  /// expressions are equal, however for the purposes of looking for a condition  /// guarding a loop, it can be useful to be a little more general, since a  /// front-end may have replicated the controlling expression. -///  static bool HasSameValue(const SCEV *A, const SCEV *B) {    // Quick check to see if they are the same SCEV.    if (A == B) return true; @@ -8527,7 +8834,6 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges(  bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred,                                                      const SCEV *LHS,                                                      const SCEV *RHS) { -    // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer.    // Return Y via OutY.    auto MatchBinaryAddToConst = @@ -8693,7 +8999,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L,    for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()];         DTN != HeaderDTN; DTN = DTN->getIDom()) { -      assert(DTN && "should reach the loop header before reaching the root!");      BasicBlock *BB = DTN->getBlock(); @@ -9116,7 +9421,6 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred,                                       getNotSCEV(FoundLHS));  } -  /// If Expr computes ~A, return A else return nullptr  static const SCEV *MatchNotExpr(const SCEV *Expr) {    const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); @@ -9132,7 +9436,6 @@ static const SCEV *MatchNotExpr(const SCEV *Expr) {    return AddRHS->getOperand(1);  } -  /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values?  template<typename MaxExprType>  static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, @@ -9143,7 +9446,6 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr,    return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end();  } -  /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values?  template<typename MaxExprType>  static bool IsMinConsistingOf(ScalarEvolution &SE, @@ -9159,7 +9461,6 @@ static bool IsMinConsistingOf(ScalarEvolution &SE,  static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,                                             ICmpInst::Predicate Pred,                                             const SCEV *LHS, const SCEV *RHS) { -    // If both sides are affine addrecs for the same loop, with equal    // steps, and we know the recurrences don't wrap, then we only    // need to check the predicate on the starting values. @@ -9295,7 +9596,9 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred,    } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) {      Value *LL, *LR;      // FIXME: Once we have SDiv implemented, we can get rid of this matching. +      using namespace llvm::PatternMatch; +      if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) {        // Rules for division.        // We are going to perform some comparisons with Denominator and its @@ -9510,14 +9813,54 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step,    return getUDivExpr(Delta, Step);  } +const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, +                                                    const SCEV *Stride, +                                                    const SCEV *End, +                                                    unsigned BitWidth, +                                                    bool IsSigned) { + +  assert(!isKnownNonPositive(Stride) && +         "Stride is expected strictly positive!"); +  // Calculate the maximum backedge count based on the range of values +  // permitted by Start, End, and Stride. +  const SCEV *MaxBECount; +  APInt MinStart = +      IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start); + +  APInt StrideForMaxBECount = +      IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride); + +  // We already know that the stride is positive, so we paper over conservatism +  // in our range computation by forcing StrideForMaxBECount to be at least one. +  // In theory this is unnecessary, but we expect MaxBECount to be a +  // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there +  // is nothing to constant fold it to). +  APInt One(BitWidth, 1, IsSigned); +  StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount); + +  APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth) +                            : APInt::getMaxValue(BitWidth); +  APInt Limit = MaxValue - (StrideForMaxBECount - 1); + +  // Although End can be a MAX expression we estimate MaxEnd considering only +  // the case End = RHS of the loop termination condition. This is safe because +  // in the other case (End - Start) is zero, leading to a zero maximum backedge +  // taken count. +  APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) +                          : APIntOps::umin(getUnsignedRangeMax(End), Limit); + +  MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */, +                              getConstant(StrideForMaxBECount) /* Step */, +                              false /* Equality */); + +  return MaxBECount; +} +  ScalarEvolution::ExitLimit  ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,                                    const Loop *L, bool IsSigned,                                    bool ControlsExit, bool AllowPredicates) {    SmallPtrSet<const SCEVPredicate *, 4> Predicates; -  // We handle only IV < Invariant -  if (!isLoopInvariant(RHS, L)) -    return getCouldNotCompute();    const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS);    bool PredicatedIV = false; @@ -9588,7 +9931,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,      if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) ||          !loopHasNoSideEffects(L))        return getCouldNotCompute(); -    } else if (!Stride->isOne() &&               doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap))      // Avoid proven overflow cases: this will ensure that the backedge taken @@ -9601,6 +9943,17 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,                                        : ICmpInst::ICMP_ULT;    const SCEV *Start = IV->getStart();    const SCEV *End = RHS; +  // When the RHS is not invariant, we do not know the end bound of the loop and +  // cannot calculate the ExactBECount needed by ExitLimit. However, we can +  // calculate the MaxBECount, given the start, stride and max value for the end +  // bound of the loop (RHS), and the fact that IV does not overflow (which is +  // checked above). +  if (!isLoopInvariant(RHS, L)) { +    const SCEV *MaxBECount = computeMaxBECountForLT( +        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); +    return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, +                     false /*MaxOrZero*/, Predicates); +  }    // If the backedge is taken at least once, then it will be taken    // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start    // is the LHS value of the less-than comparison the first time it is evaluated @@ -9633,37 +9986,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS,      MaxBECount = BECountIfBackedgeTaken;      MaxOrZero = true;    } else { -    // Calculate the maximum backedge count based on the range of values -    // permitted by Start, End, and Stride. -    APInt MinStart = IsSigned ? getSignedRangeMin(Start) -                              : getUnsignedRangeMin(Start); - -    unsigned BitWidth = getTypeSizeInBits(LHS->getType()); - -    APInt StrideForMaxBECount; - -    if (PositiveStride) -      StrideForMaxBECount = -        IsSigned ? getSignedRangeMin(Stride) -                 : getUnsignedRangeMin(Stride); -    else -      // Using a stride of 1 is safe when computing max backedge taken count for -      // a loop with unknown stride. -      StrideForMaxBECount = APInt(BitWidth, 1, IsSigned); - -    APInt Limit = -      IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1) -               : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1); - -    // Although End can be a MAX expression we estimate MaxEnd considering only -    // the case End = RHS. This is safe because in the other case (End - Start) -    // is zero, leading to a zero maximum backedge taken count. -    APInt MaxEnd = -      IsSigned ? APIntOps::smin(getSignedRangeMax(RHS), Limit) -               : APIntOps::umin(getUnsignedRangeMax(RHS), Limit); - -    MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), -                                getConstant(StrideForMaxBECount), false); +    MaxBECount = computeMaxBECountForLT( +        Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned);    }    if (isa<SCEVCouldNotCompute>(MaxBECount) && @@ -9874,6 +10198,7 @@ static inline bool containsUndefs(const SCEV *S) {  }  namespace { +  // Collect all steps of SCEV expressions.  struct SCEVCollectStrides {    ScalarEvolution &SE; @@ -9887,6 +10212,7 @@ struct SCEVCollectStrides {        Strides.push_back(AR->getStepRecurrence(SE));      return true;    } +    bool isDone() const { return false; }  }; @@ -9894,8 +10220,7 @@ struct SCEVCollectStrides {  struct SCEVCollectTerms {    SmallVectorImpl<const SCEV *> &Terms; -  SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) -      : Terms(T) {} +  SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {}    bool follow(const SCEV *S) {      if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || @@ -9910,6 +10235,7 @@ struct SCEVCollectTerms {      // Keep looking.      return true;    } +    bool isDone() const { return false; }  }; @@ -9918,7 +10244,7 @@ struct SCEVHasAddRec {    bool &ContainsAddRec;    SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { -   ContainsAddRec = false; +    ContainsAddRec = false;    }    bool follow(const SCEV *S) { @@ -9932,6 +10258,7 @@ struct SCEVHasAddRec {      // Keep looking.      return true;    } +    bool isDone() const { return false; }  }; @@ -9985,9 +10312,11 @@ struct SCEVCollectAddRecMultiplies {      // Keep looking.      return true;    } +    bool isDone() const { return false; }  }; -} + +} // end anonymous namespace  /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in  /// two places: @@ -10066,7 +10395,6 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE,    return true;  } -  // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter.  static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) {    for (const SCEV *T : Terms) @@ -10181,7 +10509,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms,  void ScalarEvolution::computeAccessFunctions(      const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts,      SmallVectorImpl<const SCEV *> &Sizes) { -    // Early exit in case this SCEV is not an affine multivariate function.    if (Sizes.empty())      return; @@ -10285,7 +10612,6 @@ void ScalarEvolution::computeAccessFunctions(  /// DelinearizationPass that walks through all loads and stores of a function  /// asking for the SCEV of the memory access with respect to all enclosing  /// loops, calling SCEV->delinearize on that and printing the results. -  void ScalarEvolution::delinearize(const SCEV *Expr,                                   SmallVectorImpl<const SCEV *> &Subscripts,                                   SmallVectorImpl<const SCEV *> &Sizes, @@ -10374,11 +10700,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI,                                   AssumptionCache &AC, DominatorTree &DT,                                   LoopInfo &LI)      : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), -      CouldNotCompute(new SCEVCouldNotCompute()), -      WalkingBEDominatingConds(false), ProvingSplitPredicate(false), -      ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), -      FirstUnknown(nullptr) { - +      CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), +      LoopDispositions(64), BlockDispositions(64) {    // To use guards for proving predicates, we need to scan every instruction in    // relevant basic blocks, and not just terminators.  Doing this is a waste of    // time if the IR does not actually contain any calls to @@ -10399,7 +10722,6 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)        LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)),        ValueExprMap(std::move(Arg.ValueExprMap)),        PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), -      WalkingBEDominatingConds(false), ProvingSplitPredicate(false),        MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)),        BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)),        PredicatedBackedgeTakenCounts( @@ -10415,6 +10737,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg)        UniqueSCEVs(std::move(Arg.UniqueSCEVs)),        UniquePreds(std::move(Arg.UniquePreds)),        SCEVAllocator(std::move(Arg.SCEVAllocator)), +      LoopUsers(std::move(Arg.LoopUsers)),        PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)),        FirstUnknown(Arg.FirstUnknown) {    Arg.FirstUnknown = nullptr; @@ -10647,9 +10970,11 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) {      if (!L)        return LoopVariant; -    // This recurrence is variant w.r.t. L if L contains AR's loop. -    if (L->contains(AR->getLoop())) +    // Everything that is not defined at loop entry is variant. +    if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader()))        return LoopVariant; +    assert(!L->contains(AR->getLoop()) && "Containing loop's header does not" +           " dominate the contained loop's header?");      // This recurrence is invariant w.r.t. L if AR's loop contains L.      if (AR->getLoop()->contains(L)) @@ -10806,7 +11131,16 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const {    return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; });  } -void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { +bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const { +  auto IsS = [&](const SCEV *X) { return S == X; }; +  auto ContainsS = [&](const SCEV *X) { +    return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS); +  }; +  return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken); +} + +void +ScalarEvolution::forgetMemoizedResults(const SCEV *S) {    ValuesAtScopes.erase(S);    LoopDispositions.erase(S);    BlockDispositions.erase(S); @@ -10816,7 +11150,7 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {    HasRecMap.erase(S);    MinTrailingZerosCache.erase(S); -  for (auto I = PredicatedSCEVRewrites.begin();  +  for (auto I = PredicatedSCEVRewrites.begin();         I != PredicatedSCEVRewrites.end();) {      std::pair<const SCEV *, const Loop *> Entry = I->first;      if (Entry.first == S) @@ -10841,6 +11175,25 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) {    RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);  } +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { +  struct FindUsedLoops { +    SmallPtrSet<const Loop *, 8> LoopsUsed; +    bool follow(const SCEV *S) { +      if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) +        LoopsUsed.insert(AR->getLoop()); +      return true; +    } + +    bool isDone() const { return false; } +  }; + +  FindUsedLoops F; +  SCEVTraversal<FindUsedLoops>(F).visitAll(S); + +  for (auto *L : F.LoopsUsed) +    LoopUsers[L].push_back(S); +} +  void ScalarEvolution::verify() const {    ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this);    ScalarEvolution SE2(F, TLI, AC, DT, LI); @@ -10849,9 +11202,12 @@ void ScalarEvolution::verify() const {    // Map's SCEV expressions from one ScalarEvolution "universe" to another.    struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { +    SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} +      const SCEV *visitConstant(const SCEVConstant *Constant) {        return SE.getConstant(Constant->getAPInt());      } +      const SCEV *visitUnknown(const SCEVUnknown *Expr) {        return SE.getUnknown(Expr->getValue());      } @@ -10859,7 +11215,6 @@ void ScalarEvolution::verify() const {      const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) {        return SE.getCouldNotCompute();      } -    SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {}    };    SCEVMapper SCM(SE2); @@ -10948,6 +11303,7 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)  INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)  INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution",                      "Scalar Evolution Analysis", false, true) +  char ScalarEvolutionWrapperPass::ID = 0;  ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { @@ -11023,6 +11379,7 @@ namespace {  class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> {  public: +    /// Rewrites \p S in the context of a loop L and the SCEV predication    /// infrastructure.    /// @@ -11038,11 +11395,6 @@ public:      return Rewriter.visit(S);    } -  SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, -                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, -                        SCEVUnionPredicate *Pred) -      : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} -    const SCEV *visitUnknown(const SCEVUnknown *Expr) {      if (Pred) {        auto ExprPreds = Pred->getPredicatesForExpr(Expr); @@ -11087,6 +11439,11 @@ public:    }  private: +  explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, +                        SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, +                        SCEVUnionPredicate *Pred) +      : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} +    bool addOverflowAssumption(const SCEVPredicate *P) {      if (!NewPreds) {        // Check if we've already made this assumption. @@ -11103,10 +11460,10 @@ private:    }    // If \p Expr represents a PHINode, we try to see if it can be represented -  // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible  +  // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible    // to add this predicate as a runtime overflow check, we return the AddRec. -  // If \p Expr does not meet these conditions (is not a PHI node, or we  -  // couldn't create an AddRec for it, or couldn't add the predicate), we just  +  // If \p Expr does not meet these conditions (is not a PHI node, or we +  // couldn't create an AddRec for it, or couldn't add the predicate), we just    // return \p Expr.    const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) {      if (!isa<PHINode>(Expr->getValue())) @@ -11121,11 +11478,12 @@ private:      }      return PredicatedRewrite->first;    } -   +    SmallPtrSetImpl<const SCEVPredicate *> *NewPreds;    SCEVUnionPredicate *Pred;    const Loop *L;  }; +  } // end anonymous namespace  const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, @@ -11136,7 +11494,6 @@ const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L,  const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates(      const SCEV *S, const Loop *L,      SmallPtrSetImpl<const SCEVPredicate *> &Preds) { -    SmallPtrSet<const SCEVPredicate *, 4> TransformPreds;    S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr);    auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); @@ -11292,7 +11649,7 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) {  PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE,                                                       Loop &L) -    : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} +    : SE(SE), L(L) {}  const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) {    const SCEV *Expr = SE.getSCEV(V); | 
