diff options
Diffstat (limited to 'lib/Analysis/ScalarEvolution.cpp')
-rw-r--r-- | lib/Analysis/ScalarEvolution.cpp | 933 |
1 files changed, 645 insertions, 288 deletions
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp index 9539fd7c7559..0b8604187121 100644 --- a/lib/Analysis/ScalarEvolution.cpp +++ b/lib/Analysis/ScalarEvolution.cpp @@ -59,12 +59,23 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/FoldingSet.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" @@ -72,28 +83,55 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/SaveAndRestore.h" #include "llvm/Support/raw_ostream.h" #include <algorithm> +#include <cassert> +#include <climits> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <map> +#include <memory> +#include <tuple> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "scalar-evolution" @@ -115,11 +153,11 @@ MaxBruteForceIterations("scalar-evolution-max-iterations", cl::ReallyHidden, cl::init(100)); // FIXME: Enable this with EXPENSIVE_CHECKS when the test suite is clean. +static cl::opt<bool> VerifySCEV( + "verify-scev", cl::Hidden, + cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); static cl::opt<bool> -VerifySCEV("verify-scev", - cl::desc("Verify ScalarEvolution's backedge taken counts (slow)")); -static cl::opt<bool> - VerifySCEVMap("verify-scev-maps", + VerifySCEVMap("verify-scev-maps", cl::Hidden, cl::desc("Verify no dangling value in ScalarEvolution's " "ExprValueMap (slow)")); @@ -415,9 +453,6 @@ void SCEVUnknown::deleted() { } void SCEVUnknown::allUsesReplacedWith(Value *New) { - // Clear this SCEVUnknown from various maps. - SE->forgetMemoizedResults(this); - // Remove this SCEVUnknown from the uniquing map. SE->UniqueSCEVs.RemoveNode(this); @@ -514,10 +549,10 @@ bool SCEVUnknown::isOffsetOf(Type *&CTy, Constant *&FieldNo) const { /// Since we do not continue running this routine on expression trees once we /// have seen unequal values, there is no need to track them in the cache. static int -CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, +CompareValueComplexity(EquivalenceClasses<const Value *> &EqCacheValue, const LoopInfo *const LI, Value *LV, Value *RV, unsigned Depth) { - if (Depth > MaxValueCompareDepth || EqCache.count({LV, RV})) + if (Depth > MaxValueCompareDepth || EqCacheValue.isEquivalent(LV, RV)) return 0; // Order pointer values after integer values. This helps SCEVExpander form @@ -577,14 +612,14 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, for (unsigned Idx : seq(0u, LNumOps)) { int Result = - CompareValueComplexity(EqCache, LI, LInst->getOperand(Idx), + CompareValueComplexity(EqCacheValue, LI, LInst->getOperand(Idx), RInst->getOperand(Idx), Depth + 1); if (Result != 0) return Result; } } - EqCache.insert({LV, RV}); + EqCacheValue.unionSets(LV, RV); return 0; } @@ -592,7 +627,8 @@ CompareValueComplexity(SmallSet<std::pair<Value *, Value *>, 8> &EqCache, // than RHS, respectively. A three-way result allows recursive comparisons to be // more efficient. static int CompareSCEVComplexity( - SmallSet<std::pair<const SCEV *, const SCEV *>, 8> &EqCacheSCEV, + EquivalenceClasses<const SCEV *> &EqCacheSCEV, + EquivalenceClasses<const Value *> &EqCacheValue, const LoopInfo *const LI, const SCEV *LHS, const SCEV *RHS, DominatorTree &DT, unsigned Depth = 0) { // Fast-path: SCEVs are uniqued so we can do a quick equality check. @@ -604,7 +640,7 @@ static int CompareSCEVComplexity( if (LType != RType) return (int)LType - (int)RType; - if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.count({LHS, RHS})) + if (Depth > MaxSCEVCompareDepth || EqCacheSCEV.isEquivalent(LHS, RHS)) return 0; // Aside from the getSCEVType() ordering, the particular ordering // isn't very important except that it's beneficial to be consistent, @@ -614,11 +650,10 @@ static int CompareSCEVComplexity( const SCEVUnknown *LU = cast<SCEVUnknown>(LHS); const SCEVUnknown *RU = cast<SCEVUnknown>(RHS); - SmallSet<std::pair<Value *, Value *>, 8> EqCache; - int X = CompareValueComplexity(EqCache, LI, LU->getValue(), RU->getValue(), - Depth + 1); + int X = CompareValueComplexity(EqCacheValue, LI, LU->getValue(), + RU->getValue(), Depth + 1); if (X == 0) - EqCacheSCEV.insert({LHS, RHS}); + EqCacheSCEV.unionSets(LHS, RHS); return X; } @@ -659,14 +694,19 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; + // Compare NoWrap flags. + if (LA->getNoWrapFlags() != RA->getNoWrapFlags()) + return (int)LA->getNoWrapFlags() - (int)RA->getNoWrapFlags(); + // Lexicographically compare. for (unsigned i = 0; i != LNumOps; ++i) { - int X = CompareSCEVComplexity(EqCacheSCEV, LI, LA->getOperand(i), - RA->getOperand(i), DT, Depth + 1); + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LA->getOperand(i), RA->getOperand(i), DT, + Depth + 1); if (X != 0) return X; } - EqCacheSCEV.insert({LHS, RHS}); + EqCacheSCEV.unionSets(LHS, RHS); return 0; } @@ -682,15 +722,18 @@ static int CompareSCEVComplexity( if (LNumOps != RNumOps) return (int)LNumOps - (int)RNumOps; + // Compare NoWrap flags. + if (LC->getNoWrapFlags() != RC->getNoWrapFlags()) + return (int)LC->getNoWrapFlags() - (int)RC->getNoWrapFlags(); + for (unsigned i = 0; i != LNumOps; ++i) { - if (i >= RNumOps) - return 1; - int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(i), - RC->getOperand(i), DT, Depth + 1); + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LC->getOperand(i), RC->getOperand(i), DT, + Depth + 1); if (X != 0) return X; } - EqCacheSCEV.insert({LHS, RHS}); + EqCacheSCEV.unionSets(LHS, RHS); return 0; } @@ -699,14 +742,14 @@ static int CompareSCEVComplexity( const SCEVUDivExpr *RC = cast<SCEVUDivExpr>(RHS); // Lexicographically compare udiv expressions. - int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getLHS(), RC->getLHS(), - DT, Depth + 1); + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getLHS(), + RC->getLHS(), DT, Depth + 1); if (X != 0) return X; - X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getRHS(), RC->getRHS(), DT, - Depth + 1); + X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, LC->getRHS(), + RC->getRHS(), DT, Depth + 1); if (X == 0) - EqCacheSCEV.insert({LHS, RHS}); + EqCacheSCEV.unionSets(LHS, RHS); return X; } @@ -717,10 +760,11 @@ static int CompareSCEVComplexity( const SCEVCastExpr *RC = cast<SCEVCastExpr>(RHS); // Compare cast expressions by operand. - int X = CompareSCEVComplexity(EqCacheSCEV, LI, LC->getOperand(), - RC->getOperand(), DT, Depth + 1); + int X = CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LC->getOperand(), RC->getOperand(), DT, + Depth + 1); if (X == 0) - EqCacheSCEV.insert({LHS, RHS}); + EqCacheSCEV.unionSets(LHS, RHS); return X; } @@ -739,26 +783,26 @@ static int CompareSCEVComplexity( /// results from this routine. In other words, we don't want the results of /// this to depend on where the addresses of various SCEV objects happened to /// land in memory. -/// static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, LoopInfo *LI, DominatorTree &DT) { if (Ops.size() < 2) return; // Noop - SmallSet<std::pair<const SCEV *, const SCEV *>, 8> EqCache; + EquivalenceClasses<const SCEV *> EqCacheSCEV; + EquivalenceClasses<const Value *> EqCacheValue; if (Ops.size() == 2) { // This is the common case, which also happens to be trivially simple. // Special case it. const SCEV *&LHS = Ops[0], *&RHS = Ops[1]; - if (CompareSCEVComplexity(EqCache, LI, RHS, LHS, DT) < 0) + if (CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, RHS, LHS, DT) < 0) std::swap(LHS, RHS); return; } // Do the rough sort by complexity. std::stable_sort(Ops.begin(), Ops.end(), - [&EqCache, LI, &DT](const SCEV *LHS, const SCEV *RHS) { - return - CompareSCEVComplexity(EqCache, LI, LHS, RHS, DT) < 0; + [&](const SCEV *LHS, const SCEV *RHS) { + return CompareSCEVComplexity(EqCacheSCEV, EqCacheValue, LI, + LHS, RHS, DT) < 0; }); // Now that we are sorted by complexity, group elements of the same @@ -785,14 +829,16 @@ static void GroupByComplexity(SmallVectorImpl<const SCEV *> &Ops, // Returns the size of the SCEV S. static inline int sizeOfSCEV(const SCEV *S) { struct FindSCEVSize { - int Size; - FindSCEVSize() : Size(0) {} + int Size = 0; + + FindSCEVSize() = default; bool follow(const SCEV *S) { ++Size; // Keep looking at all operands of S. return true; } + bool isDone() const { return false; } @@ -1032,7 +1078,7 @@ private: const SCEV *Denominator, *Quotient, *Remainder, *Zero, *One; }; -} +} // end anonymous namespace //===----------------------------------------------------------------------===// // Simple SCEV method implementations @@ -1157,7 +1203,6 @@ static const SCEV *BinomialCoefficient(const SCEV *It, unsigned K, /// A*BC(It, 0) + B*BC(It, 1) + C*BC(It, 2) + D*BC(It, 3) /// /// where BC(It, k) stands for binomial coefficient. -/// const SCEV *SCEVAddRecExpr::evaluateAtIteration(const SCEV *It, ScalarEvolution &SE) const { const SCEV *Result = getStart(); @@ -1256,6 +1301,7 @@ const SCEV *ScalarEvolution::getTruncateExpr(const SCEV *Op, SCEV *S = new (SCEVAllocator) SCEVTruncateExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1343,7 +1389,8 @@ struct ExtendOpTraits<SCEVZeroExtendExpr> : public ExtendOpTraitsBase { const ExtendOpTraitsBase::GetExtendExprTy ExtendOpTraits< SCEVZeroExtendExpr>::GetExtendExpr = &ScalarEvolution::getZeroExtendExpr; -} + +} // end anonymous namespace // The recurrence AR has been shown to have no signed/unsigned wrap or something // close to it. Typically, if we can prove NSW/NUW for AR, then we can just as @@ -1473,7 +1520,6 @@ static const SCEV *getExtendAddRecStart(const SCEVAddRecExpr *AR, Type *Ty, // // In the current context, S is `Start`, X is `Step`, Ext is `ExtendOpTy` and T // is `Delta` (defined below). -// template <typename ExtendOpTy> bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, const SCEV *Step, @@ -1484,7 +1530,6 @@ bool ScalarEvolution::proveNoWrapByVaryingStart(const SCEV *Start, // time here. It is correct (but more expensive) to continue with a // non-constant `Start` and do a general SCEV subtraction to compute // `PreStart` below. - // const SCEVConstant *StartC = dyn_cast<SCEVConstant>(Start); if (!StartC) return false; @@ -1547,6 +1592,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1733,6 +1779,7 @@ ScalarEvolution::getZeroExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV *S = new (SCEVAllocator) SCEVZeroExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1770,6 +1817,7 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -1981,12 +2029,12 @@ ScalarEvolution::getSignExtendExpr(const SCEV *Op, Type *Ty, unsigned Depth) { SCEV *S = new (SCEVAllocator) SCEVSignExtendExpr(ID.Intern(SCEVAllocator), Op, Ty); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } /// getAnyExtendExpr - Return a SCEV for the given operand extended with /// unspecified bits out to the given type. -/// const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, Type *Ty) { assert(getTypeSizeInBits(Op->getType()) < getTypeSizeInBits(Ty) && @@ -2057,7 +2105,6 @@ const SCEV *ScalarEvolution::getAnyExtendExpr(const SCEV *Op, /// may be exposed. This helps getAddRecExpr short-circuit extra work in /// the common case where no interesting opportunities are present, and /// is also used as a check to avoid infinite recursion. -/// static bool CollectAddOperandsWithScales(DenseMap<const SCEV *, APInt> &M, SmallVectorImpl<const SCEV *> &NewOps, @@ -2132,7 +2179,8 @@ StrengthenNoWrapFlags(ScalarEvolution *SE, SCEVTypes Type, const SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { using namespace std::placeholders; - typedef OverflowingBinaryOperator OBO; + + using OBO = OverflowingBinaryOperator; bool CanAnalyze = Type == scAddExpr || Type == scAddRecExpr || Type == scMulExpr; @@ -2306,12 +2354,23 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, // Check for truncates. If all the operands are truncated from the same // type, see if factoring out the truncate would permit the result to be - // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n) + // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y) // if the contents of the resulting outer trunc fold to something simple. - for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) { - const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]); - Type *DstType = Trunc->getType(); - Type *SrcType = Trunc->getOperand()->getType(); + auto FindTruncSrcType = [&]() -> Type * { + // We're ultimately looking to fold an addrec of truncs and muls of only + // constants and truncs, so if we find any other types of SCEV + // as operands of the addrec then we bail and return nullptr here. + // Otherwise, we return the type of the operand of a trunc that we find. + if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx])) + return T->getOperand()->getType(); + if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) { + const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1); + if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp)) + return T->getOperand()->getType(); + } + return nullptr; + }; + if (auto *SrcType = FindTruncSrcType()) { SmallVector<const SCEV *, 8> LargeOps; bool Ok = true; // Check all the operands to see if they can be represented in the @@ -2354,7 +2413,7 @@ const SCEV *ScalarEvolution::getAddExpr(SmallVectorImpl<const SCEV *> &Ops, const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1); // If it folds to something simple, use it. Otherwise, don't. if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold)) - return getTruncateExpr(Fold, DstType); + return getTruncateExpr(Fold, Ty); } } @@ -2608,8 +2667,8 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, SCEV::NoWrapFlags Flags) { FoldingSetNodeID ID; ID.AddInteger(scAddExpr); - for (unsigned i = 0, e = Ops.size(); i != e; ++i) - ID.AddPointer(Ops[i]); + for (const SCEV *Op : Ops) + ID.AddPointer(Op); void *IP = nullptr; SCEVAddExpr *S = static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP)); @@ -2619,6 +2678,7 @@ ScalarEvolution::getOrCreateAddExpr(SmallVectorImpl<const SCEV *> &Ops, S = new (SCEVAllocator) SCEVAddExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -2640,6 +2700,7 @@ ScalarEvolution::getOrCreateMulExpr(SmallVectorImpl<const SCEV *> &Ops, S = new (SCEVAllocator) SCEVMulExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -2679,20 +2740,24 @@ static uint64_t Choose(uint64_t n, uint64_t k, bool &Overflow) { /// Determine if any of the operands in this SCEV are a constant or if /// any of the add or multiply expressions in this SCEV contain a constant. -static bool containsConstantSomewhere(const SCEV *StartExpr) { - SmallVector<const SCEV *, 4> Ops; - Ops.push_back(StartExpr); - while (!Ops.empty()) { - const SCEV *CurrentExpr = Ops.pop_back_val(); - if (isa<SCEVConstant>(*CurrentExpr)) - return true; +static bool containsConstantInAddMulChain(const SCEV *StartExpr) { + struct FindConstantInAddMulChain { + bool FoundConstant = false; - if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) { - const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr); - Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end()); + bool follow(const SCEV *S) { + FoundConstant |= isa<SCEVConstant>(S); + return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S); } - } - return false; + + bool isDone() const { + return FoundConstant; + } + }; + + FindConstantInAddMulChain F; + SCEVTraversal<FindConstantInAddMulChain> ST(F); + ST.visitAll(StartExpr); + return F.FoundConstant; } /// Get a canonical multiply expression, or something simpler if possible. @@ -2729,7 +2794,11 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, // If any of Add's ops are Adds or Muls with a constant, // apply this transformation as well. if (Add->getNumOperands() == 2) - if (containsConstantSomewhere(Add)) + // TODO: There are some cases where this transformation is not + // profitable, for example: + // Add = (C0 + X) * Y + Z. + // Maybe the scope of this transformation should be narrowed down. + if (containsConstantInAddMulChain(Add)) return getAddExpr(getMulExpr(LHSC, Add->getOperand(0), SCEV::FlagAnyWrap, Depth + 1), getMulExpr(LHSC, Add->getOperand(1), @@ -2941,6 +3010,34 @@ const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops, return getOrCreateMulExpr(Ops, Flags); } +/// Represents an unsigned remainder expression based on unsigned division. +const SCEV *ScalarEvolution::getURemExpr(const SCEV *LHS, + const SCEV *RHS) { + assert(getEffectiveSCEVType(LHS->getType()) == + getEffectiveSCEVType(RHS->getType()) && + "SCEVURemExpr operand types don't match!"); + + // Short-circuit easy cases + if (const SCEVConstant *RHSC = dyn_cast<SCEVConstant>(RHS)) { + // If constant is one, the result is trivial + if (RHSC->getValue()->isOne()) + return getZero(LHS->getType()); // X urem 1 --> 0 + + // If constant is a power of two, fold into a zext(trunc(LHS)). + if (RHSC->getAPInt().isPowerOf2()) { + Type *FullTy = LHS->getType(); + Type *TruncTy = + IntegerType::get(getContext(), RHSC->getAPInt().logBase2()); + return getZeroExtendExpr(getTruncateExpr(LHS, TruncTy), FullTy); + } + } + + // Fallback to %a == %x urem %y == %x -<nuw> ((%x udiv %y) *<nuw> %y) + const SCEV *UDiv = getUDivExpr(LHS, RHS); + const SCEV *Mult = getMulExpr(UDiv, RHS, SCEV::FlagNUW); + return getMinusSCEV(LHS, Mult, SCEV::FlagNUW); +} + /// Get a canonical unsigned division expression, or something simpler if /// possible. const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, @@ -3056,6 +3153,7 @@ const SCEV *ScalarEvolution::getUDivExpr(const SCEV *LHS, SCEV *S = new (SCEVAllocator) SCEVUDivExpr(ID.Intern(SCEVAllocator), LHS, RHS); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -3236,6 +3334,7 @@ ScalarEvolution::getAddRecExpr(SmallVectorImpl<const SCEV *> &Operands, S = new (SCEVAllocator) SCEVAddRecExpr(ID.Intern(SCEVAllocator), O, Operands.size(), L); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); } S->setNoWrapFlags(Flags); return S; @@ -3391,6 +3490,7 @@ ScalarEvolution::getSMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { SCEV *S = new (SCEVAllocator) SCEVSMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -3492,6 +3592,7 @@ ScalarEvolution::getUMaxExpr(SmallVectorImpl<const SCEV *> &Ops) { SCEV *S = new (SCEVAllocator) SCEVUMaxExpr(ID.Intern(SCEVAllocator), O, Ops.size()); UniqueSCEVs.InsertNode(S, IP); + addToLoopUseLists(S); return S; } @@ -3714,7 +3815,6 @@ const SCEV *ScalarEvolution::getExistingSCEV(Value *V) { } /// Return a SCEV corresponding to -V = -1*V -/// const SCEV *ScalarEvolution::getNegativeSCEV(const SCEV *V, SCEV::NoWrapFlags Flags) { if (const SCEVConstant *VC = dyn_cast<SCEVConstant>(V)) @@ -3957,6 +4057,7 @@ void ScalarEvolution::forgetSymbolicName(Instruction *PN, const SCEV *SymName) { } namespace { + class SCEVInitRewriter : public SCEVRewriteVisitor<SCEVInitRewriter> { public: static const SCEV *rewrite(const SCEV *S, const Loop *L, @@ -3966,9 +4067,6 @@ public: return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } - SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) - : SCEVRewriteVisitor(SE), L(L), Valid(true) {} - const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (!SE.isLoopInvariant(Expr, L)) Valid = false; @@ -3986,10 +4084,93 @@ public: bool isValid() { return Valid; } private: + explicit SCEVInitRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + + const Loop *L; + bool Valid = true; +}; + +/// This class evaluates the compare condition by matching it against the +/// condition of loop latch. If there is a match we assume a true value +/// for the condition while building SCEV nodes. +class SCEVBackedgeConditionFolder + : public SCEVRewriteVisitor<SCEVBackedgeConditionFolder> { +public: + static const SCEV *rewrite(const SCEV *S, const Loop *L, + ScalarEvolution &SE) { + bool IsPosBECond = false; + Value *BECond = nullptr; + if (BasicBlock *Latch = L->getLoopLatch()) { + BranchInst *BI = dyn_cast<BranchInst>(Latch->getTerminator()); + if (BI && BI->isConditional()) { + assert(BI->getSuccessor(0) != BI->getSuccessor(1) && + "Both outgoing branches should not target same header!"); + BECond = BI->getCondition(); + IsPosBECond = BI->getSuccessor(0) == L->getHeader(); + } else { + return S; + } + } + SCEVBackedgeConditionFolder Rewriter(L, BECond, IsPosBECond, SE); + return Rewriter.visit(S); + } + + const SCEV *visitUnknown(const SCEVUnknown *Expr) { + const SCEV *Result = Expr; + bool InvariantF = SE.isLoopInvariant(Expr, L); + + if (!InvariantF) { + Instruction *I = cast<Instruction>(Expr->getValue()); + switch (I->getOpcode()) { + case Instruction::Select: { + SelectInst *SI = cast<SelectInst>(I); + Optional<const SCEV *> Res = + compareWithBackedgeCondition(SI->getCondition()); + if (Res.hasValue()) { + bool IsOne = cast<SCEVConstant>(Res.getValue())->getValue()->isOne(); + Result = SE.getSCEV(IsOne ? SI->getTrueValue() : SI->getFalseValue()); + } + break; + } + default: { + Optional<const SCEV *> Res = compareWithBackedgeCondition(I); + if (Res.hasValue()) + Result = Res.getValue(); + break; + } + } + } + return Result; + } + +private: + explicit SCEVBackedgeConditionFolder(const Loop *L, Value *BECond, + bool IsPosBECond, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L), BackedgeCond(BECond), + IsPositiveBECond(IsPosBECond) {} + + Optional<const SCEV *> compareWithBackedgeCondition(Value *IC); + const Loop *L; - bool Valid; + /// Loop back condition. + Value *BackedgeCond = nullptr; + /// Set to true if loop back is on positive branch condition. + bool IsPositiveBECond; }; +Optional<const SCEV *> +SCEVBackedgeConditionFolder::compareWithBackedgeCondition(Value *IC) { + + // If value matches the backedge condition for loop latch, + // then return a constant evolution node based on loopback + // branch taken. + if (BackedgeCond == IC) + return IsPositiveBECond ? SE.getOne(Type::getInt1Ty(SE.getContext())) + : SE.getZero(Type::getInt1Ty(SE.getContext())); + return None; +} + class SCEVShiftRewriter : public SCEVRewriteVisitor<SCEVShiftRewriter> { public: static const SCEV *rewrite(const SCEV *S, const Loop *L, @@ -3999,9 +4180,6 @@ public: return Rewriter.isValid() ? Result : SE.getCouldNotCompute(); } - SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) - : SCEVRewriteVisitor(SE), L(L), Valid(true) {} - const SCEV *visitUnknown(const SCEVUnknown *Expr) { // Only allow AddRecExprs for this loop. if (!SE.isLoopInvariant(Expr, L)) @@ -4015,12 +4193,17 @@ public: Valid = false; return Expr; } + bool isValid() { return Valid; } private: + explicit SCEVShiftRewriter(const Loop *L, ScalarEvolution &SE) + : SCEVRewriteVisitor(SE), L(L) {} + const Loop *L; - bool Valid; + bool Valid = true; }; + } // end anonymous namespace SCEV::NoWrapFlags @@ -4028,7 +4211,8 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { if (!AR->isAffine()) return SCEV::FlagAnyWrap; - typedef OverflowingBinaryOperator OBO; + using OBO = OverflowingBinaryOperator; + SCEV::NoWrapFlags Result = SCEV::FlagAnyWrap; if (!AR->hasNoSignedWrap()) { @@ -4055,6 +4239,7 @@ ScalarEvolution::proveNoWrapViaConstantRanges(const SCEVAddRecExpr *AR) { } namespace { + /// Represents an abstract binary operation. This may exist as a /// normal instruction or constant expression, or may have been /// derived from an expression tree. @@ -4062,16 +4247,16 @@ struct BinaryOp { unsigned Opcode; Value *LHS; Value *RHS; - bool IsNSW; - bool IsNUW; + bool IsNSW = false; + bool IsNUW = false; /// Op is set if this BinaryOp corresponds to a concrete LLVM instruction or /// constant expression. - Operator *Op; + Operator *Op = nullptr; explicit BinaryOp(Operator *Op) : Opcode(Op->getOpcode()), LHS(Op->getOperand(0)), RHS(Op->getOperand(1)), - IsNSW(false), IsNUW(false), Op(Op) { + Op(Op) { if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(Op)) { IsNSW = OBO->hasNoSignedWrap(); IsNUW = OBO->hasNoUnsignedWrap(); @@ -4080,11 +4265,10 @@ struct BinaryOp { explicit BinaryOp(unsigned Opcode, Value *LHS, Value *RHS, bool IsNSW = false, bool IsNUW = false) - : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW), - Op(nullptr) {} + : Opcode(Opcode), LHS(LHS), RHS(RHS), IsNSW(IsNSW), IsNUW(IsNUW) {} }; -} +} // end anonymous namespace /// Try to map \p V into a BinaryOp, and return \c None on failure. static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { @@ -4101,6 +4285,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { case Instruction::Sub: case Instruction::Mul: case Instruction::UDiv: + case Instruction::URem: case Instruction::And: case Instruction::Or: case Instruction::AShr: @@ -4145,7 +4330,7 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { if (auto *F = CI->getCalledFunction()) switch (F->getIntrinsicID()) { case Intrinsic::sadd_with_overflow: - case Intrinsic::uadd_with_overflow: { + case Intrinsic::uadd_with_overflow: if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) return BinaryOp(Instruction::Add, CI->getArgOperand(0), CI->getArgOperand(1)); @@ -4161,13 +4346,21 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { return BinaryOp(Instruction::Add, CI->getArgOperand(0), CI->getArgOperand(1), /* IsNSW = */ false, /* IsNUW*/ true); - } - case Intrinsic::ssub_with_overflow: case Intrinsic::usub_with_overflow: - return BinaryOp(Instruction::Sub, CI->getArgOperand(0), - CI->getArgOperand(1)); + if (!isOverflowIntrinsicNoWrap(cast<IntrinsicInst>(CI), DT)) + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1)); + // The same reasoning as sadd/uadd above. + if (F->getIntrinsicID() == Intrinsic::ssub_with_overflow) + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ true, + /* IsNUW = */ false); + else + return BinaryOp(Instruction::Sub, CI->getArgOperand(0), + CI->getArgOperand(1), /* IsNSW = */ false, + /* IsNUW = */ true); case Intrinsic::smul_with_overflow: case Intrinsic::umul_with_overflow: return BinaryOp(Instruction::Mul, CI->getArgOperand(0), @@ -4184,28 +4377,27 @@ static Optional<BinaryOp> MatchBinaryOp(Value *V, DominatorTree &DT) { return None; } -/// Helper function to createAddRecFromPHIWithCasts. We have a phi +/// Helper function to createAddRecFromPHIWithCasts. We have a phi /// node whose symbolic (unknown) SCEV is \p SymbolicPHI, which is updated via -/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the -/// way. This function checks if \p Op, an operand of this SCEVAddExpr, +/// the loop backedge by a SCEVAddExpr, possibly also with a few casts on the +/// way. This function checks if \p Op, an operand of this SCEVAddExpr, /// follows one of the following patterns: /// Op == (SExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) /// Op == (ZExt ix (Trunc iy (%SymbolicPHI) to ix) to iy) /// If the SCEV expression of \p Op conforms with one of the expected patterns /// we return the type of the truncation operation, and indicate whether the -/// truncated type should be treated as signed/unsigned by setting +/// truncated type should be treated as signed/unsigned by setting /// \p Signed to true/false, respectively. static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, bool &Signed, ScalarEvolution &SE) { - - // The case where Op == SymbolicPHI (that is, with no type conversions on - // the way) is handled by the regular add recurrence creating logic and + // The case where Op == SymbolicPHI (that is, with no type conversions on + // the way) is handled by the regular add recurrence creating logic and // would have already been triggered in createAddRecForPHI. Reaching it here - // means that createAddRecFromPHI had failed for this PHI before (e.g., + // means that createAddRecFromPHI had failed for this PHI before (e.g., // because one of the other operands of the SCEVAddExpr updating this PHI is - // not invariant). + // not invariant). // - // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in + // Here we look for the case where Op = (ext(trunc(SymbolicPHI))), and in // this case predicates that allow us to prove that Op == SymbolicPHI will // be added. if (Op == SymbolicPHI) @@ -4228,7 +4420,7 @@ static Type *isSimpleCastedPHI(const SCEV *Op, const SCEVUnknown *SymbolicPHI, const SCEV *X = Trunc->getOperand(); if (X != SymbolicPHI) return nullptr; - Signed = SExt ? true : false; + Signed = SExt != nullptr; return Trunc->getType(); } @@ -4257,7 +4449,7 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { // It will visitMul->visitAdd->visitSExt->visitTrunc->visitUnknown(%X), // and call this function with %SymbolicPHI = %X. // -// The analysis will find that the value coming around the backedge has +// The analysis will find that the value coming around the backedge has // the following SCEV: // BEValue = ((sext i32 (trunc i64 %X to i32) to i64) + %Step) // Upon concluding that this matches the desired pattern, the function @@ -4270,21 +4462,21 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { // The returned pair means that SymbolicPHI can be rewritten into NewAddRec // under the predicates {P1,P2,P3}. // This predicated rewrite will be cached in PredicatedSCEVRewrites: -// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} +// PredicatedSCEVRewrites[{%X,L}] = {NewAddRec, {P1,P2,P3)} // // TODO's: // // 1) Extend the Induction descriptor to also support inductions that involve -// casts: When needed (namely, when we are called in the context of the -// vectorizer induction analysis), a Set of cast instructions will be +// casts: When needed (namely, when we are called in the context of the +// vectorizer induction analysis), a Set of cast instructions will be // populated by this method, and provided back to isInductionPHI. This is // needed to allow the vectorizer to properly record them to be ignored by // the cost model and to avoid vectorizing them (otherwise these casts, -// which are redundant under the runtime overflow checks, will be -// vectorized, which can be costly). +// which are redundant under the runtime overflow checks, will be +// vectorized, which can be costly). // // 2) Support additional induction/PHISCEV patterns: We also want to support -// inductions where the sext-trunc / zext-trunc operations (partly) occur +// inductions where the sext-trunc / zext-trunc operations (partly) occur // after the induction update operation (the induction increment): // // (Trunc iy (SExt/ZExt ix (%SymbolicPHI + InvariantAccum) to iy) to ix) @@ -4294,17 +4486,16 @@ static const Loop *isIntegerLoopHeaderPHI(const PHINode *PN, LoopInfo &LI) { // which correspond to a phi->trunc->add->sext/zext->phi update chain. // // 3) Outline common code with createAddRecFromPHI to avoid duplication. -// Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI) { SmallVector<const SCEVPredicate *, 3> Predicates; - // *** Part1: Analyze if we have a phi-with-cast pattern for which we can + // *** Part1: Analyze if we have a phi-with-cast pattern for which we can // return an AddRec expression under some predicate. - + auto *PN = cast<PHINode>(SymbolicPHI->getValue()); const Loop *L = isIntegerLoopHeaderPHI(PN, LI); - assert (L && "Expecting an integer loop header phi"); + assert(L && "Expecting an integer loop header phi"); // The loop may have multiple entrances or multiple exits; we can analyze // this phi as an addrec if it has a unique entry value and a unique @@ -4339,12 +4530,12 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI return None; // If there is a single occurrence of the symbolic value, possibly - // casted, replace it with a recurrence. + // casted, replace it with a recurrence. unsigned FoundIndex = Add->getNumOperands(); Type *TruncTy = nullptr; bool Signed; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) - if ((TruncTy = + if ((TruncTy = isSimpleCastedPHI(Add->getOperand(i), SymbolicPHI, Signed, *this))) if (FoundIndex == e) { FoundIndex = i; @@ -4366,77 +4557,122 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI if (!isLoopInvariant(Accum, L)) return None; - - // *** Part2: Create the predicates + // *** Part2: Create the predicates // Analysis was successful: we have a phi-with-cast pattern for which we // can return an AddRec expression under the following predicates: // // P1: A Wrap predicate that guarantees that Trunc(Start) + i*Trunc(Accum) // fits within the truncated type (does not overflow) for i = 0 to n-1. - // P2: An Equal predicate that guarantees that + // P2: An Equal predicate that guarantees that // Start = (Ext ix (Trunc iy (Start) to ix) to iy) - // P3: An Equal predicate that guarantees that + // P3: An Equal predicate that guarantees that // Accum = (Ext ix (Trunc iy (Accum) to ix) to iy) // - // As we next prove, the above predicates guarantee that: + // As we next prove, the above predicates guarantee that: // Start + i*Accum = (Ext ix (Trunc iy ( Start + i*Accum ) to ix) to iy) // // // More formally, we want to prove that: - // Expr(i+1) = Start + (i+1) * Accum - // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // Expr(i+1) = Start + (i+1) * Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum // // Given that: - // 1) Expr(0) = Start - // 2) Expr(1) = Start + Accum + // 1) Expr(0) = Start + // 2) Expr(1) = Start + Accum // = (Ext ix (Trunc iy (Start) to ix) to iy) + Accum :: from P2 // 3) Induction hypothesis (step i): - // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + // Expr(i) = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum // // Proof: // Expr(i+1) = // = Start + (i+1)*Accum // = (Start + i*Accum) + Accum - // = Expr(i) + Accum - // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum + // = Expr(i) + Accum + // = (Ext ix (Trunc iy (Expr(i-1)) to ix) to iy) + Accum + Accum // :: from step i // - // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum + // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) + Accum + Accum // // = (Ext ix (Trunc iy (Start + (i-1)*Accum) to ix) to iy) // + (Ext ix (Trunc iy (Accum) to ix) to iy) // + Accum :: from P3 // - // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) + // = (Ext ix (Trunc iy ((Start + (i-1)*Accum) + Accum) to ix) to iy) // + Accum :: from P1: Ext(x)+Ext(y)=>Ext(x+y) // // = (Ext ix (Trunc iy (Start + i*Accum) to ix) to iy) + Accum - // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum + // = (Ext ix (Trunc iy (Expr(i)) to ix) to iy) + Accum // // By induction, the same applies to all iterations 1<=i<n: // - + // Create a truncated addrec for which we will add a no overflow check (P1). const SCEV *StartVal = getSCEV(StartValueV); - const SCEV *PHISCEV = + const SCEV *PHISCEV = getAddRecExpr(getTruncateExpr(StartVal, TruncTy), - getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); - const auto *AR = cast<SCEVAddRecExpr>(PHISCEV); + getTruncateExpr(Accum, TruncTy), L, SCEV::FlagAnyWrap); - SCEVWrapPredicate::IncrementWrapFlags AddedFlags = - Signed ? SCEVWrapPredicate::IncrementNSSW - : SCEVWrapPredicate::IncrementNUSW; - const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); - Predicates.push_back(AddRecPred); + // PHISCEV can be either a SCEVConstant or a SCEVAddRecExpr. + // ex: If truncated Accum is 0 and StartVal is a constant, then PHISCEV + // will be constant. + // + // If PHISCEV is a constant, then P1 degenerates into P2 or P3, so we don't + // add P1. + if (const auto *AR = dyn_cast<SCEVAddRecExpr>(PHISCEV)) { + SCEVWrapPredicate::IncrementWrapFlags AddedFlags = + Signed ? SCEVWrapPredicate::IncrementNSSW + : SCEVWrapPredicate::IncrementNUSW; + const SCEVPredicate *AddRecPred = getWrapPredicate(AR, AddedFlags); + Predicates.push_back(AddRecPred); + } // Create the Equal Predicates P2,P3: - auto AppendPredicate = [&](const SCEV *Expr) -> void { - assert (isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); + + // It is possible that the predicates P2 and/or P3 are computable at + // compile time due to StartVal and/or Accum being constants. + // If either one is, then we can check that now and escape if either P2 + // or P3 is false. + + // Construct the extended SCEV: (Ext ix (Trunc iy (Expr) to ix) to iy) + // for each of StartVal and Accum + auto getExtendedExpr = [&](const SCEV *Expr, + bool CreateSignExtend) -> const SCEV * { + assert(isLoopInvariant(Expr, L) && "Expr is expected to be invariant"); const SCEV *TruncatedExpr = getTruncateExpr(Expr, TruncTy); const SCEV *ExtendedExpr = - Signed ? getSignExtendExpr(TruncatedExpr, Expr->getType()) - : getZeroExtendExpr(TruncatedExpr, Expr->getType()); + CreateSignExtend ? getSignExtendExpr(TruncatedExpr, Expr->getType()) + : getZeroExtendExpr(TruncatedExpr, Expr->getType()); + return ExtendedExpr; + }; + + // Given: + // ExtendedExpr = (Ext ix (Trunc iy (Expr) to ix) to iy + // = getExtendedExpr(Expr) + // Determine whether the predicate P: Expr == ExtendedExpr + // is known to be false at compile time + auto PredIsKnownFalse = [&](const SCEV *Expr, + const SCEV *ExtendedExpr) -> bool { + return Expr != ExtendedExpr && + isKnownPredicate(ICmpInst::ICMP_NE, Expr, ExtendedExpr); + }; + + const SCEV *StartExtended = getExtendedExpr(StartVal, Signed); + if (PredIsKnownFalse(StartVal, StartExtended)) { + DEBUG(dbgs() << "P2 is compile-time false\n";); + return None; + } + + // The Step is always Signed (because the overflow checks are either + // NSSW or NUSW) + const SCEV *AccumExtended = getExtendedExpr(Accum, /*CreateSignExtend=*/true); + if (PredIsKnownFalse(Accum, AccumExtended)) { + DEBUG(dbgs() << "P3 is compile-time false\n";); + return None; + } + + auto AppendPredicate = [&](const SCEV *Expr, + const SCEV *ExtendedExpr) -> void { if (Expr != ExtendedExpr && !isKnownPredicate(ICmpInst::ICMP_EQ, Expr, ExtendedExpr)) { const SCEVPredicate *Pred = getEqualPredicate(Expr, ExtendedExpr); @@ -4444,14 +4680,14 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI Predicates.push_back(Pred); } }; - - AppendPredicate(StartVal); - AppendPredicate(Accum); - + + AppendPredicate(StartVal, StartExtended); + AppendPredicate(Accum, AccumExtended); + // *** Part3: Predicates are ready. Now go ahead and create the new addrec in // which the casts had been folded away. The caller can rewrite SymbolicPHI // into NewAR if it will also add the runtime overflow checks specified in - // Predicates. + // Predicates. auto *NewAR = getAddRecExpr(StartVal, Accum, L, SCEV::FlagAnyWrap); std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> PredRewrite = @@ -4463,7 +4699,6 @@ ScalarEvolution::createAddRecFromPHIWithCastsImpl(const SCEVUnknown *SymbolicPHI Optional<std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>>> ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { - auto *PN = cast<PHINode>(SymbolicPHI->getValue()); const Loop *L = isIntegerLoopHeaderPHI(PN, LI); if (!L) @@ -4475,7 +4710,7 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { std::pair<const SCEV *, SmallVector<const SCEVPredicate *, 3>> Rewrite = I->second; // Analysis was done before and failed to create an AddRec: - if (Rewrite.first == SymbolicPHI) + if (Rewrite.first == SymbolicPHI) return None; // Analysis was done before and succeeded to create an AddRec under // a predicate: @@ -4497,6 +4732,30 @@ ScalarEvolution::createAddRecFromPHIWithCasts(const SCEVUnknown *SymbolicPHI) { return Rewrite; } +// FIXME: This utility is currently required because the Rewriter currently +// does not rewrite this expression: +// {0, +, (sext ix (trunc iy to ix) to iy)} +// into {0, +, %step}, +// even when the following Equal predicate exists: +// "%step == (sext ix (trunc iy to ix) to iy)". +bool PredicatedScalarEvolution::areAddRecsEqualWithPreds( + const SCEVAddRecExpr *AR1, const SCEVAddRecExpr *AR2) const { + if (AR1 == AR2) + return true; + + auto areExprsEqual = [&](const SCEV *Expr1, const SCEV *Expr2) -> bool { + if (Expr1 != Expr2 && !Preds.implies(SE.getEqualPredicate(Expr1, Expr2)) && + !Preds.implies(SE.getEqualPredicate(Expr2, Expr1))) + return false; + return true; + }; + + if (!areExprsEqual(AR1->getStart(), AR2->getStart()) || + !areExprsEqual(AR1->getStepRecurrence(SE), AR2->getStepRecurrence(SE))) + return false; + return true; +} + /// A helper function for createAddRecFromPHI to handle simple cases. /// /// This function tries to find an AddRec expression for the simplest (yet most @@ -4612,7 +4871,8 @@ const SCEV *ScalarEvolution::createAddRecFromPHI(PHINode *PN) { SmallVector<const SCEV *, 8> Ops; for (unsigned i = 0, e = Add->getNumOperands(); i != e; ++i) if (i != FoundIndex) - Ops.push_back(Add->getOperand(i)); + Ops.push_back(SCEVBackedgeConditionFolder::rewrite(Add->getOperand(i), + L, *this)); const SCEV *Accum = getAddExpr(Ops); // This is not a valid addrec if the step amount is varying each @@ -5599,7 +5859,7 @@ bool ScalarEvolution::isAddRecNeverPoison(const Instruction *I, const Loop *L) { ScalarEvolution::LoopProperties ScalarEvolution::getLoopProperties(const Loop *L) { - typedef ScalarEvolution::LoopProperties LoopProperties; + using LoopProperties = ScalarEvolution::LoopProperties; auto Itr = LoopPropertiesCache.find(L); if (Itr == LoopPropertiesCache.end()) { @@ -5735,6 +5995,8 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } case Instruction::UDiv: return getUDivExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); + case Instruction::URem: + return getURemExpr(getSCEV(BO->LHS), getSCEV(BO->RHS)); case Instruction::Sub: { SCEV::NoWrapFlags Flags = SCEV::FlagAnyWrap; if (BO->Op) @@ -5886,7 +6148,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } break; - case Instruction::AShr: + case Instruction::AShr: { // AShr X, C, where C is a constant. ConstantInt *CI = dyn_cast<ConstantInt>(BO->RHS); if (!CI) @@ -5938,6 +6200,7 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { } break; } + } } switch (U->getOpcode()) { @@ -5948,6 +6211,21 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getZeroExtendExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::SExt: + if (auto BO = MatchBinaryOp(U->getOperand(0), DT)) { + // The NSW flag of a subtract does not always survive the conversion to + // A + (-1)*B. By pushing sign extension onto its operands we are much + // more likely to preserve NSW and allow later AddRec optimisations. + // + // NOTE: This is effectively duplicating this logic from getSignExtend: + // sext((A + B + ...)<nsw>) --> (sext(A) + sext(B) + ...)<nsw> + // but by that point the NSW information has potentially been lost. + if (BO->Opcode == Instruction::Sub && BO->IsNSW) { + Type *Ty = U->getType(); + auto *V1 = getSignExtendExpr(getSCEV(BO->LHS), Ty); + auto *V2 = getSignExtendExpr(getSCEV(BO->RHS), Ty); + return getMinusSCEV(V1, V2, SCEV::FlagNSW); + } + } return getSignExtendExpr(getSCEV(U->getOperand(0)), U->getType()); case Instruction::BitCast: @@ -5987,8 +6265,6 @@ const SCEV *ScalarEvolution::createSCEV(Value *V) { return getUnknown(V); } - - //===----------------------------------------------------------------------===// // Iteration Count Computation Code // @@ -6177,11 +6453,9 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { SmallVector<Instruction *, 16> Worklist; PushLoopPHIs(L, Worklist); - SmallPtrSet<Instruction *, 8> Visited; + SmallPtrSet<Instruction *, 8> Discovered; while (!Worklist.empty()) { Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I).second) - continue; ValueExprMapType::iterator It = ValueExprMap.find_as(static_cast<Value *>(I)); @@ -6202,7 +6476,31 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { ConstantEvolutionLoopExitValue.erase(PN); } - PushDefUseChildren(I, Worklist); + // Since we don't need to invalidate anything for correctness and we're + // only invalidating to make SCEV's results more precise, we get to stop + // early to avoid invalidating too much. This is especially important in + // cases like: + // + // %v = f(pn0, pn1) // pn0 and pn1 used through some other phi node + // loop0: + // %pn0 = phi + // ... + // loop1: + // %pn1 = phi + // ... + // + // where both loop0 and loop1's backedge taken count uses the SCEV + // expression for %v. If we don't have the early stop below then in cases + // like the above, getBackedgeTakenInfo(loop1) will clear out the trip + // count for loop0 and getBackedgeTakenInfo(loop0) will clear out the trip + // count for loop1, effectively nullifying SCEV's trip count cache. + for (auto *U : I->users()) + if (auto *I = dyn_cast<Instruction>(U)) { + auto *LoopForUser = LI.getLoopFor(I->getParent()); + if (LoopForUser && L->contains(LoopForUser) && + Discovered.insert(I).second) + Worklist.push_back(I); + } } } @@ -6217,7 +6515,7 @@ ScalarEvolution::getBackedgeTakenInfo(const Loop *L) { void ScalarEvolution::forgetLoop(const Loop *L) { // Drop any stored trip count value. auto RemoveLoopFromBackedgeMap = - [L](DenseMap<const Loop *, BackedgeTakenInfo> &Map) { + [](DenseMap<const Loop *, BackedgeTakenInfo> &Map, const Loop *L) { auto BTCPos = Map.find(L); if (BTCPos != Map.end()) { BTCPos->second.clear(); @@ -6225,47 +6523,59 @@ void ScalarEvolution::forgetLoop(const Loop *L) { } }; - RemoveLoopFromBackedgeMap(BackedgeTakenCounts); - RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts); + SmallVector<const Loop *, 16> LoopWorklist(1, L); + SmallVector<Instruction *, 32> Worklist; + SmallPtrSet<Instruction *, 16> Visited; - // Drop information about predicated SCEV rewrites for this loop. - for (auto I = PredicatedSCEVRewrites.begin(); - I != PredicatedSCEVRewrites.end();) { - std::pair<const SCEV *, const Loop *> Entry = I->first; - if (Entry.second == L) - PredicatedSCEVRewrites.erase(I++); - else - ++I; - } + // Iterate over all the loops and sub-loops to drop SCEV information. + while (!LoopWorklist.empty()) { + auto *CurrL = LoopWorklist.pop_back_val(); - // Drop information about expressions based on loop-header PHIs. - SmallVector<Instruction *, 16> Worklist; - PushLoopPHIs(L, Worklist); + RemoveLoopFromBackedgeMap(BackedgeTakenCounts, CurrL); + RemoveLoopFromBackedgeMap(PredicatedBackedgeTakenCounts, CurrL); - SmallPtrSet<Instruction *, 8> Visited; - while (!Worklist.empty()) { - Instruction *I = Worklist.pop_back_val(); - if (!Visited.insert(I).second) - continue; + // Drop information about predicated SCEV rewrites for this loop. + for (auto I = PredicatedSCEVRewrites.begin(); + I != PredicatedSCEVRewrites.end();) { + std::pair<const SCEV *, const Loop *> Entry = I->first; + if (Entry.second == CurrL) + PredicatedSCEVRewrites.erase(I++); + else + ++I; + } - ValueExprMapType::iterator It = - ValueExprMap.find_as(static_cast<Value *>(I)); - if (It != ValueExprMap.end()) { - eraseValueFromMap(It->first); - forgetMemoizedResults(It->second); - if (PHINode *PN = dyn_cast<PHINode>(I)) - ConstantEvolutionLoopExitValue.erase(PN); + auto LoopUsersItr = LoopUsers.find(CurrL); + if (LoopUsersItr != LoopUsers.end()) { + for (auto *S : LoopUsersItr->second) + forgetMemoizedResults(S); + LoopUsers.erase(LoopUsersItr); } - PushDefUseChildren(I, Worklist); - } + // Drop information about expressions based on loop-header PHIs. + PushLoopPHIs(CurrL, Worklist); - // Forget all contained loops too, to avoid dangling entries in the - // ValuesAtScopes map. - for (Loop *I : *L) - forgetLoop(I); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + if (!Visited.insert(I).second) + continue; - LoopPropertiesCache.erase(L); + ValueExprMapType::iterator It = + ValueExprMap.find_as(static_cast<Value *>(I)); + if (It != ValueExprMap.end()) { + eraseValueFromMap(It->first); + forgetMemoizedResults(It->second); + if (PHINode *PN = dyn_cast<PHINode>(I)) + ConstantEvolutionLoopExitValue.erase(PN); + } + + PushDefUseChildren(I, Worklist); + } + + LoopPropertiesCache.erase(CurrL); + // Forget all contained loops too, to avoid dangling entries in the + // ValuesAtScopes map. + LoopWorklist.append(CurrL->begin(), CurrL->end()); + } } void ScalarEvolution::forgetValue(Value *V) { @@ -6377,7 +6687,7 @@ bool ScalarEvolution::BackedgeTakenInfo::hasOperand(const SCEV *S, } ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E) - : ExactNotTaken(E), MaxNotTaken(E), MaxOrZero(false) { + : ExactNotTaken(E), MaxNotTaken(E) { assert((isa<SCEVCouldNotCompute>(MaxNotTaken) || isa<SCEVConstant>(MaxNotTaken)) && "No point in having a non-constant max backedge taken count!"); @@ -6422,7 +6732,8 @@ ScalarEvolution::BackedgeTakenInfo::BackedgeTakenInfo( &&ExitCounts, bool Complete, const SCEV *MaxCount, bool MaxOrZero) : MaxAndComplete(MaxCount, Complete), MaxOrZero(MaxOrZero) { - typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; + using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; + ExitNotTaken.reserve(ExitCounts.size()); std::transform( ExitCounts.begin(), ExitCounts.end(), std::back_inserter(ExitNotTaken), @@ -6454,7 +6765,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, SmallVector<BasicBlock *, 8> ExitingBlocks; L->getExitingBlocks(ExitingBlocks); - typedef ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo EdgeExitInfo; + using EdgeExitInfo = ScalarEvolution::BackedgeTakenInfo::EdgeExitInfo; SmallVector<EdgeExitInfo, 4> ExitCounts; bool CouldComputeBECount = true; @@ -6521,8 +6832,7 @@ ScalarEvolution::computeBackedgeTakenCount(const Loop *L, ScalarEvolution::ExitLimit ScalarEvolution::computeExitLimit(const Loop *L, BasicBlock *ExitingBlock, - bool AllowPredicates) { - + bool AllowPredicates) { // Okay, we've chosen an exiting block. See what condition causes us to exit // at this block and remember the exit block and whether all other targets // lead to the loop header. @@ -6785,19 +7095,19 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, BasicBlock *FBB, bool ControlsExit, bool AllowPredicates) { - // If the condition was exit on true, convert the condition to exit on false - ICmpInst::Predicate Cond; + ICmpInst::Predicate Pred; if (!L->contains(FBB)) - Cond = ExitCond->getPredicate(); + Pred = ExitCond->getPredicate(); else - Cond = ExitCond->getInversePredicate(); + Pred = ExitCond->getInversePredicate(); + const ICmpInst::Predicate OriginalPred = Pred; // Handle common loops like: for (X = "string"; *X; ++X) if (LoadInst *LI = dyn_cast<LoadInst>(ExitCond->getOperand(0))) if (Constant *RHS = dyn_cast<Constant>(ExitCond->getOperand(1))) { ExitLimit ItCnt = - computeLoadConstantCompareExitLimit(LI, RHS, L, Cond); + computeLoadConstantCompareExitLimit(LI, RHS, L, Pred); if (ItCnt.hasAnyInfo()) return ItCnt; } @@ -6814,11 +7124,11 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (isLoopInvariant(LHS, L) && !isLoopInvariant(RHS, L)) { // If there is a loop-invariant, force it into the RHS. std::swap(LHS, RHS); - Cond = ICmpInst::getSwappedPredicate(Cond); + Pred = ICmpInst::getSwappedPredicate(Pred); } // Simplify the operands before analyzing them. - (void)SimplifyICmpOperands(Cond, LHS, RHS); + (void)SimplifyICmpOperands(Pred, LHS, RHS); // If we have a comparison of a chrec against a constant, try to use value // ranges to answer this query. @@ -6827,13 +7137,13 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, if (AddRec->getLoop() == L) { // Form the constant range. ConstantRange CompRange = - ConstantRange::makeExactICmpRegion(Cond, RHSC->getAPInt()); + ConstantRange::makeExactICmpRegion(Pred, RHSC->getAPInt()); const SCEV *Ret = AddRec->getNumIterationsInRange(CompRange, *this); if (!isa<SCEVCouldNotCompute>(Ret)) return Ret; } - switch (Cond) { + switch (Pred) { case ICmpInst::ICMP_NE: { // while (X != Y) // Convert to: while (X-Y != 0) ExitLimit EL = howFarToZero(getMinusSCEV(LHS, RHS), L, ControlsExit, @@ -6849,7 +7159,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, } case ICmpInst::ICMP_SLT: case ICmpInst::ICMP_ULT: { // while (X < Y) - bool IsSigned = Cond == ICmpInst::ICMP_SLT; + bool IsSigned = Pred == ICmpInst::ICMP_SLT; ExitLimit EL = howManyLessThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates); if (EL.hasAnyInfo()) return EL; @@ -6857,7 +7167,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, } case ICmpInst::ICMP_SGT: case ICmpInst::ICMP_UGT: { // while (X > Y) - bool IsSigned = Cond == ICmpInst::ICMP_SGT; + bool IsSigned = Pred == ICmpInst::ICMP_SGT; ExitLimit EL = howManyGreaterThans(LHS, RHS, L, IsSigned, ControlsExit, AllowPredicates); @@ -6875,7 +7185,7 @@ ScalarEvolution::computeExitLimitFromICmp(const Loop *L, return ExhaustiveCount; return computeShiftCompareExitLimit(ExitCond->getOperand(0), - ExitCond->getOperand(1), L, Cond); + ExitCond->getOperand(1), L, OriginalPred); } ScalarEvolution::ExitLimit @@ -6920,7 +7230,6 @@ ScalarEvolution::computeLoadConstantCompareExitLimit( Constant *RHS, const Loop *L, ICmpInst::Predicate predicate) { - if (LI->isVolatile()) return getCouldNotCompute(); // Check to see if the loaded pointer is a getelementptr of a global. @@ -7333,8 +7642,8 @@ ScalarEvolution::getConstantEvolutionLoopExitValue(PHINode *PN, Value *BEValue = PN->getIncomingValueForBlock(Latch); // Execute the loop symbolically to determine the exit value. - if (BEs.getActiveBits() >= 32) - return RetVal = nullptr; // More than 2^32-1 iterations?? Not doing it! + assert(BEs.getActiveBits() < CHAR_BIT * sizeof(unsigned) && + "BEs is <= MaxBruteForceIterations which is an 'unsigned'!"); unsigned NumIterations = BEs.getZExtValue(); // must be in range unsigned IterationNum = 0; @@ -7839,7 +8148,6 @@ static const SCEV *SolveLinEquationWithOverflow(const APInt &A, const SCEV *B, /// Find the roots of the quadratic equation for the given quadratic chrec /// {L,+,M,+,N}. This returns either the two roots (which might be the same) or /// two SCEVCouldNotCompute objects. -/// static Optional<std::pair<const SCEVConstant *,const SCEVConstant *>> SolveQuadraticEquation(const SCEVAddRecExpr *AddRec, ScalarEvolution &SE) { assert(AddRec->getNumOperands() == 3 && "This is not a quadratic chrec!"); @@ -8080,7 +8388,6 @@ ScalarEvolution::getPredecessorWithUniqueSuccessorForBB(BasicBlock *BB) { /// expressions are equal, however for the purposes of looking for a condition /// guarding a loop, it can be useful to be a little more general, since a /// front-end may have replicated the controlling expression. -/// static bool HasSameValue(const SCEV *A, const SCEV *B) { // Quick check to see if they are the same SCEV. if (A == B) return true; @@ -8527,7 +8834,6 @@ bool ScalarEvolution::isKnownPredicateViaConstantRanges( bool ScalarEvolution::isKnownPredicateViaNoOverflow(ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - // Match Result to (X + Y)<ExpectedFlags> where Y is a constant integer. // Return Y via OutY. auto MatchBinaryAddToConst = @@ -8693,7 +8999,6 @@ ScalarEvolution::isLoopBackedgeGuardedByCond(const Loop *L, for (DomTreeNode *DTN = DT[Latch], *HeaderDTN = DT[L->getHeader()]; DTN != HeaderDTN; DTN = DTN->getIDom()) { - assert(DTN && "should reach the loop header before reaching the root!"); BasicBlock *BB = DTN->getBlock(); @@ -9116,7 +9421,6 @@ bool ScalarEvolution::isImpliedCondOperands(ICmpInst::Predicate Pred, getNotSCEV(FoundLHS)); } - /// If Expr computes ~A, return A else return nullptr static const SCEV *MatchNotExpr(const SCEV *Expr) { const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Expr); @@ -9132,7 +9436,6 @@ static const SCEV *MatchNotExpr(const SCEV *Expr) { return AddRHS->getOperand(1); } - /// Is MaybeMaxExpr an SMax or UMax of Candidate and some other values? template<typename MaxExprType> static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, @@ -9143,7 +9446,6 @@ static bool IsMaxConsistingOf(const SCEV *MaybeMaxExpr, return find(MaxExpr->operands(), Candidate) != MaxExpr->op_end(); } - /// Is MaybeMinExpr an SMin or UMin of Candidate and some other values? template<typename MaxExprType> static bool IsMinConsistingOf(ScalarEvolution &SE, @@ -9159,7 +9461,6 @@ static bool IsMinConsistingOf(ScalarEvolution &SE, static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS) { - // If both sides are affine addrecs for the same loop, with equal // steps, and we know the recurrences don't wrap, then we only // need to check the predicate on the starting values. @@ -9295,7 +9596,9 @@ bool ScalarEvolution::isImpliedViaOperations(ICmpInst::Predicate Pred, } else if (auto *LHSUnknownExpr = dyn_cast<SCEVUnknown>(LHS)) { Value *LL, *LR; // FIXME: Once we have SDiv implemented, we can get rid of this matching. + using namespace llvm::PatternMatch; + if (match(LHSUnknownExpr->getValue(), m_SDiv(m_Value(LL), m_Value(LR)))) { // Rules for division. // We are going to perform some comparisons with Denominator and its @@ -9510,14 +9813,54 @@ const SCEV *ScalarEvolution::computeBECount(const SCEV *Delta, const SCEV *Step, return getUDivExpr(Delta, Step); } +const SCEV *ScalarEvolution::computeMaxBECountForLT(const SCEV *Start, + const SCEV *Stride, + const SCEV *End, + unsigned BitWidth, + bool IsSigned) { + + assert(!isKnownNonPositive(Stride) && + "Stride is expected strictly positive!"); + // Calculate the maximum backedge count based on the range of values + // permitted by Start, End, and Stride. + const SCEV *MaxBECount; + APInt MinStart = + IsSigned ? getSignedRangeMin(Start) : getUnsignedRangeMin(Start); + + APInt StrideForMaxBECount = + IsSigned ? getSignedRangeMin(Stride) : getUnsignedRangeMin(Stride); + + // We already know that the stride is positive, so we paper over conservatism + // in our range computation by forcing StrideForMaxBECount to be at least one. + // In theory this is unnecessary, but we expect MaxBECount to be a + // SCEVConstant, and (udiv <constant> 0) is not constant folded by SCEV (there + // is nothing to constant fold it to). + APInt One(BitWidth, 1, IsSigned); + StrideForMaxBECount = APIntOps::smax(One, StrideForMaxBECount); + + APInt MaxValue = IsSigned ? APInt::getSignedMaxValue(BitWidth) + : APInt::getMaxValue(BitWidth); + APInt Limit = MaxValue - (StrideForMaxBECount - 1); + + // Although End can be a MAX expression we estimate MaxEnd considering only + // the case End = RHS of the loop termination condition. This is safe because + // in the other case (End - Start) is zero, leading to a zero maximum backedge + // taken count. + APInt MaxEnd = IsSigned ? APIntOps::smin(getSignedRangeMax(End), Limit) + : APIntOps::umin(getUnsignedRangeMax(End), Limit); + + MaxBECount = computeBECount(getConstant(MaxEnd - MinStart) /* Delta */, + getConstant(StrideForMaxBECount) /* Step */, + false /* Equality */); + + return MaxBECount; +} + ScalarEvolution::ExitLimit ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, const Loop *L, bool IsSigned, bool ControlsExit, bool AllowPredicates) { SmallPtrSet<const SCEVPredicate *, 4> Predicates; - // We handle only IV < Invariant - if (!isLoopInvariant(RHS, L)) - return getCouldNotCompute(); const SCEVAddRecExpr *IV = dyn_cast<SCEVAddRecExpr>(LHS); bool PredicatedIV = false; @@ -9588,7 +9931,6 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, if (PredicatedIV || !NoWrap || isKnownNonPositive(Stride) || !loopHasNoSideEffects(L)) return getCouldNotCompute(); - } else if (!Stride->isOne() && doesIVOverflowOnLT(RHS, Stride, IsSigned, NoWrap)) // Avoid proven overflow cases: this will ensure that the backedge taken @@ -9601,6 +9943,17 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, : ICmpInst::ICMP_ULT; const SCEV *Start = IV->getStart(); const SCEV *End = RHS; + // When the RHS is not invariant, we do not know the end bound of the loop and + // cannot calculate the ExactBECount needed by ExitLimit. However, we can + // calculate the MaxBECount, given the start, stride and max value for the end + // bound of the loop (RHS), and the fact that IV does not overflow (which is + // checked above). + if (!isLoopInvariant(RHS, L)) { + const SCEV *MaxBECount = computeMaxBECountForLT( + Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); + return ExitLimit(getCouldNotCompute() /* ExactNotTaken */, MaxBECount, + false /*MaxOrZero*/, Predicates); + } // If the backedge is taken at least once, then it will be taken // (End-Start)/Stride times (rounded up to a multiple of Stride), where Start // is the LHS value of the less-than comparison the first time it is evaluated @@ -9633,37 +9986,8 @@ ScalarEvolution::howManyLessThans(const SCEV *LHS, const SCEV *RHS, MaxBECount = BECountIfBackedgeTaken; MaxOrZero = true; } else { - // Calculate the maximum backedge count based on the range of values - // permitted by Start, End, and Stride. - APInt MinStart = IsSigned ? getSignedRangeMin(Start) - : getUnsignedRangeMin(Start); - - unsigned BitWidth = getTypeSizeInBits(LHS->getType()); - - APInt StrideForMaxBECount; - - if (PositiveStride) - StrideForMaxBECount = - IsSigned ? getSignedRangeMin(Stride) - : getUnsignedRangeMin(Stride); - else - // Using a stride of 1 is safe when computing max backedge taken count for - // a loop with unknown stride. - StrideForMaxBECount = APInt(BitWidth, 1, IsSigned); - - APInt Limit = - IsSigned ? APInt::getSignedMaxValue(BitWidth) - (StrideForMaxBECount - 1) - : APInt::getMaxValue(BitWidth) - (StrideForMaxBECount - 1); - - // Although End can be a MAX expression we estimate MaxEnd considering only - // the case End = RHS. This is safe because in the other case (End - Start) - // is zero, leading to a zero maximum backedge taken count. - APInt MaxEnd = - IsSigned ? APIntOps::smin(getSignedRangeMax(RHS), Limit) - : APIntOps::umin(getUnsignedRangeMax(RHS), Limit); - - MaxBECount = computeBECount(getConstant(MaxEnd - MinStart), - getConstant(StrideForMaxBECount), false); + MaxBECount = computeMaxBECountForLT( + Start, Stride, RHS, getTypeSizeInBits(LHS->getType()), IsSigned); } if (isa<SCEVCouldNotCompute>(MaxBECount) && @@ -9874,6 +10198,7 @@ static inline bool containsUndefs(const SCEV *S) { } namespace { + // Collect all steps of SCEV expressions. struct SCEVCollectStrides { ScalarEvolution &SE; @@ -9887,6 +10212,7 @@ struct SCEVCollectStrides { Strides.push_back(AR->getStepRecurrence(SE)); return true; } + bool isDone() const { return false; } }; @@ -9894,8 +10220,7 @@ struct SCEVCollectStrides { struct SCEVCollectTerms { SmallVectorImpl<const SCEV *> &Terms; - SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) - : Terms(T) {} + SCEVCollectTerms(SmallVectorImpl<const SCEV *> &T) : Terms(T) {} bool follow(const SCEV *S) { if (isa<SCEVUnknown>(S) || isa<SCEVMulExpr>(S) || @@ -9910,6 +10235,7 @@ struct SCEVCollectTerms { // Keep looking. return true; } + bool isDone() const { return false; } }; @@ -9918,7 +10244,7 @@ struct SCEVHasAddRec { bool &ContainsAddRec; SCEVHasAddRec(bool &ContainsAddRec) : ContainsAddRec(ContainsAddRec) { - ContainsAddRec = false; + ContainsAddRec = false; } bool follow(const SCEV *S) { @@ -9932,6 +10258,7 @@ struct SCEVHasAddRec { // Keep looking. return true; } + bool isDone() const { return false; } }; @@ -9985,9 +10312,11 @@ struct SCEVCollectAddRecMultiplies { // Keep looking. return true; } + bool isDone() const { return false; } }; -} + +} // end anonymous namespace /// Find parametric terms in this SCEVAddRecExpr. We first for parameters in /// two places: @@ -10066,7 +10395,6 @@ static bool findArrayDimensionsRec(ScalarEvolution &SE, return true; } - // Returns true when one of the SCEVs of Terms contains a SCEVUnknown parameter. static inline bool containsParameters(SmallVectorImpl<const SCEV *> &Terms) { for (const SCEV *T : Terms) @@ -10181,7 +10509,6 @@ void ScalarEvolution::findArrayDimensions(SmallVectorImpl<const SCEV *> &Terms, void ScalarEvolution::computeAccessFunctions( const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<const SCEV *> &Sizes) { - // Early exit in case this SCEV is not an affine multivariate function. if (Sizes.empty()) return; @@ -10285,7 +10612,6 @@ void ScalarEvolution::computeAccessFunctions( /// DelinearizationPass that walks through all loads and stores of a function /// asking for the SCEV of the memory access with respect to all enclosing /// loops, calling SCEV->delinearize on that and printing the results. - void ScalarEvolution::delinearize(const SCEV *Expr, SmallVectorImpl<const SCEV *> &Subscripts, SmallVectorImpl<const SCEV *> &Sizes, @@ -10374,11 +10700,8 @@ ScalarEvolution::ScalarEvolution(Function &F, TargetLibraryInfo &TLI, AssumptionCache &AC, DominatorTree &DT, LoopInfo &LI) : F(F), TLI(TLI), AC(AC), DT(DT), LI(LI), - CouldNotCompute(new SCEVCouldNotCompute()), - WalkingBEDominatingConds(false), ProvingSplitPredicate(false), - ValuesAtScopes(64), LoopDispositions(64), BlockDispositions(64), - FirstUnknown(nullptr) { - + CouldNotCompute(new SCEVCouldNotCompute()), ValuesAtScopes(64), + LoopDispositions(64), BlockDispositions(64) { // To use guards for proving predicates, we need to scan every instruction in // relevant basic blocks, and not just terminators. Doing this is a waste of // time if the IR does not actually contain any calls to @@ -10399,7 +10722,6 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) LI(Arg.LI), CouldNotCompute(std::move(Arg.CouldNotCompute)), ValueExprMap(std::move(Arg.ValueExprMap)), PendingLoopPredicates(std::move(Arg.PendingLoopPredicates)), - WalkingBEDominatingConds(false), ProvingSplitPredicate(false), MinTrailingZerosCache(std::move(Arg.MinTrailingZerosCache)), BackedgeTakenCounts(std::move(Arg.BackedgeTakenCounts)), PredicatedBackedgeTakenCounts( @@ -10415,6 +10737,7 @@ ScalarEvolution::ScalarEvolution(ScalarEvolution &&Arg) UniqueSCEVs(std::move(Arg.UniqueSCEVs)), UniquePreds(std::move(Arg.UniquePreds)), SCEVAllocator(std::move(Arg.SCEVAllocator)), + LoopUsers(std::move(Arg.LoopUsers)), PredicatedSCEVRewrites(std::move(Arg.PredicatedSCEVRewrites)), FirstUnknown(Arg.FirstUnknown) { Arg.FirstUnknown = nullptr; @@ -10647,9 +10970,11 @@ ScalarEvolution::computeLoopDisposition(const SCEV *S, const Loop *L) { if (!L) return LoopVariant; - // This recurrence is variant w.r.t. L if L contains AR's loop. - if (L->contains(AR->getLoop())) + // Everything that is not defined at loop entry is variant. + if (DT.dominates(L->getHeader(), AR->getLoop()->getHeader())) return LoopVariant; + assert(!L->contains(AR->getLoop()) && "Containing loop's header does not" + " dominate the contained loop's header?"); // This recurrence is invariant w.r.t. L if AR's loop contains L. if (AR->getLoop()->contains(L)) @@ -10806,7 +11131,16 @@ bool ScalarEvolution::hasOperand(const SCEV *S, const SCEV *Op) const { return SCEVExprContains(S, [&](const SCEV *Expr) { return Expr == Op; }); } -void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { +bool ScalarEvolution::ExitLimit::hasOperand(const SCEV *S) const { + auto IsS = [&](const SCEV *X) { return S == X; }; + auto ContainsS = [&](const SCEV *X) { + return !isa<SCEVCouldNotCompute>(X) && SCEVExprContains(X, IsS); + }; + return ContainsS(ExactNotTaken) || ContainsS(MaxNotTaken); +} + +void +ScalarEvolution::forgetMemoizedResults(const SCEV *S) { ValuesAtScopes.erase(S); LoopDispositions.erase(S); BlockDispositions.erase(S); @@ -10816,7 +11150,7 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { HasRecMap.erase(S); MinTrailingZerosCache.erase(S); - for (auto I = PredicatedSCEVRewrites.begin(); + for (auto I = PredicatedSCEVRewrites.begin(); I != PredicatedSCEVRewrites.end();) { std::pair<const SCEV *, const Loop *> Entry = I->first; if (Entry.first == S) @@ -10841,6 +11175,25 @@ void ScalarEvolution::forgetMemoizedResults(const SCEV *S) { RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts); } +void ScalarEvolution::addToLoopUseLists(const SCEV *S) { + struct FindUsedLoops { + SmallPtrSet<const Loop *, 8> LoopsUsed; + bool follow(const SCEV *S) { + if (auto *AR = dyn_cast<SCEVAddRecExpr>(S)) + LoopsUsed.insert(AR->getLoop()); + return true; + } + + bool isDone() const { return false; } + }; + + FindUsedLoops F; + SCEVTraversal<FindUsedLoops>(F).visitAll(S); + + for (auto *L : F.LoopsUsed) + LoopUsers[L].push_back(S); +} + void ScalarEvolution::verify() const { ScalarEvolution &SE = *const_cast<ScalarEvolution *>(this); ScalarEvolution SE2(F, TLI, AC, DT, LI); @@ -10849,9 +11202,12 @@ void ScalarEvolution::verify() const { // Map's SCEV expressions from one ScalarEvolution "universe" to another. struct SCEVMapper : public SCEVRewriteVisitor<SCEVMapper> { + SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} + const SCEV *visitConstant(const SCEVConstant *Constant) { return SE.getConstant(Constant->getAPInt()); } + const SCEV *visitUnknown(const SCEVUnknown *Expr) { return SE.getUnknown(Expr->getValue()); } @@ -10859,7 +11215,6 @@ void ScalarEvolution::verify() const { const SCEV *visitCouldNotCompute(const SCEVCouldNotCompute *Expr) { return SE.getCouldNotCompute(); } - SCEVMapper(ScalarEvolution &SE) : SCEVRewriteVisitor<SCEVMapper>(SE) {} }; SCEVMapper SCM(SE2); @@ -10948,6 +11303,7 @@ INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(ScalarEvolutionWrapperPass, "scalar-evolution", "Scalar Evolution Analysis", false, true) + char ScalarEvolutionWrapperPass::ID = 0; ScalarEvolutionWrapperPass::ScalarEvolutionWrapperPass() : FunctionPass(ID) { @@ -11023,6 +11379,7 @@ namespace { class SCEVPredicateRewriter : public SCEVRewriteVisitor<SCEVPredicateRewriter> { public: + /// Rewrites \p S in the context of a loop L and the SCEV predication /// infrastructure. /// @@ -11038,11 +11395,6 @@ public: return Rewriter.visit(S); } - SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, - SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, - SCEVUnionPredicate *Pred) - : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} - const SCEV *visitUnknown(const SCEVUnknown *Expr) { if (Pred) { auto ExprPreds = Pred->getPredicatesForExpr(Expr); @@ -11087,6 +11439,11 @@ public: } private: + explicit SCEVPredicateRewriter(const Loop *L, ScalarEvolution &SE, + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds, + SCEVUnionPredicate *Pred) + : SCEVRewriteVisitor(SE), NewPreds(NewPreds), Pred(Pred), L(L) {} + bool addOverflowAssumption(const SCEVPredicate *P) { if (!NewPreds) { // Check if we've already made this assumption. @@ -11103,10 +11460,10 @@ private: } // If \p Expr represents a PHINode, we try to see if it can be represented - // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible + // as an AddRec, possibly under a predicate (PHISCEVPred). If it is possible // to add this predicate as a runtime overflow check, we return the AddRec. - // If \p Expr does not meet these conditions (is not a PHI node, or we - // couldn't create an AddRec for it, or couldn't add the predicate), we just + // If \p Expr does not meet these conditions (is not a PHI node, or we + // couldn't create an AddRec for it, or couldn't add the predicate), we just // return \p Expr. const SCEV *convertToAddRecWithPreds(const SCEVUnknown *Expr) { if (!isa<PHINode>(Expr->getValue())) @@ -11121,11 +11478,12 @@ private: } return PredicatedRewrite->first; } - + SmallPtrSetImpl<const SCEVPredicate *> *NewPreds; SCEVUnionPredicate *Pred; const Loop *L; }; + } // end anonymous namespace const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, @@ -11136,7 +11494,6 @@ const SCEV *ScalarEvolution::rewriteUsingPredicate(const SCEV *S, const Loop *L, const SCEVAddRecExpr *ScalarEvolution::convertSCEVToAddRecWithPredicates( const SCEV *S, const Loop *L, SmallPtrSetImpl<const SCEVPredicate *> &Preds) { - SmallPtrSet<const SCEVPredicate *, 4> TransformPreds; S = SCEVPredicateRewriter::rewrite(S, L, *this, &TransformPreds, nullptr); auto *AddRec = dyn_cast<SCEVAddRecExpr>(S); @@ -11292,7 +11649,7 @@ void SCEVUnionPredicate::add(const SCEVPredicate *N) { PredicatedScalarEvolution::PredicatedScalarEvolution(ScalarEvolution &SE, Loop &L) - : SE(SE), L(L), Generation(0), BackedgeCount(nullptr) {} + : SE(SE), L(L) {} const SCEV *PredicatedScalarEvolution::getSCEV(Value *V) { const SCEV *Expr = SE.getSCEV(V); |